1use axum::{
2 extract::Request,
3 middleware::Next,
4 response::Response,
5};
6use std::sync::Arc;
7use crate::db::tenants::TenantDb;
8
9pub async fn track_usage(
10 req: Request,
11 next: Next,
12) -> Response {
13 let tenant_id = req.extensions().get::<crate::models::TenantContext>().map(|c| c.tenant_id.clone());
14 let tenant_db = req.extensions().get::<Arc<TenantDb>>().cloned();
15
16 let response = next.run(req).await;
17
18 if let (Some(tid), Some(db)) = (tenant_id, tenant_db) {
19 let headers = response.headers().clone();
20 tokio::spawn(async move {
21 let _ = crate::middleware::usage::record_usage(&tid, &headers, db.as_ref()).await;
22 });
23 }
24
25 response
26}
27
28async fn record_usage(
29 tenant_id: &str,
30 headers: &axum::http::HeaderMap,
31 db: &TenantDb,
32) -> Result<(), Box<dyn std::error::Error>> {
33 let mut tokens = 0;
34 if let Some(t) = headers.get("x-input-tokens").and_then(|v| v.to_str().ok()).and_then(|v| v.parse::<i32>().ok()) { tokens += t; }
35 if let Some(t) = headers.get("x-output-tokens").and_then(|v| v.to_str().ok()).and_then(|v| v.parse::<i32>().ok()) { tokens += t; }
36
37 db.record_usage_event(tenant_id, 1, tokens as u64).await?;
38 Ok(())
39}