1use std::time::Instant;
37
38#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum WorkerState {
41 Active,
43 Syncing,
45 Draining,
47 Failed,
49 Left,
51}
52
53#[derive(Debug, Clone)]
55pub struct ElasticWorker {
56 pub worker_id: u32,
58 pub node_id: String,
60 pub state: WorkerState,
62 pub gpu_count: u32,
64 pub backend: String,
66 pub joined_at: Instant,
68 pub joined_at_step: usize,
70 pub last_heartbeat: Instant,
72}
73
74#[derive(Debug)]
79pub struct ElasticCoordinator {
80 workers: Vec<ElasticWorker>,
82 next_worker_id: u32,
84 min_workers: usize,
86 max_workers: usize,
88 current_step: usize,
90 reconfig_pending: bool,
92 heartbeat_timeout_ms: u64,
94}
95
96impl ElasticCoordinator {
97 pub fn new(min_workers: usize, max_workers: usize, heartbeat_timeout_ms: u64) -> Self {
99 Self {
100 workers: Vec::new(),
101 next_worker_id: 0,
102 min_workers,
103 max_workers,
104 current_step: 0,
105 reconfig_pending: false,
106 heartbeat_timeout_ms,
107 }
108 }
109
110 pub fn add_worker(&mut self, node_id: String, gpu_count: u32, backend: String) -> Option<u32> {
114 if self.active_count() >= self.max_workers {
115 return None;
116 }
117
118 let worker_id = self.next_worker_id;
119 self.next_worker_id += 1;
120 let now = Instant::now();
121
122 self.workers.push(ElasticWorker {
123 worker_id,
124 node_id,
125 state: WorkerState::Syncing,
126 gpu_count,
127 backend,
128 joined_at: now,
129 joined_at_step: self.current_step,
130 last_heartbeat: now,
131 });
132
133 self.reconfig_pending = true;
134 Some(worker_id)
135 }
136
137 pub fn activate_worker(&mut self, worker_id: u32) -> bool {
139 if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
140 if w.state == WorkerState::Syncing {
141 w.state = WorkerState::Active;
142 return true;
143 }
144 }
145 false
146 }
147
148 pub fn remove_worker(&mut self, worker_id: u32) -> bool {
150 if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
151 if w.state == WorkerState::Active {
152 w.state = WorkerState::Draining;
153 self.reconfig_pending = true;
154 return true;
155 }
156 }
157 false
158 }
159
160 pub fn finalize_removal(&mut self, worker_id: u32) -> bool {
162 if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
163 if w.state == WorkerState::Draining {
164 w.state = WorkerState::Left;
165 return true;
166 }
167 }
168 false
169 }
170
171 pub fn check_heartbeats(&mut self) -> Vec<u32> {
175 let now = Instant::now();
176 let timeout = std::time::Duration::from_millis(self.heartbeat_timeout_ms);
177 let mut failed = Vec::new();
178
179 for w in &mut self.workers {
180 if w.state == WorkerState::Active && now.duration_since(w.last_heartbeat) > timeout {
181 w.state = WorkerState::Failed;
182 failed.push(w.worker_id);
183 self.reconfig_pending = true;
184 }
185 }
186
187 failed
188 }
189
190 pub fn update_heartbeat(&mut self, worker_id: u32) {
192 if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
193 w.last_heartbeat = Instant::now();
194 }
195 }
196
197 pub fn active_count(&self) -> usize {
199 self.workers.iter().filter(|w| w.state == WorkerState::Active).count()
200 }
201
202 pub fn should_pause(&self) -> bool {
204 self.active_count() < self.min_workers
205 }
206
207 pub fn needs_reconfig(&self) -> bool {
209 self.reconfig_pending
210 }
211
212 pub fn clear_reconfig(&mut self) {
214 self.reconfig_pending = false;
215 }
216
217 pub fn active_worker_ids(&self) -> Vec<u32> {
219 self.workers
220 .iter()
221 .filter(|w| w.state == WorkerState::Active)
222 .map(|w| w.worker_id)
223 .collect()
224 }
225
226 pub fn all_workers(&self) -> &[ElasticWorker] {
228 &self.workers
229 }
230
231 pub fn set_step(&mut self, step: usize) {
233 self.current_step = step;
234 }
235
236 pub fn effective_world_size(&self) -> usize {
238 self.active_count()
239 }
240
241 pub fn compute_shards(&self, total_samples: usize) -> Vec<(u32, usize, usize)> {
245 let active: Vec<u32> = self.active_worker_ids();
246 let n = active.len();
247 if n == 0 {
248 return Vec::new();
249 }
250
251 let shard_size = total_samples / n;
252 let remainder = total_samples % n;
253
254 active
255 .iter()
256 .enumerate()
257 .map(|(i, &wid)| {
258 let start = if i < remainder {
259 i * (shard_size + 1)
260 } else {
261 remainder * (shard_size + 1) + (i - remainder) * shard_size
262 };
263 let end = if i < remainder { start + shard_size + 1 } else { start + shard_size };
264 (wid, start, end)
265 })
266 .collect()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_elastic_coordinator_basic() {
276 let mut coord = ElasticCoordinator::new(1, 8, 30000);
277 assert_eq!(coord.active_count(), 0);
278 assert!(coord.should_pause());
279
280 let id = coord.add_worker("node-1".into(), 1, "cuda".into());
281 assert_eq!(id, Some(0));
282 assert_eq!(coord.active_count(), 0); coord.activate_worker(0);
285 assert_eq!(coord.active_count(), 1);
286 assert!(!coord.should_pause());
287 }
288
289 #[test]
290 fn test_elastic_add_remove() {
291 let mut coord = ElasticCoordinator::new(1, 4, 30000);
292
293 coord.add_worker("n1".into(), 1, "cuda".into());
295 coord.add_worker("n2".into(), 1, "cuda".into());
296 coord.add_worker("n3".into(), 2, "wgpu".into());
297 coord.activate_worker(0);
298 coord.activate_worker(1);
299 coord.activate_worker(2);
300 assert_eq!(coord.active_count(), 3);
301
302 coord.remove_worker(1);
304 assert_eq!(coord.active_count(), 2); coord.finalize_removal(1);
306 assert_eq!(coord.active_count(), 2);
307 }
308
309 #[test]
310 fn test_elastic_max_workers() {
311 let mut coord = ElasticCoordinator::new(1, 2, 30000);
312 coord.add_worker("n1".into(), 1, "cuda".into());
313 coord.activate_worker(0);
314 coord.add_worker("n2".into(), 1, "cuda".into());
315 coord.activate_worker(1);
316
317 let id = coord.add_worker("n3".into(), 1, "cuda".into());
319 assert_eq!(id, None);
320 }
321
322 #[test]
323 fn test_elastic_shard_computation() {
324 let mut coord = ElasticCoordinator::new(1, 4, 30000);
325 for i in 0..3 {
326 coord.add_worker(format!("n{i}"), 1, "cuda".into());
327 coord.activate_worker(i as u32);
328 }
329
330 let shards = coord.compute_shards(100);
331 assert_eq!(shards.len(), 3);
332
333 let (_, s0, e0) = shards[0];
335 let (_, s1, e1) = shards[1];
336 let (_, s2, e2) = shards[2];
337
338 assert_eq!(s0, 0);
339 assert_eq!(e0, 34);
340 assert_eq!(s1, 34);
341 assert_eq!(e1, 67);
342 assert_eq!(s2, 67);
343 assert_eq!(e2, 100);
344
345 assert_eq!(e0 - s0 + e1 - s1 + e2 - s2, 100);
347 }
348
349 #[test]
350 fn test_elastic_shard_disjointness() {
351 let mut coord = ElasticCoordinator::new(1, 8, 30000);
353 for i in 0..5 {
354 coord.add_worker(format!("n{i}"), 1, "cuda".into());
355 coord.activate_worker(i as u32);
356 }
357
358 let total = 10007; let shards = coord.compute_shards(total);
360
361 let mut covered = vec![false; total];
362 for (_, start, end) in &shards {
363 for i in *start..*end {
364 assert!(!covered[i], "sample {i} covered by multiple shards");
365 covered[i] = true;
366 }
367 }
368 assert!(covered.iter().all(|&c| c), "not all samples covered");
369 }
370
371 #[test]
372 fn test_elastic_reconfig_flag() {
373 let mut coord = ElasticCoordinator::new(1, 4, 30000);
374 assert!(!coord.needs_reconfig());
375
376 coord.add_worker("n1".into(), 1, "cuda".into());
377 assert!(coord.needs_reconfig());
378
379 coord.clear_reconfig();
380 assert!(!coord.needs_reconfig());
381 }
382
383 #[test]
384 fn test_elastic_should_pause() {
385 let mut coord = ElasticCoordinator::new(2, 4, 30000);
386 assert!(coord.should_pause()); coord.add_worker("n1".into(), 1, "cuda".into());
389 coord.activate_worker(0);
390 assert!(coord.should_pause()); coord.add_worker("n2".into(), 1, "cuda".into());
393 coord.activate_worker(1);
394 assert!(!coord.should_pause()); }
396
397 #[test]
398 fn test_elastic_effective_world_size() {
399 let mut coord = ElasticCoordinator::new(1, 4, 30000);
400 coord.add_worker("n1".into(), 1, "cuda".into());
401 coord.add_worker("n2".into(), 1, "cuda".into());
402 coord.activate_worker(0);
403 coord.activate_worker(1);
404
405 assert_eq!(coord.effective_world_size(), 2);
406
407 coord.remove_worker(0);
408 assert_eq!(coord.effective_world_size(), 1);
409 }
410
411 #[test]
414 fn test_elastic_activate_non_syncing_worker() {
415 let mut coord = ElasticCoordinator::new(1, 4, 30000);
416 coord.add_worker("n1".into(), 1, "cuda".into());
417 coord.activate_worker(0);
418 assert!(!coord.activate_worker(0));
420 }
421
422 #[test]
423 fn test_elastic_activate_nonexistent_worker() {
424 let mut coord = ElasticCoordinator::new(1, 4, 30000);
425 assert!(!coord.activate_worker(999));
426 }
427
428 #[test]
429 fn test_elastic_remove_non_active_worker() {
430 let mut coord = ElasticCoordinator::new(1, 4, 30000);
431 coord.add_worker("n1".into(), 1, "cuda".into());
432 assert!(!coord.remove_worker(0));
434 }
435
436 #[test]
437 fn test_elastic_remove_nonexistent_worker() {
438 let mut coord = ElasticCoordinator::new(1, 4, 30000);
439 assert!(!coord.remove_worker(999));
440 }
441
442 #[test]
443 fn test_elastic_finalize_removal_not_draining() {
444 let mut coord = ElasticCoordinator::new(1, 4, 30000);
445 coord.add_worker("n1".into(), 1, "cuda".into());
446 coord.activate_worker(0);
447 assert!(!coord.finalize_removal(0));
449 }
450
451 #[test]
452 fn test_elastic_finalize_nonexistent_worker() {
453 let mut coord = ElasticCoordinator::new(1, 4, 30000);
454 assert!(!coord.finalize_removal(999));
455 }
456
457 #[test]
458 fn test_elastic_update_heartbeat() {
459 let mut coord = ElasticCoordinator::new(1, 4, 30000);
460 coord.add_worker("n1".into(), 1, "cuda".into());
461 coord.activate_worker(0);
462 coord.update_heartbeat(0);
464 coord.update_heartbeat(999);
466 }
467
468 #[test]
469 fn test_elastic_check_heartbeats_no_timeout() {
470 let mut coord = ElasticCoordinator::new(1, 4, 30000);
471 coord.add_worker("n1".into(), 1, "cuda".into());
472 coord.activate_worker(0);
473 let failed = coord.check_heartbeats();
475 assert!(failed.is_empty());
476 }
477
478 #[test]
479 fn test_elastic_check_heartbeats_instant_timeout() {
480 let mut coord = ElasticCoordinator::new(1, 4, 0);
482 coord.add_worker("n1".into(), 1, "cuda".into());
483 coord.activate_worker(0);
484 coord.add_worker("n2".into(), 1, "cuda".into());
485 coord.activate_worker(1);
486
487 std::thread::sleep(std::time::Duration::from_millis(1));
489
490 let failed = coord.check_heartbeats();
491 assert_eq!(failed.len(), 2);
492 assert!(coord.needs_reconfig());
493 assert_eq!(coord.active_count(), 0);
494 }
495
496 #[test]
497 fn test_elastic_set_step() {
498 let mut coord = ElasticCoordinator::new(1, 4, 30000);
499 coord.set_step(42);
500 coord.add_worker("n1".into(), 1, "cuda".into());
502 assert_eq!(coord.all_workers()[0].joined_at_step, 42);
503 }
504
505 #[test]
506 fn test_elastic_compute_shards_empty() {
507 let coord = ElasticCoordinator::new(1, 4, 30000);
508 let shards = coord.compute_shards(100);
509 assert!(shards.is_empty());
510 }
511
512 #[test]
513 fn test_elastic_compute_shards_single_worker() {
514 let mut coord = ElasticCoordinator::new(1, 4, 30000);
515 coord.add_worker("n1".into(), 1, "cuda".into());
516 coord.activate_worker(0);
517 let shards = coord.compute_shards(100);
518 assert_eq!(shards.len(), 1);
519 assert_eq!(shards[0], (0, 0, 100));
520 }
521
522 #[test]
523 fn test_elastic_compute_shards_even_division() {
524 let mut coord = ElasticCoordinator::new(1, 4, 30000);
525 for i in 0..4 {
526 coord.add_worker(format!("n{i}"), 1, "cuda".into());
527 coord.activate_worker(i as u32);
528 }
529 let shards = coord.compute_shards(100);
530 assert_eq!(shards.len(), 4);
531 for (_, start, end) in &shards {
533 assert_eq!(end - start, 25);
534 }
535 }
536
537 #[test]
538 fn test_elastic_compute_shards_zero_samples() {
539 let mut coord = ElasticCoordinator::new(1, 4, 30000);
540 coord.add_worker("n1".into(), 1, "cuda".into());
541 coord.activate_worker(0);
542 let shards = coord.compute_shards(0);
543 assert_eq!(shards.len(), 1);
544 assert_eq!(shards[0], (0, 0, 0));
545 }
546
547 #[test]
548 fn test_elastic_all_workers() {
549 let mut coord = ElasticCoordinator::new(1, 8, 30000);
550 coord.add_worker("n1".into(), 2, "cuda".into());
551 coord.add_worker("n2".into(), 4, "wgpu".into());
552
553 let all = coord.all_workers();
554 assert_eq!(all.len(), 2);
555 assert_eq!(all[0].node_id, "n1");
556 assert_eq!(all[0].gpu_count, 2);
557 assert_eq!(all[0].backend, "cuda");
558 assert_eq!(all[0].state, WorkerState::Syncing);
559 assert_eq!(all[1].node_id, "n2");
560 assert_eq!(all[1].gpu_count, 4);
561 }
562
563 #[test]
564 fn test_elastic_active_worker_ids() {
565 let mut coord = ElasticCoordinator::new(1, 4, 30000);
566 coord.add_worker("n1".into(), 1, "cuda".into());
567 coord.add_worker("n2".into(), 1, "cuda".into());
568 coord.add_worker("n3".into(), 1, "cuda".into());
569 coord.activate_worker(0);
570 coord.activate_worker(2);
571 let active = coord.active_worker_ids();
574 assert_eq!(active, vec![0, 2]);
575 }
576
577 #[test]
578 fn test_elastic_worker_state_transitions() {
579 let mut coord = ElasticCoordinator::new(1, 4, 30000);
580 coord.add_worker("n1".into(), 1, "cuda".into());
581
582 assert_eq!(coord.all_workers()[0].state, WorkerState::Syncing);
584 coord.activate_worker(0);
585 assert_eq!(coord.all_workers()[0].state, WorkerState::Active);
586
587 coord.remove_worker(0);
589 assert_eq!(coord.all_workers()[0].state, WorkerState::Draining);
590
591 coord.finalize_removal(0);
593 assert_eq!(coord.all_workers()[0].state, WorkerState::Left);
594 }
595
596 #[test]
597 fn test_elastic_worker_id_increments() {
598 let mut coord = ElasticCoordinator::new(1, 8, 30000);
599 let id0 = coord.add_worker("n1".into(), 1, "cuda".into());
600 let id1 = coord.add_worker("n2".into(), 1, "cuda".into());
601 let id2 = coord.add_worker("n3".into(), 1, "cuda".into());
602 assert_eq!(id0, Some(0));
603 assert_eq!(id1, Some(1));
604 assert_eq!(id2, Some(2));
605 }
606
607 #[test]
608 fn test_elastic_clear_reconfig_then_add() {
609 let mut coord = ElasticCoordinator::new(1, 4, 30000);
610 coord.add_worker("n1".into(), 1, "cuda".into());
611 assert!(coord.needs_reconfig());
612 coord.clear_reconfig();
613 assert!(!coord.needs_reconfig());
614
615 coord.add_worker("n2".into(), 1, "cuda".into());
617 assert!(coord.needs_reconfig());
618 }
619
620 #[test]
621 fn test_elastic_remove_sets_reconfig() {
622 let mut coord = ElasticCoordinator::new(1, 4, 30000);
623 coord.add_worker("n1".into(), 1, "cuda".into());
624 coord.activate_worker(0);
625 coord.clear_reconfig();
626 assert!(!coord.needs_reconfig());
627
628 coord.remove_worker(0);
629 assert!(coord.needs_reconfig());
630 }
631
632 #[test]
633 fn test_worker_state_eq() {
634 assert_eq!(WorkerState::Active, WorkerState::Active);
635 assert_eq!(WorkerState::Syncing, WorkerState::Syncing);
636 assert_eq!(WorkerState::Draining, WorkerState::Draining);
637 assert_eq!(WorkerState::Failed, WorkerState::Failed);
638 assert_eq!(WorkerState::Left, WorkerState::Left);
639 assert_ne!(WorkerState::Active, WorkerState::Syncing);
640 assert_ne!(WorkerState::Draining, WorkerState::Failed);
641 }
642
643 #[test]
644 fn test_elastic_worker_clone() {
645 let mut coord = ElasticCoordinator::new(1, 4, 30000);
646 coord.add_worker("n1".into(), 2, "wgpu".into());
647 let worker = coord.all_workers()[0].clone();
648 assert_eq!(worker.node_id, "n1");
649 assert_eq!(worker.gpu_count, 2);
650 assert_eq!(worker.backend, "wgpu");
651 }
652}