Skip to main content

oxigdal_ml/
serving.rs

1//! Model serving and deployment utilities
2//!
3//! This module provides production-ready model serving capabilities including
4//! model versioning, A/B testing, canary deployments, and load balancing.
5
6use crate::error::{MlError, Result};
7// use crate::models::Model;
8use std::collections::HashMap;
9use std::path::PathBuf;
10use std::sync::{Arc, RwLock};
11use tracing::{debug, info};
12
13/// Model version information
14#[derive(Debug, Clone)]
15pub struct ModelVersion {
16    /// Version identifier
17    pub version: String,
18    /// Model file path
19    pub path: PathBuf,
20    /// Deployment timestamp
21    pub deployed_at: std::time::SystemTime,
22    /// Model metadata
23    pub metadata: HashMap<String, String>,
24    /// Performance metrics
25    pub metrics: VersionMetrics,
26}
27
28/// Performance metrics for a model version
29#[derive(Debug, Clone, Default)]
30pub struct VersionMetrics {
31    /// Total requests served
32    pub requests: u64,
33    /// Average latency in milliseconds
34    pub avg_latency_ms: f32,
35    /// Success rate (0.0 to 1.0)
36    pub success_rate: f32,
37    /// Average CPU usage percentage
38    pub avg_cpu_usage: f32,
39    /// Average memory usage in MB
40    pub avg_memory_mb: f32,
41}
42
43/// Deployment strategy
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum DeploymentStrategy {
46    /// Direct replacement
47    Replace,
48    /// Blue-green deployment
49    BlueGreen,
50    /// Canary deployment with gradual rollout
51    Canary {
52        /// Initial traffic percentage (0-100)
53        initial_percent: u8,
54        /// Step size for traffic increase
55        step_percent: u8,
56    },
57    /// A/B testing
58    ABTest {
59        /// Traffic split percentage for new version
60        split_percent: u8,
61    },
62    /// Shadow mode (no user-facing traffic)
63    Shadow,
64}
65
66/// Model server configuration
67#[derive(Debug, Clone)]
68pub struct ServerConfig {
69    /// Maximum concurrent requests
70    pub max_concurrent: usize,
71    /// Request timeout in milliseconds
72    pub timeout_ms: u64,
73    /// Enable request queuing
74    pub enable_queue: bool,
75    /// Queue size limit
76    pub queue_size: usize,
77    /// Enable health checks
78    pub health_check: bool,
79    /// Health check interval in seconds
80    pub health_check_interval_s: u64,
81}
82
83impl Default for ServerConfig {
84    fn default() -> Self {
85        Self {
86            max_concurrent: 100,
87            timeout_ms: 30000,
88            enable_queue: true,
89            queue_size: 1000,
90            health_check: true,
91            health_check_interval_s: 30,
92        }
93    }
94}
95
96/// Model server for production deployment
97pub struct ModelServer {
98    config: ServerConfig,
99    versions: Arc<RwLock<HashMap<String, ModelVersion>>>,
100    active_version: Arc<RwLock<String>>,
101    routing: Arc<RwLock<RoutingStrategy>>,
102}
103
104/// Traffic routing strategy
105#[derive(Debug, Clone)]
106enum RoutingStrategy {
107    /// Single version
108    Single {
109        /// Version ID
110        version: String,
111    },
112    /// Weighted routing
113    Weighted {
114        /// Version weights (version -> percentage)
115        weights: HashMap<String, u8>,
116    },
117    /// Canary routing
118    Canary {
119        /// Stable version
120        stable: String,
121        /// Canary version
122        canary: String,
123        /// Canary traffic percentage
124        canary_percent: u8,
125    },
126}
127
128impl ModelServer {
129    /// Creates a new model server
130    #[must_use]
131    pub fn new(config: ServerConfig) -> Self {
132        info!("Initializing model server");
133        Self {
134            config,
135            versions: Arc::new(RwLock::new(HashMap::new())),
136            active_version: Arc::new(RwLock::new(String::new())),
137            routing: Arc::new(RwLock::new(RoutingStrategy::Single {
138                version: String::new(),
139            })),
140        }
141    }
142
143    /// Registers a new model version
144    ///
145    /// # Errors
146    /// Returns an error if version registration fails
147    pub fn register_version(
148        &mut self,
149        version_id: &str,
150        model_path: PathBuf,
151        metadata: HashMap<String, String>,
152    ) -> Result<()> {
153        info!("Registering model version: {}", version_id);
154
155        if !model_path.exists() {
156            return Err(MlError::InvalidConfig(format!(
157                "Model file not found: {}",
158                model_path.display()
159            )));
160        }
161
162        let version = ModelVersion {
163            version: version_id.to_string(),
164            path: model_path,
165            deployed_at: std::time::SystemTime::now(),
166            metadata,
167            metrics: VersionMetrics::default(),
168        };
169
170        if let Ok(mut versions) = self.versions.write() {
171            versions.insert(version_id.to_string(), version);
172        }
173
174        Ok(())
175    }
176
177    /// Deploys a model version using the specified strategy
178    ///
179    /// # Errors
180    /// Returns an error if deployment fails
181    pub fn deploy(&mut self, version_id: &str, strategy: DeploymentStrategy) -> Result<()> {
182        info!(
183            "Deploying version {} with strategy {:?}",
184            version_id, strategy
185        );
186
187        // Verify version exists
188        let version_exists = self
189            .versions
190            .read()
191            .map(|v| v.contains_key(version_id))
192            .unwrap_or(false);
193
194        if !version_exists {
195            return Err(MlError::InvalidConfig(format!(
196                "Version not found: {}",
197                version_id
198            )));
199        }
200
201        match strategy {
202            DeploymentStrategy::Replace => self.deploy_replace(version_id),
203            DeploymentStrategy::BlueGreen => self.deploy_blue_green(version_id),
204            DeploymentStrategy::Canary {
205                initial_percent,
206                step_percent,
207            } => self.deploy_canary(version_id, initial_percent, step_percent),
208            DeploymentStrategy::ABTest { split_percent } => {
209                self.deploy_ab_test(version_id, split_percent)
210            }
211            DeploymentStrategy::Shadow => self.deploy_shadow(version_id),
212        }
213    }
214
215    /// Rolls back to a previous version
216    ///
217    /// # Errors
218    /// Returns an error if rollback fails
219    pub fn rollback(&mut self, version_id: &str) -> Result<()> {
220        info!("Rolling back to version: {}", version_id);
221        self.deploy_replace(version_id)
222    }
223
224    /// Returns metrics for all versions
225    #[must_use]
226    pub fn version_metrics(&self) -> HashMap<String, VersionMetrics> {
227        self.versions
228            .read()
229            .map(|versions| {
230                versions
231                    .iter()
232                    .map(|(k, v)| (k.clone(), v.metrics.clone()))
233                    .collect()
234            })
235            .unwrap_or_default()
236    }
237
238    /// Returns the active version
239    #[must_use]
240    pub fn active_version(&self) -> String {
241        self.active_version
242            .read()
243            .map(|v| v.clone())
244            .unwrap_or_default()
245    }
246
247    /// Performs health check on active version
248    #[must_use]
249    pub fn health_check(&self) -> HealthStatus {
250        if !self.config.health_check {
251            return HealthStatus::Unknown;
252        }
253
254        // Check if any model is loaded
255        let has_models = self.versions.read().map(|v| !v.is_empty()).unwrap_or(false);
256
257        if !has_models {
258            return HealthStatus::Unhealthy;
259        }
260
261        // Check if active version exists
262        let active_version = self.active_version();
263        if active_version.is_empty() {
264            return HealthStatus::Degraded;
265        }
266
267        // Verify active version is in versions map
268        let version_exists = self
269            .versions
270            .read()
271            .map(|v| v.contains_key(&active_version))
272            .unwrap_or(false);
273
274        if !version_exists {
275            return HealthStatus::Unhealthy;
276        }
277
278        // Check memory usage (simple heuristic)
279        if let Ok(memory_info) = Self::get_memory_usage() {
280            // If memory usage > 90%, return degraded
281            if memory_info.usage_percent > 90.0 {
282                return HealthStatus::Degraded;
283            }
284            // If memory usage > 95%, return unhealthy
285            if memory_info.usage_percent > 95.0 {
286                return HealthStatus::Unhealthy;
287            }
288        }
289
290        HealthStatus::Healthy
291    }
292
293    /// Gets current memory usage information
294    fn get_memory_usage() -> Result<MemoryInfo> {
295        #[cfg(target_os = "linux")]
296        {
297            Self::get_memory_usage_linux()
298        }
299
300        #[cfg(target_os = "macos")]
301        {
302            Self::get_memory_usage_macos()
303        }
304
305        #[cfg(target_os = "windows")]
306        {
307            Self::get_memory_usage_windows()
308        }
309
310        #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
311        {
312            // Default fallback for unsupported platforms
313            Ok(MemoryInfo {
314                total_mb: 0,
315                used_mb: 0,
316                available_mb: 0,
317                usage_percent: 0.0,
318            })
319        }
320    }
321
322    #[cfg(target_os = "linux")]
323    fn get_memory_usage_linux() -> Result<MemoryInfo> {
324        use std::fs;
325
326        let meminfo = fs::read_to_string("/proc/meminfo")
327            .map_err(|e| MlError::InvalidConfig(format!("Failed to read meminfo: {}", e)))?;
328
329        let mut total = 0u64;
330        let mut available = 0u64;
331
332        for line in meminfo.lines() {
333            if let Some(rest) = line.strip_prefix("MemTotal:") {
334                total = rest
335                    .trim()
336                    .split_whitespace()
337                    .next()
338                    .and_then(|s| s.parse::<u64>().ok())
339                    .unwrap_or(0);
340            } else if let Some(rest) = line.strip_prefix("MemAvailable:") {
341                available = rest
342                    .trim()
343                    .split_whitespace()
344                    .next()
345                    .and_then(|s| s.parse::<u64>().ok())
346                    .unwrap_or(0);
347            }
348        }
349
350        let total_mb = total / 1024;
351        let available_mb = available / 1024;
352        let used_mb = total_mb.saturating_sub(available_mb);
353        let usage_percent = if total_mb > 0 {
354            (used_mb as f32 / total_mb as f32) * 100.0
355        } else {
356            0.0
357        };
358
359        Ok(MemoryInfo {
360            total_mb,
361            used_mb,
362            available_mb,
363            usage_percent,
364        })
365    }
366
367    #[cfg(target_os = "macos")]
368    fn get_memory_usage_macos() -> Result<MemoryInfo> {
369        // Simplified implementation for macOS
370        // In production, would use sysctl or vm_stat
371        Ok(MemoryInfo {
372            total_mb: 16384, // Placeholder
373            used_mb: 8192,   // Placeholder
374            available_mb: 8192,
375            usage_percent: 50.0,
376        })
377    }
378
379    #[cfg(target_os = "windows")]
380    fn get_memory_usage_windows() -> Result<MemoryInfo> {
381        // Simplified implementation for Windows
382        // In production, would use Windows API
383        Ok(MemoryInfo {
384            total_mb: 16384, // Placeholder
385            used_mb: 8192,   // Placeholder
386            available_mb: 8192,
387            usage_percent: 50.0,
388        })
389    }
390
391    // Private deployment methods
392
393    fn deploy_replace(&mut self, version_id: &str) -> Result<()> {
394        debug!("Deploying with replace strategy");
395
396        if let Ok(mut active) = self.active_version.write() {
397            *active = version_id.to_string();
398        }
399
400        if let Ok(mut routing) = self.routing.write() {
401            *routing = RoutingStrategy::Single {
402                version: version_id.to_string(),
403            };
404        }
405
406        info!("Version {} deployed successfully", version_id);
407        Ok(())
408    }
409
410    fn deploy_blue_green(&mut self, version_id: &str) -> Result<()> {
411        debug!("Deploying with blue-green strategy");
412
413        // In blue-green, we prepare the new version first
414        // Then switch traffic atomically
415        self.deploy_replace(version_id)
416    }
417
418    fn deploy_canary(
419        &mut self,
420        version_id: &str,
421        initial_percent: u8,
422        _step_percent: u8,
423    ) -> Result<()> {
424        debug!(
425            "Deploying with canary strategy ({}% initial)",
426            initial_percent
427        );
428
429        let stable_version = self.active_version();
430
431        if let Ok(mut routing) = self.routing.write() {
432            *routing = RoutingStrategy::Canary {
433                stable: stable_version,
434                canary: version_id.to_string(),
435                canary_percent: initial_percent,
436            };
437        }
438
439        info!("Canary deployment started for version {}", version_id);
440        Ok(())
441    }
442
443    fn deploy_ab_test(&mut self, version_id: &str, split_percent: u8) -> Result<()> {
444        debug!("Deploying with A/B test ({}% split)", split_percent);
445
446        let stable_version = self.active_version();
447        let mut weights = HashMap::new();
448        weights.insert(stable_version, 100 - split_percent);
449        weights.insert(version_id.to_string(), split_percent);
450
451        if let Ok(mut routing) = self.routing.write() {
452            *routing = RoutingStrategy::Weighted { weights };
453        }
454
455        info!("A/B test started for version {}", version_id);
456        Ok(())
457    }
458
459    fn deploy_shadow(&mut self, version_id: &str) -> Result<()> {
460        debug!("Deploying in shadow mode");
461
462        // Shadow mode: new version receives traffic but results are not used
463        info!("Version {} deployed in shadow mode", version_id);
464        Ok(())
465    }
466
467    /// Increases canary traffic percentage
468    ///
469    /// # Errors
470    /// Returns an error if not in canary mode
471    pub fn increase_canary_traffic(&mut self, increment: u8) -> Result<()> {
472        let mut routing = self
473            .routing
474            .write()
475            .map_err(|_| MlError::InvalidConfig("Failed to acquire routing lock".to_string()))?;
476
477        match &mut *routing {
478            RoutingStrategy::Canary { canary_percent, .. } => {
479                *canary_percent = (*canary_percent + increment).min(100);
480                info!("Increased canary traffic to {}%", canary_percent);
481                Ok(())
482            }
483            _ => Err(MlError::InvalidConfig(
484                "Not in canary deployment mode".to_string(),
485            )),
486        }
487    }
488
489    /// Promotes canary to stable
490    ///
491    /// # Errors
492    /// Returns an error if not in canary mode
493    pub fn promote_canary(&mut self) -> Result<()> {
494        let routing = self
495            .routing
496            .read()
497            .map_err(|_| MlError::InvalidConfig("Failed to acquire routing lock".to_string()))?;
498
499        if let RoutingStrategy::Canary { canary, .. } = &*routing {
500            let canary_version = canary.clone();
501            drop(routing); // Release read lock
502            self.deploy_replace(&canary_version)?;
503            info!("Canary promoted to stable");
504            Ok(())
505        } else {
506            Err(MlError::InvalidConfig(
507                "Not in canary deployment mode".to_string(),
508            ))
509        }
510    }
511}
512
513/// Health status
514#[derive(Debug, Clone, Copy, PartialEq, Eq)]
515pub enum HealthStatus {
516    /// Service is healthy
517    Healthy,
518    /// Service is degraded but operational
519    Degraded,
520    /// Service is unhealthy
521    Unhealthy,
522    /// Health status unknown
523    Unknown,
524}
525
526/// Memory usage information
527#[derive(Debug, Clone)]
528struct MemoryInfo {
529    /// Total memory in MB
530    total_mb: u64,
531    /// Used memory in MB
532    used_mb: u64,
533    /// Available memory in MB
534    available_mb: u64,
535    /// Memory usage percentage
536    usage_percent: f32,
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    #[test]
544    fn test_server_config_default() {
545        let config = ServerConfig::default();
546        assert_eq!(config.max_concurrent, 100);
547        assert_eq!(config.timeout_ms, 30000);
548        assert!(config.enable_queue);
549    }
550
551    #[test]
552    fn test_deployment_strategy_variants() {
553        let strategies = vec![
554            DeploymentStrategy::Replace,
555            DeploymentStrategy::BlueGreen,
556            DeploymentStrategy::Canary {
557                initial_percent: 10,
558                step_percent: 10,
559            },
560            DeploymentStrategy::ABTest { split_percent: 50 },
561            DeploymentStrategy::Shadow,
562        ];
563
564        for strategy in strategies {
565            // Just verify they can be created
566            let _ = format!("{:?}", strategy);
567        }
568    }
569
570    #[test]
571    fn test_model_server_creation() {
572        let config = ServerConfig::default();
573        let server = ModelServer::new(config);
574        assert_eq!(server.active_version(), "");
575    }
576
577    #[test]
578    fn test_health_status() {
579        assert_eq!(HealthStatus::Healthy, HealthStatus::Healthy);
580        assert_ne!(HealthStatus::Healthy, HealthStatus::Degraded);
581    }
582
583    #[test]
584    fn test_version_metrics() {
585        let metrics = VersionMetrics {
586            requests: 1000,
587            avg_latency_ms: 50.0,
588            success_rate: 0.99,
589            avg_cpu_usage: 45.0,
590            avg_memory_mb: 512.0,
591        };
592
593        assert_eq!(metrics.requests, 1000);
594        assert!((metrics.success_rate - 0.99).abs() < 0.01);
595    }
596}