Skip to main content

optirs_core/distributed/
mod.rs

1// Distributed optimization support
2//
3// This module provides support for distributed training including parameter averaging,
4// gradient compression, and communication optimization for multi-node/multi-GPU training.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12/// Parameter averaging strategies for distributed training
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum AveragingStrategy {
15    /// Simple arithmetic mean
16    Arithmetic,
17    /// Weighted average based on data sizes
18    WeightedByData,
19    /// Weighted average based on computation times
20    WeightedByTime,
21    /// Federated averaging (FedAvg)
22    Federated,
23    /// Momentum-based averaging
24    Momentum {
25        /// Momentum factor
26        momentum: f64,
27    },
28    /// Exponentially weighted moving average
29    ExponentialMovingAverage {
30        /// Decay factor
31        decay: f64,
32    },
33}
34
35/// Distributed parameter averager
36#[derive(Debug)]
37pub struct ParameterAverager<A: Float, D: Dimension> {
38    /// Current averaged parameters
39    averaged_params: Vec<Array<A, D>>,
40    /// Averaging strategy
41    strategy: AveragingStrategy,
42    /// Node weights for weighted averaging
43    node_weights: HashMap<usize, A>,
44    /// Number of participating nodes
45    numnodes: usize,
46    /// Momentum buffer for momentum-based averaging
47    momentum_buffer: Option<Vec<Array<A, D>>>,
48    /// Step count for EMA decay adjustment
49    step_count: usize,
50    /// Whether averager is initialized
51    initialized: bool,
52}
53
54impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
55    ParameterAverager<A, D>
56{
57    /// Create a new parameter averager
58    pub fn new(strategy: AveragingStrategy, numnodes: usize) -> Self {
59        Self {
60            averaged_params: Vec::new(),
61            strategy,
62            node_weights: HashMap::new(),
63            numnodes,
64            momentum_buffer: None,
65            step_count: 0,
66            initialized: false,
67        }
68    }
69
70    /// Initialize averager with parameter shapes
71    pub fn initialize(&mut self, params: &[Array<A, D>]) -> Result<()> {
72        if self.initialized {
73            return Err(OptimError::InvalidConfig(
74                "Parameter averager already initialized".to_string(),
75            ));
76        }
77
78        self.averaged_params = params.to_vec();
79
80        // Initialize momentum buffer if needed
81        if matches!(self.strategy, AveragingStrategy::Momentum { .. }) {
82            self.momentum_buffer = Some(params.iter().map(|p| Array::zeros(p.raw_dim())).collect());
83        }
84
85        // Initialize uniform weights
86        let uniform_weight = A::one() / A::from(self.numnodes).expect("unwrap failed");
87        for nodeid in 0..self.numnodes {
88            self.node_weights.insert(nodeid, uniform_weight);
89        }
90
91        self.initialized = true;
92        Ok(())
93    }
94
95    /// Set weight for a specific node
96    pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
97        if nodeid >= self.numnodes {
98            return Err(OptimError::InvalidConfig(format!(
99                "Node ID {} exceeds number of nodes {}",
100                nodeid, self.numnodes
101            )));
102        }
103        self.node_weights.insert(nodeid, weight);
104        Ok(())
105    }
106
107    /// Average parameters from multiple nodes
108    pub fn average_parameters(
109        &mut self,
110        nodeparameters: &[(usize, Vec<Array<A, D>>)],
111    ) -> Result<()> {
112        if !self.initialized {
113            if let Some((_, first_params)) = nodeparameters.first() {
114                self.initialize(first_params)?;
115            } else {
116                return Err(OptimError::InvalidConfig(
117                    "No _parameters provided for initialization".to_string(),
118                ));
119            }
120        }
121
122        // Validate input
123        for (nodeid, params) in nodeparameters {
124            if *nodeid >= self.numnodes {
125                return Err(OptimError::InvalidConfig(format!(
126                    "Node ID {} exceeds number of nodes {}",
127                    nodeid, self.numnodes
128                )));
129            }
130            if params.len() != self.averaged_params.len() {
131                return Err(OptimError::DimensionMismatch(format!(
132                    "Expected {} parameter arrays, got {}",
133                    self.averaged_params.len(),
134                    params.len()
135                )));
136            }
137        }
138
139        self.step_count += 1;
140
141        match self.strategy {
142            AveragingStrategy::Arithmetic => {
143                self.arithmetic_average(nodeparameters)?;
144            }
145            AveragingStrategy::WeightedByData | AveragingStrategy::WeightedByTime => {
146                self.weighted_average(nodeparameters)?;
147            }
148            AveragingStrategy::Federated => {
149                self.federated_average(nodeparameters)?;
150            }
151            AveragingStrategy::Momentum { momentum } => {
152                self.momentum_average(nodeparameters, momentum)?;
153            }
154            AveragingStrategy::ExponentialMovingAverage { decay } => {
155                self.ema_average(nodeparameters, decay)?;
156            }
157        }
158
159        Ok(())
160    }
161
162    /// Simple arithmetic averaging
163    fn arithmetic_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
164        // Reset averaged _parameters
165        for param in &mut self.averaged_params {
166            param.fill(A::zero());
167        }
168
169        let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
170
171        // Sum all _parameters
172        for (_node_id, params) in nodeparameters {
173            for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
174                Zip::from(avg_param).and(param).for_each(|avg, &p| {
175                    *avg = *avg + p;
176                });
177            }
178        }
179
180        // Divide by number of nodes
181        for param in &mut self.averaged_params {
182            param.mapv_inplace(|x| x / numnodes);
183        }
184
185        Ok(())
186    }
187
188    /// Weighted averaging using node weights
189    fn weighted_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
190        // Reset averaged _parameters
191        for param in &mut self.averaged_params {
192            param.fill(A::zero());
193        }
194
195        // Compute total weight
196        let total_weight: A = nodeparameters
197            .iter()
198            .map(|(nodeid, _)| self.node_weights.get(nodeid).copied().unwrap_or(A::zero()))
199            .fold(A::zero(), |acc, w| acc + w);
200
201        if total_weight <= A::zero() {
202            return Err(OptimError::InvalidConfig(
203                "Total node weights must be > 0".to_string(),
204            ));
205        }
206
207        // Weighted sum
208        for (nodeid, params) in nodeparameters {
209            let weight = self.node_weights.get(nodeid).copied().unwrap_or(A::zero()) / total_weight;
210
211            for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
212                Zip::from(avg_param).and(param).for_each(|avg, &p| {
213                    *avg = *avg + weight * p;
214                });
215            }
216        }
217
218        Ok(())
219    }
220
221    /// Federated averaging (similar to weighted but with special handling)
222    fn federated_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
223        // For simplicity, use weighted averaging with data-based weights
224        // In practice, this would consider local dataset sizes and update frequencies
225        self.weighted_average(nodeparameters)
226    }
227
228    /// Momentum-based averaging
229    fn momentum_average(
230        &mut self,
231        nodeparameters: &[(usize, Vec<Array<A, D>>)],
232        momentum: f64,
233    ) -> Result<()> {
234        let momentum_factor = A::from(momentum).expect("unwrap failed");
235        let one_minus_momentum = A::one() - momentum_factor;
236
237        // First compute arithmetic average of incoming _parameters
238        let mut current_average: Vec<Array<A, D>> = self
239            .averaged_params
240            .iter()
241            .map(|param| Array::zeros(param.raw_dim()))
242            .collect();
243
244        let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
245        for (_node_id, params) in nodeparameters {
246            for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
247                Zip::from(avg_param).and(param).for_each(|avg, &p| {
248                    *avg = *avg + p / numnodes;
249                });
250            }
251        }
252
253        // Apply momentum update
254        if let Some(ref mut momentum_buf) = self.momentum_buffer {
255            for ((avg_param, current_param), momentum_param) in self
256                .averaged_params
257                .iter_mut()
258                .zip(current_average.iter())
259                .zip(momentum_buf.iter_mut())
260            {
261                // Update momentum buffer first
262                Zip::from(&mut *momentum_param)
263                    .and(current_param)
264                    .for_each(|mom, &curr| {
265                        *mom = momentum_factor * *mom + one_minus_momentum * curr;
266                    });
267
268                // Copy momentum buffer to averaged params
269                avg_param.assign(&*momentum_param);
270            }
271        }
272
273        Ok(())
274    }
275
276    /// Exponential moving average
277    fn ema_average(
278        &mut self,
279        nodeparameters: &[(usize, Vec<Array<A, D>>)],
280        decay: f64,
281    ) -> Result<()> {
282        let decay_factor = A::from(decay).expect("unwrap failed");
283        let one_minus_decay = A::one() - decay_factor;
284
285        // First compute arithmetic average of incoming _parameters
286        let mut current_average: Vec<Array<A, D>> = self
287            .averaged_params
288            .iter()
289            .map(|param| Array::zeros(param.raw_dim()))
290            .collect();
291
292        let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
293        for (_node_id, params) in nodeparameters {
294            for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
295                Zip::from(avg_param).and(param).for_each(|avg, &p| {
296                    *avg = *avg + p / numnodes;
297                });
298            }
299        }
300
301        // Apply EMA update
302        for (avg_param, current_param) in
303            self.averaged_params.iter_mut().zip(current_average.iter())
304        {
305            Zip::from(avg_param)
306                .and(current_param)
307                .for_each(|avg, &curr| {
308                    *avg = decay_factor * *avg + one_minus_decay * curr;
309                });
310        }
311
312        Ok(())
313    }
314
315    /// Get current averaged parameters
316    pub fn get_averaged_parameters(&self) -> &[Array<A, D>] {
317        &self.averaged_params
318    }
319
320    /// Get cloned averaged parameters
321    pub fn get_averaged_parameters_cloned(&self) -> Vec<Array<A, D>> {
322        self.averaged_params.clone()
323    }
324
325    /// Reset averager state
326    pub fn reset(&mut self) {
327        self.step_count = 0;
328        for param in &mut self.averaged_params {
329            param.fill(A::zero());
330        }
331        if let Some(ref mut momentum_buf) = self.momentum_buffer {
332            for buf in momentum_buf {
333                buf.fill(A::zero());
334            }
335        }
336    }
337
338    /// Get step count
339    pub fn step_count(&self) -> usize {
340        self.step_count
341    }
342
343    /// Get number of nodes
344    pub fn numnodes(&self) -> usize {
345        self.numnodes
346    }
347
348    /// Get averaging strategy
349    pub fn strategy(&self) -> AveragingStrategy {
350        self.strategy
351    }
352
353    /// Check if initialized
354    pub fn is_initialized(&self) -> bool {
355        self.initialized
356    }
357}
358
359/// Synchronous parameter server for distributed training
360#[derive(Debug)]
361pub struct ParameterServer<A: Float, D: Dimension> {
362    /// Parameter averager
363    averager: ParameterAverager<A, D>,
364    /// Current global parameters
365    global_parameters: Vec<Array<A, D>>,
366    /// Node update counters
367    update_counts: HashMap<usize, usize>,
368    /// Expected updates per round
369    expected_updates_per_round: usize,
370    /// Current round number
371    current_round: usize,
372    /// Synchronization barrier
373    pending_updates: HashMap<usize, Vec<Array<A, D>>>,
374}
375
376impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
377    ParameterServer<A, D>
378{
379    /// Create a new parameter server
380    pub fn new(
381        strategy: AveragingStrategy,
382        numnodes: usize,
383        expected_updates_per_round: usize,
384    ) -> Self {
385        Self {
386            averager: ParameterAverager::new(strategy, numnodes),
387            global_parameters: Vec::new(),
388            update_counts: HashMap::new(),
389            expected_updates_per_round,
390            current_round: 0,
391            pending_updates: HashMap::new(),
392        }
393    }
394
395    /// Initialize with global parameters
396    pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
397        self.averager.initialize(initialparams)?;
398        self.global_parameters = initialparams.to_vec();
399
400        // Initialize update counts
401        for nodeid in 0..self.averager.numnodes() {
402            self.update_counts.insert(nodeid, 0);
403        }
404
405        Ok(())
406    }
407
408    /// Submit parameter update from a node
409    pub fn submit_update(&mut self, nodeid: usize, parameters: Vec<Array<A, D>>) -> Result<bool> {
410        if nodeid >= self.averager.numnodes() {
411            return Err(OptimError::InvalidConfig(format!(
412                "Node ID {} exceeds number of nodes {}",
413                nodeid,
414                self.averager.numnodes()
415            )));
416        }
417
418        // Store the update
419        self.pending_updates.insert(nodeid, parameters);
420        *self.update_counts.entry(nodeid).or_insert(0) += 1;
421
422        // Check if we have enough updates for this round
423        let ready_for_aggregation = self.pending_updates.len() >= self.expected_updates_per_round;
424
425        if ready_for_aggregation {
426            self.aggregate_and_update()?;
427        }
428
429        Ok(ready_for_aggregation)
430    }
431
432    /// Force aggregation with current pending updates
433    pub fn force_aggregation(&mut self) -> Result<()> {
434        if !self.pending_updates.is_empty() {
435            self.aggregate_and_update()?;
436        }
437        Ok(())
438    }
439
440    /// Internal aggregation and update
441    fn aggregate_and_update(&mut self) -> Result<()> {
442        // Convert pending updates to the format expected by averager
443        let node_params: Vec<(usize, Vec<Array<A, D>>)> = self.pending_updates.drain().collect();
444
445        // Perform averaging
446        self.averager.average_parameters(&node_params)?;
447
448        // Update global parameters
449        self.global_parameters = self.averager.get_averaged_parameters_cloned();
450
451        // Increment round
452        self.current_round += 1;
453
454        Ok(())
455    }
456
457    /// Get current global parameters
458    pub fn get_global_parameters(&self) -> &[Array<A, D>] {
459        &self.global_parameters
460    }
461
462    /// Get cloned global parameters
463    pub fn get_global_parameters_cloned(&self) -> Vec<Array<A, D>> {
464        self.global_parameters.clone()
465    }
466
467    /// Get current round number
468    pub fn current_round(&self) -> usize {
469        self.current_round
470    }
471
472    /// Get update count for a node
473    pub fn get_update_count(&self, nodeid: usize) -> usize {
474        self.update_counts.get(&nodeid).copied().unwrap_or(0)
475    }
476
477    /// Get number of pending updates
478    pub fn pending_updates_count(&self) -> usize {
479        self.pending_updates.len()
480    }
481
482    /// Set node weight for weighted averaging
483    pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
484        self.averager.set_node_weight(nodeid, weight)
485    }
486
487    /// Reset server state
488    pub fn reset(&mut self) {
489        self.averager.reset();
490        self.update_counts.clear();
491        self.pending_updates.clear();
492        self.current_round = 0;
493
494        for nodeid in 0..self.averager.numnodes() {
495            self.update_counts.insert(nodeid, 0);
496        }
497    }
498}
499
500/// Distributed training coordinator
501#[derive(Debug)]
502pub struct DistributedCoordinator<A: Float, D: Dimension> {
503    /// Parameter server
504    parameter_server: ParameterServer<A, D>,
505    /// Communication rounds completed
506    communication_rounds: usize,
507    /// Convergence criteria
508    convergence_threshold: A,
509    /// Maximum rounds before forced stop
510    max_rounds: usize,
511    /// Training statistics
512    training_stats: TrainingStats<A>,
513}
514
515impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
516    DistributedCoordinator<A, D>
517{
518    /// Create a new distributed coordinator
519    pub fn new(
520        strategy: AveragingStrategy,
521        numnodes: usize,
522        expected_updates_per_round: usize,
523        max_rounds: usize,
524    ) -> Self {
525        Self {
526            parameter_server: ParameterServer::new(strategy, numnodes, expected_updates_per_round),
527            communication_rounds: 0,
528            convergence_threshold: A::from(1e-6).expect("unwrap failed"),
529            max_rounds,
530            training_stats: TrainingStats::new(),
531        }
532    }
533
534    /// Initialize coordinator
535    pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
536        self.parameter_server.initialize(initialparams)?;
537        self.training_stats
538            .record_round(0, A::zero(), initialparams);
539        Ok(())
540    }
541
542    /// Execute a communication round
543    pub fn communication_round(
544        &mut self,
545        node_updates: Vec<(usize, Vec<Array<A, D>>)>,
546    ) -> Result<CommunicationResult<A, D>> {
547        let mut aggregated = false;
548
549        // Submit all _updates
550        for (nodeid, params) in node_updates {
551            aggregated = self.parameter_server.submit_update(nodeid, params)? || aggregated;
552        }
553
554        // Force aggregation if not done automatically
555        if !aggregated {
556            self.parameter_server.force_aggregation()?;
557            aggregated = true;
558        }
559
560        if aggregated {
561            self.communication_rounds += 1;
562
563            // Check convergence
564            let currentparams = self.parameter_server.get_global_parameters();
565            let convergence_metric = self.compute_convergence_metric(currentparams);
566
567            self.training_stats.record_round(
568                self.communication_rounds,
569                convergence_metric,
570                currentparams,
571            );
572
573            let converged = convergence_metric < self.convergence_threshold;
574            let max_rounds_reached = self.communication_rounds >= self.max_rounds;
575
576            Ok(CommunicationResult {
577                round: self.communication_rounds,
578                global_parameters: self.parameter_server.get_global_parameters_cloned(),
579                converged,
580                should_continue: !converged && !max_rounds_reached,
581                convergence_metric,
582                stats: self.training_stats.clone(),
583            })
584        } else {
585            Ok(CommunicationResult {
586                round: self.communication_rounds,
587                global_parameters: self.parameter_server.get_global_parameters_cloned(),
588                converged: false,
589                should_continue: true,
590                convergence_metric: A::infinity(),
591                stats: self.training_stats.clone(),
592            })
593        }
594    }
595
596    /// Set convergence threshold
597    pub fn set_convergence_threshold(&mut self, threshold: A) {
598        self.convergence_threshold = threshold;
599    }
600
601    /// Get parameter server reference
602    pub fn parameter_server(&self) -> &ParameterServer<A, D> {
603        &self.parameter_server
604    }
605
606    /// Get mutable parameter server reference
607    pub fn parameter_server_mut(&mut self) -> &mut ParameterServer<A, D> {
608        &mut self.parameter_server
609    }
610
611    /// Compute convergence metric (parameter change magnitude)
612    fn compute_convergence_metric(&self, currentparams: &[Array<A, D>]) -> A {
613        if let Some(prev_params) = self.training_stats.get_previous_parameters::<D>() {
614            let mut total_change = A::zero();
615            let mut total_norm = A::zero();
616
617            for (curr, prev) in currentparams.iter().zip(prev_params.iter()) {
618                for (&c, &p) in curr.iter().zip(prev.iter()) {
619                    let diff = c - p;
620                    total_change = total_change + diff * diff;
621                    total_norm = total_norm + c * c;
622                }
623            }
624
625            if total_norm > A::zero() {
626                (total_change / total_norm).sqrt()
627            } else {
628                A::zero()
629            }
630        } else {
631            A::infinity()
632        }
633    }
634}
635
636/// Result of a communication round
637#[derive(Debug, Clone)]
638pub struct CommunicationResult<A: Float, D: Dimension> {
639    /// Round number
640    pub round: usize,
641    /// Updated global parameters
642    pub global_parameters: Vec<Array<A, D>>,
643    /// Whether training has converged
644    pub converged: bool,
645    /// Whether training should continue
646    pub should_continue: bool,
647    /// Convergence metric value
648    pub convergence_metric: A,
649    /// Training statistics
650    pub stats: TrainingStats<A>,
651}
652
653/// Training statistics for distributed training
654#[derive(Debug, Clone)]
655pub struct TrainingStats<A: Float> {
656    /// Convergence history
657    convergence_history: Vec<A>,
658    /// Round timestamps
659    round_times: Vec<usize>,
660    /// Previous parameters for convergence computation
661    previous_parameters: Option<Vec<u8>>, // Serialized for memory efficiency
662}
663
664impl<A: Float + Send + Sync> TrainingStats<A> {
665    /// Create new training stats
666    pub fn new() -> Self {
667        Self {
668            convergence_history: Vec::new(),
669            round_times: Vec::new(),
670            previous_parameters: None,
671        }
672    }
673
674    /// Record a training round
675    pub fn record_round<D: Dimension>(
676        &mut self,
677        round: usize,
678        convergence_metric: A,
679        parameters: &[Array<A, D>],
680    ) {
681        self.convergence_history.push(convergence_metric);
682        self.round_times.push(round);
683
684        // Store simplified representation of parameters for convergence computation
685        // In practice, you might want a more sophisticated serialization
686        self.previous_parameters = Some(vec![0u8; parameters.len()]);
687    }
688
689    /// Get convergence history
690    pub fn convergence_history(&self) -> &[A] {
691        &self.convergence_history
692    }
693
694    /// Get latest convergence metric
695    pub fn latest_convergence(&self) -> Option<A> {
696        self.convergence_history.last().copied()
697    }
698
699    /// Get number of rounds
700    pub fn num_rounds(&self) -> usize {
701        self.round_times.len()
702    }
703
704    /// Get previous parameters (simplified)
705    fn get_previous_parameters<D: Dimension>(&self) -> Option<Vec<Array<A, D>>> {
706        // Simplified implementation - in practice you'd deserialize properly
707        None
708    }
709}
710
711impl<A: Float + Send + Sync> Default for TrainingStats<A> {
712    fn default() -> Self {
713        Self::new()
714    }
715}
716
717/// Gradient compression strategies for communication optimization
718#[derive(Debug, Clone, PartialEq)]
719pub enum CompressionStrategy {
720    /// No compression
721    None,
722    /// Top-K sparsification (keep only top K largest gradients)
723    TopK {
724        /// Number of top gradients to keep
725        k: usize,
726    },
727    /// Random-K sparsification (keep K random gradients)
728    RandomK {
729        /// Number of random gradients to keep
730        k: usize,
731    },
732    /// Threshold-based sparsification (keep gradients above threshold)
733    Threshold {
734        /// Threshold value for gradient magnitude
735        threshold: f64,
736    },
737    /// Quantization to fewer bits
738    Quantization {
739        /// Number of bits for quantization
740        bits: u8,
741    },
742    /// Error feedback compression (maintain error state)
743    ErrorFeedback {
744        /// Base compression strategy to apply
745        base_strategy: Box<CompressionStrategy>,
746        /// Whether to enable error compensation
747        error_compensation: bool,
748    },
749    /// Gradient clipping before compression
750    ClippedCompression {
751        /// Base compression strategy to apply after clipping
752        base_strategy: Box<CompressionStrategy>,
753        /// Value to clip gradients to
754        clip_value: f64,
755    },
756}
757
758/// Compressed gradient representation
759#[derive(Debug, Clone)]
760pub struct CompressedGradient<A: Float> {
761    /// Compressed data
762    pub data: Vec<u8>,
763    /// Compression metadata
764    pub metadata: CompressionMetadata<A>,
765    /// Original shape information
766    pub shapes: Vec<Vec<usize>>,
767}
768
769/// Compression metadata
770#[derive(Debug, Clone)]
771pub struct CompressionMetadata<A: Float> {
772    /// Compression strategy used
773    pub strategy: CompressionStrategy,
774    /// Compression ratio achieved
775    pub compression_ratio: f64,
776    /// Number of non-zero elements (for sparse methods)
777    pub nnz_count: usize,
778    /// Quantization scale factors (for quantization methods)
779    pub scale_factors: Vec<A>,
780    /// Additional strategy-specific data
781    pub extra_data: Vec<u8>,
782}
783
784/// Gradient compression engine
785#[derive(Debug)]
786pub struct GradientCompressor<A: Float, D: Dimension> {
787    /// Compression strategy
788    strategy: CompressionStrategy,
789    /// Error feedback state for error compensation
790    error_state: Option<Vec<Array<A, D>>>,
791    /// Compression statistics
792    stats: CompressionStats,
793}
794
795impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
796    GradientCompressor<A, D>
797{
798    /// Create a new gradient compressor
799    pub fn new(strategy: CompressionStrategy) -> Self {
800        Self {
801            strategy,
802            error_state: None,
803            stats: CompressionStats::new(),
804        }
805    }
806
807    /// Initialize error state for error feedback compression
808    pub fn initialize_error_state(&mut self, gradientshapes: &[Array<A, D>]) {
809        self.error_state = Some(
810            gradientshapes
811                .iter()
812                .map(|g| Array::zeros(g.raw_dim()))
813                .collect(),
814        );
815    }
816
817    /// Compress gradients
818    pub fn compress(&mut self, gradients: &[Array<A, D>]) -> Result<CompressedGradient<A>> {
819        // Apply error feedback if enabled
820        let mut working_gradients: Vec<Array<A, D>> =
821            if let Some(ref mut error_state) = self.error_state {
822                gradients
823                    .iter()
824                    .zip(error_state.iter())
825                    .map(|(grad, error)| grad + error)
826                    .collect()
827            } else {
828                gradients.to_vec()
829            };
830
831        let (compressed_data, metadata) = match &self.strategy {
832            CompressionStrategy::None => self.compress_none(&working_gradients)?,
833            CompressionStrategy::TopK { k } => self.compress_topk(&working_gradients, *k)?,
834            CompressionStrategy::RandomK { k } => self.compress_randomk(&working_gradients, *k)?,
835            CompressionStrategy::Threshold { threshold } => self.compress_threshold(
836                &working_gradients,
837                A::from(*threshold).expect("unwrap failed"),
838            )?,
839            CompressionStrategy::Quantization { bits } => {
840                self.compress_quantization(&working_gradients, *bits)?
841            }
842            CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
843                // Recursively apply base strategy
844                let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
845                let compressed = temp_compressor.compress(&working_gradients)?;
846                let decompressed = temp_compressor.decompress(&compressed)?;
847
848                // Update error state
849                if let Some(ref mut error_state) = self.error_state {
850                    for ((original, decompressed), error) in gradients
851                        .iter()
852                        .zip(decompressed.iter())
853                        .zip(error_state.iter_mut())
854                    {
855                        *error = original - decompressed;
856                    }
857                }
858
859                (compressed.data, compressed.metadata)
860            }
861            CompressionStrategy::ClippedCompression {
862                base_strategy,
863                clip_value,
864            } => {
865                // Clip gradients first
866                let clip_val = A::from(*clip_value).expect("unwrap failed");
867                for grad in &mut working_gradients {
868                    grad.mapv_inplace(|x| {
869                        if x > clip_val {
870                            clip_val
871                        } else if x < -clip_val {
872                            -clip_val
873                        } else {
874                            x
875                        }
876                    });
877                }
878
879                // Apply base compression strategy
880                let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
881                let compressed = temp_compressor.compress(&working_gradients)?;
882                (compressed.data, compressed.metadata)
883            }
884        };
885
886        // Collect shape information
887        let shapes = gradients.iter().map(|g| g.shape().to_vec()).collect();
888
889        let result = CompressedGradient {
890            data: compressed_data,
891            metadata,
892            shapes,
893        };
894
895        // Update statistics
896        let original_size = self.calculate_size(gradients);
897        let compressed_size = result.data.len();
898        self.stats
899            .record_compression(original_size, compressed_size);
900
901        Ok(result)
902    }
903
904    /// Decompress gradients
905    pub fn decompress(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
906        match &compressed.metadata.strategy {
907            CompressionStrategy::None => self.decompress_none(compressed),
908            CompressionStrategy::TopK { .. } => self.decompress_sparse(compressed),
909            CompressionStrategy::RandomK { .. } => self.decompress_sparse(compressed),
910            CompressionStrategy::Threshold { .. } => self.decompress_sparse(compressed),
911            CompressionStrategy::Quantization { bits } => {
912                self.decompress_quantization(compressed, *bits)
913            }
914            CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
915                let temp_compressor = GradientCompressor::new((**base_strategy).clone());
916                temp_compressor.decompress(compressed)
917            }
918            CompressionStrategy::ClippedCompression { base_strategy, .. } => {
919                let temp_compressor = GradientCompressor::new((**base_strategy).clone());
920                temp_compressor.decompress(compressed)
921            }
922        }
923    }
924
925    /// Compress with no compression (passthrough)
926    fn compress_none(
927        &self,
928        gradients: &[Array<A, D>],
929    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
930        let mut data = Vec::new();
931
932        // Simple serialization: store all gradient values sequentially
933        for grad in gradients {
934            for &val in grad.iter() {
935                data.extend_from_slice(&val.to_f64().expect("unwrap failed").to_le_bytes());
936            }
937        }
938
939        let metadata = CompressionMetadata {
940            strategy: CompressionStrategy::None,
941            compression_ratio: 1.0,
942            nnz_count: gradients.iter().map(|g| g.len()).sum(),
943            scale_factors: Vec::new(),
944            extra_data: Vec::new(),
945        };
946
947        Ok((data, metadata))
948    }
949
950    /// Compress using Top-K sparsification
951    fn compress_topk(
952        &self,
953        gradients: &[Array<A, D>],
954        k: usize,
955    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
956        let mut indices = Vec::new();
957        let mut values = Vec::new();
958        let mut total_elements = 0;
959
960        for (grad_idx, grad) in gradients.iter().enumerate() {
961            total_elements += grad.len();
962
963            // Collect (value, index) pairs
964            let mut value_indices: Vec<(A, usize)> = grad
965                .iter()
966                .enumerate()
967                .map(|(i, &val)| (val.abs(), i))
968                .collect();
969
970            // Sort by absolute value (descending)
971            value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("unwrap failed"));
972
973            // Take top k elements
974            let k_local = k.min(value_indices.len());
975            for (_, orig_idx) in value_indices.iter().take(k_local) {
976                indices.push((grad_idx as u32, *orig_idx as u32));
977                values.push(grad.iter().nth(*orig_idx).copied().expect("unwrap failed"));
978            }
979        }
980
981        // Serialize sparse representation
982        let mut data = Vec::new();
983
984        // Store number of sparse elements
985        data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
986
987        // Store indices and values
988        for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
989            data.extend_from_slice(&grad_idx.to_le_bytes());
990            data.extend_from_slice(&elem_idx.to_le_bytes());
991            data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
992        }
993
994        let metadata = CompressionMetadata {
995            strategy: CompressionStrategy::TopK { k },
996            compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
997            nnz_count: indices.len(),
998            scale_factors: Vec::new(),
999            extra_data: Vec::new(),
1000        };
1001
1002        Ok((data, metadata))
1003    }
1004
1005    /// Compress using Random-K sparsification
1006    fn compress_randomk(
1007        &self,
1008        gradients: &[Array<A, D>],
1009        k: usize,
1010    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1011        let mut indices = Vec::new();
1012        let mut values = Vec::new();
1013        let mut total_elements = 0;
1014
1015        for (grad_idx, grad) in gradients.iter().enumerate() {
1016            total_elements += grad.len();
1017
1018            // Random sampling of k indices
1019            let k_local = k.min(grad.len());
1020            let mut selected_indices: Vec<usize> = (0..grad.len()).collect();
1021
1022            // Simple random selection (deterministic for testing)
1023            for i in 0..k_local {
1024                let swap_idx = i + ((grad_idx + i) % (grad.len() - i));
1025                selected_indices.swap(i, swap_idx);
1026            }
1027
1028            for &idx in selected_indices.iter().take(k_local) {
1029                indices.push((grad_idx as u32, idx as u32));
1030                values.push(grad.iter().nth(idx).copied().expect("unwrap failed"));
1031            }
1032        }
1033
1034        // Serialize sparse representation (same format as Top-K)
1035        let mut data = Vec::new();
1036        data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
1037
1038        for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
1039            data.extend_from_slice(&grad_idx.to_le_bytes());
1040            data.extend_from_slice(&elem_idx.to_le_bytes());
1041            data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
1042        }
1043
1044        let metadata = CompressionMetadata {
1045            strategy: CompressionStrategy::RandomK { k },
1046            compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1047            nnz_count: indices.len(),
1048            scale_factors: Vec::new(),
1049            extra_data: Vec::new(),
1050        };
1051
1052        Ok((data, metadata))
1053    }
1054
1055    /// Compress using threshold-based sparsification
1056    fn compress_threshold(
1057        &self,
1058        gradients: &[Array<A, D>],
1059        threshold: A,
1060    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1061        let mut indices = Vec::new();
1062        let mut values = Vec::new();
1063        let mut total_elements = 0;
1064
1065        for (grad_idx, grad) in gradients.iter().enumerate() {
1066            total_elements += grad.len();
1067
1068            for (elem_idx, &val) in grad.iter().enumerate() {
1069                if val.abs() > threshold {
1070                    indices.push((grad_idx as u32, elem_idx as u32));
1071                    values.push(val);
1072                }
1073            }
1074        }
1075
1076        // Serialize sparse representation
1077        let mut data = Vec::new();
1078        data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
1079
1080        for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
1081            data.extend_from_slice(&grad_idx.to_le_bytes());
1082            data.extend_from_slice(&elem_idx.to_le_bytes());
1083            data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
1084        }
1085
1086        let metadata = CompressionMetadata {
1087            strategy: CompressionStrategy::Threshold {
1088                threshold: threshold.to_f64().expect("unwrap failed"),
1089            },
1090            compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1091            nnz_count: indices.len(),
1092            scale_factors: Vec::new(),
1093            extra_data: Vec::new(),
1094        };
1095
1096        Ok((data, metadata))
1097    }
1098
1099    /// Compress using quantization
1100    fn compress_quantization(
1101        &self,
1102        gradients: &[Array<A, D>],
1103        bits: u8,
1104    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1105        if bits > 32 {
1106            return Err(OptimError::InvalidConfig(
1107                "Quantization bits must be <= 32".to_string(),
1108            ));
1109        }
1110
1111        let mut data = Vec::new();
1112        let mut scale_factors = Vec::new();
1113        let levels = (1u64 << bits) - 1;
1114
1115        for grad in gradients {
1116            // Find min and max values for this gradient
1117            let min_val = grad.iter().fold(A::infinity(), |acc, &x| acc.min(x));
1118            let max_val = grad.iter().fold(A::neg_infinity(), |acc, &x| acc.max(x));
1119
1120            let range = max_val - min_val;
1121            let scale = if range > A::zero() {
1122                range / A::from(levels).expect("unwrap failed")
1123            } else {
1124                A::one()
1125            };
1126
1127            scale_factors.push(scale);
1128
1129            // Quantize each value
1130            for &val in grad.iter() {
1131                let normalized = (val - min_val) / scale;
1132                let quantized = normalized.to_u64().expect("unwrap failed").min(levels) as u32;
1133
1134                // Store quantized value
1135                match bits {
1136                    1..=8 => data.push(quantized as u8),
1137                    9..=16 => data.extend_from_slice(&(quantized as u16).to_le_bytes()),
1138                    17..=32 => data.extend_from_slice(&quantized.to_le_bytes()),
1139                    _ => unreachable!(),
1140                }
1141            }
1142
1143            // Store min value for reconstruction
1144            data.extend_from_slice(&min_val.to_f64().expect("unwrap failed").to_le_bytes());
1145        }
1146
1147        let total_elements: usize = gradients.iter().map(|g| g.len()).sum();
1148        let metadata = CompressionMetadata {
1149            strategy: CompressionStrategy::Quantization { bits },
1150            compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1151            nnz_count: total_elements,
1152            scale_factors,
1153            extra_data: Vec::new(),
1154        };
1155
1156        Ok((data, metadata))
1157    }
1158
1159    /// Decompress uncompressed data
1160    fn decompress_none(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
1161        let mut result = Vec::new();
1162        let mut data_offset = 0;
1163
1164        for shape in &compressed.shapes {
1165            let num_elements: usize = shape.iter().product();
1166            let mut values = Vec::with_capacity(num_elements);
1167
1168            for _ in 0..num_elements {
1169                if data_offset + 8 > compressed.data.len() {
1170                    return Err(OptimError::InvalidConfig(
1171                        "Insufficient data for decompression".to_string(),
1172                    ));
1173                }
1174
1175                let bytes = &compressed.data[data_offset..data_offset + 8];
1176                let value = f64::from_le_bytes(bytes.try_into().expect("unwrap failed"));
1177                values.push(A::from(value).expect("unwrap failed"));
1178                data_offset += 8;
1179            }
1180
1181            // Create a dynamic array first, then convert to the target dimension type
1182            let dynamic_array = Array::from_shape_vec(shape.as_slice(), values).map_err(|_| {
1183                OptimError::InvalidConfig("Invalid shape for reconstruction".to_string())
1184            })?;
1185            let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1186                OptimError::InvalidConfig("Dimension conversion failed".to_string())
1187            })?;
1188            result.push(array);
1189        }
1190
1191        Ok(result)
1192    }
1193
1194    /// Decompress sparse representation
1195    fn decompress_sparse(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
1196        let mut result = Vec::new();
1197
1198        // Initialize zero arrays
1199        for shape in &compressed.shapes {
1200            let dynamic_array = Array::zeros(shape.as_slice());
1201            let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1202                OptimError::InvalidConfig("Dimension conversion failed for zero array".to_string())
1203            })?;
1204            result.push(array);
1205        }
1206
1207        // Read number of sparse elements
1208        if compressed.data.len() < 4 {
1209            return Err(OptimError::InvalidConfig(
1210                "Invalid compressed data format".to_string(),
1211            ));
1212        }
1213
1214        let num_elements =
1215            u32::from_le_bytes(compressed.data[0..4].try_into().expect("unwrap failed")) as usize;
1216        let mut data_offset = 4;
1217
1218        // Restore sparse elements
1219        for _ in 0..num_elements {
1220            if data_offset + 16 > compressed.data.len() {
1221                return Err(OptimError::InvalidConfig(
1222                    "Insufficient data for sparse decompression".to_string(),
1223                ));
1224            }
1225
1226            let grad_idx = u32::from_le_bytes(
1227                compressed.data[data_offset..data_offset + 4]
1228                    .try_into()
1229                    .expect("unwrap failed"),
1230            ) as usize;
1231            let elem_idx = u32::from_le_bytes(
1232                compressed.data[data_offset + 4..data_offset + 8]
1233                    .try_into()
1234                    .expect("unwrap failed"),
1235            ) as usize;
1236            let value_bytes = &compressed.data[data_offset + 8..data_offset + 16];
1237            let value = A::from(f64::from_le_bytes(
1238                value_bytes.try_into().expect("unwrap failed"),
1239            ))
1240            .expect("unwrap failed");
1241
1242            data_offset += 16;
1243
1244            if grad_idx >= result.len() {
1245                return Err(OptimError::InvalidConfig(
1246                    "Invalid gradient index in compressed data".to_string(),
1247                ));
1248            }
1249
1250            if let Some(elem) = result[grad_idx].iter_mut().nth(elem_idx) {
1251                *elem = value;
1252            } else {
1253                return Err(OptimError::InvalidConfig(
1254                    "Invalid element index in compressed data".to_string(),
1255                ));
1256            }
1257        }
1258
1259        Ok(result)
1260    }
1261
1262    /// Decompress quantized data
1263    fn decompress_quantization(
1264        &self,
1265        compressed: &CompressedGradient<A>,
1266        bits: u8,
1267    ) -> Result<Vec<Array<A, D>>> {
1268        let mut result = Vec::new();
1269        let mut data_offset = 0;
1270        let _levels = (1u64 << bits) - 1;
1271
1272        for (grad_idx, shape) in compressed.shapes.iter().enumerate() {
1273            let num_elements: usize = shape.iter().product();
1274            let mut values = Vec::with_capacity(num_elements);
1275
1276            // Read quantized values
1277            for _ in 0..num_elements {
1278                let quantized = match bits {
1279                    1..=8 => {
1280                        if data_offset >= compressed.data.len() {
1281                            return Err(OptimError::InvalidConfig(
1282                                "Insufficient quantized data".to_string(),
1283                            ));
1284                        }
1285                        let val = compressed.data[data_offset] as u32;
1286                        data_offset += 1;
1287                        val
1288                    }
1289                    9..=16 => {
1290                        if data_offset + 2 > compressed.data.len() {
1291                            return Err(OptimError::InvalidConfig(
1292                                "Insufficient quantized data".to_string(),
1293                            ));
1294                        }
1295                        let val = u16::from_le_bytes(
1296                            compressed.data[data_offset..data_offset + 2]
1297                                .try_into()
1298                                .expect("unwrap failed"),
1299                        ) as u32;
1300                        data_offset += 2;
1301                        val
1302                    }
1303                    17..=32 => {
1304                        if data_offset + 4 > compressed.data.len() {
1305                            return Err(OptimError::InvalidConfig(
1306                                "Insufficient quantized data".to_string(),
1307                            ));
1308                        }
1309                        let val = u32::from_le_bytes(
1310                            compressed.data[data_offset..data_offset + 4]
1311                                .try_into()
1312                                .expect("unwrap failed"),
1313                        );
1314                        data_offset += 4;
1315                        val
1316                    }
1317                    _ => {
1318                        return Err(OptimError::InvalidConfig(
1319                            "Invalid quantization bits".to_string(),
1320                        ))
1321                    }
1322                };
1323
1324                values.push(quantized);
1325            }
1326
1327            // Read min value
1328            if data_offset + 8 > compressed.data.len() {
1329                return Err(OptimError::InvalidConfig(
1330                    "Missing min value for quantization".to_string(),
1331                ));
1332            }
1333            let min_bytes = &compressed.data[data_offset..data_offset + 8];
1334            let min_val = A::from(f64::from_le_bytes(
1335                min_bytes.try_into().expect("unwrap failed"),
1336            ))
1337            .expect("unwrap failed");
1338            data_offset += 8;
1339
1340            // Get scale factor
1341            let scale = if grad_idx < compressed.metadata.scale_factors.len() {
1342                compressed.metadata.scale_factors[grad_idx]
1343            } else {
1344                return Err(OptimError::InvalidConfig(
1345                    "Missing scale factor for quantization".to_string(),
1346                ));
1347            };
1348
1349            // Dequantize values
1350            let dequantized_values: Vec<A> = values
1351                .into_iter()
1352                .map(|q| min_val + A::from(q).expect("unwrap failed") * scale)
1353                .collect();
1354
1355            let dynamic_array = Array::from_shape_vec(shape.as_slice(), dequantized_values)
1356                .map_err(|_| {
1357                    OptimError::InvalidConfig(
1358                        "Invalid shape for quantized reconstruction".to_string(),
1359                    )
1360                })?;
1361            let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1362                OptimError::InvalidConfig(
1363                    "Dimension conversion failed for quantized array".to_string(),
1364                )
1365            })?;
1366            result.push(array);
1367        }
1368
1369        Ok(result)
1370    }
1371
1372    /// Calculate size of gradients in bytes
1373    fn calculate_size(&self, gradients: &[Array<A, D>]) -> usize {
1374        gradients
1375            .iter()
1376            .map(|g| g.len() * std::mem::size_of::<A>())
1377            .sum()
1378    }
1379
1380    /// Get compression statistics
1381    pub fn stats(&self) -> &CompressionStats {
1382        &self.stats
1383    }
1384
1385    /// Reset compression statistics
1386    pub fn reset_stats(&mut self) {
1387        self.stats = CompressionStats::new();
1388    }
1389}
1390
1391/// Compression statistics
1392#[derive(Debug, Clone)]
1393pub struct CompressionStats {
1394    /// Total compressions performed
1395    pub compressions_count: usize,
1396    /// Total original bytes
1397    pub total_original_bytes: usize,
1398    /// Total compressed bytes
1399    pub total_compressed_bytes: usize,
1400    /// Average compression ratio
1401    pub average_compression_ratio: f64,
1402    /// Best compression ratio achieved
1403    pub best_compression_ratio: f64,
1404    /// Worst compression ratio achieved
1405    pub worst_compression_ratio: f64,
1406}
1407
1408impl CompressionStats {
1409    /// Create new compression statistics
1410    pub fn new() -> Self {
1411        Self {
1412            compressions_count: 0,
1413            total_original_bytes: 0,
1414            total_compressed_bytes: 0,
1415            average_compression_ratio: 0.0,
1416            best_compression_ratio: f64::INFINITY,
1417            worst_compression_ratio: 0.0,
1418        }
1419    }
1420
1421    /// Record a compression operation
1422    pub fn record_compression(&mut self, original_bytes: usize, compressedbytes: usize) {
1423        self.compressions_count += 1;
1424        self.total_original_bytes += original_bytes;
1425        self.total_compressed_bytes += compressedbytes;
1426
1427        let ratio = if original_bytes > 0 {
1428            compressedbytes as f64 / original_bytes as f64
1429        } else {
1430            1.0
1431        };
1432
1433        self.best_compression_ratio = self.best_compression_ratio.min(ratio);
1434        self.worst_compression_ratio = self.worst_compression_ratio.max(ratio);
1435
1436        self.average_compression_ratio = if self.total_original_bytes > 0 {
1437            self.total_compressed_bytes as f64 / self.total_original_bytes as f64
1438        } else {
1439            0.0
1440        };
1441    }
1442
1443    /// Get overall compression ratio
1444    pub fn overall_compression_ratio(&self) -> f64 {
1445        self.average_compression_ratio
1446    }
1447
1448    /// Get bandwidth savings (as percentage)
1449    pub fn bandwidth_savings(&self) -> f64 {
1450        (1.0 - self.average_compression_ratio) * 100.0
1451    }
1452}
1453
1454impl Default for CompressionStats {
1455    fn default() -> Self {
1456        Self::new()
1457    }
1458}
1459
1460#[cfg(test)]
1461mod tests {
1462    use super::*;
1463    use approx::assert_relative_eq;
1464    use scirs2_core::ndarray::Array1;
1465
1466    #[test]
1467    fn test_arithmetic_averaging() {
1468        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1469            ParameterAverager::new(AveragingStrategy::Arithmetic, 3);
1470
1471        let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1472        let params2 = vec![Array1::from_vec(vec![3.0, 4.0])];
1473        let params3 = vec![Array1::from_vec(vec![5.0, 6.0])];
1474
1475        let nodeparameters = vec![(0, params1), (1, params2), (2, params3)];
1476
1477        averager
1478            .average_parameters(&nodeparameters)
1479            .expect("unwrap failed");
1480
1481        let result = averager.get_averaged_parameters();
1482        assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6); // (1+3+5)/3
1483        assert_relative_eq!(result[0][1], 4.0, epsilon = 1e-6); // (2+4+6)/3
1484    }
1485
1486    #[test]
1487    fn test_weighted_averaging() {
1488        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1489            ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
1490
1491        // Initialize first to avoid overwriting weights
1492        let params1 = vec![Array1::from_vec(vec![2.0])];
1493        let params2 = vec![Array1::from_vec(vec![6.0])];
1494        let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
1495        averager.initialize(&params1).expect("unwrap failed");
1496
1497        // Set different weights after initialization
1498        averager.set_node_weight(0, 0.75).expect("unwrap failed"); // 75% weight
1499        averager.set_node_weight(1, 0.25).expect("unwrap failed"); // 25% weight
1500
1501        averager
1502            .average_parameters(&nodeparameters)
1503            .expect("unwrap failed");
1504
1505        let result = averager.get_averaged_parameters();
1506        // Weighted average: 0.75 * 2.0 + 0.25 * 6.0 = 1.5 + 1.5 = 3.0
1507        assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6);
1508    }
1509
1510    #[test]
1511    fn test_momentum_averaging() {
1512        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1513            ParameterAverager::new(AveragingStrategy::Momentum { momentum: 0.9 }, 2);
1514
1515        let params1 = vec![Array1::from_vec(vec![1.0])];
1516        let params2 = vec![Array1::from_vec(vec![3.0])];
1517
1518        // First update: average = (1+3)/2 = 2.0, momentum buffer starts at 0, so result = 0.1 * 2.0 = 0.2
1519        let node_parameters1 = vec![(0, params1.clone()), (1, params2.clone())];
1520        averager
1521            .average_parameters(&node_parameters1)
1522            .expect("unwrap failed");
1523
1524        let result1 = averager.get_averaged_parameters();
1525        // First result should be small due to zero initialization
1526        assert!(result1[0][0] >= 0.0 && result1[0][0] <= 0.5);
1527
1528        // Several more updates to let momentum build up
1529        for _ in 0..10 {
1530            let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
1531            averager
1532                .average_parameters(&nodeparameters)
1533                .expect("unwrap failed");
1534        }
1535
1536        let final_result = averager.get_averaged_parameters();
1537        // After many updates, momentum should gradually converge towards the average (2.0)
1538        // But with momentum=0.9, it builds up slowly, so we use a broader range
1539        assert!(final_result[0][0] > 0.5 && final_result[0][0] < 2.5);
1540    }
1541
1542    #[test]
1543    fn test_parameter_server() {
1544        let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
1545
1546        let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
1547        server.initialize(&initialparams).expect("unwrap failed");
1548
1549        // Submit updates from both nodes
1550        let update1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1551        let update2 = vec![Array1::from_vec(vec![3.0, 4.0])];
1552
1553        let ready1 = server.submit_update(0, update1).expect("unwrap failed");
1554        assert!(!ready1); // Not ready yet, waiting for second node
1555
1556        let ready2 = server.submit_update(1, update2).expect("unwrap failed");
1557        assert!(ready2); // Ready after both nodes submitted
1558
1559        let global_params = server.get_global_parameters();
1560        assert_relative_eq!(global_params[0][0], 2.0, epsilon = 1e-6); // (1+3)/2
1561        assert_relative_eq!(global_params[0][1], 3.0, epsilon = 1e-6); // (2+4)/2
1562
1563        assert_eq!(server.current_round(), 1);
1564    }
1565
1566    #[test]
1567    fn test_distributed_coordinator() {
1568        let mut coordinator = DistributedCoordinator::new(
1569            AveragingStrategy::Arithmetic,
1570            2,  // 2 nodes
1571            2,  // expect 2 updates per round
1572            10, // max 10 rounds
1573        );
1574
1575        let initialparams = vec![Array1::from_vec(vec![0.0])];
1576        coordinator
1577            .initialize(&initialparams)
1578            .expect("unwrap failed");
1579
1580        // Simulate training rounds
1581        for round in 1..=3 {
1582            let update1 = vec![Array1::from_vec(vec![round as f64])];
1583            let update2 = vec![Array1::from_vec(vec![(round * 2) as f64])];
1584
1585            let node_updates = vec![(0, update1), (1, update2)];
1586
1587            let result = coordinator
1588                .communication_round(node_updates)
1589                .expect("unwrap failed");
1590
1591            assert_eq!(result.round, round);
1592            assert!(result.should_continue);
1593            assert!(!result.converged); // Unlikely to converge with these updates
1594
1595            // Check that global parameters are updated
1596            assert!(result.global_parameters[0][0] > 0.0);
1597        }
1598    }
1599
1600    #[test]
1601    fn test_averaging_strategies() {
1602        // Test arithmetic and federated strategies that should produce expected ranges
1603        let simple_strategies = vec![
1604            AveragingStrategy::Arithmetic,
1605            AveragingStrategy::WeightedByData,
1606            AveragingStrategy::Federated,
1607        ];
1608
1609        for strategy in simple_strategies {
1610            let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1611                ParameterAverager::new(strategy, 2);
1612
1613            let params1 = vec![Array1::from_vec(vec![1.0])];
1614            let params2 = vec![Array1::from_vec(vec![3.0])];
1615
1616            let nodeparameters = vec![(0, params1), (1, params2)];
1617
1618            averager
1619                .average_parameters(&nodeparameters)
1620                .expect("unwrap failed");
1621            let result = averager.get_averaged_parameters();
1622            assert!(result[0][0] >= 1.0 && result[0][0] <= 3.0);
1623        }
1624
1625        // Test momentum and EMA strategies separately (they start from zero state)
1626        let stateful_strategies = vec![
1627            AveragingStrategy::Momentum { momentum: 0.9 },
1628            AveragingStrategy::ExponentialMovingAverage { decay: 0.9 },
1629        ];
1630
1631        for strategy in stateful_strategies {
1632            let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1633                ParameterAverager::new(strategy, 2);
1634
1635            let params1 = vec![Array1::from_vec(vec![1.0])];
1636            let params2 = vec![Array1::from_vec(vec![3.0])];
1637
1638            let nodeparameters = vec![(0, params1), (1, params2)];
1639
1640            averager
1641                .average_parameters(&nodeparameters)
1642                .expect("unwrap failed");
1643            let result = averager.get_averaged_parameters();
1644            // First result from momentum/EMA will be smaller due to zero initialization
1645            assert!(result[0][0] >= 0.0 && result[0][0] <= 3.0);
1646        }
1647    }
1648
1649    #[test]
1650    fn test_node_weight_validation() {
1651        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1652            ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
1653
1654        // Valid node ID
1655        assert!(averager.set_node_weight(0, 0.5).is_ok());
1656        assert!(averager.set_node_weight(1, 0.5).is_ok());
1657
1658        // Invalid node ID
1659        assert!(averager.set_node_weight(2, 0.5).is_err());
1660    }
1661
1662    #[test]
1663    fn test_parameter_dimension_validation() {
1664        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1665            ParameterAverager::new(AveragingStrategy::Arithmetic, 2);
1666
1667        let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1668        let params2 = vec![Array1::from_vec(vec![3.0])]; // Wrong dimension
1669
1670        let nodeparameters = vec![(0, params1), (1, params2)];
1671
1672        // Should fail due to dimension mismatch - currently panics instead of returning error
1673        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1674            averager.average_parameters(&nodeparameters)
1675        }));
1676
1677        // Either it returns an error or panics due to dimension mismatch
1678        assert!(result.is_err() || (result.is_ok() && result.expect("unwrap failed").is_err()));
1679    }
1680
1681    #[test]
1682    fn test_training_stats() {
1683        let mut stats = TrainingStats::new();
1684
1685        assert_eq!(stats.num_rounds(), 0);
1686        assert!(stats.latest_convergence().is_none());
1687
1688        let params = vec![Array1::from_vec(vec![1.0])];
1689        stats.record_round(1, 0.5, &params);
1690
1691        assert_eq!(stats.num_rounds(), 1);
1692        assert_eq!(stats.latest_convergence(), Some(0.5));
1693        assert_eq!(stats.convergence_history(), &[0.5]);
1694    }
1695
1696    #[test]
1697    fn test_gradient_compression_none() {
1698        let mut compressor = GradientCompressor::new(CompressionStrategy::None);
1699
1700        let gradients = vec![
1701            Array1::from_vec(vec![1.0, 2.0, 3.0]),
1702            Array1::from_vec(vec![4.0, 5.0]),
1703        ];
1704
1705        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1706        assert_eq!(compressed.metadata.strategy, CompressionStrategy::None);
1707        assert_eq!(compressed.metadata.compression_ratio, 1.0);
1708
1709        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1710        assert_eq!(decompressed.len(), 2);
1711        assert_eq!(
1712            decompressed[0].as_slice().expect("unwrap failed"),
1713            &[1.0, 2.0, 3.0]
1714        );
1715        assert_eq!(
1716            decompressed[1].as_slice().expect("unwrap failed"),
1717            &[4.0, 5.0]
1718        );
1719    }
1720
1721    #[test]
1722    fn test_gradient_compression_topk() {
1723        let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 2 });
1724
1725        let gradients = vec![Array1::from_vec(vec![0.1, 3.0, 0.2, 4.0, 0.05])];
1726
1727        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1728        assert!(compressed.metadata.compression_ratio < 1.0);
1729        assert_eq!(compressed.metadata.nnz_count, 2); // Top 2 elements
1730
1731        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1732        assert_eq!(decompressed.len(), 1);
1733
1734        // Should have only the top 2 elements (4.0 and 3.0), others should be 0
1735        let result = &decompressed[0];
1736        assert_eq!(result[1], 3.0); // Original position of 3.0
1737        assert_eq!(result[3], 4.0); // Original position of 4.0
1738        assert_eq!(result[0], 0.0); // Should be zeroed
1739        assert_eq!(result[2], 0.0); // Should be zeroed
1740        assert_eq!(result[4], 0.0); // Should be zeroed
1741    }
1742
1743    #[test]
1744    fn test_gradient_compression_threshold() {
1745        let mut compressor =
1746            GradientCompressor::new(CompressionStrategy::Threshold { threshold: 1.0 });
1747
1748        let gradients = vec![Array1::from_vec(vec![0.5, 2.0, 0.8, 3.0, 0.3])];
1749
1750        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1751        assert!(compressed.metadata.compression_ratio < 1.0);
1752        assert_eq!(compressed.metadata.nnz_count, 2); // Elements > 1.0: 2.0 and 3.0
1753
1754        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1755        let result = &decompressed[0];
1756
1757        // Only elements > 1.0 should remain
1758        assert_eq!(result[0], 0.0); // 0.5 < 1.0
1759        assert_eq!(result[1], 2.0); // 2.0 > 1.0
1760        assert_eq!(result[2], 0.0); // 0.8 < 1.0
1761        assert_eq!(result[3], 3.0); // 3.0 > 1.0
1762        assert_eq!(result[4], 0.0); // 0.3 < 1.0
1763    }
1764
1765    #[test]
1766    fn test_gradient_compression_quantization() {
1767        let mut compressor = GradientCompressor::new(CompressionStrategy::Quantization { bits: 8 });
1768
1769        let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
1770
1771        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1772        assert!(compressed.metadata.compression_ratio < 1.0); // Should use less space with 8-bit quantization
1773
1774        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1775        let result = &decompressed[0];
1776
1777        // Values should be approximately restored (with quantization error)
1778        assert!((result[0] - 1.0).abs() < 0.1);
1779        assert!((result[1] - 2.0).abs() < 0.1);
1780        assert!((result[2] - 3.0).abs() < 0.1);
1781        assert!((result[3] - 4.0).abs() < 0.1);
1782    }
1783
1784    #[test]
1785    fn test_gradient_compression_randomk() {
1786        let mut compressor = GradientCompressor::new(CompressionStrategy::RandomK { k: 3 });
1787
1788        // Use a larger array to make compression effective
1789        let gradients = vec![Array1::from_vec(vec![
1790            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0,
1791        ])];
1792
1793        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1794        // With 3 out of 10 elements, compression should be effective
1795        assert!(compressed.metadata.compression_ratio < 1.0);
1796        assert_eq!(compressed.metadata.nnz_count, 3); // Exactly 3 elements should be kept
1797
1798        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1799        let result = &decompressed[0];
1800
1801        // Exactly 3 elements should be non-zero
1802        let non_zero_count = result.iter().filter(|&&x| x != 0.0).count();
1803        assert_eq!(non_zero_count, 3);
1804    }
1805
1806    #[test]
1807    fn test_gradient_compression_error_feedback() {
1808        let base_strategy = CompressionStrategy::TopK { k: 2 };
1809        let strategy = CompressionStrategy::ErrorFeedback {
1810            base_strategy: Box::new(base_strategy),
1811            error_compensation: true,
1812        };
1813
1814        let mut compressor = GradientCompressor::new(strategy);
1815
1816        let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
1817
1818        // Initialize error state
1819        compressor.initialize_error_state(&gradients);
1820
1821        // First compression
1822        let compressed1 = compressor.compress(&gradients).expect("unwrap failed");
1823        let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
1824
1825        // Second compression (should include error feedback)
1826        let compressed2 = compressor.compress(&gradients).expect("unwrap failed");
1827        let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
1828
1829        // Both should be valid compressions
1830        assert_eq!(decompressed1.len(), 1);
1831        assert_eq!(decompressed2.len(), 1);
1832    }
1833
1834    #[test]
1835    fn test_gradient_compression_clipped() {
1836        let base_strategy = CompressionStrategy::TopK { k: 3 };
1837        let strategy = CompressionStrategy::ClippedCompression {
1838            base_strategy: Box::new(base_strategy),
1839            clip_value: 2.5,
1840        };
1841
1842        let mut compressor = GradientCompressor::new(strategy);
1843
1844        let gradients = vec![Array1::from_vec(vec![1.0, 5.0, -3.0, 2.0])];
1845
1846        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1847        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1848
1849        let result = &decompressed[0];
1850
1851        // Values should be clipped to [-2.5, 2.5] and then top-k applied
1852        for &val in result.iter() {
1853            if val != 0.0 {
1854                // Non-zero values from top-k
1855                assert!((-2.5..=2.5).contains(&val));
1856            }
1857        }
1858    }
1859
1860    #[test]
1861    fn test_compression_stats() {
1862        let mut stats = CompressionStats::new();
1863
1864        assert_eq!(stats.compressions_count, 0);
1865        assert_eq!(stats.overall_compression_ratio(), 0.0);
1866
1867        // Record some compressions
1868        stats.record_compression(1000, 500); // 50% compression
1869        assert_eq!(stats.compressions_count, 1);
1870        assert_relative_eq!(stats.overall_compression_ratio(), 0.5, epsilon = 1e-6);
1871        assert_relative_eq!(stats.bandwidth_savings(), 50.0, epsilon = 1e-6);
1872
1873        stats.record_compression(1000, 250); // 25% compression
1874        assert_eq!(stats.compressions_count, 2);
1875        assert_relative_eq!(stats.overall_compression_ratio(), 0.375, epsilon = 1e-6); // (500+250)/(1000+1000)
1876        assert_relative_eq!(stats.bandwidth_savings(), 62.5, epsilon = 1e-6);
1877
1878        assert_relative_eq!(stats.best_compression_ratio, 0.25, epsilon = 1e-6);
1879        assert_relative_eq!(stats.worst_compression_ratio, 0.5, epsilon = 1e-6);
1880    }
1881
1882    #[test]
1883    fn test_compression_roundtrip() {
1884        let strategies = vec![
1885            CompressionStrategy::None,
1886            CompressionStrategy::TopK { k: 2 },
1887            CompressionStrategy::RandomK { k: 2 },
1888            CompressionStrategy::Threshold { threshold: 1.5 },
1889            CompressionStrategy::Quantization { bits: 4 },
1890        ];
1891
1892        let gradients = vec![
1893            Array1::from_vec(vec![1.0, 2.5, 0.5, 3.0]),
1894            Array1::from_vec(vec![0.1, 4.0]),
1895        ];
1896
1897        for strategy in strategies {
1898            let mut compressor = GradientCompressor::new(strategy.clone());
1899
1900            let compressed = compressor.compress(&gradients).expect("unwrap failed");
1901            let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1902
1903            // Should decompress to same number of arrays
1904            assert_eq!(decompressed.len(), gradients.len());
1905
1906            // Shapes should match
1907            for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
1908                assert_eq!(orig.shape(), decomp.shape());
1909            }
1910
1911            // For lossless strategies, values should match exactly
1912            match strategy {
1913                CompressionStrategy::None => {
1914                    for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
1915                        for (&o, &d) in orig.iter().zip(decomp.iter()) {
1916                            assert_relative_eq!(o, d, epsilon = 1e-10);
1917                        }
1918                    }
1919                }
1920                _ => {
1921                    // For lossy strategies, just check that we get reasonable values
1922                    for decomp in &decompressed {
1923                        assert!(decomp.iter().all(|&x| x.is_finite()));
1924                    }
1925                }
1926            }
1927        }
1928    }
1929
1930    #[test]
1931    fn test_compression_invalid_configs() {
1932        // Invalid quantization bits
1933        let strategy = CompressionStrategy::Quantization { bits: 64 };
1934        let mut compressor = GradientCompressor::new(strategy);
1935
1936        let gradients = vec![Array1::from_vec(vec![1.0, 2.0])];
1937        assert!(compressor.compress(&gradients).is_err());
1938
1939        // Invalid decompression data
1940        let valid_compressor: GradientCompressor<f64, scirs2_core::ndarray::Ix1> =
1941            GradientCompressor::new(CompressionStrategy::None);
1942        let invalid_compressed = CompressedGradient {
1943            data: vec![1, 2, 3], // Insufficient data
1944            metadata: CompressionMetadata {
1945                strategy: CompressionStrategy::None,
1946                compression_ratio: 1.0,
1947                nnz_count: 1,
1948                scale_factors: vec![],
1949                extra_data: vec![],
1950            },
1951            shapes: vec![vec![2]],
1952        };
1953
1954        assert!(valid_compressor.decompress(&invalid_compressed).is_err());
1955    }
1956
1957    #[test]
1958    fn test_distributed_with_compression() {
1959        // Test parameter server with compressed gradients
1960        let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
1961        let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
1962        server.initialize(&initialparams).expect("unwrap failed");
1963
1964        let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 1 });
1965
1966        // Create gradients and compress them
1967        let gradients1 = vec![Array1::from_vec(vec![1.0, 3.0])]; // Top-1 should keep 3.0
1968        let gradients2 = vec![Array1::from_vec(vec![2.0, 1.0])]; // Top-1 should keep 2.0
1969
1970        let compressed1 = compressor.compress(&gradients1).expect("unwrap failed");
1971        let compressed2 = compressor.compress(&gradients2).expect("unwrap failed");
1972
1973        let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
1974        let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
1975
1976        // Submit decompressed gradients to server
1977        server
1978            .submit_update(0, decompressed1)
1979            .expect("unwrap failed");
1980        server
1981            .submit_update(1, decompressed2)
1982            .expect("unwrap failed");
1983
1984        let global_params = server.get_global_parameters();
1985
1986        // Should have averaged the compressed gradients
1987        // Node 0 contributes [0, 3.0], Node 1 contributes [2.0, 0]
1988        // Average: [1.0, 1.5]
1989        assert_relative_eq!(global_params[0][0], 1.0, epsilon = 1e-6);
1990        assert_relative_eq!(global_params[0][1], 1.5, epsilon = 1e-6);
1991    }
1992}