1use serde::{Deserialize, Serialize};
2use anyhow::{Result, anyhow};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct ModelInfo {
7 pub id: String,
8 pub name: String,
9 pub description: String,
10 #[serde(default)]
11 pub supports_tools: bool,
12 #[serde(default = "default_context_length")]
13 pub context_length: u32,
14}
15
16fn default_context_length() -> u32 {
17 4096
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ProviderModels {
22 pub models: Vec<ModelInfo>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ModelsConfig {
29 #[serde(flatten)]
30 pub providers: HashMap<String, ProviderModels>,
31}
32
33impl ModelsConfig {
34 pub fn load_embedded() -> Result<Self> {
36 let content = include_str!("models.yaml");
38
39 let config: ModelsConfig = serde_yaml::from_str(content)
40 .map_err(|e| anyhow!("Failed to parse embedded models config: {}", e))?;
41
42 Ok(config)
43 }
44
45 pub fn load_with_fallback() -> Self {
47 match Self::load_embedded() {
49 Ok(config) => {
50 tracing::info!("✅ Successfully loaded models from embedded YAML");
51 config
52 }
53 Err(e) => {
54 tracing::warn!("⚠️ Failed to load models from YAML, using defaults: {}", e);
55 Self::default()
56 }
57 }
58 }
59
60 pub fn get_models_for_provider(&self, provider: &str) -> Vec<ModelInfo> {
62 self.providers
63 .get(&provider.to_lowercase())
64 .map(|p| p.models.clone())
65 .unwrap_or_default()
66 }
67
68 #[allow(dead_code)]
70 pub fn get_all_providers(&self) -> Vec<String> {
71 self.providers.keys().cloned().collect()
72 }
73}
74
75impl Default for ModelsConfig {
76 fn default() -> Self {
77 let mut providers = HashMap::new();
78
79 providers.insert("openai".to_string(), ProviderModels {
81 models: vec![
82 ModelInfo {
83 id: "gpt-4o".to_string(),
84 name: "GPT-4o".to_string(),
85 description: "GPT-4 Omni model".to_string(),
86 supports_tools: true,
87 context_length: 128000,
88 },
89 ModelInfo {
90 id: "gpt-4".to_string(),
91 name: "GPT-4".to_string(),
92 description: "Most capable GPT-4 model".to_string(),
93 supports_tools: true,
94 context_length: 8192,
95 },
96 ModelInfo {
97 id: "gpt-3.5-turbo".to_string(),
98 name: "GPT-3.5 Turbo".to_string(),
99 description: "Fast and efficient model".to_string(),
100 supports_tools: true,
101 context_length: 16385,
102 },
103 ],
104 });
105
106 providers.insert("anthropic".to_string(), ProviderModels {
108 models: vec![
109 ModelInfo {
110 id: "claude-3-5-sonnet-20241022".to_string(),
111 name: "Claude 3.5 Sonnet".to_string(),
112 description: "Latest Claude 3.5 Sonnet model".to_string(),
113 supports_tools: true,
114 context_length: 200000,
115 },
116 ModelInfo {
117 id: "claude-3-haiku-20240307".to_string(),
118 name: "Claude 3 Haiku".to_string(),
119 description: "Fast Claude 3 model".to_string(),
120 supports_tools: true,
121 context_length: 200000,
122 },
123 ],
124 });
125
126 providers.insert("zhipu".to_string(), ProviderModels {
128 models: vec![
129 ModelInfo {
130 id: "glm-4-flash".to_string(),
131 name: "GLM-4 Flash".to_string(),
132 description: "Fast GLM-4 model".to_string(),
133 supports_tools: true,
134 context_length: 128000,
135 },
136 ModelInfo {
137 id: "glm-4".to_string(),
138 name: "GLM-4".to_string(),
139 description: "Standard GLM-4 model".to_string(),
140 supports_tools: true,
141 context_length: 128000,
142 },
143 ],
144 });
145
146 providers.insert("ollama".to_string(), ProviderModels {
148 models: vec![
149 ModelInfo {
150 id: "llama3.2".to_string(),
151 name: "Llama 3.2".to_string(),
152 description: "Latest Llama model".to_string(),
153 supports_tools: false,
154 context_length: 128000,
155 },
156 ModelInfo {
157 id: "llama2".to_string(),
158 name: "Llama 2".to_string(),
159 description: "Stable Llama 2 model".to_string(),
160 supports_tools: false,
161 context_length: 4096,
162 },
163 ],
164 });
165
166 providers.insert("aliyun".to_string(), ProviderModels {
168 models: vec![
169 ModelInfo {
170 id: "qwen-turbo".to_string(),
171 name: "Qwen Turbo".to_string(),
172 description: "Fast Qwen model".to_string(),
173 supports_tools: true,
174 context_length: 8192,
175 },
176 ModelInfo {
177 id: "qwen-plus".to_string(),
178 name: "Qwen Plus".to_string(),
179 description: "Enhanced Qwen model".to_string(),
180 supports_tools: true,
181 context_length: 32768,
182 },
183 ],
184 });
185
186 providers.insert("volcengine".to_string(), ProviderModels {
188 models: vec![
189 ModelInfo {
190 id: "doubao-pro-32k".to_string(),
191 name: "Doubao Pro".to_string(),
192 description: "Volcengine Doubao model".to_string(),
193 supports_tools: true,
194 context_length: 32768,
195 },
196 ],
197 });
198
199 providers.insert("tencent".to_string(), ProviderModels {
201 models: vec![
202 ModelInfo {
203 id: "hunyuan-lite".to_string(),
204 name: "Hunyuan Lite".to_string(),
205 description: "Tencent Hunyuan Lite model".to_string(),
206 supports_tools: true,
207 context_length: 256000,
208 },
209 ],
210 });
211
212 providers.insert("longcat".to_string(), ProviderModels {
214 models: vec![
215 ModelInfo {
216 id: "LongCat-Flash-Chat".to_string(),
217 name: "LongCat Flash Chat".to_string(),
218 description: "High-performance general dialogue model".to_string(),
219 supports_tools: false,
220 context_length: 4096,
221 },
222 ],
223 });
224
225 Self { providers }
226 }
227}