Skip to main content

swink_agent/
fallback.rs

1//! Model fallback configuration.
2//!
3//! [`ModelFallback`] defines an ordered list of fallback models to try when
4//! the primary model exhausts its retry budget. Each entry pairs a
5//! [`ModelSpec`] with its corresponding [`StreamFn`], allowing fallback
6//! across providers.
7
8use std::sync::Arc;
9
10use crate::stream::StreamFn;
11use crate::types::ModelSpec;
12
13/// An ordered sequence of fallback models to attempt when the primary model
14/// (and its retries) are exhausted.
15///
16/// The agent tries each model in order, applying the configured
17/// [`RetryStrategy`](crate::RetryStrategy) independently for each model.
18/// When all fallback models are also exhausted the error propagates normally.
19///
20/// # Example
21///
22/// ```rust,no_run
23/// use swink_agent::{ModelFallback, ModelSpec};
24/// # use std::sync::Arc;
25/// # fn make_stream_fn() -> Arc<dyn swink_agent::StreamFn> { todo!() }
26///
27/// let fallback = ModelFallback::new(vec![
28///     (ModelSpec::new("openai", "gpt-4o-mini"), make_stream_fn()),
29///     (ModelSpec::new("anthropic", "claude-3-haiku-20240307"), make_stream_fn()),
30/// ]);
31/// ```
32#[derive(Clone)]
33pub struct ModelFallback {
34    models: Vec<(ModelSpec, Arc<dyn StreamFn>)>,
35}
36
37impl ModelFallback {
38    /// Create a new fallback chain from an ordered list of model/stream pairs.
39    #[must_use]
40    pub fn new(models: Vec<(ModelSpec, Arc<dyn StreamFn>)>) -> Self {
41        Self { models }
42    }
43
44    /// Returns the fallback models in order.
45    #[must_use]
46    pub fn models(&self) -> &[(ModelSpec, Arc<dyn StreamFn>)] {
47        &self.models
48    }
49
50    /// Returns `true` if the fallback chain is empty.
51    #[must_use]
52    pub fn is_empty(&self) -> bool {
53        self.models.is_empty()
54    }
55
56    /// Returns the number of fallback models.
57    #[must_use]
58    pub fn len(&self) -> usize {
59        self.models.len()
60    }
61}
62
63impl std::fmt::Debug for ModelFallback {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("ModelFallback")
66            .field(
67                "models",
68                &self
69                    .models
70                    .iter()
71                    .map(|(m, _)| format!("{}:{}", m.provider, m.model_id))
72                    .collect::<Vec<_>>(),
73            )
74            .finish()
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[test]
83    fn empty_fallback() {
84        let fb = ModelFallback::new(vec![]);
85        assert!(fb.is_empty());
86        assert_eq!(fb.len(), 0);
87        assert!(fb.models().is_empty());
88    }
89
90    #[test]
91    fn debug_format() {
92        let fb = ModelFallback::new(vec![]);
93        let dbg = format!("{fb:?}");
94        assert!(dbg.contains("ModelFallback"));
95    }
96}