1use std::fmt;
55use std::time::Duration;
56
57#[derive(Debug, Clone, PartialEq)]
63pub enum AttemptError {
64 Transport(String),
69 Status5xx { code: u16, body: String },
72 Timeout(Duration),
74 NonRetryable(String),
81}
82
83impl AttemptError {
84 pub fn is_retryable(&self) -> bool {
86 matches!(
87 self,
88 AttemptError::Transport(_) | AttemptError::Status5xx { .. } | AttemptError::Timeout(_)
89 )
90 }
91}
92
93impl fmt::Display for AttemptError {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 match self {
96 AttemptError::Transport(msg) => write!(f, "transport: {msg}"),
97 AttemptError::Status5xx { code, body } => write!(f, "http {code}: {body}"),
98 AttemptError::Timeout(d) => write!(f, "timeout after {}ms", d.as_millis()),
99 AttemptError::NonRetryable(msg) => write!(f, "non-retryable: {msg}"),
100 }
101 }
102}
103
104#[derive(Debug, Clone, PartialEq)]
108pub struct FailoverSuccess<R> {
109 pub provider: String,
110 pub response: R,
111 pub prior_errors: Vec<(String, AttemptError)>,
112}
113
114#[derive(Debug, Clone, PartialEq)]
118pub struct FailoverExhausted {
119 pub attempts: Vec<(String, AttemptError)>,
120}
121
122impl fmt::Display for FailoverExhausted {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 write!(f, "all providers failed:")?;
125 for (provider, err) in &self.attempts {
126 write!(f, " [{provider}: {err}]")?;
127 }
128 Ok(())
129 }
130}
131
132pub fn run<R, F>(
143 providers: &[&str],
144 mut attempt: F,
145) -> Result<FailoverSuccess<R>, FailoverExhausted>
146where
147 F: FnMut(&str) -> Result<R, AttemptError>,
148{
149 let mut prior: Vec<(String, AttemptError)> = Vec::new();
150
151 for provider in providers {
152 match attempt(provider) {
153 Ok(response) => {
154 return Ok(FailoverSuccess {
155 provider: (*provider).to_string(),
156 response,
157 prior_errors: prior,
158 });
159 }
160 Err(err) => {
161 let retryable = err.is_retryable();
162 prior.push(((*provider).to_string(), err));
163 if !retryable {
164 return Err(FailoverExhausted { attempts: prior });
165 }
166 }
167 }
168 }
169
170 Err(FailoverExhausted { attempts: prior })
171}
172
173pub fn parse_using_clause(raw: &str) -> Option<Vec<String>> {
181 let mut out: Vec<String> = Vec::new();
182 for segment in raw.split(',') {
183 let name = segment.trim();
184 if name.is_empty() {
185 continue;
186 }
187 if !out.iter().any(|existing| existing == name) {
188 out.push(name.to_string());
189 }
190 }
191 if out.is_empty() {
192 None
193 } else {
194 Some(out)
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use std::cell::RefCell;
202
203 #[test]
206 fn transport_is_retryable() {
207 assert!(AttemptError::Transport("dns".into()).is_retryable());
208 }
209
210 #[test]
211 fn status_5xx_is_retryable() {
212 assert!(AttemptError::Status5xx {
213 code: 502,
214 body: "bad gateway".into()
215 }
216 .is_retryable());
217 }
218
219 #[test]
220 fn timeout_is_retryable() {
221 assert!(AttemptError::Timeout(Duration::from_secs(30)).is_retryable());
222 }
223
224 #[test]
225 fn non_retryable_is_not_retryable() {
226 assert!(!AttemptError::NonRetryable("401 unauthorized".into()).is_retryable());
227 }
228
229 #[test]
232 fn first_provider_succeeds_no_prior_errors() {
233 let providers = ["groq", "openai", "anthropic"];
234 let result = run(&providers, |p| {
235 Ok::<_, AttemptError>(format!("answer from {p}"))
236 });
237 let ok = result.expect("should succeed");
238 assert_eq!(ok.provider, "groq");
239 assert_eq!(ok.response, "answer from groq");
240 assert!(ok.prior_errors.is_empty());
241 }
242
243 #[test]
244 fn second_provider_succeeds_after_5xx() {
245 let providers = ["groq", "openai"];
248 let calls = RefCell::new(0u32);
249 let result = run(&providers, |p| {
250 *calls.borrow_mut() += 1;
251 if p == "groq" {
252 Err(AttemptError::Status5xx {
253 code: 502,
254 body: "bad gateway".into(),
255 })
256 } else {
257 Ok(format!("answer from {p}"))
258 }
259 });
260 let ok = result.expect("should succeed");
261 assert_eq!(ok.provider, "openai");
262 assert_eq!(ok.response, "answer from openai");
263 assert_eq!(*calls.borrow(), 2);
264 assert_eq!(ok.prior_errors.len(), 1);
265 assert_eq!(ok.prior_errors[0].0, "groq");
266 }
267
268 #[test]
269 fn third_provider_succeeds_after_transport_and_timeout() {
270 let providers = ["groq", "openai", "anthropic"];
271 let result = run(&providers, |p| match p {
272 "groq" => Err(AttemptError::Transport("connection reset".into())),
273 "openai" => Err(AttemptError::Timeout(Duration::from_secs(30))),
274 _ => Ok(format!("answer from {p}")),
275 });
276 let ok = result.expect("should succeed");
277 assert_eq!(ok.provider, "anthropic");
278 assert_eq!(ok.prior_errors.len(), 2);
279 assert!(matches!(ok.prior_errors[0].1, AttemptError::Transport(_)));
280 assert!(matches!(ok.prior_errors[1].1, AttemptError::Timeout(_)));
281 }
282
283 #[test]
286 fn all_retryable_failures_exhausts_with_full_attempt_list() {
287 let providers = ["groq", "openai", "anthropic"];
288 let result = run::<String, _>(&providers, |p| {
289 Err(AttemptError::Status5xx {
290 code: 503,
291 body: format!("{p} unavailable"),
292 })
293 });
294 let exhausted = result.expect_err("should exhaust");
295 assert_eq!(exhausted.attempts.len(), 3);
296 assert_eq!(exhausted.attempts[0].0, "groq");
297 assert_eq!(exhausted.attempts[1].0, "openai");
298 assert_eq!(exhausted.attempts[2].0, "anthropic");
299 }
300
301 #[test]
302 fn non_retryable_short_circuits_without_trying_remaining() {
303 let providers = ["groq", "openai", "anthropic"];
307 let calls = RefCell::new(0u32);
308 let result = run::<String, _>(&providers, |p| {
309 *calls.borrow_mut() += 1;
310 if p == "groq" {
311 Err(AttemptError::NonRetryable("401 unauthorized".into()))
312 } else {
313 panic!("must not call sibling providers after non-retryable")
314 }
315 });
316 let exhausted = result.expect_err("should short-circuit");
317 assert_eq!(*calls.borrow(), 1);
318 assert_eq!(exhausted.attempts.len(), 1);
319 assert_eq!(exhausted.attempts[0].0, "groq");
320 assert!(matches!(
321 exhausted.attempts[0].1,
322 AttemptError::NonRetryable(_)
323 ));
324 }
325
326 #[test]
327 fn non_retryable_after_retryable_preserves_full_trail() {
328 let providers = ["groq", "openai", "anthropic"];
331 let calls = RefCell::new(Vec::<String>::new());
332 let result = run::<String, _>(&providers, |p| {
333 calls.borrow_mut().push(p.to_string());
334 match p {
335 "groq" => Err(AttemptError::Status5xx {
336 code: 502,
337 body: "bad".into(),
338 }),
339 "openai" => Err(AttemptError::NonRetryable("401".into())),
340 _ => panic!("anthropic must not be called"),
341 }
342 });
343 let exhausted = result.expect_err("should fail");
344 assert_eq!(*calls.borrow(), vec!["groq", "openai"]);
345 assert_eq!(exhausted.attempts.len(), 2);
346 }
347
348 #[test]
349 fn empty_provider_list_returns_empty_exhausted() {
350 let providers: [&str; 0] = [];
351 let result = run::<String, _>(&providers, |_| panic!("must not be called"));
352 let exhausted = result.expect_err("empty list yields exhausted");
353 assert!(exhausted.attempts.is_empty());
354 }
355
356 #[test]
359 fn attempt_fn_is_invoked_with_identical_inputs() {
360 #[derive(Clone, PartialEq, Debug)]
364 struct Req {
365 seed: u64,
366 temperature: f32,
367 strict: bool,
368 }
369 let req = Req {
370 seed: 42,
371 temperature: 0.0,
372 strict: true,
373 };
374 let providers = ["groq", "openai"];
375 let seen = RefCell::new(Vec::<Req>::new());
376 let _ = run::<(), _>(&providers, |_| {
377 seen.borrow_mut().push(req.clone());
378 Err(AttemptError::Transport("retry".into()))
379 });
380 let seen = seen.borrow();
381 assert_eq!(seen.len(), 2);
382 assert_eq!(seen[0], seen[1]);
383 }
384
385 #[test]
388 fn parse_using_simple() {
389 assert_eq!(
390 parse_using_clause("groq,openai"),
391 Some(vec!["groq".into(), "openai".into()])
392 );
393 }
394
395 #[test]
396 fn parse_using_trims_whitespace() {
397 assert_eq!(
398 parse_using_clause(" groq , openai , anthropic "),
399 Some(vec!["groq".into(), "openai".into(), "anthropic".into()])
400 );
401 }
402
403 #[test]
404 fn parse_using_drops_empty_segments() {
405 assert_eq!(
406 parse_using_clause("groq,,openai,"),
407 Some(vec!["groq".into(), "openai".into()])
408 );
409 }
410
411 #[test]
412 fn parse_using_dedupes_preserving_first_occurrence() {
413 assert_eq!(
414 parse_using_clause("groq,openai,groq"),
415 Some(vec!["groq".into(), "openai".into()])
416 );
417 }
418
419 #[test]
420 fn parse_using_empty_returns_none() {
421 assert_eq!(parse_using_clause(""), None);
422 assert_eq!(parse_using_clause(" , , "), None);
423 }
424
425 #[test]
426 fn parse_using_single_provider() {
427 assert_eq!(parse_using_clause("groq"), Some(vec!["groq".into()]));
428 }
429
430 #[test]
433 fn exhausted_display_lists_each_attempt() {
434 let exhausted = FailoverExhausted {
435 attempts: vec![
436 ("groq".into(), AttemptError::Transport("dns".into())),
437 (
438 "openai".into(),
439 AttemptError::Status5xx {
440 code: 502,
441 body: "bad".into(),
442 },
443 ),
444 ],
445 };
446 let s = format!("{exhausted}");
447 assert!(s.contains("groq"));
448 assert!(s.contains("openai"));
449 assert!(s.contains("502"));
450 }
451}