use arc_swap::ArcSwap;
use axum::{
body::Body,
http::{Request, StatusCode},
};
use crabllm_core::{
BoxFuture, BoxStream, ChatCompletionRequest, ChatCompletionResponse, Choice, ContentBlock,
Error, FinishReason, GatewayConfig, KvPairs, Message, OpenAiUsage, Prefix, Provider, Role,
Storage,
};
use crabllm_provider::{Deployment, ProviderRegistry};
use crabllm_proxy::{AppState, UsageEvent, router};
use std::{
collections::HashMap,
sync::{Arc, RwLock},
time::Duration,
};
use tokio::sync::broadcast;
use tower::ServiceExt;
struct FakeProvider;
impl Provider for FakeProvider {
async fn chat_completion(
&self,
request: &ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Error> {
Ok(ChatCompletionResponse {
id: "chatcmpl-test".into(),
object: "chat.completion".into(),
created: 0,
model: request.model.clone(),
choices: vec![Choice {
index: 0,
message: Message {
role: Role::Assistant,
content: vec![ContentBlock::text("hi")],
},
finish_reason: Some(FinishReason::Stop),
logprobs: None,
}],
usage: Some(OpenAiUsage {
prompt_tokens: 11,
completion_tokens: 22,
total_tokens: 33,
completion_tokens_details: None,
prompt_cache_hit_tokens: None,
prompt_cache_miss_tokens: None,
}),
system_fingerprint: None,
})
}
async fn chat_completion_stream(
&self,
_request: &ChatCompletionRequest,
) -> Result<BoxStream<'static, Result<crabllm_core::ChatCompletionChunk, Error>>, Error> {
Err(Error::not_implemented("stream"))
}
async fn anthropic_messages(
&self,
_request: &crabllm_core::AnthropicRequest,
) -> Result<crabllm_core::AnthropicResponse, Error> {
Err(Error::not_implemented("anthropic_messages"))
}
async fn anthropic_messages_stream(
&self,
_request: &crabllm_core::AnthropicRequest,
) -> Result<BoxStream<'static, Result<crabllm_core::AnthropicStreamEvent, Error>>, Error> {
Err(Error::not_implemented("anthropic_messages_stream"))
}
async fn gemini_generate_content_stream(
&self,
_model: &str,
_request: &crabllm_core::GeminiRequest,
) -> Result<BoxStream<'static, Result<crabllm_core::GeminiResponse, Error>>, Error> {
Err(Error::not_implemented("gemini streaming"))
}
}
struct FakeStorage;
impl Storage for FakeStorage {
fn get(&self, _key: &[u8]) -> BoxFuture<'_, Result<Option<Vec<u8>>, Error>> {
Box::pin(async { Ok(None) })
}
fn set(&self, _key: &[u8], _value: Vec<u8>) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async { Ok(()) })
}
fn increment(&self, _key: &[u8], _delta: i64) -> BoxFuture<'_, Result<i64, Error>> {
Box::pin(async { Ok(0) })
}
fn list(&self, _prefix: &Prefix) -> BoxFuture<'_, Result<KvPairs, Error>> {
Box::pin(async { Ok(KvPairs::default()) })
}
fn delete(&self, _key: &[u8]) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async { Ok(()) })
}
}
fn empty_config() -> GatewayConfig {
GatewayConfig {
listen: String::new(),
providers: HashMap::new(),
keys: Vec::new(),
extensions: None,
storage: None,
aliases: HashMap::new(),
models: HashMap::new(),
cloud_models: None,
admin_token: None,
shutdown_timeout: 30,
openapi: true,
}
}
fn build_state(tx: broadcast::Sender<UsageEvent>) -> AppState<FakeStorage, FakeProvider> {
let mut providers: HashMap<String, Vec<Arc<Deployment<FakeProvider>>>> = HashMap::new();
providers.insert(
"fake-model".to_string(),
vec![Arc::new(Deployment {
provider: FakeProvider,
weight: 1,
max_retries: 0,
timeout: Duration::from_secs(5),
retry_deadline: Duration::from_secs(15),
})],
);
let mut model_providers = HashMap::new();
model_providers.insert("fake-model".to_string(), "fake".to_string());
let registry = ProviderRegistry::new(providers, HashMap::new(), model_providers);
AppState {
registry: Arc::new(ArcSwap::from_pointee(registry)),
config: empty_config(),
extensions: Arc::new(Vec::new()),
storage: Arc::new(FakeStorage),
key_map: Arc::new(RwLock::new(HashMap::new())),
usage_events: Some(tx),
}
}
#[tokio::test]
async fn chat_completion_emits_one_usage_event() {
let (tx, mut rx) = broadcast::channel::<UsageEvent>(16);
let state = build_state(tx);
let app = router(state, vec![]);
let body = serde_json::json!({
"model": "fake-model",
"messages": [{ "role": "user", "content": [{ "type": "text", "text": "hi" }] }],
});
let request = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(body.to_string()))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let event = tokio::time::timeout(Duration::from_secs(1), rx.recv())
.await
.expect("timed out waiting for UsageEvent")
.expect("broadcast closed");
assert_eq!(event.endpoint, "chat.completions");
assert_eq!(event.model, "fake-model");
assert_eq!(event.provider, "fake");
assert_eq!(event.usage.prompt_tokens(), 11);
assert_eq!(event.usage.completion_tokens(), 22);
assert_eq!(event.usage.cache_read_tokens, 0);
assert_eq!(event.status, 200);
assert!(event.error.is_none());
match tokio::time::timeout(Duration::from_millis(50), rx.recv()).await {
Err(_) => {} Ok(Err(broadcast::error::RecvError::Closed)) => {} Ok(Err(broadcast::error::RecvError::Lagged(_))) => {} Ok(Ok(evt)) => panic!("unexpected second event: {:?}", evt),
}
}
#[tokio::test]
async fn none_usage_events_is_zero_cost() {
let mut providers: HashMap<String, Vec<Arc<Deployment<FakeProvider>>>> = HashMap::new();
providers.insert(
"fake-model".to_string(),
vec![Arc::new(Deployment {
provider: FakeProvider,
weight: 1,
max_retries: 0,
timeout: Duration::from_secs(5),
retry_deadline: Duration::from_secs(15),
})],
);
let mut model_providers = HashMap::new();
model_providers.insert("fake-model".to_string(), "fake".to_string());
let registry = ProviderRegistry::new(providers, HashMap::new(), model_providers);
let state = AppState::<FakeStorage, FakeProvider> {
registry: Arc::new(ArcSwap::from_pointee(registry)),
config: empty_config(),
extensions: Arc::new(Vec::new()),
storage: Arc::new(FakeStorage),
key_map: Arc::new(RwLock::new(HashMap::new())),
usage_events: None,
};
let app = router(state, vec![]);
let body = serde_json::json!({
"model": "fake-model",
"messages": [{ "role": "user", "content": [{ "type": "text", "text": "hi" }] }],
});
let request = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(body.to_string()))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}