1use crate::error::{InferenceError, InferenceResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
14pub struct ModelVersion {
15 pub major: u32,
17 pub minor: u32,
19 pub patch: u32,
21}
22
23impl ModelVersion {
24 pub fn new(major: u32, minor: u32, patch: u32) -> Self {
26 Self {
27 major,
28 minor,
29 patch,
30 }
31 }
32
33 pub fn parse(s: &str) -> InferenceResult<Self> {
35 let parts: Vec<&str> = s.split('.').collect();
36 if parts.len() != 3 {
37 return Err(InferenceError::ForwardError(format!(
38 "Invalid version format: {}",
39 s
40 )));
41 }
42
43 let major = parts[0].parse().map_err(|_| {
44 InferenceError::ForwardError(format!("Invalid major version: {}", parts[0]))
45 })?;
46 let minor = parts[1].parse().map_err(|_| {
47 InferenceError::ForwardError(format!("Invalid minor version: {}", parts[1]))
48 })?;
49 let patch = parts[2].parse().map_err(|_| {
50 InferenceError::ForwardError(format!("Invalid patch version: {}", parts[2]))
51 })?;
52
53 Ok(Self::new(major, minor, patch))
54 }
55
56 pub fn is_compatible_with(&self, other: &ModelVersion) -> bool {
58 self.major == other.major
59 }
60
61 }
63
64impl std::fmt::Display for ModelVersion {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
72pub enum HealthStatus {
73 Healthy,
75 Degraded,
77 Unhealthy,
79}
80
81#[derive(Debug, Clone)]
83pub struct HealthCheck {
84 pub status: HealthStatus,
86 pub timestamp: Instant,
88 pub avg_latency_ms: f64,
90 pub error_rate: f64,
92 pub request_count: usize,
94 pub details: String,
96}
97
98impl HealthCheck {
99 pub fn new(status: HealthStatus) -> Self {
101 Self {
102 status,
103 timestamp: Instant::now(),
104 avg_latency_ms: 0.0,
105 error_rate: 0.0,
106 request_count: 0,
107 details: String::new(),
108 }
109 }
110
111 pub fn is_usable(&self) -> bool {
113 matches!(self.status, HealthStatus::Healthy | HealthStatus::Degraded)
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ModelMetadata {
120 pub id: String,
122 pub version: ModelVersion,
124 pub architecture: String,
126 pub created_at: String,
128 pub checksum: Option<String>,
130 pub extra: HashMap<String, String>,
132}
133
134impl ModelMetadata {
135 pub fn new(
137 id: impl Into<String>,
138 version: ModelVersion,
139 architecture: impl Into<String>,
140 ) -> Self {
141 Self {
142 id: id.into(),
143 version,
144 architecture: architecture.into(),
145 created_at: chrono::Utc::now().to_rfc3339(),
146 checksum: None,
147 extra: HashMap::new(),
148 }
149 }
150
151 pub fn with_checksum(mut self, checksum: impl Into<String>) -> Self {
153 self.checksum = Some(checksum.into());
154 self
155 }
156
157 pub fn add_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
159 self.extra.insert(key.into(), value.into());
160 self
161 }
162}
163
164#[derive(Debug, Clone, Default)]
166pub struct ModelStats {
167 pub total_requests: usize,
169 pub total_errors: usize,
171 pub total_latency_ms: f64,
173 pub last_health_check: Option<HealthCheck>,
175}
176
177impl ModelStats {
178 pub fn record_success(&mut self, latency_ms: f64) {
180 self.total_requests += 1;
181 self.total_latency_ms += latency_ms;
182 }
183
184 pub fn record_error(&mut self, latency_ms: f64) {
186 self.total_requests += 1;
187 self.total_errors += 1;
188 self.total_latency_ms += latency_ms;
189 }
190
191 pub fn avg_latency_ms(&self) -> f64 {
193 if self.total_requests == 0 {
194 0.0
195 } else {
196 self.total_latency_ms / self.total_requests as f64
197 }
198 }
199
200 pub fn error_rate(&self) -> f64 {
202 if self.total_requests == 0 {
203 0.0
204 } else {
205 self.total_errors as f64 / self.total_requests as f64
206 }
207 }
208
209 pub fn to_health_check(&self) -> HealthCheck {
211 let error_rate = self.error_rate();
212 let avg_latency = self.avg_latency_ms();
213
214 let status = if error_rate > 0.5 {
215 HealthStatus::Unhealthy
216 } else if error_rate > 0.1 || avg_latency > 1000.0 {
217 HealthStatus::Degraded
218 } else {
219 HealthStatus::Healthy
220 };
221
222 HealthCheck {
223 status,
224 timestamp: Instant::now(),
225 avg_latency_ms: avg_latency,
226 error_rate,
227 request_count: self.total_requests,
228 details: format!(
229 "Requests: {}, Errors: {}, Avg Latency: {:.2}ms",
230 self.total_requests, self.total_errors, avg_latency
231 ),
232 }
233 }
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
238pub enum FallbackStrategy {
239 NextVersion,
241 PreviousStable,
243 SpecificVersion,
245 NoFallback,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct VersioningConfig {
252 pub fallback_strategy: FallbackStrategy,
254 pub health_check_interval: Duration,
256 pub max_error_rate: f64,
258 pub max_latency_ms: f64,
260 pub auto_recovery: bool,
262}
263
264impl Default for VersioningConfig {
265 fn default() -> Self {
266 Self {
267 fallback_strategy: FallbackStrategy::PreviousStable,
268 health_check_interval: Duration::from_secs(60),
269 max_error_rate: 0.1,
270 max_latency_ms: 1000.0,
271 auto_recovery: true,
272 }
273 }
274}
275
276struct VersionedModelEntry {
278 metadata: ModelMetadata,
279 stats: ModelStats,
280 is_stable: bool,
281}
282
283pub struct ModelVersionManager {
285 models: Arc<RwLock<HashMap<String, Vec<VersionedModelEntry>>>>,
287 active_versions: Arc<RwLock<HashMap<String, ModelVersion>>>,
289 config: VersioningConfig,
291}
292
293impl ModelVersionManager {
294 pub fn new(config: VersioningConfig) -> Self {
296 Self {
297 models: Arc::new(RwLock::new(HashMap::new())),
298 active_versions: Arc::new(RwLock::new(HashMap::new())),
299 config,
300 }
301 }
302
303 pub fn register_version(
305 &self,
306 metadata: ModelMetadata,
307 is_stable: bool,
308 ) -> InferenceResult<()> {
309 let mut models = self.models.write().map_err(|e| {
310 InferenceError::LockError(format!("Failed to acquire write lock: {}", e))
311 })?;
312 let entries = models.entry(metadata.id.clone()).or_default();
313
314 if entries
316 .iter()
317 .any(|e| e.metadata.version == metadata.version)
318 {
319 return Err(InferenceError::ForwardError(format!(
320 "Version {} already registered for model {}",
321 metadata.version, metadata.id
322 )));
323 }
324
325 entries.push(VersionedModelEntry {
326 metadata,
327 stats: ModelStats::default(),
328 is_stable,
329 });
330
331 entries.sort_by(|a, b| b.metadata.version.cmp(&a.metadata.version));
333
334 Ok(())
335 }
336
337 pub fn get_active_version(&self, model_id: &str) -> Option<ModelVersion> {
339 let active = self.active_versions.read().ok()?;
340 active.get(model_id).cloned()
341 }
342
343 pub fn set_active_version(&self, model_id: &str, version: ModelVersion) -> InferenceResult<()> {
345 let models = self.models.read().map_err(|e| {
347 InferenceError::LockError(format!("Failed to acquire read lock: {}", e))
348 })?;
349 let entries = models
350 .get(model_id)
351 .ok_or_else(|| InferenceError::ForwardError(format!("Model {} not found", model_id)))?;
352
353 if !entries.iter().any(|e| e.metadata.version == version) {
354 return Err(InferenceError::ForwardError(format!(
355 "Version {} not found for model {}",
356 version, model_id
357 )));
358 }
359
360 let mut active = self.active_versions.write().map_err(|e| {
361 InferenceError::LockError(format!("Failed to acquire write lock: {}", e))
362 })?;
363 active.insert(model_id.to_string(), version);
364
365 Ok(())
366 }
367
368 pub fn record_request(
370 &self,
371 model_id: &str,
372 version: &ModelVersion,
373 latency_ms: f64,
374 is_error: bool,
375 ) -> InferenceResult<()> {
376 let mut models = self.models.write().map_err(|e| {
377 InferenceError::LockError(format!("Failed to acquire write lock: {}", e))
378 })?;
379 let entries = models
380 .get_mut(model_id)
381 .ok_or_else(|| InferenceError::ForwardError(format!("Model {} not found", model_id)))?;
382
383 let entry = entries
384 .iter_mut()
385 .find(|e| &e.metadata.version == version)
386 .ok_or_else(|| {
387 InferenceError::ForwardError(format!(
388 "Version {} not found for model {}",
389 version, model_id
390 ))
391 })?;
392
393 if is_error {
394 entry.stats.record_error(latency_ms);
395 } else {
396 entry.stats.record_success(latency_ms);
397 }
398
399 Ok(())
400 }
401
402 pub fn health_check(
404 &self,
405 model_id: &str,
406 version: &ModelVersion,
407 ) -> InferenceResult<HealthCheck> {
408 let mut models = self.models.write().map_err(|e| {
409 InferenceError::LockError(format!("Failed to acquire write lock: {}", e))
410 })?;
411 let entries = models
412 .get_mut(model_id)
413 .ok_or_else(|| InferenceError::ForwardError(format!("Model {} not found", model_id)))?;
414
415 let entry = entries
416 .iter_mut()
417 .find(|e| &e.metadata.version == version)
418 .ok_or_else(|| {
419 InferenceError::ForwardError(format!(
420 "Version {} not found for model {}",
421 version, model_id
422 ))
423 })?;
424
425 let health_check = entry.stats.to_health_check();
426 entry.stats.last_health_check = Some(health_check.clone());
427
428 Ok(health_check)
429 }
430
431 pub fn get_fallback_version(
433 &self,
434 model_id: &str,
435 current_version: &ModelVersion,
436 ) -> Option<ModelVersion> {
437 let models = self.models.read().ok()?;
438 let entries = models.get(model_id)?;
439
440 match self.config.fallback_strategy {
441 FallbackStrategy::NextVersion => {
442 entries
444 .iter()
445 .filter(|e| e.metadata.version < *current_version)
446 .map(|e| e.metadata.version.clone())
447 .next()
448 }
449 FallbackStrategy::PreviousStable => {
450 entries
452 .iter()
453 .filter(|e| e.is_stable && e.metadata.version != *current_version)
454 .map(|e| e.metadata.version.clone())
455 .next()
456 }
457 FallbackStrategy::NoFallback => None,
458 FallbackStrategy::SpecificVersion => {
459 None
461 }
462 }
463 }
464
465 pub fn list_versions(&self, model_id: &str) -> Vec<ModelVersion> {
467 let models = match self.models.read() {
468 Ok(m) => m,
469 Err(_) => return Vec::new(),
470 };
471 models
472 .get(model_id)
473 .map(|entries| entries.iter().map(|e| e.metadata.version.clone()).collect())
474 .unwrap_or_default()
475 }
476
477 pub fn get_metadata(&self, model_id: &str, version: &ModelVersion) -> Option<ModelMetadata> {
479 let models = self.models.read().ok()?;
480 models.get(model_id).and_then(|entries| {
481 entries
482 .iter()
483 .find(|e| &e.metadata.version == version)
484 .map(|e| e.metadata.clone())
485 })
486 }
487
488 pub fn get_stats(&self, model_id: &str, version: &ModelVersion) -> Option<ModelStats> {
490 let models = self.models.read().ok()?;
491 models.get(model_id).and_then(|entries| {
492 entries
493 .iter()
494 .find(|e| &e.metadata.version == version)
495 .map(|e| e.stats.clone())
496 })
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_model_version_creation() {
506 let version = ModelVersion::new(1, 2, 3);
507 assert_eq!(version.major, 1);
508 assert_eq!(version.minor, 2);
509 assert_eq!(version.patch, 3);
510 }
511
512 #[test]
513 fn test_model_version_parse() {
514 let version = ModelVersion::parse("1.2.3").unwrap();
515 assert_eq!(version.major, 1);
516 assert_eq!(version.minor, 2);
517 assert_eq!(version.patch, 3);
518 }
519
520 #[test]
521 fn test_model_version_parse_invalid() {
522 assert!(ModelVersion::parse("1.2").is_err());
523 assert!(ModelVersion::parse("1.2.x").is_err());
524 }
525
526 #[test]
527 fn test_model_version_compatibility() {
528 let v1 = ModelVersion::new(1, 2, 3);
529 let v2 = ModelVersion::new(1, 3, 0);
530 let v3 = ModelVersion::new(2, 0, 0);
531
532 assert!(v1.is_compatible_with(&v2));
533 assert!(!v1.is_compatible_with(&v3));
534 }
535
536 #[test]
537 fn test_model_version_ordering() {
538 let v1 = ModelVersion::new(1, 0, 0);
539 let v2 = ModelVersion::new(1, 1, 0);
540 let v3 = ModelVersion::new(2, 0, 0);
541
542 assert!(v1 < v2);
543 assert!(v2 < v3);
544 assert!(v1 < v3);
545 }
546
547 #[test]
548 fn test_health_status() {
549 let check = HealthCheck::new(HealthStatus::Healthy);
550 assert!(check.is_usable());
551
552 let check = HealthCheck::new(HealthStatus::Degraded);
553 assert!(check.is_usable());
554
555 let check = HealthCheck::new(HealthStatus::Unhealthy);
556 assert!(!check.is_usable());
557 }
558
559 #[test]
560 fn test_model_stats() {
561 let mut stats = ModelStats::default();
562
563 stats.record_success(100.0);
564 stats.record_success(200.0);
565 stats.record_error(300.0);
566
567 assert_eq!(stats.total_requests, 3);
568 assert_eq!(stats.total_errors, 1);
569 assert_eq!(stats.avg_latency_ms(), 200.0);
570 assert_eq!(stats.error_rate(), 1.0 / 3.0);
571 }
572
573 #[test]
574 fn test_model_stats_health_check() {
575 let mut stats = ModelStats::default();
576
577 stats.record_success(50.0);
579 stats.record_success(60.0);
580 let check = stats.to_health_check();
581 assert_eq!(check.status, HealthStatus::Healthy);
582
583 stats.record_error(100.0);
585 stats.record_error(100.0);
586 stats.record_error(100.0);
587 let check = stats.to_health_check();
588 assert_eq!(check.status, HealthStatus::Unhealthy);
589 }
590
591 #[test]
592 fn test_version_manager_register() {
593 let config = VersioningConfig::default();
594 let manager = ModelVersionManager::new(config);
595
596 let metadata = ModelMetadata::new("test-model", ModelVersion::new(1, 0, 0), "transformer");
597
598 manager.register_version(metadata, true).unwrap();
599
600 let versions = manager.list_versions("test-model");
601 assert_eq!(versions.len(), 1);
602 }
603
604 #[test]
605 fn test_version_manager_active_version() {
606 let config = VersioningConfig::default();
607 let manager = ModelVersionManager::new(config);
608
609 let v1 = ModelVersion::new(1, 0, 0);
610 let metadata = ModelMetadata::new("test-model", v1.clone(), "transformer");
611 manager.register_version(metadata, true).unwrap();
612
613 manager
614 .set_active_version("test-model", v1.clone())
615 .unwrap();
616
617 let active = manager.get_active_version("test-model");
618 assert_eq!(active, Some(v1));
619 }
620
621 #[test]
622 fn test_version_manager_record_request() {
623 let config = VersioningConfig::default();
624 let manager = ModelVersionManager::new(config);
625
626 let v1 = ModelVersion::new(1, 0, 0);
627 let metadata = ModelMetadata::new("test-model", v1.clone(), "transformer");
628 manager.register_version(metadata, true).unwrap();
629
630 manager
631 .record_request("test-model", &v1, 100.0, false)
632 .unwrap();
633 manager
634 .record_request("test-model", &v1, 200.0, true)
635 .unwrap();
636
637 let stats = manager.get_stats("test-model", &v1).unwrap();
638 assert_eq!(stats.total_requests, 2);
639 assert_eq!(stats.total_errors, 1);
640 }
641
642 #[test]
643 fn test_version_manager_fallback() {
644 let config = VersioningConfig {
645 fallback_strategy: FallbackStrategy::PreviousStable,
646 ..Default::default()
647 };
648 let manager = ModelVersionManager::new(config);
649
650 let v1 = ModelVersion::new(1, 0, 0);
651 let v2 = ModelVersion::new(1, 1, 0);
652
653 let metadata1 = ModelMetadata::new("test-model", v1.clone(), "transformer");
654 let metadata2 = ModelMetadata::new("test-model", v2.clone(), "transformer");
655
656 manager.register_version(metadata1, true).unwrap(); manager.register_version(metadata2, false).unwrap(); let fallback = manager.get_fallback_version("test-model", &v2);
660 assert_eq!(fallback, Some(v1));
661 }
662
663 #[test]
664 fn test_version_manager_health_check() {
665 let config = VersioningConfig::default();
666 let manager = ModelVersionManager::new(config);
667
668 let v1 = ModelVersion::new(1, 0, 0);
669 let metadata = ModelMetadata::new("test-model", v1.clone(), "transformer");
670 manager.register_version(metadata, true).unwrap();
671
672 manager
674 .record_request("test-model", &v1, 50.0, false)
675 .unwrap();
676 manager
677 .record_request("test-model", &v1, 60.0, false)
678 .unwrap();
679
680 let health = manager.health_check("test-model", &v1).unwrap();
681 assert_eq!(health.status, HealthStatus::Healthy);
682 }
683}