1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(untagged)]
32pub enum ModelRef {
33 Shorthand(String),
35 Structured {
37 provider: Provider,
39 model: ModelConfig,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 speed: Option<String>,
44 },
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
51#[serde(rename_all = "snake_case")]
52pub enum Provider {
53 Gemini,
55 Openai,
57 Anthropic,
59 Ollama,
61 OpenaiCompatible,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(untagged)]
86pub enum ModelConfig {
87 Name(String),
89 Compatible {
91 model: String,
93 base_url: String,
95 api_key: String,
97 },
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test_shorthand_parses() {
106 let json = serde_json::json!("gemini-2.5-flash");
107 let model_ref: ModelRef = serde_json::from_value(json).unwrap();
108
109 match model_ref {
110 ModelRef::Shorthand(name) => assert_eq!(name, "gemini-2.5-flash"),
111 _ => panic!("expected Shorthand variant"),
112 }
113 }
114
115 #[test]
116 fn test_shorthand_round_trip() {
117 let original = ModelRef::Shorthand("gpt-4.1".to_string());
118 let json = serde_json::to_value(&original).unwrap();
119 assert_eq!(json, serde_json::json!("gpt-4.1"));
120
121 let deserialized: ModelRef = serde_json::from_value(json).unwrap();
122 match deserialized {
123 ModelRef::Shorthand(name) => assert_eq!(name, "gpt-4.1"),
124 _ => panic!("expected Shorthand variant"),
125 }
126 }
127
128 #[test]
129 fn test_structured_parses() {
130 let json = serde_json::json!({
131 "provider": "openai",
132 "model": "gpt-4.1"
133 });
134 let model_ref: ModelRef = serde_json::from_value(json).unwrap();
135
136 match model_ref {
137 ModelRef::Structured { provider, model, speed } => {
138 assert_eq!(provider, Provider::Openai);
139 match model {
140 ModelConfig::Name(name) => assert_eq!(name, "gpt-4.1"),
141 _ => panic!("expected Name variant"),
142 }
143 assert_eq!(speed, None);
144 }
145 _ => panic!("expected Structured variant"),
146 }
147 }
148
149 #[test]
150 fn test_structured_with_speed() {
151 let json = serde_json::json!({
152 "provider": "gemini",
153 "model": "gemini-2.5-flash",
154 "speed": "fast"
155 });
156 let model_ref: ModelRef = serde_json::from_value(json).unwrap();
157
158 match model_ref {
159 ModelRef::Structured { provider, model, speed } => {
160 assert_eq!(provider, Provider::Gemini);
161 match model {
162 ModelConfig::Name(name) => assert_eq!(name, "gemini-2.5-flash"),
163 _ => panic!("expected Name variant"),
164 }
165 assert_eq!(speed, Some("fast".to_string()));
166 }
167 _ => panic!("expected Structured variant"),
168 }
169 }
170
171 #[test]
172 fn test_openai_compatible_with_base_url() {
173 let json = serde_json::json!({
174 "provider": "openai_compatible",
175 "model": {
176 "model": "deepseek-chat",
177 "base_url": "https://api.deepseek.com/v1",
178 "api_key": "sk-test-key-123"
179 }
180 });
181 let model_ref: ModelRef = serde_json::from_value(json).unwrap();
182
183 match model_ref {
184 ModelRef::Structured { provider, model, speed } => {
185 assert_eq!(provider, Provider::OpenaiCompatible);
186 match model {
187 ModelConfig::Compatible { model, base_url, api_key } => {
188 assert_eq!(model, "deepseek-chat");
189 assert_eq!(base_url, "https://api.deepseek.com/v1");
190 assert_eq!(api_key, "sk-test-key-123");
191 }
192 _ => panic!("expected Compatible variant"),
193 }
194 assert_eq!(speed, None);
195 }
196 _ => panic!("expected Structured variant"),
197 }
198 }
199
200 #[test]
201 fn test_provider_serialization() {
202 assert_eq!(serde_json::to_value(Provider::Gemini).unwrap(), serde_json::json!("gemini"));
203 assert_eq!(serde_json::to_value(Provider::Openai).unwrap(), serde_json::json!("openai"));
204 assert_eq!(
205 serde_json::to_value(Provider::Anthropic).unwrap(),
206 serde_json::json!("anthropic")
207 );
208 assert_eq!(serde_json::to_value(Provider::Ollama).unwrap(), serde_json::json!("ollama"));
209 assert_eq!(
210 serde_json::to_value(Provider::OpenaiCompatible).unwrap(),
211 serde_json::json!("openai_compatible")
212 );
213 }
214
215 #[test]
216 fn test_model_config_name() {
217 let json = serde_json::json!("claude-3.5-sonnet");
218 let config: ModelConfig = serde_json::from_value(json).unwrap();
219
220 match config {
221 ModelConfig::Name(name) => assert_eq!(name, "claude-3.5-sonnet"),
222 _ => panic!("expected Name variant"),
223 }
224 }
225
226 #[test]
227 fn test_model_config_compatible() {
228 let json = serde_json::json!({
229 "model": "local-llama",
230 "base_url": "http://localhost:11434/v1",
231 "api_key": "ollama"
232 });
233 let config: ModelConfig = serde_json::from_value(json).unwrap();
234
235 match config {
236 ModelConfig::Compatible { model, base_url, api_key } => {
237 assert_eq!(model, "local-llama");
238 assert_eq!(base_url, "http://localhost:11434/v1");
239 assert_eq!(api_key, "ollama");
240 }
241 _ => panic!("expected Compatible variant"),
242 }
243 }
244
245 #[test]
246 fn test_structured_speed_omitted_in_serialization() {
247 let model_ref = ModelRef::Structured {
248 provider: Provider::Anthropic,
249 model: ModelConfig::Name("claude-3.5-sonnet".to_string()),
250 speed: None,
251 };
252 let json = serde_json::to_value(&model_ref).unwrap();
253 assert!(!json.as_object().unwrap().contains_key("speed"));
254 }
255}