oxirs_embed/
model_registry.rs

1//! Model Registry and Versioning System
2//!
3//! This module provides a comprehensive model lifecycle management system including
4//! versioning, deployment, performance tracking, and A/B testing capabilities.
5
6use crate::{EmbeddingModel, ModelConfig};
7use anyhow::{anyhow, Result};
8use chrono::{DateTime, Utc};
9// Removed unused import
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::PathBuf;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15// Removed unused import
16// Removed unused import
17use uuid::Uuid;
18
19/// Model version information
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ModelVersion {
22    pub version_id: Uuid,
23    pub model_id: Uuid,
24    pub version_number: String,
25    pub created_at: DateTime<Utc>,
26    pub created_by: String,
27    pub description: String,
28    pub tags: Vec<String>,
29    pub metrics: HashMap<String, f64>,
30    pub config: ModelConfig,
31    pub is_production: bool,
32    pub is_deprecated: bool,
33}
34
35/// Model deployment status
36#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
37pub enum DeploymentStatus {
38    NotDeployed,
39    Deploying,
40    Deployed,
41    Failed,
42    Retiring,
43    Retired,
44}
45
46/// Model deployment information
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelDeployment {
49    pub deployment_id: Uuid,
50    pub version_id: Uuid,
51    pub status: DeploymentStatus,
52    pub deployed_at: Option<DateTime<Utc>>,
53    pub endpoint: Option<String>,
54    pub resource_allocation: ResourceAllocation,
55    pub health_check_url: Option<String>,
56    pub rollback_version: Option<Uuid>,
57}
58
59/// Resource allocation for model deployment
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ResourceAllocation {
62    pub cpu_cores: f32,
63    pub memory_gb: f32,
64    pub gpu_count: u32,
65    pub gpu_memory_gb: f32,
66    pub max_concurrent_requests: usize,
67}
68
69impl Default for ResourceAllocation {
70    fn default() -> Self {
71        Self {
72            cpu_cores: 2.0,
73            memory_gb: 4.0,
74            gpu_count: 0,
75            gpu_memory_gb: 0.0,
76            max_concurrent_requests: 100,
77        }
78    }
79}
80
81/// A/B test configuration
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ABTestConfig {
84    pub test_id: Uuid,
85    pub name: String,
86    pub description: String,
87    pub version_a: Uuid,
88    pub version_b: Uuid,
89    pub traffic_split: f32, // Percentage going to version B (0.0 - 1.0)
90    pub started_at: DateTime<Utc>,
91    pub ends_at: Option<DateTime<Utc>>,
92    pub metrics_to_track: Vec<String>,
93    pub is_active: bool,
94}
95
96/// Model performance metrics
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct PerformanceMetrics {
99    pub timestamp: DateTime<Utc>,
100    pub latency_p50_ms: f64,
101    pub latency_p95_ms: f64,
102    pub latency_p99_ms: f64,
103    pub throughput_qps: f64,
104    pub error_rate: f64,
105    pub cpu_utilization: f64,
106    pub memory_utilization: f64,
107    pub gpu_utilization: Option<f64>,
108    pub cache_hit_rate: f64,
109}
110
111/// Model registry for managing model lifecycle
112pub struct ModelRegistry {
113    models: Arc<RwLock<HashMap<Uuid, ModelMetadata>>>,
114    versions: Arc<RwLock<HashMap<Uuid, ModelVersion>>>,
115    deployments: Arc<RwLock<HashMap<Uuid, ModelDeployment>>>,
116    ab_tests: Arc<RwLock<HashMap<Uuid, ABTestConfig>>>,
117    performance_history: Arc<RwLock<HashMap<Uuid, Vec<PerformanceMetrics>>>>,
118    #[allow(dead_code)]
119    storage_path: PathBuf,
120}
121
122/// Model metadata
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ModelMetadata {
125    pub model_id: Uuid,
126    pub name: String,
127    pub model_type: String,
128    pub created_at: DateTime<Utc>,
129    pub updated_at: DateTime<Utc>,
130    pub owner: String,
131    pub description: String,
132    pub versions: Vec<Uuid>,
133    pub production_version: Option<Uuid>,
134    pub staging_version: Option<Uuid>,
135}
136
137impl ModelRegistry {
138    /// Create a new model registry
139    pub fn new(storage_path: PathBuf) -> Self {
140        Self {
141            models: Arc::new(RwLock::new(HashMap::new())),
142            versions: Arc::new(RwLock::new(HashMap::new())),
143            deployments: Arc::new(RwLock::new(HashMap::new())),
144            ab_tests: Arc::new(RwLock::new(HashMap::new())),
145            performance_history: Arc::new(RwLock::new(HashMap::new())),
146            storage_path,
147        }
148    }
149
150    /// Register a new model
151    pub async fn register_model(
152        &self,
153        name: String,
154        model_type: String,
155        owner: String,
156        description: String,
157    ) -> Result<Uuid> {
158        let model_id = Uuid::new_v4();
159        let metadata = ModelMetadata {
160            model_id,
161            name,
162            model_type,
163            created_at: Utc::now(),
164            updated_at: Utc::now(),
165            owner,
166            description,
167            versions: Vec::new(),
168            production_version: None,
169            staging_version: None,
170        };
171
172        self.models.write().await.insert(model_id, metadata);
173        Ok(model_id)
174    }
175
176    /// Register a new model version
177    pub async fn register_version(
178        &self,
179        model_id: Uuid,
180        version_number: String,
181        created_by: String,
182        description: String,
183        config: ModelConfig,
184        metrics: HashMap<String, f64>,
185    ) -> Result<Uuid> {
186        let version_id = Uuid::new_v4();
187
188        // Verify model exists
189        let mut models = self.models.write().await;
190        let model = models
191            .get_mut(&model_id)
192            .ok_or_else(|| anyhow!("Model not found: {}", model_id))?;
193
194        let version = ModelVersion {
195            version_id,
196            model_id,
197            version_number,
198            created_at: Utc::now(),
199            created_by,
200            description,
201            tags: Vec::new(),
202            metrics,
203            config,
204            is_production: false,
205            is_deprecated: false,
206        };
207
208        model.versions.push(version_id);
209        model.updated_at = Utc::now();
210
211        self.versions.write().await.insert(version_id, version);
212        Ok(version_id)
213    }
214
215    /// Deploy a model version
216    pub async fn deploy_version(
217        &self,
218        version_id: Uuid,
219        resource_allocation: ResourceAllocation,
220    ) -> Result<Uuid> {
221        // Verify version exists
222        if !self.versions.read().await.contains_key(&version_id) {
223            return Err(anyhow!("Version not found: {}", version_id));
224        }
225
226        let deployment_id = Uuid::new_v4();
227        let deployment = ModelDeployment {
228            deployment_id,
229            version_id,
230            status: DeploymentStatus::Deploying,
231            deployed_at: None,
232            endpoint: None,
233            resource_allocation,
234            health_check_url: None,
235            rollback_version: None,
236        };
237
238        self.deployments
239            .write()
240            .await
241            .insert(deployment_id, deployment);
242
243        // Start deployment process (in real implementation)
244        self.start_deployment(deployment_id).await?;
245
246        Ok(deployment_id)
247    }
248
249    /// Start deployment process
250    async fn start_deployment(&self, deployment_id: Uuid) -> Result<()> {
251        // In a real implementation, this would:
252        // 1. Allocate resources
253        // 2. Load model weights
254        // 3. Start serving infrastructure
255        // 4. Configure load balancer
256        // 5. Run health checks
257
258        // For now, simulate deployment
259        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
260
261        let mut deployments = self.deployments.write().await;
262        if let Some(deployment) = deployments.get_mut(&deployment_id) {
263            deployment.status = DeploymentStatus::Deployed;
264            deployment.deployed_at = Some(Utc::now());
265            deployment.endpoint = Some(format!("https://api.oxirs.ai/v1/embed/{deployment_id}"));
266            deployment.health_check_url = Some(format!(
267                "https://api.oxirs.ai/v1/embed/{deployment_id}/health"
268            ));
269        }
270
271        Ok(())
272    }
273
274    /// Promote version to production
275    pub async fn promote_to_production(&self, version_id: Uuid) -> Result<()> {
276        let versions = self.versions.read().await;
277        let version = versions
278            .get(&version_id)
279            .ok_or_else(|| anyhow!("Version not found: {}", version_id))?;
280
281        let model_id = version.model_id;
282        drop(versions);
283
284        let mut models = self.models.write().await;
285        let model = models
286            .get_mut(&model_id)
287            .ok_or_else(|| anyhow!("Model not found: {}", model_id))?;
288
289        // Mark previous production version as non-production
290        if let Some(prev_prod) = model.production_version {
291            let mut versions = self.versions.write().await;
292            if let Some(prev_version) = versions.get_mut(&prev_prod) {
293                prev_version.is_production = false;
294            }
295        }
296
297        model.production_version = Some(version_id);
298        model.updated_at = Utc::now();
299
300        let mut versions = self.versions.write().await;
301        if let Some(version) = versions.get_mut(&version_id) {
302            version.is_production = true;
303        }
304
305        Ok(())
306    }
307
308    /// Create A/B test
309    pub async fn create_ab_test(
310        &self,
311        name: String,
312        description: String,
313        version_a: Uuid,
314        version_b: Uuid,
315        traffic_split: f32,
316        duration_hours: Option<u32>,
317    ) -> Result<Uuid> {
318        // Verify both versions exist
319        let versions = self.versions.read().await;
320        if !versions.contains_key(&version_a) {
321            return Err(anyhow!("Version A not found: {}", version_a));
322        }
323        if !versions.contains_key(&version_b) {
324            return Err(anyhow!("Version B not found: {}", version_b));
325        }
326        drop(versions);
327
328        if !(0.0..=1.0).contains(&traffic_split) {
329            return Err(anyhow!("Traffic split must be between 0.0 and 1.0"));
330        }
331
332        let test_id = Uuid::new_v4();
333        let ab_test = ABTestConfig {
334            test_id,
335            name,
336            description,
337            version_a,
338            version_b,
339            traffic_split,
340            started_at: Utc::now(),
341            ends_at: duration_hours.map(|h| Utc::now() + chrono::Duration::hours(h as i64)),
342            metrics_to_track: vec![
343                "latency_p95".to_string(),
344                "accuracy".to_string(),
345                "error_rate".to_string(),
346            ],
347            is_active: true,
348        };
349
350        self.ab_tests.write().await.insert(test_id, ab_test);
351        Ok(test_id)
352    }
353
354    /// Record performance metrics
355    pub async fn record_performance(
356        &self,
357        version_id: Uuid,
358        metrics: PerformanceMetrics,
359    ) -> Result<()> {
360        let mut history = self.performance_history.write().await;
361        history
362            .entry(version_id)
363            .or_insert_with(Vec::new)
364            .push(metrics);
365
366        // Keep only last 1000 metrics per version
367        if let Some(vec) = history.get_mut(&version_id) {
368            if vec.len() > 1000 {
369                vec.drain(0..vec.len() - 1000);
370            }
371        }
372
373        Ok(())
374    }
375
376    /// Get model metadata
377    pub async fn get_model(&self, model_id: Uuid) -> Result<ModelMetadata> {
378        self.models
379            .read()
380            .await
381            .get(&model_id)
382            .cloned()
383            .ok_or_else(|| anyhow!("Model not found: {}", model_id))
384    }
385
386    /// Get version info
387    pub async fn get_version(&self, version_id: Uuid) -> Result<ModelVersion> {
388        self.versions
389            .read()
390            .await
391            .get(&version_id)
392            .cloned()
393            .ok_or_else(|| anyhow!("Version not found: {}", version_id))
394    }
395
396    /// Get deployment info
397    pub async fn get_deployment(&self, deployment_id: Uuid) -> Result<ModelDeployment> {
398        self.deployments
399            .read()
400            .await
401            .get(&deployment_id)
402            .cloned()
403            .ok_or_else(|| anyhow!("Deployment not found: {}", deployment_id))
404    }
405
406    /// Get performance history
407    pub async fn get_performance_history(
408        &self,
409        version_id: Uuid,
410        limit: Option<usize>,
411    ) -> Result<Vec<PerformanceMetrics>> {
412        let history = self.performance_history.read().await;
413        let metrics = history
414            .get(&version_id)
415            .ok_or_else(|| anyhow!("No performance history for version: {}", version_id))?;
416
417        let limit = limit.unwrap_or(100);
418        let start = metrics.len().saturating_sub(limit);
419
420        Ok(metrics[start..].to_vec())
421    }
422
423    /// Rollback deployment
424    pub async fn rollback_deployment(&self, deployment_id: Uuid) -> Result<()> {
425        let (rollback_version, resource_allocation) = {
426            let deployments = self.deployments.read().await;
427            let deployment = deployments
428                .get(&deployment_id)
429                .ok_or_else(|| anyhow!("Deployment not found: {}", deployment_id))?;
430
431            if let Some(rollback_version) = deployment.rollback_version {
432                (rollback_version, deployment.resource_allocation.clone())
433            } else {
434                return Err(anyhow!("No rollback version configured"));
435            }
436        };
437
438        // Deploy the rollback version
439        self.deploy_version(rollback_version, resource_allocation)
440            .await?;
441
442        // Mark current deployment as retired
443        let mut deployments = self.deployments.write().await;
444        if let Some(deployment) = deployments.get_mut(&deployment_id) {
445            deployment.status = DeploymentStatus::Retired;
446        }
447
448        Ok(())
449    }
450
451    /// List all models
452    pub async fn list_models(&self) -> Vec<ModelMetadata> {
453        self.models.read().await.values().cloned().collect()
454    }
455
456    /// List versions for a model
457    pub async fn list_versions(&self, model_id: Uuid) -> Result<Vec<ModelVersion>> {
458        let models = self.models.read().await;
459        let model = models
460            .get(&model_id)
461            .ok_or_else(|| anyhow!("Model not found: {}", model_id))?;
462
463        let version_ids = model.versions.clone();
464        drop(models);
465
466        let versions = self.versions.read().await;
467        let mut result = Vec::new();
468
469        for version_id in version_ids {
470            if let Some(version) = versions.get(&version_id) {
471                result.push(version.clone());
472            }
473        }
474
475        Ok(result)
476    }
477
478    /// Get active A/B tests
479    pub async fn get_active_ab_tests(&self) -> Vec<ABTestConfig> {
480        self.ab_tests
481            .read()
482            .await
483            .values()
484            .filter(|test| test.is_active)
485            .cloned()
486            .collect()
487    }
488
489    /// End A/B test
490    pub async fn end_ab_test(&self, test_id: Uuid) -> Result<()> {
491        let mut ab_tests = self.ab_tests.write().await;
492        let test = ab_tests
493            .get_mut(&test_id)
494            .ok_or_else(|| anyhow!("A/B test not found: {}", test_id))?;
495
496        test.is_active = false;
497        test.ends_at = Some(Utc::now());
498
499        Ok(())
500    }
501}
502
503/// Model serving infrastructure
504pub struct ModelServer {
505    registry: Arc<ModelRegistry>,
506    #[allow(dead_code)]
507    loaded_models: Arc<RwLock<HashMap<Uuid, Box<dyn EmbeddingModel>>>>,
508    warm_up_cache: Arc<RwLock<HashMap<Uuid, Vec<String>>>>,
509}
510
511impl ModelServer {
512    pub fn new(registry: Arc<ModelRegistry>) -> Self {
513        Self {
514            registry,
515            loaded_models: Arc::new(RwLock::new(HashMap::new())),
516            warm_up_cache: Arc::new(RwLock::new(HashMap::new())),
517        }
518    }
519
520    /// Load model into memory
521    pub async fn load_model(&self, _version_id: Uuid) -> Result<()> {
522        // In real implementation, this would load the actual model
523        // For now, we just mark it as loaded
524        Ok(())
525    }
526
527    /// Warm up model with sample inputs
528    pub async fn warm_up_model(&self, version_id: Uuid, samples: Vec<String>) -> Result<()> {
529        self.warm_up_cache.write().await.insert(version_id, samples);
530
531        // In real implementation, run inference on samples to warm up caches
532        Ok(())
533    }
534
535    /// Get model for inference
536    pub async fn get_model(&self, _version_id: Uuid) -> Result<Arc<Box<dyn EmbeddingModel>>> {
537        // In real implementation, return loaded model
538        Err(anyhow!("Model loading not implemented"))
539    }
540
541    /// Route request based on A/B test
542    pub async fn route_request(&self, test_id: Uuid) -> Result<Uuid> {
543        let ab_tests = self.registry.ab_tests.read().await;
544        let test = ab_tests
545            .get(&test_id)
546            .ok_or_else(|| anyhow!("A/B test not found: {}", test_id))?;
547
548        // Simple random routing based on traffic split
549        let random = {
550            use scirs2_core::random::{Random, Rng};
551            let mut random = Random::default();
552            random.random::<f32>()
553        };
554        Ok(if random < test.traffic_split {
555            test.version_b
556        } else {
557            test.version_a
558        })
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565    use tempfile::tempdir;
566
567    #[tokio::test]
568    async fn test_model_registry_lifecycle() {
569        let temp_dir = tempdir().unwrap();
570        let registry = ModelRegistry::new(temp_dir.path().to_path_buf());
571
572        // Register model
573        let model_id = registry
574            .register_model(
575                "test-model".to_string(),
576                "TransformerEmbedding".to_string(),
577                "test-user".to_string(),
578                "Test model".to_string(),
579            )
580            .await
581            .unwrap();
582
583        // Register version
584        let config = ModelConfig::default();
585        let mut metrics = HashMap::new();
586        metrics.insert("accuracy".to_string(), 0.95);
587
588        let version_id = registry
589            .register_version(
590                model_id,
591                "1.0.0".to_string(),
592                "test-user".to_string(),
593                "Initial version".to_string(),
594                config,
595                metrics,
596            )
597            .await
598            .unwrap();
599
600        // Deploy version
601        let deployment_id = registry
602            .deploy_version(version_id, ResourceAllocation::default())
603            .await
604            .unwrap();
605
606        // Wait for deployment
607        tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
608
609        // Check deployment status
610        let deployment = registry.get_deployment(deployment_id).await.unwrap();
611        assert_eq!(deployment.status, DeploymentStatus::Deployed);
612        assert!(deployment.endpoint.is_some());
613
614        // Promote to production
615        registry.promote_to_production(version_id).await.unwrap();
616
617        let model = registry.get_model(model_id).await.unwrap();
618        assert_eq!(model.production_version, Some(version_id));
619    }
620
621    #[tokio::test]
622    async fn test_ab_testing() {
623        let temp_dir = tempdir().unwrap();
624        let registry = ModelRegistry::new(temp_dir.path().to_path_buf());
625
626        // Register model and two versions
627        let model_id = registry
628            .register_model(
629                "ab-test-model".to_string(),
630                "GNNEmbedding".to_string(),
631                "test-user".to_string(),
632                "AB test model".to_string(),
633            )
634            .await
635            .unwrap();
636
637        let version_a = registry
638            .register_version(
639                model_id,
640                "1.0.0".to_string(),
641                "test-user".to_string(),
642                "Version A".to_string(),
643                ModelConfig::default(),
644                HashMap::new(),
645            )
646            .await
647            .unwrap();
648
649        let version_b = registry
650            .register_version(
651                model_id,
652                "1.1.0".to_string(),
653                "test-user".to_string(),
654                "Version B".to_string(),
655                ModelConfig::default(),
656                HashMap::new(),
657            )
658            .await
659            .unwrap();
660
661        // Create A/B test
662        let test_id = registry
663            .create_ab_test(
664                "Performance test".to_string(),
665                "Testing new model version".to_string(),
666                version_a,
667                version_b,
668                0.3,      // 30% traffic to version B
669                Some(24), // 24 hour test
670            )
671            .await
672            .unwrap();
673
674        // Check active tests
675        let active_tests = registry.get_active_ab_tests().await;
676        assert_eq!(active_tests.len(), 1);
677        assert_eq!(active_tests[0].test_id, test_id);
678
679        // End test
680        registry.end_ab_test(test_id).await.unwrap();
681
682        let active_tests = registry.get_active_ab_tests().await;
683        assert_eq!(active_tests.len(), 0);
684    }
685}