1use crate::db::tenants::TenantDb;
2use axum::{extract::Request, middleware::Next, response::Response};
3use std::sync::Arc;
4
5pub async fn track_usage(req: Request, next: Next) -> Response {
6 let tenant_id = req
7 .extensions()
8 .get::<crate::models::TenantContext>()
9 .map(|c| c.tenant_id.clone());
10 let tenant_db = req.extensions().get::<Arc<TenantDb>>().cloned();
11
12 let response = next.run(req).await;
13
14 if let (Some(tid), Some(db)) = (tenant_id, tenant_db) {
15 let headers = response.headers().clone();
16 let pool = db.pool().clone();
17 tokio::spawn(async move {
18 let _ = crate::middleware::usage::record_usage(&tid, &headers, &pool).await;
19 });
20 }
21
22 response
23}
24
25async fn record_usage(
26 tenant_id: &str,
27 headers: &axum::http::HeaderMap,
28 pool: &sqlx::PgPool,
29) -> Result<(), Box<dyn std::error::Error>> {
30 let input_tokens: i64 = headers
31 .get("x-input-tokens")
32 .and_then(|v| v.to_str().ok())
33 .and_then(|v| v.parse::<i64>().ok())
34 .unwrap_or(0);
35 let output_tokens: i64 = headers
36 .get("x-output-tokens")
37 .and_then(|v| v.to_str().ok())
38 .and_then(|v| v.parse::<i64>().ok())
39 .unwrap_or(0);
40 let token_count = input_tokens + output_tokens;
41
42 let model_name: Option<String> = headers
43 .get("x-model-name")
44 .and_then(|v| v.to_str().ok())
45 .map(|v| v.to_string());
46 let agent_name: Option<String> = headers
47 .get("x-agent-name")
48 .and_then(|v| v.to_str().ok())
49 .map(|v| v.to_string());
50 let provider_name: Option<String> = headers
51 .get("x-provider-name")
52 .and_then(|v| v.to_str().ok())
53 .map(|v| v.to_string());
54
55 sqlx::query(
61 "INSERT INTO usage_events (id, tenant_id, source, request_count, token_count, input_tokens, output_tokens, model_name, agent_name, provider_name, created_at) VALUES ($1, $2, 'http', $3, $4, $5, $6, $7, $8, $9, $10)",
62 )
63 .bind(uuid::Uuid::new_v4().to_string())
64 .bind(tenant_id)
65 .bind(1_i32)
66 .bind(token_count)
67 .bind(input_tokens)
68 .bind(output_tokens)
69 .bind(model_name)
70 .bind(agent_name)
71 .bind(provider_name)
72 .bind(chrono::Utc::now().timestamp())
73 .execute(pool)
74 .await?;
75 Ok(())
76}