1use crate::{EmbeddingModel, ModelConfig};
7use anyhow::{anyhow, Result};
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::PathBuf;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use uuid::Uuid;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ModelVersion {
22 pub version_id: Uuid,
23 pub model_id: Uuid,
24 pub version_number: String,
25 pub created_at: DateTime<Utc>,
26 pub created_by: String,
27 pub description: String,
28 pub tags: Vec<String>,
29 pub metrics: HashMap<String, f64>,
30 pub config: ModelConfig,
31 pub is_production: bool,
32 pub is_deprecated: bool,
33}
34
35#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
37pub enum DeploymentStatus {
38 NotDeployed,
39 Deploying,
40 Deployed,
41 Failed,
42 Retiring,
43 Retired,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelDeployment {
49 pub deployment_id: Uuid,
50 pub version_id: Uuid,
51 pub status: DeploymentStatus,
52 pub deployed_at: Option<DateTime<Utc>>,
53 pub endpoint: Option<String>,
54 pub resource_allocation: ResourceAllocation,
55 pub health_check_url: Option<String>,
56 pub rollback_version: Option<Uuid>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ResourceAllocation {
62 pub cpu_cores: f32,
63 pub memory_gb: f32,
64 pub gpu_count: u32,
65 pub gpu_memory_gb: f32,
66 pub max_concurrent_requests: usize,
67}
68
69impl Default for ResourceAllocation {
70 fn default() -> Self {
71 Self {
72 cpu_cores: 2.0,
73 memory_gb: 4.0,
74 gpu_count: 0,
75 gpu_memory_gb: 0.0,
76 max_concurrent_requests: 100,
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ABTestConfig {
84 pub test_id: Uuid,
85 pub name: String,
86 pub description: String,
87 pub version_a: Uuid,
88 pub version_b: Uuid,
89 pub traffic_split: f32, pub started_at: DateTime<Utc>,
91 pub ends_at: Option<DateTime<Utc>>,
92 pub metrics_to_track: Vec<String>,
93 pub is_active: bool,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct PerformanceMetrics {
99 pub timestamp: DateTime<Utc>,
100 pub latency_p50_ms: f64,
101 pub latency_p95_ms: f64,
102 pub latency_p99_ms: f64,
103 pub throughput_qps: f64,
104 pub error_rate: f64,
105 pub cpu_utilization: f64,
106 pub memory_utilization: f64,
107 pub gpu_utilization: Option<f64>,
108 pub cache_hit_rate: f64,
109}
110
111pub struct ModelRegistry {
113 models: Arc<RwLock<HashMap<Uuid, ModelMetadata>>>,
114 versions: Arc<RwLock<HashMap<Uuid, ModelVersion>>>,
115 deployments: Arc<RwLock<HashMap<Uuid, ModelDeployment>>>,
116 ab_tests: Arc<RwLock<HashMap<Uuid, ABTestConfig>>>,
117 performance_history: Arc<RwLock<HashMap<Uuid, Vec<PerformanceMetrics>>>>,
118 #[allow(dead_code)]
119 storage_path: PathBuf,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ModelMetadata {
125 pub model_id: Uuid,
126 pub name: String,
127 pub model_type: String,
128 pub created_at: DateTime<Utc>,
129 pub updated_at: DateTime<Utc>,
130 pub owner: String,
131 pub description: String,
132 pub versions: Vec<Uuid>,
133 pub production_version: Option<Uuid>,
134 pub staging_version: Option<Uuid>,
135}
136
137impl ModelRegistry {
138 pub fn new(storage_path: PathBuf) -> Self {
140 Self {
141 models: Arc::new(RwLock::new(HashMap::new())),
142 versions: Arc::new(RwLock::new(HashMap::new())),
143 deployments: Arc::new(RwLock::new(HashMap::new())),
144 ab_tests: Arc::new(RwLock::new(HashMap::new())),
145 performance_history: Arc::new(RwLock::new(HashMap::new())),
146 storage_path,
147 }
148 }
149
150 pub async fn register_model(
152 &self,
153 name: String,
154 model_type: String,
155 owner: String,
156 description: String,
157 ) -> Result<Uuid> {
158 let model_id = Uuid::new_v4();
159 let metadata = ModelMetadata {
160 model_id,
161 name,
162 model_type,
163 created_at: Utc::now(),
164 updated_at: Utc::now(),
165 owner,
166 description,
167 versions: Vec::new(),
168 production_version: None,
169 staging_version: None,
170 };
171
172 self.models.write().await.insert(model_id, metadata);
173 Ok(model_id)
174 }
175
176 pub async fn register_version(
178 &self,
179 model_id: Uuid,
180 version_number: String,
181 created_by: String,
182 description: String,
183 config: ModelConfig,
184 metrics: HashMap<String, f64>,
185 ) -> Result<Uuid> {
186 let version_id = Uuid::new_v4();
187
188 let mut models = self.models.write().await;
190 let model = models
191 .get_mut(&model_id)
192 .ok_or_else(|| anyhow!("Model not found: {}", model_id))?;
193
194 let version = ModelVersion {
195 version_id,
196 model_id,
197 version_number,
198 created_at: Utc::now(),
199 created_by,
200 description,
201 tags: Vec::new(),
202 metrics,
203 config,
204 is_production: false,
205 is_deprecated: false,
206 };
207
208 model.versions.push(version_id);
209 model.updated_at = Utc::now();
210
211 self.versions.write().await.insert(version_id, version);
212 Ok(version_id)
213 }
214
215 pub async fn deploy_version(
217 &self,
218 version_id: Uuid,
219 resource_allocation: ResourceAllocation,
220 ) -> Result<Uuid> {
221 if !self.versions.read().await.contains_key(&version_id) {
223 return Err(anyhow!("Version not found: {}", version_id));
224 }
225
226 let deployment_id = Uuid::new_v4();
227 let deployment = ModelDeployment {
228 deployment_id,
229 version_id,
230 status: DeploymentStatus::Deploying,
231 deployed_at: None,
232 endpoint: None,
233 resource_allocation,
234 health_check_url: None,
235 rollback_version: None,
236 };
237
238 self.deployments
239 .write()
240 .await
241 .insert(deployment_id, deployment);
242
243 self.start_deployment(deployment_id).await?;
245
246 Ok(deployment_id)
247 }
248
249 async fn start_deployment(&self, deployment_id: Uuid) -> Result<()> {
251 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
260
261 let mut deployments = self.deployments.write().await;
262 if let Some(deployment) = deployments.get_mut(&deployment_id) {
263 deployment.status = DeploymentStatus::Deployed;
264 deployment.deployed_at = Some(Utc::now());
265 deployment.endpoint = Some(format!("https://api.oxirs.ai/v1/embed/{deployment_id}"));
266 deployment.health_check_url = Some(format!(
267 "https://api.oxirs.ai/v1/embed/{deployment_id}/health"
268 ));
269 }
270
271 Ok(())
272 }
273
274 pub async fn promote_to_production(&self, version_id: Uuid) -> Result<()> {
276 let versions = self.versions.read().await;
277 let version = versions
278 .get(&version_id)
279 .ok_or_else(|| anyhow!("Version not found: {}", version_id))?;
280
281 let model_id = version.model_id;
282 drop(versions);
283
284 let mut models = self.models.write().await;
285 let model = models
286 .get_mut(&model_id)
287 .ok_or_else(|| anyhow!("Model not found: {}", model_id))?;
288
289 if let Some(prev_prod) = model.production_version {
291 let mut versions = self.versions.write().await;
292 if let Some(prev_version) = versions.get_mut(&prev_prod) {
293 prev_version.is_production = false;
294 }
295 }
296
297 model.production_version = Some(version_id);
298 model.updated_at = Utc::now();
299
300 let mut versions = self.versions.write().await;
301 if let Some(version) = versions.get_mut(&version_id) {
302 version.is_production = true;
303 }
304
305 Ok(())
306 }
307
308 pub async fn create_ab_test(
310 &self,
311 name: String,
312 description: String,
313 version_a: Uuid,
314 version_b: Uuid,
315 traffic_split: f32,
316 duration_hours: Option<u32>,
317 ) -> Result<Uuid> {
318 let versions = self.versions.read().await;
320 if !versions.contains_key(&version_a) {
321 return Err(anyhow!("Version A not found: {}", version_a));
322 }
323 if !versions.contains_key(&version_b) {
324 return Err(anyhow!("Version B not found: {}", version_b));
325 }
326 drop(versions);
327
328 if !(0.0..=1.0).contains(&traffic_split) {
329 return Err(anyhow!("Traffic split must be between 0.0 and 1.0"));
330 }
331
332 let test_id = Uuid::new_v4();
333 let ab_test = ABTestConfig {
334 test_id,
335 name,
336 description,
337 version_a,
338 version_b,
339 traffic_split,
340 started_at: Utc::now(),
341 ends_at: duration_hours.map(|h| Utc::now() + chrono::Duration::hours(h as i64)),
342 metrics_to_track: vec![
343 "latency_p95".to_string(),
344 "accuracy".to_string(),
345 "error_rate".to_string(),
346 ],
347 is_active: true,
348 };
349
350 self.ab_tests.write().await.insert(test_id, ab_test);
351 Ok(test_id)
352 }
353
354 pub async fn record_performance(
356 &self,
357 version_id: Uuid,
358 metrics: PerformanceMetrics,
359 ) -> Result<()> {
360 let mut history = self.performance_history.write().await;
361 history
362 .entry(version_id)
363 .or_insert_with(Vec::new)
364 .push(metrics);
365
366 if let Some(vec) = history.get_mut(&version_id) {
368 if vec.len() > 1000 {
369 vec.drain(0..vec.len() - 1000);
370 }
371 }
372
373 Ok(())
374 }
375
376 pub async fn get_model(&self, model_id: Uuid) -> Result<ModelMetadata> {
378 self.models
379 .read()
380 .await
381 .get(&model_id)
382 .cloned()
383 .ok_or_else(|| anyhow!("Model not found: {}", model_id))
384 }
385
386 pub async fn get_version(&self, version_id: Uuid) -> Result<ModelVersion> {
388 self.versions
389 .read()
390 .await
391 .get(&version_id)
392 .cloned()
393 .ok_or_else(|| anyhow!("Version not found: {}", version_id))
394 }
395
396 pub async fn get_deployment(&self, deployment_id: Uuid) -> Result<ModelDeployment> {
398 self.deployments
399 .read()
400 .await
401 .get(&deployment_id)
402 .cloned()
403 .ok_or_else(|| anyhow!("Deployment not found: {}", deployment_id))
404 }
405
406 pub async fn get_performance_history(
408 &self,
409 version_id: Uuid,
410 limit: Option<usize>,
411 ) -> Result<Vec<PerformanceMetrics>> {
412 let history = self.performance_history.read().await;
413 let metrics = history
414 .get(&version_id)
415 .ok_or_else(|| anyhow!("No performance history for version: {}", version_id))?;
416
417 let limit = limit.unwrap_or(100);
418 let start = metrics.len().saturating_sub(limit);
419
420 Ok(metrics[start..].to_vec())
421 }
422
423 pub async fn rollback_deployment(&self, deployment_id: Uuid) -> Result<()> {
425 let (rollback_version, resource_allocation) = {
426 let deployments = self.deployments.read().await;
427 let deployment = deployments
428 .get(&deployment_id)
429 .ok_or_else(|| anyhow!("Deployment not found: {}", deployment_id))?;
430
431 if let Some(rollback_version) = deployment.rollback_version {
432 (rollback_version, deployment.resource_allocation.clone())
433 } else {
434 return Err(anyhow!("No rollback version configured"));
435 }
436 };
437
438 self.deploy_version(rollback_version, resource_allocation)
440 .await?;
441
442 let mut deployments = self.deployments.write().await;
444 if let Some(deployment) = deployments.get_mut(&deployment_id) {
445 deployment.status = DeploymentStatus::Retired;
446 }
447
448 Ok(())
449 }
450
451 pub async fn list_models(&self) -> Vec<ModelMetadata> {
453 self.models.read().await.values().cloned().collect()
454 }
455
456 pub async fn list_versions(&self, model_id: Uuid) -> Result<Vec<ModelVersion>> {
458 let models = self.models.read().await;
459 let model = models
460 .get(&model_id)
461 .ok_or_else(|| anyhow!("Model not found: {}", model_id))?;
462
463 let version_ids = model.versions.clone();
464 drop(models);
465
466 let versions = self.versions.read().await;
467 let mut result = Vec::new();
468
469 for version_id in version_ids {
470 if let Some(version) = versions.get(&version_id) {
471 result.push(version.clone());
472 }
473 }
474
475 Ok(result)
476 }
477
478 pub async fn get_active_ab_tests(&self) -> Vec<ABTestConfig> {
480 self.ab_tests
481 .read()
482 .await
483 .values()
484 .filter(|test| test.is_active)
485 .cloned()
486 .collect()
487 }
488
489 pub async fn end_ab_test(&self, test_id: Uuid) -> Result<()> {
491 let mut ab_tests = self.ab_tests.write().await;
492 let test = ab_tests
493 .get_mut(&test_id)
494 .ok_or_else(|| anyhow!("A/B test not found: {}", test_id))?;
495
496 test.is_active = false;
497 test.ends_at = Some(Utc::now());
498
499 Ok(())
500 }
501}
502
503pub struct ModelServer {
505 registry: Arc<ModelRegistry>,
506 #[allow(dead_code)]
507 loaded_models: Arc<RwLock<HashMap<Uuid, Box<dyn EmbeddingModel>>>>,
508 warm_up_cache: Arc<RwLock<HashMap<Uuid, Vec<String>>>>,
509}
510
511impl ModelServer {
512 pub fn new(registry: Arc<ModelRegistry>) -> Self {
513 Self {
514 registry,
515 loaded_models: Arc::new(RwLock::new(HashMap::new())),
516 warm_up_cache: Arc::new(RwLock::new(HashMap::new())),
517 }
518 }
519
520 pub async fn load_model(&self, _version_id: Uuid) -> Result<()> {
522 Ok(())
525 }
526
527 pub async fn warm_up_model(&self, version_id: Uuid, samples: Vec<String>) -> Result<()> {
529 self.warm_up_cache.write().await.insert(version_id, samples);
530
531 Ok(())
533 }
534
535 pub async fn get_model(&self, _version_id: Uuid) -> Result<Arc<Box<dyn EmbeddingModel>>> {
537 Err(anyhow!("Model loading not implemented"))
539 }
540
541 pub async fn route_request(&self, test_id: Uuid) -> Result<Uuid> {
543 let ab_tests = self.registry.ab_tests.read().await;
544 let test = ab_tests
545 .get(&test_id)
546 .ok_or_else(|| anyhow!("A/B test not found: {}", test_id))?;
547
548 let random = {
550 use scirs2_core::random::{Random, Rng};
551 let mut random = Random::default();
552 random.random::<f32>()
553 };
554 Ok(if random < test.traffic_split {
555 test.version_b
556 } else {
557 test.version_a
558 })
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565 use tempfile::tempdir;
566
567 #[tokio::test]
568 async fn test_model_registry_lifecycle() {
569 let temp_dir = tempdir().unwrap();
570 let registry = ModelRegistry::new(temp_dir.path().to_path_buf());
571
572 let model_id = registry
574 .register_model(
575 "test-model".to_string(),
576 "TransformerEmbedding".to_string(),
577 "test-user".to_string(),
578 "Test model".to_string(),
579 )
580 .await
581 .unwrap();
582
583 let config = ModelConfig::default();
585 let mut metrics = HashMap::new();
586 metrics.insert("accuracy".to_string(), 0.95);
587
588 let version_id = registry
589 .register_version(
590 model_id,
591 "1.0.0".to_string(),
592 "test-user".to_string(),
593 "Initial version".to_string(),
594 config,
595 metrics,
596 )
597 .await
598 .unwrap();
599
600 let deployment_id = registry
602 .deploy_version(version_id, ResourceAllocation::default())
603 .await
604 .unwrap();
605
606 tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
608
609 let deployment = registry.get_deployment(deployment_id).await.unwrap();
611 assert_eq!(deployment.status, DeploymentStatus::Deployed);
612 assert!(deployment.endpoint.is_some());
613
614 registry.promote_to_production(version_id).await.unwrap();
616
617 let model = registry.get_model(model_id).await.unwrap();
618 assert_eq!(model.production_version, Some(version_id));
619 }
620
621 #[tokio::test]
622 async fn test_ab_testing() {
623 let temp_dir = tempdir().unwrap();
624 let registry = ModelRegistry::new(temp_dir.path().to_path_buf());
625
626 let model_id = registry
628 .register_model(
629 "ab-test-model".to_string(),
630 "GNNEmbedding".to_string(),
631 "test-user".to_string(),
632 "AB test model".to_string(),
633 )
634 .await
635 .unwrap();
636
637 let version_a = registry
638 .register_version(
639 model_id,
640 "1.0.0".to_string(),
641 "test-user".to_string(),
642 "Version A".to_string(),
643 ModelConfig::default(),
644 HashMap::new(),
645 )
646 .await
647 .unwrap();
648
649 let version_b = registry
650 .register_version(
651 model_id,
652 "1.1.0".to_string(),
653 "test-user".to_string(),
654 "Version B".to_string(),
655 ModelConfig::default(),
656 HashMap::new(),
657 )
658 .await
659 .unwrap();
660
661 let test_id = registry
663 .create_ab_test(
664 "Performance test".to_string(),
665 "Testing new model version".to_string(),
666 version_a,
667 version_b,
668 0.3, Some(24), )
671 .await
672 .unwrap();
673
674 let active_tests = registry.get_active_ab_tests().await;
676 assert_eq!(active_tests.len(), 1);
677 assert_eq!(active_tests[0].test_id, test_id);
678
679 registry.end_ab_test(test_id).await.unwrap();
681
682 let active_tests = registry.get_active_ab_tests().await;
683 assert_eq!(active_tests.len(), 0);
684 }
685}