crabtalk_daemon/provider.rs
1//! `Retrying<P>` — a `Provider` wrapper that adds exponential-backoff retry
2//! and per-call timeout on top of any inner provider.
3//!
4//! This restores the retry/timeout semantics that lived in the old
5//! `crates/model::Provider` wrapper before the trait migration. It is a
6//! deployment-layer concern owned by the daemon — `wcore::Model<P>` does
7//! not retry, since not every consumer (e.g. an in-process MLX provider)
8//! wants the same retry policy.
9//!
10//! Per-provider retry config (`max_retries` / `timeout` on individual
11//! `ProviderDef` entries) is not threaded through yet; the wrapper applies
12//! a single set of defaults to every dispatch. Restoring per-provider
13//! config is a follow-up — see TODO below.
14
15use crabllm_core::{
16 AudioSpeechRequest, BoxStream, ChatCompletionChunk, ChatCompletionRequest,
17 ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse, Error, ImageRequest,
18 MultipartField, Provider,
19};
20use rand::Rng;
21use std::time::Duration;
22
23/// Default values matching the old `crates/model::Provider` defaults.
24const DEFAULT_MAX_RETRIES: u32 = 2;
25const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
26const INITIAL_BACKOFF: Duration = Duration::from_millis(100);
27
28/// A `Provider` wrapper that retries transient failures with exponential
29/// backoff and full jitter, and bounds each attempt with a per-call timeout.
30///
31/// **Scope:** the retry policy applies to `chat_completion` only. Streaming
32/// (`chat_completion_stream`) skips retry — the connection is already
33/// established and clients consuming chunk-by-chunk handle their own
34/// resumption — but still bounds the connection-establishment phase with
35/// the same timeout. The non-chat methods (`embedding`, `image_generation`,
36/// `audio_speech`, `audio_transcription`) are bare pass-throughs without
37/// retry or timeout, because the daemon's current protocol doesn't expose
38/// these endpoints. If a future daemon feature needs them, extend this
39/// wrapper's scope at that point.
40#[derive(Debug, Clone)]
41pub struct Retrying<P: Provider> {
42 inner: P,
43 max_retries: u32,
44 timeout: Duration,
45}
46
47impl<P: Provider> Retrying<P> {
48 /// Wrap a provider with the default retry policy
49 /// (2 retries, 30s timeout, 100ms initial backoff).
50 pub fn new(inner: P) -> Self {
51 Self {
52 inner,
53 max_retries: DEFAULT_MAX_RETRIES,
54 timeout: DEFAULT_TIMEOUT,
55 }
56 }
57
58 // TODO: per-provider retry config. The old crates/model wrapper read
59 // `ProviderDef.max_retries` and `ProviderDef.timeout` per provider.
60 // Restoring that requires either a config-aware wrapper at the
61 // Deployment level (in crabllm) or a per-call config struct.
62}
63
64impl<P: Provider> Provider for Retrying<P> {
65 async fn chat_completion(
66 &self,
67 request: &ChatCompletionRequest,
68 ) -> Result<ChatCompletionResponse, Error> {
69 let mut backoff = INITIAL_BACKOFF;
70 let mut last_err = None;
71 for _ in 0..=self.max_retries {
72 let result = if self.timeout.is_zero() {
73 self.inner.chat_completion(request).await
74 } else {
75 match tokio::time::timeout(self.timeout, self.inner.chat_completion(request)).await
76 {
77 Ok(r) => r,
78 Err(_) => Err(Error::Timeout),
79 }
80 };
81 match result {
82 Ok(resp) => return Ok(resp),
83 Err(e) if e.is_transient() => {
84 last_err = Some(e);
85 tokio::time::sleep(jittered(backoff)).await;
86 backoff *= 2;
87 }
88 Err(e) => return Err(e),
89 }
90 }
91 Err(last_err.expect("retry loop exited without producing an error"))
92 }
93
94 async fn chat_completion_stream(
95 &self,
96 request: &ChatCompletionRequest,
97 ) -> Result<BoxStream<'static, Result<ChatCompletionChunk, Error>>, Error> {
98 // Streaming does not retry — connection establishment is the only
99 // failure mode we could meaningfully retry, and chunks already
100 // streaming would be lost on retry. Apply the timeout to stream
101 // open only.
102 if self.timeout.is_zero() {
103 self.inner.chat_completion_stream(request).await
104 } else {
105 match tokio::time::timeout(self.timeout, self.inner.chat_completion_stream(request))
106 .await
107 {
108 Ok(r) => r,
109 Err(_) => Err(Error::Timeout),
110 }
111 }
112 }
113
114 async fn embedding(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse, Error> {
115 self.inner.embedding(request).await
116 }
117
118 async fn image_generation(
119 &self,
120 request: &ImageRequest,
121 ) -> Result<(bytes::Bytes, String), Error> {
122 self.inner.image_generation(request).await
123 }
124
125 async fn audio_speech(
126 &self,
127 request: &AudioSpeechRequest,
128 ) -> Result<(bytes::Bytes, String), Error> {
129 self.inner.audio_speech(request).await
130 }
131
132 async fn audio_transcription(
133 &self,
134 model: &str,
135 fields: &[MultipartField],
136 ) -> Result<(bytes::Bytes, String), Error> {
137 self.inner.audio_transcription(model, fields).await
138 }
139}
140
141/// Full jitter: random duration in [backoff/2, backoff].
142fn jittered(backoff: Duration) -> Duration {
143 let lo = backoff.as_millis() as u64 / 2;
144 let hi = backoff.as_millis() as u64;
145 if lo >= hi {
146 return backoff;
147 }
148 Duration::from_millis(rand::rng().random_range(lo..=hi))
149}