Skip to main content

crabllm_proxy/ext/
audit.rs

1use crate::PREFIX_AUDIT;
2use axum::{
3    Json, Router,
4    extract::{Query, State},
5    http::StatusCode,
6    response::{IntoResponse, Response},
7    routing::get,
8};
9use crabllm_core::{ApiError, BoxFuture, Error, ModelInfo, RequestContext, Storage, storage_key};
10use serde::{Deserialize, Serialize};
11use std::{collections::HashMap, sync::Arc, time::SystemTime};
12
13pub struct AuditLogger {
14    storage: Arc<dyn Storage>,
15    models: HashMap<String, ModelInfo>,
16}
17
18impl AuditLogger {
19    pub fn new(
20        _config: &serde_json::Value,
21        storage: Arc<dyn Storage>,
22        models: HashMap<String, ModelInfo>,
23    ) -> Result<Self, String> {
24        Ok(Self { storage, models })
25    }
26
27    pub fn admin_routes(&self) -> Router {
28        Router::new()
29            .route("/v1/admin/logs", get(logs_handler))
30            .with_state(self.storage.clone())
31    }
32
33    fn cost_micros(&self, model: &str, provider: &str, usage: &crabllm_core::Usage) -> i64 {
34        let qualified = format!("{provider}/{model}");
35        self.models
36            .get(qualified.as_str())
37            .or_else(|| self.models.get(model))
38            .map(|info| (info.cost(usage) * 1_000_000.0).round() as i64)
39            .unwrap_or(0)
40    }
41
42    fn write_record(&self, record: AuditRecord) {
43        let ts_bytes = record.timestamp.to_be_bytes();
44        let mut suffix = Vec::with_capacity(8 + record.request_id.len());
45        suffix.extend_from_slice(&ts_bytes);
46        suffix.extend_from_slice(record.request_id.as_bytes());
47        let key = storage_key(&PREFIX_AUDIT, &suffix);
48
49        let storage = self.storage.clone();
50        // Fire-and-forget — audit logging must not block the response path.
51        tokio::spawn(async move {
52            match serde_json::to_vec(&record) {
53                Ok(value) => {
54                    if let Err(e) = storage.set(&key, value).await {
55                        tracing::warn!("audit: failed to write record: {e}");
56                    }
57                }
58                Err(e) => tracing::warn!("audit: failed to serialize record: {e}"),
59            }
60        });
61    }
62}
63
64fn now_millis() -> i64 {
65    SystemTime::now()
66        .duration_since(SystemTime::UNIX_EPOCH)
67        .unwrap_or_default()
68        .as_millis() as i64
69}
70
71fn error_status(e: &Error) -> u16 {
72    match e {
73        Error::Provider { status, .. } => *status,
74        Error::Timeout => 504,
75        _ => 500,
76    }
77}
78
79impl crabllm_core::Extension for AuditLogger {
80    fn name(&self) -> &str {
81        "audit"
82    }
83
84    fn prefix(&self) -> crabllm_core::Prefix {
85        PREFIX_AUDIT
86    }
87
88    fn on_response(
89        &self,
90        ctx: &RequestContext,
91        _raw_request: &[u8],
92        raw_response: &[u8],
93    ) -> BoxFuture<'_, ()> {
94        let usage = crabllm_core::Usage::from(raw_response);
95        let cost_micros = self.cost_micros(&ctx.model, &ctx.provider, &usage);
96
97        self.write_record(AuditRecord {
98            request_id: ctx.request_id.clone(),
99            timestamp: now_millis(),
100            principal: ctx.principal.clone().unwrap_or_default(),
101            model: ctx.model.clone(),
102            provider: ctx.provider.clone(),
103            prompt_tokens: if usage.total_tokens() > 0 {
104                Some(usage.prompt_tokens())
105            } else {
106                None
107            },
108            completion_tokens: if usage.total_tokens() > 0 {
109                Some(usage.completion_tokens())
110            } else {
111                None
112            },
113            cache_hit_tokens: if usage.cache_read_tokens > 0 {
114                Some(usage.cache_read_tokens)
115            } else {
116                None
117            },
118            cost_micros,
119            latency_ms: ctx.started_at.elapsed().as_millis() as u64,
120            status: 200,
121            error: None,
122        });
123
124        Box::pin(async {})
125    }
126
127    fn on_chunk(&self, ctx: &RequestContext, raw_chunk: &[u8]) -> BoxFuture<'_, ()> {
128        let usage = crabllm_core::Usage::from(raw_chunk);
129        if usage.total_tokens() > 0 {
130            let cost_micros = self.cost_micros(&ctx.model, &ctx.provider, &usage);
131
132            self.write_record(AuditRecord {
133                request_id: ctx.request_id.clone(),
134                timestamp: now_millis(),
135                principal: ctx.principal.clone().unwrap_or_default(),
136                model: ctx.model.clone(),
137                provider: ctx.provider.clone(),
138                prompt_tokens: Some(usage.prompt_tokens()),
139                completion_tokens: Some(usage.completion_tokens()),
140                cache_hit_tokens: if usage.cache_read_tokens > 0 {
141                    Some(usage.cache_read_tokens)
142                } else {
143                    None
144                },
145                cost_micros,
146                latency_ms: ctx.started_at.elapsed().as_millis() as u64,
147                status: 200,
148                error: None,
149            });
150        }
151
152        Box::pin(async {})
153    }
154
155    fn on_error(&self, ctx: &RequestContext, error: &Error) -> BoxFuture<'_, ()> {
156        self.write_record(AuditRecord {
157            request_id: ctx.request_id.clone(),
158            timestamp: now_millis(),
159            principal: ctx.principal.clone().unwrap_or_default(),
160            model: ctx.model.clone(),
161            provider: ctx.provider.clone(),
162            prompt_tokens: None,
163            completion_tokens: None,
164            cache_hit_tokens: None,
165            cost_micros: 0,
166            latency_ms: ctx.started_at.elapsed().as_millis() as u64,
167            status: error_status(error),
168            error: Some(error.to_string()),
169        });
170
171        Box::pin(async {})
172    }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct AuditRecord {
177    pub request_id: String,
178    pub timestamp: i64,
179    pub principal: String,
180    pub model: String,
181    pub provider: String,
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub prompt_tokens: Option<u32>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub completion_tokens: Option<u32>,
186    #[serde(default, skip_serializing_if = "Option::is_none")]
187    pub cache_hit_tokens: Option<u32>,
188    pub cost_micros: i64,
189    pub latency_ms: u64,
190    pub status: u16,
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub error: Option<String>,
193}
194
195#[derive(Deserialize)]
196struct LogQuery {
197    #[serde(default)]
198    key: Option<String>,
199    #[serde(default)]
200    model: Option<String>,
201    #[serde(default)]
202    since: Option<i64>,
203    #[serde(default)]
204    until: Option<i64>,
205    #[serde(default = "default_limit")]
206    limit: usize,
207}
208
209fn default_limit() -> usize {
210    100
211}
212
213/// GET /v1/admin/logs — query audit log records.
214///
215/// Loads all records from storage and filters in memory. Acceptable for
216/// moderate volumes; high-throughput deployments should migrate to a
217/// dedicated time-series store.
218async fn logs_handler(
219    State(storage): State<Arc<dyn Storage>>,
220    Query(query): Query<LogQuery>,
221) -> Response {
222    let pairs = match storage.list(&PREFIX_AUDIT).await {
223        Ok(p) => p,
224        Err(e) => {
225            return (
226                StatusCode::INTERNAL_SERVER_ERROR,
227                Json(ApiError::new(e.to_string(), "server_error")),
228            )
229                .into_response();
230        }
231    };
232
233    let mut records: Vec<AuditRecord> = pairs
234        .into_iter()
235        .filter_map(|(_k, v)| serde_json::from_slice(&v).ok())
236        .filter(|r: &AuditRecord| {
237            if let Some(ref key) = query.key
238                && &r.principal != key
239            {
240                return false;
241            }
242            if let Some(ref model) = query.model
243                && &r.model != model
244            {
245                return false;
246            }
247            if let Some(since) = query.since
248                && r.timestamp < since
249            {
250                return false;
251            }
252            if let Some(until) = query.until
253                && r.timestamp > until
254            {
255                return false;
256            }
257            true
258        })
259        .collect();
260
261    // Newest first.
262    records.sort_by_key(|b| std::cmp::Reverse(b.timestamp));
263    records.truncate(query.limit);
264
265    Json(records).into_response()
266}