pub mod fedprox;
pub use fedprox::*;
use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AveragingStrategy {
Arithmetic,
WeightedByData,
WeightedByTime,
Federated,
Momentum {
momentum: f64,
},
ExponentialMovingAverage {
decay: f64,
},
}
#[derive(Debug)]
pub struct ParameterAverager<A: Float, D: Dimension> {
averaged_params: Vec<Array<A, D>>,
strategy: AveragingStrategy,
node_weights: HashMap<usize, A>,
numnodes: usize,
momentum_buffer: Option<Vec<Array<A, D>>>,
step_count: usize,
initialized: bool,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
ParameterAverager<A, D>
{
pub fn new(strategy: AveragingStrategy, numnodes: usize) -> Self {
Self {
averaged_params: Vec::new(),
strategy,
node_weights: HashMap::new(),
numnodes,
momentum_buffer: None,
step_count: 0,
initialized: false,
}
}
pub fn initialize(&mut self, params: &[Array<A, D>]) -> Result<()> {
if self.initialized {
return Err(OptimError::InvalidConfig(
"Parameter averager already initialized".to_string(),
));
}
self.averaged_params = params.to_vec();
if matches!(self.strategy, AveragingStrategy::Momentum { .. }) {
self.momentum_buffer = Some(params.iter().map(|p| Array::zeros(p.raw_dim())).collect());
}
let uniform_weight = A::one() / A::from(self.numnodes).expect("unwrap failed");
for nodeid in 0..self.numnodes {
self.node_weights.insert(nodeid, uniform_weight);
}
self.initialized = true;
Ok(())
}
pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
if nodeid >= self.numnodes {
return Err(OptimError::InvalidConfig(format!(
"Node ID {} exceeds number of nodes {}",
nodeid, self.numnodes
)));
}
self.node_weights.insert(nodeid, weight);
Ok(())
}
pub fn average_parameters(
&mut self,
nodeparameters: &[(usize, Vec<Array<A, D>>)],
) -> Result<()> {
if !self.initialized {
if let Some((_, first_params)) = nodeparameters.first() {
self.initialize(first_params)?;
} else {
return Err(OptimError::InvalidConfig(
"No _parameters provided for initialization".to_string(),
));
}
}
for (nodeid, params) in nodeparameters {
if *nodeid >= self.numnodes {
return Err(OptimError::InvalidConfig(format!(
"Node ID {} exceeds number of nodes {}",
nodeid, self.numnodes
)));
}
if params.len() != self.averaged_params.len() {
return Err(OptimError::DimensionMismatch(format!(
"Expected {} parameter arrays, got {}",
self.averaged_params.len(),
params.len()
)));
}
}
self.step_count += 1;
match self.strategy {
AveragingStrategy::Arithmetic => {
self.arithmetic_average(nodeparameters)?;
}
AveragingStrategy::WeightedByData | AveragingStrategy::WeightedByTime => {
self.weighted_average(nodeparameters)?;
}
AveragingStrategy::Federated => {
self.federated_average(nodeparameters)?;
}
AveragingStrategy::Momentum { momentum } => {
self.momentum_average(nodeparameters, momentum)?;
}
AveragingStrategy::ExponentialMovingAverage { decay } => {
self.ema_average(nodeparameters, decay)?;
}
}
Ok(())
}
fn arithmetic_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
for param in &mut self.averaged_params {
param.fill(A::zero());
}
let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
for (_node_id, params) in nodeparameters {
for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
Zip::from(avg_param).and(param).for_each(|avg, &p| {
*avg = *avg + p;
});
}
}
for param in &mut self.averaged_params {
param.mapv_inplace(|x| x / numnodes);
}
Ok(())
}
fn weighted_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
for param in &mut self.averaged_params {
param.fill(A::zero());
}
let total_weight: A = nodeparameters
.iter()
.map(|(nodeid, _)| self.node_weights.get(nodeid).copied().unwrap_or(A::zero()))
.fold(A::zero(), |acc, w| acc + w);
if total_weight <= A::zero() {
return Err(OptimError::InvalidConfig(
"Total node weights must be > 0".to_string(),
));
}
for (nodeid, params) in nodeparameters {
let weight = self.node_weights.get(nodeid).copied().unwrap_or(A::zero()) / total_weight;
for (avg_param, param) in self.averaged_params.iter_mut().zip(params.iter()) {
Zip::from(avg_param).and(param).for_each(|avg, &p| {
*avg = *avg + weight * p;
});
}
}
Ok(())
}
fn federated_average(&mut self, nodeparameters: &[(usize, Vec<Array<A, D>>)]) -> Result<()> {
self.weighted_average(nodeparameters)
}
fn momentum_average(
&mut self,
nodeparameters: &[(usize, Vec<Array<A, D>>)],
momentum: f64,
) -> Result<()> {
let momentum_factor = A::from(momentum).expect("unwrap failed");
let one_minus_momentum = A::one() - momentum_factor;
let mut current_average: Vec<Array<A, D>> = self
.averaged_params
.iter()
.map(|param| Array::zeros(param.raw_dim()))
.collect();
let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
for (_node_id, params) in nodeparameters {
for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
Zip::from(avg_param).and(param).for_each(|avg, &p| {
*avg = *avg + p / numnodes;
});
}
}
if let Some(ref mut momentum_buf) = self.momentum_buffer {
for ((avg_param, current_param), momentum_param) in self
.averaged_params
.iter_mut()
.zip(current_average.iter())
.zip(momentum_buf.iter_mut())
{
Zip::from(&mut *momentum_param)
.and(current_param)
.for_each(|mom, &curr| {
*mom = momentum_factor * *mom + one_minus_momentum * curr;
});
avg_param.assign(&*momentum_param);
}
}
Ok(())
}
fn ema_average(
&mut self,
nodeparameters: &[(usize, Vec<Array<A, D>>)],
decay: f64,
) -> Result<()> {
let decay_factor = A::from(decay).expect("unwrap failed");
let one_minus_decay = A::one() - decay_factor;
let mut current_average: Vec<Array<A, D>> = self
.averaged_params
.iter()
.map(|param| Array::zeros(param.raw_dim()))
.collect();
let numnodes = A::from(nodeparameters.len()).expect("unwrap failed");
for (_node_id, params) in nodeparameters {
for (avg_param, param) in current_average.iter_mut().zip(params.iter()) {
Zip::from(avg_param).and(param).for_each(|avg, &p| {
*avg = *avg + p / numnodes;
});
}
}
for (avg_param, current_param) in
self.averaged_params.iter_mut().zip(current_average.iter())
{
Zip::from(avg_param)
.and(current_param)
.for_each(|avg, &curr| {
*avg = decay_factor * *avg + one_minus_decay * curr;
});
}
Ok(())
}
pub fn get_averaged_parameters(&self) -> &[Array<A, D>] {
&self.averaged_params
}
pub fn get_averaged_parameters_cloned(&self) -> Vec<Array<A, D>> {
self.averaged_params.clone()
}
pub fn reset(&mut self) {
self.step_count = 0;
for param in &mut self.averaged_params {
param.fill(A::zero());
}
if let Some(ref mut momentum_buf) = self.momentum_buffer {
for buf in momentum_buf {
buf.fill(A::zero());
}
}
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn numnodes(&self) -> usize {
self.numnodes
}
pub fn strategy(&self) -> AveragingStrategy {
self.strategy
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
}
#[derive(Debug)]
pub struct ParameterServer<A: Float, D: Dimension> {
averager: ParameterAverager<A, D>,
global_parameters: Vec<Array<A, D>>,
update_counts: HashMap<usize, usize>,
expected_updates_per_round: usize,
current_round: usize,
pending_updates: HashMap<usize, Vec<Array<A, D>>>,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
ParameterServer<A, D>
{
pub fn new(
strategy: AveragingStrategy,
numnodes: usize,
expected_updates_per_round: usize,
) -> Self {
Self {
averager: ParameterAverager::new(strategy, numnodes),
global_parameters: Vec::new(),
update_counts: HashMap::new(),
expected_updates_per_round,
current_round: 0,
pending_updates: HashMap::new(),
}
}
pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
self.averager.initialize(initialparams)?;
self.global_parameters = initialparams.to_vec();
for nodeid in 0..self.averager.numnodes() {
self.update_counts.insert(nodeid, 0);
}
Ok(())
}
pub fn submit_update(&mut self, nodeid: usize, parameters: Vec<Array<A, D>>) -> Result<bool> {
if nodeid >= self.averager.numnodes() {
return Err(OptimError::InvalidConfig(format!(
"Node ID {} exceeds number of nodes {}",
nodeid,
self.averager.numnodes()
)));
}
self.pending_updates.insert(nodeid, parameters);
*self.update_counts.entry(nodeid).or_insert(0) += 1;
let ready_for_aggregation = self.pending_updates.len() >= self.expected_updates_per_round;
if ready_for_aggregation {
self.aggregate_and_update()?;
}
Ok(ready_for_aggregation)
}
pub fn force_aggregation(&mut self) -> Result<()> {
if !self.pending_updates.is_empty() {
self.aggregate_and_update()?;
}
Ok(())
}
fn aggregate_and_update(&mut self) -> Result<()> {
let node_params: Vec<(usize, Vec<Array<A, D>>)> = self.pending_updates.drain().collect();
self.averager.average_parameters(&node_params)?;
self.global_parameters = self.averager.get_averaged_parameters_cloned();
self.current_round += 1;
Ok(())
}
pub fn get_global_parameters(&self) -> &[Array<A, D>] {
&self.global_parameters
}
pub fn get_global_parameters_cloned(&self) -> Vec<Array<A, D>> {
self.global_parameters.clone()
}
pub fn current_round(&self) -> usize {
self.current_round
}
pub fn get_update_count(&self, nodeid: usize) -> usize {
self.update_counts.get(&nodeid).copied().unwrap_or(0)
}
pub fn pending_updates_count(&self) -> usize {
self.pending_updates.len()
}
pub fn set_node_weight(&mut self, nodeid: usize, weight: A) -> Result<()> {
self.averager.set_node_weight(nodeid, weight)
}
pub fn reset(&mut self) {
self.averager.reset();
self.update_counts.clear();
self.pending_updates.clear();
self.current_round = 0;
for nodeid in 0..self.averager.numnodes() {
self.update_counts.insert(nodeid, 0);
}
}
}
#[derive(Debug)]
pub struct DistributedCoordinator<A: Float, D: Dimension> {
parameter_server: ParameterServer<A, D>,
communication_rounds: usize,
convergence_threshold: A,
max_rounds: usize,
training_stats: TrainingStats<A>,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
DistributedCoordinator<A, D>
{
pub fn new(
strategy: AveragingStrategy,
numnodes: usize,
expected_updates_per_round: usize,
max_rounds: usize,
) -> Self {
Self {
parameter_server: ParameterServer::new(strategy, numnodes, expected_updates_per_round),
communication_rounds: 0,
convergence_threshold: A::from(1e-6).expect("unwrap failed"),
max_rounds,
training_stats: TrainingStats::new(),
}
}
pub fn initialize(&mut self, initialparams: &[Array<A, D>]) -> Result<()> {
self.parameter_server.initialize(initialparams)?;
self.training_stats
.record_round(0, A::zero(), initialparams);
Ok(())
}
pub fn communication_round(
&mut self,
node_updates: Vec<(usize, Vec<Array<A, D>>)>,
) -> Result<CommunicationResult<A, D>> {
let mut aggregated = false;
for (nodeid, params) in node_updates {
aggregated = self.parameter_server.submit_update(nodeid, params)? || aggregated;
}
if !aggregated {
self.parameter_server.force_aggregation()?;
aggregated = true;
}
if aggregated {
self.communication_rounds += 1;
let currentparams = self.parameter_server.get_global_parameters();
let convergence_metric = self.compute_convergence_metric(currentparams);
self.training_stats.record_round(
self.communication_rounds,
convergence_metric,
currentparams,
);
let converged = convergence_metric < self.convergence_threshold;
let max_rounds_reached = self.communication_rounds >= self.max_rounds;
Ok(CommunicationResult {
round: self.communication_rounds,
global_parameters: self.parameter_server.get_global_parameters_cloned(),
converged,
should_continue: !converged && !max_rounds_reached,
convergence_metric,
stats: self.training_stats.clone(),
})
} else {
Ok(CommunicationResult {
round: self.communication_rounds,
global_parameters: self.parameter_server.get_global_parameters_cloned(),
converged: false,
should_continue: true,
convergence_metric: A::infinity(),
stats: self.training_stats.clone(),
})
}
}
pub fn set_convergence_threshold(&mut self, threshold: A) {
self.convergence_threshold = threshold;
}
pub fn parameter_server(&self) -> &ParameterServer<A, D> {
&self.parameter_server
}
pub fn parameter_server_mut(&mut self) -> &mut ParameterServer<A, D> {
&mut self.parameter_server
}
fn compute_convergence_metric(&self, currentparams: &[Array<A, D>]) -> A {
if let Some(prev_params) = self.training_stats.get_previous_parameters::<D>() {
let mut total_change = A::zero();
let mut total_norm = A::zero();
for (curr, prev) in currentparams.iter().zip(prev_params.iter()) {
for (&c, &p) in curr.iter().zip(prev.iter()) {
let diff = c - p;
total_change = total_change + diff * diff;
total_norm = total_norm + c * c;
}
}
if total_norm > A::zero() {
(total_change / total_norm).sqrt()
} else {
A::zero()
}
} else {
A::infinity()
}
}
}
#[derive(Debug, Clone)]
pub struct CommunicationResult<A: Float, D: Dimension> {
pub round: usize,
pub global_parameters: Vec<Array<A, D>>,
pub converged: bool,
pub should_continue: bool,
pub convergence_metric: A,
pub stats: TrainingStats<A>,
}
#[derive(Debug, Clone)]
pub struct TrainingStats<A: Float> {
convergence_history: Vec<A>,
round_times: Vec<usize>,
previous_parameters: Option<Vec<u8>>, }
impl<A: Float + Send + Sync> TrainingStats<A> {
pub fn new() -> Self {
Self {
convergence_history: Vec::new(),
round_times: Vec::new(),
previous_parameters: None,
}
}
pub fn record_round<D: Dimension>(
&mut self,
round: usize,
convergence_metric: A,
parameters: &[Array<A, D>],
) {
self.convergence_history.push(convergence_metric);
self.round_times.push(round);
self.previous_parameters = Some(vec![0u8; parameters.len()]);
}
pub fn convergence_history(&self) -> &[A] {
&self.convergence_history
}
pub fn latest_convergence(&self) -> Option<A> {
self.convergence_history.last().copied()
}
pub fn num_rounds(&self) -> usize {
self.round_times.len()
}
fn get_previous_parameters<D: Dimension>(&self) -> Option<Vec<Array<A, D>>> {
None
}
}
impl<A: Float + Send + Sync> Default for TrainingStats<A> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum CompressionStrategy {
None,
TopK {
k: usize,
},
RandomK {
k: usize,
},
Threshold {
threshold: f64,
},
Quantization {
bits: u8,
},
ErrorFeedback {
base_strategy: Box<CompressionStrategy>,
error_compensation: bool,
},
ClippedCompression {
base_strategy: Box<CompressionStrategy>,
clip_value: f64,
},
}
#[derive(Debug, Clone)]
pub struct CompressedGradient<A: Float> {
pub data: Vec<u8>,
pub metadata: CompressionMetadata<A>,
pub shapes: Vec<Vec<usize>>,
}
#[derive(Debug, Clone)]
pub struct CompressionMetadata<A: Float> {
pub strategy: CompressionStrategy,
pub compression_ratio: f64,
pub nnz_count: usize,
pub scale_factors: Vec<A>,
pub extra_data: Vec<u8>,
}
#[derive(Debug)]
pub struct GradientCompressor<A: Float, D: Dimension> {
strategy: CompressionStrategy,
error_state: Option<Vec<Array<A, D>>>,
stats: CompressionStats,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
GradientCompressor<A, D>
{
pub fn new(strategy: CompressionStrategy) -> Self {
Self {
strategy,
error_state: None,
stats: CompressionStats::new(),
}
}
pub fn initialize_error_state(&mut self, gradientshapes: &[Array<A, D>]) {
self.error_state = Some(
gradientshapes
.iter()
.map(|g| Array::zeros(g.raw_dim()))
.collect(),
);
}
pub fn compress(&mut self, gradients: &[Array<A, D>]) -> Result<CompressedGradient<A>> {
let mut working_gradients: Vec<Array<A, D>> =
if let Some(ref mut error_state) = self.error_state {
gradients
.iter()
.zip(error_state.iter())
.map(|(grad, error)| grad + error)
.collect()
} else {
gradients.to_vec()
};
let (compressed_data, metadata) = match &self.strategy {
CompressionStrategy::None => self.compress_none(&working_gradients)?,
CompressionStrategy::TopK { k } => self.compress_topk(&working_gradients, *k)?,
CompressionStrategy::RandomK { k } => self.compress_randomk(&working_gradients, *k)?,
CompressionStrategy::Threshold { threshold } => self.compress_threshold(
&working_gradients,
A::from(*threshold).expect("unwrap failed"),
)?,
CompressionStrategy::Quantization { bits } => {
self.compress_quantization(&working_gradients, *bits)?
}
CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
let compressed = temp_compressor.compress(&working_gradients)?;
let decompressed = temp_compressor.decompress(&compressed)?;
if let Some(ref mut error_state) = self.error_state {
for ((original, decompressed), error) in gradients
.iter()
.zip(decompressed.iter())
.zip(error_state.iter_mut())
{
*error = original - decompressed;
}
}
(compressed.data, compressed.metadata)
}
CompressionStrategy::ClippedCompression {
base_strategy,
clip_value,
} => {
let clip_val = A::from(*clip_value).expect("unwrap failed");
for grad in &mut working_gradients {
grad.mapv_inplace(|x| {
if x > clip_val {
clip_val
} else if x < -clip_val {
-clip_val
} else {
x
}
});
}
let mut temp_compressor = GradientCompressor::new((**base_strategy).clone());
let compressed = temp_compressor.compress(&working_gradients)?;
(compressed.data, compressed.metadata)
}
};
let shapes = gradients.iter().map(|g| g.shape().to_vec()).collect();
let result = CompressedGradient {
data: compressed_data,
metadata,
shapes,
};
let original_size = self.calculate_size(gradients);
let compressed_size = result.data.len();
self.stats
.record_compression(original_size, compressed_size);
Ok(result)
}
pub fn decompress(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
match &compressed.metadata.strategy {
CompressionStrategy::None => self.decompress_none(compressed),
CompressionStrategy::TopK { .. } => self.decompress_sparse(compressed),
CompressionStrategy::RandomK { .. } => self.decompress_sparse(compressed),
CompressionStrategy::Threshold { .. } => self.decompress_sparse(compressed),
CompressionStrategy::Quantization { bits } => {
self.decompress_quantization(compressed, *bits)
}
CompressionStrategy::ErrorFeedback { base_strategy, .. } => {
let temp_compressor = GradientCompressor::new((**base_strategy).clone());
temp_compressor.decompress(compressed)
}
CompressionStrategy::ClippedCompression { base_strategy, .. } => {
let temp_compressor = GradientCompressor::new((**base_strategy).clone());
temp_compressor.decompress(compressed)
}
}
}
fn compress_none(
&self,
gradients: &[Array<A, D>],
) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
let mut data = Vec::new();
for grad in gradients {
for &val in grad.iter() {
data.extend_from_slice(&val.to_f64().expect("unwrap failed").to_le_bytes());
}
}
let metadata = CompressionMetadata {
strategy: CompressionStrategy::None,
compression_ratio: 1.0,
nnz_count: gradients.iter().map(|g| g.len()).sum(),
scale_factors: Vec::new(),
extra_data: Vec::new(),
};
Ok((data, metadata))
}
fn compress_topk(
&self,
gradients: &[Array<A, D>],
k: usize,
) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
let mut indices = Vec::new();
let mut values = Vec::new();
let mut total_elements = 0;
for (grad_idx, grad) in gradients.iter().enumerate() {
total_elements += grad.len();
let mut value_indices: Vec<(A, usize)> = grad
.iter()
.enumerate()
.map(|(i, &val)| (val.abs(), i))
.collect();
value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("unwrap failed"));
let k_local = k.min(value_indices.len());
for (_, orig_idx) in value_indices.iter().take(k_local) {
indices.push((grad_idx as u32, *orig_idx as u32));
values.push(grad.iter().nth(*orig_idx).copied().expect("unwrap failed"));
}
}
let mut data = Vec::new();
data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
data.extend_from_slice(&grad_idx.to_le_bytes());
data.extend_from_slice(&elem_idx.to_le_bytes());
data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
}
let metadata = CompressionMetadata {
strategy: CompressionStrategy::TopK { k },
compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
nnz_count: indices.len(),
scale_factors: Vec::new(),
extra_data: Vec::new(),
};
Ok((data, metadata))
}
fn compress_randomk(
&self,
gradients: &[Array<A, D>],
k: usize,
) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
let mut indices = Vec::new();
let mut values = Vec::new();
let mut total_elements = 0;
for (grad_idx, grad) in gradients.iter().enumerate() {
total_elements += grad.len();
let k_local = k.min(grad.len());
let mut selected_indices: Vec<usize> = (0..grad.len()).collect();
for i in 0..k_local {
let swap_idx = i + ((grad_idx + i) % (grad.len() - i));
selected_indices.swap(i, swap_idx);
}
for &idx in selected_indices.iter().take(k_local) {
indices.push((grad_idx as u32, idx as u32));
values.push(grad.iter().nth(idx).copied().expect("unwrap failed"));
}
}
let mut data = Vec::new();
data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
data.extend_from_slice(&grad_idx.to_le_bytes());
data.extend_from_slice(&elem_idx.to_le_bytes());
data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
}
let metadata = CompressionMetadata {
strategy: CompressionStrategy::RandomK { k },
compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
nnz_count: indices.len(),
scale_factors: Vec::new(),
extra_data: Vec::new(),
};
Ok((data, metadata))
}
fn compress_threshold(
&self,
gradients: &[Array<A, D>],
threshold: A,
) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
let mut indices = Vec::new();
let mut values = Vec::new();
let mut total_elements = 0;
for (grad_idx, grad) in gradients.iter().enumerate() {
total_elements += grad.len();
for (elem_idx, &val) in grad.iter().enumerate() {
if val.abs() > threshold {
indices.push((grad_idx as u32, elem_idx as u32));
values.push(val);
}
}
}
let mut data = Vec::new();
data.extend_from_slice(&(indices.len() as u32).to_le_bytes());
for ((grad_idx, elem_idx), value) in indices.iter().zip(values.iter()) {
data.extend_from_slice(&grad_idx.to_le_bytes());
data.extend_from_slice(&elem_idx.to_le_bytes());
data.extend_from_slice(&value.to_f64().expect("unwrap failed").to_le_bytes());
}
let metadata = CompressionMetadata {
strategy: CompressionStrategy::Threshold {
threshold: threshold.to_f64().expect("unwrap failed"),
},
compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
nnz_count: indices.len(),
scale_factors: Vec::new(),
extra_data: Vec::new(),
};
Ok((data, metadata))
}
fn compress_quantization(
&self,
gradients: &[Array<A, D>],
bits: u8,
) -> Result<(Vec<u8>, CompressionMetadata<A>)> {
if bits > 32 {
return Err(OptimError::InvalidConfig(
"Quantization bits must be <= 32".to_string(),
));
}
let mut data = Vec::new();
let mut scale_factors = Vec::new();
let levels = (1u64 << bits) - 1;
for grad in gradients {
let min_val = grad.iter().fold(A::infinity(), |acc, &x| acc.min(x));
let max_val = grad.iter().fold(A::neg_infinity(), |acc, &x| acc.max(x));
let range = max_val - min_val;
let scale = if range > A::zero() {
range / A::from(levels).expect("unwrap failed")
} else {
A::one()
};
scale_factors.push(scale);
for &val in grad.iter() {
let normalized = (val - min_val) / scale;
let quantized = normalized.to_u64().expect("unwrap failed").min(levels) as u32;
match bits {
1..=8 => data.push(quantized as u8),
9..=16 => data.extend_from_slice(&(quantized as u16).to_le_bytes()),
17..=32 => data.extend_from_slice(&quantized.to_le_bytes()),
_ => unreachable!(),
}
}
data.extend_from_slice(&min_val.to_f64().expect("unwrap failed").to_le_bytes());
}
let total_elements: usize = gradients.iter().map(|g| g.len()).sum();
let metadata = CompressionMetadata {
strategy: CompressionStrategy::Quantization { bits },
compression_ratio: data.len() as f64 / (total_elements * 8) as f64,
nnz_count: total_elements,
scale_factors,
extra_data: Vec::new(),
};
Ok((data, metadata))
}
fn decompress_none(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
let mut result = Vec::new();
let mut data_offset = 0;
for shape in &compressed.shapes {
let num_elements: usize = shape.iter().product();
let mut values = Vec::with_capacity(num_elements);
for _ in 0..num_elements {
if data_offset + 8 > compressed.data.len() {
return Err(OptimError::InvalidConfig(
"Insufficient data for decompression".to_string(),
));
}
let bytes = &compressed.data[data_offset..data_offset + 8];
let value = f64::from_le_bytes(bytes.try_into().expect("unwrap failed"));
values.push(A::from(value).expect("unwrap failed"));
data_offset += 8;
}
let dynamic_array = Array::from_shape_vec(shape.as_slice(), values).map_err(|_| {
OptimError::InvalidConfig("Invalid shape for reconstruction".to_string())
})?;
let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
OptimError::InvalidConfig("Dimension conversion failed".to_string())
})?;
result.push(array);
}
Ok(result)
}
fn decompress_sparse(&self, compressed: &CompressedGradient<A>) -> Result<Vec<Array<A, D>>> {
let mut result = Vec::new();
for shape in &compressed.shapes {
let dynamic_array = Array::zeros(shape.as_slice());
let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
OptimError::InvalidConfig("Dimension conversion failed for zero array".to_string())
})?;
result.push(array);
}
if compressed.data.len() < 4 {
return Err(OptimError::InvalidConfig(
"Invalid compressed data format".to_string(),
));
}
let num_elements =
u32::from_le_bytes(compressed.data[0..4].try_into().expect("unwrap failed")) as usize;
let mut data_offset = 4;
for _ in 0..num_elements {
if data_offset + 16 > compressed.data.len() {
return Err(OptimError::InvalidConfig(
"Insufficient data for sparse decompression".to_string(),
));
}
let grad_idx = u32::from_le_bytes(
compressed.data[data_offset..data_offset + 4]
.try_into()
.expect("unwrap failed"),
) as usize;
let elem_idx = u32::from_le_bytes(
compressed.data[data_offset + 4..data_offset + 8]
.try_into()
.expect("unwrap failed"),
) as usize;
let value_bytes = &compressed.data[data_offset + 8..data_offset + 16];
let value = A::from(f64::from_le_bytes(
value_bytes.try_into().expect("unwrap failed"),
))
.expect("unwrap failed");
data_offset += 16;
if grad_idx >= result.len() {
return Err(OptimError::InvalidConfig(
"Invalid gradient index in compressed data".to_string(),
));
}
if let Some(elem) = result[grad_idx].iter_mut().nth(elem_idx) {
*elem = value;
} else {
return Err(OptimError::InvalidConfig(
"Invalid element index in compressed data".to_string(),
));
}
}
Ok(result)
}
fn decompress_quantization(
&self,
compressed: &CompressedGradient<A>,
bits: u8,
) -> Result<Vec<Array<A, D>>> {
let mut result = Vec::new();
let mut data_offset = 0;
let _levels = (1u64 << bits) - 1;
for (grad_idx, shape) in compressed.shapes.iter().enumerate() {
let num_elements: usize = shape.iter().product();
let mut values = Vec::with_capacity(num_elements);
for _ in 0..num_elements {
let quantized = match bits {
1..=8 => {
if data_offset >= compressed.data.len() {
return Err(OptimError::InvalidConfig(
"Insufficient quantized data".to_string(),
));
}
let val = compressed.data[data_offset] as u32;
data_offset += 1;
val
}
9..=16 => {
if data_offset + 2 > compressed.data.len() {
return Err(OptimError::InvalidConfig(
"Insufficient quantized data".to_string(),
));
}
let val = u16::from_le_bytes(
compressed.data[data_offset..data_offset + 2]
.try_into()
.expect("unwrap failed"),
) as u32;
data_offset += 2;
val
}
17..=32 => {
if data_offset + 4 > compressed.data.len() {
return Err(OptimError::InvalidConfig(
"Insufficient quantized data".to_string(),
));
}
let val = u32::from_le_bytes(
compressed.data[data_offset..data_offset + 4]
.try_into()
.expect("unwrap failed"),
);
data_offset += 4;
val
}
_ => {
return Err(OptimError::InvalidConfig(
"Invalid quantization bits".to_string(),
))
}
};
values.push(quantized);
}
if data_offset + 8 > compressed.data.len() {
return Err(OptimError::InvalidConfig(
"Missing min value for quantization".to_string(),
));
}
let min_bytes = &compressed.data[data_offset..data_offset + 8];
let min_val = A::from(f64::from_le_bytes(
min_bytes.try_into().expect("unwrap failed"),
))
.expect("unwrap failed");
data_offset += 8;
let scale = if grad_idx < compressed.metadata.scale_factors.len() {
compressed.metadata.scale_factors[grad_idx]
} else {
return Err(OptimError::InvalidConfig(
"Missing scale factor for quantization".to_string(),
));
};
let dequantized_values: Vec<A> = values
.into_iter()
.map(|q| min_val + A::from(q).expect("unwrap failed") * scale)
.collect();
let dynamic_array = Array::from_shape_vec(shape.as_slice(), dequantized_values)
.map_err(|_| {
OptimError::InvalidConfig(
"Invalid shape for quantized reconstruction".to_string(),
)
})?;
let array = dynamic_array.into_dimensionality::<D>().map_err(|_| {
OptimError::InvalidConfig(
"Dimension conversion failed for quantized array".to_string(),
)
})?;
result.push(array);
}
Ok(result)
}
fn calculate_size(&self, gradients: &[Array<A, D>]) -> usize {
gradients
.iter()
.map(|g| g.len() * std::mem::size_of::<A>())
.sum()
}
pub fn stats(&self) -> &CompressionStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = CompressionStats::new();
}
}
#[derive(Debug, Clone)]
pub struct CompressionStats {
pub compressions_count: usize,
pub total_original_bytes: usize,
pub total_compressed_bytes: usize,
pub average_compression_ratio: f64,
pub best_compression_ratio: f64,
pub worst_compression_ratio: f64,
}
impl CompressionStats {
pub fn new() -> Self {
Self {
compressions_count: 0,
total_original_bytes: 0,
total_compressed_bytes: 0,
average_compression_ratio: 0.0,
best_compression_ratio: f64::INFINITY,
worst_compression_ratio: 0.0,
}
}
pub fn record_compression(&mut self, original_bytes: usize, compressedbytes: usize) {
self.compressions_count += 1;
self.total_original_bytes += original_bytes;
self.total_compressed_bytes += compressedbytes;
let ratio = if original_bytes > 0 {
compressedbytes as f64 / original_bytes as f64
} else {
1.0
};
self.best_compression_ratio = self.best_compression_ratio.min(ratio);
self.worst_compression_ratio = self.worst_compression_ratio.max(ratio);
self.average_compression_ratio = if self.total_original_bytes > 0 {
self.total_compressed_bytes as f64 / self.total_original_bytes as f64
} else {
0.0
};
}
pub fn overall_compression_ratio(&self) -> f64 {
self.average_compression_ratio
}
pub fn bandwidth_savings(&self) -> f64 {
(1.0 - self.average_compression_ratio) * 100.0
}
}
impl Default for CompressionStats {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_arithmetic_averaging() {
let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
ParameterAverager::new(AveragingStrategy::Arithmetic, 3);
let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
let params2 = vec![Array1::from_vec(vec![3.0, 4.0])];
let params3 = vec![Array1::from_vec(vec![5.0, 6.0])];
let nodeparameters = vec![(0, params1), (1, params2), (2, params3)];
averager
.average_parameters(&nodeparameters)
.expect("unwrap failed");
let result = averager.get_averaged_parameters();
assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6); assert_relative_eq!(result[0][1], 4.0, epsilon = 1e-6); }
#[test]
fn test_weighted_averaging() {
let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
let params1 = vec![Array1::from_vec(vec![2.0])];
let params2 = vec![Array1::from_vec(vec![6.0])];
let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
averager.initialize(¶ms1).expect("unwrap failed");
averager.set_node_weight(0, 0.75).expect("unwrap failed"); averager.set_node_weight(1, 0.25).expect("unwrap failed");
averager
.average_parameters(&nodeparameters)
.expect("unwrap failed");
let result = averager.get_averaged_parameters();
assert_relative_eq!(result[0][0], 3.0, epsilon = 1e-6);
}
#[test]
fn test_momentum_averaging() {
let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
ParameterAverager::new(AveragingStrategy::Momentum { momentum: 0.9 }, 2);
let params1 = vec![Array1::from_vec(vec![1.0])];
let params2 = vec![Array1::from_vec(vec![3.0])];
let node_parameters1 = vec![(0, params1.clone()), (1, params2.clone())];
averager
.average_parameters(&node_parameters1)
.expect("unwrap failed");
let result1 = averager.get_averaged_parameters();
assert!(result1[0][0] >= 0.0 && result1[0][0] <= 0.5);
for _ in 0..10 {
let nodeparameters = vec![(0, params1.clone()), (1, params2.clone())];
averager
.average_parameters(&nodeparameters)
.expect("unwrap failed");
}
let final_result = averager.get_averaged_parameters();
assert!(final_result[0][0] > 0.5 && final_result[0][0] < 2.5);
}
#[test]
fn test_parameter_server() {
let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
server.initialize(&initialparams).expect("unwrap failed");
let update1 = vec![Array1::from_vec(vec![1.0, 2.0])];
let update2 = vec![Array1::from_vec(vec![3.0, 4.0])];
let ready1 = server.submit_update(0, update1).expect("unwrap failed");
assert!(!ready1);
let ready2 = server.submit_update(1, update2).expect("unwrap failed");
assert!(ready2);
let global_params = server.get_global_parameters();
assert_relative_eq!(global_params[0][0], 2.0, epsilon = 1e-6); assert_relative_eq!(global_params[0][1], 3.0, epsilon = 1e-6);
assert_eq!(server.current_round(), 1);
}
#[test]
fn test_distributed_coordinator() {
let mut coordinator = DistributedCoordinator::new(
AveragingStrategy::Arithmetic,
2, 2, 10, );
let initialparams = vec![Array1::from_vec(vec![0.0])];
coordinator
.initialize(&initialparams)
.expect("unwrap failed");
for round in 1..=3 {
let update1 = vec![Array1::from_vec(vec![round as f64])];
let update2 = vec![Array1::from_vec(vec![(round * 2) as f64])];
let node_updates = vec![(0, update1), (1, update2)];
let result = coordinator
.communication_round(node_updates)
.expect("unwrap failed");
assert_eq!(result.round, round);
assert!(result.should_continue);
assert!(!result.converged);
assert!(result.global_parameters[0][0] > 0.0);
}
}
#[test]
fn test_averaging_strategies() {
let simple_strategies = vec![
AveragingStrategy::Arithmetic,
AveragingStrategy::WeightedByData,
AveragingStrategy::Federated,
];
for strategy in simple_strategies {
let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
ParameterAverager::new(strategy, 2);
let params1 = vec![Array1::from_vec(vec![1.0])];
let params2 = vec![Array1::from_vec(vec![3.0])];
let nodeparameters = vec![(0, params1), (1, params2)];
averager
.average_parameters(&nodeparameters)
.expect("unwrap failed");
let result = averager.get_averaged_parameters();
assert!(result[0][0] >= 1.0 && result[0][0] <= 3.0);
}
let stateful_strategies = vec![
AveragingStrategy::Momentum { momentum: 0.9 },
AveragingStrategy::ExponentialMovingAverage { decay: 0.9 },
];
for strategy in stateful_strategies {
let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
ParameterAverager::new(strategy, 2);
let params1 = vec![Array1::from_vec(vec![1.0])];
let params2 = vec![Array1::from_vec(vec![3.0])];
let nodeparameters = vec![(0, params1), (1, params2)];
averager
.average_parameters(&nodeparameters)
.expect("unwrap failed");
let result = averager.get_averaged_parameters();
assert!(result[0][0] >= 0.0 && result[0][0] <= 3.0);
}
}
#[test]
fn test_node_weight_validation() {
let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
ParameterAverager::new(AveragingStrategy::WeightedByData, 2);
assert!(averager.set_node_weight(0, 0.5).is_ok());
assert!(averager.set_node_weight(1, 0.5).is_ok());
assert!(averager.set_node_weight(2, 0.5).is_err());
}
#[test]
fn test_parameter_dimension_validation() {
let mut averager: ParameterAverager<f64, scirs2_core::ndarray::Ix1> =
ParameterAverager::new(AveragingStrategy::Arithmetic, 2);
let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
let params2 = vec![Array1::from_vec(vec![3.0])];
let nodeparameters = vec![(0, params1), (1, params2)];
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
averager.average_parameters(&nodeparameters)
}));
assert!(result.is_err() || (result.is_ok() && result.expect("unwrap failed").is_err()));
}
#[test]
fn test_training_stats() {
let mut stats = TrainingStats::new();
assert_eq!(stats.num_rounds(), 0);
assert!(stats.latest_convergence().is_none());
let params = vec![Array1::from_vec(vec![1.0])];
stats.record_round(1, 0.5, ¶ms);
assert_eq!(stats.num_rounds(), 1);
assert_eq!(stats.latest_convergence(), Some(0.5));
assert_eq!(stats.convergence_history(), &[0.5]);
}
#[test]
fn test_gradient_compression_none() {
let mut compressor = GradientCompressor::new(CompressionStrategy::None);
let gradients = vec![
Array1::from_vec(vec![1.0, 2.0, 3.0]),
Array1::from_vec(vec![4.0, 5.0]),
];
let compressed = compressor.compress(&gradients).expect("unwrap failed");
assert_eq!(compressed.metadata.strategy, CompressionStrategy::None);
assert_eq!(compressed.metadata.compression_ratio, 1.0);
let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
assert_eq!(decompressed.len(), 2);
assert_eq!(
decompressed[0].as_slice().expect("unwrap failed"),
&[1.0, 2.0, 3.0]
);
assert_eq!(
decompressed[1].as_slice().expect("unwrap failed"),
&[4.0, 5.0]
);
}
#[test]
fn test_gradient_compression_topk() {
let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 2 });
let gradients = vec![Array1::from_vec(vec![0.1, 3.0, 0.2, 4.0, 0.05])];
let compressed = compressor.compress(&gradients).expect("unwrap failed");
assert!(compressed.metadata.compression_ratio < 1.0);
assert_eq!(compressed.metadata.nnz_count, 2);
let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
assert_eq!(decompressed.len(), 1);
let result = &decompressed[0];
assert_eq!(result[1], 3.0); assert_eq!(result[3], 4.0); assert_eq!(result[0], 0.0); assert_eq!(result[2], 0.0); assert_eq!(result[4], 0.0); }
#[test]
fn test_gradient_compression_threshold() {
let mut compressor =
GradientCompressor::new(CompressionStrategy::Threshold { threshold: 1.0 });
let gradients = vec![Array1::from_vec(vec![0.5, 2.0, 0.8, 3.0, 0.3])];
let compressed = compressor.compress(&gradients).expect("unwrap failed");
assert!(compressed.metadata.compression_ratio < 1.0);
assert_eq!(compressed.metadata.nnz_count, 2);
let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
let result = &decompressed[0];
assert_eq!(result[0], 0.0); assert_eq!(result[1], 2.0); assert_eq!(result[2], 0.0); assert_eq!(result[3], 3.0); assert_eq!(result[4], 0.0); }
#[test]
fn test_gradient_compression_quantization() {
let mut compressor = GradientCompressor::new(CompressionStrategy::Quantization { bits: 8 });
let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
let compressed = compressor.compress(&gradients).expect("unwrap failed");
assert!(compressed.metadata.compression_ratio < 1.0);
let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
let result = &decompressed[0];
assert!((result[0] - 1.0).abs() < 0.1);
assert!((result[1] - 2.0).abs() < 0.1);
assert!((result[2] - 3.0).abs() < 0.1);
assert!((result[3] - 4.0).abs() < 0.1);
}
#[test]
fn test_gradient_compression_randomk() {
let mut compressor = GradientCompressor::new(CompressionStrategy::RandomK { k: 3 });
let gradients = vec![Array1::from_vec(vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0,
])];
let compressed = compressor.compress(&gradients).expect("unwrap failed");
assert!(compressed.metadata.compression_ratio < 1.0);
assert_eq!(compressed.metadata.nnz_count, 3);
let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
let result = &decompressed[0];
let non_zero_count = result.iter().filter(|&&x| x != 0.0).count();
assert_eq!(non_zero_count, 3);
}
#[test]
fn test_gradient_compression_error_feedback() {
let base_strategy = CompressionStrategy::TopK { k: 2 };
let strategy = CompressionStrategy::ErrorFeedback {
base_strategy: Box::new(base_strategy),
error_compensation: true,
};
let mut compressor = GradientCompressor::new(strategy);
let gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
compressor.initialize_error_state(&gradients);
let compressed1 = compressor.compress(&gradients).expect("unwrap failed");
let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
let compressed2 = compressor.compress(&gradients).expect("unwrap failed");
let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
assert_eq!(decompressed1.len(), 1);
assert_eq!(decompressed2.len(), 1);
}
#[test]
fn test_gradient_compression_clipped() {
let base_strategy = CompressionStrategy::TopK { k: 3 };
let strategy = CompressionStrategy::ClippedCompression {
base_strategy: Box::new(base_strategy),
clip_value: 2.5,
};
let mut compressor = GradientCompressor::new(strategy);
let gradients = vec![Array1::from_vec(vec![1.0, 5.0, -3.0, 2.0])];
let compressed = compressor.compress(&gradients).expect("unwrap failed");
let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
let result = &decompressed[0];
for &val in result.iter() {
if val != 0.0 {
assert!((-2.5..=2.5).contains(&val));
}
}
}
#[test]
fn test_compression_stats() {
let mut stats = CompressionStats::new();
assert_eq!(stats.compressions_count, 0);
assert_eq!(stats.overall_compression_ratio(), 0.0);
stats.record_compression(1000, 500); assert_eq!(stats.compressions_count, 1);
assert_relative_eq!(stats.overall_compression_ratio(), 0.5, epsilon = 1e-6);
assert_relative_eq!(stats.bandwidth_savings(), 50.0, epsilon = 1e-6);
stats.record_compression(1000, 250); assert_eq!(stats.compressions_count, 2);
assert_relative_eq!(stats.overall_compression_ratio(), 0.375, epsilon = 1e-6); assert_relative_eq!(stats.bandwidth_savings(), 62.5, epsilon = 1e-6);
assert_relative_eq!(stats.best_compression_ratio, 0.25, epsilon = 1e-6);
assert_relative_eq!(stats.worst_compression_ratio, 0.5, epsilon = 1e-6);
}
#[test]
fn test_compression_roundtrip() {
let strategies = vec![
CompressionStrategy::None,
CompressionStrategy::TopK { k: 2 },
CompressionStrategy::RandomK { k: 2 },
CompressionStrategy::Threshold { threshold: 1.5 },
CompressionStrategy::Quantization { bits: 4 },
];
let gradients = vec![
Array1::from_vec(vec![1.0, 2.5, 0.5, 3.0]),
Array1::from_vec(vec![0.1, 4.0]),
];
for strategy in strategies {
let mut compressor = GradientCompressor::new(strategy.clone());
let compressed = compressor.compress(&gradients).expect("unwrap failed");
let decompressed = compressor.decompress(&compressed).expect("unwrap failed");
assert_eq!(decompressed.len(), gradients.len());
for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
assert_eq!(orig.shape(), decomp.shape());
}
match strategy {
CompressionStrategy::None => {
for (orig, decomp) in gradients.iter().zip(decompressed.iter()) {
for (&o, &d) in orig.iter().zip(decomp.iter()) {
assert_relative_eq!(o, d, epsilon = 1e-10);
}
}
}
_ => {
for decomp in &decompressed {
assert!(decomp.iter().all(|&x| x.is_finite()));
}
}
}
}
}
#[test]
fn test_compression_invalid_configs() {
let strategy = CompressionStrategy::Quantization { bits: 64 };
let mut compressor = GradientCompressor::new(strategy);
let gradients = vec![Array1::from_vec(vec![1.0, 2.0])];
assert!(compressor.compress(&gradients).is_err());
let valid_compressor: GradientCompressor<f64, scirs2_core::ndarray::Ix1> =
GradientCompressor::new(CompressionStrategy::None);
let invalid_compressed = CompressedGradient {
data: vec![1, 2, 3], metadata: CompressionMetadata {
strategy: CompressionStrategy::None,
compression_ratio: 1.0,
nnz_count: 1,
scale_factors: vec![],
extra_data: vec![],
},
shapes: vec![vec![2]],
};
assert!(valid_compressor.decompress(&invalid_compressed).is_err());
}
#[test]
fn test_distributed_with_compression() {
let mut server = ParameterServer::new(AveragingStrategy::Arithmetic, 2, 2);
let initialparams = vec![Array1::from_vec(vec![0.0, 0.0])];
server.initialize(&initialparams).expect("unwrap failed");
let mut compressor = GradientCompressor::new(CompressionStrategy::TopK { k: 1 });
let gradients1 = vec![Array1::from_vec(vec![1.0, 3.0])]; let gradients2 = vec![Array1::from_vec(vec![2.0, 1.0])];
let compressed1 = compressor.compress(&gradients1).expect("unwrap failed");
let compressed2 = compressor.compress(&gradients2).expect("unwrap failed");
let decompressed1 = compressor.decompress(&compressed1).expect("unwrap failed");
let decompressed2 = compressor.decompress(&compressed2).expect("unwrap failed");
server
.submit_update(0, decompressed1)
.expect("unwrap failed");
server
.submit_update(1, decompressed2)
.expect("unwrap failed");
let global_params = server.get_global_parameters();
assert_relative_eq!(global_params[0][0], 1.0, epsilon = 1e-6);
assert_relative_eq!(global_params[0][1], 1.5, epsilon = 1e-6);
}
}