Skip to main content

entrenar/train/transformer_trainer/
elastic.rs

1//! Elastic training — dynamic worker add/remove during training.
2//!
3//! Extends the distributed training infrastructure with the ability to:
4//! - Add new workers mid-training (scale up)
5//! - Gracefully remove workers (scale down)
6//! - Continue training after worker failure (fault tolerance)
7//!
8//! # Protocol
9//!
10//! ## Worker Join (mid-training):
11//! 1. New worker sends JoinRequest with `epoch_reached` field
12//! 2. Coordinator pauses at next step boundary
13//! 3. Coordinator sends current weights to new worker
14//! 4. Coordinator adjusts world_size and shard assignments
15//! 5. Training resumes with new worker participating
16//!
17//! ## Worker Leave (graceful):
18//! 1. Worker sends LeaveRequest
19//! 2. Coordinator removes worker from pool
20//! 3. Coordinator redistributes shards
21//! 4. Training continues with remaining workers
22//!
23//! ## Worker Failure (ungraceful):
24//! 1. Heartbeat timeout detected by coordinator
25//! 2. Coordinator marks worker as failed
26//! 3. Coordinator redistributes shards
27//! 4. If below min_workers, pause training
28//! 5. Training continues when sufficient workers available
29//!
30//! # Contract (C-ELASTIC-001)
31//!
32//! - Adding/removing workers does not change model weights
33//! - Data sharding is rebalanced to maintain disjointness (C-SHARD-001)
34//! - All active workers hold identical weights after rebalance
35
36use std::time::Instant;
37
38/// State of a worker in the elastic pool.
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum WorkerState {
41    /// Worker is actively participating in training
42    Active,
43    /// Worker has been accepted but is syncing weights
44    Syncing,
45    /// Worker has been marked for removal (will leave at next step boundary)
46    Draining,
47    /// Worker has failed (heartbeat timeout)
48    Failed,
49    /// Worker has gracefully left
50    Left,
51}
52
53/// Information about a worker in the elastic pool.
54#[derive(Debug, Clone)]
55pub struct ElasticWorker {
56    /// Worker ID (assigned at join time, stable across reconfigurations)
57    pub worker_id: u32,
58    /// Node identifier
59    pub node_id: String,
60    /// Current state
61    pub state: WorkerState,
62    /// Number of GPUs on this worker
63    pub gpu_count: u32,
64    /// Backend type
65    pub backend: String,
66    /// When the worker joined
67    pub joined_at: Instant,
68    /// Step at which the worker joined (for data shard calculation)
69    pub joined_at_step: usize,
70    /// Last heartbeat time
71    pub last_heartbeat: Instant,
72}
73
74/// Elastic training coordinator.
75///
76/// Manages a pool of workers that can dynamically grow or shrink.
77/// Tracks worker state and handles reconfiguration events.
78#[derive(Debug)]
79pub struct ElasticCoordinator {
80    /// All known workers (active, syncing, draining, failed, left)
81    workers: Vec<ElasticWorker>,
82    /// Next worker ID to assign
83    next_worker_id: u32,
84    /// Minimum workers required for training (pause if below)
85    min_workers: usize,
86    /// Maximum workers allowed
87    max_workers: usize,
88    /// Current training step
89    current_step: usize,
90    /// Whether a reconfiguration is pending
91    reconfig_pending: bool,
92    /// Heartbeat timeout (milliseconds)
93    heartbeat_timeout_ms: u64,
94}
95
96impl ElasticCoordinator {
97    /// Create a new elastic coordinator.
98    pub fn new(min_workers: usize, max_workers: usize, heartbeat_timeout_ms: u64) -> Self {
99        Self {
100            workers: Vec::new(),
101            next_worker_id: 0,
102            min_workers,
103            max_workers,
104            current_step: 0,
105            reconfig_pending: false,
106            heartbeat_timeout_ms,
107        }
108    }
109
110    /// Add a new worker to the pool.
111    ///
112    /// Returns the assigned worker ID, or None if pool is full.
113    pub fn add_worker(&mut self, node_id: String, gpu_count: u32, backend: String) -> Option<u32> {
114        if self.active_count() >= self.max_workers {
115            return None;
116        }
117
118        let worker_id = self.next_worker_id;
119        self.next_worker_id += 1;
120        let now = Instant::now();
121
122        self.workers.push(ElasticWorker {
123            worker_id,
124            node_id,
125            state: WorkerState::Syncing,
126            gpu_count,
127            backend,
128            joined_at: now,
129            joined_at_step: self.current_step,
130            last_heartbeat: now,
131        });
132
133        self.reconfig_pending = true;
134        Some(worker_id)
135    }
136
137    /// Mark a worker as active (weight sync complete).
138    pub fn activate_worker(&mut self, worker_id: u32) -> bool {
139        if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
140            if w.state == WorkerState::Syncing {
141                w.state = WorkerState::Active;
142                return true;
143            }
144        }
145        false
146    }
147
148    /// Request graceful removal of a worker.
149    pub fn remove_worker(&mut self, worker_id: u32) -> bool {
150        if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
151            if w.state == WorkerState::Active {
152                w.state = WorkerState::Draining;
153                self.reconfig_pending = true;
154                return true;
155            }
156        }
157        false
158    }
159
160    /// Complete removal of a draining worker.
161    pub fn finalize_removal(&mut self, worker_id: u32) -> bool {
162        if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
163            if w.state == WorkerState::Draining {
164                w.state = WorkerState::Left;
165                return true;
166            }
167        }
168        false
169    }
170
171    /// Check for failed workers based on heartbeat timeout.
172    ///
173    /// Returns list of worker IDs that have failed.
174    pub fn check_heartbeats(&mut self) -> Vec<u32> {
175        let now = Instant::now();
176        let timeout = std::time::Duration::from_millis(self.heartbeat_timeout_ms);
177        let mut failed = Vec::new();
178
179        for w in &mut self.workers {
180            if w.state == WorkerState::Active && now.duration_since(w.last_heartbeat) > timeout {
181                w.state = WorkerState::Failed;
182                failed.push(w.worker_id);
183                self.reconfig_pending = true;
184            }
185        }
186
187        failed
188    }
189
190    /// Update heartbeat for a worker.
191    pub fn update_heartbeat(&mut self, worker_id: u32) {
192        if let Some(w) = self.workers.iter_mut().find(|w| w.worker_id == worker_id) {
193            w.last_heartbeat = Instant::now();
194        }
195    }
196
197    /// Number of active workers.
198    pub fn active_count(&self) -> usize {
199        self.workers.iter().filter(|w| w.state == WorkerState::Active).count()
200    }
201
202    /// Whether training should be paused (below minimum workers).
203    pub fn should_pause(&self) -> bool {
204        self.active_count() < self.min_workers
205    }
206
207    /// Whether a reconfiguration is needed.
208    pub fn needs_reconfig(&self) -> bool {
209        self.reconfig_pending
210    }
211
212    /// Clear the reconfiguration flag.
213    pub fn clear_reconfig(&mut self) {
214        self.reconfig_pending = false;
215    }
216
217    /// Get list of active worker IDs.
218    pub fn active_worker_ids(&self) -> Vec<u32> {
219        self.workers
220            .iter()
221            .filter(|w| w.state == WorkerState::Active)
222            .map(|w| w.worker_id)
223            .collect()
224    }
225
226    /// Get all workers (for status display).
227    pub fn all_workers(&self) -> &[ElasticWorker] {
228        &self.workers
229    }
230
231    /// Update step counter.
232    pub fn set_step(&mut self, step: usize) {
233        self.current_step = step;
234    }
235
236    /// Get current effective world size (active workers only).
237    pub fn effective_world_size(&self) -> usize {
238        self.active_count()
239    }
240
241    /// Compute shard assignments for active workers.
242    ///
243    /// Returns (worker_id, shard_start, shard_end) for each active worker.
244    pub fn compute_shards(&self, total_samples: usize) -> Vec<(u32, usize, usize)> {
245        let active: Vec<u32> = self.active_worker_ids();
246        let n = active.len();
247        if n == 0 {
248            return Vec::new();
249        }
250
251        let shard_size = total_samples / n;
252        let remainder = total_samples % n;
253
254        active
255            .iter()
256            .enumerate()
257            .map(|(i, &wid)| {
258                let start = if i < remainder {
259                    i * (shard_size + 1)
260                } else {
261                    remainder * (shard_size + 1) + (i - remainder) * shard_size
262                };
263                let end = if i < remainder { start + shard_size + 1 } else { start + shard_size };
264                (wid, start, end)
265            })
266            .collect()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_elastic_coordinator_basic() {
276        let mut coord = ElasticCoordinator::new(1, 8, 30000);
277        assert_eq!(coord.active_count(), 0);
278        assert!(coord.should_pause());
279
280        let id = coord.add_worker("node-1".into(), 1, "cuda".into());
281        assert_eq!(id, Some(0));
282        assert_eq!(coord.active_count(), 0); // still syncing
283
284        coord.activate_worker(0);
285        assert_eq!(coord.active_count(), 1);
286        assert!(!coord.should_pause());
287    }
288
289    #[test]
290    fn test_elastic_add_remove() {
291        let mut coord = ElasticCoordinator::new(1, 4, 30000);
292
293        // Add 3 workers
294        coord.add_worker("n1".into(), 1, "cuda".into());
295        coord.add_worker("n2".into(), 1, "cuda".into());
296        coord.add_worker("n3".into(), 2, "wgpu".into());
297        coord.activate_worker(0);
298        coord.activate_worker(1);
299        coord.activate_worker(2);
300        assert_eq!(coord.active_count(), 3);
301
302        // Remove one
303        coord.remove_worker(1);
304        assert_eq!(coord.active_count(), 2); // draining doesn't count as active
305        coord.finalize_removal(1);
306        assert_eq!(coord.active_count(), 2);
307    }
308
309    #[test]
310    fn test_elastic_max_workers() {
311        let mut coord = ElasticCoordinator::new(1, 2, 30000);
312        coord.add_worker("n1".into(), 1, "cuda".into());
313        coord.activate_worker(0);
314        coord.add_worker("n2".into(), 1, "cuda".into());
315        coord.activate_worker(1);
316
317        // Pool full
318        let id = coord.add_worker("n3".into(), 1, "cuda".into());
319        assert_eq!(id, None);
320    }
321
322    #[test]
323    fn test_elastic_shard_computation() {
324        let mut coord = ElasticCoordinator::new(1, 4, 30000);
325        for i in 0..3 {
326            coord.add_worker(format!("n{i}"), 1, "cuda".into());
327            coord.activate_worker(i as u32);
328        }
329
330        let shards = coord.compute_shards(100);
331        assert_eq!(shards.len(), 3);
332
333        // 100 / 3 = 33 rem 1 → first gets 34, others 33
334        let (_, s0, e0) = shards[0];
335        let (_, s1, e1) = shards[1];
336        let (_, s2, e2) = shards[2];
337
338        assert_eq!(s0, 0);
339        assert_eq!(e0, 34);
340        assert_eq!(s1, 34);
341        assert_eq!(e1, 67);
342        assert_eq!(s2, 67);
343        assert_eq!(e2, 100);
344
345        // Complete coverage
346        assert_eq!(e0 - s0 + e1 - s1 + e2 - s2, 100);
347    }
348
349    #[test]
350    fn test_elastic_shard_disjointness() {
351        // C-ELASTIC-001: shards are disjoint and complete
352        let mut coord = ElasticCoordinator::new(1, 8, 30000);
353        for i in 0..5 {
354            coord.add_worker(format!("n{i}"), 1, "cuda".into());
355            coord.activate_worker(i as u32);
356        }
357
358        let total = 10007; // prime, to test remainder handling
359        let shards = coord.compute_shards(total);
360
361        let mut covered = vec![false; total];
362        for (_, start, end) in &shards {
363            for i in *start..*end {
364                assert!(!covered[i], "sample {i} covered by multiple shards");
365                covered[i] = true;
366            }
367        }
368        assert!(covered.iter().all(|&c| c), "not all samples covered");
369    }
370
371    #[test]
372    fn test_elastic_reconfig_flag() {
373        let mut coord = ElasticCoordinator::new(1, 4, 30000);
374        assert!(!coord.needs_reconfig());
375
376        coord.add_worker("n1".into(), 1, "cuda".into());
377        assert!(coord.needs_reconfig());
378
379        coord.clear_reconfig();
380        assert!(!coord.needs_reconfig());
381    }
382
383    #[test]
384    fn test_elastic_should_pause() {
385        let mut coord = ElasticCoordinator::new(2, 4, 30000);
386        assert!(coord.should_pause()); // 0 < 2
387
388        coord.add_worker("n1".into(), 1, "cuda".into());
389        coord.activate_worker(0);
390        assert!(coord.should_pause()); // 1 < 2
391
392        coord.add_worker("n2".into(), 1, "cuda".into());
393        coord.activate_worker(1);
394        assert!(!coord.should_pause()); // 2 >= 2
395    }
396
397    #[test]
398    fn test_elastic_effective_world_size() {
399        let mut coord = ElasticCoordinator::new(1, 4, 30000);
400        coord.add_worker("n1".into(), 1, "cuda".into());
401        coord.add_worker("n2".into(), 1, "cuda".into());
402        coord.activate_worker(0);
403        coord.activate_worker(1);
404
405        assert_eq!(coord.effective_world_size(), 2);
406
407        coord.remove_worker(0);
408        assert_eq!(coord.effective_world_size(), 1);
409    }
410
411    // ── Additional coverage tests ─────────────────────────────────
412
413    #[test]
414    fn test_elastic_activate_non_syncing_worker() {
415        let mut coord = ElasticCoordinator::new(1, 4, 30000);
416        coord.add_worker("n1".into(), 1, "cuda".into());
417        coord.activate_worker(0);
418        // Activating an already-active worker should return false
419        assert!(!coord.activate_worker(0));
420    }
421
422    #[test]
423    fn test_elastic_activate_nonexistent_worker() {
424        let mut coord = ElasticCoordinator::new(1, 4, 30000);
425        assert!(!coord.activate_worker(999));
426    }
427
428    #[test]
429    fn test_elastic_remove_non_active_worker() {
430        let mut coord = ElasticCoordinator::new(1, 4, 30000);
431        coord.add_worker("n1".into(), 1, "cuda".into());
432        // Worker is still Syncing, not Active; remove should fail
433        assert!(!coord.remove_worker(0));
434    }
435
436    #[test]
437    fn test_elastic_remove_nonexistent_worker() {
438        let mut coord = ElasticCoordinator::new(1, 4, 30000);
439        assert!(!coord.remove_worker(999));
440    }
441
442    #[test]
443    fn test_elastic_finalize_removal_not_draining() {
444        let mut coord = ElasticCoordinator::new(1, 4, 30000);
445        coord.add_worker("n1".into(), 1, "cuda".into());
446        coord.activate_worker(0);
447        // Worker is Active, not Draining; finalize should fail
448        assert!(!coord.finalize_removal(0));
449    }
450
451    #[test]
452    fn test_elastic_finalize_nonexistent_worker() {
453        let mut coord = ElasticCoordinator::new(1, 4, 30000);
454        assert!(!coord.finalize_removal(999));
455    }
456
457    #[test]
458    fn test_elastic_update_heartbeat() {
459        let mut coord = ElasticCoordinator::new(1, 4, 30000);
460        coord.add_worker("n1".into(), 1, "cuda".into());
461        coord.activate_worker(0);
462        // Should not panic; just update the timestamp
463        coord.update_heartbeat(0);
464        // Nonexistent worker: no-op
465        coord.update_heartbeat(999);
466    }
467
468    #[test]
469    fn test_elastic_check_heartbeats_no_timeout() {
470        let mut coord = ElasticCoordinator::new(1, 4, 30000);
471        coord.add_worker("n1".into(), 1, "cuda".into());
472        coord.activate_worker(0);
473        // Immediately checking should not time out (30s timeout)
474        let failed = coord.check_heartbeats();
475        assert!(failed.is_empty());
476    }
477
478    #[test]
479    fn test_elastic_check_heartbeats_instant_timeout() {
480        // Timeout of 0ms means every active worker will immediately fail
481        let mut coord = ElasticCoordinator::new(1, 4, 0);
482        coord.add_worker("n1".into(), 1, "cuda".into());
483        coord.activate_worker(0);
484        coord.add_worker("n2".into(), 1, "cuda".into());
485        coord.activate_worker(1);
486
487        // Small delay to ensure the 0ms timeout is exceeded
488        std::thread::sleep(std::time::Duration::from_millis(1));
489
490        let failed = coord.check_heartbeats();
491        assert_eq!(failed.len(), 2);
492        assert!(coord.needs_reconfig());
493        assert_eq!(coord.active_count(), 0);
494    }
495
496    #[test]
497    fn test_elastic_set_step() {
498        let mut coord = ElasticCoordinator::new(1, 4, 30000);
499        coord.set_step(42);
500        // Step is used for shard computation joined_at_step
501        coord.add_worker("n1".into(), 1, "cuda".into());
502        assert_eq!(coord.all_workers()[0].joined_at_step, 42);
503    }
504
505    #[test]
506    fn test_elastic_compute_shards_empty() {
507        let coord = ElasticCoordinator::new(1, 4, 30000);
508        let shards = coord.compute_shards(100);
509        assert!(shards.is_empty());
510    }
511
512    #[test]
513    fn test_elastic_compute_shards_single_worker() {
514        let mut coord = ElasticCoordinator::new(1, 4, 30000);
515        coord.add_worker("n1".into(), 1, "cuda".into());
516        coord.activate_worker(0);
517        let shards = coord.compute_shards(100);
518        assert_eq!(shards.len(), 1);
519        assert_eq!(shards[0], (0, 0, 100));
520    }
521
522    #[test]
523    fn test_elastic_compute_shards_even_division() {
524        let mut coord = ElasticCoordinator::new(1, 4, 30000);
525        for i in 0..4 {
526            coord.add_worker(format!("n{i}"), 1, "cuda".into());
527            coord.activate_worker(i as u32);
528        }
529        let shards = coord.compute_shards(100);
530        assert_eq!(shards.len(), 4);
531        // 100 / 4 = 25 each
532        for (_, start, end) in &shards {
533            assert_eq!(end - start, 25);
534        }
535    }
536
537    #[test]
538    fn test_elastic_compute_shards_zero_samples() {
539        let mut coord = ElasticCoordinator::new(1, 4, 30000);
540        coord.add_worker("n1".into(), 1, "cuda".into());
541        coord.activate_worker(0);
542        let shards = coord.compute_shards(0);
543        assert_eq!(shards.len(), 1);
544        assert_eq!(shards[0], (0, 0, 0));
545    }
546
547    #[test]
548    fn test_elastic_all_workers() {
549        let mut coord = ElasticCoordinator::new(1, 8, 30000);
550        coord.add_worker("n1".into(), 2, "cuda".into());
551        coord.add_worker("n2".into(), 4, "wgpu".into());
552
553        let all = coord.all_workers();
554        assert_eq!(all.len(), 2);
555        assert_eq!(all[0].node_id, "n1");
556        assert_eq!(all[0].gpu_count, 2);
557        assert_eq!(all[0].backend, "cuda");
558        assert_eq!(all[0].state, WorkerState::Syncing);
559        assert_eq!(all[1].node_id, "n2");
560        assert_eq!(all[1].gpu_count, 4);
561    }
562
563    #[test]
564    fn test_elastic_active_worker_ids() {
565        let mut coord = ElasticCoordinator::new(1, 4, 30000);
566        coord.add_worker("n1".into(), 1, "cuda".into());
567        coord.add_worker("n2".into(), 1, "cuda".into());
568        coord.add_worker("n3".into(), 1, "cuda".into());
569        coord.activate_worker(0);
570        coord.activate_worker(2);
571        // Worker 1 still syncing
572
573        let active = coord.active_worker_ids();
574        assert_eq!(active, vec![0, 2]);
575    }
576
577    #[test]
578    fn test_elastic_worker_state_transitions() {
579        let mut coord = ElasticCoordinator::new(1, 4, 30000);
580        coord.add_worker("n1".into(), 1, "cuda".into());
581
582        // Syncing -> Active
583        assert_eq!(coord.all_workers()[0].state, WorkerState::Syncing);
584        coord.activate_worker(0);
585        assert_eq!(coord.all_workers()[0].state, WorkerState::Active);
586
587        // Active -> Draining
588        coord.remove_worker(0);
589        assert_eq!(coord.all_workers()[0].state, WorkerState::Draining);
590
591        // Draining -> Left
592        coord.finalize_removal(0);
593        assert_eq!(coord.all_workers()[0].state, WorkerState::Left);
594    }
595
596    #[test]
597    fn test_elastic_worker_id_increments() {
598        let mut coord = ElasticCoordinator::new(1, 8, 30000);
599        let id0 = coord.add_worker("n1".into(), 1, "cuda".into());
600        let id1 = coord.add_worker("n2".into(), 1, "cuda".into());
601        let id2 = coord.add_worker("n3".into(), 1, "cuda".into());
602        assert_eq!(id0, Some(0));
603        assert_eq!(id1, Some(1));
604        assert_eq!(id2, Some(2));
605    }
606
607    #[test]
608    fn test_elastic_clear_reconfig_then_add() {
609        let mut coord = ElasticCoordinator::new(1, 4, 30000);
610        coord.add_worker("n1".into(), 1, "cuda".into());
611        assert!(coord.needs_reconfig());
612        coord.clear_reconfig();
613        assert!(!coord.needs_reconfig());
614
615        // Adding another worker sets reconfig again
616        coord.add_worker("n2".into(), 1, "cuda".into());
617        assert!(coord.needs_reconfig());
618    }
619
620    #[test]
621    fn test_elastic_remove_sets_reconfig() {
622        let mut coord = ElasticCoordinator::new(1, 4, 30000);
623        coord.add_worker("n1".into(), 1, "cuda".into());
624        coord.activate_worker(0);
625        coord.clear_reconfig();
626        assert!(!coord.needs_reconfig());
627
628        coord.remove_worker(0);
629        assert!(coord.needs_reconfig());
630    }
631
632    #[test]
633    fn test_worker_state_eq() {
634        assert_eq!(WorkerState::Active, WorkerState::Active);
635        assert_eq!(WorkerState::Syncing, WorkerState::Syncing);
636        assert_eq!(WorkerState::Draining, WorkerState::Draining);
637        assert_eq!(WorkerState::Failed, WorkerState::Failed);
638        assert_eq!(WorkerState::Left, WorkerState::Left);
639        assert_ne!(WorkerState::Active, WorkerState::Syncing);
640        assert_ne!(WorkerState::Draining, WorkerState::Failed);
641    }
642
643    #[test]
644    fn test_elastic_worker_clone() {
645        let mut coord = ElasticCoordinator::new(1, 4, 30000);
646        coord.add_worker("n1".into(), 2, "wgpu".into());
647        let worker = coord.all_workers()[0].clone();
648        assert_eq!(worker.node_id, "n1");
649        assert_eq!(worker.gpu_count, 2);
650        assert_eq!(worker.backend, "wgpu");
651    }
652}