use crate::error::{DistributedError, DistributedResult};
use async_trait::async_trait;
use candle_core::Tensor;
use kwaai_compression::{BlockwiseQuantizer, Compressor, QuantizedTensor};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AveragingResult {
Success {
peers_count: usize,
compression_ratio: f32,
},
NoPeersAvailable,
InProgress {
ready_peers: usize,
target_size: usize,
},
Failed(String),
}
#[async_trait]
pub trait ParameterAverager: Send + Sync {
fn accumulate(&mut self, gradients: &[Tensor]) -> DistributedResult<()>;
async fn step(&mut self) -> DistributedResult<AveragingResult>;
fn get_accumulated(&self) -> &[Tensor];
fn clear(&mut self);
}
#[derive(Debug, Clone)]
pub struct AveragingConfig {
pub group_size: usize,
pub match_timeout: Duration,
pub exchange_timeout: Duration,
pub quantization_block_size: usize,
pub enable_compression: bool,
}
impl Default for AveragingConfig {
fn default() -> Self {
Self {
group_size: 4,
match_timeout: Duration::from_secs(30),
exchange_timeout: Duration::from_secs(60),
quantization_block_size: 64,
enable_compression: true,
}
}
}
pub struct DecentralizedAverager {
#[allow(dead_code)]
config: AveragingConfig,
accumulated: Vec<Tensor>,
compressor: BlockwiseQuantizer,
accumulation_count: usize,
}
impl DecentralizedAverager {
pub fn new(config: AveragingConfig) -> Self {
info!(
group_size = config.group_size,
compression = config.enable_compression,
"Creating DecentralizedAverager"
);
let compressor = BlockwiseQuantizer::new(config.quantization_block_size);
Self {
config,
accumulated: Vec::new(),
compressor,
accumulation_count: 0,
}
}
pub fn compress_gradients(
&self,
gradients: &[Tensor],
) -> DistributedResult<Vec<QuantizedTensor>> {
debug!("Compressing {} gradient tensors", gradients.len());
gradients
.iter()
.map(|g| self.compressor.compress(g).map_err(DistributedError::from))
.collect()
}
pub fn decompress_gradients(
&self,
compressed: &[QuantizedTensor],
) -> DistributedResult<Vec<Tensor>> {
debug!("Decompressing {} gradient tensors", compressed.len());
compressed
.iter()
.map(|c| {
self.compressor
.decompress(c)
.map_err(DistributedError::from)
})
.collect()
}
pub fn average_gradients(
&self,
gradient_sets: &[Vec<Tensor>],
) -> DistributedResult<Vec<Tensor>> {
if gradient_sets.is_empty() {
warn!("average_gradients called with no gradient sets");
return Err(DistributedError::AveragingFailed(
"No gradients to average".to_string(),
));
}
let num_sets = gradient_sets.len() as f64;
let first_set = &gradient_sets[0];
first_set
.iter()
.enumerate()
.map(|(i, _)| {
let mut sum = gradient_sets[0][i].clone();
for set in gradient_sets.iter().skip(1) {
sum = (&sum + &set[i])?;
}
Ok((sum / num_sets)?)
})
.collect()
}
}
#[async_trait]
impl ParameterAverager for DecentralizedAverager {
fn accumulate(&mut self, gradients: &[Tensor]) -> DistributedResult<()> {
if self.accumulated.is_empty() {
self.accumulated = gradients.to_vec();
} else {
if self.accumulated.len() != gradients.len() {
return Err(DistributedError::AveragingFailed(format!(
"Gradient count mismatch: {} vs {}",
self.accumulated.len(),
gradients.len()
)));
}
for (acc, grad) in self.accumulated.iter_mut().zip(gradients.iter()) {
*acc = (acc.clone() + grad)?;
}
}
self.accumulation_count += 1;
debug!(
count = self.accumulation_count,
tensors = gradients.len(),
"Accumulated gradients"
);
Ok(())
}
async fn step(&mut self) -> DistributedResult<AveragingResult> {
if self.accumulated.is_empty() {
debug!("Averaging step: no accumulated gradients");
return Ok(AveragingResult::NoPeersAvailable);
}
if self.accumulation_count > 0 {
let count = self.accumulation_count as f64;
for acc in &mut self.accumulated {
*acc = (acc.clone() / count)?;
}
self.accumulation_count = 0;
info!(peers = 1, "Averaging step completed");
Ok(AveragingResult::Success {
peers_count: 1,
compression_ratio: 1.0,
})
} else {
debug!("Averaging step: no accumulations");
Ok(AveragingResult::NoPeersAvailable)
}
}
fn get_accumulated(&self) -> &[Tensor] {
&self.accumulated
}
fn clear(&mut self) {
self.accumulated.clear();
self.accumulation_count = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[tokio::test]
async fn test_averaging() {
let config = AveragingConfig::default();
let mut averager = DecentralizedAverager::new(config);
let grad1 = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], &[3], &Device::Cpu).unwrap();
let grad2 = Tensor::from_vec(vec![4.0f32, 5.0, 6.0], &[3], &Device::Cpu).unwrap();
averager.accumulate(&[grad1]).unwrap();
averager.accumulate(&[grad2]).unwrap();
let result = averager.step().await.unwrap();
match result {
AveragingResult::Success { peers_count, .. } => {
assert_eq!(peers_count, 1);
let averaged = averager.get_accumulated();
let values: Vec<f32> = averaged[0].to_vec1().unwrap();
assert!((values[0] - 2.5).abs() < 0.01);
assert!((values[1] - 3.5).abs() < 0.01);
assert!((values[2] - 4.5).abs() < 0.01);
}
_ => panic!("Expected success"),
}
}
#[tokio::test]
async fn test_step_with_no_accumulated_returns_no_peers() {
let mut averager = DecentralizedAverager::new(AveragingConfig::default());
let result = averager.step().await.unwrap();
assert!(matches!(result, AveragingResult::NoPeersAvailable));
}
#[test]
fn test_clear_resets_state() {
let mut averager = DecentralizedAverager::new(AveragingConfig::default());
let grad = Tensor::from_vec(vec![1.0f32, 2.0], &[2], &Device::Cpu).unwrap();
averager.accumulate(&[grad]).unwrap();
assert!(!averager.get_accumulated().is_empty());
averager.clear();
assert!(averager.get_accumulated().is_empty());
}
#[test]
fn test_accumulate_mismatch_returns_error() {
let mut averager = DecentralizedAverager::new(AveragingConfig::default());
let g1 = Tensor::from_vec(vec![1.0f32, 2.0], &[2], &Device::Cpu).unwrap();
let g2 = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], &[3], &Device::Cpu).unwrap();
averager.accumulate(&[g1]).unwrap();
let result = averager.accumulate(&[g2]);
assert!(result.is_err());
}
#[test]
fn test_compress_decompress_roundtrip() {
let averager = DecentralizedAverager::new(AveragingConfig::default());
let g = Tensor::from_vec(
(0..128).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
&[128],
&Device::Cpu,
)
.unwrap();
let compressed = averager.compress_gradients(&[g.clone()]).unwrap();
let recovered = averager.decompress_gradients(&compressed).unwrap();
let orig: Vec<f32> = g.to_vec1().unwrap();
let got: Vec<f32> = recovered[0].to_vec1().unwrap();
for (o, r) in orig.iter().zip(got.iter()) {
assert!((o - r).abs() < 0.5, "orig={o} recovered={r}");
}
}
#[test]
fn test_average_gradients_two_sets() {
let averager = DecentralizedAverager::new(AveragingConfig::default());
let a = Tensor::from_vec(vec![0.0f32, 4.0], &[2], &Device::Cpu).unwrap();
let b = Tensor::from_vec(vec![2.0f32, 8.0], &[2], &Device::Cpu).unwrap();
let result = averager.average_gradients(&[vec![a], vec![b]]).unwrap();
let vals: Vec<f32> = result[0].to_vec1().unwrap();
assert!((vals[0] - 1.0).abs() < 1e-4);
assert!((vals[1] - 6.0).abs() < 1e-4);
}
#[test]
fn test_average_gradients_empty_errors() {
let averager = DecentralizedAverager::new(AveragingConfig::default());
assert!(averager.average_gradients(&[]).is_err());
}
}