use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCoverageProfile {
pub primary: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub fallbacks: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_fallback: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub latency_threshold_ms: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fail_fast: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider_options: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PredefinedModelCoverageProfile {
LowestLatency,
LowestCost,
HighestQuality,
Custom(ModelCoverageProfile),
}
#[derive(Debug, Serialize)]
pub struct RouterConfig {
pub profile: PredefinedModelCoverageProfile,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider_preferences: Option<crate::models::provider_preferences::ProviderPreferences>,
}
pub struct ModelGroups;
impl ModelGroups {
pub fn general() -> ModelCoverageProfile {
ModelCoverageProfile {
primary: "openai/gpt-4o".to_string(),
fallbacks: Some(vec![
"anthropic/claude-3-opus-20240229".to_string(),
"anthropic/claude-3-sonnet-20240229".to_string(),
"google/gemini-1.5-pro".to_string(),
]),
auto_fallback: Some(true),
latency_threshold_ms: Some(10000),
fail_fast: Some(false),
provider_options: None,
}
}
pub fn code() -> ModelCoverageProfile {
ModelCoverageProfile {
primary: "anthropic/claude-3-opus-20240229".to_string(),
fallbacks: Some(vec![
"openai/gpt-4o".to_string(),
"google/gemini-1.5-pro".to_string(),
]),
auto_fallback: Some(true),
latency_threshold_ms: Some(8000),
fail_fast: Some(false),
provider_options: None,
}
}
pub fn long_context() -> ModelCoverageProfile {
ModelCoverageProfile {
primary: "anthropic/claude-3-opus-20240229".to_string(),
fallbacks: Some(vec![
"google/gemini-1.5-pro".to_string(),
"openai/gpt-4-turbo".to_string(),
]),
auto_fallback: Some(true),
latency_threshold_ms: Some(15000),
fail_fast: Some(false),
provider_options: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_groups_have_non_empty_primary_and_fallbacks() {
for profile in [
ModelGroups::general(),
ModelGroups::code(),
ModelGroups::long_context(),
] {
assert!(
!profile.primary.is_empty(),
"primary model must be non-empty",
);
let fallbacks = profile.fallbacks.expect("fallbacks should be Some");
assert!(
!fallbacks.is_empty(),
"fallbacks list must contain at least one model",
);
assert!(
fallbacks.iter().all(|m| !m.is_empty()),
"no fallback model name may be empty",
);
assert_eq!(profile.auto_fallback, Some(true));
assert!(profile.latency_threshold_ms.is_some_and(|ms| ms > 0));
assert_eq!(profile.fail_fast, Some(false));
}
}
#[test]
fn router_config_serializes_predefined_profile_as_snake_case() {
let cfg = RouterConfig {
profile: PredefinedModelCoverageProfile::LowestLatency,
provider_preferences: None,
};
let json = serde_json::to_value(&cfg).unwrap();
assert_eq!(json["profile"], "lowest_latency");
}
#[test]
fn router_config_serializes_custom_profile_as_object() {
let cfg = RouterConfig {
profile: PredefinedModelCoverageProfile::Custom(ModelGroups::general()),
provider_preferences: None,
};
let json = serde_json::to_value(&cfg).unwrap();
assert!(
json["profile"]["custom"].is_object(),
"Custom variant should serialize with nested ModelCoverageProfile object: {}",
json,
);
assert_eq!(json["profile"]["custom"]["primary"], "openai/gpt-4o");
}
}