1use crate::error::{MlError, Result};
7use std::collections::HashMap;
9use std::path::PathBuf;
10use std::sync::{Arc, RwLock};
11use tracing::{debug, info};
12
13#[derive(Debug, Clone)]
15pub struct ModelVersion {
16 pub version: String,
18 pub path: PathBuf,
20 pub deployed_at: std::time::SystemTime,
22 pub metadata: HashMap<String, String>,
24 pub metrics: VersionMetrics,
26}
27
28#[derive(Debug, Clone, Default)]
30pub struct VersionMetrics {
31 pub requests: u64,
33 pub avg_latency_ms: f32,
35 pub success_rate: f32,
37 pub avg_cpu_usage: f32,
39 pub avg_memory_mb: f32,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum DeploymentStrategy {
46 Replace,
48 BlueGreen,
50 Canary {
52 initial_percent: u8,
54 step_percent: u8,
56 },
57 ABTest {
59 split_percent: u8,
61 },
62 Shadow,
64}
65
66#[derive(Debug, Clone)]
68pub struct ServerConfig {
69 pub max_concurrent: usize,
71 pub timeout_ms: u64,
73 pub enable_queue: bool,
75 pub queue_size: usize,
77 pub health_check: bool,
79 pub health_check_interval_s: u64,
81}
82
83impl Default for ServerConfig {
84 fn default() -> Self {
85 Self {
86 max_concurrent: 100,
87 timeout_ms: 30000,
88 enable_queue: true,
89 queue_size: 1000,
90 health_check: true,
91 health_check_interval_s: 30,
92 }
93 }
94}
95
96pub struct ModelServer {
98 config: ServerConfig,
99 versions: Arc<RwLock<HashMap<String, ModelVersion>>>,
100 active_version: Arc<RwLock<String>>,
101 routing: Arc<RwLock<RoutingStrategy>>,
102}
103
104#[derive(Debug, Clone)]
106enum RoutingStrategy {
107 Single {
109 version: String,
111 },
112 Weighted {
114 weights: HashMap<String, u8>,
116 },
117 Canary {
119 stable: String,
121 canary: String,
123 canary_percent: u8,
125 },
126}
127
128impl ModelServer {
129 #[must_use]
131 pub fn new(config: ServerConfig) -> Self {
132 info!("Initializing model server");
133 Self {
134 config,
135 versions: Arc::new(RwLock::new(HashMap::new())),
136 active_version: Arc::new(RwLock::new(String::new())),
137 routing: Arc::new(RwLock::new(RoutingStrategy::Single {
138 version: String::new(),
139 })),
140 }
141 }
142
143 pub fn register_version(
148 &mut self,
149 version_id: &str,
150 model_path: PathBuf,
151 metadata: HashMap<String, String>,
152 ) -> Result<()> {
153 info!("Registering model version: {}", version_id);
154
155 if !model_path.exists() {
156 return Err(MlError::InvalidConfig(format!(
157 "Model file not found: {}",
158 model_path.display()
159 )));
160 }
161
162 let version = ModelVersion {
163 version: version_id.to_string(),
164 path: model_path,
165 deployed_at: std::time::SystemTime::now(),
166 metadata,
167 metrics: VersionMetrics::default(),
168 };
169
170 if let Ok(mut versions) = self.versions.write() {
171 versions.insert(version_id.to_string(), version);
172 }
173
174 Ok(())
175 }
176
177 pub fn deploy(&mut self, version_id: &str, strategy: DeploymentStrategy) -> Result<()> {
182 info!(
183 "Deploying version {} with strategy {:?}",
184 version_id, strategy
185 );
186
187 let version_exists = self
189 .versions
190 .read()
191 .map(|v| v.contains_key(version_id))
192 .unwrap_or(false);
193
194 if !version_exists {
195 return Err(MlError::InvalidConfig(format!(
196 "Version not found: {}",
197 version_id
198 )));
199 }
200
201 match strategy {
202 DeploymentStrategy::Replace => self.deploy_replace(version_id),
203 DeploymentStrategy::BlueGreen => self.deploy_blue_green(version_id),
204 DeploymentStrategy::Canary {
205 initial_percent,
206 step_percent,
207 } => self.deploy_canary(version_id, initial_percent, step_percent),
208 DeploymentStrategy::ABTest { split_percent } => {
209 self.deploy_ab_test(version_id, split_percent)
210 }
211 DeploymentStrategy::Shadow => self.deploy_shadow(version_id),
212 }
213 }
214
215 pub fn rollback(&mut self, version_id: &str) -> Result<()> {
220 info!("Rolling back to version: {}", version_id);
221 self.deploy_replace(version_id)
222 }
223
224 #[must_use]
226 pub fn version_metrics(&self) -> HashMap<String, VersionMetrics> {
227 self.versions
228 .read()
229 .map(|versions| {
230 versions
231 .iter()
232 .map(|(k, v)| (k.clone(), v.metrics.clone()))
233 .collect()
234 })
235 .unwrap_or_default()
236 }
237
238 #[must_use]
240 pub fn active_version(&self) -> String {
241 self.active_version
242 .read()
243 .map(|v| v.clone())
244 .unwrap_or_default()
245 }
246
247 #[must_use]
249 pub fn health_check(&self) -> HealthStatus {
250 if !self.config.health_check {
251 return HealthStatus::Unknown;
252 }
253
254 let has_models = self.versions.read().map(|v| !v.is_empty()).unwrap_or(false);
256
257 if !has_models {
258 return HealthStatus::Unhealthy;
259 }
260
261 let active_version = self.active_version();
263 if active_version.is_empty() {
264 return HealthStatus::Degraded;
265 }
266
267 let version_exists = self
269 .versions
270 .read()
271 .map(|v| v.contains_key(&active_version))
272 .unwrap_or(false);
273
274 if !version_exists {
275 return HealthStatus::Unhealthy;
276 }
277
278 if let Ok(memory_info) = Self::get_memory_usage() {
280 if memory_info.usage_percent > 90.0 {
282 return HealthStatus::Degraded;
283 }
284 if memory_info.usage_percent > 95.0 {
286 return HealthStatus::Unhealthy;
287 }
288 }
289
290 HealthStatus::Healthy
291 }
292
293 fn get_memory_usage() -> Result<MemoryInfo> {
295 #[cfg(target_os = "linux")]
296 {
297 Self::get_memory_usage_linux()
298 }
299
300 #[cfg(target_os = "macos")]
301 {
302 Self::get_memory_usage_macos()
303 }
304
305 #[cfg(target_os = "windows")]
306 {
307 Self::get_memory_usage_windows()
308 }
309
310 #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
311 {
312 Ok(MemoryInfo {
314 total_mb: 0,
315 used_mb: 0,
316 available_mb: 0,
317 usage_percent: 0.0,
318 })
319 }
320 }
321
322 #[cfg(target_os = "linux")]
323 fn get_memory_usage_linux() -> Result<MemoryInfo> {
324 use std::fs;
325
326 let meminfo = fs::read_to_string("/proc/meminfo")
327 .map_err(|e| MlError::InvalidConfig(format!("Failed to read meminfo: {}", e)))?;
328
329 let mut total = 0u64;
330 let mut available = 0u64;
331
332 for line in meminfo.lines() {
333 if let Some(rest) = line.strip_prefix("MemTotal:") {
334 total = rest
335 .trim()
336 .split_whitespace()
337 .next()
338 .and_then(|s| s.parse::<u64>().ok())
339 .unwrap_or(0);
340 } else if let Some(rest) = line.strip_prefix("MemAvailable:") {
341 available = rest
342 .trim()
343 .split_whitespace()
344 .next()
345 .and_then(|s| s.parse::<u64>().ok())
346 .unwrap_or(0);
347 }
348 }
349
350 let total_mb = total / 1024;
351 let available_mb = available / 1024;
352 let used_mb = total_mb.saturating_sub(available_mb);
353 let usage_percent = if total_mb > 0 {
354 (used_mb as f32 / total_mb as f32) * 100.0
355 } else {
356 0.0
357 };
358
359 Ok(MemoryInfo {
360 total_mb,
361 used_mb,
362 available_mb,
363 usage_percent,
364 })
365 }
366
367 #[cfg(target_os = "macos")]
368 fn get_memory_usage_macos() -> Result<MemoryInfo> {
369 Ok(MemoryInfo {
372 total_mb: 16384, used_mb: 8192, available_mb: 8192,
375 usage_percent: 50.0,
376 })
377 }
378
379 #[cfg(target_os = "windows")]
380 fn get_memory_usage_windows() -> Result<MemoryInfo> {
381 Ok(MemoryInfo {
384 total_mb: 16384, used_mb: 8192, available_mb: 8192,
387 usage_percent: 50.0,
388 })
389 }
390
391 fn deploy_replace(&mut self, version_id: &str) -> Result<()> {
394 debug!("Deploying with replace strategy");
395
396 if let Ok(mut active) = self.active_version.write() {
397 *active = version_id.to_string();
398 }
399
400 if let Ok(mut routing) = self.routing.write() {
401 *routing = RoutingStrategy::Single {
402 version: version_id.to_string(),
403 };
404 }
405
406 info!("Version {} deployed successfully", version_id);
407 Ok(())
408 }
409
410 fn deploy_blue_green(&mut self, version_id: &str) -> Result<()> {
411 debug!("Deploying with blue-green strategy");
412
413 self.deploy_replace(version_id)
416 }
417
418 fn deploy_canary(
419 &mut self,
420 version_id: &str,
421 initial_percent: u8,
422 _step_percent: u8,
423 ) -> Result<()> {
424 debug!(
425 "Deploying with canary strategy ({}% initial)",
426 initial_percent
427 );
428
429 let stable_version = self.active_version();
430
431 if let Ok(mut routing) = self.routing.write() {
432 *routing = RoutingStrategy::Canary {
433 stable: stable_version,
434 canary: version_id.to_string(),
435 canary_percent: initial_percent,
436 };
437 }
438
439 info!("Canary deployment started for version {}", version_id);
440 Ok(())
441 }
442
443 fn deploy_ab_test(&mut self, version_id: &str, split_percent: u8) -> Result<()> {
444 debug!("Deploying with A/B test ({}% split)", split_percent);
445
446 let stable_version = self.active_version();
447 let mut weights = HashMap::new();
448 weights.insert(stable_version, 100 - split_percent);
449 weights.insert(version_id.to_string(), split_percent);
450
451 if let Ok(mut routing) = self.routing.write() {
452 *routing = RoutingStrategy::Weighted { weights };
453 }
454
455 info!("A/B test started for version {}", version_id);
456 Ok(())
457 }
458
459 fn deploy_shadow(&mut self, version_id: &str) -> Result<()> {
460 debug!("Deploying in shadow mode");
461
462 info!("Version {} deployed in shadow mode", version_id);
464 Ok(())
465 }
466
467 pub fn increase_canary_traffic(&mut self, increment: u8) -> Result<()> {
472 let mut routing = self
473 .routing
474 .write()
475 .map_err(|_| MlError::InvalidConfig("Failed to acquire routing lock".to_string()))?;
476
477 match &mut *routing {
478 RoutingStrategy::Canary { canary_percent, .. } => {
479 *canary_percent = (*canary_percent + increment).min(100);
480 info!("Increased canary traffic to {}%", canary_percent);
481 Ok(())
482 }
483 _ => Err(MlError::InvalidConfig(
484 "Not in canary deployment mode".to_string(),
485 )),
486 }
487 }
488
489 pub fn promote_canary(&mut self) -> Result<()> {
494 let routing = self
495 .routing
496 .read()
497 .map_err(|_| MlError::InvalidConfig("Failed to acquire routing lock".to_string()))?;
498
499 if let RoutingStrategy::Canary { canary, .. } = &*routing {
500 let canary_version = canary.clone();
501 drop(routing); self.deploy_replace(&canary_version)?;
503 info!("Canary promoted to stable");
504 Ok(())
505 } else {
506 Err(MlError::InvalidConfig(
507 "Not in canary deployment mode".to_string(),
508 ))
509 }
510 }
511}
512
513#[derive(Debug, Clone, Copy, PartialEq, Eq)]
515pub enum HealthStatus {
516 Healthy,
518 Degraded,
520 Unhealthy,
522 Unknown,
524}
525
526#[derive(Debug, Clone)]
528struct MemoryInfo {
529 total_mb: u64,
531 used_mb: u64,
533 available_mb: u64,
535 usage_percent: f32,
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_server_config_default() {
545 let config = ServerConfig::default();
546 assert_eq!(config.max_concurrent, 100);
547 assert_eq!(config.timeout_ms, 30000);
548 assert!(config.enable_queue);
549 }
550
551 #[test]
552 fn test_deployment_strategy_variants() {
553 let strategies = vec![
554 DeploymentStrategy::Replace,
555 DeploymentStrategy::BlueGreen,
556 DeploymentStrategy::Canary {
557 initial_percent: 10,
558 step_percent: 10,
559 },
560 DeploymentStrategy::ABTest { split_percent: 50 },
561 DeploymentStrategy::Shadow,
562 ];
563
564 for strategy in strategies {
565 let _ = format!("{:?}", strategy);
567 }
568 }
569
570 #[test]
571 fn test_model_server_creation() {
572 let config = ServerConfig::default();
573 let server = ModelServer::new(config);
574 assert_eq!(server.active_version(), "");
575 }
576
577 #[test]
578 fn test_health_status() {
579 assert_eq!(HealthStatus::Healthy, HealthStatus::Healthy);
580 assert_ne!(HealthStatus::Healthy, HealthStatus::Degraded);
581 }
582
583 #[test]
584 fn test_version_metrics() {
585 let metrics = VersionMetrics {
586 requests: 1000,
587 avg_latency_ms: 50.0,
588 success_rate: 0.99,
589 avg_cpu_usage: 45.0,
590 avg_memory_mb: 512.0,
591 };
592
593 assert_eq!(metrics.requests, 1000);
594 assert!((metrics.success_rate - 0.99).abs() < 0.01);
595 }
596}