kaccy_ai/
model_version.rs

1//! Model version tracking and management
2//!
3//! This module provides version tracking for AI models, including performance metrics,
4//! deployment history, and version comparison capabilities.
5//!
6//! # Examples
7//!
8//! ```
9//! use kaccy_ai::model_version::{ModelRegistry, ModelVersion, ModelMetrics};
10//!
11//! let mut registry = ModelRegistry::new();
12//!
13//! // Register a new model version
14//! let version = ModelVersion::new("gpt-4-turbo", "20240301", "GPT-4 Turbo March 2024");
15//! registry.register_version(version).unwrap();
16//!
17//! // Record metrics
18//! let metrics = ModelMetrics::new(95.5, 1250, 0.03);
19//! registry.update_metrics("gpt-4-turbo", "20240301", metrics).unwrap();
20//! ```
21
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24
25use crate::error::{AiError, Result};
26
27/// Model performance metrics
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ModelMetrics {
30    /// Accuracy percentage (0-100)
31    pub accuracy: f64,
32    /// Average tokens per request
33    pub avg_tokens: u32,
34    /// Average cost per request (USD)
35    pub avg_cost: f64,
36    /// Total requests processed
37    pub total_requests: u64,
38    /// Total errors encountered
39    pub total_errors: u64,
40    /// Average latency in milliseconds
41    pub avg_latency_ms: f64,
42}
43
44impl ModelMetrics {
45    /// Create new metrics
46    #[must_use]
47    pub fn new(accuracy: f64, avg_tokens: u32, avg_cost: f64) -> Self {
48        Self {
49            accuracy: accuracy.clamp(0.0, 100.0),
50            avg_tokens,
51            avg_cost,
52            total_requests: 0,
53            total_errors: 0,
54            avg_latency_ms: 0.0,
55        }
56    }
57
58    /// Create empty metrics
59    #[must_use]
60    pub fn empty() -> Self {
61        Self {
62            accuracy: 0.0,
63            avg_tokens: 0,
64            avg_cost: 0.0,
65            total_requests: 0,
66            total_errors: 0,
67            avg_latency_ms: 0.0,
68        }
69    }
70
71    /// Update with new request data
72    pub fn record_request(&mut self, tokens: u32, cost: f64, latency_ms: f64, success: bool) {
73        let n = self.total_requests as f64;
74
75        // Update running averages
76        self.avg_tokens = ((f64::from(self.avg_tokens) * n + f64::from(tokens)) / (n + 1.0)) as u32;
77        self.avg_cost = (self.avg_cost * n + cost) / (n + 1.0);
78        self.avg_latency_ms = (self.avg_latency_ms * n + latency_ms) / (n + 1.0);
79
80        self.total_requests += 1;
81        if !success {
82            self.total_errors += 1;
83        }
84
85        // Update accuracy
86        let success_count = self.total_requests - self.total_errors;
87        self.accuracy = (success_count as f64 / self.total_requests as f64) * 100.0;
88    }
89
90    /// Get error rate percentage
91    #[must_use]
92    pub fn error_rate(&self) -> f64 {
93        if self.total_requests == 0 {
94            0.0
95        } else {
96            (self.total_errors as f64 / self.total_requests as f64) * 100.0
97        }
98    }
99
100    /// Calculate cost-effectiveness score (accuracy per dollar)
101    #[must_use]
102    pub fn cost_effectiveness(&self) -> f64 {
103        if self.avg_cost == 0.0 {
104            0.0
105        } else {
106            self.accuracy / self.avg_cost
107        }
108    }
109}
110
111/// Model version information
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ModelVersion {
114    /// Model family name (e.g., "gpt-4-turbo", "claude-3-opus")
115    pub model_name: String,
116    /// Version identifier (e.g., "20240301", "v1.2.3")
117    pub version: String,
118    /// Human-readable description
119    pub description: String,
120    /// Performance metrics
121    pub metrics: ModelMetrics,
122    /// Release date
123    pub release_date: chrono::DateTime<chrono::Utc>,
124    /// Deprecation date (if any)
125    pub deprecated_at: Option<chrono::DateTime<chrono::Utc>>,
126    /// Whether this version is currently active
127    pub active: bool,
128    /// Tags for categorization
129    pub tags: Vec<String>,
130}
131
132impl ModelVersion {
133    /// Create a new model version
134    pub fn new(
135        model_name: impl Into<String>,
136        version: impl Into<String>,
137        description: impl Into<String>,
138    ) -> Self {
139        Self {
140            model_name: model_name.into(),
141            version: version.into(),
142            description: description.into(),
143            metrics: ModelMetrics::empty(),
144            release_date: chrono::Utc::now(),
145            deprecated_at: None,
146            active: true,
147            tags: Vec::new(),
148        }
149    }
150
151    /// Add tags
152    #[must_use]
153    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
154        self.tags = tags;
155        self
156    }
157
158    /// Set release date
159    #[must_use]
160    pub fn with_release_date(mut self, date: chrono::DateTime<chrono::Utc>) -> Self {
161        self.release_date = date;
162        self
163    }
164
165    /// Set initial metrics
166    #[must_use]
167    pub fn with_metrics(mut self, metrics: ModelMetrics) -> Self {
168        self.metrics = metrics;
169        self
170    }
171
172    /// Mark as deprecated
173    pub fn deprecate(&mut self) {
174        self.deprecated_at = Some(chrono::Utc::now());
175        self.active = false;
176    }
177
178    /// Check if deprecated
179    #[must_use]
180    pub fn is_deprecated(&self) -> bool {
181        self.deprecated_at.is_some()
182    }
183
184    /// Get unique identifier
185    #[must_use]
186    pub fn id(&self) -> String {
187        format!("{}:{}", self.model_name, self.version)
188    }
189}
190
191/// Model registry for tracking versions
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct ModelRegistry {
194    versions: HashMap<String, ModelVersion>,
195    active_versions: HashMap<String, String>, // model_name -> active_version
196}
197
198impl Default for ModelRegistry {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl ModelRegistry {
205    /// Create a new model registry
206    #[must_use]
207    pub fn new() -> Self {
208        Self {
209            versions: HashMap::new(),
210            active_versions: HashMap::new(),
211        }
212    }
213
214    /// Register a new model version
215    pub fn register_version(&mut self, version: ModelVersion) -> Result<()> {
216        let id = version.id();
217        let model_name = version.model_name.clone();
218
219        // Set as active version if it's the first or explicitly active
220        if version.active {
221            self.active_versions
222                .insert(model_name, version.version.clone());
223        }
224
225        self.versions.insert(id, version);
226        Ok(())
227    }
228
229    /// Get a specific version
230    #[must_use]
231    pub fn get_version(&self, model_name: &str, version: &str) -> Option<&ModelVersion> {
232        let id = format!("{model_name}:{version}");
233        self.versions.get(&id)
234    }
235
236    /// Get mutable version
237    pub fn get_version_mut(
238        &mut self,
239        model_name: &str,
240        version: &str,
241    ) -> Option<&mut ModelVersion> {
242        let id = format!("{model_name}:{version}");
243        self.versions.get_mut(&id)
244    }
245
246    /// Get active version for a model
247    #[must_use]
248    pub fn get_active_version(&self, model_name: &str) -> Option<&ModelVersion> {
249        let version = self.active_versions.get(model_name)?;
250        self.get_version(model_name, version)
251    }
252
253    /// Set active version for a model
254    pub fn set_active_version(&mut self, model_name: &str, version: &str) -> Result<()> {
255        let id = format!("{model_name}:{version}");
256
257        if !self.versions.contains_key(&id) {
258            return Err(AiError::NotFound(format!("Model version {id} not found")));
259        }
260
261        self.active_versions
262            .insert(model_name.to_string(), version.to_string());
263        Ok(())
264    }
265
266    /// Get all versions for a model
267    #[must_use]
268    pub fn get_model_versions(&self, model_name: &str) -> Vec<&ModelVersion> {
269        self.versions
270            .values()
271            .filter(|v| v.model_name == model_name)
272            .collect()
273    }
274
275    /// Update metrics for a version
276    pub fn update_metrics(
277        &mut self,
278        model_name: &str,
279        version: &str,
280        metrics: ModelMetrics,
281    ) -> Result<()> {
282        let id = format!("{model_name}:{version}");
283
284        let model = self
285            .versions
286            .get_mut(&id)
287            .ok_or_else(|| AiError::NotFound(format!("Model version {id} not found")))?;
288
289        model.metrics = metrics;
290        Ok(())
291    }
292
293    /// Record a request for a model version
294    pub fn record_request(
295        &mut self,
296        model_name: &str,
297        version: &str,
298        tokens: u32,
299        cost: f64,
300        latency_ms: f64,
301        success: bool,
302    ) -> Result<()> {
303        let id = format!("{model_name}:{version}");
304
305        let model = self
306            .versions
307            .get_mut(&id)
308            .ok_or_else(|| AiError::NotFound(format!("Model version {id} not found")))?;
309
310        model
311            .metrics
312            .record_request(tokens, cost, latency_ms, success);
313        Ok(())
314    }
315
316    /// Compare two model versions
317    #[must_use]
318    pub fn compare_versions(
319        &self,
320        model1: &str,
321        version1: &str,
322        model2: &str,
323        version2: &str,
324    ) -> Option<VersionComparison> {
325        let v1 = self.get_version(model1, version1)?;
326        let v2 = self.get_version(model2, version2)?;
327
328        Some(VersionComparison {
329            version1: v1.clone(),
330            version2: v2.clone(),
331            accuracy_diff: v1.metrics.accuracy - v2.metrics.accuracy,
332            cost_diff: v1.metrics.avg_cost - v2.metrics.avg_cost,
333            latency_diff: v1.metrics.avg_latency_ms - v2.metrics.avg_latency_ms,
334        })
335    }
336
337    /// Get best performing version for a model (by accuracy)
338    #[must_use]
339    pub fn get_best_version(&self, model_name: &str) -> Option<&ModelVersion> {
340        self.get_model_versions(model_name)
341            .into_iter()
342            .max_by(|a, b| a.metrics.accuracy.partial_cmp(&b.metrics.accuracy).unwrap())
343    }
344
345    /// Get most cost-effective version for a model
346    #[must_use]
347    pub fn get_most_cost_effective(&self, model_name: &str) -> Option<&ModelVersion> {
348        self.get_model_versions(model_name)
349            .into_iter()
350            .max_by(|a, b| {
351                a.metrics
352                    .cost_effectiveness()
353                    .partial_cmp(&b.metrics.cost_effectiveness())
354                    .unwrap()
355            })
356    }
357
358    /// Get all model names
359    #[must_use]
360    pub fn list_models(&self) -> Vec<String> {
361        let mut models: Vec<String> = self
362            .versions
363            .values()
364            .map(|v| v.model_name.clone())
365            .collect();
366        models.sort();
367        models.dedup();
368        models
369    }
370
371    /// Deprecate a version
372    pub fn deprecate_version(&mut self, model_name: &str, version: &str) -> Result<()> {
373        let id = format!("{model_name}:{version}");
374
375        let model = self
376            .versions
377            .get_mut(&id)
378            .ok_or_else(|| AiError::NotFound(format!("Model version {id} not found")))?;
379
380        model.deprecate();
381
382        // If this was the active version, clear it
383        if let Some(active) = self.active_versions.get(model_name) {
384            if active == version {
385                self.active_versions.remove(model_name);
386            }
387        }
388
389        Ok(())
390    }
391
392    /// Save registry to file
393    pub fn save_to_file(&self, path: impl AsRef<std::path::Path>) -> Result<()> {
394        let json = serde_json::to_string_pretty(self)
395            .map_err(|e| AiError::Internal(format!("Failed to serialize registry: {e}")))?;
396
397        std::fs::write(path, json)
398            .map_err(|e| AiError::Internal(format!("Failed to write registry: {e}")))?;
399
400        Ok(())
401    }
402
403    /// Load registry from file
404    pub fn load_from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
405        let json = std::fs::read_to_string(path)
406            .map_err(|e| AiError::Internal(format!("Failed to read registry: {e}")))?;
407
408        let registry: ModelRegistry = serde_json::from_str(&json)
409            .map_err(|e| AiError::Internal(format!("Failed to deserialize registry: {e}")))?;
410
411        Ok(registry)
412    }
413
414    /// Get total number of versions
415    #[must_use]
416    pub fn len(&self) -> usize {
417        self.versions.len()
418    }
419
420    /// Check if empty
421    #[must_use]
422    pub fn is_empty(&self) -> bool {
423        self.versions.is_empty()
424    }
425}
426
427/// Comparison between two model versions
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct VersionComparison {
430    pub version1: ModelVersion,
431    pub version2: ModelVersion,
432    /// Accuracy difference (version1 - version2)
433    pub accuracy_diff: f64,
434    /// Cost difference (version1 - version2)
435    pub cost_diff: f64,
436    /// Latency difference (version1 - version2)
437    pub latency_diff: f64,
438}
439
440impl VersionComparison {
441    /// Get recommendation based on comparison
442    #[must_use]
443    pub fn recommendation(&self) -> &'static str {
444        if self.accuracy_diff > 5.0 && self.cost_diff < 0.01 {
445            "Version 1 is significantly more accurate with similar cost"
446        } else if self.accuracy_diff < -5.0 && self.cost_diff > -0.01 {
447            "Version 2 is significantly more accurate with similar cost"
448        } else if self.cost_diff < -0.005 && self.accuracy_diff.abs() < 2.0 {
449            "Version 1 is more cost-effective with similar accuracy"
450        } else if self.cost_diff > 0.005 && self.accuracy_diff.abs() < 2.0 {
451            "Version 2 is more cost-effective with similar accuracy"
452        } else {
453            "Versions have similar performance characteristics"
454        }
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_model_metrics_creation() {
464        let metrics = ModelMetrics::new(95.5, 1000, 0.02);
465        assert_eq!(metrics.accuracy, 95.5);
466        assert_eq!(metrics.avg_tokens, 1000);
467        assert_eq!(metrics.avg_cost, 0.02);
468    }
469
470    #[test]
471    fn test_model_metrics_record_request() {
472        let mut metrics = ModelMetrics::empty();
473
474        metrics.record_request(1000, 0.02, 150.0, true);
475        assert_eq!(metrics.total_requests, 1);
476        assert_eq!(metrics.accuracy, 100.0);
477
478        metrics.record_request(1200, 0.025, 180.0, false);
479        assert_eq!(metrics.total_requests, 2);
480        assert_eq!(metrics.total_errors, 1);
481        assert_eq!(metrics.accuracy, 50.0);
482    }
483
484    #[test]
485    fn test_model_version_creation() {
486        let version = ModelVersion::new("gpt-4-turbo", "20240301", "GPT-4 Turbo March");
487        assert_eq!(version.model_name, "gpt-4-turbo");
488        assert_eq!(version.version, "20240301");
489        assert!(version.active);
490        assert!(!version.is_deprecated());
491    }
492
493    #[test]
494    fn test_model_registry() {
495        let mut registry = ModelRegistry::new();
496
497        let version = ModelVersion::new("gpt-4-turbo", "20240301", "Test version");
498        registry.register_version(version).unwrap();
499
500        assert_eq!(registry.len(), 1);
501
502        let retrieved = registry.get_version("gpt-4-turbo", "20240301");
503        assert!(retrieved.is_some());
504    }
505
506    #[test]
507    fn test_active_version() {
508        let mut registry = ModelRegistry::new();
509
510        let v1 = ModelVersion::new("gpt-4", "v1", "Version 1");
511        let v2 = ModelVersion::new("gpt-4", "v2", "Version 2");
512
513        registry.register_version(v1).unwrap();
514        registry.register_version(v2).unwrap();
515
516        registry.set_active_version("gpt-4", "v2").unwrap();
517
518        let active = registry.get_active_version("gpt-4").unwrap();
519        assert_eq!(active.version, "v2");
520    }
521
522    #[test]
523    fn test_version_comparison() {
524        let mut registry = ModelRegistry::new();
525
526        let v1 = ModelVersion::new("gpt-4", "v1", "V1")
527            .with_metrics(ModelMetrics::new(90.0, 1000, 0.02));
528        let v2 = ModelVersion::new("gpt-4", "v2", "V2")
529            .with_metrics(ModelMetrics::new(95.0, 1200, 0.025));
530
531        registry.register_version(v1).unwrap();
532        registry.register_version(v2).unwrap();
533
534        let comparison = registry
535            .compare_versions("gpt-4", "v1", "gpt-4", "v2")
536            .unwrap();
537        assert!(comparison.accuracy_diff < 0.0); // v2 is more accurate
538    }
539
540    #[test]
541    fn test_deprecation() {
542        let mut registry = ModelRegistry::new();
543
544        let version = ModelVersion::new("gpt-3.5", "old", "Old version");
545        registry.register_version(version).unwrap();
546
547        registry.deprecate_version("gpt-3.5", "old").unwrap();
548
549        let deprecated = registry.get_version("gpt-3.5", "old").unwrap();
550        assert!(deprecated.is_deprecated());
551        assert!(!deprecated.active);
552    }
553
554    #[test]
555    fn test_best_version() {
556        let mut registry = ModelRegistry::new();
557
558        registry
559            .register_version(
560                ModelVersion::new("claude", "v1", "V1")
561                    .with_metrics(ModelMetrics::new(85.0, 1000, 0.02)),
562            )
563            .unwrap();
564
565        registry
566            .register_version(
567                ModelVersion::new("claude", "v2", "V2")
568                    .with_metrics(ModelMetrics::new(95.0, 1200, 0.03)),
569            )
570            .unwrap();
571
572        let best = registry.get_best_version("claude").unwrap();
573        assert_eq!(best.version, "v2");
574    }
575
576    #[test]
577    fn test_registry_persistence() {
578        let mut registry = ModelRegistry::new();
579
580        registry
581            .register_version(ModelVersion::new("test-model", "v1.0", "Test"))
582            .unwrap();
583
584        let temp_path = "/tmp/model_registry_test.json";
585        registry.save_to_file(temp_path).unwrap();
586
587        let loaded = ModelRegistry::load_from_file(temp_path).unwrap();
588        assert_eq!(loaded.len(), 1);
589
590        // Cleanup
591        let _ = std::fs::remove_file(temp_path);
592    }
593}