use axum::{
Json, Router,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
};
use crabllm_core::{
ApiError, BoxFuture, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, Error,
Prefix, RequestContext, Storage, cost, storage_key,
};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc, time::SystemTime};
const PREFIX: Prefix = *b"alog";
pub struct AuditLogger {
storage: Arc<dyn Storage>,
pricing: HashMap<String, crabllm_core::PricingConfig>,
}
impl AuditLogger {
pub fn new(
_config: &serde_json::Value,
storage: Arc<dyn Storage>,
pricing: HashMap<String, crabllm_core::PricingConfig>,
) -> Result<Self, String> {
Ok(Self { storage, pricing })
}
pub fn admin_routes(&self) -> Router {
Router::new()
.route("/v1/admin/logs", get(logs_handler))
.with_state(self.storage.clone())
}
fn cost_micros(&self, model: &str, prompt: u32, completion: u32) -> i64 {
self.pricing
.get(model)
.map(|p| (cost(p, prompt, completion) * 1_000_000.0).round() as i64)
.unwrap_or(0)
}
fn write_record(&self, record: AuditRecord) {
let ts_bytes = record.timestamp.to_be_bytes();
let mut suffix = Vec::with_capacity(8 + record.request_id.len());
suffix.extend_from_slice(&ts_bytes);
suffix.extend_from_slice(record.request_id.as_bytes());
let key = storage_key(&PREFIX, &suffix);
let storage = self.storage.clone();
tokio::spawn(async move {
match serde_json::to_vec(&record) {
Ok(value) => {
if let Err(e) = storage.set(&key, value).await {
tracing::warn!("audit: failed to write record: {e}");
}
}
Err(e) => tracing::warn!("audit: failed to serialize record: {e}"),
}
});
}
}
fn now_millis() -> i64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64
}
fn error_status(e: &Error) -> u16 {
match e {
Error::Provider { status, .. } => *status,
Error::Timeout => 504,
_ => 500,
}
}
impl crabllm_core::Extension for AuditLogger {
fn name(&self) -> &str {
"audit"
}
fn prefix(&self) -> Prefix {
PREFIX
}
fn on_response(
&self,
ctx: &RequestContext,
_request: &ChatCompletionRequest,
response: &ChatCompletionResponse,
) -> BoxFuture<'_, ()> {
let (prompt, completion) = response
.usage
.as_ref()
.map(|u| (Some(u.prompt_tokens), Some(u.completion_tokens)))
.unwrap_or((None, None));
let cost_micros = match (prompt, completion) {
(Some(p), Some(c)) => self.cost_micros(&ctx.model, p, c),
_ => 0,
};
self.write_record(AuditRecord {
request_id: ctx.request_id.clone(),
timestamp: now_millis(),
key_name: ctx.key_name.clone().unwrap_or_default(),
model: ctx.model.clone(),
provider: ctx.provider.clone(),
prompt_tokens: prompt,
completion_tokens: completion,
cost_micros,
latency_ms: ctx.started_at.elapsed().as_millis() as u64,
status: 200,
});
Box::pin(async {})
}
fn on_chunk(&self, ctx: &RequestContext, chunk: &ChatCompletionChunk) -> BoxFuture<'_, ()> {
if let Some(ref usage) = chunk.usage {
let cost_micros =
self.cost_micros(&ctx.model, usage.prompt_tokens, usage.completion_tokens);
self.write_record(AuditRecord {
request_id: ctx.request_id.clone(),
timestamp: now_millis(),
key_name: ctx.key_name.clone().unwrap_or_default(),
model: ctx.model.clone(),
provider: ctx.provider.clone(),
prompt_tokens: Some(usage.prompt_tokens),
completion_tokens: Some(usage.completion_tokens),
cost_micros,
latency_ms: ctx.started_at.elapsed().as_millis() as u64,
status: 200,
});
}
Box::pin(async {})
}
fn on_error(&self, ctx: &RequestContext, error: &Error) -> BoxFuture<'_, ()> {
self.write_record(AuditRecord {
request_id: ctx.request_id.clone(),
timestamp: now_millis(),
key_name: ctx.key_name.clone().unwrap_or_default(),
model: ctx.model.clone(),
provider: ctx.provider.clone(),
prompt_tokens: None,
completion_tokens: None,
cost_micros: 0,
latency_ms: ctx.started_at.elapsed().as_millis() as u64,
status: error_status(error),
});
Box::pin(async {})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AuditRecord {
request_id: String,
timestamp: i64,
key_name: String,
model: String,
provider: String,
#[serde(skip_serializing_if = "Option::is_none")]
prompt_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
completion_tokens: Option<u32>,
cost_micros: i64,
latency_ms: u64,
status: u16,
}
#[derive(Deserialize)]
struct LogQuery {
#[serde(default)]
key: Option<String>,
#[serde(default)]
model: Option<String>,
#[serde(default)]
since: Option<i64>,
#[serde(default)]
until: Option<i64>,
#[serde(default = "default_limit")]
limit: usize,
}
fn default_limit() -> usize {
100
}
async fn logs_handler(
State(storage): State<Arc<dyn Storage>>,
Query(query): Query<LogQuery>,
) -> Response {
let pairs = match storage.list(&PREFIX).await {
Ok(p) => p,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new(e.to_string(), "server_error")),
)
.into_response();
}
};
let mut records: Vec<AuditRecord> = pairs
.into_iter()
.filter_map(|(_k, v)| serde_json::from_slice(&v).ok())
.filter(|r: &AuditRecord| {
if let Some(ref key) = query.key
&& &r.key_name != key
{
return false;
}
if let Some(ref model) = query.model
&& &r.model != model
{
return false;
}
if let Some(since) = query.since
&& r.timestamp < since
{
return false;
}
if let Some(until) = query.until
&& r.timestamp > until
{
return false;
}
true
})
.collect();
records.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
records.truncate(query.limit);
Json(records).into_response()
}