1use anyhow::{anyhow, Result};
19use scirs2_core::ndarray_ext::Array1;
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25use tokio::sync::RwLock;
26use tracing::{debug, info, warn};
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelServingConfig {
31 pub model_registry_path: PathBuf,
33 pub enable_ab_testing: bool,
35 pub ab_test_split: f64,
37 pub warmup_samples: usize,
39 pub auto_rollback_threshold: f64,
41 pub model_cache_size: usize,
43 pub enable_versioning: bool,
45}
46
47impl Default for ModelServingConfig {
48 fn default() -> Self {
49 Self {
50 model_registry_path: PathBuf::from("/tmp/oxirs_models"),
51 enable_ab_testing: true,
52 ab_test_split: 0.5,
53 warmup_samples: 100,
54 auto_rollback_threshold: 0.1,
55 model_cache_size: 5,
56 enable_versioning: true,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ModelVersion {
64 pub version_id: String,
65 pub model_type: ModelType,
66 pub deployed_at: chrono::DateTime<chrono::Utc>,
67 pub status: ModelStatus,
68 pub metrics: ModelMetrics,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
73pub enum ModelType {
74 TransformerOptimizer,
76 CostEstimator,
78 JoinOptimizer,
80 CardinalityEstimator,
82}
83
84#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
86pub enum ModelStatus {
87 Loading,
88 Warming,
89 Serving,
90 ABTesting,
91 Rollback,
92 Deprecated,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ModelMetrics {
98 pub total_requests: u64,
99 pub successful_requests: u64,
100 pub failed_requests: u64,
101 pub average_latency_ms: f64,
102 pub p95_latency_ms: f64,
103 pub p99_latency_ms: f64,
104 pub error_rate: f64,
105}
106
107impl Default for ModelMetrics {
108 fn default() -> Self {
109 Self {
110 total_requests: 0,
111 successful_requests: 0,
112 failed_requests: 0,
113 average_latency_ms: 0.0,
114 p95_latency_ms: 0.0,
115 p99_latency_ms: 0.0,
116 error_rate: 0.0,
117 }
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct TransformerConfig {
124 pub input_dim: usize,
125 pub hidden_dim: usize,
126 pub num_heads: usize,
127 pub num_layers: usize,
128 pub dropout: f64,
129}
130
131pub struct QueryTransformerModel {
133 #[allow(dead_code)]
134 config: TransformerConfig,
135 version: String,
136 parameters: Vec<Array1<f64>>,
137}
138
139impl QueryTransformerModel {
140 pub fn new(config: TransformerConfig, version: String) -> Self {
142 let parameters = (0..config.num_layers)
144 .map(|_| Array1::from_vec(vec![0.0; config.hidden_dim]))
145 .collect();
146
147 Self {
148 config,
149 version,
150 parameters,
151 }
152 }
153
154 pub fn optimize_query(&self, query_embedding: &[f64]) -> Result<Vec<f64>> {
156 let mut output = query_embedding.to_vec();
159
160 for param in &self.parameters {
162 for (i, val) in output.iter_mut().enumerate() {
163 if i < param.len() {
164 *val += param[i] * 0.1;
165 }
166 }
167 }
168
169 Ok(output)
170 }
171
172 pub fn version(&self) -> &str {
174 &self.version
175 }
176}
177
178pub struct MLModelServing {
180 config: ModelServingConfig,
181 models: Arc<RwLock<HashMap<String, Arc<QueryTransformerModel>>>>,
182 active_versions: Arc<RwLock<HashMap<ModelType, String>>>,
183 ab_test_config: Arc<RwLock<Option<ABTestConfig>>>,
184 version_metrics: Arc<RwLock<HashMap<String, ModelMetrics>>>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct ABTestConfig {
190 pub control_version: String,
191 pub treatment_version: String,
192 pub traffic_split: f64,
193 pub started_at: chrono::DateTime<chrono::Utc>,
194 pub min_samples: usize,
195}
196
197impl MLModelServing {
198 pub fn new(config: ModelServingConfig) -> Self {
200 Self {
201 config,
202 models: Arc::new(RwLock::new(HashMap::new())),
203 active_versions: Arc::new(RwLock::new(HashMap::new())),
204 ab_test_config: Arc::new(RwLock::new(None)),
205 version_metrics: Arc::new(RwLock::new(HashMap::new())),
206 }
207 }
208
209 pub async fn deploy_model(
211 &self,
212 version_id: String,
213 model_type: ModelType,
214 model: Arc<QueryTransformerModel>,
215 ) -> Result<()> {
216 info!(
217 "Deploying model version: {} (type: {:?})",
218 version_id, model_type
219 );
220
221 {
223 let mut models = self.models.write().await;
224 models.insert(version_id.clone(), model);
225 }
226
227 {
229 let mut metrics = self.version_metrics.write().await;
230 metrics.insert(version_id.clone(), ModelMetrics::default());
231 }
232
233 self.warmup_model(&version_id).await?;
235
236 {
238 let mut active = self.active_versions.write().await;
239 if !active.contains_key(&model_type) {
240 active.insert(model_type.clone(), version_id.clone());
241 info!("Set {} as active version for {:?}", version_id, model_type);
242 }
243 }
244
245 info!("Model {} deployed successfully", version_id);
246 Ok(())
247 }
248
249 async fn warmup_model(&self, version_id: &str) -> Result<()> {
251 debug!("Warming up model: {}", version_id);
252
253 let models = self.models.read().await;
254 let model = models
255 .get(version_id)
256 .ok_or_else(|| anyhow!("Model not found: {}", version_id))?;
257
258 for _ in 0..self.config.warmup_samples {
260 let sample = vec![0.5; 128]; let _output = model.optimize_query(&sample)?;
262 }
263
264 debug!("Model warmup completed: {}", version_id);
265 Ok(())
266 }
267
268 pub async fn start_ab_test(
270 &self,
271 control_version: String,
272 treatment_version: String,
273 traffic_split: f64,
274 ) -> Result<()> {
275 if !self.config.enable_ab_testing {
276 return Err(anyhow!("A/B testing is not enabled"));
277 }
278
279 {
281 let models = self.models.read().await;
282 if !models.contains_key(&control_version) {
283 return Err(anyhow!("Control version not found: {}", control_version));
284 }
285 if !models.contains_key(&treatment_version) {
286 return Err(anyhow!(
287 "Treatment version not found: {}",
288 treatment_version
289 ));
290 }
291 }
292
293 let ab_config = ABTestConfig {
294 control_version: control_version.clone(),
295 treatment_version: treatment_version.clone(),
296 traffic_split,
297 started_at: chrono::Utc::now(),
298 min_samples: 1000,
299 };
300
301 {
302 let mut ab_test = self.ab_test_config.write().await;
303 *ab_test = Some(ab_config);
304 }
305
306 info!(
307 "A/B test started: control={}, treatment={}, split={}",
308 control_version, treatment_version, traffic_split
309 );
310 Ok(())
311 }
312
313 pub async fn serve_prediction(
315 &self,
316 model_type: ModelType,
317 query_embedding: &[f64],
318 request_id: &str,
319 ) -> Result<Vec<f64>> {
320 let start_time = Instant::now();
321
322 let version_id = self.select_model_version(&model_type, request_id).await?;
324
325 let result = {
327 let models = self.models.read().await;
328 let model = models
329 .get(&version_id)
330 .ok_or_else(|| anyhow!("Model not found: {}", version_id))?;
331
332 model.optimize_query(query_embedding)
333 };
334
335 let latency = start_time.elapsed();
336
337 self.update_metrics(&version_id, &result, latency).await;
339
340 if self.config.auto_rollback_threshold > 0.0 {
342 self.check_auto_rollback(&version_id).await?;
343 }
344
345 result
346 }
347
348 async fn select_model_version(
350 &self,
351 model_type: &ModelType,
352 request_id: &str,
353 ) -> Result<String> {
354 let ab_test = self.ab_test_config.read().await;
356
357 if let Some(ref config) = *ab_test {
358 let hash = request_id
360 .bytes()
361 .fold(0u64, |acc, b| acc.wrapping_add(b as u64));
362 let ratio = (hash % 100) as f64 / 100.0;
363
364 let version = if ratio < config.traffic_split {
365 config.treatment_version.clone()
366 } else {
367 config.control_version.clone()
368 };
369
370 Ok(version)
371 } else {
372 let active = self.active_versions.read().await;
374 active
375 .get(model_type)
376 .cloned()
377 .ok_or_else(|| anyhow!("No active version for model type: {:?}", model_type))
378 }
379 }
380
381 async fn update_metrics(&self, version_id: &str, result: &Result<Vec<f64>>, latency: Duration) {
383 let mut metrics_map = self.version_metrics.write().await;
384 if let Some(metrics) = metrics_map.get_mut(version_id) {
385 metrics.total_requests += 1;
386
387 if result.is_ok() {
388 metrics.successful_requests += 1;
389 } else {
390 metrics.failed_requests += 1;
391 }
392
393 let latency_ms = latency.as_secs_f64() * 1000.0;
395 metrics.average_latency_ms =
396 (metrics.average_latency_ms * (metrics.total_requests - 1) as f64 + latency_ms)
397 / metrics.total_requests as f64;
398
399 metrics.error_rate = metrics.failed_requests as f64 / metrics.total_requests as f64;
401 }
402 }
403
404 async fn check_auto_rollback(&self, version_id: &str) -> Result<()> {
406 let metrics_map = self.version_metrics.read().await;
407
408 if let Some(metrics) = metrics_map.get(version_id) {
409 if metrics.total_requests > 100
410 && metrics.error_rate > self.config.auto_rollback_threshold
411 {
412 warn!(
413 "Auto-rollback triggered for {}: error_rate={:.2}%",
414 version_id,
415 metrics.error_rate * 100.0
416 );
417
418 }
421 }
422
423 Ok(())
424 }
425
426 pub async fn get_ab_test_results(&self) -> Result<ABTestResults> {
428 let ab_test = self.ab_test_config.read().await;
429
430 let config = ab_test
431 .as_ref()
432 .ok_or_else(|| anyhow!("No active A/B test"))?;
433
434 let metrics_map = self.version_metrics.read().await;
435
436 let control_metrics = metrics_map
437 .get(&config.control_version)
438 .cloned()
439 .unwrap_or_default();
440
441 let treatment_metrics = metrics_map
442 .get(&config.treatment_version)
443 .cloned()
444 .unwrap_or_default();
445
446 let improvement = if control_metrics.average_latency_ms > 0.0 {
448 ((control_metrics.average_latency_ms - treatment_metrics.average_latency_ms)
449 / control_metrics.average_latency_ms)
450 * 100.0
451 } else {
452 0.0
453 };
454
455 let is_significant = control_metrics.total_requests >= config.min_samples as u64
456 && treatment_metrics.total_requests >= config.min_samples as u64;
457
458 Ok(ABTestResults {
459 control_version: config.control_version.clone(),
460 treatment_version: config.treatment_version.clone(),
461 control_metrics,
462 treatment_metrics,
463 improvement_percentage: improvement,
464 is_significant,
465 })
466 }
467
468 pub async fn promote_version(&self, version_id: String, model_type: ModelType) -> Result<()> {
470 {
472 let models = self.models.read().await;
473 if !models.contains_key(&version_id) {
474 return Err(anyhow!("Model version not found: {}", version_id));
475 }
476 }
477
478 {
480 let mut active = self.active_versions.write().await;
481 active.insert(model_type.clone(), version_id.clone());
482 }
483
484 info!("Promoted {} to production for {:?}", version_id, model_type);
485 Ok(())
486 }
487
488 pub async fn get_metrics(&self, version_id: &str) -> Result<ModelMetrics> {
490 let metrics_map = self.version_metrics.read().await;
491 metrics_map
492 .get(version_id)
493 .cloned()
494 .ok_or_else(|| anyhow!("Metrics not found for version: {}", version_id))
495 }
496
497 pub async fn list_models(&self) -> Vec<String> {
499 let models = self.models.read().await;
500 models.keys().cloned().collect()
501 }
502}
503
504#[derive(Debug, Clone, Serialize, Deserialize)]
506pub struct ABTestResults {
507 pub control_version: String,
508 pub treatment_version: String,
509 pub control_metrics: ModelMetrics,
510 pub treatment_metrics: ModelMetrics,
511 pub improvement_percentage: f64,
512 pub is_significant: bool,
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[tokio::test]
520 async fn test_model_serving_creation() {
521 let config = ModelServingConfig::default();
522 let serving = MLModelServing::new(config);
523 assert_eq!(serving.list_models().await.len(), 0);
524 }
525
526 #[tokio::test]
527 async fn test_model_deployment() {
528 let config = ModelServingConfig {
529 warmup_samples: 10,
530 ..Default::default()
531 };
532 let serving = MLModelServing::new(config);
533
534 let transformer_config = TransformerConfig {
535 input_dim: 128,
536 hidden_dim: 256,
537 num_heads: 8,
538 num_layers: 4,
539 dropout: 0.1,
540 };
541
542 let model = Arc::new(QueryTransformerModel::new(
543 transformer_config,
544 "v1.0.0".to_string(),
545 ));
546
547 serving
548 .deploy_model("v1.0.0".to_string(), ModelType::TransformerOptimizer, model)
549 .await
550 .unwrap();
551
552 let models = serving.list_models().await;
553 assert_eq!(models.len(), 1);
554 assert!(models.contains(&"v1.0.0".to_string()));
555 }
556
557 #[tokio::test]
558 async fn test_model_prediction() {
559 let config = ModelServingConfig {
560 warmup_samples: 10,
561 enable_ab_testing: false,
562 ..Default::default()
563 };
564 let serving = MLModelServing::new(config);
565
566 let transformer_config = TransformerConfig {
567 input_dim: 128,
568 hidden_dim: 256,
569 num_heads: 8,
570 num_layers: 4,
571 dropout: 0.1,
572 };
573
574 let model = Arc::new(QueryTransformerModel::new(
575 transformer_config,
576 "v1.0.0".to_string(),
577 ));
578
579 serving
580 .deploy_model("v1.0.0".to_string(), ModelType::TransformerOptimizer, model)
581 .await
582 .unwrap();
583
584 let query_embedding = vec![0.5; 128];
585 let result = serving
586 .serve_prediction(ModelType::TransformerOptimizer, &query_embedding, "req-123")
587 .await
588 .unwrap();
589
590 assert!(!result.is_empty());
591 }
592
593 #[tokio::test]
594 async fn test_ab_testing() {
595 let config = ModelServingConfig {
596 warmup_samples: 5,
597 enable_ab_testing: true,
598 ..Default::default()
599 };
600 let serving = MLModelServing::new(config);
601
602 let transformer_config = TransformerConfig {
603 input_dim: 128,
604 hidden_dim: 256,
605 num_heads: 8,
606 num_layers: 4,
607 dropout: 0.1,
608 };
609
610 let model_v1 = Arc::new(QueryTransformerModel::new(
611 transformer_config.clone(),
612 "v1.0.0".to_string(),
613 ));
614
615 let model_v2 = Arc::new(QueryTransformerModel::new(
616 transformer_config,
617 "v2.0.0".to_string(),
618 ));
619
620 serving
621 .deploy_model(
622 "v1.0.0".to_string(),
623 ModelType::TransformerOptimizer,
624 model_v1,
625 )
626 .await
627 .unwrap();
628
629 serving
630 .deploy_model(
631 "v2.0.0".to_string(),
632 ModelType::TransformerOptimizer,
633 model_v2,
634 )
635 .await
636 .unwrap();
637
638 serving
639 .start_ab_test("v1.0.0".to_string(), "v2.0.0".to_string(), 0.5)
640 .await
641 .unwrap();
642
643 let query_embedding = vec![0.5; 128];
645 for i in 0..20 {
646 let request_id = format!("request-{}", i);
647 serving
648 .serve_prediction(
649 ModelType::TransformerOptimizer,
650 &query_embedding,
651 &request_id,
652 )
653 .await
654 .unwrap();
655 }
656
657 let v1_metrics = serving.get_metrics("v1.0.0").await.unwrap();
659 let v2_metrics = serving.get_metrics("v2.0.0").await.unwrap();
660
661 let total_requests = v1_metrics.total_requests + v2_metrics.total_requests;
663 assert_eq!(total_requests, 20, "Total requests should be 20");
664 assert!(
665 v1_metrics.total_requests > 0 || v2_metrics.total_requests > 0,
666 "At least one version should receive requests"
667 );
668 }
669}