1use crate::db::PostgresClient;
2use crate::models::{ApiKey, Tenant, TenantContext, TenantTier};
3use crate::types::{AppError, Result};
4use chrono::{Datelike, TimeZone, Utc};
5use sha2::{Digest, Sha256};
6use sqlx::Row;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11pub struct TenantDb {
12 postgres: Arc<PostgresClient>,
13 monthly_cache: Arc<RwLock<HashMap<String, (i64, u64)>>>,
14 daily_cache: Arc<RwLock<HashMap<String, (i64, u64)>>>,
15}
16
17impl TenantDb {
18 pub fn new(postgres: Arc<PostgresClient>) -> Self {
19 Self {
20 postgres,
21 monthly_cache: Arc::new(RwLock::new(HashMap::new())),
22 daily_cache: Arc::new(RwLock::new(HashMap::new())),
23 }
24 }
25
26 pub fn pool(&self) -> &sqlx::PgPool {
27 &self.postgres.pool
28 }
29
30 pub async fn create_tenant(&self, name: String, tier: TenantTier) -> Result<Tenant> {
31 let id = uuid::Uuid::new_v4().to_string();
32 let tenant = Tenant::new(id.clone(), name, tier);
33
34 sqlx::query(
35 "INSERT INTO tenants (id, name, tier, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)"
36 )
37 .bind(&tenant.id)
38 .bind(&tenant.name)
39 .bind(tenant.tier.as_str())
40 .bind(tenant.created_at)
41 .bind(tenant.updated_at)
42 .execute(&self.postgres.pool)
43 .await
44 .map_err(|e| AppError::Database(format!("Failed to create tenant: {}", e)))?;
45
46 Ok(tenant)
47 }
48
49 pub async fn list_tenants(&self) -> Result<Vec<Tenant>> {
50 let rows = sqlx::query(
51 "SELECT id, name, tier, created_at, updated_at FROM tenants ORDER BY created_at DESC",
52 )
53 .fetch_all(&self.postgres.pool)
54 .await
55 .map_err(|e| AppError::Database(format!("Failed to list tenants: {}", e)))?;
56
57 let mut tenants = Vec::new();
58 for row in rows {
59 let tier_str: String = row.get(2);
60 let tier = TenantTier::from_str(&tier_str).unwrap_or(TenantTier::Free);
61 tenants.push(Tenant {
62 id: row.get(0),
63 name: row.get(1),
64 tier,
65 created_at: row.get(3),
66 updated_at: row.get(4),
67 });
68 }
69
70 Ok(tenants)
71 }
72
73 pub async fn get_tenant(&self, tenant_id: &str) -> Result<Option<Tenant>> {
74 let row =
75 sqlx::query("SELECT id, name, tier, created_at, updated_at FROM tenants WHERE id = $1")
76 .bind(tenant_id)
77 .fetch_optional(&self.postgres.pool)
78 .await
79 .map_err(|e| AppError::Database(format!("Failed to get tenant: {}", e)))?;
80
81 if let Some(row) = row {
82 let tier_str: String = row.get(2);
83 let tier = TenantTier::from_str(&tier_str).unwrap_or(TenantTier::Free);
84 Ok(Some(Tenant {
85 id: row.get(0),
86 name: row.get(1),
87 tier,
88 created_at: row.get(3),
89 updated_at: row.get(4),
90 }))
91 } else {
92 Ok(None)
93 }
94 }
95
96 pub async fn create_api_key(&self, tenant_id: &str, name: String) -> Result<(ApiKey, String)> {
97 let id = uuid::Uuid::new_v4().to_string();
98 let raw_key = generate_api_key();
99 let key_prefix = format!("ares_{}", &raw_key[..8]);
100
101 let key_hash = hash_api_key(&raw_key);
102
103 let api_key = ApiKey::new(id, tenant_id.to_string(), key_hash, key_prefix, name);
104
105 sqlx::query(
106 "INSERT INTO api_keys (id, tenant_id, key_hash, key_prefix, name, is_active, created_at, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
107 )
108 .bind(&api_key.id)
109 .bind(&api_key.tenant_id)
110 .bind(&api_key.key_hash)
111 .bind(&api_key.key_prefix)
112 .bind(&api_key.name)
113 .bind(api_key.is_active as i32)
114 .bind(api_key.created_at)
115 .bind(api_key.expires_at)
116 .execute(&self.postgres.pool)
117 .await
118 .map_err(|e| AppError::Database(format!("Failed to create API key: {}", e)))?;
119
120 Ok((api_key, raw_key))
121 }
122
123 pub async fn list_api_keys(&self, tenant_id: &str) -> Result<Vec<ApiKey>> {
124 let rows = sqlx::query(
125 "SELECT id, tenant_id, key_hash, key_prefix, name, is_active, created_at, expires_at FROM api_keys WHERE tenant_id = $1 ORDER BY created_at DESC"
126 )
127 .bind(tenant_id)
128 .fetch_all(&self.postgres.pool)
129 .await
130 .map_err(|e| AppError::Database(format!("Failed to list API keys: {}", e)))?;
131
132 let mut keys = Vec::new();
133 for row in rows {
134 let expires_at: Option<i64> = row.get(7);
135 keys.push(ApiKey {
136 id: row.get(0),
137 tenant_id: row.get(1),
138 key_hash: row.get(2),
139 key_prefix: row.get(3),
140 name: row.get(4),
141 is_active: row.get::<i32, _>(5) != 0,
142 created_at: row.get(6),
143 expires_at,
144 });
145 }
146
147 Ok(keys)
148 }
149
150 pub async fn verify_api_key(&self, raw_key: &str) -> Result<Option<TenantContext>> {
151 let key_prefix = format!("ares_{}", &raw_key[5..13]);
152 let row = sqlx::query(
153 "SELECT ak.id, ak.tenant_id, ak.key_hash, ak.is_active, ak.expires_at, t.tier
154 FROM api_keys ak
155 JOIN tenants t ON ak.tenant_id = t.id
156 WHERE ak.key_prefix = $1",
157 )
158 .bind(key_prefix)
159 .fetch_optional(&self.postgres.pool)
160 .await
161 .map_err(|e| AppError::Database(format!("Failed to lookup API key: {}", e)))?;
162
163 if let Some(row) = row {
164 let key_hash: String = row.get(2);
165 let is_active: i32 = row.get(3);
166 let expires_at: Option<i64> = row.get(4);
167 let tier_str: String = row.get(5);
168
169 if is_active == 0 {
170 return Ok(None);
171 }
172
173 if let Some(exp) = expires_at {
174 if Utc::now().timestamp() > exp {
175 return Ok(None);
176 }
177 }
178
179 let key_without_prefix = raw_key.strip_prefix("ares_").unwrap_or(raw_key);
181 let input_hash = hash_api_key(key_without_prefix);
182 if input_hash != key_hash {
183 return Ok(None);
184 }
185
186 let tenant_id: String = row.get(1);
187 let tier = TenantTier::from_str(&tier_str).unwrap_or(TenantTier::Free);
188
189 Ok(Some(TenantContext::new(tenant_id, tier)))
190 } else {
191 Ok(None)
192 }
193 }
194
195 pub async fn get_monthly_requests(&self, tenant_id: &str) -> Result<u64> {
196 let cache_key = tenant_id.to_string();
197 let now = Utc::now();
198 let month_start = now
199 .date_naive()
200 .with_day(1)
201 .unwrap()
202 .and_hms_opt(0, 0, 0)
203 .unwrap()
204 .and_utc()
205 .timestamp();
206
207 {
208 let cache = self.monthly_cache.read().await;
209 if let Some((cached_month, count)) = cache.get(&cache_key) {
210 if *cached_month == month_start {
211 return Ok(*count);
212 }
213 }
214 }
215
216 let row = sqlx::query(
217 "SELECT COALESCE(SUM(request_count)::bigint, 0) FROM monthly_usage_cache WHERE tenant_id = $1 AND usage_month >= $2"
218 )
219 .bind(tenant_id)
220 .bind(month_start)
221 .fetch_one(&self.postgres.pool)
222 .await
223 .map_err(|e| AppError::Database(format!("Failed to get monthly requests: {}", e)))?;
224
225 let count: i64 = row.try_get::<i64, _>(0).unwrap_or(0);
226 let count = count as u64;
227
228 {
229 let mut cache = self.monthly_cache.write().await;
230 cache.insert(cache_key, (month_start, count));
231 }
232
233 Ok(count)
234 }
235
236 pub async fn get_daily_requests(&self, tenant_id: &str) -> Result<u64> {
237 let cache_key = tenant_id.to_string();
238 let today = Utc::now()
239 .date_naive()
240 .and_hms_opt(0, 0, 0)
241 .unwrap()
242 .and_utc()
243 .timestamp();
244
245 {
246 let cache = self.daily_cache.read().await;
247 if let Some((cached_day, count)) = cache.get(&cache_key) {
248 if *cached_day == today {
249 return Ok(*count);
250 }
251 }
252 }
253
254 let row = sqlx::query(
255 "SELECT COALESCE(SUM(request_count)::bigint, 0) FROM daily_rate_limits WHERE tenant_id = $1 AND usage_date >= $2"
256 )
257 .bind(tenant_id)
258 .bind(today)
259 .fetch_one(&self.postgres.pool)
260 .await
261 .map_err(|e| AppError::Database(format!("Failed to get daily requests: {}", e)))?;
262
263 let count: i64 = row.try_get::<i64, _>(0).unwrap_or(0);
264 let count = count as u64;
265
266 {
267 let mut cache = self.daily_cache.write().await;
268 cache.insert(cache_key, (today, count));
269 }
270
271 Ok(count)
272 }
273
274 pub async fn record_usage_event(
275 &self,
276 tenant_id: &str,
277 requests: u64,
278 tokens: u64,
279 ) -> Result<()> {
280 let now = Utc::now();
281 let today = now
282 .date_naive()
283 .and_hms_opt(0, 0, 0)
284 .unwrap()
285 .and_utc()
286 .timestamp();
287 let month_start = now
288 .date_naive()
289 .with_day(1)
290 .unwrap()
291 .and_hms_opt(0, 0, 0)
292 .unwrap()
293 .and_utc()
294 .timestamp();
295
296 sqlx::query(
297 "INSERT INTO usage_events (id, tenant_id, source, request_count, token_count, created_at) VALUES ($1, $2, 'http', $3, $4, $5)"
298 )
299 .bind(uuid::Uuid::new_v4().to_string())
300 .bind(tenant_id)
301 .bind(requests as i64)
302 .bind(tokens as i64)
303 .bind(now.timestamp())
304 .execute(&self.postgres.pool)
305 .await
306 .map_err(|e| AppError::Database(format!("Failed to record usage event: {}", e)))?;
307
308 sqlx::query(
309 "INSERT INTO monthly_usage_cache (tenant_id, usage_month, request_count, token_count) VALUES ($1, $2, $3, $4)
310 ON CONFLICT(tenant_id, usage_month) DO UPDATE SET
311 request_count = monthly_usage_cache.request_count + $5, token_count = monthly_usage_cache.token_count + $6"
312 )
313 .bind(tenant_id)
314 .bind(month_start)
315 .bind(requests as i64)
316 .bind(tokens as i64)
317 .bind(requests as i64)
318 .bind(tokens as i64)
319 .execute(&self.postgres.pool)
320 .await
321 .map_err(|e| AppError::Database(format!("Failed to update monthly cache: {}", e)))?;
322
323 sqlx::query(
324 "INSERT INTO daily_rate_limits (tenant_id, usage_date, request_count) VALUES ($1, $2, $3)
325 ON CONFLICT(tenant_id, usage_date) DO UPDATE SET
326 request_count = daily_rate_limits.request_count + $4"
327 )
328 .bind(tenant_id)
329 .bind(today)
330 .bind(requests as i64)
331 .bind(requests as i64)
332 .execute(&self.postgres.pool)
333 .await
334 .map_err(|e| AppError::Database(format!("Failed to update daily limit: {}", e)))?;
335
336 {
337 let mut cache = self.monthly_cache.write().await;
338 if let Some((month, count)) = cache.get_mut(tenant_id) {
339 if *month == month_start {
340 *count += requests;
341 }
342 }
343 }
344
345 {
346 let mut cache = self.daily_cache.write().await;
347 if let Some((day, count)) = cache.get_mut(tenant_id) {
348 if *day == today {
349 *count += requests;
350 }
351 }
352 }
353
354 Ok(())
355 }
356
357 pub async fn get_usage_summary(&self, tenant_id: &str) -> Result<UsageSummary> {
358 let monthly_requests = self.get_monthly_requests(tenant_id).await?;
359 let daily_requests = self.get_daily_requests(tenant_id).await?;
360
361 let now = Utc::now();
362 let month_start = now
363 .date_naive()
364 .with_day(1)
365 .unwrap()
366 .and_hms_opt(0, 0, 0)
367 .unwrap()
368 .and_utc()
369 .timestamp();
370
371 let row = sqlx::query(
372 "SELECT COALESCE(SUM(token_count)::bigint, 0) FROM monthly_usage_cache WHERE tenant_id = $1 AND usage_month >= $2"
373 )
374 .bind(tenant_id)
375 .bind(month_start)
376 .fetch_one(&self.postgres.pool)
377 .await
378 .map_err(|e| AppError::Database(format!("Failed to get monthly tokens: {}", e)))?;
379
380 let monthly_tokens: i64 = row.try_get::<i64, _>(0).unwrap_or(0);
381
382 Ok(UsageSummary {
383 monthly_requests,
384 monthly_tokens: monthly_tokens as u64,
385 daily_requests,
386 })
387 }
388
389 pub async fn revoke_api_key(&self, tenant_id: &str, key_id: &str) -> Result<()> {
390 let result =
391 sqlx::query("UPDATE api_keys SET is_active = 0 WHERE id = $1 AND tenant_id = $2")
392 .bind(key_id)
393 .bind(tenant_id)
394 .execute(&self.postgres.pool)
395 .await
396 .map_err(|e| AppError::Database(format!("Failed to revoke API key: {}", e)))?;
397
398 if result.rows_affected() == 0 {
399 return Err(AppError::NotFound(format!(
400 "API key '{}' not found for tenant '{}'",
401 key_id, tenant_id
402 )));
403 }
404 Ok(())
405 }
406
407 pub async fn update_tenant_quota(&self, tenant_id: &str, tier: TenantTier) -> Result<()> {
408 sqlx::query("UPDATE tenants SET tier = $1, updated_at = $2 WHERE id = $3")
409 .bind(tier.as_str())
410 .bind(Utc::now().timestamp())
411 .bind(tenant_id)
412 .execute(&self.postgres.pool)
413 .await
414 .map_err(|e| AppError::Database(format!("Failed to update tenant quota: {}", e)))?;
415
416 Ok(())
417 }
418}
419
420fn generate_api_key() -> String {
421 let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
422 format!("ares_{}", hex::encode(bytes))
423}
424
425fn hash_api_key(raw_key: &str) -> String {
426 let mut hasher = Sha256::new();
427 hasher.update(raw_key.as_bytes());
428 hex::encode(hasher.finalize())
429}
430
431#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
432pub struct UsageSummary {
433 pub monthly_requests: u64,
434 pub monthly_tokens: u64,
435 pub daily_requests: u64,
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_generate_api_key() {
444 let key = generate_api_key();
445 assert!(key.starts_with("ares_"));
446 assert_eq!(key.len(), 69);
447 }
448}