use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tenflowers_core::{Result, TensorError};
use super::types::{CheckpointState, StreamingConfig, WorkerHealth, WorkerMetrics, WorkerStatus};
pub struct StreamCoordinator {
pub(super) config: StreamingConfig,
pub(super) worker_assignments: Arc<RwLock<HashMap<usize, Vec<usize>>>>,
pub(super) worker_health: Arc<RwLock<HashMap<usize, WorkerHealth>>>,
pub(super) global_checkpoints: Arc<RwLock<HashMap<usize, CheckpointState>>>,
pub(super) balancing_metrics: Arc<RwLock<HashMap<usize, WorkerMetrics>>>,
}
impl StreamCoordinator {
pub fn new(config: StreamingConfig) -> Result<Self> {
config.validate()?;
Ok(Self {
config,
worker_assignments: Arc::new(RwLock::new(HashMap::new())),
worker_health: Arc::new(RwLock::new(HashMap::new())),
global_checkpoints: Arc::new(RwLock::new(HashMap::new())),
balancing_metrics: Arc::new(RwLock::new(HashMap::new())),
})
}
pub fn register_worker(&self, rank: usize, indices: Vec<usize>) -> Result<()> {
let mut assignments = self
.worker_assignments
.write()
.map_err(|e| TensorError::invalid_operation_simple(format!("Lock error: {}", e)))?;
assignments.insert(rank, indices);
let mut health = self
.worker_health
.write()
.map_err(|e| TensorError::invalid_operation_simple(format!("Lock error: {}", e)))?;
health.insert(
rank,
WorkerHealth {
rank,
status: WorkerStatus::Active,
last_heartbeat: Self::current_timestamp(),
samples_processed: 0,
average_throughput: 0.0,
},
);
Ok(())
}
pub fn update_worker_health(
&self,
rank: usize,
samples_processed: u64,
throughput: f64,
) -> Result<()> {
let mut health = self
.worker_health
.write()
.map_err(|e| TensorError::invalid_operation_simple(format!("Lock error: {}", e)))?;
if let Some(worker_health) = health.get_mut(&rank) {
worker_health.last_heartbeat = Self::current_timestamp();
worker_health.samples_processed = samples_processed;
worker_health.average_throughput = throughput;
worker_health.status = Self::determine_worker_status(throughput);
}
Ok(())
}
pub fn get_worker_health(&self, rank: usize) -> Result<Option<WorkerHealth>> {
let health = self
.worker_health
.read()
.map_err(|e| TensorError::invalid_operation_simple(format!("Lock error: {}", e)))?;
Ok(health.get(&rank).cloned())
}
pub fn register_checkpoint(&self, rank: usize, checkpoint: CheckpointState) -> Result<()> {
let mut checkpoints = self
.global_checkpoints
.write()
.map_err(|e| TensorError::invalid_operation_simple(format!("Lock error: {}", e)))?;
checkpoints.insert(rank, checkpoint);
Ok(())
}
pub fn get_checkpoint(&self, rank: usize) -> Result<Option<CheckpointState>> {
let checkpoints = self
.global_checkpoints
.read()
.map_err(|e| TensorError::invalid_operation_simple(format!("Lock error: {}", e)))?;
Ok(checkpoints.get(&rank).cloned())
}
pub fn rebalance_if_needed(&self) -> Result<bool> {
if !self.config.dynamic_balancing {
return Ok(false);
}
let health = self
.worker_health
.read()
.map_err(|e| TensorError::invalid_operation_simple(format!("Lock error: {}", e)))?;
let workers: Vec<_> = health.values().collect();
if workers.is_empty() {
return Ok(false);
}
let avg_throughput: f64 =
workers.iter().map(|w| w.average_throughput).sum::<f64>() / workers.len() as f64;
let variance: f64 = workers
.iter()
.map(|w| {
let diff = w.average_throughput - avg_throughput;
diff * diff
})
.sum::<f64>()
/ workers.len() as f64;
let std_dev = variance.sqrt();
let coefficient_of_variation = if avg_throughput > 0.0 {
std_dev / avg_throughput
} else {
0.0
};
let rebalanced = coefficient_of_variation > 0.2;
Ok(rebalanced)
}
fn current_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn determine_worker_status(throughput: f64) -> WorkerStatus {
if throughput <= 0.0 {
WorkerStatus::Failed
} else if throughput < 10.0 {
WorkerStatus::Slow
} else {
WorkerStatus::Active
}
}
}