1use crate::error::{ModelError, ModelResult};
15use scirs2_core::ndarray::Array1;
16use std::sync::{Arc, Condvar, Mutex};
17
18pub trait GradientSync: Send {
27 fn sync_gradients(&self, gradients: &mut Array1<f32>) -> ModelResult<()>;
31
32 fn is_distributed(&self) -> bool {
34 false
35 }
36
37 fn num_workers(&self) -> usize {
39 1
40 }
41}
42
43#[derive(Debug, Clone, Default)]
52pub struct LocalGradientSync;
53
54impl LocalGradientSync {
55 pub fn new() -> Self {
57 Self
58 }
59}
60
61impl GradientSync for LocalGradientSync {
62 #[inline]
63 fn sync_gradients(&self, _gradients: &mut Array1<f32>) -> ModelResult<()> {
64 Ok(())
66 }
67
68 fn is_distributed(&self) -> bool {
69 false
70 }
71
72 fn num_workers(&self) -> usize {
73 1
74 }
75}
76
77#[derive(Debug)]
88struct BarrierState {
89 accumulator: Option<Vec<f32>>,
91 result: Option<Vec<f32>>,
93 arrived: usize,
95 departed: usize,
97 generation: usize,
100}
101
102impl BarrierState {
103 fn new() -> Self {
104 Self {
105 accumulator: None,
106 result: None,
107 arrived: 0,
108 departed: 0,
109 generation: 0,
110 }
111 }
112}
113
114#[derive(Debug)]
115struct SharedState {
116 inner: Mutex<BarrierState>,
117 all_arrived: Condvar,
118 all_departed: Condvar,
119 num_workers: usize,
120}
121
122impl SharedState {
123 fn new(num_workers: usize) -> Self {
124 Self {
125 inner: Mutex::new(BarrierState::new()),
126 all_arrived: Condvar::new(),
127 all_departed: Condvar::new(),
128 num_workers,
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
146pub struct ThreadedGradientSync {
147 shared: Arc<SharedState>,
148 worker_id: usize,
149}
150
151impl ThreadedGradientSync {
152 pub fn new_workers(num_workers: usize) -> Vec<Self> {
158 assert!(num_workers > 0, "num_workers must be at least 1");
159 let shared = Arc::new(SharedState::new(num_workers));
160 (0..num_workers)
161 .map(|id| Self {
162 shared: Arc::clone(&shared),
163 worker_id: id,
164 })
165 .collect()
166 }
167
168 pub fn worker_id(&self) -> usize {
170 self.worker_id
171 }
172}
173
174impl GradientSync for ThreadedGradientSync {
175 fn sync_gradients(&self, gradients: &mut Array1<f32>) -> ModelResult<()> {
176 let n = gradients.len();
177 let num_workers = self.shared.num_workers;
178
179 {
183 let mut state =
184 self.shared.inner.lock().map_err(|_| {
185 ModelError::load_error("gradient sync", "barrier mutex poisoned")
186 })?;
187
188 match state.accumulator.as_mut() {
189 None => {
190 state.accumulator = Some(gradients.iter().copied().collect());
191 }
192 Some(acc) => {
193 if acc.len() != n {
194 return Err(ModelError::dimension_mismatch(
195 "gradient sync",
196 acc.len(),
197 n,
198 ));
199 }
200 for (a, &g) in acc.iter_mut().zip(gradients.iter()) {
201 *a += g;
202 }
203 }
204 }
205 state.arrived += 1;
206 }
207
208 {
213 let mut state =
214 self.shared.inner.lock().map_err(|_| {
215 ModelError::load_error("gradient sync", "barrier mutex poisoned")
216 })?;
217
218 if state.arrived == num_workers {
219 if let Some(acc) = state.accumulator.take() {
221 let scale = 1.0 / num_workers as f32;
222 state.result = Some(acc.iter().map(|&x| x * scale).collect());
223 }
224 state.generation = state.generation.wrapping_add(1);
225 self.shared.all_arrived.notify_all();
226 } else {
227 let gen_before = state.generation;
228 let state = self
230 .shared
231 .all_arrived
232 .wait_while(state, |s| s.generation == gen_before)
233 .map_err(|_| {
234 ModelError::load_error("gradient sync", "condvar wait failed (arrived)")
235 })?;
236 drop(state);
238 }
239 }
240
241 {
245 let state =
246 self.shared.inner.lock().map_err(|_| {
247 ModelError::load_error("gradient sync", "barrier mutex poisoned")
248 })?;
249 if let Some(result) = state.result.as_ref() {
250 for (g, &r) in gradients.iter_mut().zip(result.iter()) {
251 *g = r;
252 }
253 }
254 }
255
256 let should_wait;
262 {
263 let mut state =
264 self.shared.inner.lock().map_err(|_| {
265 ModelError::load_error("gradient sync", "barrier mutex poisoned")
266 })?;
267
268 state.departed += 1;
269 if state.departed == num_workers {
270 state.accumulator = None;
271 state.result = None;
272 state.arrived = 0;
273 state.departed = 0;
274 self.shared.all_departed.notify_all();
275 should_wait = false;
276 } else {
277 should_wait = true;
278 }
279 }
280
281 if should_wait {
282 let state =
283 self.shared.inner.lock().map_err(|_| {
284 ModelError::load_error("gradient sync", "barrier mutex poisoned")
285 })?;
286 let _guard = self
287 .shared
288 .all_departed
289 .wait_while(state, |s| s.departed != 0)
290 .map_err(|_| {
291 ModelError::load_error("gradient sync", "condvar wait failed (departed)")
292 })?;
293 }
294
295 Ok(())
296 }
297
298 fn is_distributed(&self) -> bool {
299 true
300 }
301
302 fn num_workers(&self) -> usize {
303 self.shared.num_workers
304 }
305}
306
307pub fn run_parallel_workers<F>(num_workers: usize, f: F) -> Vec<Array1<f32>>
329where
330 F: Fn(ThreadedGradientSync) -> Array1<f32> + Send + Sync + Clone + 'static,
331{
332 let syncs = ThreadedGradientSync::new_workers(num_workers);
333 let f = Arc::new(f);
334
335 let handles: Vec<_> = syncs
336 .into_iter()
337 .map(|sync| {
338 let f_clone = Arc::clone(&f);
339 std::thread::spawn(move || f_clone(sync))
340 })
341 .collect();
342
343 handles
344 .into_iter()
345 .map(|h| h.join().expect("worker thread panicked"))
346 .collect()
347}
348
349#[cfg(test)]
354mod tests {
355 use super::*;
356 use scirs2_core::ndarray::Array1;
357
358 #[test]
359 fn test_local_gradient_sync_noop() {
360 let sync = LocalGradientSync::new();
361 let original = vec![1.0_f32, 2.0, 3.0, 4.0];
362 let mut gradients = Array1::from_vec(original.clone());
363
364 sync.sync_gradients(&mut gradients)
365 .expect("local sync should not fail");
366
367 for (g, o) in gradients.iter().zip(original.iter()) {
368 assert!(
369 (g - o).abs() < 1e-7,
370 "LocalGradientSync must not modify gradients: got {g} expected {o}"
371 );
372 }
373
374 assert!(!sync.is_distributed());
375 assert_eq!(sync.num_workers(), 1);
376 }
377
378 #[test]
379 fn test_threaded_gradient_sync_averaging() {
380 let worker_grads = [vec![2.0_f32, 4.0], vec![4.0_f32, 8.0]];
383 let expected = [3.0_f32, 6.0];
384
385 let results = run_parallel_workers(2, move |sync| {
386 let id = sync.worker_id();
387 let mut grad = Array1::from_vec(worker_grads[id].clone());
388 sync.sync_gradients(&mut grad)
389 .expect("threaded sync should not fail");
390 grad
391 });
392
393 for result in &results {
394 for (r, e) in result.iter().zip(expected.iter()) {
395 assert!(
396 (r - e).abs() < 1e-5,
397 "averaged gradient mismatch: got {r} expected {e}"
398 );
399 }
400 }
401 }
402
403 #[test]
404 fn test_checkpoint_save_load_weights() {
405 use crate::checkpoint::CheckpointManager;
406 use std::env::temp_dir;
407
408 let dir = temp_dir().join(format!(
409 "kizzasi_weights_test_{}",
410 std::time::SystemTime::now()
411 .duration_since(std::time::UNIX_EPOCH)
412 .map(|d| d.as_nanos())
413 .unwrap_or(0)
414 ));
415
416 let manager = CheckpointManager::new(&dir);
417
418 let weights = Array1::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0, 5.0]);
419 let bias = 0.42_f32;
420 let step = 100_usize;
421
422 let path = manager
423 .save_weights(&weights, bias, step)
424 .expect("save_weights should succeed");
425
426 let (loaded_weights, loaded_bias) =
427 CheckpointManager::load_weights(&path).expect("load_weights should succeed");
428
429 assert_eq!(loaded_weights.len(), weights.len());
430 for (l, w) in loaded_weights.iter().zip(weights.iter()) {
431 assert!((l - w).abs() < 1e-6, "weight mismatch: {l} vs {w}");
432 }
433 assert!((loaded_bias - bias).abs() < 1e-6, "bias mismatch");
434 }
435}
436
437#[derive(Debug, Clone, Copy, PartialEq, Eq)]
443pub enum GradientStrategy {
444 AllReduce,
446 ReduceToRoot,
448 NoSync,
450}
451
452#[derive(Debug, Clone, Copy, PartialEq, Eq)]
454pub enum CommBackend {
455 InProcess,
457 #[allow(dead_code)]
459 External,
460}
461
462#[derive(Debug, Clone)]
464pub struct DistributedConfig {
465 pub world_size: usize,
467 pub rank: usize,
469 pub grad_strategy: GradientStrategy,
471 pub backend: CommBackend,
473}
474
475impl Default for DistributedConfig {
476 fn default() -> Self {
477 Self {
478 world_size: 1,
479 rank: 0,
480 grad_strategy: GradientStrategy::AllReduce,
481 backend: CommBackend::InProcess,
482 }
483 }
484}
485
486#[derive(Debug, Clone)]
488pub struct GradientBuffer {
489 pub name: String,
491 pub gradients: Vec<f32>,
493}
494
495pub struct SharedGradientStore {
506 buffers: Arc<Mutex<Vec<Option<Vec<GradientBuffer>>>>>,
507 world_size: usize,
508}
509
510impl SharedGradientStore {
511 pub fn new(world_size: usize) -> Self {
513 Self {
514 buffers: Arc::new(Mutex::new(vec![None; world_size])),
515 world_size,
516 }
517 }
518
519 pub fn push(&self, rank: usize, grads: Vec<GradientBuffer>) -> ModelResult<()> {
524 if rank >= self.world_size {
525 return Err(ModelError::load_error(
526 "distributed",
527 format!(
528 "rank {rank} out of bounds for world_size {}",
529 self.world_size
530 ),
531 ));
532 }
533 let mut guard = self
534 .buffers
535 .lock()
536 .map_err(|_| ModelError::load_error("distributed", "lock poisoned"))?;
537 guard[rank] = Some(grads);
538 Ok(())
539 }
540
541 pub fn all_reduce_mean(&self, _rank: usize) -> ModelResult<Vec<GradientBuffer>> {
549 let guard = self
550 .buffers
551 .lock()
552 .map_err(|_| ModelError::load_error("distributed", "lock poisoned"))?;
553 let all_filled = guard.iter().all(|b| b.is_some());
554 if !all_filled {
555 return Err(ModelError::load_error(
556 "distributed",
557 "not all ranks have submitted gradients",
558 ));
559 }
560 let grad_lists: Vec<Vec<GradientBuffer>> = guard.iter().filter_map(|b| b.clone()).collect();
561 drop(guard);
562 average_gradients(&grad_lists)
563 }
564
565 pub fn clear(&self) -> ModelResult<()> {
570 let mut guard = self
571 .buffers
572 .lock()
573 .map_err(|_| ModelError::load_error("distributed", "lock poisoned"))?;
574 for slot in guard.iter_mut() {
575 *slot = None;
576 }
577 Ok(())
578 }
579}
580
581pub struct DataParallelModel {
592 config: DistributedConfig,
593 weights: Arc<std::sync::RwLock<std::collections::HashMap<String, Vec<f32>>>>,
594 grad_store: Option<SharedGradientStore>,
595}
596
597impl DataParallelModel {
598 pub fn new(
600 weights: std::collections::HashMap<String, Vec<f32>>,
601 config: DistributedConfig,
602 ) -> Self {
603 let grad_store =
604 if config.grad_strategy == GradientStrategy::AllReduce && config.world_size > 1 {
605 Some(SharedGradientStore::new(config.world_size))
606 } else {
607 None
608 };
609 Self {
610 config,
611 weights: Arc::new(std::sync::RwLock::new(weights)),
612 grad_store,
613 }
614 }
615
616 pub fn weights(&self) -> std::collections::HashMap<String, Vec<f32>> {
618 self.weights.read().map(|g| g.clone()).unwrap_or_default()
619 }
620
621 pub fn step(&self, local_grads: Vec<GradientBuffer>, learning_rate: f32) -> ModelResult<()> {
630 let effective_grads = match &self.grad_store {
631 Some(store) => {
632 store.push(self.config.rank, local_grads)?;
633 store.all_reduce_mean(self.config.rank)?
634 }
635 None => local_grads,
636 };
637
638 let mut guard = self
639 .weights
640 .write()
641 .map_err(|_| ModelError::load_error("distributed", "weight RwLock poisoned"))?;
642 sgd_step(&mut guard, &effective_grads, learning_rate)
643 }
644
645 pub fn broadcast_weights(&self) -> ModelResult<()> {
650 Ok(())
652 }
653}
654
655pub fn partition_indices(total: usize, world_size: usize, rank: usize) -> Vec<usize> {
663 let step = world_size.max(1);
664 (rank..total).step_by(step).collect()
665}
666
667pub fn average_gradients(grad_lists: &[Vec<GradientBuffer>]) -> ModelResult<Vec<GradientBuffer>> {
675 if grad_lists.is_empty() {
676 return Ok(vec![]);
677 }
678 let n = grad_lists.len() as f32;
679 let template = &grad_lists[0];
680 let mut result = template.clone();
681 for (i, res_buf) in result.iter_mut().enumerate() {
682 for list in grad_lists.iter().skip(1) {
683 let other = list.get(i).ok_or_else(|| {
684 ModelError::load_error("distributed", "gradient list length mismatch")
685 })?;
686 if other.gradients.len() != res_buf.gradients.len() {
687 return Err(ModelError::dimension_mismatch(
688 "average_gradients",
689 res_buf.gradients.len(),
690 other.gradients.len(),
691 ));
692 }
693 for (r, o) in res_buf.gradients.iter_mut().zip(other.gradients.iter()) {
694 *r += o;
695 }
696 }
697 for v in res_buf.gradients.iter_mut() {
698 *v /= n;
699 }
700 }
701 Ok(result)
702}
703
704pub fn sgd_step(
712 weights: &mut std::collections::HashMap<String, Vec<f32>>,
713 gradients: &[GradientBuffer],
714 lr: f32,
715) -> ModelResult<()> {
716 for grad_buf in gradients {
717 if let Some(w) = weights.get_mut(&grad_buf.name) {
718 if w.len() != grad_buf.gradients.len() {
719 return Err(ModelError::dimension_mismatch(
720 "sgd_step",
721 w.len(),
722 grad_buf.gradients.len(),
723 ));
724 }
725 for (wi, &gi) in w.iter_mut().zip(grad_buf.gradients.iter()) {
726 *wi -= lr * gi;
727 }
728 }
729 }
730 Ok(())
731}
732
733#[cfg(test)]
738mod dp_tests {
739 use super::*;
740
741 #[test]
742 fn test_partition_indices_basic() {
743 let idx = partition_indices(10, 3, 0);
744 assert_eq!(idx, vec![0, 3, 6, 9]);
745 let idx1 = partition_indices(10, 3, 1);
746 assert_eq!(idx1, vec![1, 4, 7]);
747 let idx2 = partition_indices(10, 3, 2);
748 assert_eq!(idx2, vec![2, 5, 8]);
749 }
750
751 #[test]
752 fn test_average_gradients_two_workers() {
753 let grads1 = vec![GradientBuffer {
754 name: "w".to_string(),
755 gradients: vec![1.0_f32, 2.0],
756 }];
757 let grads2 = vec![GradientBuffer {
758 name: "w".to_string(),
759 gradients: vec![3.0_f32, 4.0],
760 }];
761 let avg = average_gradients(&[grads1, grads2]).expect("average should succeed");
762 assert!((avg[0].gradients[0] - 2.0).abs() < 1e-6);
763 assert!((avg[0].gradients[1] - 3.0).abs() < 1e-6);
764 }
765
766 #[test]
767 fn test_sgd_step_updates_weights() {
768 let mut weights = std::collections::HashMap::new();
769 weights.insert("w".to_string(), vec![1.0_f32, 2.0, 3.0]);
770 let grads = vec![GradientBuffer {
771 name: "w".to_string(),
772 gradients: vec![0.1_f32, 0.2, 0.3],
773 }];
774 sgd_step(&mut weights, &grads, 1.0).expect("sgd_step should succeed");
775 assert!((weights["w"][0] - 0.9).abs() < 1e-6);
776 assert!((weights["w"][1] - 1.8).abs() < 1e-6);
777 assert!((weights["w"][2] - 2.7).abs() < 1e-6);
778 }
779
780 #[test]
781 fn test_shared_gradient_store_all_reduce() {
782 let store = SharedGradientStore::new(2);
783 let grads0 = vec![GradientBuffer {
784 name: "w".to_string(),
785 gradients: vec![1.0_f32, 2.0],
786 }];
787 let grads1 = vec![GradientBuffer {
788 name: "w".to_string(),
789 gradients: vec![3.0_f32, 4.0],
790 }];
791 store.push(0, grads0).expect("push rank 0");
792 store.push(1, grads1).expect("push rank 1");
793 let avg = store.all_reduce_mean(0).expect("all_reduce_mean");
794 assert!((avg[0].gradients[0] - 2.0).abs() < 1e-6);
795 assert!((avg[0].gradients[1] - 3.0).abs() < 1e-6);
796 }
797
798 #[test]
799 fn test_data_parallel_model_weights_shared() {
800 let mut weights = std::collections::HashMap::new();
801 weights.insert("embed".to_string(), vec![0.1_f32; 16]);
802 let model = DataParallelModel::new(weights, DistributedConfig::default());
803 let w = model.weights();
804 assert!(w.contains_key("embed"));
805 assert_eq!(w["embed"].len(), 16);
806 }
807
808 #[test]
809 fn test_distributed_config_default() {
810 let cfg = DistributedConfig::default();
811 assert_eq!(cfg.world_size, 1);
812 assert_eq!(cfg.rank, 0);
813 assert_eq!(cfg.grad_strategy, GradientStrategy::AllReduce);
814 assert_eq!(cfg.backend, CommBackend::InProcess);
815 }
816
817 #[test]
818 fn test_partition_indices_single_worker() {
819 let idx = partition_indices(5, 1, 0);
820 assert_eq!(idx, vec![0, 1, 2, 3, 4]);
821 }
822
823 #[test]
824 fn test_average_gradients_single() {
825 let grads = vec![GradientBuffer {
826 name: "w".to_string(),
827 gradients: vec![2.0_f32, 4.0],
828 }];
829 let avg = average_gradients(&[grads]).expect("single-list average");
830 assert_eq!(avg[0].gradients, vec![2.0_f32, 4.0]);
831 }
832
833 #[test]
834 fn test_data_parallel_model_step_single_worker() {
835 let mut weights = std::collections::HashMap::new();
836 weights.insert("w".to_string(), vec![1.0_f32, 2.0]);
837 let model = DataParallelModel::new(weights, DistributedConfig::default());
838 let grads = vec![GradientBuffer {
839 name: "w".to_string(),
840 gradients: vec![0.5_f32, 0.5],
841 }];
842 model.step(grads, 0.1).expect("step should succeed");
843 let w = model.weights();
844 assert!((w["w"][0] - 0.95).abs() < 1e-6);
845 assert!((w["w"][1] - 1.95).abs() < 1e-6);
846 }
847
848 #[test]
849 fn test_broadcast_weights_noop() {
850 let weights = std::collections::HashMap::new();
851 let model = DataParallelModel::new(weights, DistributedConfig::default());
852 assert!(model.broadcast_weights().is_ok());
853 }
854
855 #[test]
856 fn test_shared_gradient_store_clear() {
857 let store = SharedGradientStore::new(1);
858 let grads = vec![GradientBuffer {
859 name: "w".to_string(),
860 gradients: vec![1.0_f32],
861 }];
862 store.push(0, grads).expect("push");
863 store.clear().expect("clear");
864 assert!(store.all_reduce_mean(0).is_err());
866 }
867}