llm_link/models/
mod.rs

1use serde::{Deserialize, Serialize};
2use anyhow::{Result, anyhow};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct ModelInfo {
7    pub id: String,
8    pub name: String,
9    pub description: String,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ProviderModels {
14    pub models: Vec<ModelInfo>,
15}
16
17/// Models configuration using HashMap for flexible provider support
18/// This allows any provider to be added in models.yaml without code changes
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ModelsConfig {
21    #[serde(flatten)]
22    pub providers: HashMap<String, ProviderModels>,
23}
24
25impl ModelsConfig {
26    /// Load models configuration from embedded YAML
27    pub fn load_embedded() -> Result<Self> {
28        // Load embedded models.yaml from src/models/models.yaml
29        let content = include_str!("models.yaml");
30
31        let config: ModelsConfig = serde_yaml::from_str(content)
32            .map_err(|e| anyhow!("Failed to parse embedded models config: {}", e))?;
33
34        Ok(config)
35    }
36
37    /// Load models configuration with fallback to default
38    pub fn load_with_fallback() -> Self {
39        // Try to load from embedded YAML first
40        match Self::load_embedded() {
41            Ok(config) => {
42                tracing::info!("✅ Successfully loaded models from embedded YAML");
43                config
44            }
45            Err(e) => {
46                tracing::warn!("⚠️ Failed to load models from YAML, using defaults: {}", e);
47                Self::default()
48            }
49        }
50    }
51
52    /// Get models for a specific provider
53    pub fn get_models_for_provider(&self, provider: &str) -> Vec<ModelInfo> {
54        self.providers
55            .get(&provider.to_lowercase())
56            .map(|p| p.models.clone())
57            .unwrap_or_default()
58    }
59
60    /// Get all provider names
61    #[allow(dead_code)]
62    pub fn get_all_providers(&self) -> Vec<String> {
63        self.providers.keys().cloned().collect()
64    }
65}
66
67impl Default for ModelsConfig {
68    fn default() -> Self {
69        let mut providers = HashMap::new();
70
71        // OpenAI
72        providers.insert("openai".to_string(), ProviderModels {
73            models: vec![
74                ModelInfo {
75                    id: "gpt-4o".to_string(),
76                    name: "GPT-4o".to_string(),
77                    description: "GPT-4 Omni model".to_string(),
78                },
79                ModelInfo {
80                    id: "gpt-4".to_string(),
81                    name: "GPT-4".to_string(),
82                    description: "Most capable GPT-4 model".to_string(),
83                },
84                ModelInfo {
85                    id: "gpt-3.5-turbo".to_string(),
86                    name: "GPT-3.5 Turbo".to_string(),
87                    description: "Fast and efficient model".to_string(),
88                },
89            ],
90        });
91
92        // Anthropic
93        providers.insert("anthropic".to_string(), ProviderModels {
94            models: vec![
95                ModelInfo {
96                    id: "claude-3-5-sonnet-20241022".to_string(),
97                    name: "Claude 3.5 Sonnet".to_string(),
98                    description: "Latest Claude 3.5 Sonnet model".to_string(),
99                },
100                ModelInfo {
101                    id: "claude-3-haiku-20240307".to_string(),
102                    name: "Claude 3 Haiku".to_string(),
103                    description: "Fast Claude 3 model".to_string(),
104                },
105            ],
106        });
107
108        // Zhipu
109        providers.insert("zhipu".to_string(), ProviderModels {
110            models: vec![
111                ModelInfo {
112                    id: "glm-4-flash".to_string(),
113                    name: "GLM-4 Flash".to_string(),
114                    description: "Fast GLM-4 model".to_string(),
115                },
116                ModelInfo {
117                    id: "glm-4".to_string(),
118                    name: "GLM-4".to_string(),
119                    description: "Standard GLM-4 model".to_string(),
120                },
121            ],
122        });
123
124        // Ollama
125        providers.insert("ollama".to_string(), ProviderModels {
126            models: vec![
127                ModelInfo {
128                    id: "llama3.2".to_string(),
129                    name: "Llama 3.2".to_string(),
130                    description: "Latest Llama model".to_string(),
131                },
132                ModelInfo {
133                    id: "llama2".to_string(),
134                    name: "Llama 2".to_string(),
135                    description: "Stable Llama 2 model".to_string(),
136                },
137            ],
138        });
139
140        // Aliyun
141        providers.insert("aliyun".to_string(), ProviderModels {
142            models: vec![
143                ModelInfo {
144                    id: "qwen-turbo".to_string(),
145                    name: "Qwen Turbo".to_string(),
146                    description: "Fast Qwen model".to_string(),
147                },
148                ModelInfo {
149                    id: "qwen-plus".to_string(),
150                    name: "Qwen Plus".to_string(),
151                    description: "Enhanced Qwen model".to_string(),
152                },
153            ],
154        });
155
156        // Volcengine
157        providers.insert("volcengine".to_string(), ProviderModels {
158            models: vec![
159                ModelInfo {
160                    id: "doubao-pro-32k".to_string(),
161                    name: "Doubao Pro".to_string(),
162                    description: "Volcengine Doubao model".to_string(),
163                },
164            ],
165        });
166
167        // Tencent
168        providers.insert("tencent".to_string(), ProviderModels {
169            models: vec![
170                ModelInfo {
171                    id: "hunyuan-lite".to_string(),
172                    name: "Hunyuan Lite".to_string(),
173                    description: "Tencent Hunyuan Lite model".to_string(),
174                },
175            ],
176        });
177
178        // Longcat
179        providers.insert("longcat".to_string(), ProviderModels {
180            models: vec![
181                ModelInfo {
182                    id: "LongCat-Flash-Chat".to_string(),
183                    name: "LongCat Flash Chat".to_string(),
184                    description: "High-performance general dialogue model".to_string(),
185                },
186            ],
187        });
188
189        Self { providers }
190    }
191}