use crate::model_db::{get_model_entry, ModelEntry};
#[derive(Debug, Clone, PartialEq)]
pub struct FallbackChain {
models: Vec<&'static ModelEntry>,
names: Vec<String>,
}
impl Default for FallbackChain {
fn default() -> Self {
let default_ids = [
"google/gemini-2.0-flash",
"openai/gpt-4o-mini",
"anthropic/claude-3-5-haiku-20241022",
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"anthropic/claude-opus-4-20250514",
];
Self::from_ids(&default_ids).expect("Default fallback chain should always be valid")
}
}
impl FallbackChain {
pub fn new(models: Vec<&'static ModelEntry>) -> Self {
let names: Vec<String> = models
.iter()
.map(|m| format!("{}/{}", m.provider, m.id))
.collect();
Self { models, names }
}
pub fn from_ids(ids: &[&str]) -> Result<Self, FallbackChainError> {
let mut models: Vec<&'static ModelEntry> = Vec::with_capacity(ids.len());
let mut names: Vec<String> = Vec::with_capacity(ids.len());
for id in ids {
let (provider, model_id) = match id.split_once('/') {
Some((p, m)) => (p, m),
None => {
return Err(FallbackChainError::InvalidFormat {
id: id.to_string(),
reason: "Expected 'provider/model' format".to_string(),
});
}
};
match get_model_entry(provider, model_id) {
Some(entry) => {
models.push(entry);
names.push(id.to_string());
}
None => {
return Err(FallbackChainError::ModelNotFound {
id: id.to_string(),
provider: provider.to_string(),
model_id: model_id.to_string(),
});
}
}
}
Ok(Self { models, names })
}
pub fn next(&self, current: &str) -> Option<&'static ModelEntry> {
let index = self.index_of(current)?;
let next_index = index + 1;
if next_index < self.models.len() {
Some(self.models[next_index])
} else {
None
}
}
pub fn index_of(&self, model_id: &str) -> Option<usize> {
self.names.iter().position(|n| n == model_id)
}
pub fn iter(&self) -> impl Iterator<Item = &'static ModelEntry> + '_ {
self.models.iter().copied()
}
pub fn is_empty(&self) -> bool {
self.models.is_empty()
}
pub fn len(&self) -> usize {
self.models.len()
}
pub fn models(&self) -> &[&'static ModelEntry] {
&self.models
}
pub fn names(&self) -> &[String] {
&self.names
}
pub fn first(&self) -> Option<&'static ModelEntry> {
self.models.first().copied()
}
pub fn last(&self) -> Option<&'static ModelEntry> {
self.models.last().copied()
}
pub fn contains(&self, model_id: &str) -> bool {
self.index_of(model_id).is_some()
}
pub fn from_inclusive(&self, model_id: &str) -> Option<Self> {
let start_index = self.index_of(model_id)?;
let models: Vec<_> = self.models[start_index..].to_vec();
let names: Vec<_> = self.names[start_index..].to_vec();
Some(Self { models, names })
}
pub fn from_after(&self, model_id: &str) -> Option<Self> {
let start_index = self.index_of(model_id)?;
let next_index = start_index + 1;
if next_index >= self.models.len() {
return None;
}
let models: Vec<_> = self.models[next_index..].to_vec();
let names: Vec<_> = self.names[next_index..].to_vec();
Some(Self { models, names })
}
}
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum FallbackChainError {
#[error("Invalid model ID format '{id}': {reason}")]
InvalidFormat {
id: String,
reason: String,
},
#[error("Model not found: {provider}/{model_id}")]
ModelNotFound {
id: String,
provider: String,
model_id: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_db::get_model_entry;
#[test]
fn test_from_ids_valid() {
let chain = FallbackChain::from_ids(&["anthropic/claude-sonnet-4-20250514"]).unwrap();
assert_eq!(chain.len(), 1);
assert_eq!(chain.first().unwrap().id, "claude-sonnet-4-20250514");
}
#[test]
fn test_from_ids_multiple() {
let chain = FallbackChain::from_ids(&[
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"google/gemini-2.0-flash",
])
.unwrap();
assert_eq!(chain.len(), 3);
assert_eq!(chain.first().unwrap().id, "gpt-4o");
assert_eq!(chain.last().unwrap().id, "gemini-2.0-flash");
}
#[test]
fn test_from_ids_invalid_format() {
let result = FallbackChain::from_ids(&["invalid-no-slash"]);
assert!(matches!(
result,
Err(FallbackChainError::InvalidFormat { .. })
));
}
#[test]
fn test_from_ids_not_found() {
let result = FallbackChain::from_ids(&["nonexistent-provider/nonexistent-model"]);
assert!(matches!(
result,
Err(FallbackChainError::ModelNotFound { .. })
));
}
#[test]
fn test_new_direct() {
let model = get_model_entry("openai", "gpt-4o").unwrap();
let chain = FallbackChain::new(vec![model]);
assert_eq!(chain.len(), 1);
assert_eq!(chain.first().unwrap().id, "gpt-4o");
}
#[test]
fn test_default_chain() {
let chain = FallbackChain::default();
assert!(!chain.is_empty());
assert!(chain.len() >= 3);
let first = chain.first();
assert!(first.is_some());
}
#[test]
fn test_next() {
let chain = FallbackChain::from_ids(&[
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"google/gemini-2.0-flash",
])
.unwrap();
assert_eq!(
chain.next("openai/gpt-4o").unwrap().id,
"claude-sonnet-4-20250514"
);
assert_eq!(
chain.next("anthropic/claude-sonnet-4-20250514").unwrap().id,
"gemini-2.0-flash"
);
assert_eq!(chain.next("google/gemini-2.0-flash"), None);
assert_eq!(chain.next("unknown"), None);
}
#[test]
fn test_index_of() {
let chain = FallbackChain::from_ids(&[
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"google/gemini-2.0-flash",
])
.unwrap();
assert_eq!(chain.index_of("openai/gpt-4o"), Some(0));
assert_eq!(
chain.index_of("anthropic/claude-sonnet-4-20250514"),
Some(1)
);
assert_eq!(chain.index_of("google/gemini-2.0-flash"), Some(2));
assert_eq!(chain.index_of("unknown"), None);
}
#[test]
fn test_contains() {
let chain =
FallbackChain::from_ids(&["openai/gpt-4o", "anthropic/claude-sonnet-4-20250514"])
.unwrap();
assert!(chain.contains("openai/gpt-4o"));
assert!(chain.contains("anthropic/claude-sonnet-4-20250514"));
assert!(!chain.contains("google/gemini-2.0-flash"));
}
#[test]
fn test_iter() {
let chain = FallbackChain::from_ids(&[
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"google/gemini-2.0-flash",
])
.unwrap();
let ids: Vec<_> = chain.iter().map(|m| m.id).collect();
assert_eq!(
ids,
vec!["gpt-4o", "claude-sonnet-4-20250514", "gemini-2.0-flash"]
);
}
#[test]
fn test_is_empty() {
let empty: FallbackChain = FallbackChain::new(vec![]);
assert!(empty.is_empty());
let non_empty = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
assert!(!non_empty.is_empty());
}
#[test]
fn test_models_and_names() {
let chain = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
assert_eq!(chain.models().len(), 1);
assert_eq!(chain.names(), &["openai/gpt-4o"]);
}
#[test]
fn test_from_inclusive() {
let chain = FallbackChain::from_ids(&[
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"google/gemini-2.0-flash",
])
.unwrap();
let remaining = chain
.from_inclusive("anthropic/claude-sonnet-4-20250514")
.unwrap();
assert_eq!(
remaining.names(),
&[
"anthropic/claude-sonnet-4-20250514",
"google/gemini-2.0-flash"
]
);
assert!(chain.from_inclusive("unknown").is_none());
}
#[test]
fn test_from_after() {
let chain = FallbackChain::from_ids(&[
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"google/gemini-2.0-flash",
])
.unwrap();
let remaining = chain
.from_after("anthropic/claude-sonnet-4-20250514")
.unwrap();
assert_eq!(remaining.names(), &["google/gemini-2.0-flash"]);
assert!(chain.from_after("google/gemini-2.0-flash").is_none()); assert!(chain.from_after("unknown").is_none());
}
#[test]
fn test_first_last() {
let chain = FallbackChain::from_ids(&[
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"google/gemini-2.0-flash",
])
.unwrap();
assert_eq!(chain.first().unwrap().id, "gpt-4o");
assert_eq!(chain.last().unwrap().id, "gemini-2.0-flash");
let empty: FallbackChain = FallbackChain::new(vec![]);
assert_eq!(empty.first(), None);
assert_eq!(empty.last(), None);
}
#[test]
fn test_debug_format() {
let chain = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
let debug_str = format!("{:?}", chain);
assert!(debug_str.contains("FallbackChain"));
}
}