Skip to main content

ares/db/
tenants.rs

1use crate::db::PostgresClient;
2use crate::models::{ApiKey, Tenant, TenantContext, TenantTier};
3use crate::types::{AppError, Result};
4use chrono::{Datelike, TimeZone, Utc};
5use sha2::{Digest, Sha256};
6use sqlx::Row;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11pub struct TenantDb {
12    postgres: Arc<PostgresClient>,
13    monthly_cache: Arc<RwLock<HashMap<String, (i64, u64)>>>,
14    daily_cache: Arc<RwLock<HashMap<String, (i64, u64)>>>,
15}
16
17impl TenantDb {
18    pub fn new(postgres: Arc<PostgresClient>) -> Self {
19        Self {
20            postgres,
21            monthly_cache: Arc::new(RwLock::new(HashMap::new())),
22            daily_cache: Arc::new(RwLock::new(HashMap::new())),
23        }
24    }
25
26    pub fn pool(&self) -> &sqlx::PgPool {
27        &self.postgres.pool
28    }
29
30    pub async fn create_tenant(&self, name: String, tier: TenantTier) -> Result<Tenant> {
31        let id = uuid::Uuid::new_v4().to_string();
32        let tenant = Tenant::new(id.clone(), name, tier);
33
34        sqlx::query(
35            "INSERT INTO tenants (id, name, tier, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)"
36        )
37        .bind(&tenant.id)
38        .bind(&tenant.name)
39        .bind(tenant.tier.as_str())
40        .bind(tenant.created_at)
41        .bind(tenant.updated_at)
42        .execute(&self.postgres.pool)
43        .await
44        .map_err(|e| AppError::Database(format!("Failed to create tenant: {}", e)))?;
45
46        Ok(tenant)
47    }
48
49    pub async fn list_tenants(&self) -> Result<Vec<Tenant>> {
50        let rows = sqlx::query(
51            "SELECT id, name, tier, created_at, updated_at FROM tenants ORDER BY created_at DESC",
52        )
53        .fetch_all(&self.postgres.pool)
54        .await
55        .map_err(|e| AppError::Database(format!("Failed to list tenants: {}", e)))?;
56
57        let mut tenants = Vec::new();
58        for row in rows {
59            let tier_str: String = row.get(2);
60            let tier = TenantTier::from_str(&tier_str).unwrap_or(TenantTier::Free);
61            tenants.push(Tenant {
62                id: row.get(0),
63                name: row.get(1),
64                tier,
65                created_at: row.get(3),
66                updated_at: row.get(4),
67            });
68        }
69
70        Ok(tenants)
71    }
72
73    pub async fn get_tenant(&self, tenant_id: &str) -> Result<Option<Tenant>> {
74        let row =
75            sqlx::query("SELECT id, name, tier, created_at, updated_at FROM tenants WHERE id = $1")
76                .bind(tenant_id)
77                .fetch_optional(&self.postgres.pool)
78                .await
79                .map_err(|e| AppError::Database(format!("Failed to get tenant: {}", e)))?;
80
81        if let Some(row) = row {
82            let tier_str: String = row.get(2);
83            let tier = TenantTier::from_str(&tier_str).unwrap_or(TenantTier::Free);
84            Ok(Some(Tenant {
85                id: row.get(0),
86                name: row.get(1),
87                tier,
88                created_at: row.get(3),
89                updated_at: row.get(4),
90            }))
91        } else {
92            Ok(None)
93        }
94    }
95
96    pub async fn create_api_key(&self, tenant_id: &str, name: String) -> Result<(ApiKey, String)> {
97        let id = uuid::Uuid::new_v4().to_string();
98        let raw_key = generate_api_key();
99        let key_prefix = format!("ares_{}", &raw_key[..8]);
100
101        let key_hash = hash_api_key(&raw_key);
102
103        let api_key = ApiKey::new(id, tenant_id.to_string(), key_hash, key_prefix, name);
104
105        sqlx::query(
106            "INSERT INTO api_keys (id, tenant_id, key_hash, key_prefix, name, is_active, created_at, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
107        )
108        .bind(&api_key.id)
109        .bind(&api_key.tenant_id)
110        .bind(&api_key.key_hash)
111        .bind(&api_key.key_prefix)
112        .bind(&api_key.name)
113        .bind(api_key.is_active as i32)
114        .bind(api_key.created_at)
115        .bind(api_key.expires_at)
116        .execute(&self.postgres.pool)
117        .await
118        .map_err(|e| AppError::Database(format!("Failed to create API key: {}", e)))?;
119
120        Ok((api_key, raw_key))
121    }
122
123    pub async fn list_api_keys(&self, tenant_id: &str) -> Result<Vec<ApiKey>> {
124        let rows = sqlx::query(
125            "SELECT id, tenant_id, key_hash, key_prefix, name, is_active, created_at, expires_at FROM api_keys WHERE tenant_id = $1 ORDER BY created_at DESC"
126        )
127        .bind(tenant_id)
128        .fetch_all(&self.postgres.pool)
129        .await
130        .map_err(|e| AppError::Database(format!("Failed to list API keys: {}", e)))?;
131
132        let mut keys = Vec::new();
133        for row in rows {
134            let expires_at: Option<i64> = row.get(7);
135            keys.push(ApiKey {
136                id: row.get(0),
137                tenant_id: row.get(1),
138                key_hash: row.get(2),
139                key_prefix: row.get(3),
140                name: row.get(4),
141                is_active: row.get::<i32, _>(5) != 0,
142                created_at: row.get(6),
143                expires_at,
144            });
145        }
146
147        Ok(keys)
148    }
149
150    pub async fn verify_api_key(&self, raw_key: &str) -> Result<Option<TenantContext>> {
151        let key_prefix = format!("ares_{}", &raw_key[5..13]);
152        let row = sqlx::query(
153            "SELECT ak.id, ak.tenant_id, ak.key_hash, ak.is_active, ak.expires_at, t.tier 
154             FROM api_keys ak 
155             JOIN tenants t ON ak.tenant_id = t.id 
156             WHERE ak.key_prefix = $1",
157        )
158        .bind(key_prefix)
159        .fetch_optional(&self.postgres.pool)
160        .await
161        .map_err(|e| AppError::Database(format!("Failed to lookup API key: {}", e)))?;
162
163        if let Some(row) = row {
164            let key_hash: String = row.get(2);
165            let is_active: i32 = row.get(3);
166            let expires_at: Option<i64> = row.get(4);
167            let tier_str: String = row.get(5);
168
169            if is_active == 0 {
170                return Ok(None);
171            }
172
173            if let Some(exp) = expires_at {
174                if Utc::now().timestamp() > exp {
175                    return Ok(None);
176                }
177            }
178
179            // Strip "ares_" prefix before hashing to match what create_api_key hashes
180            let key_without_prefix = raw_key.strip_prefix("ares_").unwrap_or(raw_key);
181            let input_hash = hash_api_key(key_without_prefix);
182            if input_hash != key_hash {
183                return Ok(None);
184            }
185
186            let tenant_id: String = row.get(1);
187            let tier = TenantTier::from_str(&tier_str).unwrap_or(TenantTier::Free);
188
189            Ok(Some(TenantContext::new(tenant_id, tier)))
190        } else {
191            Ok(None)
192        }
193    }
194
195    pub async fn get_monthly_requests(&self, tenant_id: &str) -> Result<u64> {
196        let cache_key = tenant_id.to_string();
197        let now = Utc::now();
198        let month_start = now
199            .date_naive()
200            .with_day(1)
201            .unwrap()
202            .and_hms_opt(0, 0, 0)
203            .unwrap()
204            .and_utc()
205            .timestamp();
206
207        {
208            let cache = self.monthly_cache.read().await;
209            if let Some((cached_month, count)) = cache.get(&cache_key) {
210                if *cached_month == month_start {
211                    return Ok(*count);
212                }
213            }
214        }
215
216        let row = sqlx::query(
217            "SELECT COALESCE(SUM(request_count)::bigint, 0) FROM monthly_usage_cache WHERE tenant_id = $1 AND usage_month >= $2"
218        )
219        .bind(tenant_id)
220        .bind(month_start)
221        .fetch_one(&self.postgres.pool)
222        .await
223        .map_err(|e| AppError::Database(format!("Failed to get monthly requests: {}", e)))?;
224
225        let count: i64 = row.try_get::<i64, _>(0).unwrap_or(0);
226        let count = count as u64;
227
228        {
229            let mut cache = self.monthly_cache.write().await;
230            cache.insert(cache_key, (month_start, count));
231        }
232
233        Ok(count)
234    }
235
236    pub async fn get_daily_requests(&self, tenant_id: &str) -> Result<u64> {
237        let cache_key = tenant_id.to_string();
238        let today = Utc::now()
239            .date_naive()
240            .and_hms_opt(0, 0, 0)
241            .unwrap()
242            .and_utc()
243            .timestamp();
244
245        {
246            let cache = self.daily_cache.read().await;
247            if let Some((cached_day, count)) = cache.get(&cache_key) {
248                if *cached_day == today {
249                    return Ok(*count);
250                }
251            }
252        }
253
254        let row = sqlx::query(
255            "SELECT COALESCE(SUM(request_count)::bigint, 0) FROM daily_rate_limits WHERE tenant_id = $1 AND usage_date >= $2"
256        )
257        .bind(tenant_id)
258        .bind(today)
259        .fetch_one(&self.postgres.pool)
260        .await
261        .map_err(|e| AppError::Database(format!("Failed to get daily requests: {}", e)))?;
262
263        let count: i64 = row.try_get::<i64, _>(0).unwrap_or(0);
264        let count = count as u64;
265
266        {
267            let mut cache = self.daily_cache.write().await;
268            cache.insert(cache_key, (today, count));
269        }
270
271        Ok(count)
272    }
273
274    pub async fn record_usage_event(
275        &self,
276        tenant_id: &str,
277        requests: u64,
278        tokens: u64,
279    ) -> Result<()> {
280        let now = Utc::now();
281        let today = now
282            .date_naive()
283            .and_hms_opt(0, 0, 0)
284            .unwrap()
285            .and_utc()
286            .timestamp();
287        let month_start = now
288            .date_naive()
289            .with_day(1)
290            .unwrap()
291            .and_hms_opt(0, 0, 0)
292            .unwrap()
293            .and_utc()
294            .timestamp();
295
296        sqlx::query(
297            "INSERT INTO usage_events (id, tenant_id, source, request_count, token_count, created_at) VALUES ($1, $2, 'http', $3, $4, $5)"
298        )
299        .bind(uuid::Uuid::new_v4().to_string())
300        .bind(tenant_id)
301        .bind(requests as i64)
302        .bind(tokens as i64)
303        .bind(now.timestamp())
304        .execute(&self.postgres.pool)
305        .await
306        .map_err(|e| AppError::Database(format!("Failed to record usage event: {}", e)))?;
307
308        sqlx::query(
309            "INSERT INTO monthly_usage_cache (tenant_id, usage_month, request_count, token_count) VALUES ($1, $2, $3, $4)
310             ON CONFLICT(tenant_id, usage_month) DO UPDATE SET 
311             request_count = monthly_usage_cache.request_count + $5, token_count = monthly_usage_cache.token_count + $6"
312        )
313        .bind(tenant_id)
314        .bind(month_start)
315        .bind(requests as i64)
316        .bind(tokens as i64)
317        .bind(requests as i64)
318        .bind(tokens as i64)
319        .execute(&self.postgres.pool)
320        .await
321        .map_err(|e| AppError::Database(format!("Failed to update monthly cache: {}", e)))?;
322
323        sqlx::query(
324            "INSERT INTO daily_rate_limits (tenant_id, usage_date, request_count) VALUES ($1, $2, $3)
325             ON CONFLICT(tenant_id, usage_date) DO UPDATE SET 
326             request_count = daily_rate_limits.request_count + $4"
327        )
328        .bind(tenant_id)
329        .bind(today)
330        .bind(requests as i64)
331        .bind(requests as i64)
332        .execute(&self.postgres.pool)
333        .await
334        .map_err(|e| AppError::Database(format!("Failed to update daily limit: {}", e)))?;
335
336        {
337            let mut cache = self.monthly_cache.write().await;
338            if let Some((month, count)) = cache.get_mut(tenant_id) {
339                if *month == month_start {
340                    *count += requests;
341                }
342            }
343        }
344
345        {
346            let mut cache = self.daily_cache.write().await;
347            if let Some((day, count)) = cache.get_mut(tenant_id) {
348                if *day == today {
349                    *count += requests;
350                }
351            }
352        }
353
354        Ok(())
355    }
356
357    pub async fn get_usage_summary(&self, tenant_id: &str) -> Result<UsageSummary> {
358        let monthly_requests = self.get_monthly_requests(tenant_id).await?;
359        let daily_requests = self.get_daily_requests(tenant_id).await?;
360
361        let now = Utc::now();
362        let month_start = now
363            .date_naive()
364            .with_day(1)
365            .unwrap()
366            .and_hms_opt(0, 0, 0)
367            .unwrap()
368            .and_utc()
369            .timestamp();
370
371        let row = sqlx::query(
372            "SELECT COALESCE(SUM(token_count)::bigint, 0) FROM monthly_usage_cache WHERE tenant_id = $1 AND usage_month >= $2"
373        )
374        .bind(tenant_id)
375        .bind(month_start)
376        .fetch_one(&self.postgres.pool)
377        .await
378        .map_err(|e| AppError::Database(format!("Failed to get monthly tokens: {}", e)))?;
379
380        let monthly_tokens: i64 = row.try_get::<i64, _>(0).unwrap_or(0);
381
382        Ok(UsageSummary {
383            monthly_requests,
384            monthly_tokens: monthly_tokens as u64,
385            daily_requests,
386        })
387    }
388
389    pub async fn revoke_api_key(&self, tenant_id: &str, key_id: &str) -> Result<()> {
390        let result =
391            sqlx::query("UPDATE api_keys SET is_active = 0 WHERE id = $1 AND tenant_id = $2")
392                .bind(key_id)
393                .bind(tenant_id)
394                .execute(&self.postgres.pool)
395                .await
396                .map_err(|e| AppError::Database(format!("Failed to revoke API key: {}", e)))?;
397
398        if result.rows_affected() == 0 {
399            return Err(AppError::NotFound(format!(
400                "API key '{}' not found for tenant '{}'",
401                key_id, tenant_id
402            )));
403        }
404        Ok(())
405    }
406
407    pub async fn update_tenant_quota(&self, tenant_id: &str, tier: TenantTier) -> Result<()> {
408        sqlx::query("UPDATE tenants SET tier = $1, updated_at = $2 WHERE id = $3")
409            .bind(tier.as_str())
410            .bind(Utc::now().timestamp())
411            .bind(tenant_id)
412            .execute(&self.postgres.pool)
413            .await
414            .map_err(|e| AppError::Database(format!("Failed to update tenant quota: {}", e)))?;
415
416        Ok(())
417    }
418}
419
420fn generate_api_key() -> String {
421    let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
422    format!("ares_{}", hex::encode(bytes))
423}
424
425fn hash_api_key(raw_key: &str) -> String {
426    let mut hasher = Sha256::new();
427    hasher.update(raw_key.as_bytes());
428    hex::encode(hasher.finalize())
429}
430
431#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
432pub struct UsageSummary {
433    pub monthly_requests: u64,
434    pub monthly_tokens: u64,
435    pub daily_requests: u64,
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_generate_api_key() {
444        let key = generate_api_key();
445        assert!(key.starts_with("ares_"));
446        assert_eq!(key.len(), 69);
447    }
448}