use std::fmt;
use std::time::Duration;
#[derive(Debug, Clone, PartialEq)]
pub enum AttemptError {
Transport(String),
Status5xx { code: u16, body: String },
Timeout(Duration),
NonRetryable(String),
}
impl AttemptError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
AttemptError::Transport(_) | AttemptError::Status5xx { .. } | AttemptError::Timeout(_)
)
}
}
impl fmt::Display for AttemptError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AttemptError::Transport(msg) => write!(f, "transport: {msg}"),
AttemptError::Status5xx { code, body } => write!(f, "http {code}: {body}"),
AttemptError::Timeout(d) => write!(f, "timeout after {}ms", d.as_millis()),
AttemptError::NonRetryable(msg) => write!(f, "non-retryable: {msg}"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FailoverSuccess<R> {
pub provider: String,
pub response: R,
pub prior_errors: Vec<(String, AttemptError)>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FailoverExhausted {
pub attempts: Vec<(String, AttemptError)>,
}
impl fmt::Display for FailoverExhausted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "all providers failed:")?;
for (provider, err) in &self.attempts {
write!(f, " [{provider}: {err}]")?;
}
Ok(())
}
}
pub fn run<R, F>(
providers: &[&str],
mut attempt: F,
) -> Result<FailoverSuccess<R>, FailoverExhausted>
where
F: FnMut(&str) -> Result<R, AttemptError>,
{
let mut prior: Vec<(String, AttemptError)> = Vec::new();
for provider in providers {
match attempt(provider) {
Ok(response) => {
return Ok(FailoverSuccess {
provider: (*provider).to_string(),
response,
prior_errors: prior,
});
}
Err(err) => {
let retryable = err.is_retryable();
prior.push(((*provider).to_string(), err));
if !retryable {
return Err(FailoverExhausted { attempts: prior });
}
}
}
}
Err(FailoverExhausted { attempts: prior })
}
pub fn parse_using_clause(raw: &str) -> Option<Vec<String>> {
let mut out: Vec<String> = Vec::new();
for segment in raw.split(',') {
let name = segment.trim();
if name.is_empty() {
continue;
}
if !out.iter().any(|existing| existing == name) {
out.push(name.to_string());
}
}
if out.is_empty() {
None
} else {
Some(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::RefCell;
#[test]
fn transport_is_retryable() {
assert!(AttemptError::Transport("dns".into()).is_retryable());
}
#[test]
fn status_5xx_is_retryable() {
assert!(AttemptError::Status5xx {
code: 502,
body: "bad gateway".into()
}
.is_retryable());
}
#[test]
fn timeout_is_retryable() {
assert!(AttemptError::Timeout(Duration::from_secs(30)).is_retryable());
}
#[test]
fn non_retryable_is_not_retryable() {
assert!(!AttemptError::NonRetryable("401 unauthorized".into()).is_retryable());
}
#[test]
fn first_provider_succeeds_no_prior_errors() {
let providers = ["groq", "openai", "anthropic"];
let result = run(&providers, |p| {
Ok::<_, AttemptError>(format!("answer from {p}"))
});
let ok = result.expect("should succeed");
assert_eq!(ok.provider, "groq");
assert_eq!(ok.response, "answer from groq");
assert!(ok.prior_errors.is_empty());
}
#[test]
fn second_provider_succeeds_after_5xx() {
let providers = ["groq", "openai"];
let calls = RefCell::new(0u32);
let result = run(&providers, |p| {
*calls.borrow_mut() += 1;
if p == "groq" {
Err(AttemptError::Status5xx {
code: 502,
body: "bad gateway".into(),
})
} else {
Ok(format!("answer from {p}"))
}
});
let ok = result.expect("should succeed");
assert_eq!(ok.provider, "openai");
assert_eq!(ok.response, "answer from openai");
assert_eq!(*calls.borrow(), 2);
assert_eq!(ok.prior_errors.len(), 1);
assert_eq!(ok.prior_errors[0].0, "groq");
}
#[test]
fn third_provider_succeeds_after_transport_and_timeout() {
let providers = ["groq", "openai", "anthropic"];
let result = run(&providers, |p| match p {
"groq" => Err(AttemptError::Transport("connection reset".into())),
"openai" => Err(AttemptError::Timeout(Duration::from_secs(30))),
_ => Ok(format!("answer from {p}")),
});
let ok = result.expect("should succeed");
assert_eq!(ok.provider, "anthropic");
assert_eq!(ok.prior_errors.len(), 2);
assert!(matches!(ok.prior_errors[0].1, AttemptError::Transport(_)));
assert!(matches!(ok.prior_errors[1].1, AttemptError::Timeout(_)));
}
#[test]
fn all_retryable_failures_exhausts_with_full_attempt_list() {
let providers = ["groq", "openai", "anthropic"];
let result = run::<String, _>(&providers, |p| {
Err(AttemptError::Status5xx {
code: 503,
body: format!("{p} unavailable"),
})
});
let exhausted = result.expect_err("should exhaust");
assert_eq!(exhausted.attempts.len(), 3);
assert_eq!(exhausted.attempts[0].0, "groq");
assert_eq!(exhausted.attempts[1].0, "openai");
assert_eq!(exhausted.attempts[2].0, "anthropic");
}
#[test]
fn non_retryable_short_circuits_without_trying_remaining() {
let providers = ["groq", "openai", "anthropic"];
let calls = RefCell::new(0u32);
let result = run::<String, _>(&providers, |p| {
*calls.borrow_mut() += 1;
if p == "groq" {
Err(AttemptError::NonRetryable("401 unauthorized".into()))
} else {
panic!("must not call sibling providers after non-retryable")
}
});
let exhausted = result.expect_err("should short-circuit");
assert_eq!(*calls.borrow(), 1);
assert_eq!(exhausted.attempts.len(), 1);
assert_eq!(exhausted.attempts[0].0, "groq");
assert!(matches!(
exhausted.attempts[0].1,
AttemptError::NonRetryable(_)
));
}
#[test]
fn non_retryable_after_retryable_preserves_full_trail() {
let providers = ["groq", "openai", "anthropic"];
let calls = RefCell::new(Vec::<String>::new());
let result = run::<String, _>(&providers, |p| {
calls.borrow_mut().push(p.to_string());
match p {
"groq" => Err(AttemptError::Status5xx {
code: 502,
body: "bad".into(),
}),
"openai" => Err(AttemptError::NonRetryable("401".into())),
_ => panic!("anthropic must not be called"),
}
});
let exhausted = result.expect_err("should fail");
assert_eq!(*calls.borrow(), vec!["groq", "openai"]);
assert_eq!(exhausted.attempts.len(), 2);
}
#[test]
fn empty_provider_list_returns_empty_exhausted() {
let providers: [&str; 0] = [];
let result = run::<String, _>(&providers, |_| panic!("must not be called"));
let exhausted = result.expect_err("empty list yields exhausted");
assert!(exhausted.attempts.is_empty());
}
#[test]
fn attempt_fn_is_invoked_with_identical_inputs() {
#[derive(Clone, PartialEq, Debug)]
struct Req {
seed: u64,
temperature: f32,
strict: bool,
}
let req = Req {
seed: 42,
temperature: 0.0,
strict: true,
};
let providers = ["groq", "openai"];
let seen = RefCell::new(Vec::<Req>::new());
let _ = run::<(), _>(&providers, |_| {
seen.borrow_mut().push(req.clone());
Err(AttemptError::Transport("retry".into()))
});
let seen = seen.borrow();
assert_eq!(seen.len(), 2);
assert_eq!(seen[0], seen[1]);
}
#[test]
fn parse_using_simple() {
assert_eq!(
parse_using_clause("groq,openai"),
Some(vec!["groq".into(), "openai".into()])
);
}
#[test]
fn parse_using_trims_whitespace() {
assert_eq!(
parse_using_clause(" groq , openai , anthropic "),
Some(vec!["groq".into(), "openai".into(), "anthropic".into()])
);
}
#[test]
fn parse_using_drops_empty_segments() {
assert_eq!(
parse_using_clause("groq,,openai,"),
Some(vec!["groq".into(), "openai".into()])
);
}
#[test]
fn parse_using_dedupes_preserving_first_occurrence() {
assert_eq!(
parse_using_clause("groq,openai,groq"),
Some(vec!["groq".into(), "openai".into()])
);
}
#[test]
fn parse_using_empty_returns_none() {
assert_eq!(parse_using_clause(""), None);
assert_eq!(parse_using_clause(" , , "), None);
}
#[test]
fn parse_using_single_provider() {
assert_eq!(parse_using_clause("groq"), Some(vec!["groq".into()]));
}
#[test]
fn exhausted_display_lists_each_attempt() {
let exhausted = FailoverExhausted {
attempts: vec![
("groq".into(), AttemptError::Transport("dns".into())),
(
"openai".into(),
AttemptError::Status5xx {
code: 502,
body: "bad".into(),
},
),
],
};
let s = format!("{exhausted}");
assert!(s.contains("groq"));
assert!(s.contains("openai"));
assert!(s.contains("502"));
}
}