Skip to main content

llm_fallback_chain/
lib.rs

1//! Multi-provider failover for LLM calls.
2//!
3//! Wrap an ordered list of `(name, callable)` provider pairs. Each call tries
4//! them in order. If a provider returns `Err`, the chain decides whether to fall
5//! back or re-raise, then moves on. You get back a [`ChainResult`] with the
6//! return value, the winning provider name, and a trace of every failed attempt.
7//!
8//! ```
9//! use llm_fallback_chain::{FallbackChain, DynError};
10//!
11//! let chain = FallbackChain::<&str, String>::new(vec![
12//!     ("anthropic", Box::new(|_p: &&str| -> Result<String, DynError> {
13//!         Err("rate limited".into())
14//!     }) as _),
15//!     ("openai", Box::new(|p: &&str| -> Result<String, DynError> {
16//!         Ok(format!("o:{}", p))
17//!     }) as _),
18//! ]).unwrap();
19//!
20//! let result = chain.call(&"hi").unwrap();
21//! assert_eq!(result.value, "o:hi");
22//! assert_eq!(result.provider, "openai");
23//! assert_eq!(result.attempts.len(), 1);
24//! ```
25//!
26//! Pluggable predicate to whitelist only certain errors:
27//!
28//! ```
29//! use llm_fallback_chain::{FallbackChain, DynError};
30//!
31//! let chain = FallbackChain::<(), i32>::new(vec![
32//!     ("a", Box::new(|_: &()| -> Result<i32, DynError> {
33//!         Err("validation error".into())
34//!     }) as _),
35//!     ("b", Box::new(|_: &()| -> Result<i32, DynError> { Ok(1) }) as _),
36//! ])
37//! .unwrap()
38//! .with_should_fall_back(|err| err.to_string().contains("rate"));
39//!
40//! // validation error is not "rate", so we do not fall back; the chain re-raises.
41//! assert!(chain.call(&()).is_err());
42//! ```
43
44use std::error::Error as StdError;
45use std::fmt;
46use std::time::Instant;
47
48/// Boxed dynamic error used throughout the chain. Providers return this so
49/// every provider in the chain can fail in its own way.
50pub type DynError = Box<dyn StdError + Send + Sync>;
51
52/// One provider attempt within a chain call.
53#[derive(Debug)]
54pub struct Attempt {
55    /// Provider name as passed to [`FallbackChain::new`].
56    pub name: String,
57    /// The error the provider returned (`None` on success).
58    pub error: Option<DynError>,
59    /// Wall time the provider took, in milliseconds.
60    pub duration_ms: f64,
61}
62
63/// Outcome of a successful [`FallbackChain::call`].
64#[derive(Debug)]
65pub struct ChainResult<O> {
66    /// Whatever the winning provider returned.
67    pub value: O,
68    /// Name of the provider that succeeded.
69    pub provider: String,
70    /// Failed attempts that came before the success. Empty when the first
71    /// provider worked.
72    pub attempts: Vec<Attempt>,
73}
74
75/// Raised when every provider in the chain failed.
76#[derive(Debug)]
77pub struct AllProvidersFailed {
78    /// One [`Attempt`] per provider tried, in order.
79    pub attempts: Vec<Attempt>,
80}
81
82impl fmt::Display for AllProvidersFailed {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        let names: Vec<&str> = self.attempts.iter().map(|a| a.name.as_str()).collect();
85        write!(f, "all providers failed: {}", names.join(", "))
86    }
87}
88
89impl StdError for AllProvidersFailed {}
90
91/// Sync provider callable: `(input) -> Result<O, DynError>`.
92pub type SyncProvider<I, O> = Box<dyn Fn(&I) -> Result<O, DynError> + Send + Sync>;
93
94/// Predicate deciding whether a given error should cause fallback.
95/// Default: any error causes fallback.
96pub type ShouldFallBack = Box<dyn Fn(&DynError) -> bool + Send + Sync>;
97
98/// Audit callback fired after a provider fails, before the next is tried.
99/// `(failed_name, error, next_name)`.
100pub type OnFallback = Box<dyn Fn(&str, &DynError, &str) + Send + Sync>;
101
102fn default_should_fall_back(_err: &DynError) -> bool {
103    true
104}
105
106/// Ordered list of LLM providers to try in sequence.
107///
108/// Each provider is a `(name, callable)` pair. [`call`](Self::call) tries them
109/// in order until one returns `Ok`. Returns a [`ChainResult`] describing what
110/// won and what failed. If every provider returns `Err`, an
111/// [`AllProvidersFailed`] error is returned.
112pub struct FallbackChain<I, O> {
113    providers: Vec<(String, SyncProvider<I, O>)>,
114    should_fall_back: ShouldFallBack,
115    on_fallback: Option<OnFallback>,
116}
117
118impl<I, O> FallbackChain<I, O> {
119    /// Build a new chain from an ordered list of `(name, provider)` pairs.
120    ///
121    /// Returns `Err` if `providers` is empty.
122    pub fn new<S: Into<String>>(
123        providers: Vec<(S, SyncProvider<I, O>)>,
124    ) -> Result<Self, &'static str> {
125        if providers.is_empty() {
126            return Err("providers must be a non-empty list");
127        }
128        let providers = providers
129            .into_iter()
130            .map(|(name, fn_)| (name.into(), fn_))
131            .collect();
132        Ok(Self {
133            providers,
134            should_fall_back: Box::new(default_should_fall_back),
135            on_fallback: None,
136        })
137    }
138
139    /// Set the predicate that decides whether a given error should cause
140    /// fallback. Default falls back on any error.
141    pub fn with_should_fall_back<F>(mut self, f: F) -> Self
142    where
143        F: Fn(&DynError) -> bool + Send + Sync + 'static,
144    {
145        self.should_fall_back = Box::new(f);
146        self
147    }
148
149    /// Set the audit callback fired after each fallback. Called with
150    /// `(failed_name, error, next_name)` before the next provider is tried.
151    pub fn with_on_fallback<F>(mut self, f: F) -> Self
152    where
153        F: Fn(&str, &DynError, &str) + Send + Sync + 'static,
154    {
155        self.on_fallback = Some(Box::new(f));
156        self
157    }
158
159    /// Skip a provider entirely when `predicate(&name)` returns `true`.
160    /// Useful when an upstream circuit breaker says a provider is open.
161    /// Returns `Err` if the filter removes every provider.
162    pub fn with_skip<P>(mut self, predicate: P) -> Result<Self, &'static str>
163    where
164        P: Fn(&str) -> bool,
165    {
166        self.providers.retain(|(name, _)| !predicate(name));
167        if self.providers.is_empty() {
168            return Err("with_skip removed all providers");
169        }
170        Ok(self)
171    }
172
173    /// Provider names in order.
174    pub fn names(&self) -> Vec<&str> {
175        self.providers.iter().map(|(n, _)| n.as_str()).collect()
176    }
177
178    /// Try each provider in order until one returns `Ok`.
179    pub fn call(&self, input: &I) -> Result<ChainResult<O>, DynError> {
180        let mut failures: Vec<Attempt> = Vec::new();
181        let last = self.providers.len() - 1;
182        for (i, (name, fn_)) in self.providers.iter().enumerate() {
183            let start = Instant::now();
184            match fn_(input) {
185                Ok(value) => {
186                    return Ok(ChainResult {
187                        value,
188                        provider: name.clone(),
189                        attempts: failures,
190                    });
191                }
192                Err(err) => {
193                    let elapsed = start.elapsed().as_secs_f64() * 1000.0;
194                    if !(self.should_fall_back)(&err) {
195                        return Err(err);
196                    }
197                    if i < last {
198                        if let Some(cb) = &self.on_fallback {
199                            let next_name = &self.providers[i + 1].0;
200                            cb(name, &err, next_name);
201                        }
202                    }
203                    failures.push(Attempt {
204                        name: name.clone(),
205                        error: Some(err),
206                        duration_ms: elapsed,
207                    });
208                }
209            }
210        }
211        Err(Box::new(AllProvidersFailed { attempts: failures }))
212    }
213}
214
215#[cfg(feature = "tokio")]
216mod async_chain {
217    use super::{
218        default_should_fall_back, AllProvidersFailed, Attempt, ChainResult, DynError, OnFallback,
219        ShouldFallBack,
220    };
221    use futures::future::BoxFuture;
222    use std::time::Instant;
223
224    /// Async provider callable: `(input) -> Future<Result<O, DynError>>`.
225    /// The returned future borrows from the input for lifetime `'a`.
226    pub type AsyncProvider<I, O> =
227        Box<dyn for<'a> Fn(&'a I) -> BoxFuture<'a, Result<O, DynError>> + Send + Sync>;
228
229    /// Helper to construct an [`AsyncProvider`] from a closure. Avoids the
230    /// HRTB inference rough edge: the explicit function signature pins the
231    /// `for<'a>` bound so callers can pass a plain `|input| async { ... }`
232    /// without manual lifetime annotations.
233    pub fn async_provider<I, O, F, Fut>(f: F) -> AsyncProvider<I, O>
234    where
235        F: for<'a> Fn(&'a I) -> Fut + Send + Sync + 'static,
236        Fut: std::future::Future<Output = Result<O, DynError>> + Send + 'static,
237        I: 'static,
238    {
239        Box::new(move |i: &I| {
240            let fut = f(i);
241            Box::pin(fut) as BoxFuture<'_, _>
242        })
243    }
244
245    /// Async variant of [`crate::FallbackChain`]. Behaves the same way but
246    /// awaits provider futures.
247    pub struct AsyncFallbackChain<I, O> {
248        providers: Vec<(String, AsyncProvider<I, O>)>,
249        should_fall_back: ShouldFallBack,
250        on_fallback: Option<OnFallback>,
251    }
252
253    impl<I: Send + Sync, O: Send> AsyncFallbackChain<I, O> {
254        pub fn new<S: Into<String>>(
255            providers: Vec<(S, AsyncProvider<I, O>)>,
256        ) -> Result<Self, &'static str> {
257            if providers.is_empty() {
258                return Err("providers must be a non-empty list");
259            }
260            let providers = providers
261                .into_iter()
262                .map(|(name, fn_)| (name.into(), fn_))
263                .collect();
264            Ok(Self {
265                providers,
266                should_fall_back: Box::new(default_should_fall_back),
267                on_fallback: None,
268            })
269        }
270
271        pub fn with_should_fall_back<F>(mut self, f: F) -> Self
272        where
273            F: Fn(&DynError) -> bool + Send + Sync + 'static,
274        {
275            self.should_fall_back = Box::new(f);
276            self
277        }
278
279        pub fn with_on_fallback<F>(mut self, f: F) -> Self
280        where
281            F: Fn(&str, &DynError, &str) + Send + Sync + 'static,
282        {
283            self.on_fallback = Some(Box::new(f));
284            self
285        }
286
287        pub fn with_skip<P>(mut self, predicate: P) -> Result<Self, &'static str>
288        where
289            P: Fn(&str) -> bool,
290        {
291            self.providers.retain(|(name, _)| !predicate(name));
292            if self.providers.is_empty() {
293                return Err("with_skip removed all providers");
294            }
295            Ok(self)
296        }
297
298        pub fn names(&self) -> Vec<&str> {
299            self.providers.iter().map(|(n, _)| n.as_str()).collect()
300        }
301
302        pub async fn call(&self, input: &I) -> Result<ChainResult<O>, DynError> {
303            let mut failures: Vec<Attempt> = Vec::new();
304            let last = self.providers.len() - 1;
305            for (i, (name, fn_)) in self.providers.iter().enumerate() {
306                let start = Instant::now();
307                match fn_(input).await {
308                    Ok(value) => {
309                        return Ok(ChainResult {
310                            value,
311                            provider: name.clone(),
312                            attempts: failures,
313                        });
314                    }
315                    Err(err) => {
316                        let elapsed = start.elapsed().as_secs_f64() * 1000.0;
317                        if !(self.should_fall_back)(&err) {
318                            return Err(err);
319                        }
320                        if i < last {
321                            if let Some(cb) = &self.on_fallback {
322                                let next_name = &self.providers[i + 1].0;
323                                cb(name, &err, next_name);
324                            }
325                        }
326                        failures.push(Attempt {
327                            name: name.clone(),
328                            error: Some(err),
329                            duration_ms: elapsed,
330                        });
331                    }
332                }
333            }
334            Err(Box::new(AllProvidersFailed { attempts: failures }))
335        }
336    }
337}
338
339#[cfg(feature = "tokio")]
340pub use async_chain::{async_provider, AsyncFallbackChain, AsyncProvider};
341
342#[cfg(feature = "serde")]
343mod serde_impls {
344    use super::Attempt;
345    use serde::Serialize;
346
347    /// Lossy serializable view of an [`Attempt`]. The error is recorded as its
348    /// `Display` string because boxed `dyn Error` is not naturally `Serialize`.
349    #[derive(Debug, Serialize)]
350    pub struct AttemptView {
351        pub name: String,
352        pub error: Option<String>,
353        pub duration_ms: f64,
354    }
355
356    impl From<&Attempt> for AttemptView {
357        fn from(a: &Attempt) -> Self {
358            Self {
359                name: a.name.clone(),
360                error: a.error.as_ref().map(|e| e.to_string()),
361                duration_ms: a.duration_ms,
362            }
363        }
364    }
365}
366
367#[cfg(feature = "serde")]
368pub use serde_impls::AttemptView;