Skip to main content

ares/middleware/
usage.rs

1use axum::{
2    extract::Request,
3    middleware::Next,
4    response::Response,
5};
6use std::sync::Arc;
7use crate::db::tenants::TenantDb;
8
9pub async fn track_usage(
10    req: Request,
11    next: Next,
12) -> Response {
13    let tenant_id = req.extensions().get::<crate::models::TenantContext>().map(|c| c.tenant_id.clone());
14    let tenant_db = req.extensions().get::<Arc<TenantDb>>().cloned();
15
16    let response = next.run(req).await;
17
18    if let (Some(tid), Some(db)) = (tenant_id, tenant_db) {
19        let headers = response.headers().clone();
20        tokio::spawn(async move {
21            let _ = crate::middleware::usage::record_usage(&tid, &headers, db.as_ref()).await;
22        });
23    }
24
25    response
26}
27
28async fn record_usage(
29    tenant_id: &str,
30    headers: &axum::http::HeaderMap,
31    db: &TenantDb,
32) -> Result<(), Box<dyn std::error::Error>> {
33    let mut tokens = 0;
34    if let Some(t) = headers.get("x-input-tokens").and_then(|v| v.to_str().ok()).and_then(|v| v.parse::<i32>().ok()) { tokens += t; }
35    if let Some(t) = headers.get("x-output-tokens").and_then(|v| v.to_str().ok()).and_then(|v| v.parse::<i32>().ok()) { tokens += t; }
36    
37    db.record_usage_event(tenant_id, 1, tokens as u64).await?;
38    Ok(())
39}