Skip to main content

oxi_ai/
fallback_chain.rs

1//! Fallback chain for ordered model failover.
2//!
3//! A `FallbackChain` manages an ordered list of models to try sequentially
4//! when a request fails. This enables automatic failover to backup models
5//! without requiring manual intervention.
6//!
7//! # Usage
8//!
9//! ```ignore
10//! use oxi_ai::fallback_chain::FallbackChain;
11//!
12//! // Create from provider/model strings
13//! let chain = FallbackChain::from_ids(&[
14//!     "anthropic/claude-sonnet-4-20250514",
15//!     "openai/gpt-4o",
16//! ])?;
17//!
18//! // Iterate through models
19//! for model in chain.iter() {
20//!     println!("Trying: {}", model.name);
21//! }
22//!
23//! // Get the next model after a failure
24//! if let Some(next) = chain.next("anthropic/claude-sonnet-4-20250514") {
25//!     println!("Fallback to: {}", next.name);
26//! }
27//! ```
28
29use crate::model_db::{get_model_entry, ModelEntry};
30
31/// An ordered chain of models for sequential fallback on failure.
32///
33/// When a model request fails, the chain allows easy iteration to the next
34/// available model in priority order. This is useful for implementing
35/// automatic failover strategies.
36///
37/// # Example
38///
39/// ```ignore
40/// use oxi_ai::fallback_chain::FallbackChain;
41///
42/// // From string IDs
43/// let chain = FallbackChain::from_ids(&["openai/gpt-4o", "google/gemini-2.0-flash"])?;
44///
45/// // Direct construction
46/// let models = vec![model1, model2];
47/// let chain = FallbackChain::new(models);
48///
49/// // Find next model after failure
50/// if let Some(next) = chain.next("openai/gpt-4o") {
51///     // Use next model...
52/// }
53/// ```
54#[derive(Debug, Clone, PartialEq)]
55pub struct FallbackChain {
56    /// The ordered list of model entries.
57    models: Vec<&'static ModelEntry>,
58    /// The original provider/model strings for reference.
59    names: Vec<String>,
60}
61
62impl Default for FallbackChain {
63    /// Creates a default fallback chain with cheap, reliable models.
64    ///
65    /// The default chain includes models from multiple providers to ensure
66    /// redundancy and cost efficiency. These are selected based on:
67    /// - Low input cost
68    /// - Wide context window
69    /// - Vision support for versatility
70    fn default() -> Self {
71        // Default chain: prioritize cheap models from different providers
72        // Order: cheapest first, then progressively more expensive
73        let default_ids = [
74            // Free/very cheap models
75            "google/gemini-2.0-flash",
76            "openai/gpt-4o-mini",
77            "anthropic/claude-3-5-haiku-20241022",
78            // Mid-tier reliable models
79            "openai/gpt-4o",
80            "anthropic/claude-sonnet-4-20250514",
81            // Premium models as last resort
82            "anthropic/claude-opus-4-20250514",
83        ];
84
85        Self::from_ids(&default_ids).expect("Default fallback chain should always be valid")
86    }
87}
88
89impl FallbackChain {
90    /// Creates a new fallback chain from an ordered list of models.
91    ///
92    /// # Arguments
93    ///
94    /// * `models` - A vector of model entries in priority order (first = highest priority)
95    ///
96    /// # Example
97    ///
98    /// ```ignore
99    /// use oxi_ai::model_db::get_model_entry;
100    ///
101    /// let models = vec![
102    ///     get_model_entry("openai", "gpt-4o").unwrap(),
103    ///     get_model_entry("anthropic", "claude-sonnet-4-20250514").unwrap(),
104    /// ];
105    /// let chain = FallbackChain::new(models);
106    /// ```
107    pub fn new(models: Vec<&'static ModelEntry>) -> Self {
108        let names: Vec<String> = models
109            .iter()
110            .map(|m| format!("{}/{}", m.provider, m.id))
111            .collect();
112
113        Self { models, names }
114    }
115
116    /// Creates a fallback chain from "provider/model" ID strings.
117    ///
118    /// Each string must be in the format `"provider/model-id"`, for example:
119    /// - `"anthropic/claude-sonnet-4-20250514"`
120    /// - `"openai/gpt-4o"`
121    /// - `"google/gemini-2.0-flash"`
122    ///
123    /// # Arguments
124    ///
125    /// * `ids` - Slice of strings in `"provider/model"` format
126    ///
127    /// # Errors
128    ///
129    /// Returns a `FallbackChainError` if any model ID cannot be found in the database.
130    ///
131    /// # Example
132    ///
133    /// ```ignore
134    /// let chain = FallbackChain::from_ids(&[
135    ///     "anthropic/claude-sonnet-4-20250514",
136    ///     "openai/gpt-4o",
137    /// ])?;
138    /// ```
139    pub fn from_ids(ids: &[&str]) -> Result<Self, FallbackChainError> {
140        let mut models: Vec<&'static ModelEntry> = Vec::with_capacity(ids.len());
141        let mut names: Vec<String> = Vec::with_capacity(ids.len());
142
143        for id in ids {
144            let (provider, model_id) = match id.split_once('/') {
145                Some((p, m)) => (p, m),
146                None => {
147                    return Err(FallbackChainError::InvalidFormat {
148                        id: id.to_string(),
149                        reason: "Expected 'provider/model' format".to_string(),
150                    });
151                }
152            };
153
154            match get_model_entry(provider, model_id) {
155                Some(entry) => {
156                    models.push(entry);
157                    names.push(id.to_string());
158                }
159                None => {
160                    return Err(FallbackChainError::ModelNotFound {
161                        id: id.to_string(),
162                        provider: provider.to_string(),
163                        model_id: model_id.to_string(),
164                    });
165                }
166            }
167        }
168
169        Ok(Self { models, names })
170    }
171
172    /// Returns the next model in the chain after the current one.
173    ///
174    /// # Arguments
175    ///
176    /// * `current` - The current model ID in `"provider/model"` format
177    ///
178    /// # Returns
179    ///
180    /// * `Some(&ModelEntry)` - The next model in the chain
181    /// * `None` - If the current model is not in the chain, or it's the last model
182    ///
183    /// # Example
184    ///
185    /// ```ignore
186    /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
187    ///
188    /// assert_eq!(chain.next("a").map(|m| m.id), Some("b"));
189    /// assert_eq!(chain.next("b").map(|m| m.id), Some("c"));
190    /// assert_eq!(chain.next("c"), None); // Last in chain
191    /// assert_eq!(chain.next("unknown"), None); // Not in chain
192    /// ```
193    pub fn next(&self, current: &str) -> Option<&'static ModelEntry> {
194        let index = self.index_of(current)?;
195        let next_index = index + 1;
196
197        if next_index < self.models.len() {
198            Some(self.models[next_index])
199        } else {
200            None
201        }
202    }
203
204    /// Returns the index of a model in the chain.
205    ///
206    /// # Arguments
207    ///
208    /// * `model_id` - The model ID in `"provider/model"` format
209    ///
210    /// # Returns
211    ///
212    /// * `Some(usize)` - The zero-based position in the chain
213    /// * `None` - If the model is not in the chain
214    ///
215    /// # Example
216    ///
217    /// ```ignore
218    /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
219    ///
220    /// assert_eq!(chain.index_of("a"), Some(0));
221    /// assert_eq!(chain.index_of("b"), Some(1));
222    /// assert_eq!(chain.index_of("c"), Some(2));
223    /// assert_eq!(chain.index_of("unknown"), None);
224    /// ```
225    pub fn index_of(&self, model_id: &str) -> Option<usize> {
226        self.names.iter().position(|n| n == model_id)
227    }
228
229    /// Returns an iterator over the model entries in the chain.
230    ///
231    /// # Example
232    ///
233    /// ```ignore
234    /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
235    ///
236    /// for model in chain.iter() {
237    ///     println!("Model: {} ({})", model.name, model.provider);
238    /// }
239    /// ```
240    pub fn iter(&self) -> impl Iterator<Item = &'static ModelEntry> + '_ {
241        self.models.iter().copied()
242    }
243
244    /// Returns `true` if the chain contains no models.
245    ///
246    /// # Example
247    ///
248    /// ```ignore
249    /// let chain = FallbackChain::new(vec![]);
250    /// assert!(chain.is_empty());
251    ///
252    /// let chain = FallbackChain::from_ids(&["a"])?;
253    /// assert!(!chain.is_empty());
254    /// ```
255    pub fn is_empty(&self) -> bool {
256        self.models.is_empty()
257    }
258
259    /// Returns the number of models in the chain.
260    ///
261    /// # Example
262    ///
263    /// ```ignore
264    /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
265    /// assert_eq!(chain.len(), 3);
266    /// ```
267    pub fn len(&self) -> usize {
268        self.models.len()
269    }
270
271    /// Returns a slice of all model entries.
272    ///
273    /// # Example
274    ///
275    /// ```ignore
276    /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
277    /// let models: Vec<_> = chain.models();
278    /// ```
279    pub fn models(&self) -> &[&'static ModelEntry] {
280        &self.models
281    }
282
283    /// Returns the model ID strings that were used to create the chain.
284    ///
285    /// # Example
286    ///
287    /// ```ignore
288    /// let chain = FallbackChain::from_ids(&["openai/gpt-4o", "anthropic/claude-sonnet-4"])?;
289    /// assert_eq!(chain.names(), &["openai/gpt-4o", "anthropic/claude-sonnet-4"]);
290    /// ```
291    pub fn names(&self) -> &[String] {
292        &self.names
293    }
294
295    /// Returns the first model in the chain, if any.
296    ///
297    /// # Example
298    ///
299    /// ```ignore
300    /// let chain = FallbackChain::from_ids(&["a", "b"])?;
301    /// assert_eq!(chain.first().map(|m| m.id), Some("a"));
302    ///
303    /// let empty: FallbackChain = FallbackChain::new(vec![]);
304    /// assert_eq!(empty.first(), None);
305    /// ```
306    pub fn first(&self) -> Option<&'static ModelEntry> {
307        self.models.first().copied()
308    }
309
310    /// Returns the last model in the chain, if any.
311    ///
312    /// # Example
313    ///
314    /// ```ignore
315    /// let chain = FallbackChain::from_ids(&["a", "b"])?;
316    /// assert_eq!(chain.last().map(|m| m.id), Some("b"));
317    ///
318    /// let empty: FallbackChain = FallbackChain::new(vec![]);
319    /// assert_eq!(empty.last(), None);
320    /// ```
321    pub fn last(&self) -> Option<&'static ModelEntry> {
322        self.models.last().copied()
323    }
324
325    /// Checks if the chain contains a specific model.
326    ///
327    /// # Arguments
328    ///
329    /// * `model_id` - The model ID in `"provider/model"` format
330    ///
331    /// # Example
332    ///
333    /// ```ignore
334    /// let chain = FallbackChain::from_ids(&["a", "b"])?;
335    /// assert!(chain.contains("a"));
336    /// assert!(!chain.contains("c"));
337    /// ```
338    pub fn contains(&self, model_id: &str) -> bool {
339        self.index_of(model_id).is_some()
340    }
341
342    /// Creates a new chain with models after (and including) the given model.
343    ///
344    /// This is useful for continuing fallback after a model succeeds but you
345    /// want to track the remaining options.
346    ///
347    /// # Arguments
348    ///
349    /// * `model_id` - The model ID to start from (inclusive)
350    ///
351    /// # Returns
352    ///
353    /// * `Some(FallbackChain)` - The remaining models from the starting point
354    /// * `None` - If the model is not in the chain
355    ///
356    /// # Example
357    ///
358    /// ```ignore
359    /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
360    /// let remaining = chain.from_inclusive("b")?;
361    /// assert_eq!(remaining.names(), &["b", "c"]);
362    /// ```
363    pub fn from_inclusive(&self, model_id: &str) -> Option<Self> {
364        let start_index = self.index_of(model_id)?;
365
366        let models: Vec<_> = self.models[start_index..].to_vec();
367        let names: Vec<_> = self.names[start_index..].to_vec();
368
369        Some(Self { models, names })
370    }
371
372    /// Creates a new chain with models after (excluding) the given model.
373    ///
374    /// # Arguments
375    ///
376    /// * `model_id` - The model ID to skip
377    ///
378    /// # Returns
379    ///
380    /// * `Some(FallbackChain)` - The remaining models after the given model
381    /// * `None` - If the model is not in the chain or is the last model
382    ///
383    /// # Example
384    ///
385    /// ```ignore
386    /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
387    /// let remaining = chain.from_after("b")?;
388    /// assert_eq!(remaining.names(), &["c"]);
389    /// ```
390    pub fn from_after(&self, model_id: &str) -> Option<Self> {
391        let start_index = self.index_of(model_id)?;
392        let next_index = start_index + 1;
393
394        if next_index >= self.models.len() {
395            return None;
396        }
397
398        let models: Vec<_> = self.models[next_index..].to_vec();
399        let names: Vec<_> = self.names[next_index..].to_vec();
400
401        Some(Self { models, names })
402    }
403}
404
405/// Errors that can occur when creating a fallback chain.
406#[derive(Debug, Clone, PartialEq, thiserror::Error)]
407pub enum FallbackChainError {
408    /// The model ID format is invalid (expected "provider/model").
409    #[error("Invalid model ID format '{id}': {reason}")]
410    InvalidFormat {
411        /// The malformed model ID.
412        id: String,
413        /// Explanation of why the format is invalid.
414        reason: String,
415    },
416
417    /// The model was not found in the model database.
418    #[error("Model not found: {provider}/{model_id}")]
419    ModelNotFound {
420        /// The full model ID that was requested.
421        id: String,
422        /// The provider that was searched.
423        provider: String,
424        /// The model ID that was not found.
425        model_id: String,
426    },
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use crate::model_db::get_model_entry;
433
434    #[test]
435    fn test_from_ids_valid() {
436        let chain = FallbackChain::from_ids(&["anthropic/claude-sonnet-4-20250514"]).unwrap();
437        assert_eq!(chain.len(), 1);
438        assert_eq!(chain.first().unwrap().id, "claude-sonnet-4-20250514");
439    }
440
441    #[test]
442    fn test_from_ids_multiple() {
443        let chain = FallbackChain::from_ids(&[
444            "openai/gpt-4o",
445            "anthropic/claude-sonnet-4-20250514",
446            "google/gemini-2.0-flash",
447        ])
448        .unwrap();
449
450        assert_eq!(chain.len(), 3);
451        assert_eq!(chain.first().unwrap().id, "gpt-4o");
452        assert_eq!(chain.last().unwrap().id, "gemini-2.0-flash");
453    }
454
455    #[test]
456    fn test_from_ids_invalid_format() {
457        let result = FallbackChain::from_ids(&["invalid-no-slash"]);
458        assert!(matches!(
459            result,
460            Err(FallbackChainError::InvalidFormat { .. })
461        ));
462    }
463
464    #[test]
465    fn test_from_ids_not_found() {
466        let result = FallbackChain::from_ids(&["nonexistent-provider/nonexistent-model"]);
467        assert!(matches!(
468            result,
469            Err(FallbackChainError::ModelNotFound { .. })
470        ));
471    }
472
473    #[test]
474    fn test_new_direct() {
475        let model = get_model_entry("openai", "gpt-4o").unwrap();
476        let chain = FallbackChain::new(vec![model]);
477
478        assert_eq!(chain.len(), 1);
479        assert_eq!(chain.first().unwrap().id, "gpt-4o");
480    }
481
482    #[test]
483    fn test_default_chain() {
484        let chain = FallbackChain::default();
485
486        // Default chain should have several models
487        assert!(!chain.is_empty());
488        assert!(chain.len() >= 3);
489
490        // First model should be the highest priority
491        let first = chain.first();
492        assert!(first.is_some());
493    }
494
495    #[test]
496    fn test_next() {
497        let chain = FallbackChain::from_ids(&[
498            "openai/gpt-4o",
499            "anthropic/claude-sonnet-4-20250514",
500            "google/gemini-2.0-flash",
501        ])
502        .unwrap();
503
504        assert_eq!(
505            chain.next("openai/gpt-4o").unwrap().id,
506            "claude-sonnet-4-20250514"
507        );
508        assert_eq!(
509            chain.next("anthropic/claude-sonnet-4-20250514").unwrap().id,
510            "gemini-2.0-flash"
511        );
512        assert_eq!(chain.next("google/gemini-2.0-flash"), None);
513        assert_eq!(chain.next("unknown"), None);
514    }
515
516    #[test]
517    fn test_index_of() {
518        let chain = FallbackChain::from_ids(&[
519            "openai/gpt-4o",
520            "anthropic/claude-sonnet-4-20250514",
521            "google/gemini-2.0-flash",
522        ])
523        .unwrap();
524
525        assert_eq!(chain.index_of("openai/gpt-4o"), Some(0));
526        assert_eq!(
527            chain.index_of("anthropic/claude-sonnet-4-20250514"),
528            Some(1)
529        );
530        assert_eq!(chain.index_of("google/gemini-2.0-flash"), Some(2));
531        assert_eq!(chain.index_of("unknown"), None);
532    }
533
534    #[test]
535    fn test_contains() {
536        let chain =
537            FallbackChain::from_ids(&["openai/gpt-4o", "anthropic/claude-sonnet-4-20250514"])
538                .unwrap();
539
540        assert!(chain.contains("openai/gpt-4o"));
541        assert!(chain.contains("anthropic/claude-sonnet-4-20250514"));
542        assert!(!chain.contains("google/gemini-2.0-flash"));
543    }
544
545    #[test]
546    fn test_iter() {
547        let chain = FallbackChain::from_ids(&[
548            "openai/gpt-4o",
549            "anthropic/claude-sonnet-4-20250514",
550            "google/gemini-2.0-flash",
551        ])
552        .unwrap();
553        let ids: Vec<_> = chain.iter().map(|m| m.id).collect();
554
555        assert_eq!(
556            ids,
557            vec!["gpt-4o", "claude-sonnet-4-20250514", "gemini-2.0-flash"]
558        );
559    }
560
561    #[test]
562    fn test_is_empty() {
563        let empty: FallbackChain = FallbackChain::new(vec![]);
564        assert!(empty.is_empty());
565
566        let non_empty = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
567        assert!(!non_empty.is_empty());
568    }
569
570    #[test]
571    fn test_models_and_names() {
572        let chain = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
573
574        assert_eq!(chain.models().len(), 1);
575        assert_eq!(chain.names(), &["openai/gpt-4o"]);
576    }
577
578    #[test]
579    fn test_from_inclusive() {
580        let chain = FallbackChain::from_ids(&[
581            "openai/gpt-4o",
582            "anthropic/claude-sonnet-4-20250514",
583            "google/gemini-2.0-flash",
584        ])
585        .unwrap();
586
587        let remaining = chain
588            .from_inclusive("anthropic/claude-sonnet-4-20250514")
589            .unwrap();
590        assert_eq!(
591            remaining.names(),
592            &[
593                "anthropic/claude-sonnet-4-20250514",
594                "google/gemini-2.0-flash"
595            ]
596        );
597
598        assert!(chain.from_inclusive("unknown").is_none());
599    }
600
601    #[test]
602    fn test_from_after() {
603        let chain = FallbackChain::from_ids(&[
604            "openai/gpt-4o",
605            "anthropic/claude-sonnet-4-20250514",
606            "google/gemini-2.0-flash",
607        ])
608        .unwrap();
609
610        let remaining = chain
611            .from_after("anthropic/claude-sonnet-4-20250514")
612            .unwrap();
613        assert_eq!(remaining.names(), &["google/gemini-2.0-flash"]);
614
615        assert!(chain.from_after("google/gemini-2.0-flash").is_none()); // No model after last
616        assert!(chain.from_after("unknown").is_none());
617    }
618
619    #[test]
620    fn test_first_last() {
621        let chain = FallbackChain::from_ids(&[
622            "openai/gpt-4o",
623            "anthropic/claude-sonnet-4-20250514",
624            "google/gemini-2.0-flash",
625        ])
626        .unwrap();
627
628        assert_eq!(chain.first().unwrap().id, "gpt-4o");
629        assert_eq!(chain.last().unwrap().id, "gemini-2.0-flash");
630
631        let empty: FallbackChain = FallbackChain::new(vec![]);
632        assert_eq!(empty.first(), None);
633        assert_eq!(empty.last(), None);
634    }
635
636    #[test]
637    fn test_debug_format() {
638        let chain = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
639        let debug_str = format!("{:?}", chain);
640        assert!(debug_str.contains("FallbackChain"));
641    }
642}