Skip to main content

llmkit_tower/
retry.rs

1//! Exponential-backoff retry layer.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use llmkit_core::{
8    ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmError,
9    LlmProvider, LlmResult,
10};
11
12use crate::layer::LlmLayer;
13
14/// Configures exponential-backoff retries for retryable errors.
15#[derive(Debug, Clone, Copy)]
16pub struct RetryLayer {
17    attempts: u32,
18    base: Duration,
19    max_backoff: Duration,
20}
21
22impl RetryLayer {
23    /// `attempts` total tries with exponential backoff starting at `base`.
24    pub fn exponential(attempts: u32, base: Duration) -> Self {
25        Self { attempts: attempts.max(1), base, max_backoff: Duration::from_secs(30) }
26    }
27
28    /// Cap the backoff delay.
29    pub fn max_backoff(mut self, max: Duration) -> Self {
30        self.max_backoff = max;
31        self
32    }
33}
34
35impl LlmLayer for RetryLayer {
36    type Provider = Retry;
37    fn layer(self, inner: Arc<dyn LlmProvider>) -> Retry {
38        Retry { inner, cfg: self }
39    }
40}
41
42/// Provider produced by [`RetryLayer`].
43pub struct Retry {
44    inner: Arc<dyn LlmProvider>,
45    cfg: RetryLayer,
46}
47
48impl Retry {
49    fn backoff(&self, attempt: u32, err: &LlmError) -> Duration {
50        // Honour a server-supplied retry-after when present.
51        if let LlmError::RateLimited { retry_after: Some(d), .. } = err {
52            return (*d).min(self.cfg.max_backoff);
53        }
54        let mult = 2u32.saturating_pow(attempt);
55        self.cfg.base.saturating_mul(mult).min(self.cfg.max_backoff)
56    }
57}
58
59#[async_trait]
60impl LlmProvider for Retry {
61    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
62        let mut last = None;
63        for attempt in 0..self.cfg.attempts {
64            match self.inner.chat(req.clone()).await {
65                Ok(resp) => return Ok(resp),
66                Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
67                    tracing::debug!(provider = self.inner.name(), attempt, error = %e, "retrying");
68                    tokio::time::sleep(self.backoff(attempt, &e)).await;
69                    last = Some(e);
70                }
71                Err(e) => return Err(e),
72            }
73        }
74        Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
75    }
76
77    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
78        // Retry only the connection establishment, not mid-stream failures.
79        let mut last = None;
80        for attempt in 0..self.cfg.attempts {
81            match self.inner.chat_stream(req.clone()).await {
82                Ok(s) => return Ok(s),
83                Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
84                    tokio::time::sleep(self.backoff(attempt, &e)).await;
85                    last = Some(e);
86                }
87                Err(e) => return Err(e),
88            }
89        }
90        Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
91    }
92
93    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
94        let mut last = None;
95        for attempt in 0..self.cfg.attempts {
96            match self.inner.embed(req.clone()).await {
97                Ok(r) => return Ok(r),
98                Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
99                    tokio::time::sleep(self.backoff(attempt, &e)).await;
100                    last = Some(e);
101                }
102                Err(e) => return Err(e),
103            }
104        }
105        Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
106    }
107
108    fn name(&self) -> &'static str {
109        self.inner.name()
110    }
111
112    fn model(&self) -> &str {
113        self.inner.model()
114    }
115
116    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
117        self.inner.estimate_cost(req)
118    }
119}