1use crate::PREFIX_AUDIT;
2use axum::{
3 Json, Router,
4 extract::{Query, State},
5 http::StatusCode,
6 response::{IntoResponse, Response},
7 routing::get,
8};
9use crabllm_core::{ApiError, BoxFuture, Error, ModelInfo, RequestContext, Storage, storage_key};
10use serde::{Deserialize, Serialize};
11use std::{collections::HashMap, sync::Arc, time::SystemTime};
12
13pub struct AuditLogger {
14 storage: Arc<dyn Storage>,
15 models: HashMap<String, ModelInfo>,
16}
17
18impl AuditLogger {
19 pub fn new(
20 _config: &serde_json::Value,
21 storage: Arc<dyn Storage>,
22 models: HashMap<String, ModelInfo>,
23 ) -> Result<Self, String> {
24 Ok(Self { storage, models })
25 }
26
27 pub fn admin_routes(&self) -> Router {
28 Router::new()
29 .route("/v1/admin/logs", get(logs_handler))
30 .with_state(self.storage.clone())
31 }
32
33 fn cost_micros(&self, model: &str, provider: &str, usage: &crabllm_core::Usage) -> i64 {
34 let qualified = format!("{provider}/{model}");
35 self.models
36 .get(qualified.as_str())
37 .or_else(|| self.models.get(model))
38 .map(|info| (info.cost(usage) * 1_000_000.0).round() as i64)
39 .unwrap_or(0)
40 }
41
42 fn write_record(&self, record: AuditRecord) {
43 let ts_bytes = record.timestamp.to_be_bytes();
44 let mut suffix = Vec::with_capacity(8 + record.request_id.len());
45 suffix.extend_from_slice(&ts_bytes);
46 suffix.extend_from_slice(record.request_id.as_bytes());
47 let key = storage_key(&PREFIX_AUDIT, &suffix);
48
49 let storage = self.storage.clone();
50 tokio::spawn(async move {
52 match serde_json::to_vec(&record) {
53 Ok(value) => {
54 if let Err(e) = storage.set(&key, value).await {
55 tracing::warn!("audit: failed to write record: {e}");
56 }
57 }
58 Err(e) => tracing::warn!("audit: failed to serialize record: {e}"),
59 }
60 });
61 }
62}
63
64fn now_millis() -> i64 {
65 SystemTime::now()
66 .duration_since(SystemTime::UNIX_EPOCH)
67 .unwrap_or_default()
68 .as_millis() as i64
69}
70
71fn error_status(e: &Error) -> u16 {
72 match e {
73 Error::Provider { status, .. } => *status,
74 Error::Timeout => 504,
75 _ => 500,
76 }
77}
78
79impl crabllm_core::Extension for AuditLogger {
80 fn name(&self) -> &str {
81 "audit"
82 }
83
84 fn prefix(&self) -> crabllm_core::Prefix {
85 PREFIX_AUDIT
86 }
87
88 fn on_response(
89 &self,
90 ctx: &RequestContext,
91 _raw_request: &[u8],
92 raw_response: &[u8],
93 ) -> BoxFuture<'_, ()> {
94 let usage = crabllm_core::Usage::from(raw_response);
95 let cost_micros = self.cost_micros(&ctx.model, &ctx.provider, &usage);
96
97 self.write_record(AuditRecord {
98 request_id: ctx.request_id.clone(),
99 timestamp: now_millis(),
100 principal: ctx.principal.clone().unwrap_or_default(),
101 model: ctx.model.clone(),
102 provider: ctx.provider.clone(),
103 prompt_tokens: if usage.total_tokens() > 0 {
104 Some(usage.prompt_tokens())
105 } else {
106 None
107 },
108 completion_tokens: if usage.total_tokens() > 0 {
109 Some(usage.completion_tokens())
110 } else {
111 None
112 },
113 cache_hit_tokens: if usage.cache_read_tokens > 0 {
114 Some(usage.cache_read_tokens)
115 } else {
116 None
117 },
118 cost_micros,
119 latency_ms: ctx.started_at.elapsed().as_millis() as u64,
120 status: 200,
121 error: None,
122 });
123
124 Box::pin(async {})
125 }
126
127 fn on_chunk(&self, ctx: &RequestContext, raw_chunk: &[u8]) -> BoxFuture<'_, ()> {
128 let usage = crabllm_core::Usage::from(raw_chunk);
129 if usage.total_tokens() > 0 {
130 let cost_micros = self.cost_micros(&ctx.model, &ctx.provider, &usage);
131
132 self.write_record(AuditRecord {
133 request_id: ctx.request_id.clone(),
134 timestamp: now_millis(),
135 principal: ctx.principal.clone().unwrap_or_default(),
136 model: ctx.model.clone(),
137 provider: ctx.provider.clone(),
138 prompt_tokens: Some(usage.prompt_tokens()),
139 completion_tokens: Some(usage.completion_tokens()),
140 cache_hit_tokens: if usage.cache_read_tokens > 0 {
141 Some(usage.cache_read_tokens)
142 } else {
143 None
144 },
145 cost_micros,
146 latency_ms: ctx.started_at.elapsed().as_millis() as u64,
147 status: 200,
148 error: None,
149 });
150 }
151
152 Box::pin(async {})
153 }
154
155 fn on_error(&self, ctx: &RequestContext, error: &Error) -> BoxFuture<'_, ()> {
156 self.write_record(AuditRecord {
157 request_id: ctx.request_id.clone(),
158 timestamp: now_millis(),
159 principal: ctx.principal.clone().unwrap_or_default(),
160 model: ctx.model.clone(),
161 provider: ctx.provider.clone(),
162 prompt_tokens: None,
163 completion_tokens: None,
164 cache_hit_tokens: None,
165 cost_micros: 0,
166 latency_ms: ctx.started_at.elapsed().as_millis() as u64,
167 status: error_status(error),
168 error: Some(error.to_string()),
169 });
170
171 Box::pin(async {})
172 }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct AuditRecord {
177 pub request_id: String,
178 pub timestamp: i64,
179 pub principal: String,
180 pub model: String,
181 pub provider: String,
182 #[serde(skip_serializing_if = "Option::is_none")]
183 pub prompt_tokens: Option<u32>,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 pub completion_tokens: Option<u32>,
186 #[serde(default, skip_serializing_if = "Option::is_none")]
187 pub cache_hit_tokens: Option<u32>,
188 pub cost_micros: i64,
189 pub latency_ms: u64,
190 pub status: u16,
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub error: Option<String>,
193}
194
195#[derive(Deserialize)]
196struct LogQuery {
197 #[serde(default)]
198 key: Option<String>,
199 #[serde(default)]
200 model: Option<String>,
201 #[serde(default)]
202 since: Option<i64>,
203 #[serde(default)]
204 until: Option<i64>,
205 #[serde(default = "default_limit")]
206 limit: usize,
207}
208
209fn default_limit() -> usize {
210 100
211}
212
213async fn logs_handler(
219 State(storage): State<Arc<dyn Storage>>,
220 Query(query): Query<LogQuery>,
221) -> Response {
222 let pairs = match storage.list(&PREFIX_AUDIT).await {
223 Ok(p) => p,
224 Err(e) => {
225 return (
226 StatusCode::INTERNAL_SERVER_ERROR,
227 Json(ApiError::new(e.to_string(), "server_error")),
228 )
229 .into_response();
230 }
231 };
232
233 let mut records: Vec<AuditRecord> = pairs
234 .into_iter()
235 .filter_map(|(_k, v)| serde_json::from_slice(&v).ok())
236 .filter(|r: &AuditRecord| {
237 if let Some(ref key) = query.key
238 && &r.principal != key
239 {
240 return false;
241 }
242 if let Some(ref model) = query.model
243 && &r.model != model
244 {
245 return false;
246 }
247 if let Some(since) = query.since
248 && r.timestamp < since
249 {
250 return false;
251 }
252 if let Some(until) = query.until
253 && r.timestamp > until
254 {
255 return false;
256 }
257 true
258 })
259 .collect();
260
261 records.sort_by_key(|b| std::cmp::Reverse(b.timestamp));
263 records.truncate(query.limit);
264
265 Json(records).into_response()
266}