llmkit-tower 0.1.0

Tower middleware (retry, rate limit, cost tracking, tracing) for llmkit-rs
Documentation
//! Provider fallback chaining: primary → secondary on error/timeout.

use std::sync::Arc;

use async_trait::async_trait;
use llmkit_core::{
    ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmError,
    LlmProvider, LlmResult,
};

/// Tries providers in order, advancing to the next on a retryable failure.
///
/// Non-retryable errors (invalid request, auth, budget) short-circuit — there
/// is no point trying the next provider if the request itself is malformed.
pub struct FallbackProvider {
    providers: Vec<Arc<dyn LlmProvider>>,
}

impl FallbackProvider {
    /// Build a chain from an ordered list of providers (primary first).
    pub fn new(providers: Vec<Arc<dyn LlmProvider>>) -> Self {
        assert!(!providers.is_empty(), "FallbackProvider requires at least one provider");
        Self { providers }
    }

    fn should_advance(err: &LlmError) -> bool {
        err.is_retryable()
    }
}

#[async_trait]
impl LlmProvider for FallbackProvider {
    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
        let mut errors = Vec::new();
        let last = self.providers.len() - 1;
        for (i, p) in self.providers.iter().enumerate() {
            match p.chat(req.clone()).await {
                Ok(resp) => return Ok(resp),
                Err(e) if i < last && Self::should_advance(&e) => {
                    tracing::warn!(provider = p.name(), error = %e, "falling back to next provider");
                    errors.push(e);
                }
                Err(e) => {
                    errors.push(e);
                    return Err(LlmError::AllProvidersFailed(errors));
                }
            }
        }
        Err(LlmError::AllProvidersFailed(errors))
    }

    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
        let mut errors = Vec::new();
        let last = self.providers.len() - 1;
        for (i, p) in self.providers.iter().enumerate() {
            match p.chat_stream(req.clone()).await {
                Ok(s) => return Ok(s),
                Err(e) if i < last && Self::should_advance(&e) => errors.push(e),
                Err(e) => {
                    errors.push(e);
                    return Err(LlmError::AllProvidersFailed(errors));
                }
            }
        }
        Err(LlmError::AllProvidersFailed(errors))
    }

    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
        let mut errors = Vec::new();
        let last = self.providers.len() - 1;
        for (i, p) in self.providers.iter().enumerate() {
            match p.embed(req.clone()).await {
                Ok(r) => return Ok(r),
                Err(e) if i < last && Self::should_advance(&e) => errors.push(e),
                Err(e) => {
                    errors.push(e);
                    return Err(LlmError::AllProvidersFailed(errors));
                }
            }
        }
        Err(LlmError::AllProvidersFailed(errors))
    }

    fn name(&self) -> &'static str {
        self.providers[0].name()
    }

    fn model(&self) -> &str {
        self.providers[0].model()
    }

    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
        self.providers[0].estimate_cost(req)
    }
}