Skip to main content

peat_protocol/storage/
model_distribution.rs

1//! Model distribution API (ADR-025 Phase 4, ADR-022 Integration)
2//!
3//! Model-specific distribution API integrating with Edge MLOps architecture.
4//! Builds on `FileDistribution` to provide model-aware distribution with
5//! variant selection, convergence tracking, and rollback capabilities.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────────────────────────────────┐
11//! │                 ModelDistribution Trait                      │
12//! │  distribute_model() / convergence_status() / rollback()     │
13//! └──────────────────────────┬──────────────────────────────────┘
14//!                            │
15//!            ┌───────────────┴───────────────┐
16//!            ▼                               ▼
17//! ┌──────────────────────┐       ┌──────────────────────┐
18//! │  FileDistribution    │       │  ModelRegistry       │
19//! │  (blob transfer)     │       │  (version tracking)  │
20//! └──────────────────────┘       └──────────────────────┘
21//! ```
22//!
23//! # Usage
24//!
25//! ```ignore
26//! use peat_protocol::storage::{
27//!     ModelDistribution, DistributionScope, TransferPriority,
28//! };
29//!
30//! // Distribute model to capable nodes
31//! let handle = distribution.distribute_model(
32//!     "target_recognition",
33//!     "4.2.1",
34//!     DistributionScope::Capable {
35//!         min_gpu_gb: Some(4.0),
36//!         cpu_arch: Some("aarch64".into()),
37//!         min_storage_mb: Some(500),
38//!     },
39//!     TransferPriority::High,
40//! ).await?;
41//!
42//! // Check convergence status
43//! let status = distribution.convergence_status(
44//!     "target_recognition",
45//!     "4.2.1",
46//! ).await?;
47//!
48//! println!("Converged: {}/{}", status.converged, status.total_platforms);
49//! ```
50
51use super::file_distribution::{DistributionHandle, DistributionScope, TransferPriority};
52use anyhow::Result;
53use chrono::{DateTime, Utc};
54use serde::{Deserialize, Serialize};
55use std::collections::HashMap;
56use std::time::Duration;
57use tokio::sync::RwLock;
58
59// ============================================================================
60// Types
61// ============================================================================
62
63/// Handle to track a model distribution operation
64#[derive(Clone, Debug, Serialize, Deserialize)]
65pub struct ModelDistributionHandle {
66    /// Model identifier
67    pub model_id: String,
68    /// Target version
69    pub version: String,
70    /// Selected variant for this distribution
71    pub variant_id: String,
72    /// Underlying file distribution handle
73    pub distribution_handle: DistributionHandle,
74    /// When distribution was initiated
75    pub initiated_at: DateTime<Utc>,
76}
77
78/// Status of model convergence across the formation
79#[derive(Clone, Debug, Serialize, Deserialize)]
80pub struct ModelConvergenceStatus {
81    /// Model identifier
82    pub model_id: String,
83    /// Target version we're converging to
84    pub target_version: String,
85    /// Total number of target platforms
86    pub total_platforms: usize,
87    /// Platforms that have target version AND it's operational
88    pub converged: usize,
89    /// Platforms currently receiving/deploying the model
90    pub in_progress: usize,
91    /// Platforms not yet started
92    pub pending: usize,
93    /// Platforms where distribution/deployment failed
94    pub failed: usize,
95    /// Distribution of versions across platforms (version -> count)
96    pub version_distribution: HashMap<String, usize>,
97    /// What's blocking convergence on specific nodes
98    pub blockers: Vec<ConvergenceBlocker>,
99    /// Estimated time to full convergence (if calculable)
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub estimated_completion: Option<Duration>,
102}
103
104impl ModelConvergenceStatus {
105    /// Create new convergence status
106    pub fn new(model_id: &str, target_version: &str, total_platforms: usize) -> Self {
107        Self {
108            model_id: model_id.to_string(),
109            target_version: target_version.to_string(),
110            total_platforms,
111            converged: 0,
112            in_progress: 0,
113            pending: total_platforms,
114            failed: 0,
115            version_distribution: HashMap::new(),
116            blockers: Vec::new(),
117            estimated_completion: None,
118        }
119    }
120
121    /// Check if convergence is complete (all platforms converged or failed)
122    pub fn is_complete(&self) -> bool {
123        self.converged + self.failed >= self.total_platforms
124    }
125
126    /// Check if convergence succeeded (all platforms have target version)
127    pub fn is_success(&self) -> bool {
128        self.converged >= self.total_platforms && self.failed == 0
129    }
130
131    /// Calculate convergence progress (0.0 to 1.0)
132    pub fn convergence_progress(&self) -> f64 {
133        if self.total_platforms == 0 {
134            return 1.0;
135        }
136        self.converged as f64 / self.total_platforms as f64
137    }
138}
139
140/// What's blocking convergence on a specific node
141#[derive(Clone, Debug, Serialize, Deserialize)]
142pub struct ConvergenceBlocker {
143    /// Node that's blocked
144    pub node_id: String,
145    /// Why it's blocked
146    pub reason: BlockerReason,
147    /// When the block was first detected
148    pub since: DateTime<Utc>,
149    /// Additional context
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub details: Option<String>,
152}
153
154impl ConvergenceBlocker {
155    /// Create a new blocker
156    pub fn new(node_id: &str, reason: BlockerReason) -> Self {
157        Self {
158            node_id: node_id.to_string(),
159            reason,
160            since: Utc::now(),
161            details: None,
162        }
163    }
164
165    /// Add details to the blocker
166    pub fn with_details(mut self, details: &str) -> Self {
167        self.details = Some(details.to_string());
168        self
169    }
170}
171
172/// Reason a node is blocked from convergence
173#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
174pub enum BlockerReason {
175    /// Node is network partitioned
176    NetworkPartition,
177    /// Insufficient storage space
178    InsufficientStorage,
179    /// Insufficient GPU memory for model
180    InsufficientGpuMemory,
181    /// File transfer failed
182    TransferFailed,
183    /// Model deployment/loading failed
184    DeploymentFailed,
185    /// Node doesn't meet capability requirements
186    IncompatibleCapabilities,
187    /// Node is currently busy with another operation
188    NodeBusy,
189    /// Unknown/other reason
190    Unknown,
191}
192
193impl std::fmt::Display for BlockerReason {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        match self {
196            Self::NetworkPartition => write!(f, "Network partition"),
197            Self::InsufficientStorage => write!(f, "Insufficient storage"),
198            Self::InsufficientGpuMemory => write!(f, "Insufficient GPU memory"),
199            Self::TransferFailed => write!(f, "Transfer failed"),
200            Self::DeploymentFailed => write!(f, "Deployment failed"),
201            Self::IncompatibleCapabilities => write!(f, "Incompatible capabilities"),
202            Self::NodeBusy => write!(f, "Node busy"),
203            Self::Unknown => write!(f, "Unknown"),
204        }
205    }
206}
207
208/// Node's model deployment status
209#[derive(Clone, Debug, Serialize, Deserialize)]
210pub struct NodeModelStatus {
211    /// Node identifier
212    pub node_id: String,
213    /// Currently deployed model version (if any)
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub current_version: Option<String>,
216    /// Deployed variant ID
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub variant_id: Option<String>,
219    /// Model operational status
220    pub operational_status: ModelOperationalStatus,
221    /// Last status update time
222    pub last_updated: DateTime<Utc>,
223}
224
225/// Operational status of a deployed model
226#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
227pub enum ModelOperationalStatus {
228    /// No model deployed
229    #[default]
230    NotDeployed,
231    /// Model is being downloaded
232    Downloading,
233    /// Model is being loaded into runtime
234    Loading,
235    /// Model is operational and serving inference
236    Operational,
237    /// Model is loaded but degraded (e.g., high latency)
238    Degraded,
239    /// Model failed to load or crashed
240    Failed,
241}
242
243/// Variant selection criteria for model distribution
244#[derive(Clone, Debug, Default, Serialize, Deserialize)]
245pub struct VariantSelector {
246    /// Preferred precision (e.g., "float16", "int8")
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub preferred_precision: Option<String>,
249    /// Required execution providers
250    #[serde(skip_serializing_if = "Vec::is_empty", default)]
251    pub required_providers: Vec<String>,
252    /// Maximum model size in bytes
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub max_size_bytes: Option<u64>,
255}
256
257// ============================================================================
258// ModelDistribution Trait
259// ============================================================================
260
261/// Model distribution service for AI/ML model deployment
262///
263/// Provides model-specific distribution with variant selection,
264/// convergence tracking, and rollback capabilities.
265#[async_trait::async_trait]
266pub trait ModelDistribution: Send + Sync {
267    /// Distribute a model version to target platforms
268    ///
269    /// Selects appropriate variant based on target capabilities and initiates
270    /// distribution. Variant selection considers GPU memory, CPU architecture,
271    /// and available execution providers.
272    ///
273    /// # Arguments
274    ///
275    /// * `model_id` - Model identifier (e.g., "target_recognition")
276    /// * `version` - Semantic version (e.g., "4.2.1")
277    /// * `scope` - Target platforms (all, formation, specific nodes, capable)
278    /// * `priority` - Transfer priority
279    ///
280    /// # Returns
281    ///
282    /// Handle for tracking distribution progress
283    async fn distribute_model(
284        &self,
285        model_id: &str,
286        version: &str,
287        scope: DistributionScope,
288        priority: TransferPriority,
289    ) -> Result<ModelDistributionHandle>;
290
291    /// Distribute a model with explicit variant selection
292    ///
293    /// Use when automatic variant selection is not desired.
294    async fn distribute_model_variant(
295        &self,
296        model_id: &str,
297        version: &str,
298        variant_id: &str,
299        scope: DistributionScope,
300        priority: TransferPriority,
301    ) -> Result<ModelDistributionHandle>;
302
303    /// Distribute model delta (differential update)
304    ///
305    /// For large models, only transfer changed chunks between versions.
306    /// Requires target platforms to have `from_version` locally.
307    ///
308    /// # Note
309    ///
310    /// Delta updates use content-defined chunking to minimize transfer size.
311    /// If target doesn't have `from_version`, falls back to full distribution.
312    async fn distribute_model_delta(
313        &self,
314        model_id: &str,
315        from_version: &str,
316        to_version: &str,
317        scope: DistributionScope,
318    ) -> Result<ModelDistributionHandle>;
319
320    /// Get convergence status for a model version
321    ///
322    /// Returns detailed status of how many platforms have converged to
323    /// the target version, what's blocking others, and estimated completion.
324    async fn convergence_status(
325        &self,
326        model_id: &str,
327        target_version: &str,
328    ) -> Result<ModelConvergenceStatus>;
329
330    /// Initiate rollback to a previous version
331    ///
332    /// Distributes the previous version to all platforms that have the
333    /// current (problematic) version.
334    async fn rollback(
335        &self,
336        model_id: &str,
337        to_version: &str,
338        scope: DistributionScope,
339    ) -> Result<ModelDistributionHandle>;
340
341    /// Get model status on a specific node
342    async fn node_model_status(
343        &self,
344        model_id: &str,
345        node_id: &str,
346    ) -> Result<Option<NodeModelStatus>>;
347
348    /// List all nodes with a specific model version
349    async fn nodes_with_version(
350        &self,
351        model_id: &str,
352        version: &str,
353    ) -> Result<Vec<NodeModelStatus>>;
354
355    /// Cancel an in-progress model distribution
356    async fn cancel(&self, handle: &ModelDistributionHandle) -> Result<()>;
357
358    /// Subscribe to convergence status updates
359    async fn subscribe_convergence(
360        &self,
361        model_id: &str,
362        target_version: &str,
363    ) -> Result<tokio::sync::broadcast::Receiver<ModelConvergenceStatus>>;
364}
365
366// ============================================================================
367// In-Memory Model Registry (for tracking deployed versions)
368// ============================================================================
369
370/// Tracks which model versions are deployed on which nodes
371#[derive(Debug, Default)]
372pub struct ModelDeploymentTracker {
373    /// Node model statuses: node_id -> model_id -> status
374    node_statuses: RwLock<HashMap<String, HashMap<String, NodeModelStatus>>>,
375    /// Active distributions: distribution_id -> handle
376    active_distributions: RwLock<HashMap<String, ModelDistributionHandle>>,
377    /// Convergence status channels: (model_id, version) -> broadcast sender
378    #[allow(dead_code)] // For future subscribe_convergence implementation
379    convergence_channels:
380        RwLock<HashMap<(String, String), tokio::sync::broadcast::Sender<ModelConvergenceStatus>>>,
381}
382
383impl ModelDeploymentTracker {
384    /// Create a new deployment tracker
385    pub fn new() -> Self {
386        Self::default()
387    }
388
389    /// Update node's model status
390    pub async fn update_node_status(&self, status: NodeModelStatus) {
391        let mut statuses = self.node_statuses.write().await;
392        let node_models = statuses.entry(status.node_id.clone()).or_default();
393
394        // Extract model_id from status if we can infer it
395        if let Some(ref version) = status.current_version {
396            // We need a model_id - for now use a placeholder approach
397            // In real usage, the status would include the model_id
398            node_models.insert(version.clone(), status);
399        }
400    }
401
402    /// Get status for a specific node and model
403    pub async fn get_node_status(&self, model_id: &str, node_id: &str) -> Option<NodeModelStatus> {
404        let statuses = self.node_statuses.read().await;
405        statuses
406            .get(node_id)
407            .and_then(|models| models.get(model_id))
408            .cloned()
409    }
410
411    /// Get all nodes with a specific model version
412    pub async fn get_nodes_with_version(
413        &self,
414        model_id: &str,
415        version: &str,
416    ) -> Vec<NodeModelStatus> {
417        let statuses = self.node_statuses.read().await;
418        statuses
419            .values()
420            .filter_map(|models| models.get(model_id))
421            .filter(|status| status.current_version.as_deref() == Some(version))
422            .cloned()
423            .collect()
424    }
425
426    /// Register an active distribution
427    pub async fn register_distribution(&self, handle: ModelDistributionHandle) {
428        let mut distributions = self.active_distributions.write().await;
429        distributions.insert(handle.distribution_handle.distribution_id.clone(), handle);
430    }
431
432    /// Get active distribution by ID
433    pub async fn get_distribution(&self, distribution_id: &str) -> Option<ModelDistributionHandle> {
434        let distributions = self.active_distributions.read().await;
435        distributions.get(distribution_id).cloned()
436    }
437
438    /// Remove completed distribution
439    pub async fn complete_distribution(&self, distribution_id: &str) {
440        let mut distributions = self.active_distributions.write().await;
441        distributions.remove(distribution_id);
442    }
443
444    /// Calculate convergence status for a model version
445    pub async fn calculate_convergence(
446        &self,
447        model_id: &str,
448        target_version: &str,
449        total_platforms: usize,
450    ) -> ModelConvergenceStatus {
451        let statuses = self.node_statuses.read().await;
452
453        let mut status = ModelConvergenceStatus::new(model_id, target_version, total_platforms);
454        let mut version_counts: HashMap<String, usize> = HashMap::new();
455
456        for (node_id, models) in statuses.iter() {
457            if let Some(node_status) = models.get(model_id) {
458                if let Some(ref version) = node_status.current_version {
459                    *version_counts.entry(version.clone()).or_default() += 1;
460
461                    if version == target_version {
462                        match node_status.operational_status {
463                            ModelOperationalStatus::Operational => {
464                                status.converged += 1;
465                                status.pending = status.pending.saturating_sub(1);
466                            }
467                            ModelOperationalStatus::Downloading
468                            | ModelOperationalStatus::Loading => {
469                                status.in_progress += 1;
470                                status.pending = status.pending.saturating_sub(1);
471                            }
472                            ModelOperationalStatus::Failed => {
473                                status.failed += 1;
474                                status.pending = status.pending.saturating_sub(1);
475                                status.blockers.push(ConvergenceBlocker::new(
476                                    node_id,
477                                    BlockerReason::DeploymentFailed,
478                                ));
479                            }
480                            ModelOperationalStatus::Degraded => {
481                                // Count as converged but note degradation
482                                status.converged += 1;
483                                status.pending = status.pending.saturating_sub(1);
484                            }
485                            ModelOperationalStatus::NotDeployed => {
486                                // Still pending
487                            }
488                        }
489                    }
490                }
491            }
492        }
493
494        status.version_distribution = version_counts;
495        status
496    }
497}
498
499// ============================================================================
500// Tests
501// ============================================================================
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_convergence_status_creation() {
509        let status = ModelConvergenceStatus::new("target_recognition", "4.2.1", 10);
510
511        assert_eq!(status.model_id, "target_recognition");
512        assert_eq!(status.target_version, "4.2.1");
513        assert_eq!(status.total_platforms, 10);
514        assert_eq!(status.converged, 0);
515        assert_eq!(status.pending, 10);
516        assert!(!status.is_complete());
517        assert!(!status.is_success());
518        assert_eq!(status.convergence_progress(), 0.0);
519    }
520
521    #[test]
522    fn test_convergence_progress() {
523        let mut status = ModelConvergenceStatus::new("model", "1.0", 10);
524        status.converged = 5;
525        status.pending = 5;
526
527        assert_eq!(status.convergence_progress(), 0.5);
528        assert!(!status.is_complete());
529
530        status.converged = 10;
531        status.pending = 0;
532        assert_eq!(status.convergence_progress(), 1.0);
533        assert!(status.is_complete());
534        assert!(status.is_success());
535    }
536
537    #[test]
538    fn test_convergence_with_failures() {
539        let mut status = ModelConvergenceStatus::new("model", "1.0", 10);
540        status.converged = 8;
541        status.failed = 2;
542        status.pending = 0;
543
544        assert!(status.is_complete());
545        assert!(!status.is_success()); // Not success because of failures
546    }
547
548    #[test]
549    fn test_blocker_creation() {
550        let blocker = ConvergenceBlocker::new("node-1", BlockerReason::InsufficientGpuMemory)
551            .with_details("Required 8GB, available 4GB");
552
553        assert_eq!(blocker.node_id, "node-1");
554        assert_eq!(blocker.reason, BlockerReason::InsufficientGpuMemory);
555        assert_eq!(
556            blocker.details,
557            Some("Required 8GB, available 4GB".to_string())
558        );
559    }
560
561    #[test]
562    fn test_blocker_reason_display() {
563        assert_eq!(
564            format!("{}", BlockerReason::NetworkPartition),
565            "Network partition"
566        );
567        assert_eq!(
568            format!("{}", BlockerReason::InsufficientStorage),
569            "Insufficient storage"
570        );
571        assert_eq!(
572            format!("{}", BlockerReason::TransferFailed),
573            "Transfer failed"
574        );
575    }
576
577    #[test]
578    fn test_node_model_status() {
579        let status = NodeModelStatus {
580            node_id: "node-1".to_string(),
581            current_version: Some("4.2.1".to_string()),
582            variant_id: Some("fp16-cuda".to_string()),
583            operational_status: ModelOperationalStatus::Operational,
584            last_updated: Utc::now(),
585        };
586
587        assert_eq!(status.current_version, Some("4.2.1".to_string()));
588        assert_eq!(
589            status.operational_status,
590            ModelOperationalStatus::Operational
591        );
592    }
593
594    #[tokio::test]
595    async fn test_deployment_tracker() {
596        let tracker = ModelDeploymentTracker::new();
597
598        // No status initially
599        let status = tracker.get_node_status("model-1", "node-1").await;
600        assert!(status.is_none());
601
602        // Empty nodes list
603        let nodes = tracker.get_nodes_with_version("model-1", "1.0").await;
604        assert!(nodes.is_empty());
605    }
606
607    #[test]
608    fn test_model_distribution_handle() {
609        use super::super::blob_traits::BlobHash;
610
611        let handle = ModelDistributionHandle {
612            model_id: "target_recognition".to_string(),
613            version: "4.2.1".to_string(),
614            variant_id: "fp16-cuda".to_string(),
615            distribution_handle: DistributionHandle::new(
616                BlobHash::from_hex("abc123"),
617                DistributionScope::AllNodes,
618                TransferPriority::High,
619            ),
620            initiated_at: Utc::now(),
621        };
622
623        assert_eq!(handle.model_id, "target_recognition");
624        assert_eq!(handle.version, "4.2.1");
625        assert_eq!(handle.variant_id, "fp16-cuda");
626    }
627}