use crate::arrow::{TensorDtype, TensorMetadata};
use ipfrs_core::Cid;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum GradientError {
#[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
ShapeMismatch {
expected: Vec<usize>,
actual: Vec<usize>,
},
#[error("Checksum verification failed")]
ChecksumFailed,
#[error("Invalid compression ratio: {0}")]
InvalidCompressionRatio(f32),
#[error("Empty gradient set")]
EmptyGradientSet,
#[error("Incompatible dtype: {0:?}")]
IncompatibleDtype(TensorDtype),
#[error("Outlier detected at index {index}: value {value}")]
OutlierDetected { index: usize, value: f32 },
#[error("Invalid gradient: {0}")]
InvalidGradient(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparseGradient {
pub indices: Vec<usize>,
pub values: Vec<f32>,
pub shape: Vec<usize>,
pub metadata: TensorMetadata,
}
impl SparseGradient {
pub fn new(indices: Vec<usize>, values: Vec<f32>, shape: Vec<usize>) -> Self {
let metadata = TensorMetadata {
name: "sparse_gradient".to_string(),
shape: shape.clone(),
dtype: TensorDtype::Float32,
strides: None,
custom: HashMap::new(),
};
Self {
indices,
values,
shape,
metadata,
}
}
pub fn nnz(&self) -> usize {
self.indices.len()
}
pub fn total_elements(&self) -> usize {
self.shape.iter().product()
}
pub fn sparsity_ratio(&self) -> f32 {
1.0 - (self.nnz() as f32 / self.total_elements() as f32)
}
pub fn to_dense(&self) -> Vec<f32> {
let total = self.total_elements();
let mut dense = vec![0.0; total];
for (&idx, &val) in self.indices.iter().zip(&self.values) {
if idx < total {
dense[idx] = val;
}
}
dense
}
pub fn verify_shape(&self) -> Result<(), GradientError> {
let total = self.total_elements();
for &idx in &self.indices {
if idx >= total {
return Err(GradientError::InvalidGradient(format!(
"Index {} out of bounds for shape {:?}",
idx, self.shape
)));
}
}
if self.indices.len() != self.values.len() {
return Err(GradientError::InvalidGradient(format!(
"Indices length {} != values length {}",
self.indices.len(),
self.values.len()
)));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedGradient {
pub quantized_values: Vec<i8>,
pub scale: f32,
pub min_val: f32,
pub shape: Vec<usize>,
pub metadata: TensorMetadata,
}
impl QuantizedGradient {
pub fn from_dense(values: &[f32], shape: Vec<usize>) -> Self {
let (quantized_values, scale, min_val) = Self::quantize_i8(values);
let metadata = TensorMetadata {
name: "quantized_gradient".to_string(),
shape: shape.clone(),
dtype: TensorDtype::Int8,
strides: None,
custom: HashMap::new(),
};
Self {
quantized_values,
scale,
min_val,
shape,
metadata,
}
}
fn quantize_i8(values: &[f32]) -> (Vec<i8>, f32, f32) {
if values.is_empty() {
return (Vec::new(), 1.0, 0.0);
}
let min_val = values.iter().copied().fold(f32::INFINITY, f32::min);
let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let scale = if (max_val - min_val).abs() < 1e-8 {
1.0
} else {
(max_val - min_val) / 255.0
};
let quantized = values
.iter()
.map(|&v| {
let normalized = (v - min_val) / scale;
(normalized - 128.0).round().clamp(-128.0, 127.0) as i8
})
.collect();
(quantized, scale, min_val)
}
pub fn to_dense(&self) -> Vec<f32> {
self.quantized_values
.iter()
.map(|&q| {
let normalized = (q as f32) + 128.0;
normalized * self.scale + self.min_val
})
.collect()
}
pub fn compression_ratio(&self) -> f32 {
let original_size = self.quantized_values.len() * 4;
let compressed_size = self.quantized_values.len() + 8; original_size as f32 / compressed_size as f32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GradientDelta {
#[serde(serialize_with = "crate::serialize_cid")]
#[serde(deserialize_with = "crate::deserialize_cid")]
pub base_model: Cid,
pub layer_gradients: HashMap<String, LayerGradient>,
pub checksum: u64,
pub timestamp: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LayerGradient {
Dense { values: Vec<f32>, shape: Vec<usize> },
Sparse(SparseGradient),
Quantized(QuantizedGradient),
}
impl LayerGradient {
pub fn shape(&self) -> &[usize] {
match self {
LayerGradient::Dense { shape, .. } => shape,
LayerGradient::Sparse(sg) => &sg.shape,
LayerGradient::Quantized(qg) => &qg.shape,
}
}
pub fn to_dense(&self) -> Vec<f32> {
match self {
LayerGradient::Dense { values, .. } => values.clone(),
LayerGradient::Sparse(sg) => sg.to_dense(),
LayerGradient::Quantized(qg) => qg.to_dense(),
}
}
pub fn memory_size(&self) -> usize {
match self {
LayerGradient::Dense { values, .. } => values.len() * 4,
LayerGradient::Sparse(sg) => sg.indices.len() * 4 + sg.values.len() * 4,
LayerGradient::Quantized(qg) => qg.quantized_values.len() + 8,
}
}
}
impl GradientDelta {
pub fn new(base_model: Cid) -> Self {
Self {
base_model,
layer_gradients: HashMap::new(),
checksum: 0,
timestamp: chrono::Utc::now().timestamp(),
}
}
pub fn add_dense_gradient(&mut self, layer_name: String, values: Vec<f32>, shape: Vec<usize>) {
self.layer_gradients
.insert(layer_name, LayerGradient::Dense { values, shape });
self.update_checksum();
}
pub fn add_sparse_gradient(&mut self, layer_name: String, gradient: SparseGradient) {
self.layer_gradients
.insert(layer_name, LayerGradient::Sparse(gradient));
self.update_checksum();
}
pub fn add_quantized_gradient(&mut self, layer_name: String, gradient: QuantizedGradient) {
self.layer_gradients
.insert(layer_name, LayerGradient::Quantized(gradient));
self.update_checksum();
}
fn update_checksum(&mut self) {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.layer_gradients.len().hash(&mut hasher);
let mut sorted_layers: Vec<_> = self.layer_gradients.iter().collect();
sorted_layers.sort_by_key(|(name, _)| *name);
for (name, gradient) in sorted_layers {
name.hash(&mut hasher);
gradient.shape().hash(&mut hasher);
let dense = gradient.to_dense();
let sample_size = dense.len().min(100);
for &v in dense.iter().take(sample_size) {
v.to_bits().hash(&mut hasher);
}
}
self.checksum = hasher.finish();
}
pub fn verify_checksum(&self) -> Result<(), GradientError> {
let mut temp = self.clone();
temp.update_checksum();
if temp.checksum == self.checksum {
Ok(())
} else {
Err(GradientError::ChecksumFailed)
}
}
pub fn total_memory_size(&self) -> usize {
self.layer_gradients.values().map(|g| g.memory_size()).sum()
}
}
pub struct GradientCompressor;
impl GradientCompressor {
pub fn top_k(
values: &[f32],
shape: Vec<usize>,
k: usize,
) -> Result<SparseGradient, GradientError> {
if k == 0 || k > values.len() {
return Err(GradientError::InvalidCompressionRatio(
k as f32 / values.len() as f32,
));
}
let mut indexed_values: Vec<(usize, f32)> = values
.iter()
.enumerate()
.map(|(i, &v)| (i, v.abs()))
.collect();
indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed_values.truncate(k);
let mut indices = Vec::with_capacity(k);
let mut sparse_values = Vec::with_capacity(k);
for (idx, _) in indexed_values {
indices.push(idx);
sparse_values.push(values[idx]);
}
Ok(SparseGradient::new(indices, sparse_values, shape))
}
pub fn threshold(values: &[f32], shape: Vec<usize>, threshold: f32) -> SparseGradient {
let mut indices = Vec::new();
let mut sparse_values = Vec::new();
for (i, &v) in values.iter().enumerate() {
if v.abs() >= threshold {
indices.push(i);
sparse_values.push(v);
}
}
SparseGradient::new(indices, sparse_values, shape)
}
pub fn quantize(values: &[f32], shape: Vec<usize>) -> QuantizedGradient {
QuantizedGradient::from_dense(values, shape)
}
pub fn random_sparsification(
values: &[f32],
shape: Vec<usize>,
keep_ratio: f32,
) -> Result<SparseGradient, GradientError> {
use rand::Rng;
if keep_ratio <= 0.0 || keep_ratio > 1.0 {
return Err(GradientError::InvalidCompressionRatio(keep_ratio));
}
let mut rng = rand::rng();
let mut indices = Vec::new();
let mut sparse_values = Vec::new();
for (i, &v) in values.iter().enumerate() {
if rng.random::<f32>() < keep_ratio {
indices.push(i);
sparse_values.push(v / keep_ratio); }
}
Ok(SparseGradient::new(indices, sparse_values, shape))
}
}
pub struct GradientAggregator;
impl GradientAggregator {
pub fn average(gradients: &[Vec<f32>]) -> Result<Vec<f32>, GradientError> {
if gradients.is_empty() {
return Err(GradientError::EmptyGradientSet);
}
let len = gradients[0].len();
for g in gradients.iter() {
if g.len() != len {
return Err(GradientError::ShapeMismatch {
expected: vec![len],
actual: vec![g.len()],
});
}
}
let mut result = vec![0.0; len];
let count = gradients.len() as f32;
for gradient in gradients {
for (i, &v) in gradient.iter().enumerate() {
result[i] += v / count;
}
}
Ok(result)
}
pub fn weighted_average(
gradients: &[Vec<f32>],
weights: &[f32],
) -> Result<Vec<f32>, GradientError> {
if gradients.is_empty() {
return Err(GradientError::EmptyGradientSet);
}
if gradients.len() != weights.len() {
return Err(GradientError::InvalidGradient(format!(
"Gradient count {} != weight count {}",
gradients.len(),
weights.len()
)));
}
let len = gradients[0].len();
for g in gradients.iter() {
if g.len() != len {
return Err(GradientError::ShapeMismatch {
expected: vec![len],
actual: vec![g.len()],
});
}
}
let weight_sum: f32 = weights.iter().sum();
if weight_sum == 0.0 {
return Err(GradientError::InvalidGradient(
"Sum of weights is zero".to_string(),
));
}
let mut result = vec![0.0; len];
for (gradient, &weight) in gradients.iter().zip(weights) {
let normalized_weight = weight / weight_sum;
for (i, &v) in gradient.iter().enumerate() {
result[i] += v * normalized_weight;
}
}
Ok(result)
}
pub fn apply_momentum(
current_gradient: &[f32],
previous_momentum: &[f32],
momentum_factor: f32,
) -> Result<Vec<f32>, GradientError> {
if current_gradient.len() != previous_momentum.len() {
return Err(GradientError::ShapeMismatch {
expected: vec![previous_momentum.len()],
actual: vec![current_gradient.len()],
});
}
let result = current_gradient
.iter()
.zip(previous_momentum)
.map(|(&g, &m)| momentum_factor * m + g)
.collect();
Ok(result)
}
}
pub struct GradientVerifier;
impl GradientVerifier {
pub fn verify_shape(gradient: &[f32], expected_shape: &[usize]) -> Result<(), GradientError> {
let expected_size: usize = expected_shape.iter().product();
if gradient.len() != expected_size {
return Err(GradientError::ShapeMismatch {
expected: expected_shape.to_vec(),
actual: vec![gradient.len()],
});
}
Ok(())
}
pub fn detect_outliers(gradient: &[f32], std_threshold: f32) -> Result<(), GradientError> {
if gradient.is_empty() {
return Ok(());
}
let mean = gradient.iter().sum::<f32>() / gradient.len() as f32;
let variance =
gradient.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / gradient.len() as f32;
let std_dev = variance.sqrt();
for (i, &v) in gradient.iter().enumerate() {
let z_score = (v - mean).abs() / std_dev;
if z_score > std_threshold {
return Err(GradientError::OutlierDetected { index: i, value: v });
}
}
Ok(())
}
pub fn verify_finite(gradient: &[f32]) -> Result<(), GradientError> {
for (i, &v) in gradient.iter().enumerate() {
if !v.is_finite() {
return Err(GradientError::InvalidGradient(format!(
"Non-finite value at index {}: {}",
i, v
)));
}
}
Ok(())
}
pub fn l2_norm(gradient: &[f32]) -> f32 {
gradient.iter().map(|&v| v * v).sum::<f32>().sqrt()
}
pub fn clip_by_norm(gradient: &mut [f32], max_norm: f32) {
let norm = Self::l2_norm(gradient);
if norm > max_norm {
let scale = max_norm / norm;
for v in gradient.iter_mut() {
*v *= scale;
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct PrivacyBudget {
pub epsilon: f64,
pub delta: f64,
pub remaining_epsilon: f64,
}
impl PrivacyBudget {
pub fn new(epsilon: f64, delta: f64) -> Self {
Self {
epsilon,
delta,
remaining_epsilon: epsilon,
}
}
pub fn consume(&mut self, epsilon_used: f64) -> Result<(), GradientError> {
if epsilon_used > self.remaining_epsilon {
return Err(GradientError::InvalidGradient(format!(
"Insufficient privacy budget: need {}, have {}",
epsilon_used, self.remaining_epsilon
)));
}
self.remaining_epsilon -= epsilon_used;
Ok(())
}
pub fn is_exhausted(&self) -> bool {
self.remaining_epsilon <= 0.0
}
pub fn remaining_fraction(&self) -> f64 {
self.remaining_epsilon / self.epsilon
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DPMechanism {
Gaussian,
Laplacian,
}
pub struct DifferentialPrivacy {
budget: PrivacyBudget,
sensitivity: f64,
mechanism: DPMechanism,
}
impl DifferentialPrivacy {
pub fn new(epsilon: f64, delta: f64, sensitivity: f64, mechanism: DPMechanism) -> Self {
Self {
budget: PrivacyBudget::new(epsilon, delta),
sensitivity,
mechanism,
}
}
pub fn add_gaussian_noise(&mut self, gradient: &mut [f32]) -> Result<(), GradientError> {
use rand::Rng;
if self.budget.is_exhausted() {
return Err(GradientError::InvalidGradient(
"Privacy budget exhausted".to_string(),
));
}
let ln_term = (1.25 / self.budget.delta).ln();
let sigma = self.sensitivity * (2.0 * ln_term).sqrt() / self.budget.epsilon;
let mut rng = rand::rng();
for v in gradient.iter_mut() {
let noise: f64 = rng.random_range(-1.0..1.0);
let gaussian_noise = sigma * noise;
*v += gaussian_noise as f32;
}
self.budget.consume(self.budget.epsilon / 100.0)?;
Ok(())
}
pub fn add_laplacian_noise(&mut self, gradient: &mut [f32]) -> Result<(), GradientError> {
use rand::Rng;
if self.budget.is_exhausted() {
return Err(GradientError::InvalidGradient(
"Privacy budget exhausted".to_string(),
));
}
let scale = self.sensitivity / self.budget.epsilon;
let mut rng = rand::rng();
for v in gradient.iter_mut() {
let u: f64 = rng.random_range(-0.5..0.5);
let laplacian_noise = -scale * u.signum() * (1.0 - 2.0 * u.abs()).ln();
*v += laplacian_noise as f32;
}
self.budget.consume(self.budget.epsilon / 100.0)?;
Ok(())
}
pub fn apply_dp_sgd(
&mut self,
gradient: &mut [f32],
clip_norm: f32,
) -> Result<(), GradientError> {
GradientVerifier::clip_by_norm(gradient, clip_norm);
match self.mechanism {
DPMechanism::Gaussian => self.add_gaussian_noise(gradient)?,
DPMechanism::Laplacian => self.add_laplacian_noise(gradient)?,
}
Ok(())
}
pub fn remaining_budget(&self) -> f64 {
self.budget.remaining_epsilon
}
pub fn is_budget_exhausted(&self) -> bool {
self.budget.is_exhausted()
}
pub fn get_privacy_params(&self) -> (f64, f64) {
(self.budget.epsilon, self.budget.delta)
}
pub fn calculate_noise_multiplier(epsilon: f64, delta: f64, sensitivity: f64) -> f64 {
let ln_term = (1.25 / delta).ln();
sensitivity * (2.0 * ln_term).sqrt() / epsilon
}
}
pub struct SecureAggregation {
min_participants: usize,
participant_count: usize,
}
impl SecureAggregation {
pub fn new(min_participants: usize) -> Self {
Self {
min_participants,
participant_count: 0,
}
}
pub fn add_participant(&mut self) {
self.participant_count += 1;
}
pub fn can_aggregate(&self) -> bool {
self.participant_count >= self.min_participants
}
pub fn aggregate_secure(&self, gradients: &[Vec<f32>]) -> Result<Vec<f32>, GradientError> {
if !self.can_aggregate() {
return Err(GradientError::InvalidGradient(format!(
"Not enough participants: need {}, have {}",
self.min_participants, self.participant_count
)));
}
GradientAggregator::average(gradients)
}
pub fn reset(&mut self) {
self.participant_count = 0;
}
pub fn participant_count(&self) -> usize {
self.participant_count
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ClientState {
Idle,
Training,
Completed,
Failed,
}
#[derive(Debug, Clone)]
pub struct ClientInfo {
pub client_id: String,
pub state: ClientState,
pub sample_count: usize,
pub last_update: i64,
}
impl ClientInfo {
pub fn new(client_id: String, sample_count: usize) -> Self {
Self {
client_id,
state: ClientState::Idle,
sample_count,
last_update: chrono::Utc::now().timestamp(),
}
}
pub fn start_training(&mut self) {
self.state = ClientState::Training;
self.last_update = chrono::Utc::now().timestamp();
}
pub fn complete_training(&mut self) {
self.state = ClientState::Completed;
self.last_update = chrono::Utc::now().timestamp();
}
pub fn mark_failed(&mut self) {
self.state = ClientState::Failed;
self.last_update = chrono::Utc::now().timestamp();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedRound {
pub round_num: usize,
pub client_count: usize,
#[serde(serialize_with = "crate::serialize_cid")]
#[serde(deserialize_with = "crate::deserialize_cid")]
pub global_model: Cid,
pub aggregated_gradient: Option<Vec<f32>>,
pub start_time: i64,
pub end_time: Option<i64>,
pub completed_count: usize,
}
impl FederatedRound {
pub fn new(round_num: usize, global_model: Cid, client_count: usize) -> Self {
Self {
round_num,
client_count,
global_model,
aggregated_gradient: None,
start_time: chrono::Utc::now().timestamp(),
end_time: None,
completed_count: 0,
}
}
pub fn mark_client_completed(&mut self) {
self.completed_count += 1;
}
pub fn is_complete(&self) -> bool {
self.completed_count >= self.client_count
}
pub fn complete(&mut self, aggregated_gradient: Vec<f32>) {
self.aggregated_gradient = Some(aggregated_gradient);
self.end_time = Some(chrono::Utc::now().timestamp());
}
pub fn duration(&self) -> Option<i64> {
self.end_time.map(|end| end - self.start_time)
}
}
pub struct ConvergenceDetector {
window_size: usize,
loss_history: Vec<f64>,
threshold: f64,
}
impl ConvergenceDetector {
pub fn new(window_size: usize, threshold: f64) -> Self {
Self {
window_size,
loss_history: Vec::new(),
threshold,
}
}
pub fn add_loss(&mut self, loss: f64) {
self.loss_history.push(loss);
if self.loss_history.len() > self.window_size {
self.loss_history.remove(0);
}
}
pub fn has_converged(&self) -> bool {
if self.loss_history.len() < self.window_size {
return false;
}
let recent = &self.loss_history[self.loss_history.len() - self.window_size..];
let mean = recent.iter().sum::<f64>() / recent.len() as f64;
if mean.abs() < 1e-10 {
return true;
}
let std_dev =
(recent.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64).sqrt();
std_dev / mean.abs() < self.threshold
}
pub fn latest_loss(&self) -> Option<f64> {
self.loss_history.last().copied()
}
pub fn reset(&mut self) {
self.loss_history.clear();
}
pub fn history(&self) -> &[f64] {
&self.loss_history
}
}
pub struct ModelSyncProtocol {
current_round: usize,
max_rounds: usize,
min_clients_per_round: usize,
rounds: Vec<FederatedRound>,
convergence: ConvergenceDetector,
}
impl ModelSyncProtocol {
pub fn new(
max_rounds: usize,
min_clients_per_round: usize,
convergence_window: usize,
convergence_threshold: f64,
) -> Self {
Self {
current_round: 0,
max_rounds,
min_clients_per_round,
rounds: Vec::new(),
convergence: ConvergenceDetector::new(convergence_window, convergence_threshold),
}
}
pub fn start_round(
&mut self,
global_model: Cid,
client_count: usize,
) -> Result<usize, GradientError> {
if client_count < self.min_clients_per_round {
return Err(GradientError::InvalidGradient(format!(
"Not enough clients: need {}, got {}",
self.min_clients_per_round, client_count
)));
}
if self.current_round >= self.max_rounds {
return Err(GradientError::InvalidGradient(format!(
"Maximum rounds reached: {}",
self.max_rounds
)));
}
let round = FederatedRound::new(self.current_round, global_model, client_count);
self.rounds.push(round);
self.current_round += 1;
Ok(self.current_round - 1)
}
pub fn complete_round(
&mut self,
round_num: usize,
aggregated_gradient: Vec<f32>,
loss: f64,
) -> Result<(), GradientError> {
if round_num >= self.rounds.len() {
return Err(GradientError::InvalidGradient(format!(
"Invalid round number: {}",
round_num
)));
}
self.rounds[round_num].complete(aggregated_gradient);
self.convergence.add_loss(loss);
Ok(())
}
pub fn should_continue(&self) -> bool {
self.current_round < self.max_rounds && !self.convergence.has_converged()
}
pub fn has_converged(&self) -> bool {
self.convergence.has_converged()
}
pub fn current_round(&self) -> usize {
self.current_round
}
pub fn total_rounds(&self) -> usize {
self.rounds.len()
}
pub fn get_round(&self, round_num: usize) -> Option<&FederatedRound> {
self.rounds.get(round_num)
}
pub fn latest_loss(&self) -> Option<f64> {
self.convergence.latest_loss()
}
pub fn max_rounds(&self) -> usize {
self.max_rounds
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_gradient() {
let indices = vec![0, 5, 10];
let values = vec![1.0, 2.0, 3.0];
let shape = vec![20];
let sparse = SparseGradient::new(indices.clone(), values.clone(), shape);
assert_eq!(sparse.nnz(), 3);
assert_eq!(sparse.total_elements(), 20);
assert!((sparse.sparsity_ratio() - 0.85).abs() < 0.01);
let dense = sparse.to_dense();
assert_eq!(dense.len(), 20);
assert_eq!(dense[0], 1.0);
assert_eq!(dense[5], 2.0);
assert_eq!(dense[10], 3.0);
}
#[test]
fn test_quantized_gradient() {
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = vec![5];
let quantized = QuantizedGradient::from_dense(&values, shape);
let dequantized = quantized.to_dense();
for (i, (orig, deq)) in values.iter().zip(&dequantized).enumerate() {
let error = (orig - deq).abs();
assert!(
error < 0.02,
"Value {} mismatch: orig={}, deq={}, error={}",
i,
orig,
deq,
error
);
}
}
#[test]
fn test_gradient_delta() {
let base_cid = Cid::default();
let mut delta = GradientDelta::new(base_cid);
delta.add_dense_gradient("layer1".to_string(), vec![1.0, 2.0, 3.0], vec![3]);
delta.add_dense_gradient("layer2".to_string(), vec![4.0, 5.0], vec![2]);
assert_eq!(delta.layer_gradients.len(), 2);
assert!(delta.verify_checksum().is_ok());
}
#[test]
fn test_top_k_compression() {
let values = vec![1.0, 5.0, 2.0, 8.0, 3.0];
let shape = vec![5];
let sparse = GradientCompressor::top_k(&values, shape, 2).unwrap();
assert_eq!(sparse.nnz(), 2);
assert!(sparse.values.contains(&8.0));
assert!(sparse.values.contains(&5.0));
}
#[test]
fn test_threshold_compression() {
let values = vec![0.1, 5.0, 0.2, 8.0, 0.3];
let shape = vec![5];
let sparse = GradientCompressor::threshold(&values, shape, 1.0);
assert_eq!(sparse.nnz(), 2);
assert!(sparse.values.contains(&5.0));
assert!(sparse.values.contains(&8.0));
}
#[test]
fn test_gradient_averaging() {
let g1 = vec![1.0, 2.0, 3.0];
let g2 = vec![3.0, 4.0, 5.0];
let gradients = vec![g1, g2];
let avg = GradientAggregator::average(&gradients).unwrap();
assert_eq!(avg, vec![2.0, 3.0, 4.0]);
}
#[test]
fn test_weighted_averaging() {
let g1 = vec![1.0, 2.0, 3.0];
let g2 = vec![3.0, 4.0, 5.0];
let gradients = vec![g1, g2];
let weights = vec![0.25, 0.75];
let avg = GradientAggregator::weighted_average(&gradients, &weights).unwrap();
assert!((avg[0] - 2.5).abs() < 0.01);
assert!((avg[1] - 3.5).abs() < 0.01);
assert!((avg[2] - 4.5).abs() < 0.01);
}
#[test]
fn test_momentum() {
let current = vec![1.0, 2.0, 3.0];
let previous = vec![0.5, 1.0, 1.5];
let result = GradientAggregator::apply_momentum(¤t, &previous, 0.9).unwrap();
assert!((result[0] - 1.45).abs() < 0.01);
assert!((result[1] - 2.9).abs() < 0.01);
assert!((result[2] - 4.35).abs() < 0.01);
}
#[test]
fn test_gradient_verification() {
let gradient = vec![1.0, 2.0, 3.0, 4.0];
assert!(GradientVerifier::verify_shape(&gradient, &[4]).is_ok());
assert!(GradientVerifier::verify_shape(&gradient, &[2, 2]).is_ok());
assert!(GradientVerifier::verify_shape(&gradient, &[5]).is_err());
assert!(GradientVerifier::verify_finite(&gradient).is_ok());
let invalid = vec![1.0, f32::NAN, 3.0];
assert!(GradientVerifier::verify_finite(&invalid).is_err());
}
#[test]
fn test_gradient_clipping() {
let mut gradient = vec![3.0, 4.0];
GradientVerifier::clip_by_norm(&mut gradient, 2.5);
let norm = GradientVerifier::l2_norm(&gradient);
assert!((norm - 2.5).abs() < 0.01);
}
#[test]
fn test_privacy_budget() {
let mut budget = PrivacyBudget::new(1.0, 1e-5);
assert_eq!(budget.remaining_epsilon, 1.0);
assert!(!budget.is_exhausted());
budget.consume(0.5).unwrap();
assert_eq!(budget.remaining_epsilon, 0.5);
assert!((budget.remaining_fraction() - 0.5).abs() < 1e-6);
budget.consume(0.5).unwrap();
assert!(budget.is_exhausted());
assert!(budget.consume(0.1).is_err());
}
#[test]
fn test_differential_privacy_gaussian() {
let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Gaussian);
let mut gradient = vec![1.0, 2.0, 3.0, 4.0];
let original = gradient.clone();
dp.add_gaussian_noise(&mut gradient).unwrap();
assert_ne!(gradient, original);
assert!(GradientVerifier::verify_finite(&gradient).is_ok());
assert!(dp.remaining_budget() < 1.0);
}
#[test]
fn test_differential_privacy_laplacian() {
let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Laplacian);
let mut gradient = vec![1.0, 2.0, 3.0, 4.0];
let original = gradient.clone();
dp.add_laplacian_noise(&mut gradient).unwrap();
assert_ne!(gradient, original);
assert!(GradientVerifier::verify_finite(&gradient).is_ok());
assert!(dp.remaining_budget() < 1.0);
}
#[test]
fn test_dp_sgd() {
let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Gaussian);
let mut gradient = vec![3.0, 4.0, 5.0, 6.0]; let original_norm = GradientVerifier::l2_norm(&gradient);
dp.apply_dp_sgd(&mut gradient, 5.0).unwrap();
let new_norm = GradientVerifier::l2_norm(&gradient);
assert!(original_norm != new_norm);
assert!(GradientVerifier::verify_finite(&gradient).is_ok());
}
#[test]
fn test_privacy_budget_exhaustion() {
let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Gaussian);
let mut gradient = vec![1.0, 2.0];
let mut successful_calls = 0;
for _ in 0..200 {
if dp.add_gaussian_noise(&mut gradient).is_ok() {
successful_calls += 1;
} else {
break;
}
}
assert!(
(90..=110).contains(&successful_calls),
"Expected ~100 calls, got {}",
successful_calls
);
let remaining = dp.remaining_budget();
assert!(
remaining < 0.02,
"Expected nearly exhausted budget, got {}",
remaining
);
let mut new_gradient = vec![1.0, 2.0];
let result = dp.add_gaussian_noise(&mut new_gradient);
let _ = result;
}
#[test]
fn test_noise_multiplier_calculation() {
let epsilon = 1.0;
let delta = 1e-5;
let sensitivity = 1.0;
let multiplier =
DifferentialPrivacy::calculate_noise_multiplier(epsilon, delta, sensitivity);
assert!(multiplier > 0.0);
assert!(multiplier < 10.0);
let multiplier_high_eps =
DifferentialPrivacy::calculate_noise_multiplier(10.0, delta, sensitivity);
assert!(multiplier_high_eps < multiplier);
}
#[test]
fn test_secure_aggregation() {
let mut aggregator = SecureAggregation::new(3);
assert_eq!(aggregator.participant_count(), 0);
assert!(!aggregator.can_aggregate());
aggregator.add_participant();
aggregator.add_participant();
assert!(!aggregator.can_aggregate());
aggregator.add_participant();
assert!(aggregator.can_aggregate());
let g1 = vec![1.0, 2.0, 3.0];
let g2 = vec![2.0, 3.0, 4.0];
let g3 = vec![3.0, 4.0, 5.0];
let gradients = vec![g1, g2, g3];
let result = aggregator.aggregate_secure(&gradients).unwrap();
assert!((result[0] - 2.0).abs() < 0.01);
assert!((result[1] - 3.0).abs() < 0.01);
assert!((result[2] - 4.0).abs() < 0.01);
aggregator.reset();
assert_eq!(aggregator.participant_count(), 0);
}
#[test]
fn test_secure_aggregation_insufficient_participants() {
let aggregator = SecureAggregation::new(5);
let g1 = vec![1.0, 2.0];
let g2 = vec![3.0, 4.0];
let gradients = vec![g1, g2];
let result = aggregator.aggregate_secure(&gradients);
assert!(result.is_err());
}
#[test]
fn test_dp_mechanism_types() {
let gaussian = DPMechanism::Gaussian;
let laplacian = DPMechanism::Laplacian;
assert_eq!(gaussian, DPMechanism::Gaussian);
assert_eq!(laplacian, DPMechanism::Laplacian);
assert_ne!(gaussian, laplacian);
}
#[test]
fn test_client_info() {
let mut client = ClientInfo::new("client1".to_string(), 1000);
assert_eq!(client.client_id, "client1");
assert_eq!(client.state, ClientState::Idle);
assert_eq!(client.sample_count, 1000);
client.start_training();
assert_eq!(client.state, ClientState::Training);
client.complete_training();
assert_eq!(client.state, ClientState::Completed);
client.mark_failed();
assert_eq!(client.state, ClientState::Failed);
}
#[test]
fn test_federated_round() {
let model_cid = Cid::default();
let mut round = FederatedRound::new(0, model_cid, 5);
assert_eq!(round.round_num, 0);
assert_eq!(round.client_count, 5);
assert_eq!(round.completed_count, 0);
assert!(!round.is_complete());
for _ in 0..5 {
round.mark_client_completed();
}
assert_eq!(round.completed_count, 5);
assert!(round.is_complete());
let gradient = vec![1.0, 2.0, 3.0];
round.complete(gradient.clone());
assert_eq!(round.aggregated_gradient, Some(gradient));
assert!(round.end_time.is_some());
assert!(round.duration().is_some());
}
#[test]
fn test_convergence_detector() {
let mut detector = ConvergenceDetector::new(3, 0.01);
detector.add_loss(1.0);
detector.add_loss(0.99);
detector.add_loss(0.98);
assert!(detector.has_converged());
assert_eq!(detector.latest_loss(), Some(0.98));
assert_eq!(detector.history().len(), 3);
detector.reset();
assert_eq!(detector.history().len(), 0);
}
#[test]
fn test_convergence_detector_not_converged() {
let mut detector = ConvergenceDetector::new(3, 0.01);
detector.add_loss(1.0);
detector.add_loss(0.5);
detector.add_loss(1.5);
assert!(!detector.has_converged());
}
#[test]
fn test_model_sync_protocol() {
let mut protocol = ModelSyncProtocol::new(10, 3, 3, 0.01);
assert_eq!(protocol.current_round(), 0);
assert_eq!(protocol.max_rounds(), 10);
assert!(protocol.should_continue());
let model_cid = Cid::default();
let round_num = protocol.start_round(model_cid, 5).unwrap();
assert_eq!(round_num, 0);
assert_eq!(protocol.current_round(), 1);
assert_eq!(protocol.total_rounds(), 1);
let gradient = vec![1.0, 2.0, 3.0];
protocol
.complete_round(round_num, gradient.clone(), 1.0)
.unwrap();
assert_eq!(protocol.latest_loss(), Some(1.0));
let round = protocol.get_round(0).unwrap();
assert_eq!(round.round_num, 0);
assert_eq!(round.aggregated_gradient, Some(gradient));
}
#[test]
fn test_model_sync_protocol_convergence() {
let mut protocol = ModelSyncProtocol::new(10, 2, 3, 0.01);
let model_cid = Cid::default();
for i in 0..3 {
protocol.start_round(model_cid, 3).unwrap();
let gradient = vec![1.0, 2.0];
let loss = 1.0 - (i as f64 * 0.001);
protocol.complete_round(i, gradient, loss).unwrap();
}
assert!(protocol.has_converged());
assert!(!protocol.should_continue());
}
#[test]
fn test_model_sync_protocol_max_rounds() {
let mut protocol = ModelSyncProtocol::new(2, 1, 3, 0.01);
let model_cid = Cid::default();
protocol.start_round(model_cid, 2).unwrap();
protocol.start_round(model_cid, 2).unwrap();
let result = protocol.start_round(model_cid, 2);
assert!(result.is_err());
}
#[test]
fn test_model_sync_protocol_min_clients() {
let mut protocol = ModelSyncProtocol::new(10, 5, 3, 0.01);
let model_cid = Cid::default();
let result = protocol.start_round(model_cid, 3);
assert!(result.is_err());
let result = protocol.start_round(model_cid, 5);
assert!(result.is_ok());
}
#[test]
fn test_client_state_enum() {
let idle = ClientState::Idle;
let training = ClientState::Training;
let completed = ClientState::Completed;
let failed = ClientState::Failed;
assert_ne!(idle, training);
assert_ne!(training, completed);
assert_ne!(completed, failed);
assert_eq!(idle, ClientState::Idle);
}
}