llmkit-tower 0.1.0

Tower middleware (retry, rate limit, cost tracking, tracing) for llmkit-rs
Documentation
//! Exponential-backoff retry layer.

use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use llmkit_core::{
    ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmError,
    LlmProvider, LlmResult,
};

use crate::layer::LlmLayer;

/// Configures exponential-backoff retries for retryable errors.
#[derive(Debug, Clone, Copy)]
pub struct RetryLayer {
    attempts: u32,
    base: Duration,
    max_backoff: Duration,
}

impl RetryLayer {
    /// `attempts` total tries with exponential backoff starting at `base`.
    pub fn exponential(attempts: u32, base: Duration) -> Self {
        Self { attempts: attempts.max(1), base, max_backoff: Duration::from_secs(30) }
    }

    /// Cap the backoff delay.
    pub fn max_backoff(mut self, max: Duration) -> Self {
        self.max_backoff = max;
        self
    }
}

impl LlmLayer for RetryLayer {
    type Provider = Retry;
    fn layer(self, inner: Arc<dyn LlmProvider>) -> Retry {
        Retry { inner, cfg: self }
    }
}

/// Provider produced by [`RetryLayer`].
pub struct Retry {
    inner: Arc<dyn LlmProvider>,
    cfg: RetryLayer,
}

impl Retry {
    fn backoff(&self, attempt: u32, err: &LlmError) -> Duration {
        // Honour a server-supplied retry-after when present.
        if let LlmError::RateLimited { retry_after: Some(d), .. } = err {
            return (*d).min(self.cfg.max_backoff);
        }
        let mult = 2u32.saturating_pow(attempt);
        self.cfg.base.saturating_mul(mult).min(self.cfg.max_backoff)
    }
}

#[async_trait]
impl LlmProvider for Retry {
    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
        let mut last = None;
        for attempt in 0..self.cfg.attempts {
            match self.inner.chat(req.clone()).await {
                Ok(resp) => return Ok(resp),
                Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
                    tracing::debug!(provider = self.inner.name(), attempt, error = %e, "retrying");
                    tokio::time::sleep(self.backoff(attempt, &e)).await;
                    last = Some(e);
                }
                Err(e) => return Err(e),
            }
        }
        Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
    }

    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
        // Retry only the connection establishment, not mid-stream failures.
        let mut last = None;
        for attempt in 0..self.cfg.attempts {
            match self.inner.chat_stream(req.clone()).await {
                Ok(s) => return Ok(s),
                Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
                    tokio::time::sleep(self.backoff(attempt, &e)).await;
                    last = Some(e);
                }
                Err(e) => return Err(e),
            }
        }
        Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
    }

    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
        let mut last = None;
        for attempt in 0..self.cfg.attempts {
            match self.inner.embed(req.clone()).await {
                Ok(r) => return Ok(r),
                Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
                    tokio::time::sleep(self.backoff(attempt, &e)).await;
                    last = Some(e);
                }
                Err(e) => return Err(e),
            }
        }
        Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
    }

    fn name(&self) -> &'static str {
        self.inner.name()
    }

    fn model(&self) -> &str {
        self.inner.model()
    }

    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
        self.inner.estimate_cost(req)
    }
}