1use serde::{Deserialize, Serialize};
35
36#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
41pub enum ModelProvider {
42 OpenRouter,
45
46 Ollama,
49
50 Mlx,
53}
54
55impl std::fmt::Display for ModelProvider {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 ModelProvider::OpenRouter => write!(f, "OpenRouter"),
59 ModelProvider::Ollama => write!(f, "Ollama"),
60 ModelProvider::Mlx => write!(f, "MLX"),
61 }
62 }
63}
64
65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
70pub struct AiModel {
71 pub display_name: String,
74
75 pub identifier: String,
81
82 pub provider: ModelProvider,
84
85 pub is_free: bool,
88
89 pub context_window: u32,
92}
93
94impl AiModel {
95 #[must_use]
104 pub fn available_models() -> Vec<AiModel> {
105 vec![
106 AiModel {
110 display_name: "Devstral 2".to_string(),
111 identifier: "mistralai/devstral-2512:free".to_string(),
112 provider: ModelProvider::OpenRouter,
113 is_free: true,
114 context_window: 262_000,
115 },
116 AiModel {
117 display_name: "Mistral Small 3.1".to_string(),
118 identifier: "mistralai/mistral-small-3.1-24b-instruct:free".to_string(),
119 provider: ModelProvider::OpenRouter,
120 is_free: true,
121 context_window: 128_000,
122 },
123 AiModel {
127 display_name: "Grok Code Fast".to_string(),
128 identifier: "x-ai/grok-code-fast-1".to_string(),
129 provider: ModelProvider::OpenRouter,
130 is_free: false,
131 context_window: 256_000,
132 },
133 AiModel {
134 display_name: "Claude Sonnet 4.5".to_string(),
135 identifier: "anthropic/claude-sonnet-4.5".to_string(),
136 provider: ModelProvider::OpenRouter,
137 is_free: false,
138 context_window: 1_000_000,
139 },
140 AiModel {
144 display_name: "Mistral 7B (Local)".to_string(),
145 identifier: "mistral:7b".to_string(),
146 provider: ModelProvider::Ollama,
147 is_free: true,
148 context_window: 32_000,
149 },
150 ]
151 }
152
153 #[must_use]
167 pub fn default_free() -> AiModel {
168 Self::available_models()
169 .into_iter()
170 .find(|m| m.is_free && m.provider == ModelProvider::OpenRouter)
171 .expect("Registry must contain at least one free OpenRouter model")
172 }
173
174 #[must_use]
187 pub fn for_provider(provider: ModelProvider) -> Vec<AiModel> {
188 Self::available_models()
189 .into_iter()
190 .filter(|m| m.provider == provider)
191 .collect()
192 }
193
194 #[must_use]
217 pub fn find_by_identifier(identifier: &str) -> Option<AiModel> {
218 Self::available_models()
219 .into_iter()
220 .find(|m| m.identifier == identifier)
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_available_models_not_empty() {
230 let models = AiModel::available_models();
231 assert!(
232 !models.is_empty(),
233 "Registry must contain at least one model"
234 );
235 }
236
237 #[test]
238 fn test_available_models_have_unique_identifiers() {
239 let models = AiModel::available_models();
240 let mut identifiers = Vec::new();
241 for model in &models {
242 assert!(
243 !identifiers.contains(&model.identifier),
244 "Duplicate identifier: {}",
245 model.identifier
246 );
247 identifiers.push(model.identifier.clone());
248 }
249 }
250
251 #[test]
252 fn test_default_free_is_free() {
253 let model = AiModel::default_free();
254 assert!(model.is_free, "Default model must be free");
255 }
256
257 #[test]
258 fn test_default_free_is_openrouter() {
259 let model = AiModel::default_free();
260 assert_eq!(
261 model.provider,
262 ModelProvider::OpenRouter,
263 "Default model must be from OpenRouter"
264 );
265 }
266
267 #[test]
268 fn test_for_provider_openrouter() {
269 let models = AiModel::for_provider(ModelProvider::OpenRouter);
270 assert!(!models.is_empty(), "OpenRouter should have models");
271 assert!(
272 models
273 .iter()
274 .all(|m| m.provider == ModelProvider::OpenRouter),
275 "All returned models should be from OpenRouter"
276 );
277 }
278
279 #[test]
280 fn test_for_provider_ollama() {
281 let models = AiModel::for_provider(ModelProvider::Ollama);
282 assert!(!models.is_empty(), "Ollama should have models");
283 assert!(
284 models.iter().all(|m| m.provider == ModelProvider::Ollama),
285 "All returned models should be from Ollama"
286 );
287 }
288
289 #[test]
290 fn test_for_provider_mlx_empty() {
291 let models = AiModel::for_provider(ModelProvider::Mlx);
292 assert!(
293 models.is_empty(),
294 "MLX should have no models in Phase 1 (reserved for future)"
295 );
296 }
297
298 #[test]
299 fn test_find_by_identifier_devstral() {
300 let model = AiModel::find_by_identifier("mistralai/devstral-2512:free");
301 assert!(model.is_some(), "Should find Devstral model");
302 let model = model.unwrap();
303 assert_eq!(model.display_name, "Devstral 2");
304 assert!(model.is_free);
305 }
306
307 #[test]
308 fn test_find_by_identifier_claude() {
309 let model = AiModel::find_by_identifier("anthropic/claude-sonnet-4.5");
310 assert!(model.is_some(), "Should find Claude model");
311 let model = model.unwrap();
312 assert_eq!(model.display_name, "Claude Sonnet 4.5");
313 assert!(!model.is_free);
314 }
315
316 #[test]
317 fn test_find_by_identifier_not_found() {
318 let model = AiModel::find_by_identifier("nonexistent/model");
319 assert!(model.is_none(), "Should not find nonexistent model");
320 }
321
322 #[test]
323 fn test_find_by_identifier_case_sensitive() {
324 let model = AiModel::find_by_identifier("MISTRALAI/DEVSTRAL-2512:FREE");
325 assert!(
326 model.is_none(),
327 "Identifier lookup should be case-sensitive"
328 );
329 }
330
331 #[test]
332 fn test_model_provider_display() {
333 assert_eq!(ModelProvider::OpenRouter.to_string(), "OpenRouter");
334 assert_eq!(ModelProvider::Ollama.to_string(), "Ollama");
335 assert_eq!(ModelProvider::Mlx.to_string(), "MLX");
336 }
337
338 #[test]
339 fn test_free_models_have_reasonable_context() {
340 let free_models = AiModel::available_models()
341 .into_iter()
342 .filter(|m| m.is_free)
343 .collect::<Vec<_>>();
344
345 assert!(!free_models.is_empty(), "Should have free models");
346 for model in free_models {
347 assert!(
348 model.context_window >= 32_000,
349 "Free model {} should have at least 32K context",
350 model.display_name
351 );
352 }
353 }
354
355 #[test]
356 fn test_paid_models_have_larger_context() {
357 let paid_models = AiModel::available_models()
358 .into_iter()
359 .filter(|m| !m.is_free)
360 .collect::<Vec<_>>();
361
362 assert!(!paid_models.is_empty(), "Should have paid models");
363 for model in paid_models {
364 assert!(
365 model.context_window >= 256_000,
366 "Paid model {} should have at least 256K context",
367 model.display_name
368 );
369 }
370 }
371
372 #[test]
373 fn test_model_serialization() {
374 let model = AiModel::default_free();
375 let json = serde_json::to_string(&model).expect("Should serialize");
376 let deserialized: AiModel = serde_json::from_str(&json).expect("Should deserialize");
377 assert_eq!(model, deserialized);
378 }
379
380 #[test]
381 fn test_model_provider_serialization() {
382 let provider = ModelProvider::OpenRouter;
383 let json = serde_json::to_string(&provider).expect("Should serialize");
384 let deserialized: ModelProvider = serde_json::from_str(&json).expect("Should deserialize");
385 assert_eq!(provider, deserialized);
386 }
387}