1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum AveragingStrategy {
15 Arithmetic,
17 WeightedByData,
19 WeightedByTime,
21 Federated,
23 Momentum {
25 momentum: f64,
27 },
28 ExponentialMovingAverage {
30 decay: f64,
32 },
33}
34
35#[derive(Debug)]
37pub struct ParameterAverager<A: Float, D: Dimension> {
38 averaged_params: Vec<Array<A, D>>,
40 strategy: AveragingStrategy,
42 node_weights: HashMap<usize, A>,
44 numnodes: usize,
46 momentum_buffer: Option<Vec<Array<A, D>>>,
48 step_count: usize,
50 initialized: bool,
52}
53
54impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
55 ParameterAverager<A, D>
56{
57 pub fn new(strategy: AveragingStrategy, numnodes: usize) -> Self {
59 Self {
60 averaged_params: Vec::new(),
61 strategy,
62 node_weights: HashMap::new(),
63 numnodes,
64 momentum_buffer: None,
65 step_count: 0,
66 initialized: false,
67 }
68 }
69
70 pub fn initialize(&mut self, params: &[Array<A, D>]) -> Result<()> {
72 if self.initialized {
73 return Err(OptimError::InvalidConfig(
74 "Parameter averager already initialized".to_string(),
75 ));
76 }
77
78 self.averaged_params = params.to_vec();
79
80 if matches!(self.strategy, AveragingStrategy::Momentum { .. }) {
82 self.momentum_buffer = Some(params.iter().map(|p| Array::zeros(p.raw_dim())).collect());
83 }
84
85 let uniform_weight = A::one() / A::from(self.numnodes).expect("unwrap failed");
87 for nodeid in 0..self.numnodes {
88 self.node_weights.insert(nodeid, uniform_weight);
89 }
90
91 self.initialized = true;
92 Ok(())
93 }
94
95 pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
97 if nodeid >= self.numnodes {
98 return Err(OptimError::InvalidConfig(format!(
99 "Node ID {} exceeds number of nodes {}",
100 nodeid, self.numnodes
101 )));
102 }
103 self.node_weights.insert(nodeid, weight);
104 Ok(())
105 }
106
107 pub fn average_parameters(
109 &mut self,
110 nodeparameters: &[(usize, Vec<Array<A, D>>)],
111 ) -> Result<()> {
112 if !self.initialized {
113 if let Some((_, first_params)) = nodeparameters.first() {
114 self.initialize(first_params)?;
115 } else {
116 return Err(OptimError::InvalidConfig(
117 "No _parameters provided for initialization".to_string(),
118 ));
119 }
120 }
121
122 for (nodeid, params) in nodeparameters {
124 if *nodeid >= self.numnodes {
125 return Err(OptimError::InvalidConfig(format!(
126 "Node ID {} exceeds number of nodes {}",
127 nodeid, self.numnodes
128 )));
129 }
130 if params.len() != self.averaged_params.len() {
131 return Err(OptimError::DimensionMismatch(format!(
132 "Expected {} parameter arrays, got {}",
133 self.averaged_params.len(),
134 params.len()
135 )));
136 }
137 }
138
139 self.step_count += 1;
140
141 match self.strategy {
142 AveragingStrategy::Arithmetic => {
143 self.arithmetic_average(nodeparameters)?;
144 }
145 AveragingStrategy::WeightedByData | AveragingStrategy::WeightedByTime => {
146 self.weighted_average(nodeparameters)?;
147 }
148 AveragingStrategy::Federated => {
149 self.federated_average(nodeparameters)?;
150 }
151 AveragingStrategy::Momentum { momentum } => {
152 self.momentum_average(nodeparameters, momentum)?;
153 }
154 AveragingStrategy::ExponentialMovingAverage { decay } => {
155 self.ema_average(nodeparameters, decay)?;
156 }
157 }
158
159 Ok(())
160 }
161
162 fn arithmetic_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
164 for param in &mut self.averaged_params {
166 param.fill(A::zero());
167 }
168
169 let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
170
171 for (_node_id, params) in nodeparameters {
173 for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
174 Zip::from(avg_param).and(param).for_each(|avg, &p| {
175 *avg = *avg + p;
176 });
177 }
178 }
179
180 for param in &mut self.averaged_params {
182 param.mapv_inplace(|x| x / numnodes);
183 }
184
185 Ok(())
186 }
187
188 fn weighted_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
190 for param in &mut self.averaged_params {
192 param.fill(A::zero());
193 }
194
195 let total_weight: A = nodeparameters
197 .iter()
198 .map(|(nodeid, _)| self.node_weights.get(nodeid).copied().unwrap_or(A::zero()))
199 .fold(A::zero(), |acc, w| acc + w);
200
201 if total_weight <= A::zero() {
202 return Err(OptimError::InvalidConfig(
203 "Total node weights must be > 0".to_string(),
204 ));
205 }
206
207 for (nodeid, params) in nodeparameters {
209 let weight = self.node_weights.get(nodeid).copied().unwrap_or(A::zero()) / total_weight;
210
211 for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
212 Zip::from(avg_param).and(param).for_each(|avg, &p| {
213 *avg = *avg + weight * p;
214 });
215 }
216 }
217
218 Ok(())
219 }
220
221 fn federated_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
223 self.weighted_average(nodeparameters)
226 }
227
228 fn momentum_average(
230 &mut self,
231 nodeparameters: &[(usize, Vec<Array<A, D>>)],
232 momentum: f64,
233 ) -> Result<()> {
234 let momentum_factor = A::from(momentum).expect("unwrap failed");
235 let one_minus_momentum = A::one() - momentum_factor;
236
237 let mut current_average: Vec<Array<A, D>> = self
239 .averaged_params
240 .iter()
241 .map(|param| Array::zeros(param.raw_dim()))
242 .collect();
243
244 let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
245 for (_node_id, params) in nodeparameters {
246 for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
247 Zip::from(avg_param).and(param).for_each(|avg, &p| {
248 *avg = *avg + p / numnodes;
249 });
250 }
251 }
252
253 if let Some(ref mut momentum_buf) = self.momentum_buffer {
255 for ((avg_param, current_param), momentum_param) in self
256 .averaged_params
257 .iter_mut()
258 .zip(current_average.iter())
259 .zip(momentum_buf.iter_mut())
260 {
261 Zip::from(&mut *momentum_param)
263 .and(current_param)
264 .for_each(|mom, &curr| {
265 *mom = momentum_factor * *mom + one_minus_momentum * curr;
266 });
267
268 avg_param.assign(&*momentum_param);
270 }
271 }
272
273 Ok(())
274 }
275
276 fn ema_average(
278 &mut self,
279 nodeparameters: &[(usize, Vec<Array<A, D>>)],
280 decay: f64,
281 ) -> Result<()> {
282 let decay_factor = A::from(decay).expect("unwrap failed");
283 let one_minus_decay = A::one() - decay_factor;
284
285 let mut current_average: Vec<Array<A, D>> = self
287 .averaged_params
288 .iter()
289 .map(|param| Array::zeros(param.raw_dim()))
290 .collect();
291
292 let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
293 for (_node_id, params) in nodeparameters {
294 for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
295 Zip::from(avg_param).and(param).for_each(|avg, &p| {
296 *avg = *avg + p / numnodes;
297 });
298 }
299 }
300
301 for (avg_param, current_param) in
303 self.averaged_params.iter_mut().zip(current_average.iter())
304 {
305 Zip::from(avg_param)
306 .and(current_param)
307 .for_each(|avg, &curr| {
308 *avg = decay_factor * *avg + one_minus_decay * curr;
309 });
310 }
311
312 Ok(())
313 }
314
315 pub fn get_averaged_parameters(&self) -> &[Array<A, D>] {
317 &self.averaged_params
318 }
319
320 pub fn get_averaged_parameters_cloned(&self) -> Vec<Array<A, D>> {
322 self.averaged_params.clone()
323 }
324
325 pub fn reset(&mut self) {
327 self.step_count = 0;
328 for param in &mut self.averaged_params {
329 param.fill(A::zero());
330 }
331 if let Some(ref mut momentum_buf) = self.momentum_buffer {
332 for buf in momentum_buf {
333 buf.fill(A::zero());
334 }
335 }
336 }
337
338 pub fn step_count(&self) -> usize {
340 self.step_count
341 }
342
343 pub fn numnodes(&self) -> usize {
345 self.numnodes
346 }
347
348 pub fn strategy(&self) -> AveragingStrategy {
350 self.strategy
351 }
352
353 pub fn is_initialized(&self) -> bool {
355 self.initialized
356 }
357}
358
359#[derive(Debug)]
361pub struct ParameterServer<A: Float, D: Dimension> {
362 averager: ParameterAverager<A, D>,
364 global_parameters: Vec<Array<A, D>>,
366 update_counts: HashMap<usize, usize>,
368 expected_updates_per_round: usize,
370 current_round: usize,
372 pending_updates: HashMap<usize, Vec<Array<A, D>>>,
374}
375
376impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
377 ParameterServer<A, D>
378{
379 pub fn new(
381 strategy: AveragingStrategy,
382 numnodes: usize,
383 expected_updates_per_round: usize,
384 ) -> Self {
385 Self {
386 averager: ParameterAverager::new(strategy, numnodes),
387 global_parameters: Vec::new(),
388 update_counts: HashMap::new(),
389 expected_updates_per_round,
390 current_round: 0,
391 pending_updates: HashMap::new(),
392 }
393 }
394
395 pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
397 self.averager.initialize(initialparams)?;
398 self.global_parameters = initialparams.to_vec();
399
400 for nodeid in 0..self.averager.numnodes() {
402 self.update_counts.insert(nodeid, 0);
403 }
404
405 Ok(())
406 }
407
408 pub fn submit_update(&mut self, nodeid: usize, parameters: Vec<Array<A, D>>) -> Result<bool> {
410 if nodeid >= self.averager.numnodes() {
411 return Err(OptimError::InvalidConfig(format!(
412 "Node ID {} exceeds number of nodes {}",
413 nodeid,
414 self.averager.numnodes()
415 )));
416 }
417
418 self.pending_updates.insert(nodeid, parameters);
420 *self.update_counts.entry(nodeid).or_insert(0) += 1;
421
422 let ready_for_aggregation = self.pending_updates.len() >= self.expected_updates_per_round;
424
425 if ready_for_aggregation {
426 self.aggregate_and_update()?;
427 }
428
429 Ok(ready_for_aggregation)
430 }
431
432 pub fn force_aggregation(&mut self) -> Result<()> {
434 if !self.pending_updates.is_empty() {
435 self.aggregate_and_update()?;
436 }
437 Ok(())
438 }
439
440 fn aggregate_and_update(&mut self) -> Result<()> {
442 let node_params: Vec<(usize, Vec<Array<A, D>>)> = self.pending_updates.drain().collect();
444
445 self.averager.average_parameters(&node_params)?;
447
448 self.global_parameters = self.averager.get_averaged_parameters_cloned();
450
451 self.current_round += 1;
453
454 Ok(())
455 }
456
457 pub fn get_global_parameters(&self) -> &[Array<A, D>] {
459 &self.global_parameters
460 }
461
462 pub fn get_global_parameters_cloned(&self) -> Vec<Array<A, D>> {
464 self.global_parameters.clone()
465 }
466
467 pub fn current_round(&self) -> usize {
469 self.current_round
470 }
471
472 pub fn get_update_count(&self, nodeid: usize) -> usize {
474 self.update_counts.get(&nodeid).copied().unwrap_or(0)
475 }
476
477 pub fn pending_updates_count(&self) -> usize {
479 self.pending_updates.len()
480 }
481
482 pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
484 self.averager.set_node_weight(nodeid, weight)
485 }
486
487 pub fn reset(&mut self) {
489 self.averager.reset();
490 self.update_counts.clear();
491 self.pending_updates.clear();
492 self.current_round = 0;
493
494 for nodeid in 0..self.averager.numnodes() {
495 self.update_counts.insert(nodeid, 0);
496 }
497 }
498}
499
500#[derive(Debug)]
502pub struct DistributedCoordinator<A: Float, D: Dimension> {
503 parameter_server: ParameterServer<A, D>,
505 communication_rounds: usize,
507 convergence_threshold: A,
509 max_rounds: usize,
511 training_stats: TrainingStats<A>,
513}
514
515impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
516 DistributedCoordinator<A, D>
517{
518 pub fn new(
520 strategy: AveragingStrategy,
521 numnodes: usize,
522 expected_updates_per_round: usize,
523 max_rounds: usize,
524 ) -> Self {
525 Self {
526 parameter_server: ParameterServer::new(strategy, numnodes, expected_updates_per_round),
527 communication_rounds: 0,
528 convergence_threshold: A::from(1e-6).expect("unwrap failed"),
529 max_rounds,
530 training_stats: TrainingStats::new(),
531 }
532 }
533
534 pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
536 self.parameter_server.initialize(initialparams)?;
537 self.training_stats
538 .record_round(0, A::zero(), initialparams);
539 Ok(())
540 }
541
542 pub fn communication_round(
544 &mut self,
545 node_updates: Vec<(usize, Vec<Array<A, D>>)>,
546 ) -> Result<CommunicationResult<A, D>> {
547 let mut aggregated = false;
548
549 for (nodeid, params) in node_updates {
551 aggregated = self.parameter_server.submit_update(nodeid, params)? || aggregated;
552 }
553
554 if !aggregated {
556 self.parameter_server.force_aggregation()?;
557 aggregated = true;
558 }
559
560 if aggregated {
561 self.communication_rounds += 1;
562
563 let currentparams = self.parameter_server.get_global_parameters();
565 let convergence_metric = self.compute_convergence_metric(currentparams);
566
567 self.training_stats.record_round(
568 self.communication_rounds,
569 convergence_metric,
570 currentparams,
571 );
572
573 let converged = convergence_metric < self.convergence_threshold;
574 let max_rounds_reached = self.communication_rounds >= self.max_rounds;
575
576 Ok(CommunicationResult {
577 round: self.communication_rounds,
578 global_parameters: self.parameter_server.get_global_parameters_cloned(),
579 converged,
580 should_continue: !converged && !max_rounds_reached,
581 convergence_metric,
582 stats: self.training_stats.clone(),
583 })
584 } else {
585 Ok(CommunicationResult {
586 round: self.communication_rounds,
587 global_parameters: self.parameter_server.get_global_parameters_cloned(),
588 converged: false,
589 should_continue: true,
590 convergence_metric: A::infinity(),
591 stats: self.training_stats.clone(),
592 })
593 }
594 }
595
596 pub fn set_convergence_threshold(&mut self, threshold: A) {
598 self.convergence_threshold = threshold;
599 }
600
601 pub fn parameter_server(&self) -> &ParameterServer<A, D> {
603 &self.parameter_server
604 }
605
606 pub fn parameter_server_mut(&mut self) -> &mut ParameterServer<A, D> {
608 &mut self.parameter_server
609 }
610
611 fn compute_convergence_metric(&self, currentparams: &[Array<A, D>]) -> A {
613 if let Some(prev_params) = self.training_stats.get_previous_parameters::<D>() {
614 let mut total_change = A::zero();
615 let mut total_norm = A::zero();
616
617 for (curr, prev) in currentparams.iter().zip(prev_params.iter()) {
618 for (&c, &p) in curr.iter().zip(prev.iter()) {
619 let diff = c - p;
620 total_change = total_change + diff * diff;
621 total_norm = total_norm + c * c;
622 }
623 }
624
625 if total_norm > A::zero() {
626 (total_change / total_norm).sqrt()
627 } else {
628 A::zero()
629 }
630 } else {
631 A::infinity()
632 }
633 }
634}
635
636#[derive(Debug, Clone)]
638pub struct CommunicationResult<A: Float, D: Dimension> {
639 pub round: usize,
641 pub global_parameters: Vec<Array<A, D>>,
643 pub converged: bool,
645 pub should_continue: bool,
647 pub convergence_metric: A,
649 pub stats: TrainingStats<A>,
651}
652
653#[derive(Debug, Clone)]
655pub struct TrainingStats<A: Float> {
656 convergence_history: Vec<A>,
658 round_times: Vec<usize>,
660 previous_parameters: Option<Vec<u8>>, }
663
664impl<A: Float + Send + Sync> TrainingStats<A> {
665 pub fn new() -> Self {
667 Self {
668 convergence_history: Vec::new(),
669 round_times: Vec::new(),
670 previous_parameters: None,
671 }
672 }
673
674 pub fn record_round<D: Dimension>(
676 &mut self,
677 round: usize,
678 convergence_metric: A,
679 parameters: &[Array<A, D>],
680 ) {
681 self.convergence_history.push(convergence_metric);
682 self.round_times.push(round);
683
684 self.previous_parameters = Some(vec![0u8; parameters.len()]);
687 }
688
689 pub fn convergence_history(&self) -> &[A] {
691 &self.convergence_history
692 }
693
694 pub fn latest_convergence(&self) -> Option<A> {
696 self.convergence_history.last().copied()
697 }
698
699 pub fn num_rounds(&self) -> usize {
701 self.round_times.len()
702 }
703
704 fn get_previous_parameters<D: Dimension>(&self) -> Option<Vec<Array<A, D>>> {
706 None
708 }
709}
710
711impl<A: Float + Send + Sync> Default for TrainingStats<A> {
712 fn default() -> Self {
713 Self::new()
714 }
715}
716
717#[derive(Debug, Clone, PartialEq)]
719pub enum CompressionStrategy {
720 None,
722 TopK {
724 k: usize,
726 },
727 RandomK {
729 k: usize,
731 },
732 Threshold {
734 threshold: f64,
736 },
737 Quantization {
739 bits: u8,
741 },
742 ErrorFeedback {
744 base_strategy: Box<CompressionStrategy>,
746 error_compensation: bool,
748 },
749 ClippedCompression {
751 base_strategy: Box<CompressionStrategy>,
753 clip_value: f64,
755 },
756}
757
758#[derive(Debug, Clone)]
760pub struct CompressedGradient<A: Float> {
761 pub data: Vec<u8>,
763 pub metadata: CompressionMetadata<A>,
765 pub shapes: Vec<Vec<usize>>,
767}
768
769#[derive(Debug, Clone)]
771pub struct CompressionMetadata<A: Float> {
772 pub strategy: CompressionStrategy,
774 pub compression_ratio: f64,
776 pub nnz_count: usize,
778 pub scale_factors: Vec<A>,
780 pub extra_data: Vec<u8>,
782}
783
784#[derive(Debug)]
786pub struct GradientCompressor<A: Float, D: Dimension> {
787 strategy: CompressionStrategy,
789 error_state: Option<Vec<Array<A, D>>>,
791 stats: CompressionStats,
793}
794
795impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
796 GradientCompressor<A, D>
797{
798 pub fn new(strategy: CompressionStrategy) -> Self {
800 Self {
801 strategy,
802 error_state: None,
803 stats: CompressionStats::new(),
804 }
805 }
806
807 pub fn initialize_error_state(&mut self, gradientshapes: &[Array<A, D>]) {
809 self.error_state = Some(
810 gradientshapes
811 .iter()
812 .map(|g| Array::zeros(g.raw_dim()))
813 .collect(),
814 );
815 }
816
817 pub fn compress(&mut self, gradients: &[Array<A, D>]) -> Result<CompressedGradient<A>> {
819 let mut working_gradients: Vec<Array<A, D>> =
821 if let Some(ref mut error_state) = self.error_state {
822 gradients
823 .iter()
824 .zip(error_state.iter())
825 .map(|(grad, error)| grad + error)
826 .collect()
827 } else {
828 gradients.to_vec()
829 };
830
831 let (compressed_data, metadata) = match &self.strategy {
832 CompressionStrategy::None => self.compress_none(&working_gradients)?,
833 CompressionStrategy::TopK { k } => self.compress_topk(&working_gradients, *k)?,
834 CompressionStrategy::RandomK { k } => self.compress_randomk(&working_gradients, *k)?,
835 CompressionStrategy::Threshold { threshold } => self.compress_threshold(
836 &working_gradients,
837 A::from(*threshold).expect("unwrap failed"),
838 )?,
839 CompressionStrategy::Quantization { bits } => {
840 self.compress_quantization(&working_gradients, *bits)?
841 }
842 CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
843 let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
845 let compressed = temp_compressor.compress(&working_gradients)?;
846 let decompressed = temp_compressor.decompress(&compressed)?;
847
848 if let Some(ref mut error_state) = self.error_state {
850 for ((original, decompressed), error) in gradients
851 .iter()
852 .zip(decompressed.iter())
853 .zip(error_state.iter_mut())
854 {
855 *error = original - decompressed;
856 }
857 }
858
859 (compressed.data, compressed.metadata)
860 }
861 CompressionStrategy::ClippedCompression {
862 base_strategy,
863 clip_value,
864 } => {
865 let clip_val = A::from(*clip_value).expect("unwrap failed");
867 for grad in &mut working_gradients {
868 grad.mapv_inplace(|x| {
869 if x > clip_val {
870 clip_val
871 } else if x < -clip_val {
872 -clip_val
873 } else {
874 x
875 }
876 });
877 }
878
879 let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
881 let compressed = temp_compressor.compress(&working_gradients)?;
882 (compressed.data, compressed.metadata)
883 }
884 };
885
886 let shapes = gradients.iter().map(|g| g.shape().to_vec()).collect();
888
889 let result = CompressedGradient {
890 data: compressed_data,
891 metadata,
892 shapes,
893 };
894
895 let original_size = self.calculate_size(gradients);
897 let compressed_size = result.data.len();
898 self.stats
899 .record_compression(original_size, compressed_size);
900
901 Ok(result)
902 }
903
904 pub fn decompress(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
906 match &compressed.metadata.strategy {
907 CompressionStrategy::None => self.decompress_none(compressed),
908 CompressionStrategy::TopK { .. } => self.decompress_sparse(compressed),
909 CompressionStrategy::RandomK { .. } => self.decompress_sparse(compressed),
910 CompressionStrategy::Threshold { .. } => self.decompress_sparse(compressed),
911 CompressionStrategy::Quantization { bits } => {
912 self.decompress_quantization(compressed, *bits)
913 }
914 CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
915 let temp_compressor = GradientCompressor::new((**base_strategy).clone());
916 temp_compressor.decompress(compressed)
917 }
918 CompressionStrategy::ClippedCompression { base_strategy, .. } => {
919 let temp_compressor = GradientCompressor::new((**base_strategy).clone());
920 temp_compressor.decompress(compressed)
921 }
922 }
923 }
924
925 fn compress_none(
927 &self,
928 gradients: &[Array<A, D>],
929 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
930 let mut data = Vec::new();
931
932 for grad in gradients {
934 for &val in grad.iter() {
935 data.extend_from_slice(&val.to_f64().expect("unwrap failed").to_le_bytes());
936 }
937 }
938
939 let metadata = CompressionMetadata {
940 strategy: CompressionStrategy::None,
941 compression_ratio: 1.0,
942 nnz_count: gradients.iter().map(|g| g.len()).sum(),
943 scale_factors: Vec::new(),
944 extra_data: Vec::new(),
945 };
946
947 Ok((data, metadata))
948 }
949
950 fn compress_topk(
952 &self,
953 gradients: &[Array<A, D>],
954 k: usize,
955 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
956 let mut indices = Vec::new();
957 let mut values = Vec::new();
958 let mut total_elements = 0;
959
960 for (grad_idx, grad) in gradients.iter().enumerate() {
961 total_elements += grad.len();
962
963 let mut value_indices: Vec<(A, usize)> = grad
965 .iter()
966 .enumerate()
967 .map(|(i, &val)| (val.abs(), i))
968 .collect();
969
970 value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("unwrap failed"));
972
973 let k_local = k.min(value_indices.len());
975 for (_, orig_idx) in value_indices.iter().take(k_local) {
976 indices.push((grad_idx as u32, *orig_idx as u32));
977 values.push(grad.iter().nth(*orig_idx).copied().expect("unwrap failed"));
978 }
979 }
980
981 let mut data = Vec::new();
983
984 data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
986
987 for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
989 data.extend_from_slice(&grad_idx.to_le_bytes());
990 data.extend_from_slice(&elem_idx.to_le_bytes());
991 data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
992 }
993
994 let metadata = CompressionMetadata {
995 strategy: CompressionStrategy::TopK { k },
996 compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
997 nnz_count: indices.len(),
998 scale_factors: Vec::new(),
999 extra_data: Vec::new(),
1000 };
1001
1002 Ok((data, metadata))
1003 }
1004
1005 fn compress_randomk(
1007 &self,
1008 gradients: &[Array<A, D>],
1009 k: usize,
1010 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1011 let mut indices = Vec::new();
1012 let mut values = Vec::new();
1013 let mut total_elements = 0;
1014
1015 for (grad_idx, grad) in gradients.iter().enumerate() {
1016 total_elements += grad.len();
1017
1018 let k_local = k.min(grad.len());
1020 let mut selected_indices: Vec<usize> = (0..grad.len()).collect();
1021
1022 for i in 0..k_local {
1024 let swap_idx = i + ((grad_idx + i) % (grad.len() - i));
1025 selected_indices.swap(i, swap_idx);
1026 }
1027
1028 for &idx in selected_indices.iter().take(k_local) {
1029 indices.push((grad_idx as u32, idx as u32));
1030 values.push(grad.iter().nth(idx).copied().expect("unwrap failed"));
1031 }
1032 }
1033
1034 let mut data = Vec::new();
1036 data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
1037
1038 for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
1039 data.extend_from_slice(&grad_idx.to_le_bytes());
1040 data.extend_from_slice(&elem_idx.to_le_bytes());
1041 data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
1042 }
1043
1044 let metadata = CompressionMetadata {
1045 strategy: CompressionStrategy::RandomK { k },
1046 compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1047 nnz_count: indices.len(),
1048 scale_factors: Vec::new(),
1049 extra_data: Vec::new(),
1050 };
1051
1052 Ok((data, metadata))
1053 }
1054
1055 fn compress_threshold(
1057 &self,
1058 gradients: &[Array<A, D>],
1059 threshold: A,
1060 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1061 let mut indices = Vec::new();
1062 let mut values = Vec::new();
1063 let mut total_elements = 0;
1064
1065 for (grad_idx, grad) in gradients.iter().enumerate() {
1066 total_elements += grad.len();
1067
1068 for (elem_idx, &val) in grad.iter().enumerate() {
1069 if val.abs() > threshold {
1070 indices.push((grad_idx as u32, elem_idx as u32));
1071 values.push(val);
1072 }
1073 }
1074 }
1075
1076 let mut data = Vec::new();
1078 data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
1079
1080 for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
1081 data.extend_from_slice(&grad_idx.to_le_bytes());
1082 data.extend_from_slice(&elem_idx.to_le_bytes());
1083 data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
1084 }
1085
1086 let metadata = CompressionMetadata {
1087 strategy: CompressionStrategy::Threshold {
1088 threshold: threshold.to_f64().expect("unwrap failed"),
1089 },
1090 compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1091 nnz_count: indices.len(),
1092 scale_factors: Vec::new(),
1093 extra_data: Vec::new(),
1094 };
1095
1096 Ok((data, metadata))
1097 }
1098
1099 fn compress_quantization(
1101 &self,
1102 gradients: &[Array<A, D>],
1103 bits: u8,
1104 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1105 if bits > 32 {
1106 return Err(OptimError::InvalidConfig(
1107 "Quantization bits must be <= 32".to_string(),
1108 ));
1109 }
1110
1111 let mut data = Vec::new();
1112 let mut scale_factors = Vec::new();
1113 let levels = (1u64 << bits) - 1;
1114
1115 for grad in gradients {
1116 let min_val = grad.iter().fold(A::infinity(), |acc, &x| acc.min(x));
1118 let max_val = grad.iter().fold(A::neg_infinity(), |acc, &x| acc.max(x));
1119
1120 let range = max_val - min_val;
1121 let scale = if range > A::zero() {
1122 range / A::from(levels).expect("unwrap failed")
1123 } else {
1124 A::one()
1125 };
1126
1127 scale_factors.push(scale);
1128
1129 for &val in grad.iter() {
1131 let normalized = (val - min_val) / scale;
1132 let quantized = normalized.to_u64().expect("unwrap failed").min(levels) as u32;
1133
1134 match bits {
1136 1..=8 => data.push(quantized as u8),
1137 9..=16 => data.extend_from_slice(&(quantized as u16).to_le_bytes()),
1138 17..=32 => data.extend_from_slice(&quantized.to_le_bytes()),
1139 _ => unreachable!(),
1140 }
1141 }
1142
1143 data.extend_from_slice(&min_val.to_f64().expect("unwrap failed").to_le_bytes());
1145 }
1146
1147 let total_elements: usize = gradients.iter().map(|g| g.len()).sum();
1148 let metadata = CompressionMetadata {
1149 strategy: CompressionStrategy::Quantization { bits },
1150 compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1151 nnz_count: total_elements,
1152 scale_factors,
1153 extra_data: Vec::new(),
1154 };
1155
1156 Ok((data, metadata))
1157 }
1158
1159 fn decompress_none(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
1161 let mut result = Vec::new();
1162 let mut data_offset = 0;
1163
1164 for shape in &compressed.shapes {
1165 let num_elements: usize = shape.iter().product();
1166 let mut values = Vec::with_capacity(num_elements);
1167
1168 for _ in 0..num_elements {
1169 if data_offset + 8 > compressed.data.len() {
1170 return Err(OptimError::InvalidConfig(
1171 "Insufficient data for decompression".to_string(),
1172 ));
1173 }
1174
1175 let bytes = &compressed.data[data_offset..data_offset + 8];
1176 let value = f64::from_le_bytes(bytes.try_into().expect("unwrap failed"));
1177 values.push(A::from(value).expect("unwrap failed"));
1178 data_offset += 8;
1179 }
1180
1181 let dynamic_array = Array::from_shape_vec(shape.as_slice(), values).map_err(|_| {
1183 OptimError::InvalidConfig("Invalid shape for reconstruction".to_string())
1184 })?;
1185 let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1186 OptimError::InvalidConfig("Dimension conversion failed".to_string())
1187 })?;
1188 result.push(array);
1189 }
1190
1191 Ok(result)
1192 }
1193
1194 fn decompress_sparse(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
1196 let mut result = Vec::new();
1197
1198 for shape in &compressed.shapes {
1200 let dynamic_array = Array::zeros(shape.as_slice());
1201 let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1202 OptimError::InvalidConfig("Dimension conversion failed for zero array".to_string())
1203 })?;
1204 result.push(array);
1205 }
1206
1207 if compressed.data.len() < 4 {
1209 return Err(OptimError::InvalidConfig(
1210 "Invalid compressed data format".to_string(),
1211 ));
1212 }
1213
1214 let num_elements =
1215 u32::from_le_bytes(compressed.data[0..4].try_into().expect("unwrap failed")) as usize;
1216 let mut data_offset = 4;
1217
1218 for _ in 0..num_elements {
1220 if data_offset + 16 > compressed.data.len() {
1221 return Err(OptimError::InvalidConfig(
1222 "Insufficient data for sparse decompression".to_string(),
1223 ));
1224 }
1225
1226 let grad_idx = u32::from_le_bytes(
1227 compressed.data[data_offset..data_offset + 4]
1228 .try_into()
1229 .expect("unwrap failed"),
1230 ) as usize;
1231 let elem_idx = u32::from_le_bytes(
1232 compressed.data[data_offset + 4..data_offset + 8]
1233 .try_into()
1234 .expect("unwrap failed"),
1235 ) as usize;
1236 let value_bytes = &compressed.data[data_offset + 8..data_offset + 16];
1237 let value = A::from(f64::from_le_bytes(
1238 value_bytes.try_into().expect("unwrap failed"),
1239 ))
1240 .expect("unwrap failed");
1241
1242 data_offset += 16;
1243
1244 if grad_idx >= result.len() {
1245 return Err(OptimError::InvalidConfig(
1246 "Invalid gradient index in compressed data".to_string(),
1247 ));
1248 }
1249
1250 if let Some(elem) = result[grad_idx].iter_mut().nth(elem_idx) {
1251 *elem = value;
1252 } else {
1253 return Err(OptimError::InvalidConfig(
1254 "Invalid element index in compressed data".to_string(),
1255 ));
1256 }
1257 }
1258
1259 Ok(result)
1260 }
1261
1262 fn decompress_quantization(
1264 &self,
1265 compressed: &CompressedGradient<A>,
1266 bits: u8,
1267 ) -> Result<Vec<Array<A, D>>> {
1268 let mut result = Vec::new();
1269 let mut data_offset = 0;
1270 let _levels = (1u64 << bits) - 1;
1271
1272 for (grad_idx, shape) in compressed.shapes.iter().enumerate() {
1273 let num_elements: usize = shape.iter().product();
1274 let mut values = Vec::with_capacity(num_elements);
1275
1276 for _ in 0..num_elements {
1278 let quantized = match bits {
1279 1..=8 => {
1280 if data_offset >= compressed.data.len() {
1281 return Err(OptimError::InvalidConfig(
1282 "Insufficient quantized data".to_string(),
1283 ));
1284 }
1285 let val = compressed.data[data_offset] as u32;
1286 data_offset += 1;
1287 val
1288 }
1289 9..=16 => {
1290 if data_offset + 2 > compressed.data.len() {
1291 return Err(OptimError::InvalidConfig(
1292 "Insufficient quantized data".to_string(),
1293 ));
1294 }
1295 let val = u16::from_le_bytes(
1296 compressed.data[data_offset..data_offset + 2]
1297 .try_into()
1298 .expect("unwrap failed"),
1299 ) as u32;
1300 data_offset += 2;
1301 val
1302 }
1303 17..=32 => {
1304 if data_offset + 4 > compressed.data.len() {
1305 return Err(OptimError::InvalidConfig(
1306 "Insufficient quantized data".to_string(),
1307 ));
1308 }
1309 let val = u32::from_le_bytes(
1310 compressed.data[data_offset..data_offset + 4]
1311 .try_into()
1312 .expect("unwrap failed"),
1313 );
1314 data_offset += 4;
1315 val
1316 }
1317 _ => {
1318 return Err(OptimError::InvalidConfig(
1319 "Invalid quantization bits".to_string(),
1320 ))
1321 }
1322 };
1323
1324 values.push(quantized);
1325 }
1326
1327 if data_offset + 8 > compressed.data.len() {
1329 return Err(OptimError::InvalidConfig(
1330 "Missing min value for quantization".to_string(),
1331 ));
1332 }
1333 let min_bytes = &compressed.data[data_offset..data_offset + 8];
1334 let min_val = A::from(f64::from_le_bytes(
1335 min_bytes.try_into().expect("unwrap failed"),
1336 ))
1337 .expect("unwrap failed");
1338 data_offset += 8;
1339
1340 let scale = if grad_idx < compressed.metadata.scale_factors.len() {
1342 compressed.metadata.scale_factors[grad_idx]
1343 } else {
1344 return Err(OptimError::InvalidConfig(
1345 "Missing scale factor for quantization".to_string(),
1346 ));
1347 };
1348
1349 let dequantized_values: Vec<A> = values
1351 .into_iter()
1352 .map(|q| min_val + A::from(q).expect("unwrap failed") * scale)
1353 .collect();
1354
1355 let dynamic_array = Array::from_shape_vec(shape.as_slice(), dequantized_values)
1356 .map_err(|_| {
1357 OptimError::InvalidConfig(
1358 "Invalid shape for quantized reconstruction".to_string(),
1359 )
1360 })?;
1361 let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1362 OptimError::InvalidConfig(
1363 "Dimension conversion failed for quantized array".to_string(),
1364 )
1365 })?;
1366 result.push(array);
1367 }
1368
1369 Ok(result)
1370 }
1371
1372 fn calculate_size(&self, gradients: &[Array<A, D>]) -> usize {
1374 gradients
1375 .iter()
1376 .map(|g| g.len() * std::mem::size_of::<A>())
1377 .sum()
1378 }
1379
1380 pub fn stats(&self) -> &CompressionStats {
1382 &self.stats
1383 }
1384
1385 pub fn reset_stats(&mut self) {
1387 self.stats = CompressionStats::new();
1388 }
1389}
1390
1391#[derive(Debug, Clone)]
1393pub struct CompressionStats {
1394 pub compressions_count: usize,
1396 pub total_original_bytes: usize,
1398 pub total_compressed_bytes: usize,
1400 pub average_compression_ratio: f64,
1402 pub best_compression_ratio: f64,
1404 pub worst_compression_ratio: f64,
1406}
1407
1408impl CompressionStats {
1409 pub fn new() -> Self {
1411 Self {
1412 compressions_count: 0,
1413 total_original_bytes: 0,
1414 total_compressed_bytes: 0,
1415 average_compression_ratio: 0.0,
1416 best_compression_ratio: f64::INFINITY,
1417 worst_compression_ratio: 0.0,
1418 }
1419 }
1420
1421 pub fn record_compression(&mut self, original_bytes: usize, compressedbytes: usize) {
1423 self.compressions_count += 1;
1424 self.total_original_bytes += original_bytes;
1425 self.total_compressed_bytes += compressedbytes;
1426
1427 let ratio = if original_bytes > 0 {
1428 compressedbytes as f64 / original_bytes as f64
1429 } else {
1430 1.0
1431 };
1432
1433 self.best_compression_ratio = self.best_compression_ratio.min(ratio);
1434 self.worst_compression_ratio = self.worst_compression_ratio.max(ratio);
1435
1436 self.average_compression_ratio = if self.total_original_bytes > 0 {
1437 self.total_compressed_bytes as f64 / self.total_original_bytes as f64
1438 } else {
1439 0.0
1440 };
1441 }
1442
1443 pub fn overall_compression_ratio(&self) -> f64 {
1445 self.average_compression_ratio
1446 }
1447
1448 pub fn bandwidth_savings(&self) -> f64 {
1450 (1.0 - self.average_compression_ratio) * 100.0
1451 }
1452}
1453
1454impl Default for CompressionStats {
1455 fn default() -> Self {
1456 Self::new()
1457 }
1458}
1459
1460#[cfg(test)]
1461mod tests {
1462 use super::*;
1463 use approx::assert_relative_eq;
1464 use scirs2_core::ndarray::Array1;
1465
1466 #[test]
1467 fn test_arithmetic_averaging() {
1468 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1469 ParameterAverager::new(AveragingStrategy::Arithmetic, 3);
1470
1471 let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1472 let params2 = vec![Array1::from_vec(vec![3.0, 4.0])];
1473 let params3 = vec![Array1::from_vec(vec![5.0, 6.0])];
1474
1475 let nodeparameters = vec![(0, params1), (1, params2), (2, params3)];
1476
1477 averager
1478 .average_parameters(&nodeparameters)
1479 .expect("unwrap failed");
1480
1481 let result = averager.get_averaged_parameters();
1482 assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6); assert_relative_eq!(result[0][1], 4.0, epsilon = 1e-6); }
1485
1486 #[test]
1487 fn test_weighted_averaging() {
1488 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1489 ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
1490
1491 let params1 = vec![Array1::from_vec(vec![2.0])];
1493 let params2 = vec![Array1::from_vec(vec![6.0])];
1494 let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
1495 averager.initialize(¶ms1).expect("unwrap failed");
1496
1497 averager.set_node_weight(0, 0.75).expect("unwrap failed"); averager.set_node_weight(1, 0.25).expect("unwrap failed"); averager
1502 .average_parameters(&nodeparameters)
1503 .expect("unwrap failed");
1504
1505 let result = averager.get_averaged_parameters();
1506 assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6);
1508 }
1509
1510 #[test]
1511 fn test_momentum_averaging() {
1512 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1513 ParameterAverager::new(AveragingStrategy::Momentum { momentum: 0.9 }, 2);
1514
1515 let params1 = vec![Array1::from_vec(vec![1.0])];
1516 let params2 = vec![Array1::from_vec(vec![3.0])];
1517
1518 let node_parameters1 = vec![(0, params1.clone()), (1, params2.clone())];
1520 averager
1521 .average_parameters(&node_parameters1)
1522 .expect("unwrap failed");
1523
1524 let result1 = averager.get_averaged_parameters();
1525 assert!(result1[0][0] >= 0.0 && result1[0][0] <= 0.5);
1527
1528 for _ in 0..10 {
1530 let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
1531 averager
1532 .average_parameters(&nodeparameters)
1533 .expect("unwrap failed");
1534 }
1535
1536 let final_result = averager.get_averaged_parameters();
1537 assert!(final_result[0][0] > 0.5 && final_result[0][0] < 2.5);
1540 }
1541
1542 #[test]
1543 fn test_parameter_server() {
1544 let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
1545
1546 let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
1547 server.initialize(&initialparams).expect("unwrap failed");
1548
1549 let update1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1551 let update2 = vec![Array1::from_vec(vec![3.0, 4.0])];
1552
1553 let ready1 = server.submit_update(0, update1).expect("unwrap failed");
1554 assert!(!ready1); let ready2 = server.submit_update(1, update2).expect("unwrap failed");
1557 assert!(ready2); let global_params = server.get_global_parameters();
1560 assert_relative_eq!(global_params[0][0], 2.0, epsilon = 1e-6); assert_relative_eq!(global_params[0][1], 3.0, epsilon = 1e-6); assert_eq!(server.current_round(), 1);
1564 }
1565
1566 #[test]
1567 fn test_distributed_coordinator() {
1568 let mut coordinator = DistributedCoordinator::new(
1569 AveragingStrategy::Arithmetic,
1570 2, 2, 10, );
1574
1575 let initialparams = vec![Array1::from_vec(vec![0.0])];
1576 coordinator
1577 .initialize(&initialparams)
1578 .expect("unwrap failed");
1579
1580 for round in 1..=3 {
1582 let update1 = vec![Array1::from_vec(vec![round as f64])];
1583 let update2 = vec![Array1::from_vec(vec![(round * 2) as f64])];
1584
1585 let node_updates = vec![(0, update1), (1, update2)];
1586
1587 let result = coordinator
1588 .communication_round(node_updates)
1589 .expect("unwrap failed");
1590
1591 assert_eq!(result.round, round);
1592 assert!(result.should_continue);
1593 assert!(!result.converged); assert!(result.global_parameters[0][0] > 0.0);
1597 }
1598 }
1599
1600 #[test]
1601 fn test_averaging_strategies() {
1602 let simple_strategies = vec![
1604 AveragingStrategy::Arithmetic,
1605 AveragingStrategy::WeightedByData,
1606 AveragingStrategy::Federated,
1607 ];
1608
1609 for strategy in simple_strategies {
1610 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1611 ParameterAverager::new(strategy, 2);
1612
1613 let params1 = vec![Array1::from_vec(vec![1.0])];
1614 let params2 = vec![Array1::from_vec(vec![3.0])];
1615
1616 let nodeparameters = vec![(0, params1), (1, params2)];
1617
1618 averager
1619 .average_parameters(&nodeparameters)
1620 .expect("unwrap failed");
1621 let result = averager.get_averaged_parameters();
1622 assert!(result[0][0] >= 1.0 && result[0][0] <= 3.0);
1623 }
1624
1625 let stateful_strategies = vec![
1627 AveragingStrategy::Momentum { momentum: 0.9 },
1628 AveragingStrategy::ExponentialMovingAverage { decay: 0.9 },
1629 ];
1630
1631 for strategy in stateful_strategies {
1632 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1633 ParameterAverager::new(strategy, 2);
1634
1635 let params1 = vec![Array1::from_vec(vec![1.0])];
1636 let params2 = vec![Array1::from_vec(vec![3.0])];
1637
1638 let nodeparameters = vec![(0, params1), (1, params2)];
1639
1640 averager
1641 .average_parameters(&nodeparameters)
1642 .expect("unwrap failed");
1643 let result = averager.get_averaged_parameters();
1644 assert!(result[0][0] >= 0.0 && result[0][0] <= 3.0);
1646 }
1647 }
1648
1649 #[test]
1650 fn test_node_weight_validation() {
1651 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1652 ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
1653
1654 assert!(averager.set_node_weight(0, 0.5).is_ok());
1656 assert!(averager.set_node_weight(1, 0.5).is_ok());
1657
1658 assert!(averager.set_node_weight(2, 0.5).is_err());
1660 }
1661
1662 #[test]
1663 fn test_parameter_dimension_validation() {
1664 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1665 ParameterAverager::new(AveragingStrategy::Arithmetic, 2);
1666
1667 let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1668 let params2 = vec![Array1::from_vec(vec![3.0])]; let nodeparameters = vec![(0, params1), (1, params2)];
1671
1672 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1674 averager.average_parameters(&nodeparameters)
1675 }));
1676
1677 assert!(result.is_err() || (result.is_ok() && result.expect("unwrap failed").is_err()));
1679 }
1680
1681 #[test]
1682 fn test_training_stats() {
1683 let mut stats = TrainingStats::new();
1684
1685 assert_eq!(stats.num_rounds(), 0);
1686 assert!(stats.latest_convergence().is_none());
1687
1688 let params = vec![Array1::from_vec(vec![1.0])];
1689 stats.record_round(1, 0.5, ¶ms);
1690
1691 assert_eq!(stats.num_rounds(), 1);
1692 assert_eq!(stats.latest_convergence(), Some(0.5));
1693 assert_eq!(stats.convergence_history(), &[0.5]);
1694 }
1695
1696 #[test]
1697 fn test_gradient_compression_none() {
1698 let mut compressor = GradientCompressor::new(CompressionStrategy::None);
1699
1700 let gradients = vec![
1701 Array1::from_vec(vec![1.0, 2.0, 3.0]),
1702 Array1::from_vec(vec![4.0, 5.0]),
1703 ];
1704
1705 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1706 assert_eq!(compressed.metadata.strategy, CompressionStrategy::None);
1707 assert_eq!(compressed.metadata.compression_ratio, 1.0);
1708
1709 let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1710 assert_eq!(decompressed.len(), 2);
1711 assert_eq!(
1712 decompressed[0].as_slice().expect("unwrap failed"),
1713 &[1.0, 2.0, 3.0]
1714 );
1715 assert_eq!(
1716 decompressed[1].as_slice().expect("unwrap failed"),
1717 &[4.0, 5.0]
1718 );
1719 }
1720
1721 #[test]
1722 fn test_gradient_compression_topk() {
1723 let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 2 });
1724
1725 let gradients = vec![Array1::from_vec(vec![0.1, 3.0, 0.2, 4.0, 0.05])];
1726
1727 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1728 assert!(compressed.metadata.compression_ratio < 1.0);
1729 assert_eq!(compressed.metadata.nnz_count, 2); let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1732 assert_eq!(decompressed.len(), 1);
1733
1734 let result = &decompressed[0];
1736 assert_eq!(result[1], 3.0); assert_eq!(result[3], 4.0); assert_eq!(result[0], 0.0); assert_eq!(result[2], 0.0); assert_eq!(result[4], 0.0); }
1742
1743 #[test]
1744 fn test_gradient_compression_threshold() {
1745 let mut compressor =
1746 GradientCompressor::new(CompressionStrategy::Threshold { threshold: 1.0 });
1747
1748 let gradients = vec![Array1::from_vec(vec![0.5, 2.0, 0.8, 3.0, 0.3])];
1749
1750 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1751 assert!(compressed.metadata.compression_ratio < 1.0);
1752 assert_eq!(compressed.metadata.nnz_count, 2); let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1755 let result = &decompressed[0];
1756
1757 assert_eq!(result[0], 0.0); assert_eq!(result[1], 2.0); assert_eq!(result[2], 0.0); assert_eq!(result[3], 3.0); assert_eq!(result[4], 0.0); }
1764
1765 #[test]
1766 fn test_gradient_compression_quantization() {
1767 let mut compressor = GradientCompressor::new(CompressionStrategy::Quantization { bits: 8 });
1768
1769 let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
1770
1771 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1772 assert!(compressed.metadata.compression_ratio < 1.0); let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1775 let result = &decompressed[0];
1776
1777 assert!((result[0] - 1.0).abs() < 0.1);
1779 assert!((result[1] - 2.0).abs() < 0.1);
1780 assert!((result[2] - 3.0).abs() < 0.1);
1781 assert!((result[3] - 4.0).abs() < 0.1);
1782 }
1783
1784 #[test]
1785 fn test_gradient_compression_randomk() {
1786 let mut compressor = GradientCompressor::new(CompressionStrategy::RandomK { k: 3 });
1787
1788 let gradients = vec![Array1::from_vec(vec![
1790 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0,
1791 ])];
1792
1793 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1794 assert!(compressed.metadata.compression_ratio < 1.0);
1796 assert_eq!(compressed.metadata.nnz_count, 3); let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1799 let result = &decompressed[0];
1800
1801 let non_zero_count = result.iter().filter(|&&x| x != 0.0).count();
1803 assert_eq!(non_zero_count, 3);
1804 }
1805
1806 #[test]
1807 fn test_gradient_compression_error_feedback() {
1808 let base_strategy = CompressionStrategy::TopK { k: 2 };
1809 let strategy = CompressionStrategy::ErrorFeedback {
1810 base_strategy: Box::new(base_strategy),
1811 error_compensation: true,
1812 };
1813
1814 let mut compressor = GradientCompressor::new(strategy);
1815
1816 let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
1817
1818 compressor.initialize_error_state(&gradients);
1820
1821 let compressed1 = compressor.compress(&gradients).expect("unwrap failed");
1823 let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
1824
1825 let compressed2 = compressor.compress(&gradients).expect("unwrap failed");
1827 let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
1828
1829 assert_eq!(decompressed1.len(), 1);
1831 assert_eq!(decompressed2.len(), 1);
1832 }
1833
1834 #[test]
1835 fn test_gradient_compression_clipped() {
1836 let base_strategy = CompressionStrategy::TopK { k: 3 };
1837 let strategy = CompressionStrategy::ClippedCompression {
1838 base_strategy: Box::new(base_strategy),
1839 clip_value: 2.5,
1840 };
1841
1842 let mut compressor = GradientCompressor::new(strategy);
1843
1844 let gradients = vec![Array1::from_vec(vec![1.0, 5.0, -3.0, 2.0])];
1845
1846 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1847 let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1848
1849 let result = &decompressed[0];
1850
1851 for &val in result.iter() {
1853 if val != 0.0 {
1854 assert!((-2.5..=2.5).contains(&val));
1856 }
1857 }
1858 }
1859
1860 #[test]
1861 fn test_compression_stats() {
1862 let mut stats = CompressionStats::new();
1863
1864 assert_eq!(stats.compressions_count, 0);
1865 assert_eq!(stats.overall_compression_ratio(), 0.0);
1866
1867 stats.record_compression(1000, 500); assert_eq!(stats.compressions_count, 1);
1870 assert_relative_eq!(stats.overall_compression_ratio(), 0.5, epsilon = 1e-6);
1871 assert_relative_eq!(stats.bandwidth_savings(), 50.0, epsilon = 1e-6);
1872
1873 stats.record_compression(1000, 250); assert_eq!(stats.compressions_count, 2);
1875 assert_relative_eq!(stats.overall_compression_ratio(), 0.375, epsilon = 1e-6); assert_relative_eq!(stats.bandwidth_savings(), 62.5, epsilon = 1e-6);
1877
1878 assert_relative_eq!(stats.best_compression_ratio, 0.25, epsilon = 1e-6);
1879 assert_relative_eq!(stats.worst_compression_ratio, 0.5, epsilon = 1e-6);
1880 }
1881
1882 #[test]
1883 fn test_compression_roundtrip() {
1884 let strategies = vec![
1885 CompressionStrategy::None,
1886 CompressionStrategy::TopK { k: 2 },
1887 CompressionStrategy::RandomK { k: 2 },
1888 CompressionStrategy::Threshold { threshold: 1.5 },
1889 CompressionStrategy::Quantization { bits: 4 },
1890 ];
1891
1892 let gradients = vec![
1893 Array1::from_vec(vec![1.0, 2.5, 0.5, 3.0]),
1894 Array1::from_vec(vec![0.1, 4.0]),
1895 ];
1896
1897 for strategy in strategies {
1898 let mut compressor = GradientCompressor::new(strategy.clone());
1899
1900 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1901 let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1902
1903 assert_eq!(decompressed.len(), gradients.len());
1905
1906 for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
1908 assert_eq!(orig.shape(), decomp.shape());
1909 }
1910
1911 match strategy {
1913 CompressionStrategy::None => {
1914 for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
1915 for (&o, &d) in orig.iter().zip(decomp.iter()) {
1916 assert_relative_eq!(o, d, epsilon = 1e-10);
1917 }
1918 }
1919 }
1920 _ => {
1921 for decomp in &decompressed {
1923 assert!(decomp.iter().all(|&x| x.is_finite()));
1924 }
1925 }
1926 }
1927 }
1928 }
1929
1930 #[test]
1931 fn test_compression_invalid_configs() {
1932 let strategy = CompressionStrategy::Quantization { bits: 64 };
1934 let mut compressor = GradientCompressor::new(strategy);
1935
1936 let gradients = vec![Array1::from_vec(vec![1.0, 2.0])];
1937 assert!(compressor.compress(&gradients).is_err());
1938
1939 let valid_compressor: GradientCompressor<f64, scirs2_core::ndarray::Ix1> =
1941 GradientCompressor::new(CompressionStrategy::None);
1942 let invalid_compressed = CompressedGradient {
1943 data: vec![1, 2, 3], metadata: CompressionMetadata {
1945 strategy: CompressionStrategy::None,
1946 compression_ratio: 1.0,
1947 nnz_count: 1,
1948 scale_factors: vec![],
1949 extra_data: vec![],
1950 },
1951 shapes: vec![vec![2]],
1952 };
1953
1954 assert!(valid_compressor.decompress(&invalid_compressed).is_err());
1955 }
1956
1957 #[test]
1958 fn test_distributed_with_compression() {
1959 let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
1961 let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
1962 server.initialize(&initialparams).expect("unwrap failed");
1963
1964 let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 1 });
1965
1966 let gradients1 = vec![Array1::from_vec(vec![1.0, 3.0])]; let gradients2 = vec![Array1::from_vec(vec![2.0, 1.0])]; let compressed1 = compressor.compress(&gradients1).expect("unwrap failed");
1971 let compressed2 = compressor.compress(&gradients2).expect("unwrap failed");
1972
1973 let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
1974 let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
1975
1976 server
1978 .submit_update(0, decompressed1)
1979 .expect("unwrap failed");
1980 server
1981 .submit_update(1, decompressed2)
1982 .expect("unwrap failed");
1983
1984 let global_params = server.get_global_parameters();
1985
1986 assert_relative_eq!(global_params[0][0], 1.0, epsilon = 1e-6);
1990 assert_relative_eq!(global_params[0][1], 1.5, epsilon = 1e-6);
1991 }
1992}