1use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24
25use crate::error::{AiError, Result};
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ModelMetrics {
30 pub accuracy: f64,
32 pub avg_tokens: u32,
34 pub avg_cost: f64,
36 pub total_requests: u64,
38 pub total_errors: u64,
40 pub avg_latency_ms: f64,
42}
43
44impl ModelMetrics {
45 #[must_use]
47 pub fn new(accuracy: f64, avg_tokens: u32, avg_cost: f64) -> Self {
48 Self {
49 accuracy: accuracy.clamp(0.0, 100.0),
50 avg_tokens,
51 avg_cost,
52 total_requests: 0,
53 total_errors: 0,
54 avg_latency_ms: 0.0,
55 }
56 }
57
58 #[must_use]
60 pub fn empty() -> Self {
61 Self {
62 accuracy: 0.0,
63 avg_tokens: 0,
64 avg_cost: 0.0,
65 total_requests: 0,
66 total_errors: 0,
67 avg_latency_ms: 0.0,
68 }
69 }
70
71 pub fn record_request(&mut self, tokens: u32, cost: f64, latency_ms: f64, success: bool) {
73 let n = self.total_requests as f64;
74
75 self.avg_tokens = ((f64::from(self.avg_tokens) * n + f64::from(tokens)) / (n + 1.0)) as u32;
77 self.avg_cost = (self.avg_cost * n + cost) / (n + 1.0);
78 self.avg_latency_ms = (self.avg_latency_ms * n + latency_ms) / (n + 1.0);
79
80 self.total_requests += 1;
81 if !success {
82 self.total_errors += 1;
83 }
84
85 let success_count = self.total_requests - self.total_errors;
87 self.accuracy = (success_count as f64 / self.total_requests as f64) * 100.0;
88 }
89
90 #[must_use]
92 pub fn error_rate(&self) -> f64 {
93 if self.total_requests == 0 {
94 0.0
95 } else {
96 (self.total_errors as f64 / self.total_requests as f64) * 100.0
97 }
98 }
99
100 #[must_use]
102 pub fn cost_effectiveness(&self) -> f64 {
103 if self.avg_cost == 0.0 {
104 0.0
105 } else {
106 self.accuracy / self.avg_cost
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ModelVersion {
114 pub model_name: String,
116 pub version: String,
118 pub description: String,
120 pub metrics: ModelMetrics,
122 pub release_date: chrono::DateTime<chrono::Utc>,
124 pub deprecated_at: Option<chrono::DateTime<chrono::Utc>>,
126 pub active: bool,
128 pub tags: Vec<String>,
130}
131
132impl ModelVersion {
133 pub fn new(
135 model_name: impl Into<String>,
136 version: impl Into<String>,
137 description: impl Into<String>,
138 ) -> Self {
139 Self {
140 model_name: model_name.into(),
141 version: version.into(),
142 description: description.into(),
143 metrics: ModelMetrics::empty(),
144 release_date: chrono::Utc::now(),
145 deprecated_at: None,
146 active: true,
147 tags: Vec::new(),
148 }
149 }
150
151 #[must_use]
153 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
154 self.tags = tags;
155 self
156 }
157
158 #[must_use]
160 pub fn with_release_date(mut self, date: chrono::DateTime<chrono::Utc>) -> Self {
161 self.release_date = date;
162 self
163 }
164
165 #[must_use]
167 pub fn with_metrics(mut self, metrics: ModelMetrics) -> Self {
168 self.metrics = metrics;
169 self
170 }
171
172 pub fn deprecate(&mut self) {
174 self.deprecated_at = Some(chrono::Utc::now());
175 self.active = false;
176 }
177
178 #[must_use]
180 pub fn is_deprecated(&self) -> bool {
181 self.deprecated_at.is_some()
182 }
183
184 #[must_use]
186 pub fn id(&self) -> String {
187 format!("{}:{}", self.model_name, self.version)
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct ModelRegistry {
194 versions: HashMap<String, ModelVersion>,
195 active_versions: HashMap<String, String>, }
197
198impl Default for ModelRegistry {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl ModelRegistry {
205 #[must_use]
207 pub fn new() -> Self {
208 Self {
209 versions: HashMap::new(),
210 active_versions: HashMap::new(),
211 }
212 }
213
214 pub fn register_version(&mut self, version: ModelVersion) -> Result<()> {
216 let id = version.id();
217 let model_name = version.model_name.clone();
218
219 if version.active {
221 self.active_versions
222 .insert(model_name, version.version.clone());
223 }
224
225 self.versions.insert(id, version);
226 Ok(())
227 }
228
229 #[must_use]
231 pub fn get_version(&self, model_name: &str, version: &str) -> Option<&ModelVersion> {
232 let id = format!("{model_name}:{version}");
233 self.versions.get(&id)
234 }
235
236 pub fn get_version_mut(
238 &mut self,
239 model_name: &str,
240 version: &str,
241 ) -> Option<&mut ModelVersion> {
242 let id = format!("{model_name}:{version}");
243 self.versions.get_mut(&id)
244 }
245
246 #[must_use]
248 pub fn get_active_version(&self, model_name: &str) -> Option<&ModelVersion> {
249 let version = self.active_versions.get(model_name)?;
250 self.get_version(model_name, version)
251 }
252
253 pub fn set_active_version(&mut self, model_name: &str, version: &str) -> Result<()> {
255 let id = format!("{model_name}:{version}");
256
257 if !self.versions.contains_key(&id) {
258 return Err(AiError::NotFound(format!("Model version {id} not found")));
259 }
260
261 self.active_versions
262 .insert(model_name.to_string(), version.to_string());
263 Ok(())
264 }
265
266 #[must_use]
268 pub fn get_model_versions(&self, model_name: &str) -> Vec<&ModelVersion> {
269 self.versions
270 .values()
271 .filter(|v| v.model_name == model_name)
272 .collect()
273 }
274
275 pub fn update_metrics(
277 &mut self,
278 model_name: &str,
279 version: &str,
280 metrics: ModelMetrics,
281 ) -> Result<()> {
282 let id = format!("{model_name}:{version}");
283
284 let model = self
285 .versions
286 .get_mut(&id)
287 .ok_or_else(|| AiError::NotFound(format!("Model version {id} not found")))?;
288
289 model.metrics = metrics;
290 Ok(())
291 }
292
293 pub fn record_request(
295 &mut self,
296 model_name: &str,
297 version: &str,
298 tokens: u32,
299 cost: f64,
300 latency_ms: f64,
301 success: bool,
302 ) -> Result<()> {
303 let id = format!("{model_name}:{version}");
304
305 let model = self
306 .versions
307 .get_mut(&id)
308 .ok_or_else(|| AiError::NotFound(format!("Model version {id} not found")))?;
309
310 model
311 .metrics
312 .record_request(tokens, cost, latency_ms, success);
313 Ok(())
314 }
315
316 #[must_use]
318 pub fn compare_versions(
319 &self,
320 model1: &str,
321 version1: &str,
322 model2: &str,
323 version2: &str,
324 ) -> Option<VersionComparison> {
325 let v1 = self.get_version(model1, version1)?;
326 let v2 = self.get_version(model2, version2)?;
327
328 Some(VersionComparison {
329 version1: v1.clone(),
330 version2: v2.clone(),
331 accuracy_diff: v1.metrics.accuracy - v2.metrics.accuracy,
332 cost_diff: v1.metrics.avg_cost - v2.metrics.avg_cost,
333 latency_diff: v1.metrics.avg_latency_ms - v2.metrics.avg_latency_ms,
334 })
335 }
336
337 #[must_use]
339 pub fn get_best_version(&self, model_name: &str) -> Option<&ModelVersion> {
340 self.get_model_versions(model_name)
341 .into_iter()
342 .max_by(|a, b| a.metrics.accuracy.partial_cmp(&b.metrics.accuracy).unwrap())
343 }
344
345 #[must_use]
347 pub fn get_most_cost_effective(&self, model_name: &str) -> Option<&ModelVersion> {
348 self.get_model_versions(model_name)
349 .into_iter()
350 .max_by(|a, b| {
351 a.metrics
352 .cost_effectiveness()
353 .partial_cmp(&b.metrics.cost_effectiveness())
354 .unwrap()
355 })
356 }
357
358 #[must_use]
360 pub fn list_models(&self) -> Vec<String> {
361 let mut models: Vec<String> = self
362 .versions
363 .values()
364 .map(|v| v.model_name.clone())
365 .collect();
366 models.sort();
367 models.dedup();
368 models
369 }
370
371 pub fn deprecate_version(&mut self, model_name: &str, version: &str) -> Result<()> {
373 let id = format!("{model_name}:{version}");
374
375 let model = self
376 .versions
377 .get_mut(&id)
378 .ok_or_else(|| AiError::NotFound(format!("Model version {id} not found")))?;
379
380 model.deprecate();
381
382 if let Some(active) = self.active_versions.get(model_name) {
384 if active == version {
385 self.active_versions.remove(model_name);
386 }
387 }
388
389 Ok(())
390 }
391
392 pub fn save_to_file(&self, path: impl AsRef<std::path::Path>) -> Result<()> {
394 let json = serde_json::to_string_pretty(self)
395 .map_err(|e| AiError::Internal(format!("Failed to serialize registry: {e}")))?;
396
397 std::fs::write(path, json)
398 .map_err(|e| AiError::Internal(format!("Failed to write registry: {e}")))?;
399
400 Ok(())
401 }
402
403 pub fn load_from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
405 let json = std::fs::read_to_string(path)
406 .map_err(|e| AiError::Internal(format!("Failed to read registry: {e}")))?;
407
408 let registry: ModelRegistry = serde_json::from_str(&json)
409 .map_err(|e| AiError::Internal(format!("Failed to deserialize registry: {e}")))?;
410
411 Ok(registry)
412 }
413
414 #[must_use]
416 pub fn len(&self) -> usize {
417 self.versions.len()
418 }
419
420 #[must_use]
422 pub fn is_empty(&self) -> bool {
423 self.versions.is_empty()
424 }
425}
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct VersionComparison {
430 pub version1: ModelVersion,
431 pub version2: ModelVersion,
432 pub accuracy_diff: f64,
434 pub cost_diff: f64,
436 pub latency_diff: f64,
438}
439
440impl VersionComparison {
441 #[must_use]
443 pub fn recommendation(&self) -> &'static str {
444 if self.accuracy_diff > 5.0 && self.cost_diff < 0.01 {
445 "Version 1 is significantly more accurate with similar cost"
446 } else if self.accuracy_diff < -5.0 && self.cost_diff > -0.01 {
447 "Version 2 is significantly more accurate with similar cost"
448 } else if self.cost_diff < -0.005 && self.accuracy_diff.abs() < 2.0 {
449 "Version 1 is more cost-effective with similar accuracy"
450 } else if self.cost_diff > 0.005 && self.accuracy_diff.abs() < 2.0 {
451 "Version 2 is more cost-effective with similar accuracy"
452 } else {
453 "Versions have similar performance characteristics"
454 }
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_model_metrics_creation() {
464 let metrics = ModelMetrics::new(95.5, 1000, 0.02);
465 assert_eq!(metrics.accuracy, 95.5);
466 assert_eq!(metrics.avg_tokens, 1000);
467 assert_eq!(metrics.avg_cost, 0.02);
468 }
469
470 #[test]
471 fn test_model_metrics_record_request() {
472 let mut metrics = ModelMetrics::empty();
473
474 metrics.record_request(1000, 0.02, 150.0, true);
475 assert_eq!(metrics.total_requests, 1);
476 assert_eq!(metrics.accuracy, 100.0);
477
478 metrics.record_request(1200, 0.025, 180.0, false);
479 assert_eq!(metrics.total_requests, 2);
480 assert_eq!(metrics.total_errors, 1);
481 assert_eq!(metrics.accuracy, 50.0);
482 }
483
484 #[test]
485 fn test_model_version_creation() {
486 let version = ModelVersion::new("gpt-4-turbo", "20240301", "GPT-4 Turbo March");
487 assert_eq!(version.model_name, "gpt-4-turbo");
488 assert_eq!(version.version, "20240301");
489 assert!(version.active);
490 assert!(!version.is_deprecated());
491 }
492
493 #[test]
494 fn test_model_registry() {
495 let mut registry = ModelRegistry::new();
496
497 let version = ModelVersion::new("gpt-4-turbo", "20240301", "Test version");
498 registry.register_version(version).unwrap();
499
500 assert_eq!(registry.len(), 1);
501
502 let retrieved = registry.get_version("gpt-4-turbo", "20240301");
503 assert!(retrieved.is_some());
504 }
505
506 #[test]
507 fn test_active_version() {
508 let mut registry = ModelRegistry::new();
509
510 let v1 = ModelVersion::new("gpt-4", "v1", "Version 1");
511 let v2 = ModelVersion::new("gpt-4", "v2", "Version 2");
512
513 registry.register_version(v1).unwrap();
514 registry.register_version(v2).unwrap();
515
516 registry.set_active_version("gpt-4", "v2").unwrap();
517
518 let active = registry.get_active_version("gpt-4").unwrap();
519 assert_eq!(active.version, "v2");
520 }
521
522 #[test]
523 fn test_version_comparison() {
524 let mut registry = ModelRegistry::new();
525
526 let v1 = ModelVersion::new("gpt-4", "v1", "V1")
527 .with_metrics(ModelMetrics::new(90.0, 1000, 0.02));
528 let v2 = ModelVersion::new("gpt-4", "v2", "V2")
529 .with_metrics(ModelMetrics::new(95.0, 1200, 0.025));
530
531 registry.register_version(v1).unwrap();
532 registry.register_version(v2).unwrap();
533
534 let comparison = registry
535 .compare_versions("gpt-4", "v1", "gpt-4", "v2")
536 .unwrap();
537 assert!(comparison.accuracy_diff < 0.0); }
539
540 #[test]
541 fn test_deprecation() {
542 let mut registry = ModelRegistry::new();
543
544 let version = ModelVersion::new("gpt-3.5", "old", "Old version");
545 registry.register_version(version).unwrap();
546
547 registry.deprecate_version("gpt-3.5", "old").unwrap();
548
549 let deprecated = registry.get_version("gpt-3.5", "old").unwrap();
550 assert!(deprecated.is_deprecated());
551 assert!(!deprecated.active);
552 }
553
554 #[test]
555 fn test_best_version() {
556 let mut registry = ModelRegistry::new();
557
558 registry
559 .register_version(
560 ModelVersion::new("claude", "v1", "V1")
561 .with_metrics(ModelMetrics::new(85.0, 1000, 0.02)),
562 )
563 .unwrap();
564
565 registry
566 .register_version(
567 ModelVersion::new("claude", "v2", "V2")
568 .with_metrics(ModelMetrics::new(95.0, 1200, 0.03)),
569 )
570 .unwrap();
571
572 let best = registry.get_best_version("claude").unwrap();
573 assert_eq!(best.version, "v2");
574 }
575
576 #[test]
577 fn test_registry_persistence() {
578 let mut registry = ModelRegistry::new();
579
580 registry
581 .register_version(ModelVersion::new("test-model", "v1.0", "Test"))
582 .unwrap();
583
584 let temp_path = "/tmp/model_registry_test.json";
585 registry.save_to_file(temp_path).unwrap();
586
587 let loaded = ModelRegistry::load_from_file(temp_path).unwrap();
588 assert_eq!(loaded.len(), 1);
589
590 let _ = std::fs::remove_file(temp_path);
592 }
593}