oxi_ai/fallback_chain.rs
1//! Fallback chain for ordered model failover.
2//!
3//! A `FallbackChain` manages an ordered list of models to try sequentially
4//! when a request fails. This enables automatic failover to backup models
5//! without requiring manual intervention.
6//!
7//! # Usage
8//!
9//! ```ignore
10//! use oxi_ai::fallback_chain::FallbackChain;
11//!
12//! // Create from provider/model strings
13//! let chain = FallbackChain::from_ids(&[
14//! "anthropic/claude-sonnet-4-20250514",
15//! "openai/gpt-4o",
16//! ])?;
17//!
18//! // Iterate through models
19//! for model in chain.iter() {
20//! println!("Trying: {}", model.name);
21//! }
22//!
23//! // Get the next model after a failure
24//! if let Some(next) = chain.next("anthropic/claude-sonnet-4-20250514") {
25//! println!("Fallback to: {}", next.name);
26//! }
27//! ```
28
29use crate::model_db::{get_model_entry, ModelEntry};
30
31/// An ordered chain of models for sequential fallback on failure.
32///
33/// When a model request fails, the chain allows easy iteration to the next
34/// available model in priority order. This is useful for implementing
35/// automatic failover strategies.
36///
37/// # Example
38///
39/// ```ignore
40/// use oxi_ai::fallback_chain::FallbackChain;
41///
42/// // From string IDs
43/// let chain = FallbackChain::from_ids(&["openai/gpt-4o", "google/gemini-2.0-flash"])?;
44///
45/// // Direct construction
46/// let models = vec![model1, model2];
47/// let chain = FallbackChain::new(models);
48///
49/// // Find next model after failure
50/// if let Some(next) = chain.next("openai/gpt-4o") {
51/// // Use next model...
52/// }
53/// ```
54#[derive(Debug, Clone, PartialEq)]
55pub struct FallbackChain {
56 /// The ordered list of model entries.
57 models: Vec<&'static ModelEntry>,
58 /// The original provider/model strings for reference.
59 names: Vec<String>,
60}
61
62impl Default for FallbackChain {
63 /// Creates a default fallback chain with cheap, reliable models.
64 ///
65 /// The default chain includes models from multiple providers to ensure
66 /// redundancy and cost efficiency. These are selected based on:
67 /// - Low input cost
68 /// - Wide context window
69 /// - Vision support for versatility
70 fn default() -> Self {
71 // Default chain: prioritize cheap models from different providers
72 // Order: cheapest first, then progressively more expensive
73 let default_ids = [
74 // Free/very cheap models
75 "google/gemini-2.0-flash",
76 "openai/gpt-4o-mini",
77 "anthropic/claude-3-5-haiku-20241022",
78 // Mid-tier reliable models
79 "openai/gpt-4o",
80 "anthropic/claude-sonnet-4-20250514",
81 // Premium models as last resort
82 "anthropic/claude-opus-4-20250514",
83 ];
84
85 Self::from_ids(&default_ids).expect("Default fallback chain should always be valid")
86 }
87}
88
89impl FallbackChain {
90 /// Creates a new fallback chain from an ordered list of models.
91 ///
92 /// # Arguments
93 ///
94 /// * `models` - A vector of model entries in priority order (first = highest priority)
95 ///
96 /// # Example
97 ///
98 /// ```ignore
99 /// use oxi_ai::model_db::get_model_entry;
100 ///
101 /// let models = vec![
102 /// get_model_entry("openai", "gpt-4o").unwrap(),
103 /// get_model_entry("anthropic", "claude-sonnet-4-20250514").unwrap(),
104 /// ];
105 /// let chain = FallbackChain::new(models);
106 /// ```
107 pub fn new(models: Vec<&'static ModelEntry>) -> Self {
108 let names: Vec<String> = models
109 .iter()
110 .map(|m| format!("{}/{}", m.provider, m.id))
111 .collect();
112
113 Self { models, names }
114 }
115
116 /// Creates a fallback chain from "provider/model" ID strings.
117 ///
118 /// Each string must be in the format `"provider/model-id"`, for example:
119 /// - `"anthropic/claude-sonnet-4-20250514"`
120 /// - `"openai/gpt-4o"`
121 /// - `"google/gemini-2.0-flash"`
122 ///
123 /// # Arguments
124 ///
125 /// * `ids` - Slice of strings in `"provider/model"` format
126 ///
127 /// # Errors
128 ///
129 /// Returns a `FallbackChainError` if any model ID cannot be found in the database.
130 ///
131 /// # Example
132 ///
133 /// ```ignore
134 /// let chain = FallbackChain::from_ids(&[
135 /// "anthropic/claude-sonnet-4-20250514",
136 /// "openai/gpt-4o",
137 /// ])?;
138 /// ```
139 pub fn from_ids(ids: &[&str]) -> Result<Self, FallbackChainError> {
140 let mut models: Vec<&'static ModelEntry> = Vec::with_capacity(ids.len());
141 let mut names: Vec<String> = Vec::with_capacity(ids.len());
142
143 for id in ids {
144 let (provider, model_id) = match id.split_once('/') {
145 Some((p, m)) => (p, m),
146 None => {
147 return Err(FallbackChainError::InvalidFormat {
148 id: id.to_string(),
149 reason: "Expected 'provider/model' format".to_string(),
150 });
151 }
152 };
153
154 match get_model_entry(provider, model_id) {
155 Some(entry) => {
156 models.push(entry);
157 names.push(id.to_string());
158 }
159 None => {
160 return Err(FallbackChainError::ModelNotFound {
161 id: id.to_string(),
162 provider: provider.to_string(),
163 model_id: model_id.to_string(),
164 });
165 }
166 }
167 }
168
169 Ok(Self { models, names })
170 }
171
172 /// Returns the next model in the chain after the current one.
173 ///
174 /// # Arguments
175 ///
176 /// * `current` - The current model ID in `"provider/model"` format
177 ///
178 /// # Returns
179 ///
180 /// * `Some(&ModelEntry)` - The next model in the chain
181 /// * `None` - If the current model is not in the chain, or it's the last model
182 ///
183 /// # Example
184 ///
185 /// ```ignore
186 /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
187 ///
188 /// assert_eq!(chain.next("a").map(|m| m.id), Some("b"));
189 /// assert_eq!(chain.next("b").map(|m| m.id), Some("c"));
190 /// assert_eq!(chain.next("c"), None); // Last in chain
191 /// assert_eq!(chain.next("unknown"), None); // Not in chain
192 /// ```
193 pub fn next(&self, current: &str) -> Option<&'static ModelEntry> {
194 let index = self.index_of(current)?;
195 let next_index = index + 1;
196
197 if next_index < self.models.len() {
198 Some(self.models[next_index])
199 } else {
200 None
201 }
202 }
203
204 /// Returns the index of a model in the chain.
205 ///
206 /// # Arguments
207 ///
208 /// * `model_id` - The model ID in `"provider/model"` format
209 ///
210 /// # Returns
211 ///
212 /// * `Some(usize)` - The zero-based position in the chain
213 /// * `None` - If the model is not in the chain
214 ///
215 /// # Example
216 ///
217 /// ```ignore
218 /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
219 ///
220 /// assert_eq!(chain.index_of("a"), Some(0));
221 /// assert_eq!(chain.index_of("b"), Some(1));
222 /// assert_eq!(chain.index_of("c"), Some(2));
223 /// assert_eq!(chain.index_of("unknown"), None);
224 /// ```
225 pub fn index_of(&self, model_id: &str) -> Option<usize> {
226 self.names.iter().position(|n| n == model_id)
227 }
228
229 /// Returns an iterator over the model entries in the chain.
230 ///
231 /// # Example
232 ///
233 /// ```ignore
234 /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
235 ///
236 /// for model in chain.iter() {
237 /// println!("Model: {} ({})", model.name, model.provider);
238 /// }
239 /// ```
240 pub fn iter(&self) -> impl Iterator<Item = &'static ModelEntry> + '_ {
241 self.models.iter().copied()
242 }
243
244 /// Returns `true` if the chain contains no models.
245 ///
246 /// # Example
247 ///
248 /// ```ignore
249 /// let chain = FallbackChain::new(vec![]);
250 /// assert!(chain.is_empty());
251 ///
252 /// let chain = FallbackChain::from_ids(&["a"])?;
253 /// assert!(!chain.is_empty());
254 /// ```
255 pub fn is_empty(&self) -> bool {
256 self.models.is_empty()
257 }
258
259 /// Returns the number of models in the chain.
260 ///
261 /// # Example
262 ///
263 /// ```ignore
264 /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
265 /// assert_eq!(chain.len(), 3);
266 /// ```
267 pub fn len(&self) -> usize {
268 self.models.len()
269 }
270
271 /// Returns a slice of all model entries.
272 ///
273 /// # Example
274 ///
275 /// ```ignore
276 /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
277 /// let models: Vec<_> = chain.models();
278 /// ```
279 pub fn models(&self) -> &[&'static ModelEntry] {
280 &self.models
281 }
282
283 /// Returns the model ID strings that were used to create the chain.
284 ///
285 /// # Example
286 ///
287 /// ```ignore
288 /// let chain = FallbackChain::from_ids(&["openai/gpt-4o", "anthropic/claude-sonnet-4"])?;
289 /// assert_eq!(chain.names(), &["openai/gpt-4o", "anthropic/claude-sonnet-4"]);
290 /// ```
291 pub fn names(&self) -> &[String] {
292 &self.names
293 }
294
295 /// Returns the first model in the chain, if any.
296 ///
297 /// # Example
298 ///
299 /// ```ignore
300 /// let chain = FallbackChain::from_ids(&["a", "b"])?;
301 /// assert_eq!(chain.first().map(|m| m.id), Some("a"));
302 ///
303 /// let empty: FallbackChain = FallbackChain::new(vec![]);
304 /// assert_eq!(empty.first(), None);
305 /// ```
306 pub fn first(&self) -> Option<&'static ModelEntry> {
307 self.models.first().copied()
308 }
309
310 /// Returns the last model in the chain, if any.
311 ///
312 /// # Example
313 ///
314 /// ```ignore
315 /// let chain = FallbackChain::from_ids(&["a", "b"])?;
316 /// assert_eq!(chain.last().map(|m| m.id), Some("b"));
317 ///
318 /// let empty: FallbackChain = FallbackChain::new(vec![]);
319 /// assert_eq!(empty.last(), None);
320 /// ```
321 pub fn last(&self) -> Option<&'static ModelEntry> {
322 self.models.last().copied()
323 }
324
325 /// Checks if the chain contains a specific model.
326 ///
327 /// # Arguments
328 ///
329 /// * `model_id` - The model ID in `"provider/model"` format
330 ///
331 /// # Example
332 ///
333 /// ```ignore
334 /// let chain = FallbackChain::from_ids(&["a", "b"])?;
335 /// assert!(chain.contains("a"));
336 /// assert!(!chain.contains("c"));
337 /// ```
338 pub fn contains(&self, model_id: &str) -> bool {
339 self.index_of(model_id).is_some()
340 }
341
342 /// Creates a new chain with models after (and including) the given model.
343 ///
344 /// This is useful for continuing fallback after a model succeeds but you
345 /// want to track the remaining options.
346 ///
347 /// # Arguments
348 ///
349 /// * `model_id` - The model ID to start from (inclusive)
350 ///
351 /// # Returns
352 ///
353 /// * `Some(FallbackChain)` - The remaining models from the starting point
354 /// * `None` - If the model is not in the chain
355 ///
356 /// # Example
357 ///
358 /// ```ignore
359 /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
360 /// let remaining = chain.from_inclusive("b")?;
361 /// assert_eq!(remaining.names(), &["b", "c"]);
362 /// ```
363 pub fn from_inclusive(&self, model_id: &str) -> Option<Self> {
364 let start_index = self.index_of(model_id)?;
365
366 let models: Vec<_> = self.models[start_index..].to_vec();
367 let names: Vec<_> = self.names[start_index..].to_vec();
368
369 Some(Self { models, names })
370 }
371
372 /// Creates a new chain with models after (excluding) the given model.
373 ///
374 /// # Arguments
375 ///
376 /// * `model_id` - The model ID to skip
377 ///
378 /// # Returns
379 ///
380 /// * `Some(FallbackChain)` - The remaining models after the given model
381 /// * `None` - If the model is not in the chain or is the last model
382 ///
383 /// # Example
384 ///
385 /// ```ignore
386 /// let chain = FallbackChain::from_ids(&["a", "b", "c"])?;
387 /// let remaining = chain.from_after("b")?;
388 /// assert_eq!(remaining.names(), &["c"]);
389 /// ```
390 pub fn from_after(&self, model_id: &str) -> Option<Self> {
391 let start_index = self.index_of(model_id)?;
392 let next_index = start_index + 1;
393
394 if next_index >= self.models.len() {
395 return None;
396 }
397
398 let models: Vec<_> = self.models[next_index..].to_vec();
399 let names: Vec<_> = self.names[next_index..].to_vec();
400
401 Some(Self { models, names })
402 }
403}
404
405/// Errors that can occur when creating a fallback chain.
406#[derive(Debug, Clone, PartialEq, thiserror::Error)]
407pub enum FallbackChainError {
408 /// The model ID format is invalid (expected "provider/model").
409 #[error("Invalid model ID format '{id}': {reason}")]
410 InvalidFormat {
411 /// The malformed model ID.
412 id: String,
413 /// Explanation of why the format is invalid.
414 reason: String,
415 },
416
417 /// The model was not found in the model database.
418 #[error("Model not found: {provider}/{model_id}")]
419 ModelNotFound {
420 /// The full model ID that was requested.
421 id: String,
422 /// The provider that was searched.
423 provider: String,
424 /// The model ID that was not found.
425 model_id: String,
426 },
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use crate::model_db::get_model_entry;
433
434 #[test]
435 fn test_from_ids_valid() {
436 let chain = FallbackChain::from_ids(&["anthropic/claude-sonnet-4-20250514"]).unwrap();
437 assert_eq!(chain.len(), 1);
438 assert_eq!(chain.first().unwrap().id, "claude-sonnet-4-20250514");
439 }
440
441 #[test]
442 fn test_from_ids_multiple() {
443 let chain = FallbackChain::from_ids(&[
444 "openai/gpt-4o",
445 "anthropic/claude-sonnet-4-20250514",
446 "google/gemini-2.0-flash",
447 ])
448 .unwrap();
449
450 assert_eq!(chain.len(), 3);
451 assert_eq!(chain.first().unwrap().id, "gpt-4o");
452 assert_eq!(chain.last().unwrap().id, "gemini-2.0-flash");
453 }
454
455 #[test]
456 fn test_from_ids_invalid_format() {
457 let result = FallbackChain::from_ids(&["invalid-no-slash"]);
458 assert!(matches!(
459 result,
460 Err(FallbackChainError::InvalidFormat { .. })
461 ));
462 }
463
464 #[test]
465 fn test_from_ids_not_found() {
466 let result = FallbackChain::from_ids(&["nonexistent-provider/nonexistent-model"]);
467 assert!(matches!(
468 result,
469 Err(FallbackChainError::ModelNotFound { .. })
470 ));
471 }
472
473 #[test]
474 fn test_new_direct() {
475 let model = get_model_entry("openai", "gpt-4o").unwrap();
476 let chain = FallbackChain::new(vec![model]);
477
478 assert_eq!(chain.len(), 1);
479 assert_eq!(chain.first().unwrap().id, "gpt-4o");
480 }
481
482 #[test]
483 fn test_default_chain() {
484 let chain = FallbackChain::default();
485
486 // Default chain should have several models
487 assert!(!chain.is_empty());
488 assert!(chain.len() >= 3);
489
490 // First model should be the highest priority
491 let first = chain.first();
492 assert!(first.is_some());
493 }
494
495 #[test]
496 fn test_next() {
497 let chain = FallbackChain::from_ids(&[
498 "openai/gpt-4o",
499 "anthropic/claude-sonnet-4-20250514",
500 "google/gemini-2.0-flash",
501 ])
502 .unwrap();
503
504 assert_eq!(
505 chain.next("openai/gpt-4o").unwrap().id,
506 "claude-sonnet-4-20250514"
507 );
508 assert_eq!(
509 chain.next("anthropic/claude-sonnet-4-20250514").unwrap().id,
510 "gemini-2.0-flash"
511 );
512 assert_eq!(chain.next("google/gemini-2.0-flash"), None);
513 assert_eq!(chain.next("unknown"), None);
514 }
515
516 #[test]
517 fn test_index_of() {
518 let chain = FallbackChain::from_ids(&[
519 "openai/gpt-4o",
520 "anthropic/claude-sonnet-4-20250514",
521 "google/gemini-2.0-flash",
522 ])
523 .unwrap();
524
525 assert_eq!(chain.index_of("openai/gpt-4o"), Some(0));
526 assert_eq!(
527 chain.index_of("anthropic/claude-sonnet-4-20250514"),
528 Some(1)
529 );
530 assert_eq!(chain.index_of("google/gemini-2.0-flash"), Some(2));
531 assert_eq!(chain.index_of("unknown"), None);
532 }
533
534 #[test]
535 fn test_contains() {
536 let chain =
537 FallbackChain::from_ids(&["openai/gpt-4o", "anthropic/claude-sonnet-4-20250514"])
538 .unwrap();
539
540 assert!(chain.contains("openai/gpt-4o"));
541 assert!(chain.contains("anthropic/claude-sonnet-4-20250514"));
542 assert!(!chain.contains("google/gemini-2.0-flash"));
543 }
544
545 #[test]
546 fn test_iter() {
547 let chain = FallbackChain::from_ids(&[
548 "openai/gpt-4o",
549 "anthropic/claude-sonnet-4-20250514",
550 "google/gemini-2.0-flash",
551 ])
552 .unwrap();
553 let ids: Vec<_> = chain.iter().map(|m| m.id).collect();
554
555 assert_eq!(
556 ids,
557 vec!["gpt-4o", "claude-sonnet-4-20250514", "gemini-2.0-flash"]
558 );
559 }
560
561 #[test]
562 fn test_is_empty() {
563 let empty: FallbackChain = FallbackChain::new(vec![]);
564 assert!(empty.is_empty());
565
566 let non_empty = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
567 assert!(!non_empty.is_empty());
568 }
569
570 #[test]
571 fn test_models_and_names() {
572 let chain = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
573
574 assert_eq!(chain.models().len(), 1);
575 assert_eq!(chain.names(), &["openai/gpt-4o"]);
576 }
577
578 #[test]
579 fn test_from_inclusive() {
580 let chain = FallbackChain::from_ids(&[
581 "openai/gpt-4o",
582 "anthropic/claude-sonnet-4-20250514",
583 "google/gemini-2.0-flash",
584 ])
585 .unwrap();
586
587 let remaining = chain
588 .from_inclusive("anthropic/claude-sonnet-4-20250514")
589 .unwrap();
590 assert_eq!(
591 remaining.names(),
592 &[
593 "anthropic/claude-sonnet-4-20250514",
594 "google/gemini-2.0-flash"
595 ]
596 );
597
598 assert!(chain.from_inclusive("unknown").is_none());
599 }
600
601 #[test]
602 fn test_from_after() {
603 let chain = FallbackChain::from_ids(&[
604 "openai/gpt-4o",
605 "anthropic/claude-sonnet-4-20250514",
606 "google/gemini-2.0-flash",
607 ])
608 .unwrap();
609
610 let remaining = chain
611 .from_after("anthropic/claude-sonnet-4-20250514")
612 .unwrap();
613 assert_eq!(remaining.names(), &["google/gemini-2.0-flash"]);
614
615 assert!(chain.from_after("google/gemini-2.0-flash").is_none()); // No model after last
616 assert!(chain.from_after("unknown").is_none());
617 }
618
619 #[test]
620 fn test_first_last() {
621 let chain = FallbackChain::from_ids(&[
622 "openai/gpt-4o",
623 "anthropic/claude-sonnet-4-20250514",
624 "google/gemini-2.0-flash",
625 ])
626 .unwrap();
627
628 assert_eq!(chain.first().unwrap().id, "gpt-4o");
629 assert_eq!(chain.last().unwrap().id, "gemini-2.0-flash");
630
631 let empty: FallbackChain = FallbackChain::new(vec![]);
632 assert_eq!(empty.first(), None);
633 assert_eq!(empty.last(), None);
634 }
635
636 #[test]
637 fn test_debug_format() {
638 let chain = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
639 let debug_str = format!("{:?}", chain);
640 assert!(debug_str.contains("FallbackChain"));
641 }
642}