1use std::fmt;
55use std::time::Duration;
56
57#[derive(Debug, Clone, PartialEq)]
63pub enum AttemptError {
64 Transport(String),
69 Status5xx { code: u16, body: String },
72 Timeout(Duration),
74 NonRetryable(String),
81}
82
83impl AttemptError {
84 pub fn is_retryable(&self) -> bool {
86 matches!(
87 self,
88 AttemptError::Transport(_) | AttemptError::Status5xx { .. } | AttemptError::Timeout(_)
89 )
90 }
91}
92
93impl fmt::Display for AttemptError {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 match self {
96 AttemptError::Transport(msg) => write!(f, "transport: {msg}"),
97 AttemptError::Status5xx { code, body } => write!(f, "http {code}: {body}"),
98 AttemptError::Timeout(d) => write!(f, "timeout after {}ms", d.as_millis()),
99 AttemptError::NonRetryable(msg) => write!(f, "non-retryable: {msg}"),
100 }
101 }
102}
103
104#[derive(Debug, Clone, PartialEq)]
108pub struct FailoverSuccess<R> {
109 pub provider: String,
110 pub response: R,
111 pub prior_errors: Vec<(String, AttemptError)>,
112}
113
114#[derive(Debug, Clone, PartialEq)]
118pub struct FailoverExhausted {
119 pub attempts: Vec<(String, AttemptError)>,
120}
121
122impl fmt::Display for FailoverExhausted {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 write!(f, "all providers failed:")?;
125 for (provider, err) in &self.attempts {
126 write!(f, " [{provider}: {err}]")?;
127 }
128 Ok(())
129 }
130}
131
132pub fn run<R, F>(providers: &[&str], mut attempt: F) -> Result<FailoverSuccess<R>, FailoverExhausted>
143where
144 F: FnMut(&str) -> Result<R, AttemptError>,
145{
146 let mut prior: Vec<(String, AttemptError)> = Vec::new();
147
148 for provider in providers {
149 match attempt(provider) {
150 Ok(response) => {
151 return Ok(FailoverSuccess {
152 provider: (*provider).to_string(),
153 response,
154 prior_errors: prior,
155 });
156 }
157 Err(err) => {
158 let retryable = err.is_retryable();
159 prior.push(((*provider).to_string(), err));
160 if !retryable {
161 return Err(FailoverExhausted { attempts: prior });
162 }
163 }
164 }
165 }
166
167 Err(FailoverExhausted { attempts: prior })
168}
169
170pub fn parse_using_clause(raw: &str) -> Option<Vec<String>> {
178 let mut out: Vec<String> = Vec::new();
179 for segment in raw.split(',') {
180 let name = segment.trim();
181 if name.is_empty() {
182 continue;
183 }
184 if !out.iter().any(|existing| existing == name) {
185 out.push(name.to_string());
186 }
187 }
188 if out.is_empty() { None } else { Some(out) }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use std::cell::RefCell;
195
196 #[test]
199 fn transport_is_retryable() {
200 assert!(AttemptError::Transport("dns".into()).is_retryable());
201 }
202
203 #[test]
204 fn status_5xx_is_retryable() {
205 assert!(
206 AttemptError::Status5xx {
207 code: 502,
208 body: "bad gateway".into()
209 }
210 .is_retryable()
211 );
212 }
213
214 #[test]
215 fn timeout_is_retryable() {
216 assert!(AttemptError::Timeout(Duration::from_secs(30)).is_retryable());
217 }
218
219 #[test]
220 fn non_retryable_is_not_retryable() {
221 assert!(!AttemptError::NonRetryable("401 unauthorized".into()).is_retryable());
222 }
223
224 #[test]
227 fn first_provider_succeeds_no_prior_errors() {
228 let providers = ["groq", "openai", "anthropic"];
229 let result = run(&providers, |p| Ok::<_, AttemptError>(format!("answer from {p}")));
230 let ok = result.expect("should succeed");
231 assert_eq!(ok.provider, "groq");
232 assert_eq!(ok.response, "answer from groq");
233 assert!(ok.prior_errors.is_empty());
234 }
235
236 #[test]
237 fn second_provider_succeeds_after_5xx() {
238 let providers = ["groq", "openai"];
241 let calls = RefCell::new(0u32);
242 let result = run(&providers, |p| {
243 *calls.borrow_mut() += 1;
244 if p == "groq" {
245 Err(AttemptError::Status5xx {
246 code: 502,
247 body: "bad gateway".into(),
248 })
249 } else {
250 Ok(format!("answer from {p}"))
251 }
252 });
253 let ok = result.expect("should succeed");
254 assert_eq!(ok.provider, "openai");
255 assert_eq!(ok.response, "answer from openai");
256 assert_eq!(*calls.borrow(), 2);
257 assert_eq!(ok.prior_errors.len(), 1);
258 assert_eq!(ok.prior_errors[0].0, "groq");
259 }
260
261 #[test]
262 fn third_provider_succeeds_after_transport_and_timeout() {
263 let providers = ["groq", "openai", "anthropic"];
264 let result = run(&providers, |p| match p {
265 "groq" => Err(AttemptError::Transport("connection reset".into())),
266 "openai" => Err(AttemptError::Timeout(Duration::from_secs(30))),
267 _ => Ok(format!("answer from {p}")),
268 });
269 let ok = result.expect("should succeed");
270 assert_eq!(ok.provider, "anthropic");
271 assert_eq!(ok.prior_errors.len(), 2);
272 assert!(matches!(ok.prior_errors[0].1, AttemptError::Transport(_)));
273 assert!(matches!(ok.prior_errors[1].1, AttemptError::Timeout(_)));
274 }
275
276 #[test]
279 fn all_retryable_failures_exhausts_with_full_attempt_list() {
280 let providers = ["groq", "openai", "anthropic"];
281 let result = run::<String, _>(&providers, |p| {
282 Err(AttemptError::Status5xx {
283 code: 503,
284 body: format!("{p} unavailable"),
285 })
286 });
287 let exhausted = result.expect_err("should exhaust");
288 assert_eq!(exhausted.attempts.len(), 3);
289 assert_eq!(exhausted.attempts[0].0, "groq");
290 assert_eq!(exhausted.attempts[1].0, "openai");
291 assert_eq!(exhausted.attempts[2].0, "anthropic");
292 }
293
294 #[test]
295 fn non_retryable_short_circuits_without_trying_remaining() {
296 let providers = ["groq", "openai", "anthropic"];
300 let calls = RefCell::new(0u32);
301 let result = run::<String, _>(&providers, |p| {
302 *calls.borrow_mut() += 1;
303 if p == "groq" {
304 Err(AttemptError::NonRetryable("401 unauthorized".into()))
305 } else {
306 panic!("must not call sibling providers after non-retryable")
307 }
308 });
309 let exhausted = result.expect_err("should short-circuit");
310 assert_eq!(*calls.borrow(), 1);
311 assert_eq!(exhausted.attempts.len(), 1);
312 assert_eq!(exhausted.attempts[0].0, "groq");
313 assert!(matches!(
314 exhausted.attempts[0].1,
315 AttemptError::NonRetryable(_)
316 ));
317 }
318
319 #[test]
320 fn non_retryable_after_retryable_preserves_full_trail() {
321 let providers = ["groq", "openai", "anthropic"];
324 let calls = RefCell::new(Vec::<String>::new());
325 let result = run::<String, _>(&providers, |p| {
326 calls.borrow_mut().push(p.to_string());
327 match p {
328 "groq" => Err(AttemptError::Status5xx {
329 code: 502,
330 body: "bad".into(),
331 }),
332 "openai" => Err(AttemptError::NonRetryable("401".into())),
333 _ => panic!("anthropic must not be called"),
334 }
335 });
336 let exhausted = result.expect_err("should fail");
337 assert_eq!(*calls.borrow(), vec!["groq", "openai"]);
338 assert_eq!(exhausted.attempts.len(), 2);
339 }
340
341 #[test]
342 fn empty_provider_list_returns_empty_exhausted() {
343 let providers: [&str; 0] = [];
344 let result = run::<String, _>(&providers, |_| panic!("must not be called"));
345 let exhausted = result.expect_err("empty list yields exhausted");
346 assert!(exhausted.attempts.is_empty());
347 }
348
349 #[test]
352 fn attempt_fn_is_invoked_with_identical_inputs() {
353 #[derive(Clone, PartialEq, Debug)]
357 struct Req {
358 seed: u64,
359 temperature: f32,
360 strict: bool,
361 }
362 let req = Req {
363 seed: 42,
364 temperature: 0.0,
365 strict: true,
366 };
367 let providers = ["groq", "openai"];
368 let seen = RefCell::new(Vec::<Req>::new());
369 let _ = run::<(), _>(&providers, |_| {
370 seen.borrow_mut().push(req.clone());
371 Err(AttemptError::Transport("retry".into()))
372 });
373 let seen = seen.borrow();
374 assert_eq!(seen.len(), 2);
375 assert_eq!(seen[0], seen[1]);
376 }
377
378 #[test]
381 fn parse_using_simple() {
382 assert_eq!(
383 parse_using_clause("groq,openai"),
384 Some(vec!["groq".into(), "openai".into()])
385 );
386 }
387
388 #[test]
389 fn parse_using_trims_whitespace() {
390 assert_eq!(
391 parse_using_clause(" groq , openai , anthropic "),
392 Some(vec!["groq".into(), "openai".into(), "anthropic".into()])
393 );
394 }
395
396 #[test]
397 fn parse_using_drops_empty_segments() {
398 assert_eq!(
399 parse_using_clause("groq,,openai,"),
400 Some(vec!["groq".into(), "openai".into()])
401 );
402 }
403
404 #[test]
405 fn parse_using_dedupes_preserving_first_occurrence() {
406 assert_eq!(
407 parse_using_clause("groq,openai,groq"),
408 Some(vec!["groq".into(), "openai".into()])
409 );
410 }
411
412 #[test]
413 fn parse_using_empty_returns_none() {
414 assert_eq!(parse_using_clause(""), None);
415 assert_eq!(parse_using_clause(" , , "), None);
416 }
417
418 #[test]
419 fn parse_using_single_provider() {
420 assert_eq!(parse_using_clause("groq"), Some(vec!["groq".into()]));
421 }
422
423 #[test]
426 fn exhausted_display_lists_each_attempt() {
427 let exhausted = FailoverExhausted {
428 attempts: vec![
429 ("groq".into(), AttemptError::Transport("dns".into())),
430 (
431 "openai".into(),
432 AttemptError::Status5xx {
433 code: 502,
434 body: "bad".into(),
435 },
436 ),
437 ],
438 };
439 let s = format!("{exhausted}");
440 assert!(s.contains("groq"));
441 assert!(s.contains("openai"));
442 assert!(s.contains("502"));
443 }
444}