use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::provider_api::{LlmError, LlmProvider, LlmRequest, LlmResponse};
pub fn try_with_fallback<T, R, E>(
candidates: &[T],
start_index: usize,
operation: impl Fn(&T) -> Result<R, E>,
is_retryable: impl Fn(&E) -> bool,
) -> (Result<R, E>, usize) {
let n = candidates.len();
let mut last_error = None;
let mut idx = start_index;
for _ in 0..n {
let candidate = &candidates[idx % n];
match operation(candidate) {
Ok(result) => return (Ok(result), idx % n),
Err(e) => {
if is_retryable(&e) {
last_error = Some(e);
idx += 1;
} else {
return (Err(e), idx % n);
}
}
}
}
(
Err(last_error.expect("at least one candidate must exist")),
idx % n,
)
}
pub struct FallbackLlmProvider {
candidates: Vec<Arc<dyn LlmProvider>>,
current: AtomicUsize,
}
impl FallbackLlmProvider {
pub fn new(candidates: Vec<Arc<dyn LlmProvider>>) -> Self {
assert!(
!candidates.is_empty(),
"FallbackLlmProvider requires at least one candidate"
);
Self {
candidates,
current: AtomicUsize::new(0),
}
}
#[must_use]
pub fn candidate_count(&self) -> usize {
self.candidates.len()
}
#[must_use]
pub fn active_index(&self) -> usize {
self.current.load(Ordering::Relaxed)
}
#[must_use]
pub fn describe_candidates(&self) -> Vec<String> {
self.candidates
.iter()
.map(|p| format!("{}/{}", p.name(), p.model()))
.collect()
}
}
impl LlmProvider for FallbackLlmProvider {
fn name(&self) -> &'static str {
self.candidates[self.current.load(Ordering::Relaxed)].name()
}
fn model(&self) -> &str {
self.candidates[self.current.load(Ordering::Relaxed)].model()
}
fn complete(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
let start = self.current.load(Ordering::Relaxed);
let (result, used_index) = try_with_fallback(
&self.candidates,
start,
|provider| provider.complete(request),
|e| e.retryable,
);
if result.is_ok() && used_index != start {
self.current.store(used_index, Ordering::Relaxed);
}
result
}
fn health_check(&self) -> Result<(), LlmError> {
let start = self.current.load(Ordering::Relaxed);
let (result, used_index) = try_with_fallback(
&self.candidates,
start,
|provider| provider.health_check(),
|e| e.retryable,
);
if result.is_ok() && used_index != start {
self.current.store(used_index, Ordering::Relaxed);
}
result
}
}
impl std::fmt::Debug for FallbackLlmProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let active = self.current.load(Ordering::Relaxed);
f.debug_struct("FallbackLlmProvider")
.field(
"active",
&format_args!(
"{}/{}",
self.candidates[active].name(),
self.candidates[active].model()
),
)
.field("candidates", &self.describe_candidates())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn try_with_fallback_returns_first_success() {
let items = vec![1, 2, 3];
let (result, idx) = try_with_fallback(
&items,
0,
|&n| if n >= 1 { Ok(n * 10) } else { Err("fail") },
|_| true,
);
assert_eq!(result, Ok(10));
assert_eq!(idx, 0);
}
#[test]
fn try_with_fallback_skips_retryable_errors() {
let items = vec![1, 2, 3];
let (result, idx) = try_with_fallback(
&items,
0,
|&n| if n >= 3 { Ok(n * 10) } else { Err("retryable") },
|_| true, );
assert_eq!(result, Ok(30));
assert_eq!(idx, 2);
}
#[test]
fn try_with_fallback_stops_on_non_retryable() {
let items = vec![1, 2, 3];
let (result, idx) = try_with_fallback(
&items,
0,
|&n| if n == 1 { Err("fatal") } else { Ok(n) },
|_| false, );
assert_eq!(result, Err("fatal"));
assert_eq!(idx, 0);
}
#[test]
fn try_with_fallback_wraps_around_from_start_index() {
let items = vec![10, 20, 30];
let (result, idx) = try_with_fallback(
&items,
2, |&n| if n == 10 { Ok(n) } else { Err("retry") },
|_| true,
);
assert_eq!(result, Ok(10));
assert_eq!(idx, 0);
}
#[test]
fn try_with_fallback_all_fail() {
let items = vec![1, 2, 3];
let (result, _) = try_with_fallback(
&items,
0,
|_: &i32| -> Result<i32, &str> { Err("all bad") },
|_| true,
);
assert!(result.is_err());
}
}