Skip to main content

kizzasi_model/
distributed.rs

1//! Distributed Training Support for kizzasi-model
2//!
3//! Provides gradient synchronization primitives for single-node and multi-threaded
4//! distributed training simulation. The design follows an extensible trait-based
5//! architecture so that real network-based all-reduce can be plugged in later.
6//!
7//! # Architecture
8//!
9//! - [`GradientSync`]: Core trait for gradient synchronization strategies.
10//! - [`LocalGradientSync`]: No-op implementation for single-node training.
11//! - [`ThreadedGradientSync`]: `Arc<Mutex>`-based all-reduce for multi-threaded simulation.
12//! - [`run_parallel_workers`]: Helper to run closure-per-worker in parallel threads.
13
14use crate::error::{ModelError, ModelResult};
15use scirs2_core::ndarray::Array1;
16use std::sync::{Arc, Condvar, Mutex};
17
18// ---------------------------------------------------------------------------
19// GradientSync trait
20// ---------------------------------------------------------------------------
21
22/// Trait for gradient synchronization strategies.
23///
24/// Implementations are responsible for aggregating gradients across workers
25/// (e.g., averaging in all-reduce) and writing the result back in-place.
26pub trait GradientSync: Send {
27    /// Synchronize (aggregate) gradients across all workers.
28    ///
29    /// On return `gradients` holds the post-synchronization values.
30    fn sync_gradients(&self, gradients: &mut Array1<f32>) -> ModelResult<()>;
31
32    /// Returns `true` if this sync implementation involves multiple workers.
33    fn is_distributed(&self) -> bool {
34        false
35    }
36
37    /// Number of workers participating in synchronization.
38    fn num_workers(&self) -> usize {
39        1
40    }
41}
42
43// ---------------------------------------------------------------------------
44// LocalGradientSync — no-op for single-node training
45// ---------------------------------------------------------------------------
46
47/// No-op gradient sync for single-node / single-threaded training.
48///
49/// `sync_gradients` is a pure identity operation; it leaves the gradient
50/// array untouched and never allocates.
51#[derive(Debug, Clone, Default)]
52pub struct LocalGradientSync;
53
54impl LocalGradientSync {
55    /// Create a new `LocalGradientSync`.
56    pub fn new() -> Self {
57        Self
58    }
59}
60
61impl GradientSync for LocalGradientSync {
62    #[inline]
63    fn sync_gradients(&self, _gradients: &mut Array1<f32>) -> ModelResult<()> {
64        // Single-node: nothing to do.
65        Ok(())
66    }
67
68    fn is_distributed(&self) -> bool {
69        false
70    }
71
72    fn num_workers(&self) -> usize {
73        1
74    }
75}
76
77// ---------------------------------------------------------------------------
78// ThreadedGradientSync — barrier + all-reduce over Arc<Mutex<>>
79// ---------------------------------------------------------------------------
80
81/// Shared state for a group of [`ThreadedGradientSync`] workers.
82///
83/// All workers in the same group share a single `SharedState` instance.
84/// The barrier uses a single `Mutex<BarrierState>` and a `Condvar` so that
85/// accumulation, averaging, read-back, and reset all happen under coordinated
86/// locking with no races.
87#[derive(Debug)]
88struct BarrierState {
89    /// Accumulated gradient sum; `None` before the first worker deposits.
90    accumulator: Option<Vec<f32>>,
91    /// Averaged result available for all workers to read back.
92    result: Option<Vec<f32>>,
93    /// How many workers have deposited their gradients this round.
94    arrived: usize,
95    /// How many workers have finished reading back the result.
96    departed: usize,
97    /// Generation counter — incremented when the averaging is done so waiters
98    /// can distinguish this round from the next.
99    generation: usize,
100}
101
102impl BarrierState {
103    fn new() -> Self {
104        Self {
105            accumulator: None,
106            result: None,
107            arrived: 0,
108            departed: 0,
109            generation: 0,
110        }
111    }
112}
113
114#[derive(Debug)]
115struct SharedState {
116    inner: Mutex<BarrierState>,
117    all_arrived: Condvar,
118    all_departed: Condvar,
119    num_workers: usize,
120}
121
122impl SharedState {
123    fn new(num_workers: usize) -> Self {
124        Self {
125            inner: Mutex::new(BarrierState::new()),
126            all_arrived: Condvar::new(),
127            all_departed: Condvar::new(),
128            num_workers,
129        }
130    }
131}
132
133/// All-reduce gradient synchronizer backed by `Arc<Mutex<>>` for multi-threaded
134/// training simulation within a single process.
135///
136/// All workers that share the same underlying `SharedState` barrier must call
137/// [`GradientSync::sync_gradients`] with arrays of the same length, otherwise an
138/// error is returned. The synchronization algorithm is:
139///
140/// 1. Worker adds its gradients into the shared accumulator.
141/// 2. The last arriving worker computes the element-wise mean, stores it as the
142///    result, and signals all waiters.
143/// 3. All workers copy the averaged result back into their local gradient buffer.
144/// 4. The last departing worker resets state for the next round.
145#[derive(Debug, Clone)]
146pub struct ThreadedGradientSync {
147    shared: Arc<SharedState>,
148    worker_id: usize,
149}
150
151impl ThreadedGradientSync {
152    /// Create `num_workers` sync objects that share the same barrier state.
153    ///
154    /// # Panics
155    ///
156    /// Panics if `num_workers == 0`.
157    pub fn new_workers(num_workers: usize) -> Vec<Self> {
158        assert!(num_workers > 0, "num_workers must be at least 1");
159        let shared = Arc::new(SharedState::new(num_workers));
160        (0..num_workers)
161            .map(|id| Self {
162                shared: Arc::clone(&shared),
163                worker_id: id,
164            })
165            .collect()
166    }
167
168    /// Return the worker index (0-based) for this instance.
169    pub fn worker_id(&self) -> usize {
170        self.worker_id
171    }
172}
173
174impl GradientSync for ThreadedGradientSync {
175    fn sync_gradients(&self, gradients: &mut Array1<f32>) -> ModelResult<()> {
176        let n = gradients.len();
177        let num_workers = self.shared.num_workers;
178
179        // ----------------------------------------------------------------
180        // Phase 1: deposit gradients into the shared accumulator.
181        // ----------------------------------------------------------------
182        {
183            let mut state =
184                self.shared.inner.lock().map_err(|_| {
185                    ModelError::load_error("gradient sync", "barrier mutex poisoned")
186                })?;
187
188            match state.accumulator.as_mut() {
189                None => {
190                    state.accumulator = Some(gradients.iter().copied().collect());
191                }
192                Some(acc) => {
193                    if acc.len() != n {
194                        return Err(ModelError::dimension_mismatch(
195                            "gradient sync",
196                            acc.len(),
197                            n,
198                        ));
199                    }
200                    for (a, &g) in acc.iter_mut().zip(gradients.iter()) {
201                        *a += g;
202                    }
203                }
204            }
205            state.arrived += 1;
206        }
207
208        // ----------------------------------------------------------------
209        // Phase 2: barrier — wait until all workers have deposited; the
210        // last worker computes the mean and signals everyone.
211        // ----------------------------------------------------------------
212        {
213            let mut state =
214                self.shared.inner.lock().map_err(|_| {
215                    ModelError::load_error("gradient sync", "barrier mutex poisoned")
216                })?;
217
218            if state.arrived == num_workers {
219                // Last worker: compute average and publish result.
220                if let Some(acc) = state.accumulator.take() {
221                    let scale = 1.0 / num_workers as f32;
222                    state.result = Some(acc.iter().map(|&x| x * scale).collect());
223                }
224                state.generation = state.generation.wrapping_add(1);
225                self.shared.all_arrived.notify_all();
226            } else {
227                let gen_before = state.generation;
228                // Release lock and wait.
229                let state = self
230                    .shared
231                    .all_arrived
232                    .wait_while(state, |s| s.generation == gen_before)
233                    .map_err(|_| {
234                        ModelError::load_error("gradient sync", "condvar wait failed (arrived)")
235                    })?;
236                // Keep `state` alive until end of block so the guard is dropped.
237                drop(state);
238            }
239        }
240
241        // ----------------------------------------------------------------
242        // Phase 3: read back the averaged result (result is now published).
243        // ----------------------------------------------------------------
244        {
245            let state =
246                self.shared.inner.lock().map_err(|_| {
247                    ModelError::load_error("gradient sync", "barrier mutex poisoned")
248                })?;
249            if let Some(result) = state.result.as_ref() {
250                for (g, &r) in gradients.iter_mut().zip(result.iter()) {
251                    *g = r;
252                }
253            }
254        }
255
256        // ----------------------------------------------------------------
257        // Phase 4: depart barrier — the last departing worker resets state
258        // so the next round can begin. Earlier departing workers wait until
259        // reset is complete to prevent fast workers from lapping.
260        // ----------------------------------------------------------------
261        let should_wait;
262        {
263            let mut state =
264                self.shared.inner.lock().map_err(|_| {
265                    ModelError::load_error("gradient sync", "barrier mutex poisoned")
266                })?;
267
268            state.departed += 1;
269            if state.departed == num_workers {
270                state.accumulator = None;
271                state.result = None;
272                state.arrived = 0;
273                state.departed = 0;
274                self.shared.all_departed.notify_all();
275                should_wait = false;
276            } else {
277                should_wait = true;
278            }
279        }
280
281        if should_wait {
282            let state =
283                self.shared.inner.lock().map_err(|_| {
284                    ModelError::load_error("gradient sync", "barrier mutex poisoned")
285                })?;
286            let _guard = self
287                .shared
288                .all_departed
289                .wait_while(state, |s| s.departed != 0)
290                .map_err(|_| {
291                    ModelError::load_error("gradient sync", "condvar wait failed (departed)")
292                })?;
293        }
294
295        Ok(())
296    }
297
298    fn is_distributed(&self) -> bool {
299        true
300    }
301
302    fn num_workers(&self) -> usize {
303        self.shared.num_workers
304    }
305}
306
307// ---------------------------------------------------------------------------
308// run_parallel_workers helper
309// ---------------------------------------------------------------------------
310
311/// Run a closure on each of `num_workers` [`ThreadedGradientSync`] instances
312/// in parallel threads, collecting the resulting gradient arrays.
313///
314/// This is primarily useful for testing all-reduce correctness:
315///
316/// ```rust,ignore
317/// let results = run_parallel_workers(2, |sync| {
318///     let mut grad = Array1::from_vec(vec![1.0, 2.0]);
319///     sync.sync_gradients(&mut grad).unwrap();
320///     grad
321/// });
322/// ```
323///
324/// # Type bounds
325///
326/// `F` must be `Send + Sync + Clone` so it can be cloned per worker and sent
327/// across thread boundaries.
328pub fn run_parallel_workers<F>(num_workers: usize, f: F) -> Vec<Array1<f32>>
329where
330    F: Fn(ThreadedGradientSync) -> Array1<f32> + Send + Sync + Clone + 'static,
331{
332    let syncs = ThreadedGradientSync::new_workers(num_workers);
333    let f = Arc::new(f);
334
335    let handles: Vec<_> = syncs
336        .into_iter()
337        .map(|sync| {
338            let f_clone = Arc::clone(&f);
339            std::thread::spawn(move || f_clone(sync))
340        })
341        .collect();
342
343    handles
344        .into_iter()
345        .map(|h| h.join().expect("worker thread panicked"))
346        .collect()
347}
348
349// ---------------------------------------------------------------------------
350// Tests
351// ---------------------------------------------------------------------------
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use scirs2_core::ndarray::Array1;
357
358    #[test]
359    fn test_local_gradient_sync_noop() {
360        let sync = LocalGradientSync::new();
361        let original = vec![1.0_f32, 2.0, 3.0, 4.0];
362        let mut gradients = Array1::from_vec(original.clone());
363
364        sync.sync_gradients(&mut gradients)
365            .expect("local sync should not fail");
366
367        for (g, o) in gradients.iter().zip(original.iter()) {
368            assert!(
369                (g - o).abs() < 1e-7,
370                "LocalGradientSync must not modify gradients: got {g} expected {o}"
371            );
372        }
373
374        assert!(!sync.is_distributed());
375        assert_eq!(sync.num_workers(), 1);
376    }
377
378    #[test]
379    fn test_threaded_gradient_sync_averaging() {
380        // Worker 0 has gradients [2.0, 4.0], worker 1 has [4.0, 8.0].
381        // Expected average: [3.0, 6.0].
382        let worker_grads = [vec![2.0_f32, 4.0], vec![4.0_f32, 8.0]];
383        let expected = [3.0_f32, 6.0];
384
385        let results = run_parallel_workers(2, move |sync| {
386            let id = sync.worker_id();
387            let mut grad = Array1::from_vec(worker_grads[id].clone());
388            sync.sync_gradients(&mut grad)
389                .expect("threaded sync should not fail");
390            grad
391        });
392
393        for result in &results {
394            for (r, e) in result.iter().zip(expected.iter()) {
395                assert!(
396                    (r - e).abs() < 1e-5,
397                    "averaged gradient mismatch: got {r} expected {e}"
398                );
399            }
400        }
401    }
402
403    #[test]
404    fn test_checkpoint_save_load_weights() {
405        use crate::checkpoint::CheckpointManager;
406        use std::env::temp_dir;
407
408        let dir = temp_dir().join(format!(
409            "kizzasi_weights_test_{}",
410            std::time::SystemTime::now()
411                .duration_since(std::time::UNIX_EPOCH)
412                .map(|d| d.as_nanos())
413                .unwrap_or(0)
414        ));
415
416        let manager = CheckpointManager::new(&dir);
417
418        let weights = Array1::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0, 5.0]);
419        let bias = 0.42_f32;
420        let step = 100_usize;
421
422        let path = manager
423            .save_weights(&weights, bias, step)
424            .expect("save_weights should succeed");
425
426        let (loaded_weights, loaded_bias) =
427            CheckpointManager::load_weights(&path).expect("load_weights should succeed");
428
429        assert_eq!(loaded_weights.len(), weights.len());
430        for (l, w) in loaded_weights.iter().zip(weights.iter()) {
431            assert!((l - w).abs() < 1e-6, "weight mismatch: {l} vs {w}");
432        }
433        assert!((loaded_bias - bias).abs() < 1e-6, "bias mismatch");
434    }
435}
436
437// ---------------------------------------------------------------------------
438// Data-Parallel Infrastructure
439// ---------------------------------------------------------------------------
440
441/// Gradient averaging strategy for distributed training.
442#[derive(Debug, Clone, Copy, PartialEq, Eq)]
443pub enum GradientStrategy {
444    /// Average gradients across all workers (AllReduce).
445    AllReduce,
446    /// Reduce to rank 0 only.
447    ReduceToRoot,
448    /// No gradient sync (for inference).
449    NoSync,
450}
451
452/// Communication backend selection.
453#[derive(Debug, Clone, Copy, PartialEq, Eq)]
454pub enum CommBackend {
455    /// In-process simulation — Pure Rust, no networking.
456    InProcess,
457    /// Placeholder for future external (NCCL/MPI) backend (C dependency, feature-gated).
458    #[allow(dead_code)]
459    External,
460}
461
462/// Configuration for distributed (data-parallel) training or inference.
463#[derive(Debug, Clone)]
464pub struct DistributedConfig {
465    /// Total number of data-parallel workers.
466    pub world_size: usize,
467    /// This worker's rank (0..world_size).
468    pub rank: usize,
469    /// How gradients are aggregated across workers.
470    pub grad_strategy: GradientStrategy,
471    /// Communication backend (always InProcess for Pure Rust).
472    pub backend: CommBackend,
473}
474
475impl Default for DistributedConfig {
476    fn default() -> Self {
477        Self {
478            world_size: 1,
479            rank: 0,
480            grad_strategy: GradientStrategy::AllReduce,
481            backend: CommBackend::InProcess,
482        }
483    }
484}
485
486/// Named gradient buffer for a single parameter tensor.
487#[derive(Debug, Clone)]
488pub struct GradientBuffer {
489    /// Parameter name (must match the weight key in the model's weight map).
490    pub name: String,
491    /// Gradient values, same length as the corresponding weight tensor.
492    pub gradients: Vec<f32>,
493}
494
495// ---------------------------------------------------------------------------
496// SharedGradientStore
497// ---------------------------------------------------------------------------
498
499/// Thread-safe gradient store that simulates AllReduce across `world_size` ranks.
500///
501/// Each rank pushes its local gradients via [`SharedGradientStore::push`].
502/// Once all ranks have pushed, any rank can call [`SharedGradientStore::all_reduce_mean`]
503/// to obtain the element-wise average. Call [`SharedGradientStore::clear`] after
504/// each optimiser step to reset state for the next iteration.
505pub struct SharedGradientStore {
506    buffers: Arc<Mutex<Vec<Option<Vec<GradientBuffer>>>>>,
507    world_size: usize,
508}
509
510impl SharedGradientStore {
511    /// Create a new store for `world_size` ranks.
512    pub fn new(world_size: usize) -> Self {
513        Self {
514            buffers: Arc::new(Mutex::new(vec![None; world_size])),
515            world_size,
516        }
517    }
518
519    /// Submit gradient buffers from `rank`.
520    ///
521    /// # Errors
522    /// Returns an error if the mutex is poisoned or `rank >= world_size`.
523    pub fn push(&self, rank: usize, grads: Vec<GradientBuffer>) -> ModelResult<()> {
524        if rank >= self.world_size {
525            return Err(ModelError::load_error(
526                "distributed",
527                format!(
528                    "rank {rank} out of bounds for world_size {}",
529                    self.world_size
530                ),
531            ));
532        }
533        let mut guard = self
534            .buffers
535            .lock()
536            .map_err(|_| ModelError::load_error("distributed", "lock poisoned"))?;
537        guard[rank] = Some(grads);
538        Ok(())
539    }
540
541    /// Wait until all ranks have pushed, then return the element-wise mean.
542    ///
543    /// In tests all ranks run in the same process/thread, so all buffers will
544    /// already be filled before this is called.
545    ///
546    /// # Errors
547    /// Returns an error if not all ranks have submitted yet, or on lock failure.
548    pub fn all_reduce_mean(&self, _rank: usize) -> ModelResult<Vec<GradientBuffer>> {
549        let guard = self
550            .buffers
551            .lock()
552            .map_err(|_| ModelError::load_error("distributed", "lock poisoned"))?;
553        let all_filled = guard.iter().all(|b| b.is_some());
554        if !all_filled {
555            return Err(ModelError::load_error(
556                "distributed",
557                "not all ranks have submitted gradients",
558            ));
559        }
560        let grad_lists: Vec<Vec<GradientBuffer>> = guard.iter().filter_map(|b| b.clone()).collect();
561        drop(guard);
562        average_gradients(&grad_lists)
563    }
564
565    /// Clear all gradient buffers — call after each optimiser step.
566    ///
567    /// # Errors
568    /// Returns an error on lock failure.
569    pub fn clear(&self) -> ModelResult<()> {
570        let mut guard = self
571            .buffers
572            .lock()
573            .map_err(|_| ModelError::load_error("distributed", "lock poisoned"))?;
574        for slot in guard.iter_mut() {
575            *slot = None;
576        }
577        Ok(())
578    }
579}
580
581// ---------------------------------------------------------------------------
582// DataParallelModel
583// ---------------------------------------------------------------------------
584
585/// Data-parallel wrapper around a named weight map.
586///
587/// Simulates splitting a mini-batch across `world_size` workers, each
588/// computing local gradients, then performing an AllReduce followed by an
589/// SGD update. Because all workers live in the same process they share a
590/// single `Arc<RwLock<HashMap>>` so weight broadcasts are free.
591pub struct DataParallelModel {
592    config: DistributedConfig,
593    weights: Arc<std::sync::RwLock<std::collections::HashMap<String, Vec<f32>>>>,
594    grad_store: Option<SharedGradientStore>,
595}
596
597impl DataParallelModel {
598    /// Create a new data-parallel model with the given weight map and config.
599    pub fn new(
600        weights: std::collections::HashMap<String, Vec<f32>>,
601        config: DistributedConfig,
602    ) -> Self {
603        let grad_store =
604            if config.grad_strategy == GradientStrategy::AllReduce && config.world_size > 1 {
605                Some(SharedGradientStore::new(config.world_size))
606            } else {
607                None
608            };
609        Self {
610            config,
611            weights: Arc::new(std::sync::RwLock::new(weights)),
612            grad_store,
613        }
614    }
615
616    /// Return a snapshot of the current weight map.
617    pub fn weights(&self) -> std::collections::HashMap<String, Vec<f32>> {
618        self.weights.read().map(|g| g.clone()).unwrap_or_default()
619    }
620
621    /// Apply a gradient update using the configured strategy.
622    ///
623    /// For `AllReduce` with `world_size > 1` this pushes local gradients to the
624    /// [`SharedGradientStore`] and then applies the averaged result. For single-
625    /// worker or `NoSync` modes the update is applied directly.
626    ///
627    /// # Errors
628    /// Propagates gradient-store and weight-lock errors.
629    pub fn step(&self, local_grads: Vec<GradientBuffer>, learning_rate: f32) -> ModelResult<()> {
630        let effective_grads = match &self.grad_store {
631            Some(store) => {
632                store.push(self.config.rank, local_grads)?;
633                store.all_reduce_mean(self.config.rank)?
634            }
635            None => local_grads,
636        };
637
638        let mut guard = self
639            .weights
640            .write()
641            .map_err(|_| ModelError::load_error("distributed", "weight RwLock poisoned"))?;
642        sgd_step(&mut guard, &effective_grads, learning_rate)
643    }
644
645    /// Broadcast weights from rank 0 to all ranks.
646    ///
647    /// In-process: all workers already share the same `Arc`, so this is a
648    /// no-op that succeeds immediately.
649    pub fn broadcast_weights(&self) -> ModelResult<()> {
650        // In-process: shared Arc means all workers see the same data.
651        Ok(())
652    }
653}
654
655// ---------------------------------------------------------------------------
656// Free functions
657// ---------------------------------------------------------------------------
658
659/// Partition `total` sample indices across `world_size` workers using round-robin.
660///
661/// Returns the indices owned by `rank`.
662pub fn partition_indices(total: usize, world_size: usize, rank: usize) -> Vec<usize> {
663    let step = world_size.max(1);
664    (rank..total).step_by(step).collect()
665}
666
667/// Compute the element-wise mean of multiple gradient-buffer lists.
668///
669/// All lists must contain the same number of buffers, each with the same
670/// gradient length.
671///
672/// # Errors
673/// Returns an error if buffer lists are mismatched in length or gradient sizes differ.
674pub fn average_gradients(grad_lists: &[Vec<GradientBuffer>]) -> ModelResult<Vec<GradientBuffer>> {
675    if grad_lists.is_empty() {
676        return Ok(vec![]);
677    }
678    let n = grad_lists.len() as f32;
679    let template = &grad_lists[0];
680    let mut result = template.clone();
681    for (i, res_buf) in result.iter_mut().enumerate() {
682        for list in grad_lists.iter().skip(1) {
683            let other = list.get(i).ok_or_else(|| {
684                ModelError::load_error("distributed", "gradient list length mismatch")
685            })?;
686            if other.gradients.len() != res_buf.gradients.len() {
687                return Err(ModelError::dimension_mismatch(
688                    "average_gradients",
689                    res_buf.gradients.len(),
690                    other.gradients.len(),
691                ));
692            }
693            for (r, o) in res_buf.gradients.iter_mut().zip(other.gradients.iter()) {
694                *r += o;
695            }
696        }
697        for v in res_buf.gradients.iter_mut() {
698            *v /= n;
699        }
700    }
701    Ok(result)
702}
703
704/// Apply a vanilla SGD update: `weight -= lr * gradient`.
705///
706/// Only weights that appear in `gradients` are updated; missing parameter
707/// names are silently skipped (sparse gradient support).
708///
709/// # Errors
710/// Returns an error if gradient and weight lengths differ for any parameter.
711pub fn sgd_step(
712    weights: &mut std::collections::HashMap<String, Vec<f32>>,
713    gradients: &[GradientBuffer],
714    lr: f32,
715) -> ModelResult<()> {
716    for grad_buf in gradients {
717        if let Some(w) = weights.get_mut(&grad_buf.name) {
718            if w.len() != grad_buf.gradients.len() {
719                return Err(ModelError::dimension_mismatch(
720                    "sgd_step",
721                    w.len(),
722                    grad_buf.gradients.len(),
723                ));
724            }
725            for (wi, &gi) in w.iter_mut().zip(grad_buf.gradients.iter()) {
726                *wi -= lr * gi;
727            }
728        }
729    }
730    Ok(())
731}
732
733// ---------------------------------------------------------------------------
734// Data-parallel tests
735// ---------------------------------------------------------------------------
736
737#[cfg(test)]
738mod dp_tests {
739    use super::*;
740
741    #[test]
742    fn test_partition_indices_basic() {
743        let idx = partition_indices(10, 3, 0);
744        assert_eq!(idx, vec![0, 3, 6, 9]);
745        let idx1 = partition_indices(10, 3, 1);
746        assert_eq!(idx1, vec![1, 4, 7]);
747        let idx2 = partition_indices(10, 3, 2);
748        assert_eq!(idx2, vec![2, 5, 8]);
749    }
750
751    #[test]
752    fn test_average_gradients_two_workers() {
753        let grads1 = vec![GradientBuffer {
754            name: "w".to_string(),
755            gradients: vec![1.0_f32, 2.0],
756        }];
757        let grads2 = vec![GradientBuffer {
758            name: "w".to_string(),
759            gradients: vec![3.0_f32, 4.0],
760        }];
761        let avg = average_gradients(&[grads1, grads2]).expect("average should succeed");
762        assert!((avg[0].gradients[0] - 2.0).abs() < 1e-6);
763        assert!((avg[0].gradients[1] - 3.0).abs() < 1e-6);
764    }
765
766    #[test]
767    fn test_sgd_step_updates_weights() {
768        let mut weights = std::collections::HashMap::new();
769        weights.insert("w".to_string(), vec![1.0_f32, 2.0, 3.0]);
770        let grads = vec![GradientBuffer {
771            name: "w".to_string(),
772            gradients: vec![0.1_f32, 0.2, 0.3],
773        }];
774        sgd_step(&mut weights, &grads, 1.0).expect("sgd_step should succeed");
775        assert!((weights["w"][0] - 0.9).abs() < 1e-6);
776        assert!((weights["w"][1] - 1.8).abs() < 1e-6);
777        assert!((weights["w"][2] - 2.7).abs() < 1e-6);
778    }
779
780    #[test]
781    fn test_shared_gradient_store_all_reduce() {
782        let store = SharedGradientStore::new(2);
783        let grads0 = vec![GradientBuffer {
784            name: "w".to_string(),
785            gradients: vec![1.0_f32, 2.0],
786        }];
787        let grads1 = vec![GradientBuffer {
788            name: "w".to_string(),
789            gradients: vec![3.0_f32, 4.0],
790        }];
791        store.push(0, grads0).expect("push rank 0");
792        store.push(1, grads1).expect("push rank 1");
793        let avg = store.all_reduce_mean(0).expect("all_reduce_mean");
794        assert!((avg[0].gradients[0] - 2.0).abs() < 1e-6);
795        assert!((avg[0].gradients[1] - 3.0).abs() < 1e-6);
796    }
797
798    #[test]
799    fn test_data_parallel_model_weights_shared() {
800        let mut weights = std::collections::HashMap::new();
801        weights.insert("embed".to_string(), vec![0.1_f32; 16]);
802        let model = DataParallelModel::new(weights, DistributedConfig::default());
803        let w = model.weights();
804        assert!(w.contains_key("embed"));
805        assert_eq!(w["embed"].len(), 16);
806    }
807
808    #[test]
809    fn test_distributed_config_default() {
810        let cfg = DistributedConfig::default();
811        assert_eq!(cfg.world_size, 1);
812        assert_eq!(cfg.rank, 0);
813        assert_eq!(cfg.grad_strategy, GradientStrategy::AllReduce);
814        assert_eq!(cfg.backend, CommBackend::InProcess);
815    }
816
817    #[test]
818    fn test_partition_indices_single_worker() {
819        let idx = partition_indices(5, 1, 0);
820        assert_eq!(idx, vec![0, 1, 2, 3, 4]);
821    }
822
823    #[test]
824    fn test_average_gradients_single() {
825        let grads = vec![GradientBuffer {
826            name: "w".to_string(),
827            gradients: vec![2.0_f32, 4.0],
828        }];
829        let avg = average_gradients(&[grads]).expect("single-list average");
830        assert_eq!(avg[0].gradients, vec![2.0_f32, 4.0]);
831    }
832
833    #[test]
834    fn test_data_parallel_model_step_single_worker() {
835        let mut weights = std::collections::HashMap::new();
836        weights.insert("w".to_string(), vec![1.0_f32, 2.0]);
837        let model = DataParallelModel::new(weights, DistributedConfig::default());
838        let grads = vec![GradientBuffer {
839            name: "w".to_string(),
840            gradients: vec![0.5_f32, 0.5],
841        }];
842        model.step(grads, 0.1).expect("step should succeed");
843        let w = model.weights();
844        assert!((w["w"][0] - 0.95).abs() < 1e-6);
845        assert!((w["w"][1] - 1.95).abs() < 1e-6);
846    }
847
848    #[test]
849    fn test_broadcast_weights_noop() {
850        let weights = std::collections::HashMap::new();
851        let model = DataParallelModel::new(weights, DistributedConfig::default());
852        assert!(model.broadcast_weights().is_ok());
853    }
854
855    #[test]
856    fn test_shared_gradient_store_clear() {
857        let store = SharedGradientStore::new(1);
858        let grads = vec![GradientBuffer {
859            name: "w".to_string(),
860            gradients: vec![1.0_f32],
861        }];
862        store.push(0, grads).expect("push");
863        store.clear().expect("clear");
864        // After clear, all_reduce_mean should fail (not all submitted)
865        assert!(store.all_reduce_mean(0).is_err());
866    }
867}