1use serde::{Deserialize, Serialize};
2use anyhow::{Result, anyhow};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ModelInfo {
6 pub id: String,
7 pub name: String,
8 pub description: String,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderModels {
13 pub models: Vec<ModelInfo>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ModelsConfig {
18 pub openai: ProviderModels,
19 pub anthropic: ProviderModels,
20 pub zhipu: ProviderModels,
21 pub ollama: ProviderModels,
22 pub aliyun: ProviderModels,
23}
24
25impl ModelsConfig {
26 pub fn load_embedded() -> Result<Self> {
28 let content = include_str!("models.yaml");
30
31 let config: ModelsConfig = serde_yaml::from_str(content)
32 .map_err(|e| anyhow!("Failed to parse embedded models config: {}", e))?;
33
34 Ok(config)
35 }
36
37 pub fn load_with_fallback() -> Self {
39 if let Ok(config) = Self::load_embedded() {
41 return config;
42 }
43
44 Self::default()
46 }
47
48 pub fn get_models_for_provider(&self, provider: &str) -> Vec<ModelInfo> {
50 match provider.to_lowercase().as_str() {
51 "openai" => self.openai.models.clone(),
52 "anthropic" => self.anthropic.models.clone(),
53 "zhipu" => self.zhipu.models.clone(),
54 "ollama" => self.ollama.models.clone(),
55 "aliyun" => self.aliyun.models.clone(),
56 _ => vec![],
57 }
58 }
59}
60
61impl Default for ModelsConfig {
62 fn default() -> Self {
63 Self {
64 openai: ProviderModels {
65 models: vec![
66 ModelInfo {
67 id: "gpt-4o".to_string(),
68 name: "GPT-4o".to_string(),
69 description: "GPT-4 Omni model".to_string(),
70 },
71 ModelInfo {
72 id: "gpt-4".to_string(),
73 name: "GPT-4".to_string(),
74 description: "Most capable GPT-4 model".to_string(),
75 },
76 ModelInfo {
77 id: "gpt-3.5-turbo".to_string(),
78 name: "GPT-3.5 Turbo".to_string(),
79 description: "Fast and efficient model".to_string(),
80 },
81 ],
82 },
83 anthropic: ProviderModels {
84 models: vec![
85 ModelInfo {
86 id: "claude-3-5-sonnet-20241022".to_string(),
87 name: "Claude 3.5 Sonnet".to_string(),
88 description: "Latest Claude 3.5 Sonnet model".to_string(),
89 },
90 ModelInfo {
91 id: "claude-3-haiku-20240307".to_string(),
92 name: "Claude 3 Haiku".to_string(),
93 description: "Fast Claude 3 model".to_string(),
94 },
95 ],
96 },
97 zhipu: ProviderModels {
98 models: vec![
99 ModelInfo {
100 id: "glm-4-flash".to_string(),
101 name: "GLM-4 Flash".to_string(),
102 description: "Fast GLM-4 model".to_string(),
103 },
104 ModelInfo {
105 id: "glm-4".to_string(),
106 name: "GLM-4".to_string(),
107 description: "Standard GLM-4 model".to_string(),
108 },
109 ],
110 },
111 ollama: ProviderModels {
112 models: vec![
113 ModelInfo {
114 id: "llama3.2".to_string(),
115 name: "Llama 3.2".to_string(),
116 description: "Latest Llama model".to_string(),
117 },
118 ModelInfo {
119 id: "llama2".to_string(),
120 name: "Llama 2".to_string(),
121 description: "Stable Llama 2 model".to_string(),
122 },
123 ModelInfo {
124 id: "codellama".to_string(),
125 name: "Code Llama".to_string(),
126 description: "Code-specialized model".to_string(),
127 },
128 ModelInfo {
129 id: "mistral".to_string(),
130 name: "Mistral".to_string(),
131 description: "Mistral 7B model".to_string(),
132 },
133 ],
134 },
135 aliyun: ProviderModels {
136 models: vec![
137 ModelInfo {
138 id: "qwen-turbo".to_string(),
139 name: "Qwen Turbo".to_string(),
140 description: "Fast Qwen model".to_string(),
141 },
142 ModelInfo {
143 id: "qwen-plus".to_string(),
144 name: "Qwen Plus".to_string(),
145 description: "Enhanced Qwen model".to_string(),
146 },
147 ],
148 },
149 }
150 }
151}