llm_link/models/
mod.rs

1use serde::{Deserialize, Serialize};
2use anyhow::{Result, anyhow};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct ModelInfo {
7    pub id: String,
8    pub name: String,
9    pub description: String,
10    #[serde(default)]
11    pub supports_tools: bool,
12    #[serde(default = "default_context_length")]
13    pub context_length: u32,
14}
15
16fn default_context_length() -> u32 {
17    4096
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ProviderModels {
22    pub models: Vec<ModelInfo>,
23}
24
25/// Models configuration using HashMap for flexible provider support
26/// This allows any provider to be added in models.yaml without code changes
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ModelsConfig {
29    #[serde(flatten)]
30    pub providers: HashMap<String, ProviderModels>,
31}
32
33impl ModelsConfig {
34    /// Load models configuration from embedded YAML
35    pub fn load_embedded() -> Result<Self> {
36        // Load embedded models.yaml from src/models/models.yaml
37        let content = include_str!("models.yaml");
38
39        let config: ModelsConfig = serde_yaml::from_str(content)
40            .map_err(|e| anyhow!("Failed to parse embedded models config: {}", e))?;
41
42        Ok(config)
43    }
44
45    /// Load models configuration with fallback to default
46    pub fn load_with_fallback() -> Self {
47        // Try to load from embedded YAML first
48        match Self::load_embedded() {
49            Ok(config) => {
50                tracing::info!("✅ Successfully loaded models from embedded YAML");
51                config
52            }
53            Err(e) => {
54                tracing::warn!("⚠️ Failed to load models from YAML, using defaults: {}", e);
55                Self::default()
56            }
57        }
58    }
59
60    /// Get models for a specific provider
61    pub fn get_models_for_provider(&self, provider: &str) -> Vec<ModelInfo> {
62        self.providers
63            .get(&provider.to_lowercase())
64            .map(|p| p.models.clone())
65            .unwrap_or_default()
66    }
67
68    /// Get all provider names
69    #[allow(dead_code)]
70    pub fn get_all_providers(&self) -> Vec<String> {
71        self.providers.keys().cloned().collect()
72    }
73}
74
75impl Default for ModelsConfig {
76    fn default() -> Self {
77        let mut providers = HashMap::new();
78
79        // OpenAI
80        providers.insert("openai".to_string(), ProviderModels {
81            models: vec![
82                ModelInfo {
83                    id: "gpt-4o".to_string(),
84                    name: "GPT-4o".to_string(),
85                    description: "GPT-4 Omni model".to_string(),
86                    supports_tools: true,
87                    context_length: 128000,
88                },
89                ModelInfo {
90                    id: "gpt-4".to_string(),
91                    name: "GPT-4".to_string(),
92                    description: "Most capable GPT-4 model".to_string(),
93                    supports_tools: true,
94                    context_length: 8192,
95                },
96                ModelInfo {
97                    id: "gpt-3.5-turbo".to_string(),
98                    name: "GPT-3.5 Turbo".to_string(),
99                    description: "Fast and efficient model".to_string(),
100                    supports_tools: true,
101                    context_length: 16385,
102                },
103            ],
104        });
105
106        // Anthropic
107        providers.insert("anthropic".to_string(), ProviderModels {
108            models: vec![
109                ModelInfo {
110                    id: "claude-3-5-sonnet-20241022".to_string(),
111                    name: "Claude 3.5 Sonnet".to_string(),
112                    description: "Latest Claude 3.5 Sonnet model".to_string(),
113                    supports_tools: true,
114                    context_length: 200000,
115                },
116                ModelInfo {
117                    id: "claude-3-haiku-20240307".to_string(),
118                    name: "Claude 3 Haiku".to_string(),
119                    description: "Fast Claude 3 model".to_string(),
120                    supports_tools: true,
121                    context_length: 200000,
122                },
123            ],
124        });
125
126        // Zhipu
127        providers.insert("zhipu".to_string(), ProviderModels {
128            models: vec![
129                ModelInfo {
130                    id: "glm-4-flash".to_string(),
131                    name: "GLM-4 Flash".to_string(),
132                    description: "Fast GLM-4 model".to_string(),
133                    supports_tools: true,
134                    context_length: 128000,
135                },
136                ModelInfo {
137                    id: "glm-4".to_string(),
138                    name: "GLM-4".to_string(),
139                    description: "Standard GLM-4 model".to_string(),
140                    supports_tools: true,
141                    context_length: 128000,
142                },
143            ],
144        });
145
146        // Ollama
147        providers.insert("ollama".to_string(), ProviderModels {
148            models: vec![
149                ModelInfo {
150                    id: "llama3.2".to_string(),
151                    name: "Llama 3.2".to_string(),
152                    description: "Latest Llama model".to_string(),
153                    supports_tools: false,
154                    context_length: 128000,
155                },
156                ModelInfo {
157                    id: "llama2".to_string(),
158                    name: "Llama 2".to_string(),
159                    description: "Stable Llama 2 model".to_string(),
160                    supports_tools: false,
161                    context_length: 4096,
162                },
163            ],
164        });
165
166        // Aliyun
167        providers.insert("aliyun".to_string(), ProviderModels {
168            models: vec![
169                ModelInfo {
170                    id: "qwen-turbo".to_string(),
171                    name: "Qwen Turbo".to_string(),
172                    description: "Fast Qwen model".to_string(),
173                    supports_tools: true,
174                    context_length: 8192,
175                },
176                ModelInfo {
177                    id: "qwen-plus".to_string(),
178                    name: "Qwen Plus".to_string(),
179                    description: "Enhanced Qwen model".to_string(),
180                    supports_tools: true,
181                    context_length: 32768,
182                },
183            ],
184        });
185
186        // Volcengine
187        providers.insert("volcengine".to_string(), ProviderModels {
188            models: vec![
189                ModelInfo {
190                    id: "doubao-pro-32k".to_string(),
191                    name: "Doubao Pro".to_string(),
192                    description: "Volcengine Doubao model".to_string(),
193                    supports_tools: true,
194                    context_length: 32768,
195                },
196            ],
197        });
198
199        // Tencent
200        providers.insert("tencent".to_string(), ProviderModels {
201            models: vec![
202                ModelInfo {
203                    id: "hunyuan-lite".to_string(),
204                    name: "Hunyuan Lite".to_string(),
205                    description: "Tencent Hunyuan Lite model".to_string(),
206                    supports_tools: true,
207                    context_length: 256000,
208                },
209            ],
210        });
211
212        // Longcat
213        providers.insert("longcat".to_string(), ProviderModels {
214            models: vec![
215                ModelInfo {
216                    id: "LongCat-Flash-Chat".to_string(),
217                    name: "LongCat Flash Chat".to_string(),
218                    description: "High-performance general dialogue model".to_string(),
219                    supports_tools: false,
220                    context_length: 4096,
221                },
222            ],
223        });
224
225        Self { providers }
226    }
227}