use std::time::Instant;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WorkerState {
Active,
Syncing,
Draining,
Failed,
Left,
}
#[derive(Debug, Clone)]
pub struct ElasticWorker {
pub worker_id: u32,
pub node_id: String,
pub state: WorkerState,
pub gpu_count: u32,
pub backend: String,
pub joined_at: Instant,
pub joined_at_step: usize,
pub last_heartbeat: Instant,
}
#[derive(Debug)]
pub struct ElasticCoordinator {
workers: Vec<ElasticWorker>,
next_worker_id: u32,
min_workers: usize,
max_workers: usize,
current_step: usize,
reconfig_pending: bool,
heartbeat_timeout_ms: u64,
}
impl ElasticCoordinator {
pub fn new(min_workers: usize, max_workers: usize, heartbeat_timeout_ms: u64) -> Self {
Self {
workers: Vec::new(),
next_worker_id: 0,
min_workers,
max_workers,
current_step: 0,
reconfig_pending: false,
heartbeat_timeout_ms,
}
}
pub fn add_worker(&mut self, node_id: String, gpu_count: u32, backend: String) -> Option<u32> {
if self.active_count() >= self.max_workers {
return None;
}
let worker_id = self.next_worker_id;
self.next_worker_id += 1;
let now = Instant::now();
self.workers.push(ElasticWorker {
worker_id,
node_id,
state: WorkerState::Syncing,
gpu_count,
backend,
joined_at: now,
joined_at_step: self.current_step,
last_heartbeat: now,
});
self.reconfig_pending = true;
Some(worker_id)
}
pub fn activate_worker(&mut self, worker_id: u32) -> bool {
if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
if w.state == WorkerState::Syncing {
w.state = WorkerState::Active;
return true;
}
}
false
}
pub fn remove_worker(&mut self, worker_id: u32) -> bool {
if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
if w.state == WorkerState::Active {
w.state = WorkerState::Draining;
self.reconfig_pending = true;
return true;
}
}
false
}
pub fn finalize_removal(&mut self, worker_id: u32) -> bool {
if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
if w.state == WorkerState::Draining {
w.state = WorkerState::Left;
return true;
}
}
false
}
pub fn check_heartbeats(&mut self) -> Vec<u32> {
let now = Instant::now();
let timeout = std::time::Duration::from_millis(self.heartbeat_timeout_ms);
let mut failed = Vec::new();
for w in &mut self.workers {
if w.state == WorkerState::Active && now.duration_since(w.last_heartbeat) > timeout {
w.state = WorkerState::Failed;
failed.push(w.worker_id);
self.reconfig_pending = true;
}
}
failed
}
pub fn update_heartbeat(&mut self, worker_id: u32) {
if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
w.last_heartbeat = Instant::now();
}
}
pub fn active_count(&self) -> usize {
self.workers.iter().filter(|w| w.state == WorkerState::Active).count()
}
pub fn should_pause(&self) -> bool {
self.active_count() < self.min_workers
}
pub fn needs_reconfig(&self) -> bool {
self.reconfig_pending
}
pub fn clear_reconfig(&mut self) {
self.reconfig_pending = false;
}
pub fn active_worker_ids(&self) -> Vec<u32> {
self.workers
.iter()
.filter(|w| w.state == WorkerState::Active)
.map(|w| w.worker_id)
.collect()
}
pub fn all_workers(&self) -> &[ElasticWorker] {
&self.workers
}
pub fn set_step(&mut self, step: usize) {
self.current_step = step;
}
pub fn effective_world_size(&self) -> usize {
self.active_count()
}
pub fn compute_shards(&self, total_samples: usize) -> Vec<(u32, usize, usize)> {
let active: Vec<u32> = self.active_worker_ids();
let n = active.len();
if n == 0 {
return Vec::new();
}
let shard_size = total_samples / n;
let remainder = total_samples % n;
active
.iter()
.enumerate()
.map(|(i, &wid)| {
let start = if i < remainder {
i * (shard_size + 1)
} else {
remainder * (shard_size + 1) + (i - remainder) * shard_size
};
let end = if i < remainder { start + shard_size + 1 } else { start + shard_size };
(wid, start, end)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_elastic_coordinator_basic() {
let mut coord = ElasticCoordinator::new(1, 8, 30000);
assert_eq!(coord.active_count(), 0);
assert!(coord.should_pause());
let id = coord.add_worker("node-1".into(), 1, "cuda".into());
assert_eq!(id, Some(0));
assert_eq!(coord.active_count(), 0);
coord.activate_worker(0);
assert_eq!(coord.active_count(), 1);
assert!(!coord.should_pause());
}
#[test]
fn test_elastic_add_remove() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.add_worker("n2".into(), 1, "cuda".into());
coord.add_worker("n3".into(), 2, "wgpu".into());
coord.activate_worker(0);
coord.activate_worker(1);
coord.activate_worker(2);
assert_eq!(coord.active_count(), 3);
coord.remove_worker(1);
assert_eq!(coord.active_count(), 2); coord.finalize_removal(1);
assert_eq!(coord.active_count(), 2);
}
#[test]
fn test_elastic_max_workers() {
let mut coord = ElasticCoordinator::new(1, 2, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
coord.add_worker("n2".into(), 1, "cuda".into());
coord.activate_worker(1);
let id = coord.add_worker("n3".into(), 1, "cuda".into());
assert_eq!(id, None);
}
#[test]
fn test_elastic_shard_computation() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
for i in 0..3 {
coord.add_worker(format!("n{i}"), 1, "cuda".into());
coord.activate_worker(i as u32);
}
let shards = coord.compute_shards(100);
assert_eq!(shards.len(), 3);
let (_, s0, e0) = shards[0];
let (_, s1, e1) = shards[1];
let (_, s2, e2) = shards[2];
assert_eq!(s0, 0);
assert_eq!(e0, 34);
assert_eq!(s1, 34);
assert_eq!(e1, 67);
assert_eq!(s2, 67);
assert_eq!(e2, 100);
assert_eq!(e0 - s0 + e1 - s1 + e2 - s2, 100);
}
#[test]
fn test_elastic_shard_disjointness() {
let mut coord = ElasticCoordinator::new(1, 8, 30000);
for i in 0..5 {
coord.add_worker(format!("n{i}"), 1, "cuda".into());
coord.activate_worker(i as u32);
}
let total = 10007; let shards = coord.compute_shards(total);
let mut covered = vec![false; total];
for (_, start, end) in &shards {
for i in *start..*end {
assert!(!covered[i], "sample {i} covered by multiple shards");
covered[i] = true;
}
}
assert!(covered.iter().all(|&c| c), "not all samples covered");
}
#[test]
fn test_elastic_reconfig_flag() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
assert!(!coord.needs_reconfig());
coord.add_worker("n1".into(), 1, "cuda".into());
assert!(coord.needs_reconfig());
coord.clear_reconfig();
assert!(!coord.needs_reconfig());
}
#[test]
fn test_elastic_should_pause() {
let mut coord = ElasticCoordinator::new(2, 4, 30000);
assert!(coord.should_pause());
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
assert!(coord.should_pause());
coord.add_worker("n2".into(), 1, "cuda".into());
coord.activate_worker(1);
assert!(!coord.should_pause()); }
#[test]
fn test_elastic_effective_world_size() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.add_worker("n2".into(), 1, "cuda".into());
coord.activate_worker(0);
coord.activate_worker(1);
assert_eq!(coord.effective_world_size(), 2);
coord.remove_worker(0);
assert_eq!(coord.effective_world_size(), 1);
}
#[test]
fn test_elastic_activate_non_syncing_worker() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
assert!(!coord.activate_worker(0));
}
#[test]
fn test_elastic_activate_nonexistent_worker() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
assert!(!coord.activate_worker(999));
}
#[test]
fn test_elastic_remove_non_active_worker() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
assert!(!coord.remove_worker(0));
}
#[test]
fn test_elastic_remove_nonexistent_worker() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
assert!(!coord.remove_worker(999));
}
#[test]
fn test_elastic_finalize_removal_not_draining() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
assert!(!coord.finalize_removal(0));
}
#[test]
fn test_elastic_finalize_nonexistent_worker() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
assert!(!coord.finalize_removal(999));
}
#[test]
fn test_elastic_update_heartbeat() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
coord.update_heartbeat(0);
coord.update_heartbeat(999);
}
#[test]
fn test_elastic_check_heartbeats_no_timeout() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
let failed = coord.check_heartbeats();
assert!(failed.is_empty());
}
#[test]
fn test_elastic_check_heartbeats_instant_timeout() {
let mut coord = ElasticCoordinator::new(1, 4, 0);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
coord.add_worker("n2".into(), 1, "cuda".into());
coord.activate_worker(1);
std::thread::sleep(std::time::Duration::from_millis(1));
let failed = coord.check_heartbeats();
assert_eq!(failed.len(), 2);
assert!(coord.needs_reconfig());
assert_eq!(coord.active_count(), 0);
}
#[test]
fn test_elastic_set_step() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.set_step(42);
coord.add_worker("n1".into(), 1, "cuda".into());
assert_eq!(coord.all_workers()[0].joined_at_step, 42);
}
#[test]
fn test_elastic_compute_shards_empty() {
let coord = ElasticCoordinator::new(1, 4, 30000);
let shards = coord.compute_shards(100);
assert!(shards.is_empty());
}
#[test]
fn test_elastic_compute_shards_single_worker() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
let shards = coord.compute_shards(100);
assert_eq!(shards.len(), 1);
assert_eq!(shards[0], (0, 0, 100));
}
#[test]
fn test_elastic_compute_shards_even_division() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
for i in 0..4 {
coord.add_worker(format!("n{i}"), 1, "cuda".into());
coord.activate_worker(i as u32);
}
let shards = coord.compute_shards(100);
assert_eq!(shards.len(), 4);
for (_, start, end) in &shards {
assert_eq!(end - start, 25);
}
}
#[test]
fn test_elastic_compute_shards_zero_samples() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
let shards = coord.compute_shards(0);
assert_eq!(shards.len(), 1);
assert_eq!(shards[0], (0, 0, 0));
}
#[test]
fn test_elastic_all_workers() {
let mut coord = ElasticCoordinator::new(1, 8, 30000);
coord.add_worker("n1".into(), 2, "cuda".into());
coord.add_worker("n2".into(), 4, "wgpu".into());
let all = coord.all_workers();
assert_eq!(all.len(), 2);
assert_eq!(all[0].node_id, "n1");
assert_eq!(all[0].gpu_count, 2);
assert_eq!(all[0].backend, "cuda");
assert_eq!(all[0].state, WorkerState::Syncing);
assert_eq!(all[1].node_id, "n2");
assert_eq!(all[1].gpu_count, 4);
}
#[test]
fn test_elastic_active_worker_ids() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.add_worker("n2".into(), 1, "cuda".into());
coord.add_worker("n3".into(), 1, "cuda".into());
coord.activate_worker(0);
coord.activate_worker(2);
let active = coord.active_worker_ids();
assert_eq!(active, vec![0, 2]);
}
#[test]
fn test_elastic_worker_state_transitions() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
assert_eq!(coord.all_workers()[0].state, WorkerState::Syncing);
coord.activate_worker(0);
assert_eq!(coord.all_workers()[0].state, WorkerState::Active);
coord.remove_worker(0);
assert_eq!(coord.all_workers()[0].state, WorkerState::Draining);
coord.finalize_removal(0);
assert_eq!(coord.all_workers()[0].state, WorkerState::Left);
}
#[test]
fn test_elastic_worker_id_increments() {
let mut coord = ElasticCoordinator::new(1, 8, 30000);
let id0 = coord.add_worker("n1".into(), 1, "cuda".into());
let id1 = coord.add_worker("n2".into(), 1, "cuda".into());
let id2 = coord.add_worker("n3".into(), 1, "cuda".into());
assert_eq!(id0, Some(0));
assert_eq!(id1, Some(1));
assert_eq!(id2, Some(2));
}
#[test]
fn test_elastic_clear_reconfig_then_add() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
assert!(coord.needs_reconfig());
coord.clear_reconfig();
assert!(!coord.needs_reconfig());
coord.add_worker("n2".into(), 1, "cuda".into());
assert!(coord.needs_reconfig());
}
#[test]
fn test_elastic_remove_sets_reconfig() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 1, "cuda".into());
coord.activate_worker(0);
coord.clear_reconfig();
assert!(!coord.needs_reconfig());
coord.remove_worker(0);
assert!(coord.needs_reconfig());
}
#[test]
fn test_worker_state_eq() {
assert_eq!(WorkerState::Active, WorkerState::Active);
assert_eq!(WorkerState::Syncing, WorkerState::Syncing);
assert_eq!(WorkerState::Draining, WorkerState::Draining);
assert_eq!(WorkerState::Failed, WorkerState::Failed);
assert_eq!(WorkerState::Left, WorkerState::Left);
assert_ne!(WorkerState::Active, WorkerState::Syncing);
assert_ne!(WorkerState::Draining, WorkerState::Failed);
}
#[test]
fn test_elastic_worker_clone() {
let mut coord = ElasticCoordinator::new(1, 4, 30000);
coord.add_worker("n1".into(), 2, "wgpu".into());
let worker = coord.all_workers()[0].clone();
assert_eq!(worker.node_id, "n1");
assert_eq!(worker.gpu_count, 2);
assert_eq!(worker.backend, "wgpu");
}
}