Skip to main content

ares/middleware/
usage.rs

1use crate::db::tenants::TenantDb;
2use axum::{extract::Request, middleware::Next, response::Response};
3use std::sync::Arc;
4
5pub async fn track_usage(req: Request, next: Next) -> Response {
6    let tenant_id = req
7        .extensions()
8        .get::<crate::models::TenantContext>()
9        .map(|c| c.tenant_id.clone());
10    let tenant_db = req.extensions().get::<Arc<TenantDb>>().cloned();
11
12    let response = next.run(req).await;
13
14    if let (Some(tid), Some(db)) = (tenant_id, tenant_db) {
15        let headers = response.headers().clone();
16        let pool = db.pool().clone();
17        tokio::spawn(async move {
18            let _ = crate::middleware::usage::record_usage(&tid, &headers, &pool).await;
19        });
20    }
21
22    response
23}
24
25async fn record_usage(
26	tenant_id: &str,
27	headers: &axum::http::HeaderMap,
28	pool: &sqlx::PgPool,
29) -> Result<(), Box<dyn std::error::Error>> {
30	let input_tokens: i64 = headers
31		.get("x-input-tokens")
32		.and_then(|v| v.to_str().ok())
33		.and_then(|v| v.parse::<i64>().ok())
34		.unwrap_or(0);
35	let output_tokens: i64 = headers
36		.get("x-output-tokens")
37		.and_then(|v| v.to_str().ok())
38		.and_then(|v| v.parse::<i64>().ok())
39		.unwrap_or(0);
40	let token_count = input_tokens + output_tokens;
41
42	let model_name: Option<String> = headers
43		.get("x-model-name")
44		.and_then(|v| v.to_str().ok())
45		.map(|v| v.to_string());
46	let agent_name: Option<String> = headers
47		.get("x-agent-name")
48		.and_then(|v| v.to_str().ok())
49		.map(|v| v.to_string());
50	let provider_name: Option<String> = headers
51		.get("x-provider-name")
52		.and_then(|v| v.to_str().ok())
53		.map(|v| v.to_string());
54
55	// Record usage event.
56	// Use runtime `sqlx::query` (not the `query!` macro) so downstream
57	// crates don't need DATABASE_URL at compile time or a `.sqlx` cache.
58	// Library crates that ship via crates.io cannot rely on a live DB
59	// or bundled cache being available to their consumers.
60	sqlx::query(
61		"INSERT INTO usage_events (id, tenant_id, source, request_count, token_count, input_tokens, output_tokens, model_name, agent_name, provider_name, created_at) VALUES ($1, $2, 'http', $3, $4, $5, $6, $7, $8, $9, $10)",
62	)
63	.bind(uuid::Uuid::new_v4().to_string())
64	.bind(tenant_id)
65	.bind(1_i32)
66	.bind(token_count)
67	.bind(input_tokens)
68	.bind(output_tokens)
69	.bind(model_name)
70	.bind(agent_name)
71	.bind(provider_name)
72	.bind(chrono::Utc::now().timestamp())
73	.execute(pool)
74	.await?;
75	Ok(())
76}