1use serde::{Deserialize, Serialize};
2use anyhow::{Result, anyhow};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct ModelInfo {
7 pub id: String,
8 pub name: String,
9 pub description: String,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ProviderModels {
14 pub models: Vec<ModelInfo>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ModelsConfig {
21 #[serde(flatten)]
22 pub providers: HashMap<String, ProviderModels>,
23}
24
25impl ModelsConfig {
26 pub fn load_embedded() -> Result<Self> {
28 let content = include_str!("models.yaml");
30
31 let config: ModelsConfig = serde_yaml::from_str(content)
32 .map_err(|e| anyhow!("Failed to parse embedded models config: {}", e))?;
33
34 Ok(config)
35 }
36
37 pub fn load_with_fallback() -> Self {
39 match Self::load_embedded() {
41 Ok(config) => {
42 tracing::info!("✅ Successfully loaded models from embedded YAML");
43 config
44 }
45 Err(e) => {
46 tracing::warn!("⚠️ Failed to load models from YAML, using defaults: {}", e);
47 Self::default()
48 }
49 }
50 }
51
52 pub fn get_models_for_provider(&self, provider: &str) -> Vec<ModelInfo> {
54 self.providers
55 .get(&provider.to_lowercase())
56 .map(|p| p.models.clone())
57 .unwrap_or_default()
58 }
59
60 #[allow(dead_code)]
62 pub fn get_all_providers(&self) -> Vec<String> {
63 self.providers.keys().cloned().collect()
64 }
65}
66
67impl Default for ModelsConfig {
68 fn default() -> Self {
69 let mut providers = HashMap::new();
70
71 providers.insert("openai".to_string(), ProviderModels {
73 models: vec![
74 ModelInfo {
75 id: "gpt-4o".to_string(),
76 name: "GPT-4o".to_string(),
77 description: "GPT-4 Omni model".to_string(),
78 },
79 ModelInfo {
80 id: "gpt-4".to_string(),
81 name: "GPT-4".to_string(),
82 description: "Most capable GPT-4 model".to_string(),
83 },
84 ModelInfo {
85 id: "gpt-3.5-turbo".to_string(),
86 name: "GPT-3.5 Turbo".to_string(),
87 description: "Fast and efficient model".to_string(),
88 },
89 ],
90 });
91
92 providers.insert("anthropic".to_string(), ProviderModels {
94 models: vec![
95 ModelInfo {
96 id: "claude-3-5-sonnet-20241022".to_string(),
97 name: "Claude 3.5 Sonnet".to_string(),
98 description: "Latest Claude 3.5 Sonnet model".to_string(),
99 },
100 ModelInfo {
101 id: "claude-3-haiku-20240307".to_string(),
102 name: "Claude 3 Haiku".to_string(),
103 description: "Fast Claude 3 model".to_string(),
104 },
105 ],
106 });
107
108 providers.insert("zhipu".to_string(), ProviderModels {
110 models: vec![
111 ModelInfo {
112 id: "glm-4-flash".to_string(),
113 name: "GLM-4 Flash".to_string(),
114 description: "Fast GLM-4 model".to_string(),
115 },
116 ModelInfo {
117 id: "glm-4".to_string(),
118 name: "GLM-4".to_string(),
119 description: "Standard GLM-4 model".to_string(),
120 },
121 ],
122 });
123
124 providers.insert("ollama".to_string(), ProviderModels {
126 models: vec![
127 ModelInfo {
128 id: "llama3.2".to_string(),
129 name: "Llama 3.2".to_string(),
130 description: "Latest Llama model".to_string(),
131 },
132 ModelInfo {
133 id: "llama2".to_string(),
134 name: "Llama 2".to_string(),
135 description: "Stable Llama 2 model".to_string(),
136 },
137 ],
138 });
139
140 providers.insert("aliyun".to_string(), ProviderModels {
142 models: vec![
143 ModelInfo {
144 id: "qwen-turbo".to_string(),
145 name: "Qwen Turbo".to_string(),
146 description: "Fast Qwen model".to_string(),
147 },
148 ModelInfo {
149 id: "qwen-plus".to_string(),
150 name: "Qwen Plus".to_string(),
151 description: "Enhanced Qwen model".to_string(),
152 },
153 ],
154 });
155
156 providers.insert("volcengine".to_string(), ProviderModels {
158 models: vec![
159 ModelInfo {
160 id: "doubao-pro-32k".to_string(),
161 name: "Doubao Pro".to_string(),
162 description: "Volcengine Doubao model".to_string(),
163 },
164 ],
165 });
166
167 providers.insert("tencent".to_string(), ProviderModels {
169 models: vec![
170 ModelInfo {
171 id: "hunyuan-lite".to_string(),
172 name: "Hunyuan Lite".to_string(),
173 description: "Tencent Hunyuan Lite model".to_string(),
174 },
175 ],
176 });
177
178 providers.insert("longcat".to_string(), ProviderModels {
180 models: vec![
181 ModelInfo {
182 id: "LongCat-Flash-Chat".to_string(),
183 name: "LongCat Flash Chat".to_string(),
184 description: "High-performance general dialogue model".to_string(),
185 },
186 ],
187 });
188
189 Self { providers }
190 }
191}