Skip to main content

crabllm_core/
retrying.rs

1use crate::{
2    AnthropicRequest, AnthropicResponse, AnthropicStreamEvent, AudioSpeechRequest, BoxStream,
3    ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest,
4    EmbeddingResponse, Error, GeminiRequest, GeminiResponse, ImageRequest, MultipartField,
5    Provider,
6};
7use rand::Rng;
8use std::{future::Future, time::Duration};
9
10const DEFAULT_MAX_RETRIES: u32 = 2;
11const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
12const DEFAULT_MAX_RETRY_AFTER: Duration = Duration::from_secs(60);
13const INITIAL_BACKOFF: Duration = Duration::from_millis(100);
14
15/// A `Provider` wrapper that retries transient failures with exponential
16/// backoff and full jitter, and bounds each attempt with a per-call timeout.
17///
18/// 429s whose `retry_after` exceeds `max_retry_after` are propagated
19/// immediately — the upstream is signalling a wait longer than this wrapper
20/// is willing to block for.
21#[derive(Debug, Clone)]
22pub struct Retrying<P: Provider> {
23    inner: P,
24    max_retries: u32,
25    timeout: Duration,
26    max_retry_after: Duration,
27}
28
29impl<P: Provider> Retrying<P> {
30    /// Wrap a provider with the default retry policy
31    /// (2 retries, 30s timeout, 60s max retry-after, 100ms initial backoff).
32    pub fn new(inner: P) -> Self {
33        Self {
34            inner,
35            max_retries: DEFAULT_MAX_RETRIES,
36            timeout: DEFAULT_TIMEOUT,
37            max_retry_after: DEFAULT_MAX_RETRY_AFTER,
38        }
39    }
40
41    /// Override the maximum `Retry-After` duration this wrapper will honor.
42    /// 429s above this threshold are propagated as non-retryable.
43    pub fn max_retry_after(mut self, d: Duration) -> Self {
44        self.max_retry_after = d;
45        self
46    }
47
48    /// Whether this error should be retried. Transient errors are retryable
49    /// unless they carry a `retry_after` that exceeds the threshold.
50    fn should_retry(&self, e: &Error) -> bool {
51        if !e.is_transient() {
52            return false;
53        }
54        !matches!(e.retry_after(), Some(ra) if ra > self.max_retry_after)
55    }
56
57    async fn timed<T>(
58        &self,
59        fut: impl Future<Output = Result<T, Error>> + Send,
60    ) -> Result<T, Error> {
61        if self.timeout.is_zero() {
62            return fut.await;
63        }
64        let Ok(result) = tokio::time::timeout(self.timeout, fut).await else {
65            return Err(Error::Timeout);
66        };
67        result
68    }
69}
70
71impl<P: Provider> Provider for Retrying<P> {
72    async fn chat_completion(
73        &self,
74        request: &ChatCompletionRequest,
75    ) -> Result<ChatCompletionResponse, Error> {
76        let mut backoff = INITIAL_BACKOFF;
77        let mut last_err = None;
78        for _ in 0..=self.max_retries {
79            match self.timed(self.inner.chat_completion(request)).await {
80                Ok(resp) => return Ok(resp),
81                Err(e) if self.should_retry(&e) => {
82                    let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
83                    last_err = Some(e);
84                    tokio::time::sleep(sleep).await;
85                    backoff *= 2;
86                }
87                Err(e) => return Err(e),
88            }
89        }
90        Err(last_err.expect("retry loop exited without producing an error"))
91    }
92
93    async fn chat_completion_stream(
94        &self,
95        request: &ChatCompletionRequest,
96    ) -> Result<BoxStream<'static, Result<ChatCompletionChunk, Error>>, Error> {
97        let mut backoff = INITIAL_BACKOFF;
98        let mut last_err = None;
99        for _ in 0..=self.max_retries {
100            match self.timed(self.inner.chat_completion_stream(request)).await {
101                Ok(stream) => return Ok(stream),
102                Err(e) if self.should_retry(&e) => {
103                    let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
104                    last_err = Some(e);
105                    tokio::time::sleep(sleep).await;
106                    backoff *= 2;
107                }
108                Err(e) => return Err(e),
109            }
110        }
111        Err(last_err.expect("retry loop exited without producing an error"))
112    }
113
114    async fn anthropic_messages(
115        &self,
116        request: &AnthropicRequest,
117    ) -> Result<AnthropicResponse, Error> {
118        let mut backoff = INITIAL_BACKOFF;
119        let mut last_err = None;
120        for _ in 0..=self.max_retries {
121            match self.timed(self.inner.anthropic_messages(request)).await {
122                Ok(resp) => return Ok(resp),
123                Err(e) if self.should_retry(&e) => {
124                    let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
125                    last_err = Some(e);
126                    tokio::time::sleep(sleep).await;
127                    backoff *= 2;
128                }
129                Err(e) => return Err(e),
130            }
131        }
132        Err(last_err.expect("retry loop exited without producing an error"))
133    }
134
135    async fn anthropic_messages_stream(
136        &self,
137        request: &AnthropicRequest,
138    ) -> Result<BoxStream<'static, Result<AnthropicStreamEvent, Error>>, Error> {
139        let mut backoff = INITIAL_BACKOFF;
140        let mut last_err = None;
141        for _ in 0..=self.max_retries {
142            match self
143                .timed(self.inner.anthropic_messages_stream(request))
144                .await
145            {
146                Ok(stream) => return Ok(stream),
147                Err(e) if self.should_retry(&e) => {
148                    let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
149                    last_err = Some(e);
150                    tokio::time::sleep(sleep).await;
151                    backoff *= 2;
152                }
153                Err(e) => return Err(e),
154            }
155        }
156        Err(last_err.expect("retry loop exited without producing an error"))
157    }
158
159    async fn gemini_generate_content_stream(
160        &self,
161        model: &str,
162        request: &GeminiRequest,
163    ) -> Result<BoxStream<'static, Result<GeminiResponse, Error>>, Error> {
164        let mut backoff = INITIAL_BACKOFF;
165        let mut last_err = None;
166        for _ in 0..=self.max_retries {
167            match self
168                .timed(self.inner.gemini_generate_content_stream(model, request))
169                .await
170            {
171                Ok(stream) => return Ok(stream),
172                Err(e) if self.should_retry(&e) => {
173                    let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
174                    last_err = Some(e);
175                    tokio::time::sleep(sleep).await;
176                    backoff *= 2;
177                }
178                Err(e) => return Err(e),
179            }
180        }
181        Err(last_err.expect("retry loop exited without producing an error"))
182    }
183
184    async fn embedding(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse, Error> {
185        self.inner.embedding(request).await
186    }
187
188    async fn image_generation(
189        &self,
190        request: &ImageRequest,
191    ) -> Result<(bytes::Bytes, String), Error> {
192        self.inner.image_generation(request).await
193    }
194
195    async fn audio_speech(
196        &self,
197        request: &AudioSpeechRequest,
198    ) -> Result<(bytes::Bytes, String), Error> {
199        self.inner.audio_speech(request).await
200    }
201
202    async fn audio_transcription(
203        &self,
204        model: &str,
205        fields: &[MultipartField],
206    ) -> Result<(bytes::Bytes, String), Error> {
207        self.inner.audio_transcription(model, fields).await
208    }
209}
210
211/// Full jitter: random duration in [backoff/2, backoff].
212fn jittered(backoff: Duration) -> Duration {
213    let lo = backoff.as_millis() as u64 / 2;
214    let hi = backoff.as_millis() as u64;
215    if lo >= hi {
216        return backoff;
217    }
218    Duration::from_millis(rand::rng().random_range(lo..=hi))
219}