kizzasi_inference/
versioning.rs

1//! Model versioning and fallback management
2//!
3//! This module provides robust model versioning, health checking, and automatic
4//! fallback mechanisms for production inference systems.
5
6use crate::error::{InferenceError, InferenceResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::{Duration, Instant};
11
12/// Semantic version for models
13#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
14pub struct ModelVersion {
15    /// Major version (breaking changes)
16    pub major: u32,
17    /// Minor version (new features)
18    pub minor: u32,
19    /// Patch version (bug fixes)
20    pub patch: u32,
21}
22
23impl ModelVersion {
24    /// Create a new model version
25    pub fn new(major: u32, minor: u32, patch: u32) -> Self {
26        Self {
27            major,
28            minor,
29            patch,
30        }
31    }
32
33    /// Parse version from string (e.g., "1.2.3")
34    pub fn parse(s: &str) -> InferenceResult<Self> {
35        let parts: Vec<&str> = s.split('.').collect();
36        if parts.len() != 3 {
37            return Err(InferenceError::ForwardError(format!(
38                "Invalid version format: {}",
39                s
40            )));
41        }
42
43        let major = parts[0].parse().map_err(|_| {
44            InferenceError::ForwardError(format!("Invalid major version: {}", parts[0]))
45        })?;
46        let minor = parts[1].parse().map_err(|_| {
47            InferenceError::ForwardError(format!("Invalid minor version: {}", parts[1]))
48        })?;
49        let patch = parts[2].parse().map_err(|_| {
50            InferenceError::ForwardError(format!("Invalid patch version: {}", parts[2]))
51        })?;
52
53        Ok(Self::new(major, minor, patch))
54    }
55
56    /// Check if this version is compatible with another (same major version)
57    pub fn is_compatible_with(&self, other: &ModelVersion) -> bool {
58        self.major == other.major
59    }
60
61    // Note: to_string() is provided by Display trait implementation
62}
63
64impl std::fmt::Display for ModelVersion {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
67    }
68}
69
70/// Health status of a model
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
72pub enum HealthStatus {
73    /// Model is healthy and operational
74    Healthy,
75    /// Model is degraded but operational
76    Degraded,
77    /// Model is unhealthy and should not be used
78    Unhealthy,
79}
80
81/// Health check result
82#[derive(Debug, Clone)]
83pub struct HealthCheck {
84    /// Health status
85    pub status: HealthStatus,
86    /// Timestamp of the check
87    pub timestamp: Instant,
88    /// Average latency (ms)
89    pub avg_latency_ms: f64,
90    /// Error rate (0.0 to 1.0)
91    pub error_rate: f64,
92    /// Number of requests processed
93    pub request_count: usize,
94    /// Additional details
95    pub details: String,
96}
97
98impl HealthCheck {
99    /// Create a new health check result
100    pub fn new(status: HealthStatus) -> Self {
101        Self {
102            status,
103            timestamp: Instant::now(),
104            avg_latency_ms: 0.0,
105            error_rate: 0.0,
106            request_count: 0,
107            details: String::new(),
108        }
109    }
110
111    /// Check if the model is usable
112    pub fn is_usable(&self) -> bool {
113        matches!(self.status, HealthStatus::Healthy | HealthStatus::Degraded)
114    }
115}
116
117/// Model metadata
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ModelMetadata {
120    /// Model identifier
121    pub id: String,
122    /// Model version
123    pub version: ModelVersion,
124    /// Model architecture type
125    pub architecture: String,
126    /// Creation timestamp
127    pub created_at: String,
128    /// Model checksum (for integrity verification)
129    pub checksum: Option<String>,
130    /// Additional metadata
131    pub extra: HashMap<String, String>,
132}
133
134impl ModelMetadata {
135    /// Create new metadata
136    pub fn new(
137        id: impl Into<String>,
138        version: ModelVersion,
139        architecture: impl Into<String>,
140    ) -> Self {
141        Self {
142            id: id.into(),
143            version,
144            architecture: architecture.into(),
145            created_at: chrono::Utc::now().to_rfc3339(),
146            checksum: None,
147            extra: HashMap::new(),
148        }
149    }
150
151    /// Set checksum
152    pub fn with_checksum(mut self, checksum: impl Into<String>) -> Self {
153        self.checksum = Some(checksum.into());
154        self
155    }
156
157    /// Add extra metadata field
158    pub fn add_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
159        self.extra.insert(key.into(), value.into());
160        self
161    }
162}
163
164/// Statistics for a versioned model
165#[derive(Debug, Clone, Default)]
166pub struct ModelStats {
167    /// Total requests processed
168    pub total_requests: usize,
169    /// Total errors encountered
170    pub total_errors: usize,
171    /// Sum of latencies (for computing average)
172    pub total_latency_ms: f64,
173    /// Last health check
174    pub last_health_check: Option<HealthCheck>,
175}
176
177impl ModelStats {
178    /// Record a successful request
179    pub fn record_success(&mut self, latency_ms: f64) {
180        self.total_requests += 1;
181        self.total_latency_ms += latency_ms;
182    }
183
184    /// Record an error
185    pub fn record_error(&mut self, latency_ms: f64) {
186        self.total_requests += 1;
187        self.total_errors += 1;
188        self.total_latency_ms += latency_ms;
189    }
190
191    /// Get average latency
192    pub fn avg_latency_ms(&self) -> f64 {
193        if self.total_requests == 0 {
194            0.0
195        } else {
196            self.total_latency_ms / self.total_requests as f64
197        }
198    }
199
200    /// Get error rate
201    pub fn error_rate(&self) -> f64 {
202        if self.total_requests == 0 {
203            0.0
204        } else {
205            self.total_errors as f64 / self.total_requests as f64
206        }
207    }
208
209    /// Generate health check from current stats
210    pub fn to_health_check(&self) -> HealthCheck {
211        let error_rate = self.error_rate();
212        let avg_latency = self.avg_latency_ms();
213
214        let status = if error_rate > 0.5 {
215            HealthStatus::Unhealthy
216        } else if error_rate > 0.1 || avg_latency > 1000.0 {
217            HealthStatus::Degraded
218        } else {
219            HealthStatus::Healthy
220        };
221
222        HealthCheck {
223            status,
224            timestamp: Instant::now(),
225            avg_latency_ms: avg_latency,
226            error_rate,
227            request_count: self.total_requests,
228            details: format!(
229                "Requests: {}, Errors: {}, Avg Latency: {:.2}ms",
230                self.total_requests, self.total_errors, avg_latency
231            ),
232        }
233    }
234}
235
236/// Fallback strategy when primary model fails
237#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
238pub enum FallbackStrategy {
239    /// Use the next available version
240    NextVersion,
241    /// Use the previous stable version
242    PreviousStable,
243    /// Use a specific version
244    SpecificVersion,
245    /// Return an error (no fallback)
246    NoFallback,
247}
248
249/// Configuration for model versioning
250#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct VersioningConfig {
252    /// Fallback strategy
253    pub fallback_strategy: FallbackStrategy,
254    /// Health check interval
255    pub health_check_interval: Duration,
256    /// Maximum error rate before marking unhealthy
257    pub max_error_rate: f64,
258    /// Maximum latency before marking degraded (ms)
259    pub max_latency_ms: f64,
260    /// Automatic recovery enabled
261    pub auto_recovery: bool,
262}
263
264impl Default for VersioningConfig {
265    fn default() -> Self {
266        Self {
267            fallback_strategy: FallbackStrategy::PreviousStable,
268            health_check_interval: Duration::from_secs(60),
269            max_error_rate: 0.1,
270            max_latency_ms: 1000.0,
271            auto_recovery: true,
272        }
273    }
274}
275
276/// Versioned model entry
277struct VersionedModelEntry {
278    metadata: ModelMetadata,
279    stats: ModelStats,
280    is_stable: bool,
281}
282
283/// Model version manager
284pub struct ModelVersionManager {
285    /// Map of model ID to versions
286    models: Arc<RwLock<HashMap<String, Vec<VersionedModelEntry>>>>,
287    /// Currently active versions per model ID
288    active_versions: Arc<RwLock<HashMap<String, ModelVersion>>>,
289    /// Configuration
290    config: VersioningConfig,
291}
292
293impl ModelVersionManager {
294    /// Create a new version manager
295    pub fn new(config: VersioningConfig) -> Self {
296        Self {
297            models: Arc::new(RwLock::new(HashMap::new())),
298            active_versions: Arc::new(RwLock::new(HashMap::new())),
299            config,
300        }
301    }
302
303    /// Register a model version
304    pub fn register_version(
305        &self,
306        metadata: ModelMetadata,
307        is_stable: bool,
308    ) -> InferenceResult<()> {
309        let mut models = self.models.write().map_err(|e| {
310            InferenceError::LockError(format!("Failed to acquire write lock: {}", e))
311        })?;
312        let entries = models.entry(metadata.id.clone()).or_default();
313
314        // Check if version already exists
315        if entries
316            .iter()
317            .any(|e| e.metadata.version == metadata.version)
318        {
319            return Err(InferenceError::ForwardError(format!(
320                "Version {} already registered for model {}",
321                metadata.version, metadata.id
322            )));
323        }
324
325        entries.push(VersionedModelEntry {
326            metadata,
327            stats: ModelStats::default(),
328            is_stable,
329        });
330
331        // Sort by version (descending)
332        entries.sort_by(|a, b| b.metadata.version.cmp(&a.metadata.version));
333
334        Ok(())
335    }
336
337    /// Get the active version for a model
338    pub fn get_active_version(&self, model_id: &str) -> Option<ModelVersion> {
339        let active = self.active_versions.read().ok()?;
340        active.get(model_id).cloned()
341    }
342
343    /// Set the active version for a model
344    pub fn set_active_version(&self, model_id: &str, version: ModelVersion) -> InferenceResult<()> {
345        // Verify version exists
346        let models = self.models.read().map_err(|e| {
347            InferenceError::LockError(format!("Failed to acquire read lock: {}", e))
348        })?;
349        let entries = models
350            .get(model_id)
351            .ok_or_else(|| InferenceError::ForwardError(format!("Model {} not found", model_id)))?;
352
353        if !entries.iter().any(|e| e.metadata.version == version) {
354            return Err(InferenceError::ForwardError(format!(
355                "Version {} not found for model {}",
356                version, model_id
357            )));
358        }
359
360        let mut active = self.active_versions.write().map_err(|e| {
361            InferenceError::LockError(format!("Failed to acquire write lock: {}", e))
362        })?;
363        active.insert(model_id.to_string(), version);
364
365        Ok(())
366    }
367
368    /// Record a request for a model version
369    pub fn record_request(
370        &self,
371        model_id: &str,
372        version: &ModelVersion,
373        latency_ms: f64,
374        is_error: bool,
375    ) -> InferenceResult<()> {
376        let mut models = self.models.write().map_err(|e| {
377            InferenceError::LockError(format!("Failed to acquire write lock: {}", e))
378        })?;
379        let entries = models
380            .get_mut(model_id)
381            .ok_or_else(|| InferenceError::ForwardError(format!("Model {} not found", model_id)))?;
382
383        let entry = entries
384            .iter_mut()
385            .find(|e| &e.metadata.version == version)
386            .ok_or_else(|| {
387                InferenceError::ForwardError(format!(
388                    "Version {} not found for model {}",
389                    version, model_id
390                ))
391            })?;
392
393        if is_error {
394            entry.stats.record_error(latency_ms);
395        } else {
396            entry.stats.record_success(latency_ms);
397        }
398
399        Ok(())
400    }
401
402    /// Perform health check on a model version
403    pub fn health_check(
404        &self,
405        model_id: &str,
406        version: &ModelVersion,
407    ) -> InferenceResult<HealthCheck> {
408        let mut models = self.models.write().map_err(|e| {
409            InferenceError::LockError(format!("Failed to acquire write lock: {}", e))
410        })?;
411        let entries = models
412            .get_mut(model_id)
413            .ok_or_else(|| InferenceError::ForwardError(format!("Model {} not found", model_id)))?;
414
415        let entry = entries
416            .iter_mut()
417            .find(|e| &e.metadata.version == version)
418            .ok_or_else(|| {
419                InferenceError::ForwardError(format!(
420                    "Version {} not found for model {}",
421                    version, model_id
422                ))
423            })?;
424
425        let health_check = entry.stats.to_health_check();
426        entry.stats.last_health_check = Some(health_check.clone());
427
428        Ok(health_check)
429    }
430
431    /// Get fallback version for a model
432    pub fn get_fallback_version(
433        &self,
434        model_id: &str,
435        current_version: &ModelVersion,
436    ) -> Option<ModelVersion> {
437        let models = self.models.read().ok()?;
438        let entries = models.get(model_id)?;
439
440        match self.config.fallback_strategy {
441            FallbackStrategy::NextVersion => {
442                // Find next lower version
443                entries
444                    .iter()
445                    .filter(|e| e.metadata.version < *current_version)
446                    .map(|e| e.metadata.version.clone())
447                    .next()
448            }
449            FallbackStrategy::PreviousStable => {
450                // Find most recent stable version that's not current
451                entries
452                    .iter()
453                    .filter(|e| e.is_stable && e.metadata.version != *current_version)
454                    .map(|e| e.metadata.version.clone())
455                    .next()
456            }
457            FallbackStrategy::NoFallback => None,
458            FallbackStrategy::SpecificVersion => {
459                // Would need to be configured separately
460                None
461            }
462        }
463    }
464
465    /// List all versions for a model
466    pub fn list_versions(&self, model_id: &str) -> Vec<ModelVersion> {
467        let models = match self.models.read() {
468            Ok(m) => m,
469            Err(_) => return Vec::new(),
470        };
471        models
472            .get(model_id)
473            .map(|entries| entries.iter().map(|e| e.metadata.version.clone()).collect())
474            .unwrap_or_default()
475    }
476
477    /// Get metadata for a specific version
478    pub fn get_metadata(&self, model_id: &str, version: &ModelVersion) -> Option<ModelMetadata> {
479        let models = self.models.read().ok()?;
480        models.get(model_id).and_then(|entries| {
481            entries
482                .iter()
483                .find(|e| &e.metadata.version == version)
484                .map(|e| e.metadata.clone())
485        })
486    }
487
488    /// Get statistics for a specific version
489    pub fn get_stats(&self, model_id: &str, version: &ModelVersion) -> Option<ModelStats> {
490        let models = self.models.read().ok()?;
491        models.get(model_id).and_then(|entries| {
492            entries
493                .iter()
494                .find(|e| &e.metadata.version == version)
495                .map(|e| e.stats.clone())
496        })
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_model_version_creation() {
506        let version = ModelVersion::new(1, 2, 3);
507        assert_eq!(version.major, 1);
508        assert_eq!(version.minor, 2);
509        assert_eq!(version.patch, 3);
510    }
511
512    #[test]
513    fn test_model_version_parse() {
514        let version = ModelVersion::parse("1.2.3").unwrap();
515        assert_eq!(version.major, 1);
516        assert_eq!(version.minor, 2);
517        assert_eq!(version.patch, 3);
518    }
519
520    #[test]
521    fn test_model_version_parse_invalid() {
522        assert!(ModelVersion::parse("1.2").is_err());
523        assert!(ModelVersion::parse("1.2.x").is_err());
524    }
525
526    #[test]
527    fn test_model_version_compatibility() {
528        let v1 = ModelVersion::new(1, 2, 3);
529        let v2 = ModelVersion::new(1, 3, 0);
530        let v3 = ModelVersion::new(2, 0, 0);
531
532        assert!(v1.is_compatible_with(&v2));
533        assert!(!v1.is_compatible_with(&v3));
534    }
535
536    #[test]
537    fn test_model_version_ordering() {
538        let v1 = ModelVersion::new(1, 0, 0);
539        let v2 = ModelVersion::new(1, 1, 0);
540        let v3 = ModelVersion::new(2, 0, 0);
541
542        assert!(v1 < v2);
543        assert!(v2 < v3);
544        assert!(v1 < v3);
545    }
546
547    #[test]
548    fn test_health_status() {
549        let check = HealthCheck::new(HealthStatus::Healthy);
550        assert!(check.is_usable());
551
552        let check = HealthCheck::new(HealthStatus::Degraded);
553        assert!(check.is_usable());
554
555        let check = HealthCheck::new(HealthStatus::Unhealthy);
556        assert!(!check.is_usable());
557    }
558
559    #[test]
560    fn test_model_stats() {
561        let mut stats = ModelStats::default();
562
563        stats.record_success(100.0);
564        stats.record_success(200.0);
565        stats.record_error(300.0);
566
567        assert_eq!(stats.total_requests, 3);
568        assert_eq!(stats.total_errors, 1);
569        assert_eq!(stats.avg_latency_ms(), 200.0);
570        assert_eq!(stats.error_rate(), 1.0 / 3.0);
571    }
572
573    #[test]
574    fn test_model_stats_health_check() {
575        let mut stats = ModelStats::default();
576
577        // Low error rate, low latency -> Healthy
578        stats.record_success(50.0);
579        stats.record_success(60.0);
580        let check = stats.to_health_check();
581        assert_eq!(check.status, HealthStatus::Healthy);
582
583        // High error rate -> Unhealthy
584        stats.record_error(100.0);
585        stats.record_error(100.0);
586        stats.record_error(100.0);
587        let check = stats.to_health_check();
588        assert_eq!(check.status, HealthStatus::Unhealthy);
589    }
590
591    #[test]
592    fn test_version_manager_register() {
593        let config = VersioningConfig::default();
594        let manager = ModelVersionManager::new(config);
595
596        let metadata = ModelMetadata::new("test-model", ModelVersion::new(1, 0, 0), "transformer");
597
598        manager.register_version(metadata, true).unwrap();
599
600        let versions = manager.list_versions("test-model");
601        assert_eq!(versions.len(), 1);
602    }
603
604    #[test]
605    fn test_version_manager_active_version() {
606        let config = VersioningConfig::default();
607        let manager = ModelVersionManager::new(config);
608
609        let v1 = ModelVersion::new(1, 0, 0);
610        let metadata = ModelMetadata::new("test-model", v1.clone(), "transformer");
611        manager.register_version(metadata, true).unwrap();
612
613        manager
614            .set_active_version("test-model", v1.clone())
615            .unwrap();
616
617        let active = manager.get_active_version("test-model");
618        assert_eq!(active, Some(v1));
619    }
620
621    #[test]
622    fn test_version_manager_record_request() {
623        let config = VersioningConfig::default();
624        let manager = ModelVersionManager::new(config);
625
626        let v1 = ModelVersion::new(1, 0, 0);
627        let metadata = ModelMetadata::new("test-model", v1.clone(), "transformer");
628        manager.register_version(metadata, true).unwrap();
629
630        manager
631            .record_request("test-model", &v1, 100.0, false)
632            .unwrap();
633        manager
634            .record_request("test-model", &v1, 200.0, true)
635            .unwrap();
636
637        let stats = manager.get_stats("test-model", &v1).unwrap();
638        assert_eq!(stats.total_requests, 2);
639        assert_eq!(stats.total_errors, 1);
640    }
641
642    #[test]
643    fn test_version_manager_fallback() {
644        let config = VersioningConfig {
645            fallback_strategy: FallbackStrategy::PreviousStable,
646            ..Default::default()
647        };
648        let manager = ModelVersionManager::new(config);
649
650        let v1 = ModelVersion::new(1, 0, 0);
651        let v2 = ModelVersion::new(1, 1, 0);
652
653        let metadata1 = ModelMetadata::new("test-model", v1.clone(), "transformer");
654        let metadata2 = ModelMetadata::new("test-model", v2.clone(), "transformer");
655
656        manager.register_version(metadata1, true).unwrap(); // Stable
657        manager.register_version(metadata2, false).unwrap(); // Not stable
658
659        let fallback = manager.get_fallback_version("test-model", &v2);
660        assert_eq!(fallback, Some(v1));
661    }
662
663    #[test]
664    fn test_version_manager_health_check() {
665        let config = VersioningConfig::default();
666        let manager = ModelVersionManager::new(config);
667
668        let v1 = ModelVersion::new(1, 0, 0);
669        let metadata = ModelMetadata::new("test-model", v1.clone(), "transformer");
670        manager.register_version(metadata, true).unwrap();
671
672        // Record some successful requests
673        manager
674            .record_request("test-model", &v1, 50.0, false)
675            .unwrap();
676        manager
677            .record_request("test-model", &v1, 60.0, false)
678            .unwrap();
679
680        let health = manager.health_check("test-model", &v1).unwrap();
681        assert_eq!(health.status, HealthStatus::Healthy);
682    }
683}