Skip to main content

crabtalk_daemon/
provider.rs

1//! `Retrying<P>` — a `Provider` wrapper that adds exponential-backoff retry
2//! and per-call timeout on top of any inner provider.
3//!
4//! This restores the retry/timeout semantics that lived in the old
5//! `crates/model::Provider` wrapper before the trait migration. It is a
6//! deployment-layer concern owned by the daemon — `wcore::Model<P>` does
7//! not retry, since not every consumer (e.g. an in-process MLX provider)
8//! wants the same retry policy.
9//!
10//! Per-provider retry config (`max_retries` / `timeout` on individual
11//! `ProviderDef` entries) is not threaded through yet; the wrapper applies
12//! a single set of defaults to every dispatch. Restoring per-provider
13//! config is a follow-up — see TODO below.
14
15use crabllm_core::{
16    AudioSpeechRequest, BoxStream, ChatCompletionChunk, ChatCompletionRequest,
17    ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse, Error, ImageRequest,
18    MultipartField, Provider,
19};
20use rand::Rng;
21use std::time::Duration;
22
23/// Default values matching the old `crates/model::Provider` defaults.
24const DEFAULT_MAX_RETRIES: u32 = 2;
25const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
26const INITIAL_BACKOFF: Duration = Duration::from_millis(100);
27
28/// A `Provider` wrapper that retries transient failures with exponential
29/// backoff and full jitter, and bounds each attempt with a per-call timeout.
30///
31/// **Scope:** the retry policy applies to `chat_completion` only. Streaming
32/// (`chat_completion_stream`) skips retry — the connection is already
33/// established and clients consuming chunk-by-chunk handle their own
34/// resumption — but still bounds the connection-establishment phase with
35/// the same timeout. The non-chat methods (`embedding`, `image_generation`,
36/// `audio_speech`, `audio_transcription`) are bare pass-throughs without
37/// retry or timeout, because the daemon's current protocol doesn't expose
38/// these endpoints. If a future daemon feature needs them, extend this
39/// wrapper's scope at that point.
40#[derive(Debug, Clone)]
41pub struct Retrying<P: Provider> {
42    inner: P,
43    max_retries: u32,
44    timeout: Duration,
45}
46
47impl<P: Provider> Retrying<P> {
48    /// Wrap a provider with the default retry policy
49    /// (2 retries, 30s timeout, 100ms initial backoff).
50    pub fn new(inner: P) -> Self {
51        Self {
52            inner,
53            max_retries: DEFAULT_MAX_RETRIES,
54            timeout: DEFAULT_TIMEOUT,
55        }
56    }
57
58    // TODO: per-provider retry config. The old crates/model wrapper read
59    // `ProviderDef.max_retries` and `ProviderDef.timeout` per provider.
60    // Restoring that requires either a config-aware wrapper at the
61    // Deployment level (in crabllm) or a per-call config struct.
62}
63
64impl<P: Provider> Provider for Retrying<P> {
65    async fn chat_completion(
66        &self,
67        request: &ChatCompletionRequest,
68    ) -> Result<ChatCompletionResponse, Error> {
69        let mut backoff = INITIAL_BACKOFF;
70        let mut last_err = None;
71        for _ in 0..=self.max_retries {
72            let result = if self.timeout.is_zero() {
73                self.inner.chat_completion(request).await
74            } else {
75                match tokio::time::timeout(self.timeout, self.inner.chat_completion(request)).await
76                {
77                    Ok(r) => r,
78                    Err(_) => Err(Error::Timeout),
79                }
80            };
81            match result {
82                Ok(resp) => return Ok(resp),
83                Err(e) if e.is_transient() => {
84                    last_err = Some(e);
85                    tokio::time::sleep(jittered(backoff)).await;
86                    backoff *= 2;
87                }
88                Err(e) => return Err(e),
89            }
90        }
91        Err(last_err.expect("retry loop exited without producing an error"))
92    }
93
94    async fn chat_completion_stream(
95        &self,
96        request: &ChatCompletionRequest,
97    ) -> Result<BoxStream<'static, Result<ChatCompletionChunk, Error>>, Error> {
98        // Streaming does not retry — connection establishment is the only
99        // failure mode we could meaningfully retry, and chunks already
100        // streaming would be lost on retry. Apply the timeout to stream
101        // open only.
102        if self.timeout.is_zero() {
103            self.inner.chat_completion_stream(request).await
104        } else {
105            match tokio::time::timeout(self.timeout, self.inner.chat_completion_stream(request))
106                .await
107            {
108                Ok(r) => r,
109                Err(_) => Err(Error::Timeout),
110            }
111        }
112    }
113
114    async fn embedding(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse, Error> {
115        self.inner.embedding(request).await
116    }
117
118    async fn image_generation(
119        &self,
120        request: &ImageRequest,
121    ) -> Result<(bytes::Bytes, String), Error> {
122        self.inner.image_generation(request).await
123    }
124
125    async fn audio_speech(
126        &self,
127        request: &AudioSpeechRequest,
128    ) -> Result<(bytes::Bytes, String), Error> {
129        self.inner.audio_speech(request).await
130    }
131
132    async fn audio_transcription(
133        &self,
134        model: &str,
135        fields: &[MultipartField],
136    ) -> Result<(bytes::Bytes, String), Error> {
137        self.inner.audio_transcription(model, fields).await
138    }
139}
140
141/// Full jitter: random duration in [backoff/2, backoff].
142fn jittered(backoff: Duration) -> Duration {
143    let lo = backoff.as_millis() as u64 / 2;
144    let hi = backoff.as_millis() as u64;
145    if lo >= hi {
146        return backoff;
147    }
148    Duration::from_millis(rand::rng().random_range(lo..=hi))
149}