llmkit-rs 0.1.0

Unified multi-provider async LLM client for Rust — OpenAI, Anthropic, Ollama, with Tower middleware
Documentation
//! Tower middleware unit tests against a mock provider (no network).

use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use llmkit::prelude::*;
use llmkit::{
    ChatStream, CostTrackingLayer, EmbedRequest, EmbedResponse, FinishReason, ModelPricing,
    TokenUsage,
};

/// A configurable mock that fails a fixed number of times, then succeeds.
struct MockProvider {
    fail_times: AtomicU32,
    calls: Arc<AtomicU32>,
    model: String,
    usage: TokenUsage,
}

impl MockProvider {
    fn new(model: &str) -> Self {
        Self {
            fail_times: AtomicU32::new(0),
            calls: Arc::new(AtomicU32::new(0)),
            model: model.into(),
            usage: TokenUsage::new(1_000_000, 1_000_000),
        }
    }

    fn failing(model: &str, fail_times: u32) -> Self {
        let m = Self::new(model);
        m.fail_times.store(fail_times, Ordering::SeqCst);
        m
    }
}

#[async_trait]
impl LlmProvider for MockProvider {
    async fn chat(&self, _req: ChatRequest) -> LlmResult<ChatResponse> {
        self.calls.fetch_add(1, Ordering::SeqCst);
        if self.fail_times.load(Ordering::SeqCst) > 0 {
            self.fail_times.fetch_sub(1, Ordering::SeqCst);
            return Err(LlmError::Timeout);
        }
        Ok(ChatResponse {
            id: "mock".into(),
            provider: "mock".into(),
            model: self.model.clone(),
            message: Message::assistant("ok"),
            finish_reason: FinishReason::Stop,
            tool_calls: vec![],
            usage: self.usage,
            cost: None,
            latency_ms: 0,
        })
    }

    async fn chat_stream(&self, _req: ChatRequest) -> LlmResult<ChatStream> {
        Err(LlmError::Unsupported("mock stream".into()))
    }

    async fn embed(&self, _req: EmbedRequest) -> LlmResult<EmbedResponse> {
        Ok(EmbedResponse {
            provider: "mock".into(),
            model: self.model.clone(),
            embeddings: vec![vec![0.0; 4]],
            usage: TokenUsage::default(),
        })
    }

    fn name(&self) -> &'static str {
        "mock"
    }

    fn model(&self) -> &str {
        &self.model
    }
}

#[tokio::test(start_paused = true)]
async fn retry_recovers_after_transient_failures() {
    let mock = MockProvider::failing("claude-opus-4-8", 2);
    let calls = mock.calls.clone();

    let client = LlmClientBuilder::new()
        .provider(mock)
        .layer(RetryLayer::exponential(3, Duration::from_millis(10)))
        .build()
        .unwrap();

    let resp = client.chat_once(ChatRequest::builder().user("hi").build()).await.unwrap();
    assert_eq!(resp.text().as_deref(), Some("ok"));
    assert_eq!(calls.load(Ordering::SeqCst), 3); // 2 failures + 1 success
}

#[tokio::test(start_paused = true)]
async fn retry_gives_up_and_returns_error() {
    let mock = MockProvider::failing("gpt-4o", 5);

    let client = LlmClientBuilder::new()
        .provider(mock)
        .layer(RetryLayer::exponential(2, Duration::from_millis(10)))
        .build()
        .unwrap();

    let err = client.chat_once(ChatRequest::builder().user("hi").build()).await.unwrap_err();
    assert!(matches!(err, LlmError::Timeout));
}

#[tokio::test]
async fn cost_tracking_accumulates_session_cost() {
    let cost_layer = CostTrackingLayer::new();
    let handle = cost_layer.session_cost();

    let client = LlmClientBuilder::new()
        .provider(MockProvider::new("claude-opus-4-8")) // priced at $5/$25 per MTok
        .layer(cost_layer)
        .track_cost(handle)
        .build()
        .unwrap();

    // 1M prompt + 1M completion => $5 + $25 = $30 per call.
    client.chat_once(ChatRequest::builder().user("x").build()).await.unwrap();
    assert!((client.session_cost_usd() - 30.0).abs() < 1e-6);

    client.chat_once(ChatRequest::builder().user("x").build()).await.unwrap();
    assert!((client.session_cost_usd() - 60.0).abs() < 1e-6);
}

#[tokio::test]
async fn budget_cap_refuses_once_exceeded() {
    let client = LlmClientBuilder::new()
        .provider(MockProvider::new("claude-opus-4-8"))
        .layer(CostTrackingLayer::with_budget(20.0)) // first call ($30) passes, then over budget
        .build()
        .unwrap();

    client.chat_once(ChatRequest::builder().user("x").build()).await.unwrap();
    let err = client.chat_once(ChatRequest::builder().user("x").build()).await.unwrap_err();
    assert!(matches!(err, LlmError::BudgetExceeded(_)));
}

#[tokio::test]
async fn fallback_advances_on_retryable_error() {
    let primary = MockProvider::failing("gpt-4o", 10); // always fails (timeout)
    let secondary = MockProvider::new("claude-opus-4-8");
    let secondary_calls = secondary.calls.clone();

    let client = LlmClientBuilder::new()
        .provider(primary)
        .fallback(secondary)
        .build()
        .unwrap();

    let resp = client.chat_once(ChatRequest::builder().user("hi").build()).await.unwrap();
    assert_eq!(resp.model, "claude-opus-4-8");
    assert_eq!(secondary_calls.load(Ordering::SeqCst), 1);
}

#[tokio::test]
async fn rate_limit_allows_within_capacity() {
    let client = LlmClientBuilder::new()
        .provider(MockProvider::new("gpt-4o-mini"))
        .layer(RateLimitLayer::token_bucket(1_000_000, Duration::from_secs(60)))
        .build()
        .unwrap();

    // Small request well within the bucket — should not block.
    let resp = client.chat_once(ChatRequest::builder().user("hi").build()).await.unwrap();
    assert_eq!(resp.text().as_deref(), Some("ok"));
}

#[test]
fn pricing_table_matches_expectations() {
    let opus = ModelPricing { input_per_mtok: 5.0, output_per_mtok: 25.0 };
    assert_eq!(llmkit::pricing::pricing_for("claude-opus-4-8"), Some(opus));
}