1use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use llmkit_core::{
8 ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmError,
9 LlmProvider, LlmResult,
10};
11
12use crate::layer::LlmLayer;
13
14#[derive(Debug, Clone, Copy)]
16pub struct RetryLayer {
17 attempts: u32,
18 base: Duration,
19 max_backoff: Duration,
20}
21
22impl RetryLayer {
23 pub fn exponential(attempts: u32, base: Duration) -> Self {
25 Self { attempts: attempts.max(1), base, max_backoff: Duration::from_secs(30) }
26 }
27
28 pub fn max_backoff(mut self, max: Duration) -> Self {
30 self.max_backoff = max;
31 self
32 }
33}
34
35impl LlmLayer for RetryLayer {
36 type Provider = Retry;
37 fn layer(self, inner: Arc<dyn LlmProvider>) -> Retry {
38 Retry { inner, cfg: self }
39 }
40}
41
42pub struct Retry {
44 inner: Arc<dyn LlmProvider>,
45 cfg: RetryLayer,
46}
47
48impl Retry {
49 fn backoff(&self, attempt: u32, err: &LlmError) -> Duration {
50 if let LlmError::RateLimited { retry_after: Some(d), .. } = err {
52 return (*d).min(self.cfg.max_backoff);
53 }
54 let mult = 2u32.saturating_pow(attempt);
55 self.cfg.base.saturating_mul(mult).min(self.cfg.max_backoff)
56 }
57}
58
59#[async_trait]
60impl LlmProvider for Retry {
61 async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
62 let mut last = None;
63 for attempt in 0..self.cfg.attempts {
64 match self.inner.chat(req.clone()).await {
65 Ok(resp) => return Ok(resp),
66 Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
67 tracing::debug!(provider = self.inner.name(), attempt, error = %e, "retrying");
68 tokio::time::sleep(self.backoff(attempt, &e)).await;
69 last = Some(e);
70 }
71 Err(e) => return Err(e),
72 }
73 }
74 Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
75 }
76
77 async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
78 let mut last = None;
80 for attempt in 0..self.cfg.attempts {
81 match self.inner.chat_stream(req.clone()).await {
82 Ok(s) => return Ok(s),
83 Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
84 tokio::time::sleep(self.backoff(attempt, &e)).await;
85 last = Some(e);
86 }
87 Err(e) => return Err(e),
88 }
89 }
90 Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
91 }
92
93 async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
94 let mut last = None;
95 for attempt in 0..self.cfg.attempts {
96 match self.inner.embed(req.clone()).await {
97 Ok(r) => return Ok(r),
98 Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
99 tokio::time::sleep(self.backoff(attempt, &e)).await;
100 last = Some(e);
101 }
102 Err(e) => return Err(e),
103 }
104 }
105 Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
106 }
107
108 fn name(&self) -> &'static str {
109 self.inner.name()
110 }
111
112 fn model(&self) -> &str {
113 self.inner.model()
114 }
115
116 fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
117 self.inner.estimate_cost(req)
118 }
119}