use super::file_distribution::{DistributionHandle, DistributionScope, TransferPriority};
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::RwLock;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelDistributionHandle {
pub model_id: String,
pub version: String,
pub variant_id: String,
pub distribution_handle: DistributionHandle,
pub initiated_at: DateTime<Utc>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelConvergenceStatus {
pub model_id: String,
pub target_version: String,
pub total_platforms: usize,
pub converged: usize,
pub in_progress: usize,
pub pending: usize,
pub failed: usize,
pub version_distribution: HashMap<String, usize>,
pub blockers: Vec<ConvergenceBlocker>,
#[serde(skip_serializing_if = "Option::is_none")]
pub estimated_completion: Option<Duration>,
}
impl ModelConvergenceStatus {
pub fn new(model_id: &str, target_version: &str, total_platforms: usize) -> Self {
Self {
model_id: model_id.to_string(),
target_version: target_version.to_string(),
total_platforms,
converged: 0,
in_progress: 0,
pending: total_platforms,
failed: 0,
version_distribution: HashMap::new(),
blockers: Vec::new(),
estimated_completion: None,
}
}
pub fn is_complete(&self) -> bool {
self.converged + self.failed >= self.total_platforms
}
pub fn is_success(&self) -> bool {
self.converged >= self.total_platforms && self.failed == 0
}
pub fn convergence_progress(&self) -> f64 {
if self.total_platforms == 0 {
return 1.0;
}
self.converged as f64 / self.total_platforms as f64
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ConvergenceBlocker {
pub node_id: String,
pub reason: BlockerReason,
pub since: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<String>,
}
impl ConvergenceBlocker {
pub fn new(node_id: &str, reason: BlockerReason) -> Self {
Self {
node_id: node_id.to_string(),
reason,
since: Utc::now(),
details: None,
}
}
pub fn with_details(mut self, details: &str) -> Self {
self.details = Some(details.to_string());
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum BlockerReason {
NetworkPartition,
InsufficientStorage,
InsufficientGpuMemory,
TransferFailed,
DeploymentFailed,
IncompatibleCapabilities,
NodeBusy,
Unknown,
}
impl std::fmt::Display for BlockerReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NetworkPartition => write!(f, "Network partition"),
Self::InsufficientStorage => write!(f, "Insufficient storage"),
Self::InsufficientGpuMemory => write!(f, "Insufficient GPU memory"),
Self::TransferFailed => write!(f, "Transfer failed"),
Self::DeploymentFailed => write!(f, "Deployment failed"),
Self::IncompatibleCapabilities => write!(f, "Incompatible capabilities"),
Self::NodeBusy => write!(f, "Node busy"),
Self::Unknown => write!(f, "Unknown"),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct NodeModelStatus {
pub node_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub current_version: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub variant_id: Option<String>,
pub operational_status: ModelOperationalStatus,
pub last_updated: DateTime<Utc>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum ModelOperationalStatus {
#[default]
NotDeployed,
Downloading,
Loading,
Operational,
Degraded,
Failed,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct VariantSelector {
#[serde(skip_serializing_if = "Option::is_none")]
pub preferred_precision: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub required_providers: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_size_bytes: Option<u64>,
}
#[async_trait::async_trait]
pub trait ModelDistribution: Send + Sync {
async fn distribute_model(
&self,
model_id: &str,
version: &str,
scope: DistributionScope,
priority: TransferPriority,
) -> Result<ModelDistributionHandle>;
async fn distribute_model_variant(
&self,
model_id: &str,
version: &str,
variant_id: &str,
scope: DistributionScope,
priority: TransferPriority,
) -> Result<ModelDistributionHandle>;
async fn distribute_model_delta(
&self,
model_id: &str,
from_version: &str,
to_version: &str,
scope: DistributionScope,
) -> Result<ModelDistributionHandle>;
async fn convergence_status(
&self,
model_id: &str,
target_version: &str,
) -> Result<ModelConvergenceStatus>;
async fn rollback(
&self,
model_id: &str,
to_version: &str,
scope: DistributionScope,
) -> Result<ModelDistributionHandle>;
async fn node_model_status(
&self,
model_id: &str,
node_id: &str,
) -> Result<Option<NodeModelStatus>>;
async fn nodes_with_version(
&self,
model_id: &str,
version: &str,
) -> Result<Vec<NodeModelStatus>>;
async fn cancel(&self, handle: &ModelDistributionHandle) -> Result<()>;
async fn subscribe_convergence(
&self,
model_id: &str,
target_version: &str,
) -> Result<tokio::sync::broadcast::Receiver<ModelConvergenceStatus>>;
}
#[derive(Debug, Default)]
pub struct ModelDeploymentTracker {
node_statuses: RwLock<HashMap<String, HashMap<String, NodeModelStatus>>>,
active_distributions: RwLock<HashMap<String, ModelDistributionHandle>>,
#[allow(dead_code)] convergence_channels:
RwLock<HashMap<(String, String), tokio::sync::broadcast::Sender<ModelConvergenceStatus>>>,
}
impl ModelDeploymentTracker {
pub fn new() -> Self {
Self::default()
}
pub async fn update_node_status(&self, status: NodeModelStatus) {
let mut statuses = self.node_statuses.write().await;
let node_models = statuses.entry(status.node_id.clone()).or_default();
if let Some(ref version) = status.current_version {
node_models.insert(version.clone(), status);
}
}
pub async fn get_node_status(&self, model_id: &str, node_id: &str) -> Option<NodeModelStatus> {
let statuses = self.node_statuses.read().await;
statuses
.get(node_id)
.and_then(|models| models.get(model_id))
.cloned()
}
pub async fn get_nodes_with_version(
&self,
model_id: &str,
version: &str,
) -> Vec<NodeModelStatus> {
let statuses = self.node_statuses.read().await;
statuses
.values()
.filter_map(|models| models.get(model_id))
.filter(|status| status.current_version.as_deref() == Some(version))
.cloned()
.collect()
}
pub async fn register_distribution(&self, handle: ModelDistributionHandle) {
let mut distributions = self.active_distributions.write().await;
distributions.insert(handle.distribution_handle.distribution_id.clone(), handle);
}
pub async fn get_distribution(&self, distribution_id: &str) -> Option<ModelDistributionHandle> {
let distributions = self.active_distributions.read().await;
distributions.get(distribution_id).cloned()
}
pub async fn complete_distribution(&self, distribution_id: &str) {
let mut distributions = self.active_distributions.write().await;
distributions.remove(distribution_id);
}
pub async fn calculate_convergence(
&self,
model_id: &str,
target_version: &str,
total_platforms: usize,
) -> ModelConvergenceStatus {
let statuses = self.node_statuses.read().await;
let mut status = ModelConvergenceStatus::new(model_id, target_version, total_platforms);
let mut version_counts: HashMap<String, usize> = HashMap::new();
for (node_id, models) in statuses.iter() {
if let Some(node_status) = models.get(model_id) {
if let Some(ref version) = node_status.current_version {
*version_counts.entry(version.clone()).or_default() += 1;
if version == target_version {
match node_status.operational_status {
ModelOperationalStatus::Operational => {
status.converged += 1;
status.pending = status.pending.saturating_sub(1);
}
ModelOperationalStatus::Downloading
| ModelOperationalStatus::Loading => {
status.in_progress += 1;
status.pending = status.pending.saturating_sub(1);
}
ModelOperationalStatus::Failed => {
status.failed += 1;
status.pending = status.pending.saturating_sub(1);
status.blockers.push(ConvergenceBlocker::new(
node_id,
BlockerReason::DeploymentFailed,
));
}
ModelOperationalStatus::Degraded => {
status.converged += 1;
status.pending = status.pending.saturating_sub(1);
}
ModelOperationalStatus::NotDeployed => {
}
}
}
}
}
}
status.version_distribution = version_counts;
status
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convergence_status_creation() {
let status = ModelConvergenceStatus::new("target_recognition", "4.2.1", 10);
assert_eq!(status.model_id, "target_recognition");
assert_eq!(status.target_version, "4.2.1");
assert_eq!(status.total_platforms, 10);
assert_eq!(status.converged, 0);
assert_eq!(status.pending, 10);
assert!(!status.is_complete());
assert!(!status.is_success());
assert_eq!(status.convergence_progress(), 0.0);
}
#[test]
fn test_convergence_progress() {
let mut status = ModelConvergenceStatus::new("model", "1.0", 10);
status.converged = 5;
status.pending = 5;
assert_eq!(status.convergence_progress(), 0.5);
assert!(!status.is_complete());
status.converged = 10;
status.pending = 0;
assert_eq!(status.convergence_progress(), 1.0);
assert!(status.is_complete());
assert!(status.is_success());
}
#[test]
fn test_convergence_with_failures() {
let mut status = ModelConvergenceStatus::new("model", "1.0", 10);
status.converged = 8;
status.failed = 2;
status.pending = 0;
assert!(status.is_complete());
assert!(!status.is_success()); }
#[test]
fn test_blocker_creation() {
let blocker = ConvergenceBlocker::new("node-1", BlockerReason::InsufficientGpuMemory)
.with_details("Required 8GB, available 4GB");
assert_eq!(blocker.node_id, "node-1");
assert_eq!(blocker.reason, BlockerReason::InsufficientGpuMemory);
assert_eq!(
blocker.details,
Some("Required 8GB, available 4GB".to_string())
);
}
#[test]
fn test_blocker_reason_display() {
assert_eq!(
format!("{}", BlockerReason::NetworkPartition),
"Network partition"
);
assert_eq!(
format!("{}", BlockerReason::InsufficientStorage),
"Insufficient storage"
);
assert_eq!(
format!("{}", BlockerReason::TransferFailed),
"Transfer failed"
);
}
#[test]
fn test_node_model_status() {
let status = NodeModelStatus {
node_id: "node-1".to_string(),
current_version: Some("4.2.1".to_string()),
variant_id: Some("fp16-cuda".to_string()),
operational_status: ModelOperationalStatus::Operational,
last_updated: Utc::now(),
};
assert_eq!(status.current_version, Some("4.2.1".to_string()));
assert_eq!(
status.operational_status,
ModelOperationalStatus::Operational
);
}
#[tokio::test]
async fn test_deployment_tracker() {
let tracker = ModelDeploymentTracker::new();
let status = tracker.get_node_status("model-1", "node-1").await;
assert!(status.is_none());
let nodes = tracker.get_nodes_with_version("model-1", "1.0").await;
assert!(nodes.is_empty());
}
#[test]
fn test_model_distribution_handle() {
use super::super::blob_traits::BlobHash;
let handle = ModelDistributionHandle {
model_id: "target_recognition".to_string(),
version: "4.2.1".to_string(),
variant_id: "fp16-cuda".to_string(),
distribution_handle: DistributionHandle::new(
BlobHash::from_hex("abc123"),
DistributionScope::AllNodes,
TransferPriority::High,
),
initiated_at: Utc::now(),
};
assert_eq!(handle.model_id, "target_recognition");
assert_eq!(handle.version, "4.2.1");
assert_eq!(handle.variant_id, "fp16-cuda");
}
}