use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use llmkit::prelude::*;
use llmkit::{
ChatStream, CostTrackingLayer, EmbedRequest, EmbedResponse, FinishReason, ModelPricing,
TokenUsage,
};
struct MockProvider {
fail_times: AtomicU32,
calls: Arc<AtomicU32>,
model: String,
usage: TokenUsage,
}
impl MockProvider {
fn new(model: &str) -> Self {
Self {
fail_times: AtomicU32::new(0),
calls: Arc::new(AtomicU32::new(0)),
model: model.into(),
usage: TokenUsage::new(1_000_000, 1_000_000),
}
}
fn failing(model: &str, fail_times: u32) -> Self {
let m = Self::new(model);
m.fail_times.store(fail_times, Ordering::SeqCst);
m
}
}
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _req: ChatRequest) -> LlmResult<ChatResponse> {
self.calls.fetch_add(1, Ordering::SeqCst);
if self.fail_times.load(Ordering::SeqCst) > 0 {
self.fail_times.fetch_sub(1, Ordering::SeqCst);
return Err(LlmError::Timeout);
}
Ok(ChatResponse {
id: "mock".into(),
provider: "mock".into(),
model: self.model.clone(),
message: Message::assistant("ok"),
finish_reason: FinishReason::Stop,
tool_calls: vec![],
usage: self.usage,
cost: None,
latency_ms: 0,
})
}
async fn chat_stream(&self, _req: ChatRequest) -> LlmResult<ChatStream> {
Err(LlmError::Unsupported("mock stream".into()))
}
async fn embed(&self, _req: EmbedRequest) -> LlmResult<EmbedResponse> {
Ok(EmbedResponse {
provider: "mock".into(),
model: self.model.clone(),
embeddings: vec![vec![0.0; 4]],
usage: TokenUsage::default(),
})
}
fn name(&self) -> &'static str {
"mock"
}
fn model(&self) -> &str {
&self.model
}
}
#[tokio::test(start_paused = true)]
async fn retry_recovers_after_transient_failures() {
let mock = MockProvider::failing("claude-opus-4-8", 2);
let calls = mock.calls.clone();
let client = LlmClientBuilder::new()
.provider(mock)
.layer(RetryLayer::exponential(3, Duration::from_millis(10)))
.build()
.unwrap();
let resp = client.chat_once(ChatRequest::builder().user("hi").build()).await.unwrap();
assert_eq!(resp.text().as_deref(), Some("ok"));
assert_eq!(calls.load(Ordering::SeqCst), 3); }
#[tokio::test(start_paused = true)]
async fn retry_gives_up_and_returns_error() {
let mock = MockProvider::failing("gpt-4o", 5);
let client = LlmClientBuilder::new()
.provider(mock)
.layer(RetryLayer::exponential(2, Duration::from_millis(10)))
.build()
.unwrap();
let err = client.chat_once(ChatRequest::builder().user("hi").build()).await.unwrap_err();
assert!(matches!(err, LlmError::Timeout));
}
#[tokio::test]
async fn cost_tracking_accumulates_session_cost() {
let cost_layer = CostTrackingLayer::new();
let handle = cost_layer.session_cost();
let client = LlmClientBuilder::new()
.provider(MockProvider::new("claude-opus-4-8")) .layer(cost_layer)
.track_cost(handle)
.build()
.unwrap();
client.chat_once(ChatRequest::builder().user("x").build()).await.unwrap();
assert!((client.session_cost_usd() - 30.0).abs() < 1e-6);
client.chat_once(ChatRequest::builder().user("x").build()).await.unwrap();
assert!((client.session_cost_usd() - 60.0).abs() < 1e-6);
}
#[tokio::test]
async fn budget_cap_refuses_once_exceeded() {
let client = LlmClientBuilder::new()
.provider(MockProvider::new("claude-opus-4-8"))
.layer(CostTrackingLayer::with_budget(20.0)) .build()
.unwrap();
client.chat_once(ChatRequest::builder().user("x").build()).await.unwrap();
let err = client.chat_once(ChatRequest::builder().user("x").build()).await.unwrap_err();
assert!(matches!(err, LlmError::BudgetExceeded(_)));
}
#[tokio::test]
async fn fallback_advances_on_retryable_error() {
let primary = MockProvider::failing("gpt-4o", 10); let secondary = MockProvider::new("claude-opus-4-8");
let secondary_calls = secondary.calls.clone();
let client = LlmClientBuilder::new()
.provider(primary)
.fallback(secondary)
.build()
.unwrap();
let resp = client.chat_once(ChatRequest::builder().user("hi").build()).await.unwrap();
assert_eq!(resp.model, "claude-opus-4-8");
assert_eq!(secondary_calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn rate_limit_allows_within_capacity() {
let client = LlmClientBuilder::new()
.provider(MockProvider::new("gpt-4o-mini"))
.layer(RateLimitLayer::token_bucket(1_000_000, Duration::from_secs(60)))
.build()
.unwrap();
let resp = client.chat_once(ChatRequest::builder().user("hi").build()).await.unwrap();
assert_eq!(resp.text().as_deref(), Some("ok"));
}
#[test]
fn pricing_table_matches_expectations() {
let opus = ModelPricing { input_per_mtok: 5.0, output_per_mtok: 25.0 };
assert_eq!(llmkit::pricing::pricing_for("claude-opus-4-8"), Some(opus));
}