use crate::error::{CoreError, CoreResult, ErrorContext};
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
};
fn lock_params(m: &Mutex<Vec<f64>>) -> CoreResult<std::sync::MutexGuard<'_, Vec<f64>>> {
m.lock().map_err(|_| {
CoreError::ComputationError(ErrorContext::new("parameter server: mutex poisoned"))
})
}
pub struct ParameterServer {
params: Arc<Mutex<Vec<f64>>>,
version: Arc<AtomicU64>,
n_workers: usize,
learning_rate: f64,
}
impl ParameterServer {
pub fn new(initial_params: Vec<f64>, n_workers: usize, lr: f64) -> Self {
assert!(n_workers > 0, "n_workers must be > 0");
assert!(!initial_params.is_empty(), "initial_params must not be empty");
Self {
params: Arc::new(Mutex::new(initial_params)),
version: Arc::new(AtomicU64::new(0)),
n_workers,
learning_rate: lr,
}
}
pub fn get_params(&self) -> Vec<f64> {
self.params
.lock()
.map(|g| g.clone())
.unwrap_or_else(|_| Vec::new())
}
pub fn version(&self) -> u64 {
self.version.load(Ordering::Acquire)
}
pub fn n_workers(&self) -> usize {
self.n_workers
}
pub fn push_gradient(&self, worker_id: usize, gradient: &[f64]) -> CoreResult<()> {
if worker_id >= self.n_workers {
return Err(CoreError::ValueError(ErrorContext::new(format!(
"worker_id {worker_id} >= n_workers {}",
self.n_workers
))));
}
let mut params = lock_params(&self.params)?;
if gradient.len() != params.len() {
return Err(CoreError::ValueError(ErrorContext::new(format!(
"gradient length {} does not match parameter length {}",
gradient.len(),
params.len()
))));
}
let lr = self.learning_rate;
for (p, &g) in params.iter_mut().zip(gradient.iter()) {
*p -= lr * g;
}
self.version.fetch_add(1, Ordering::Release);
Ok(())
}
pub fn sync_step(&self, gradients: Vec<Vec<f64>>) -> CoreResult<()> {
if gradients.len() != self.n_workers {
return Err(CoreError::ValueError(ErrorContext::new(format!(
"expected {} gradients, got {}",
self.n_workers,
gradients.len()
))));
}
let mut params = lock_params(&self.params)?;
let param_len = params.len();
let n = self.n_workers as f64;
for (i, g) in gradients.iter().enumerate() {
if g.len() != param_len {
return Err(CoreError::ValueError(ErrorContext::new(format!(
"gradient {i} has length {} but param vector has length {param_len}",
g.len()
))));
}
}
let mut avg_grad = vec![0.0_f64; param_len];
for g in &gradients {
for (acc, &val) in avg_grad.iter_mut().zip(g.iter()) {
*acc += val;
}
}
for acc in avg_grad.iter_mut() {
*acc /= n;
}
let lr = self.learning_rate;
for (p, &avg) in params.iter_mut().zip(avg_grad.iter()) {
*p -= lr * avg;
}
self.version.fetch_add(1, Ordering::Release);
Ok(())
}
}
pub fn top_k_sparsify(gradient: &[f64], k: usize) -> (Vec<usize>, Vec<f64>) {
assert!(k > 0, "k must be > 0");
if gradient.is_empty() {
return (Vec::new(), Vec::new());
}
let effective_k = k.min(gradient.len());
let mut indexed: Vec<(f64, usize)> = gradient
.iter()
.enumerate()
.map(|(i, &v)| (v.abs(), i))
.collect();
let split_at = gradient.len() - effective_k;
if split_at > 0 {
indexed.select_nth_unstable_by(split_at, |a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
}
let mut selected: Vec<(usize, f64)> = indexed[split_at..]
.iter()
.map(|&(_, idx)| (idx, gradient[idx]))
.collect();
selected.sort_unstable_by_key(|&(idx, _)| idx);
let indices: Vec<usize> = selected.iter().map(|&(i, _)| i).collect();
let values: Vec<f64> = selected.iter().map(|&(_, v)| v).collect();
(indices, values)
}
pub fn decompress_gradient(indices: &[usize], values: &[f64], size: usize) -> Vec<f64> {
assert_eq!(
indices.len(),
values.len(),
"indices and values must have the same length"
);
let mut dense = vec![0.0_f64; size];
for (&idx, &val) in indices.iter().zip(values.iter()) {
assert!(idx < size, "index {idx} out of bounds for size {size}");
dense[idx] = val;
}
dense
}
pub struct ErrorFeedbackCompressor {
error_buffer: Vec<f64>,
k: usize,
}
impl ErrorFeedbackCompressor {
pub fn new(size: usize, k: usize) -> Self {
assert!(size > 0, "size must be > 0");
assert!(k > 0, "k must be > 0");
Self {
error_buffer: vec![0.0_f64; size],
k,
}
}
pub fn error_buffer(&self) -> &[f64] {
&self.error_buffer
}
pub fn compress_and_feedback(&mut self, gradient: &[f64]) -> (Vec<usize>, Vec<f64>) {
assert_eq!(
gradient.len(),
self.error_buffer.len(),
"gradient length must match compressor size"
);
let corrected: Vec<f64> = gradient
.iter()
.zip(self.error_buffer.iter())
.map(|(&g, &e)| g + e)
.collect();
let (indices, values) = top_k_sparsify(&corrected, self.k);
let transmitted = decompress_gradient(&indices, &values, corrected.len());
for (e, (c, t)) in self
.error_buffer
.iter_mut()
.zip(corrected.iter().zip(transmitted.iter()))
{
*e = c - t;
}
(indices, values)
}
pub fn error_norm(&self) -> f64 {
let sq: f64 = self.error_buffer.iter().map(|&v| v * v).sum();
sq.sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_push_gradient_single_worker() {
let ps = ParameterServer::new(vec![0.0_f64, 0.0, 0.0, 0.0], 2, 0.1);
ps.push_gradient(0, &[1.0, 2.0, 3.0, 4.0])
.expect("push failed");
let params = ps.get_params();
let expected = [-0.1_f64, -0.2, -0.3, -0.4];
for (p, &e) in params.iter().zip(expected.iter()) {
assert!((p - e).abs() < 1e-10, "got {p}, expected {e}");
}
}
#[test]
fn test_push_gradient_increments_version() {
let ps = ParameterServer::new(vec![0.0_f64; 2], 1, 0.01);
assert_eq!(ps.version(), 0);
ps.push_gradient(0, &[1.0, 1.0]).expect("push failed");
assert_eq!(ps.version(), 1);
ps.push_gradient(0, &[1.0, 1.0]).expect("push failed");
assert_eq!(ps.version(), 2);
}
#[test]
fn test_push_gradient_concurrent() {
let n = 4_usize;
let ps = Arc::new(ParameterServer::new(vec![0.0_f64; 3], n, 1.0));
let handles: Vec<_> = (0..n)
.map(|id| {
let ps_ref = Arc::clone(&ps);
thread::spawn(move || ps_ref.push_gradient(id, &[1.0, 1.0, 1.0]))
})
.collect();
for h in handles {
h.join().expect("thread panic").expect("push error");
}
assert_eq!(ps.version(), n as u64);
let params = ps.get_params();
for p in ¶ms {
assert!((p - -4.0_f64).abs() < 1e-10, "got {p}");
}
}
#[test]
fn test_push_gradient_invalid_worker_id() {
let ps = ParameterServer::new(vec![0.0_f64; 2], 2, 0.1);
assert!(ps.push_gradient(5, &[1.0, 2.0]).is_err());
}
#[test]
fn test_push_gradient_wrong_length() {
let ps = ParameterServer::new(vec![0.0_f64; 4], 2, 0.1);
assert!(ps.push_gradient(0, &[1.0, 2.0]).is_err()); }
#[test]
fn test_sync_step_averages_gradients() {
let ps = ParameterServer::new(vec![0.0_f64; 4], 2, 1.0);
let gradients = vec![vec![2.0_f64; 4], vec![4.0_f64; 4]];
ps.sync_step(gradients).expect("sync_step failed");
let params = ps.get_params();
for p in ¶ms {
assert!((p - -3.0_f64).abs() < 1e-10, "got {p}");
}
}
#[test]
fn test_sync_step_wrong_worker_count() {
let ps = ParameterServer::new(vec![0.0_f64; 2], 3, 0.1);
let gradients = vec![vec![1.0_f64; 2], vec![1.0_f64; 2]]; assert!(ps.sync_step(gradients).is_err());
}
#[test]
fn test_sync_step_increments_version() {
let ps = ParameterServer::new(vec![0.0_f64; 2], 2, 0.1);
let gradients = vec![vec![0.0_f64; 2], vec![0.0_f64; 2]];
ps.sync_step(gradients).expect("sync_step failed");
assert_eq!(ps.version(), 1);
}
#[test]
fn test_top_k_sparsify_selects_largest_abs() {
let grad = vec![0.1_f64, -0.5, 0.3, -0.8, 0.2];
let (idx, vals) = top_k_sparsify(&grad, 2);
assert_eq!(idx, vec![1, 3]);
assert_eq!(vals, vec![-0.5, -0.8]);
}
#[test]
fn test_top_k_sparsify_k_equals_length() {
let grad = vec![1.0_f64, 2.0, 3.0];
let (idx, vals) = top_k_sparsify(&grad, 3);
assert_eq!(idx.len(), 3);
assert_eq!(vals.len(), 3);
}
#[test]
fn test_top_k_sparsify_k_exceeds_length_clips_to_len() {
let grad = vec![1.0_f64, 2.0];
let (idx, vals) = top_k_sparsify(&grad, 100);
assert_eq!(idx.len(), 2);
assert_eq!(vals.len(), 2);
}
#[test]
fn test_top_k_sparsify_indices_sorted() {
let grad: Vec<f64> = (0..20).rev().map(|i| i as f64).collect();
let (idx, _vals) = top_k_sparsify(&grad, 5);
for w in idx.windows(2) {
assert!(w[0] < w[1], "indices must be in ascending order");
}
}
#[test]
fn test_decompress_gradient_round_trip() {
let grad = vec![0.0_f64, 0.5, 0.0, -0.3, 0.0];
let (idx, vals) = top_k_sparsify(&grad, 2);
let dense = decompress_gradient(&idx, &vals, grad.len());
for (a, b) in dense.iter().zip(grad.iter()) {
assert!((a - b).abs() < 1e-10, "dense[i]={a} vs grad[i]={b}");
}
}
#[test]
fn test_decompress_gradient_zeros_for_missing() {
let idx = vec![2_usize, 4];
let vals = vec![1.5_f64, -0.7];
let dense = decompress_gradient(&idx, &vals, 6);
assert_eq!(dense[0], 0.0);
assert_eq!(dense[1], 0.0);
assert!((dense[2] - 1.5).abs() < 1e-10);
assert_eq!(dense[3], 0.0);
assert!((dense[4] - -0.7).abs() < 1e-10);
assert_eq!(dense[5], 0.0);
}
#[test]
fn test_error_feedback_compressor_basic() {
let mut comp = ErrorFeedbackCompressor::new(5, 2);
let grad = vec![0.1_f64, -0.5, 0.3, -0.8, 0.2];
let (idx, vals) = comp.compress_and_feedback(&grad);
assert_eq!(idx.len(), 2);
assert_eq!(vals.len(), 2);
let err = comp.error_buffer();
let transmitted = decompress_gradient(&idx, &vals, 5);
for (i, (&e, &g)) in err.iter().zip(grad.iter()).enumerate() {
let expected_err = g - transmitted[i];
assert!((e - expected_err).abs() < 1e-10);
}
}
#[test]
fn test_error_feedback_error_norm_decreases_over_steps() {
let size = 10_usize;
let k = 3_usize;
let mut comp = ErrorFeedbackCompressor::new(size, k);
let grad: Vec<f64> = (0..size).map(|i| (i as f64) * 0.01).collect();
let mut norms = Vec::new();
for _ in 0..20 {
comp.compress_and_feedback(&grad);
norms.push(comp.error_norm());
}
let last_norm = *norms.last().expect("should succeed");
assert!(
last_norm.is_finite(),
"error norm diverged: {last_norm}"
);
}
#[test]
fn test_error_feedback_accumulates_residual() {
let mut comp = ErrorFeedbackCompressor::new(2, 1);
let grad = vec![1.0_f64, 0.9];
let (idx1, _vals1) = comp.compress_and_feedback(&grad);
assert_eq!(idx1.len(), 1);
let (idx2, _vals2) = comp.compress_and_feedback(&grad);
assert_eq!(idx2.len(), 1);
let covered: std::collections::HashSet<usize> =
idx1.iter().chain(idx2.iter()).cloned().collect();
assert_eq!(covered.len(), 2, "both elements should be covered over 2 steps");
}
#[test]
fn test_error_feedback_zero_gradient_keeps_zero_error() {
let mut comp = ErrorFeedbackCompressor::new(4, 2);
let grad = vec![0.0_f64; 4];
comp.compress_and_feedback(&grad);
for &e in comp.error_buffer() {
assert_eq!(e, 0.0);
}
}
}