1pub mod fedprox;
7pub use fedprox::*;
8
9use crate::error::{OptimError, Result};
10use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
11use scirs2_core::numeric::Float;
12use std::collections::HashMap;
13use std::fmt::Debug;
14
15#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum AveragingStrategy {
18 Arithmetic,
20 WeightedByData,
22 WeightedByTime,
24 Federated,
26 Momentum {
28 momentum: f64,
30 },
31 ExponentialMovingAverage {
33 decay: f64,
35 },
36}
37
38#[derive(Debug)]
40pub struct ParameterAverager<A: Float, D: Dimension> {
41 averaged_params: Vec<Array<A, D>>,
43 strategy: AveragingStrategy,
45 node_weights: HashMap<usize, A>,
47 numnodes: usize,
49 momentum_buffer: Option<Vec<Array<A, D>>>,
51 step_count: usize,
53 initialized: bool,
55}
56
57impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
58 ParameterAverager<A, D>
59{
60 pub fn new(strategy: AveragingStrategy, numnodes: usize) -> Self {
62 Self {
63 averaged_params: Vec::new(),
64 strategy,
65 node_weights: HashMap::new(),
66 numnodes,
67 momentum_buffer: None,
68 step_count: 0,
69 initialized: false,
70 }
71 }
72
73 pub fn initialize(&mut self, params: &[Array<A, D>]) -> Result<()> {
75 if self.initialized {
76 return Err(OptimError::InvalidConfig(
77 "Parameter averager already initialized".to_string(),
78 ));
79 }
80
81 self.averaged_params = params.to_vec();
82
83 if matches!(self.strategy, AveragingStrategy::Momentum { .. }) {
85 self.momentum_buffer = Some(params.iter().map(|p| Array::zeros(p.raw_dim())).collect());
86 }
87
88 let uniform_weight = A::one() / A::from(self.numnodes).expect("unwrap failed");
90 for nodeid in 0..self.numnodes {
91 self.node_weights.insert(nodeid, uniform_weight);
92 }
93
94 self.initialized = true;
95 Ok(())
96 }
97
98 pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
100 if nodeid >= self.numnodes {
101 return Err(OptimError::InvalidConfig(format!(
102 "Node ID {} exceeds number of nodes {}",
103 nodeid, self.numnodes
104 )));
105 }
106 self.node_weights.insert(nodeid, weight);
107 Ok(())
108 }
109
110 pub fn average_parameters(
112 &mut self,
113 nodeparameters: &[(usize, Vec<Array<A, D>>)],
114 ) -> Result<()> {
115 if !self.initialized {
116 if let Some((_, first_params)) = nodeparameters.first() {
117 self.initialize(first_params)?;
118 } else {
119 return Err(OptimError::InvalidConfig(
120 "No _parameters provided for initialization".to_string(),
121 ));
122 }
123 }
124
125 for (nodeid, params) in nodeparameters {
127 if *nodeid >= self.numnodes {
128 return Err(OptimError::InvalidConfig(format!(
129 "Node ID {} exceeds number of nodes {}",
130 nodeid, self.numnodes
131 )));
132 }
133 if params.len() != self.averaged_params.len() {
134 return Err(OptimError::DimensionMismatch(format!(
135 "Expected {} parameter arrays, got {}",
136 self.averaged_params.len(),
137 params.len()
138 )));
139 }
140 }
141
142 self.step_count += 1;
143
144 match self.strategy {
145 AveragingStrategy::Arithmetic => {
146 self.arithmetic_average(nodeparameters)?;
147 }
148 AveragingStrategy::WeightedByData | AveragingStrategy::WeightedByTime => {
149 self.weighted_average(nodeparameters)?;
150 }
151 AveragingStrategy::Federated => {
152 self.federated_average(nodeparameters)?;
153 }
154 AveragingStrategy::Momentum { momentum } => {
155 self.momentum_average(nodeparameters, momentum)?;
156 }
157 AveragingStrategy::ExponentialMovingAverage { decay } => {
158 self.ema_average(nodeparameters, decay)?;
159 }
160 }
161
162 Ok(())
163 }
164
165 fn arithmetic_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
167 for param in &mut self.averaged_params {
169 param.fill(A::zero());
170 }
171
172 let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
173
174 for (_node_id, params) in nodeparameters {
176 for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
177 Zip::from(avg_param).and(param).for_each(|avg, &p| {
178 *avg = *avg + p;
179 });
180 }
181 }
182
183 for param in &mut self.averaged_params {
185 param.mapv_inplace(|x| x / numnodes);
186 }
187
188 Ok(())
189 }
190
191 fn weighted_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
193 for param in &mut self.averaged_params {
195 param.fill(A::zero());
196 }
197
198 let total_weight: A = nodeparameters
200 .iter()
201 .map(|(nodeid, _)| self.node_weights.get(nodeid).copied().unwrap_or(A::zero()))
202 .fold(A::zero(), |acc, w| acc + w);
203
204 if total_weight <= A::zero() {
205 return Err(OptimError::InvalidConfig(
206 "Total node weights must be > 0".to_string(),
207 ));
208 }
209
210 for (nodeid, params) in nodeparameters {
212 let weight = self.node_weights.get(nodeid).copied().unwrap_or(A::zero()) / total_weight;
213
214 for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
215 Zip::from(avg_param).and(param).for_each(|avg, &p| {
216 *avg = *avg + weight * p;
217 });
218 }
219 }
220
221 Ok(())
222 }
223
224 fn federated_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
226 self.weighted_average(nodeparameters)
229 }
230
231 fn momentum_average(
233 &mut self,
234 nodeparameters: &[(usize, Vec<Array<A, D>>)],
235 momentum: f64,
236 ) -> Result<()> {
237 let momentum_factor = A::from(momentum).expect("unwrap failed");
238 let one_minus_momentum = A::one() - momentum_factor;
239
240 let mut current_average: Vec<Array<A, D>> = self
242 .averaged_params
243 .iter()
244 .map(|param| Array::zeros(param.raw_dim()))
245 .collect();
246
247 let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
248 for (_node_id, params) in nodeparameters {
249 for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
250 Zip::from(avg_param).and(param).for_each(|avg, &p| {
251 *avg = *avg + p / numnodes;
252 });
253 }
254 }
255
256 if let Some(ref mut momentum_buf) = self.momentum_buffer {
258 for ((avg_param, current_param), momentum_param) in self
259 .averaged_params
260 .iter_mut()
261 .zip(current_average.iter())
262 .zip(momentum_buf.iter_mut())
263 {
264 Zip::from(&mut *momentum_param)
266 .and(current_param)
267 .for_each(|mom, &curr| {
268 *mom = momentum_factor * *mom + one_minus_momentum * curr;
269 });
270
271 avg_param.assign(&*momentum_param);
273 }
274 }
275
276 Ok(())
277 }
278
279 fn ema_average(
281 &mut self,
282 nodeparameters: &[(usize, Vec<Array<A, D>>)],
283 decay: f64,
284 ) -> Result<()> {
285 let decay_factor = A::from(decay).expect("unwrap failed");
286 let one_minus_decay = A::one() - decay_factor;
287
288 let mut current_average: Vec<Array<A, D>> = self
290 .averaged_params
291 .iter()
292 .map(|param| Array::zeros(param.raw_dim()))
293 .collect();
294
295 let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
296 for (_node_id, params) in nodeparameters {
297 for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
298 Zip::from(avg_param).and(param).for_each(|avg, &p| {
299 *avg = *avg + p / numnodes;
300 });
301 }
302 }
303
304 for (avg_param, current_param) in
306 self.averaged_params.iter_mut().zip(current_average.iter())
307 {
308 Zip::from(avg_param)
309 .and(current_param)
310 .for_each(|avg, &curr| {
311 *avg = decay_factor * *avg + one_minus_decay * curr;
312 });
313 }
314
315 Ok(())
316 }
317
318 pub fn get_averaged_parameters(&self) -> &[Array<A, D>] {
320 &self.averaged_params
321 }
322
323 pub fn get_averaged_parameters_cloned(&self) -> Vec<Array<A, D>> {
325 self.averaged_params.clone()
326 }
327
328 pub fn reset(&mut self) {
330 self.step_count = 0;
331 for param in &mut self.averaged_params {
332 param.fill(A::zero());
333 }
334 if let Some(ref mut momentum_buf) = self.momentum_buffer {
335 for buf in momentum_buf {
336 buf.fill(A::zero());
337 }
338 }
339 }
340
341 pub fn step_count(&self) -> usize {
343 self.step_count
344 }
345
346 pub fn numnodes(&self) -> usize {
348 self.numnodes
349 }
350
351 pub fn strategy(&self) -> AveragingStrategy {
353 self.strategy
354 }
355
356 pub fn is_initialized(&self) -> bool {
358 self.initialized
359 }
360}
361
362#[derive(Debug)]
364pub struct ParameterServer<A: Float, D: Dimension> {
365 averager: ParameterAverager<A, D>,
367 global_parameters: Vec<Array<A, D>>,
369 update_counts: HashMap<usize, usize>,
371 expected_updates_per_round: usize,
373 current_round: usize,
375 pending_updates: HashMap<usize, Vec<Array<A, D>>>,
377}
378
379impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
380 ParameterServer<A, D>
381{
382 pub fn new(
384 strategy: AveragingStrategy,
385 numnodes: usize,
386 expected_updates_per_round: usize,
387 ) -> Self {
388 Self {
389 averager: ParameterAverager::new(strategy, numnodes),
390 global_parameters: Vec::new(),
391 update_counts: HashMap::new(),
392 expected_updates_per_round,
393 current_round: 0,
394 pending_updates: HashMap::new(),
395 }
396 }
397
398 pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
400 self.averager.initialize(initialparams)?;
401 self.global_parameters = initialparams.to_vec();
402
403 for nodeid in 0..self.averager.numnodes() {
405 self.update_counts.insert(nodeid, 0);
406 }
407
408 Ok(())
409 }
410
411 pub fn submit_update(&mut self, nodeid: usize, parameters: Vec<Array<A, D>>) -> Result<bool> {
413 if nodeid >= self.averager.numnodes() {
414 return Err(OptimError::InvalidConfig(format!(
415 "Node ID {} exceeds number of nodes {}",
416 nodeid,
417 self.averager.numnodes()
418 )));
419 }
420
421 self.pending_updates.insert(nodeid, parameters);
423 *self.update_counts.entry(nodeid).or_insert(0) += 1;
424
425 let ready_for_aggregation = self.pending_updates.len() >= self.expected_updates_per_round;
427
428 if ready_for_aggregation {
429 self.aggregate_and_update()?;
430 }
431
432 Ok(ready_for_aggregation)
433 }
434
435 pub fn force_aggregation(&mut self) -> Result<()> {
437 if !self.pending_updates.is_empty() {
438 self.aggregate_and_update()?;
439 }
440 Ok(())
441 }
442
443 fn aggregate_and_update(&mut self) -> Result<()> {
445 let node_params: Vec<(usize, Vec<Array<A, D>>)> = self.pending_updates.drain().collect();
447
448 self.averager.average_parameters(&node_params)?;
450
451 self.global_parameters = self.averager.get_averaged_parameters_cloned();
453
454 self.current_round += 1;
456
457 Ok(())
458 }
459
460 pub fn get_global_parameters(&self) -> &[Array<A, D>] {
462 &self.global_parameters
463 }
464
465 pub fn get_global_parameters_cloned(&self) -> Vec<Array<A, D>> {
467 self.global_parameters.clone()
468 }
469
470 pub fn current_round(&self) -> usize {
472 self.current_round
473 }
474
475 pub fn get_update_count(&self, nodeid: usize) -> usize {
477 self.update_counts.get(&nodeid).copied().unwrap_or(0)
478 }
479
480 pub fn pending_updates_count(&self) -> usize {
482 self.pending_updates.len()
483 }
484
485 pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
487 self.averager.set_node_weight(nodeid, weight)
488 }
489
490 pub fn reset(&mut self) {
492 self.averager.reset();
493 self.update_counts.clear();
494 self.pending_updates.clear();
495 self.current_round = 0;
496
497 for nodeid in 0..self.averager.numnodes() {
498 self.update_counts.insert(nodeid, 0);
499 }
500 }
501}
502
503#[derive(Debug)]
505pub struct DistributedCoordinator<A: Float, D: Dimension> {
506 parameter_server: ParameterServer<A, D>,
508 communication_rounds: usize,
510 convergence_threshold: A,
512 max_rounds: usize,
514 training_stats: TrainingStats<A>,
516}
517
518impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
519 DistributedCoordinator<A, D>
520{
521 pub fn new(
523 strategy: AveragingStrategy,
524 numnodes: usize,
525 expected_updates_per_round: usize,
526 max_rounds: usize,
527 ) -> Self {
528 Self {
529 parameter_server: ParameterServer::new(strategy, numnodes, expected_updates_per_round),
530 communication_rounds: 0,
531 convergence_threshold: A::from(1e-6).expect("unwrap failed"),
532 max_rounds,
533 training_stats: TrainingStats::new(),
534 }
535 }
536
537 pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
539 self.parameter_server.initialize(initialparams)?;
540 self.training_stats
541 .record_round(0, A::zero(), initialparams);
542 Ok(())
543 }
544
545 pub fn communication_round(
547 &mut self,
548 node_updates: Vec<(usize, Vec<Array<A, D>>)>,
549 ) -> Result<CommunicationResult<A, D>> {
550 let mut aggregated = false;
551
552 for (nodeid, params) in node_updates {
554 aggregated = self.parameter_server.submit_update(nodeid, params)? || aggregated;
555 }
556
557 if !aggregated {
559 self.parameter_server.force_aggregation()?;
560 aggregated = true;
561 }
562
563 if aggregated {
564 self.communication_rounds += 1;
565
566 let currentparams = self.parameter_server.get_global_parameters();
568 let convergence_metric = self.compute_convergence_metric(currentparams);
569
570 self.training_stats.record_round(
571 self.communication_rounds,
572 convergence_metric,
573 currentparams,
574 );
575
576 let converged = convergence_metric < self.convergence_threshold;
577 let max_rounds_reached = self.communication_rounds >= self.max_rounds;
578
579 Ok(CommunicationResult {
580 round: self.communication_rounds,
581 global_parameters: self.parameter_server.get_global_parameters_cloned(),
582 converged,
583 should_continue: !converged && !max_rounds_reached,
584 convergence_metric,
585 stats: self.training_stats.clone(),
586 })
587 } else {
588 Ok(CommunicationResult {
589 round: self.communication_rounds,
590 global_parameters: self.parameter_server.get_global_parameters_cloned(),
591 converged: false,
592 should_continue: true,
593 convergence_metric: A::infinity(),
594 stats: self.training_stats.clone(),
595 })
596 }
597 }
598
599 pub fn set_convergence_threshold(&mut self, threshold: A) {
601 self.convergence_threshold = threshold;
602 }
603
604 pub fn parameter_server(&self) -> &ParameterServer<A, D> {
606 &self.parameter_server
607 }
608
609 pub fn parameter_server_mut(&mut self) -> &mut ParameterServer<A, D> {
611 &mut self.parameter_server
612 }
613
614 fn compute_convergence_metric(&self, currentparams: &[Array<A, D>]) -> A {
616 if let Some(prev_params) = self.training_stats.get_previous_parameters::<D>() {
617 let mut total_change = A::zero();
618 let mut total_norm = A::zero();
619
620 for (curr, prev) in currentparams.iter().zip(prev_params.iter()) {
621 for (&c, &p) in curr.iter().zip(prev.iter()) {
622 let diff = c - p;
623 total_change = total_change + diff * diff;
624 total_norm = total_norm + c * c;
625 }
626 }
627
628 if total_norm > A::zero() {
629 (total_change / total_norm).sqrt()
630 } else {
631 A::zero()
632 }
633 } else {
634 A::infinity()
635 }
636 }
637}
638
639#[derive(Debug, Clone)]
641pub struct CommunicationResult<A: Float, D: Dimension> {
642 pub round: usize,
644 pub global_parameters: Vec<Array<A, D>>,
646 pub converged: bool,
648 pub should_continue: bool,
650 pub convergence_metric: A,
652 pub stats: TrainingStats<A>,
654}
655
656#[derive(Debug, Clone)]
658pub struct TrainingStats<A: Float> {
659 convergence_history: Vec<A>,
661 round_times: Vec<usize>,
663 previous_parameters: Option<Vec<u8>>, }
666
667impl<A: Float + Send + Sync> TrainingStats<A> {
668 pub fn new() -> Self {
670 Self {
671 convergence_history: Vec::new(),
672 round_times: Vec::new(),
673 previous_parameters: None,
674 }
675 }
676
677 pub fn record_round<D: Dimension>(
679 &mut self,
680 round: usize,
681 convergence_metric: A,
682 parameters: &[Array<A, D>],
683 ) {
684 self.convergence_history.push(convergence_metric);
685 self.round_times.push(round);
686
687 self.previous_parameters = Some(vec![0u8; parameters.len()]);
690 }
691
692 pub fn convergence_history(&self) -> &[A] {
694 &self.convergence_history
695 }
696
697 pub fn latest_convergence(&self) -> Option<A> {
699 self.convergence_history.last().copied()
700 }
701
702 pub fn num_rounds(&self) -> usize {
704 self.round_times.len()
705 }
706
707 fn get_previous_parameters<D: Dimension>(&self) -> Option<Vec<Array<A, D>>> {
709 None
711 }
712}
713
714impl<A: Float + Send + Sync> Default for TrainingStats<A> {
715 fn default() -> Self {
716 Self::new()
717 }
718}
719
720#[derive(Debug, Clone, PartialEq)]
722pub enum CompressionStrategy {
723 None,
725 TopK {
727 k: usize,
729 },
730 RandomK {
732 k: usize,
734 },
735 Threshold {
737 threshold: f64,
739 },
740 Quantization {
742 bits: u8,
744 },
745 ErrorFeedback {
747 base_strategy: Box<CompressionStrategy>,
749 error_compensation: bool,
751 },
752 ClippedCompression {
754 base_strategy: Box<CompressionStrategy>,
756 clip_value: f64,
758 },
759}
760
761#[derive(Debug, Clone)]
763pub struct CompressedGradient<A: Float> {
764 pub data: Vec<u8>,
766 pub metadata: CompressionMetadata<A>,
768 pub shapes: Vec<Vec<usize>>,
770}
771
772#[derive(Debug, Clone)]
774pub struct CompressionMetadata<A: Float> {
775 pub strategy: CompressionStrategy,
777 pub compression_ratio: f64,
779 pub nnz_count: usize,
781 pub scale_factors: Vec<A>,
783 pub extra_data: Vec<u8>,
785}
786
787#[derive(Debug)]
789pub struct GradientCompressor<A: Float, D: Dimension> {
790 strategy: CompressionStrategy,
792 error_state: Option<Vec<Array<A, D>>>,
794 stats: CompressionStats,
796}
797
798impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
799 GradientCompressor<A, D>
800{
801 pub fn new(strategy: CompressionStrategy) -> Self {
803 Self {
804 strategy,
805 error_state: None,
806 stats: CompressionStats::new(),
807 }
808 }
809
810 pub fn initialize_error_state(&mut self, gradientshapes: &[Array<A, D>]) {
812 self.error_state = Some(
813 gradientshapes
814 .iter()
815 .map(|g| Array::zeros(g.raw_dim()))
816 .collect(),
817 );
818 }
819
820 pub fn compress(&mut self, gradients: &[Array<A, D>]) -> Result<CompressedGradient<A>> {
822 let mut working_gradients: Vec<Array<A, D>> =
824 if let Some(ref mut error_state) = self.error_state {
825 gradients
826 .iter()
827 .zip(error_state.iter())
828 .map(|(grad, error)| grad + error)
829 .collect()
830 } else {
831 gradients.to_vec()
832 };
833
834 let (compressed_data, metadata) = match &self.strategy {
835 CompressionStrategy::None => self.compress_none(&working_gradients)?,
836 CompressionStrategy::TopK { k } => self.compress_topk(&working_gradients, *k)?,
837 CompressionStrategy::RandomK { k } => self.compress_randomk(&working_gradients, *k)?,
838 CompressionStrategy::Threshold { threshold } => self.compress_threshold(
839 &working_gradients,
840 A::from(*threshold).expect("unwrap failed"),
841 )?,
842 CompressionStrategy::Quantization { bits } => {
843 self.compress_quantization(&working_gradients, *bits)?
844 }
845 CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
846 let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
848 let compressed = temp_compressor.compress(&working_gradients)?;
849 let decompressed = temp_compressor.decompress(&compressed)?;
850
851 if let Some(ref mut error_state) = self.error_state {
853 for ((original, decompressed), error) in gradients
854 .iter()
855 .zip(decompressed.iter())
856 .zip(error_state.iter_mut())
857 {
858 *error = original - decompressed;
859 }
860 }
861
862 (compressed.data, compressed.metadata)
863 }
864 CompressionStrategy::ClippedCompression {
865 base_strategy,
866 clip_value,
867 } => {
868 let clip_val = A::from(*clip_value).expect("unwrap failed");
870 for grad in &mut working_gradients {
871 grad.mapv_inplace(|x| {
872 if x > clip_val {
873 clip_val
874 } else if x < -clip_val {
875 -clip_val
876 } else {
877 x
878 }
879 });
880 }
881
882 let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
884 let compressed = temp_compressor.compress(&working_gradients)?;
885 (compressed.data, compressed.metadata)
886 }
887 };
888
889 let shapes = gradients.iter().map(|g| g.shape().to_vec()).collect();
891
892 let result = CompressedGradient {
893 data: compressed_data,
894 metadata,
895 shapes,
896 };
897
898 let original_size = self.calculate_size(gradients);
900 let compressed_size = result.data.len();
901 self.stats
902 .record_compression(original_size, compressed_size);
903
904 Ok(result)
905 }
906
907 pub fn decompress(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
909 match &compressed.metadata.strategy {
910 CompressionStrategy::None => self.decompress_none(compressed),
911 CompressionStrategy::TopK { .. } => self.decompress_sparse(compressed),
912 CompressionStrategy::RandomK { .. } => self.decompress_sparse(compressed),
913 CompressionStrategy::Threshold { .. } => self.decompress_sparse(compressed),
914 CompressionStrategy::Quantization { bits } => {
915 self.decompress_quantization(compressed, *bits)
916 }
917 CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
918 let temp_compressor = GradientCompressor::new((**base_strategy).clone());
919 temp_compressor.decompress(compressed)
920 }
921 CompressionStrategy::ClippedCompression { base_strategy, .. } => {
922 let temp_compressor = GradientCompressor::new((**base_strategy).clone());
923 temp_compressor.decompress(compressed)
924 }
925 }
926 }
927
928 fn compress_none(
930 &self,
931 gradients: &[Array<A, D>],
932 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
933 let mut data = Vec::new();
934
935 for grad in gradients {
937 for &val in grad.iter() {
938 data.extend_from_slice(&val.to_f64().expect("unwrap failed").to_le_bytes());
939 }
940 }
941
942 let metadata = CompressionMetadata {
943 strategy: CompressionStrategy::None,
944 compression_ratio: 1.0,
945 nnz_count: gradients.iter().map(|g| g.len()).sum(),
946 scale_factors: Vec::new(),
947 extra_data: Vec::new(),
948 };
949
950 Ok((data, metadata))
951 }
952
953 fn compress_topk(
955 &self,
956 gradients: &[Array<A, D>],
957 k: usize,
958 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
959 let mut indices = Vec::new();
960 let mut values = Vec::new();
961 let mut total_elements = 0;
962
963 for (grad_idx, grad) in gradients.iter().enumerate() {
964 total_elements += grad.len();
965
966 let mut value_indices: Vec<(A, usize)> = grad
968 .iter()
969 .enumerate()
970 .map(|(i, &val)| (val.abs(), i))
971 .collect();
972
973 value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("unwrap failed"));
975
976 let k_local = k.min(value_indices.len());
978 for (_, orig_idx) in value_indices.iter().take(k_local) {
979 indices.push((grad_idx as u32, *orig_idx as u32));
980 values.push(grad.iter().nth(*orig_idx).copied().expect("unwrap failed"));
981 }
982 }
983
984 let mut data = Vec::new();
986
987 data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
989
990 for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
992 data.extend_from_slice(&grad_idx.to_le_bytes());
993 data.extend_from_slice(&elem_idx.to_le_bytes());
994 data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
995 }
996
997 let metadata = CompressionMetadata {
998 strategy: CompressionStrategy::TopK { k },
999 compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1000 nnz_count: indices.len(),
1001 scale_factors: Vec::new(),
1002 extra_data: Vec::new(),
1003 };
1004
1005 Ok((data, metadata))
1006 }
1007
1008 fn compress_randomk(
1010 &self,
1011 gradients: &[Array<A, D>],
1012 k: usize,
1013 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1014 let mut indices = Vec::new();
1015 let mut values = Vec::new();
1016 let mut total_elements = 0;
1017
1018 for (grad_idx, grad) in gradients.iter().enumerate() {
1019 total_elements += grad.len();
1020
1021 let k_local = k.min(grad.len());
1023 let mut selected_indices: Vec<usize> = (0..grad.len()).collect();
1024
1025 for i in 0..k_local {
1027 let swap_idx = i + ((grad_idx + i) % (grad.len() - i));
1028 selected_indices.swap(i, swap_idx);
1029 }
1030
1031 for &idx in selected_indices.iter().take(k_local) {
1032 indices.push((grad_idx as u32, idx as u32));
1033 values.push(grad.iter().nth(idx).copied().expect("unwrap failed"));
1034 }
1035 }
1036
1037 let mut data = Vec::new();
1039 data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
1040
1041 for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
1042 data.extend_from_slice(&grad_idx.to_le_bytes());
1043 data.extend_from_slice(&elem_idx.to_le_bytes());
1044 data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
1045 }
1046
1047 let metadata = CompressionMetadata {
1048 strategy: CompressionStrategy::RandomK { k },
1049 compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1050 nnz_count: indices.len(),
1051 scale_factors: Vec::new(),
1052 extra_data: Vec::new(),
1053 };
1054
1055 Ok((data, metadata))
1056 }
1057
1058 fn compress_threshold(
1060 &self,
1061 gradients: &[Array<A, D>],
1062 threshold: A,
1063 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1064 let mut indices = Vec::new();
1065 let mut values = Vec::new();
1066 let mut total_elements = 0;
1067
1068 for (grad_idx, grad) in gradients.iter().enumerate() {
1069 total_elements += grad.len();
1070
1071 for (elem_idx, &val) in grad.iter().enumerate() {
1072 if val.abs() > threshold {
1073 indices.push((grad_idx as u32, elem_idx as u32));
1074 values.push(val);
1075 }
1076 }
1077 }
1078
1079 let mut data = Vec::new();
1081 data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
1082
1083 for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
1084 data.extend_from_slice(&grad_idx.to_le_bytes());
1085 data.extend_from_slice(&elem_idx.to_le_bytes());
1086 data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
1087 }
1088
1089 let metadata = CompressionMetadata {
1090 strategy: CompressionStrategy::Threshold {
1091 threshold: threshold.to_f64().expect("unwrap failed"),
1092 },
1093 compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1094 nnz_count: indices.len(),
1095 scale_factors: Vec::new(),
1096 extra_data: Vec::new(),
1097 };
1098
1099 Ok((data, metadata))
1100 }
1101
1102 fn compress_quantization(
1104 &self,
1105 gradients: &[Array<A, D>],
1106 bits: u8,
1107 ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1108 if bits > 32 {
1109 return Err(OptimError::InvalidConfig(
1110 "Quantization bits must be <= 32".to_string(),
1111 ));
1112 }
1113
1114 let mut data = Vec::new();
1115 let mut scale_factors = Vec::new();
1116 let levels = (1u64 << bits) - 1;
1117
1118 for grad in gradients {
1119 let min_val = grad.iter().fold(A::infinity(), |acc, &x| acc.min(x));
1121 let max_val = grad.iter().fold(A::neg_infinity(), |acc, &x| acc.max(x));
1122
1123 let range = max_val - min_val;
1124 let scale = if range > A::zero() {
1125 range / A::from(levels).expect("unwrap failed")
1126 } else {
1127 A::one()
1128 };
1129
1130 scale_factors.push(scale);
1131
1132 for &val in grad.iter() {
1134 let normalized = (val - min_val) / scale;
1135 let quantized = normalized.to_u64().expect("unwrap failed").min(levels) as u32;
1136
1137 match bits {
1139 1..=8 => data.push(quantized as u8),
1140 9..=16 => data.extend_from_slice(&(quantized as u16).to_le_bytes()),
1141 17..=32 => data.extend_from_slice(&quantized.to_le_bytes()),
1142 _ => unreachable!(),
1143 }
1144 }
1145
1146 data.extend_from_slice(&min_val.to_f64().expect("unwrap failed").to_le_bytes());
1148 }
1149
1150 let total_elements: usize = gradients.iter().map(|g| g.len()).sum();
1151 let metadata = CompressionMetadata {
1152 strategy: CompressionStrategy::Quantization { bits },
1153 compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1154 nnz_count: total_elements,
1155 scale_factors,
1156 extra_data: Vec::new(),
1157 };
1158
1159 Ok((data, metadata))
1160 }
1161
1162 fn decompress_none(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
1164 let mut result = Vec::new();
1165 let mut data_offset = 0;
1166
1167 for shape in &compressed.shapes {
1168 let num_elements: usize = shape.iter().product();
1169 let mut values = Vec::with_capacity(num_elements);
1170
1171 for _ in 0..num_elements {
1172 if data_offset + 8 > compressed.data.len() {
1173 return Err(OptimError::InvalidConfig(
1174 "Insufficient data for decompression".to_string(),
1175 ));
1176 }
1177
1178 let bytes = &compressed.data[data_offset..data_offset + 8];
1179 let value = f64::from_le_bytes(bytes.try_into().expect("unwrap failed"));
1180 values.push(A::from(value).expect("unwrap failed"));
1181 data_offset += 8;
1182 }
1183
1184 let dynamic_array = Array::from_shape_vec(shape.as_slice(), values).map_err(|_| {
1186 OptimError::InvalidConfig("Invalid shape for reconstruction".to_string())
1187 })?;
1188 let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1189 OptimError::InvalidConfig("Dimension conversion failed".to_string())
1190 })?;
1191 result.push(array);
1192 }
1193
1194 Ok(result)
1195 }
1196
1197 fn decompress_sparse(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
1199 let mut result = Vec::new();
1200
1201 for shape in &compressed.shapes {
1203 let dynamic_array = Array::zeros(shape.as_slice());
1204 let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1205 OptimError::InvalidConfig("Dimension conversion failed for zero array".to_string())
1206 })?;
1207 result.push(array);
1208 }
1209
1210 if compressed.data.len() < 4 {
1212 return Err(OptimError::InvalidConfig(
1213 "Invalid compressed data format".to_string(),
1214 ));
1215 }
1216
1217 let num_elements =
1218 u32::from_le_bytes(compressed.data[0..4].try_into().expect("unwrap failed")) as usize;
1219 let mut data_offset = 4;
1220
1221 for _ in 0..num_elements {
1223 if data_offset + 16 > compressed.data.len() {
1224 return Err(OptimError::InvalidConfig(
1225 "Insufficient data for sparse decompression".to_string(),
1226 ));
1227 }
1228
1229 let grad_idx = u32::from_le_bytes(
1230 compressed.data[data_offset..data_offset + 4]
1231 .try_into()
1232 .expect("unwrap failed"),
1233 ) as usize;
1234 let elem_idx = u32::from_le_bytes(
1235 compressed.data[data_offset + 4..data_offset + 8]
1236 .try_into()
1237 .expect("unwrap failed"),
1238 ) as usize;
1239 let value_bytes = &compressed.data[data_offset + 8..data_offset + 16];
1240 let value = A::from(f64::from_le_bytes(
1241 value_bytes.try_into().expect("unwrap failed"),
1242 ))
1243 .expect("unwrap failed");
1244
1245 data_offset += 16;
1246
1247 if grad_idx >= result.len() {
1248 return Err(OptimError::InvalidConfig(
1249 "Invalid gradient index in compressed data".to_string(),
1250 ));
1251 }
1252
1253 if let Some(elem) = result[grad_idx].iter_mut().nth(elem_idx) {
1254 *elem = value;
1255 } else {
1256 return Err(OptimError::InvalidConfig(
1257 "Invalid element index in compressed data".to_string(),
1258 ));
1259 }
1260 }
1261
1262 Ok(result)
1263 }
1264
1265 fn decompress_quantization(
1267 &self,
1268 compressed: &CompressedGradient<A>,
1269 bits: u8,
1270 ) -> Result<Vec<Array<A, D>>> {
1271 let mut result = Vec::new();
1272 let mut data_offset = 0;
1273 let _levels = (1u64 << bits) - 1;
1274
1275 for (grad_idx, shape) in compressed.shapes.iter().enumerate() {
1276 let num_elements: usize = shape.iter().product();
1277 let mut values = Vec::with_capacity(num_elements);
1278
1279 for _ in 0..num_elements {
1281 let quantized = match bits {
1282 1..=8 => {
1283 if data_offset >= compressed.data.len() {
1284 return Err(OptimError::InvalidConfig(
1285 "Insufficient quantized data".to_string(),
1286 ));
1287 }
1288 let val = compressed.data[data_offset] as u32;
1289 data_offset += 1;
1290 val
1291 }
1292 9..=16 => {
1293 if data_offset + 2 > compressed.data.len() {
1294 return Err(OptimError::InvalidConfig(
1295 "Insufficient quantized data".to_string(),
1296 ));
1297 }
1298 let val = u16::from_le_bytes(
1299 compressed.data[data_offset..data_offset + 2]
1300 .try_into()
1301 .expect("unwrap failed"),
1302 ) as u32;
1303 data_offset += 2;
1304 val
1305 }
1306 17..=32 => {
1307 if data_offset + 4 > compressed.data.len() {
1308 return Err(OptimError::InvalidConfig(
1309 "Insufficient quantized data".to_string(),
1310 ));
1311 }
1312 let val = u32::from_le_bytes(
1313 compressed.data[data_offset..data_offset + 4]
1314 .try_into()
1315 .expect("unwrap failed"),
1316 );
1317 data_offset += 4;
1318 val
1319 }
1320 _ => {
1321 return Err(OptimError::InvalidConfig(
1322 "Invalid quantization bits".to_string(),
1323 ))
1324 }
1325 };
1326
1327 values.push(quantized);
1328 }
1329
1330 if data_offset + 8 > compressed.data.len() {
1332 return Err(OptimError::InvalidConfig(
1333 "Missing min value for quantization".to_string(),
1334 ));
1335 }
1336 let min_bytes = &compressed.data[data_offset..data_offset + 8];
1337 let min_val = A::from(f64::from_le_bytes(
1338 min_bytes.try_into().expect("unwrap failed"),
1339 ))
1340 .expect("unwrap failed");
1341 data_offset += 8;
1342
1343 let scale = if grad_idx < compressed.metadata.scale_factors.len() {
1345 compressed.metadata.scale_factors[grad_idx]
1346 } else {
1347 return Err(OptimError::InvalidConfig(
1348 "Missing scale factor for quantization".to_string(),
1349 ));
1350 };
1351
1352 let dequantized_values: Vec<A> = values
1354 .into_iter()
1355 .map(|q| min_val + A::from(q).expect("unwrap failed") * scale)
1356 .collect();
1357
1358 let dynamic_array = Array::from_shape_vec(shape.as_slice(), dequantized_values)
1359 .map_err(|_| {
1360 OptimError::InvalidConfig(
1361 "Invalid shape for quantized reconstruction".to_string(),
1362 )
1363 })?;
1364 let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1365 OptimError::InvalidConfig(
1366 "Dimension conversion failed for quantized array".to_string(),
1367 )
1368 })?;
1369 result.push(array);
1370 }
1371
1372 Ok(result)
1373 }
1374
1375 fn calculate_size(&self, gradients: &[Array<A, D>]) -> usize {
1377 gradients
1378 .iter()
1379 .map(|g| g.len() * std::mem::size_of::<A>())
1380 .sum()
1381 }
1382
1383 pub fn stats(&self) -> &CompressionStats {
1385 &self.stats
1386 }
1387
1388 pub fn reset_stats(&mut self) {
1390 self.stats = CompressionStats::new();
1391 }
1392}
1393
1394#[derive(Debug, Clone)]
1396pub struct CompressionStats {
1397 pub compressions_count: usize,
1399 pub total_original_bytes: usize,
1401 pub total_compressed_bytes: usize,
1403 pub average_compression_ratio: f64,
1405 pub best_compression_ratio: f64,
1407 pub worst_compression_ratio: f64,
1409}
1410
1411impl CompressionStats {
1412 pub fn new() -> Self {
1414 Self {
1415 compressions_count: 0,
1416 total_original_bytes: 0,
1417 total_compressed_bytes: 0,
1418 average_compression_ratio: 0.0,
1419 best_compression_ratio: f64::INFINITY,
1420 worst_compression_ratio: 0.0,
1421 }
1422 }
1423
1424 pub fn record_compression(&mut self, original_bytes: usize, compressedbytes: usize) {
1426 self.compressions_count += 1;
1427 self.total_original_bytes += original_bytes;
1428 self.total_compressed_bytes += compressedbytes;
1429
1430 let ratio = if original_bytes > 0 {
1431 compressedbytes as f64 / original_bytes as f64
1432 } else {
1433 1.0
1434 };
1435
1436 self.best_compression_ratio = self.best_compression_ratio.min(ratio);
1437 self.worst_compression_ratio = self.worst_compression_ratio.max(ratio);
1438
1439 self.average_compression_ratio = if self.total_original_bytes > 0 {
1440 self.total_compressed_bytes as f64 / self.total_original_bytes as f64
1441 } else {
1442 0.0
1443 };
1444 }
1445
1446 pub fn overall_compression_ratio(&self) -> f64 {
1448 self.average_compression_ratio
1449 }
1450
1451 pub fn bandwidth_savings(&self) -> f64 {
1453 (1.0 - self.average_compression_ratio) * 100.0
1454 }
1455}
1456
1457impl Default for CompressionStats {
1458 fn default() -> Self {
1459 Self::new()
1460 }
1461}
1462
1463#[cfg(test)]
1464mod tests {
1465 use super::*;
1466 use approx::assert_relative_eq;
1467 use scirs2_core::ndarray::Array1;
1468
1469 #[test]
1470 fn test_arithmetic_averaging() {
1471 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1472 ParameterAverager::new(AveragingStrategy::Arithmetic, 3);
1473
1474 let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1475 let params2 = vec![Array1::from_vec(vec![3.0, 4.0])];
1476 let params3 = vec![Array1::from_vec(vec![5.0, 6.0])];
1477
1478 let nodeparameters = vec![(0, params1), (1, params2), (2, params3)];
1479
1480 averager
1481 .average_parameters(&nodeparameters)
1482 .expect("unwrap failed");
1483
1484 let result = averager.get_averaged_parameters();
1485 assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6); assert_relative_eq!(result[0][1], 4.0, epsilon = 1e-6); }
1488
1489 #[test]
1490 fn test_weighted_averaging() {
1491 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1492 ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
1493
1494 let params1 = vec![Array1::from_vec(vec![2.0])];
1496 let params2 = vec![Array1::from_vec(vec![6.0])];
1497 let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
1498 averager.initialize(¶ms1).expect("unwrap failed");
1499
1500 averager.set_node_weight(0, 0.75).expect("unwrap failed"); averager.set_node_weight(1, 0.25).expect("unwrap failed"); averager
1505 .average_parameters(&nodeparameters)
1506 .expect("unwrap failed");
1507
1508 let result = averager.get_averaged_parameters();
1509 assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6);
1511 }
1512
1513 #[test]
1514 fn test_momentum_averaging() {
1515 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1516 ParameterAverager::new(AveragingStrategy::Momentum { momentum: 0.9 }, 2);
1517
1518 let params1 = vec![Array1::from_vec(vec![1.0])];
1519 let params2 = vec![Array1::from_vec(vec![3.0])];
1520
1521 let node_parameters1 = vec![(0, params1.clone()), (1, params2.clone())];
1523 averager
1524 .average_parameters(&node_parameters1)
1525 .expect("unwrap failed");
1526
1527 let result1 = averager.get_averaged_parameters();
1528 assert!(result1[0][0] >= 0.0 && result1[0][0] <= 0.5);
1530
1531 for _ in 0..10 {
1533 let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
1534 averager
1535 .average_parameters(&nodeparameters)
1536 .expect("unwrap failed");
1537 }
1538
1539 let final_result = averager.get_averaged_parameters();
1540 assert!(final_result[0][0] > 0.5 && final_result[0][0] < 2.5);
1543 }
1544
1545 #[test]
1546 fn test_parameter_server() {
1547 let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
1548
1549 let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
1550 server.initialize(&initialparams).expect("unwrap failed");
1551
1552 let update1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1554 let update2 = vec![Array1::from_vec(vec![3.0, 4.0])];
1555
1556 let ready1 = server.submit_update(0, update1).expect("unwrap failed");
1557 assert!(!ready1); let ready2 = server.submit_update(1, update2).expect("unwrap failed");
1560 assert!(ready2); let global_params = server.get_global_parameters();
1563 assert_relative_eq!(global_params[0][0], 2.0, epsilon = 1e-6); assert_relative_eq!(global_params[0][1], 3.0, epsilon = 1e-6); assert_eq!(server.current_round(), 1);
1567 }
1568
1569 #[test]
1570 fn test_distributed_coordinator() {
1571 let mut coordinator = DistributedCoordinator::new(
1572 AveragingStrategy::Arithmetic,
1573 2, 2, 10, );
1577
1578 let initialparams = vec![Array1::from_vec(vec![0.0])];
1579 coordinator
1580 .initialize(&initialparams)
1581 .expect("unwrap failed");
1582
1583 for round in 1..=3 {
1585 let update1 = vec![Array1::from_vec(vec![round as f64])];
1586 let update2 = vec![Array1::from_vec(vec![(round * 2) as f64])];
1587
1588 let node_updates = vec![(0, update1), (1, update2)];
1589
1590 let result = coordinator
1591 .communication_round(node_updates)
1592 .expect("unwrap failed");
1593
1594 assert_eq!(result.round, round);
1595 assert!(result.should_continue);
1596 assert!(!result.converged); assert!(result.global_parameters[0][0] > 0.0);
1600 }
1601 }
1602
1603 #[test]
1604 fn test_averaging_strategies() {
1605 let simple_strategies = vec![
1607 AveragingStrategy::Arithmetic,
1608 AveragingStrategy::WeightedByData,
1609 AveragingStrategy::Federated,
1610 ];
1611
1612 for strategy in simple_strategies {
1613 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1614 ParameterAverager::new(strategy, 2);
1615
1616 let params1 = vec![Array1::from_vec(vec![1.0])];
1617 let params2 = vec![Array1::from_vec(vec![3.0])];
1618
1619 let nodeparameters = vec![(0, params1), (1, params2)];
1620
1621 averager
1622 .average_parameters(&nodeparameters)
1623 .expect("unwrap failed");
1624 let result = averager.get_averaged_parameters();
1625 assert!(result[0][0] >= 1.0 && result[0][0] <= 3.0);
1626 }
1627
1628 let stateful_strategies = vec![
1630 AveragingStrategy::Momentum { momentum: 0.9 },
1631 AveragingStrategy::ExponentialMovingAverage { decay: 0.9 },
1632 ];
1633
1634 for strategy in stateful_strategies {
1635 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1636 ParameterAverager::new(strategy, 2);
1637
1638 let params1 = vec![Array1::from_vec(vec![1.0])];
1639 let params2 = vec![Array1::from_vec(vec![3.0])];
1640
1641 let nodeparameters = vec![(0, params1), (1, params2)];
1642
1643 averager
1644 .average_parameters(&nodeparameters)
1645 .expect("unwrap failed");
1646 let result = averager.get_averaged_parameters();
1647 assert!(result[0][0] >= 0.0 && result[0][0] <= 3.0);
1649 }
1650 }
1651
1652 #[test]
1653 fn test_node_weight_validation() {
1654 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1655 ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
1656
1657 assert!(averager.set_node_weight(0, 0.5).is_ok());
1659 assert!(averager.set_node_weight(1, 0.5).is_ok());
1660
1661 assert!(averager.set_node_weight(2, 0.5).is_err());
1663 }
1664
1665 #[test]
1666 fn test_parameter_dimension_validation() {
1667 let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1668 ParameterAverager::new(AveragingStrategy::Arithmetic, 2);
1669
1670 let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1671 let params2 = vec![Array1::from_vec(vec![3.0])]; let nodeparameters = vec![(0, params1), (1, params2)];
1674
1675 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1677 averager.average_parameters(&nodeparameters)
1678 }));
1679
1680 assert!(result.is_err() || (result.is_ok() && result.expect("unwrap failed").is_err()));
1682 }
1683
1684 #[test]
1685 fn test_training_stats() {
1686 let mut stats = TrainingStats::new();
1687
1688 assert_eq!(stats.num_rounds(), 0);
1689 assert!(stats.latest_convergence().is_none());
1690
1691 let params = vec![Array1::from_vec(vec![1.0])];
1692 stats.record_round(1, 0.5, ¶ms);
1693
1694 assert_eq!(stats.num_rounds(), 1);
1695 assert_eq!(stats.latest_convergence(), Some(0.5));
1696 assert_eq!(stats.convergence_history(), &[0.5]);
1697 }
1698
1699 #[test]
1700 fn test_gradient_compression_none() {
1701 let mut compressor = GradientCompressor::new(CompressionStrategy::None);
1702
1703 let gradients = vec![
1704 Array1::from_vec(vec![1.0, 2.0, 3.0]),
1705 Array1::from_vec(vec![4.0, 5.0]),
1706 ];
1707
1708 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1709 assert_eq!(compressed.metadata.strategy, CompressionStrategy::None);
1710 assert_eq!(compressed.metadata.compression_ratio, 1.0);
1711
1712 let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1713 assert_eq!(decompressed.len(), 2);
1714 assert_eq!(
1715 decompressed[0].as_slice().expect("unwrap failed"),
1716 &[1.0, 2.0, 3.0]
1717 );
1718 assert_eq!(
1719 decompressed[1].as_slice().expect("unwrap failed"),
1720 &[4.0, 5.0]
1721 );
1722 }
1723
1724 #[test]
1725 fn test_gradient_compression_topk() {
1726 let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 2 });
1727
1728 let gradients = vec![Array1::from_vec(vec![0.1, 3.0, 0.2, 4.0, 0.05])];
1729
1730 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1731 assert!(compressed.metadata.compression_ratio < 1.0);
1732 assert_eq!(compressed.metadata.nnz_count, 2); let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1735 assert_eq!(decompressed.len(), 1);
1736
1737 let result = &decompressed[0];
1739 assert_eq!(result[1], 3.0); assert_eq!(result[3], 4.0); assert_eq!(result[0], 0.0); assert_eq!(result[2], 0.0); assert_eq!(result[4], 0.0); }
1745
1746 #[test]
1747 fn test_gradient_compression_threshold() {
1748 let mut compressor =
1749 GradientCompressor::new(CompressionStrategy::Threshold { threshold: 1.0 });
1750
1751 let gradients = vec![Array1::from_vec(vec![0.5, 2.0, 0.8, 3.0, 0.3])];
1752
1753 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1754 assert!(compressed.metadata.compression_ratio < 1.0);
1755 assert_eq!(compressed.metadata.nnz_count, 2); let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1758 let result = &decompressed[0];
1759
1760 assert_eq!(result[0], 0.0); assert_eq!(result[1], 2.0); assert_eq!(result[2], 0.0); assert_eq!(result[3], 3.0); assert_eq!(result[4], 0.0); }
1767
1768 #[test]
1769 fn test_gradient_compression_quantization() {
1770 let mut compressor = GradientCompressor::new(CompressionStrategy::Quantization { bits: 8 });
1771
1772 let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
1773
1774 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1775 assert!(compressed.metadata.compression_ratio < 1.0); let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1778 let result = &decompressed[0];
1779
1780 assert!((result[0] - 1.0).abs() < 0.1);
1782 assert!((result[1] - 2.0).abs() < 0.1);
1783 assert!((result[2] - 3.0).abs() < 0.1);
1784 assert!((result[3] - 4.0).abs() < 0.1);
1785 }
1786
1787 #[test]
1788 fn test_gradient_compression_randomk() {
1789 let mut compressor = GradientCompressor::new(CompressionStrategy::RandomK { k: 3 });
1790
1791 let gradients = vec![Array1::from_vec(vec![
1793 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0,
1794 ])];
1795
1796 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1797 assert!(compressed.metadata.compression_ratio < 1.0);
1799 assert_eq!(compressed.metadata.nnz_count, 3); let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1802 let result = &decompressed[0];
1803
1804 let non_zero_count = result.iter().filter(|&&x| x != 0.0).count();
1806 assert_eq!(non_zero_count, 3);
1807 }
1808
1809 #[test]
1810 fn test_gradient_compression_error_feedback() {
1811 let base_strategy = CompressionStrategy::TopK { k: 2 };
1812 let strategy = CompressionStrategy::ErrorFeedback {
1813 base_strategy: Box::new(base_strategy),
1814 error_compensation: true,
1815 };
1816
1817 let mut compressor = GradientCompressor::new(strategy);
1818
1819 let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
1820
1821 compressor.initialize_error_state(&gradients);
1823
1824 let compressed1 = compressor.compress(&gradients).expect("unwrap failed");
1826 let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
1827
1828 let compressed2 = compressor.compress(&gradients).expect("unwrap failed");
1830 let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
1831
1832 assert_eq!(decompressed1.len(), 1);
1834 assert_eq!(decompressed2.len(), 1);
1835 }
1836
1837 #[test]
1838 fn test_gradient_compression_clipped() {
1839 let base_strategy = CompressionStrategy::TopK { k: 3 };
1840 let strategy = CompressionStrategy::ClippedCompression {
1841 base_strategy: Box::new(base_strategy),
1842 clip_value: 2.5,
1843 };
1844
1845 let mut compressor = GradientCompressor::new(strategy);
1846
1847 let gradients = vec![Array1::from_vec(vec![1.0, 5.0, -3.0, 2.0])];
1848
1849 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1850 let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1851
1852 let result = &decompressed[0];
1853
1854 for &val in result.iter() {
1856 if val != 0.0 {
1857 assert!((-2.5..=2.5).contains(&val));
1859 }
1860 }
1861 }
1862
1863 #[test]
1864 fn test_compression_stats() {
1865 let mut stats = CompressionStats::new();
1866
1867 assert_eq!(stats.compressions_count, 0);
1868 assert_eq!(stats.overall_compression_ratio(), 0.0);
1869
1870 stats.record_compression(1000, 500); assert_eq!(stats.compressions_count, 1);
1873 assert_relative_eq!(stats.overall_compression_ratio(), 0.5, epsilon = 1e-6);
1874 assert_relative_eq!(stats.bandwidth_savings(), 50.0, epsilon = 1e-6);
1875
1876 stats.record_compression(1000, 250); assert_eq!(stats.compressions_count, 2);
1878 assert_relative_eq!(stats.overall_compression_ratio(), 0.375, epsilon = 1e-6); assert_relative_eq!(stats.bandwidth_savings(), 62.5, epsilon = 1e-6);
1880
1881 assert_relative_eq!(stats.best_compression_ratio, 0.25, epsilon = 1e-6);
1882 assert_relative_eq!(stats.worst_compression_ratio, 0.5, epsilon = 1e-6);
1883 }
1884
1885 #[test]
1886 fn test_compression_roundtrip() {
1887 let strategies = vec![
1888 CompressionStrategy::None,
1889 CompressionStrategy::TopK { k: 2 },
1890 CompressionStrategy::RandomK { k: 2 },
1891 CompressionStrategy::Threshold { threshold: 1.5 },
1892 CompressionStrategy::Quantization { bits: 4 },
1893 ];
1894
1895 let gradients = vec![
1896 Array1::from_vec(vec![1.0, 2.5, 0.5, 3.0]),
1897 Array1::from_vec(vec![0.1, 4.0]),
1898 ];
1899
1900 for strategy in strategies {
1901 let mut compressor = GradientCompressor::new(strategy.clone());
1902
1903 let compressed = compressor.compress(&gradients).expect("unwrap failed");
1904 let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1905
1906 assert_eq!(decompressed.len(), gradients.len());
1908
1909 for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
1911 assert_eq!(orig.shape(), decomp.shape());
1912 }
1913
1914 match strategy {
1916 CompressionStrategy::None => {
1917 for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
1918 for (&o, &d) in orig.iter().zip(decomp.iter()) {
1919 assert_relative_eq!(o, d, epsilon = 1e-10);
1920 }
1921 }
1922 }
1923 _ => {
1924 for decomp in &decompressed {
1926 assert!(decomp.iter().all(|&x| x.is_finite()));
1927 }
1928 }
1929 }
1930 }
1931 }
1932
1933 #[test]
1934 fn test_compression_invalid_configs() {
1935 let strategy = CompressionStrategy::Quantization { bits: 64 };
1937 let mut compressor = GradientCompressor::new(strategy);
1938
1939 let gradients = vec![Array1::from_vec(vec![1.0, 2.0])];
1940 assert!(compressor.compress(&gradients).is_err());
1941
1942 let valid_compressor: GradientCompressor<f64, scirs2_core::ndarray::Ix1> =
1944 GradientCompressor::new(CompressionStrategy::None);
1945 let invalid_compressed = CompressedGradient {
1946 data: vec![1, 2, 3], metadata: CompressionMetadata {
1948 strategy: CompressionStrategy::None,
1949 compression_ratio: 1.0,
1950 nnz_count: 1,
1951 scale_factors: vec![],
1952 extra_data: vec![],
1953 },
1954 shapes: vec![vec![2]],
1955 };
1956
1957 assert!(valid_compressor.decompress(&invalid_compressed).is_err());
1958 }
1959
1960 #[test]
1961 fn test_distributed_with_compression() {
1962 let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
1964 let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
1965 server.initialize(&initialparams).expect("unwrap failed");
1966
1967 let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 1 });
1968
1969 let gradients1 = vec![Array1::from_vec(vec![1.0, 3.0])]; let gradients2 = vec![Array1::from_vec(vec![2.0, 1.0])]; let compressed1 = compressor.compress(&gradients1).expect("unwrap failed");
1974 let compressed2 = compressor.compress(&gradients2).expect("unwrap failed");
1975
1976 let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
1977 let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
1978
1979 server
1981 .submit_update(0, decompressed1)
1982 .expect("unwrap failed");
1983 server
1984 .submit_update(1, decompressed2)
1985 .expect("unwrap failed");
1986
1987 let global_params = server.get_global_parameters();
1988
1989 assert_relative_eq!(global_params[0][0], 1.0, epsilon = 1e-6);
1993 assert_relative_eq!(global_params[0][1], 1.5, epsilon = 1e-6);
1994 }
1995}