1use super::file_distribution::{DistributionHandle, DistributionScope, TransferPriority};
52use anyhow::Result;
53use chrono::{DateTime, Utc};
54use serde::{Deserialize, Serialize};
55use std::collections::HashMap;
56use std::time::Duration;
57use tokio::sync::RwLock;
58
59#[derive(Clone, Debug, Serialize, Deserialize)]
65pub struct ModelDistributionHandle {
66 pub model_id: String,
68 pub version: String,
70 pub variant_id: String,
72 pub distribution_handle: DistributionHandle,
74 pub initiated_at: DateTime<Utc>,
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize)]
80pub struct ModelConvergenceStatus {
81 pub model_id: String,
83 pub target_version: String,
85 pub total_platforms: usize,
87 pub converged: usize,
89 pub in_progress: usize,
91 pub pending: usize,
93 pub failed: usize,
95 pub version_distribution: HashMap<String, usize>,
97 pub blockers: Vec<ConvergenceBlocker>,
99 #[serde(skip_serializing_if = "Option::is_none")]
101 pub estimated_completion: Option<Duration>,
102}
103
104impl ModelConvergenceStatus {
105 pub fn new(model_id: &str, target_version: &str, total_platforms: usize) -> Self {
107 Self {
108 model_id: model_id.to_string(),
109 target_version: target_version.to_string(),
110 total_platforms,
111 converged: 0,
112 in_progress: 0,
113 pending: total_platforms,
114 failed: 0,
115 version_distribution: HashMap::new(),
116 blockers: Vec::new(),
117 estimated_completion: None,
118 }
119 }
120
121 pub fn is_complete(&self) -> bool {
123 self.converged + self.failed >= self.total_platforms
124 }
125
126 pub fn is_success(&self) -> bool {
128 self.converged >= self.total_platforms && self.failed == 0
129 }
130
131 pub fn convergence_progress(&self) -> f64 {
133 if self.total_platforms == 0 {
134 return 1.0;
135 }
136 self.converged as f64 / self.total_platforms as f64
137 }
138}
139
140#[derive(Clone, Debug, Serialize, Deserialize)]
142pub struct ConvergenceBlocker {
143 pub node_id: String,
145 pub reason: BlockerReason,
147 pub since: DateTime<Utc>,
149 #[serde(skip_serializing_if = "Option::is_none")]
151 pub details: Option<String>,
152}
153
154impl ConvergenceBlocker {
155 pub fn new(node_id: &str, reason: BlockerReason) -> Self {
157 Self {
158 node_id: node_id.to_string(),
159 reason,
160 since: Utc::now(),
161 details: None,
162 }
163 }
164
165 pub fn with_details(mut self, details: &str) -> Self {
167 self.details = Some(details.to_string());
168 self
169 }
170}
171
172#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
174pub enum BlockerReason {
175 NetworkPartition,
177 InsufficientStorage,
179 InsufficientGpuMemory,
181 TransferFailed,
183 DeploymentFailed,
185 IncompatibleCapabilities,
187 NodeBusy,
189 Unknown,
191}
192
193impl std::fmt::Display for BlockerReason {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 match self {
196 Self::NetworkPartition => write!(f, "Network partition"),
197 Self::InsufficientStorage => write!(f, "Insufficient storage"),
198 Self::InsufficientGpuMemory => write!(f, "Insufficient GPU memory"),
199 Self::TransferFailed => write!(f, "Transfer failed"),
200 Self::DeploymentFailed => write!(f, "Deployment failed"),
201 Self::IncompatibleCapabilities => write!(f, "Incompatible capabilities"),
202 Self::NodeBusy => write!(f, "Node busy"),
203 Self::Unknown => write!(f, "Unknown"),
204 }
205 }
206}
207
208#[derive(Clone, Debug, Serialize, Deserialize)]
210pub struct NodeModelStatus {
211 pub node_id: String,
213 #[serde(skip_serializing_if = "Option::is_none")]
215 pub current_version: Option<String>,
216 #[serde(skip_serializing_if = "Option::is_none")]
218 pub variant_id: Option<String>,
219 pub operational_status: ModelOperationalStatus,
221 pub last_updated: DateTime<Utc>,
223}
224
225#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
227pub enum ModelOperationalStatus {
228 #[default]
230 NotDeployed,
231 Downloading,
233 Loading,
235 Operational,
237 Degraded,
239 Failed,
241}
242
243#[derive(Clone, Debug, Default, Serialize, Deserialize)]
245pub struct VariantSelector {
246 #[serde(skip_serializing_if = "Option::is_none")]
248 pub preferred_precision: Option<String>,
249 #[serde(skip_serializing_if = "Vec::is_empty", default)]
251 pub required_providers: Vec<String>,
252 #[serde(skip_serializing_if = "Option::is_none")]
254 pub max_size_bytes: Option<u64>,
255}
256
257#[async_trait::async_trait]
266pub trait ModelDistribution: Send + Sync {
267 async fn distribute_model(
284 &self,
285 model_id: &str,
286 version: &str,
287 scope: DistributionScope,
288 priority: TransferPriority,
289 ) -> Result<ModelDistributionHandle>;
290
291 async fn distribute_model_variant(
295 &self,
296 model_id: &str,
297 version: &str,
298 variant_id: &str,
299 scope: DistributionScope,
300 priority: TransferPriority,
301 ) -> Result<ModelDistributionHandle>;
302
303 async fn distribute_model_delta(
313 &self,
314 model_id: &str,
315 from_version: &str,
316 to_version: &str,
317 scope: DistributionScope,
318 ) -> Result<ModelDistributionHandle>;
319
320 async fn convergence_status(
325 &self,
326 model_id: &str,
327 target_version: &str,
328 ) -> Result<ModelConvergenceStatus>;
329
330 async fn rollback(
335 &self,
336 model_id: &str,
337 to_version: &str,
338 scope: DistributionScope,
339 ) -> Result<ModelDistributionHandle>;
340
341 async fn node_model_status(
343 &self,
344 model_id: &str,
345 node_id: &str,
346 ) -> Result<Option<NodeModelStatus>>;
347
348 async fn nodes_with_version(
350 &self,
351 model_id: &str,
352 version: &str,
353 ) -> Result<Vec<NodeModelStatus>>;
354
355 async fn cancel(&self, handle: &ModelDistributionHandle) -> Result<()>;
357
358 async fn subscribe_convergence(
360 &self,
361 model_id: &str,
362 target_version: &str,
363 ) -> Result<tokio::sync::broadcast::Receiver<ModelConvergenceStatus>>;
364}
365
366#[derive(Debug, Default)]
372pub struct ModelDeploymentTracker {
373 node_statuses: RwLock<HashMap<String, HashMap<String, NodeModelStatus>>>,
375 active_distributions: RwLock<HashMap<String, ModelDistributionHandle>>,
377 #[allow(dead_code)] convergence_channels:
380 RwLock<HashMap<(String, String), tokio::sync::broadcast::Sender<ModelConvergenceStatus>>>,
381}
382
383impl ModelDeploymentTracker {
384 pub fn new() -> Self {
386 Self::default()
387 }
388
389 pub async fn update_node_status(&self, status: NodeModelStatus) {
391 let mut statuses = self.node_statuses.write().await;
392 let node_models = statuses.entry(status.node_id.clone()).or_default();
393
394 if let Some(ref version) = status.current_version {
396 node_models.insert(version.clone(), status);
399 }
400 }
401
402 pub async fn get_node_status(&self, model_id: &str, node_id: &str) -> Option<NodeModelStatus> {
404 let statuses = self.node_statuses.read().await;
405 statuses
406 .get(node_id)
407 .and_then(|models| models.get(model_id))
408 .cloned()
409 }
410
411 pub async fn get_nodes_with_version(
413 &self,
414 model_id: &str,
415 version: &str,
416 ) -> Vec<NodeModelStatus> {
417 let statuses = self.node_statuses.read().await;
418 statuses
419 .values()
420 .filter_map(|models| models.get(model_id))
421 .filter(|status| status.current_version.as_deref() == Some(version))
422 .cloned()
423 .collect()
424 }
425
426 pub async fn register_distribution(&self, handle: ModelDistributionHandle) {
428 let mut distributions = self.active_distributions.write().await;
429 distributions.insert(handle.distribution_handle.distribution_id.clone(), handle);
430 }
431
432 pub async fn get_distribution(&self, distribution_id: &str) -> Option<ModelDistributionHandle> {
434 let distributions = self.active_distributions.read().await;
435 distributions.get(distribution_id).cloned()
436 }
437
438 pub async fn complete_distribution(&self, distribution_id: &str) {
440 let mut distributions = self.active_distributions.write().await;
441 distributions.remove(distribution_id);
442 }
443
444 pub async fn calculate_convergence(
446 &self,
447 model_id: &str,
448 target_version: &str,
449 total_platforms: usize,
450 ) -> ModelConvergenceStatus {
451 let statuses = self.node_statuses.read().await;
452
453 let mut status = ModelConvergenceStatus::new(model_id, target_version, total_platforms);
454 let mut version_counts: HashMap<String, usize> = HashMap::new();
455
456 for (node_id, models) in statuses.iter() {
457 if let Some(node_status) = models.get(model_id) {
458 if let Some(ref version) = node_status.current_version {
459 *version_counts.entry(version.clone()).or_default() += 1;
460
461 if version == target_version {
462 match node_status.operational_status {
463 ModelOperationalStatus::Operational => {
464 status.converged += 1;
465 status.pending = status.pending.saturating_sub(1);
466 }
467 ModelOperationalStatus::Downloading
468 | ModelOperationalStatus::Loading => {
469 status.in_progress += 1;
470 status.pending = status.pending.saturating_sub(1);
471 }
472 ModelOperationalStatus::Failed => {
473 status.failed += 1;
474 status.pending = status.pending.saturating_sub(1);
475 status.blockers.push(ConvergenceBlocker::new(
476 node_id,
477 BlockerReason::DeploymentFailed,
478 ));
479 }
480 ModelOperationalStatus::Degraded => {
481 status.converged += 1;
483 status.pending = status.pending.saturating_sub(1);
484 }
485 ModelOperationalStatus::NotDeployed => {
486 }
488 }
489 }
490 }
491 }
492 }
493
494 status.version_distribution = version_counts;
495 status
496 }
497}
498
499#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_convergence_status_creation() {
509 let status = ModelConvergenceStatus::new("target_recognition", "4.2.1", 10);
510
511 assert_eq!(status.model_id, "target_recognition");
512 assert_eq!(status.target_version, "4.2.1");
513 assert_eq!(status.total_platforms, 10);
514 assert_eq!(status.converged, 0);
515 assert_eq!(status.pending, 10);
516 assert!(!status.is_complete());
517 assert!(!status.is_success());
518 assert_eq!(status.convergence_progress(), 0.0);
519 }
520
521 #[test]
522 fn test_convergence_progress() {
523 let mut status = ModelConvergenceStatus::new("model", "1.0", 10);
524 status.converged = 5;
525 status.pending = 5;
526
527 assert_eq!(status.convergence_progress(), 0.5);
528 assert!(!status.is_complete());
529
530 status.converged = 10;
531 status.pending = 0;
532 assert_eq!(status.convergence_progress(), 1.0);
533 assert!(status.is_complete());
534 assert!(status.is_success());
535 }
536
537 #[test]
538 fn test_convergence_with_failures() {
539 let mut status = ModelConvergenceStatus::new("model", "1.0", 10);
540 status.converged = 8;
541 status.failed = 2;
542 status.pending = 0;
543
544 assert!(status.is_complete());
545 assert!(!status.is_success()); }
547
548 #[test]
549 fn test_blocker_creation() {
550 let blocker = ConvergenceBlocker::new("node-1", BlockerReason::InsufficientGpuMemory)
551 .with_details("Required 8GB, available 4GB");
552
553 assert_eq!(blocker.node_id, "node-1");
554 assert_eq!(blocker.reason, BlockerReason::InsufficientGpuMemory);
555 assert_eq!(
556 blocker.details,
557 Some("Required 8GB, available 4GB".to_string())
558 );
559 }
560
561 #[test]
562 fn test_blocker_reason_display() {
563 assert_eq!(
564 format!("{}", BlockerReason::NetworkPartition),
565 "Network partition"
566 );
567 assert_eq!(
568 format!("{}", BlockerReason::InsufficientStorage),
569 "Insufficient storage"
570 );
571 assert_eq!(
572 format!("{}", BlockerReason::TransferFailed),
573 "Transfer failed"
574 );
575 }
576
577 #[test]
578 fn test_node_model_status() {
579 let status = NodeModelStatus {
580 node_id: "node-1".to_string(),
581 current_version: Some("4.2.1".to_string()),
582 variant_id: Some("fp16-cuda".to_string()),
583 operational_status: ModelOperationalStatus::Operational,
584 last_updated: Utc::now(),
585 };
586
587 assert_eq!(status.current_version, Some("4.2.1".to_string()));
588 assert_eq!(
589 status.operational_status,
590 ModelOperationalStatus::Operational
591 );
592 }
593
594 #[tokio::test]
595 async fn test_deployment_tracker() {
596 let tracker = ModelDeploymentTracker::new();
597
598 let status = tracker.get_node_status("model-1", "node-1").await;
600 assert!(status.is_none());
601
602 let nodes = tracker.get_nodes_with_version("model-1", "1.0").await;
604 assert!(nodes.is_empty());
605 }
606
607 #[test]
608 fn test_model_distribution_handle() {
609 use super::super::blob_traits::BlobHash;
610
611 let handle = ModelDistributionHandle {
612 model_id: "target_recognition".to_string(),
613 version: "4.2.1".to_string(),
614 variant_id: "fp16-cuda".to_string(),
615 distribution_handle: DistributionHandle::new(
616 BlobHash::from_hex("abc123"),
617 DistributionScope::AllNodes,
618 TransferPriority::High,
619 ),
620 initiated_at: Utc::now(),
621 };
622
623 assert_eq!(handle.model_id, "target_recognition");
624 assert_eq!(handle.version, "4.2.1");
625 assert_eq!(handle.variant_id, "fp16-cuda");
626 }
627}