use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCoverageProfile {
pub primary: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub fallbacks: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_fallback: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub latency_threshold_ms: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fail_fast: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider_options: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PredefinedModelCoverageProfile {
LowestLatency,
LowestCost,
HighestQuality,
Custom(ModelCoverageProfile),
}
#[derive(Debug, Serialize)]
pub struct RouterConfig {
pub profile: PredefinedModelCoverageProfile,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider_preferences: Option<crate::types::provider::ProviderPreferences>,
}
pub struct ModelGroups;
impl ModelGroups {
pub fn general() -> ModelCoverageProfile {
ModelCoverageProfile {
primary: "openai/gpt-4o".to_string(),
fallbacks: Some(vec![
"anthropic/claude-3-opus-20240229".to_string(),
"anthropic/claude-3-sonnet-20240229".to_string(),
"google/gemini-1.5-pro".to_string(),
]),
auto_fallback: Some(true),
latency_threshold_ms: Some(10000),
fail_fast: Some(false),
provider_options: None,
}
}
pub fn code() -> ModelCoverageProfile {
ModelCoverageProfile {
primary: "anthropic/claude-3-opus-20240229".to_string(),
fallbacks: Some(vec![
"openai/gpt-4o".to_string(),
"google/gemini-1.5-pro".to_string(),
]),
auto_fallback: Some(true),
latency_threshold_ms: Some(8000),
fail_fast: Some(false),
provider_options: None,
}
}
pub fn long_context() -> ModelCoverageProfile {
ModelCoverageProfile {
primary: "anthropic/claude-3-opus-20240229".to_string(),
fallbacks: Some(vec![
"google/gemini-1.5-pro".to_string(),
"openai/gpt-4-turbo".to_string(),
]),
auto_fallback: Some(true),
latency_threshold_ms: Some(15000),
fail_fast: Some(false),
provider_options: None,
}
}
}