ipfrs_tensorlogic/
gradient.rs

1//! Gradient storage and management for federated learning
2//!
3//! This module provides:
4//! - Gradient delta format (differences from base model)
5//! - Gradient compression (sparsification, quantization, top-k)
6//! - Gradient aggregation (averaging, weighted, momentum)
7//! - Gradient verification (checksum, shape, outliers)
8
9use crate::arrow::{TensorDtype, TensorMetadata};
10use ipfrs_core::Cid;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use thiserror::Error;
14
15/// Errors that can occur during gradient operations
16#[derive(Debug, Error)]
17pub enum GradientError {
18    #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
19    ShapeMismatch {
20        expected: Vec<usize>,
21        actual: Vec<usize>,
22    },
23
24    #[error("Checksum verification failed")]
25    ChecksumFailed,
26
27    #[error("Invalid compression ratio: {0}")]
28    InvalidCompressionRatio(f32),
29
30    #[error("Empty gradient set")]
31    EmptyGradientSet,
32
33    #[error("Incompatible dtype: {0:?}")]
34    IncompatibleDtype(TensorDtype),
35
36    #[error("Outlier detected at index {index}: value {value}")]
37    OutlierDetected { index: usize, value: f32 },
38
39    #[error("Invalid gradient: {0}")]
40    InvalidGradient(String),
41}
42
43/// Sparse gradient representation
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SparseGradient {
46    /// Indices of non-zero elements (flattened)
47    pub indices: Vec<usize>,
48    /// Non-zero gradient values
49    pub values: Vec<f32>,
50    /// Original tensor shape
51    pub shape: Vec<usize>,
52    /// Metadata
53    pub metadata: TensorMetadata,
54}
55
56impl SparseGradient {
57    /// Create a new sparse gradient
58    pub fn new(indices: Vec<usize>, values: Vec<f32>, shape: Vec<usize>) -> Self {
59        let metadata = TensorMetadata {
60            name: "sparse_gradient".to_string(),
61            shape: shape.clone(),
62            dtype: TensorDtype::Float32,
63            strides: None,
64            custom: HashMap::new(),
65        };
66
67        Self {
68            indices,
69            values,
70            shape,
71            metadata,
72        }
73    }
74
75    /// Get the number of non-zero elements
76    pub fn nnz(&self) -> usize {
77        self.indices.len()
78    }
79
80    /// Get the total number of elements
81    pub fn total_elements(&self) -> usize {
82        self.shape.iter().product()
83    }
84
85    /// Get the sparsity ratio (0.0 = dense, 1.0 = all zeros)
86    pub fn sparsity_ratio(&self) -> f32 {
87        1.0 - (self.nnz() as f32 / self.total_elements() as f32)
88    }
89
90    /// Convert to dense representation
91    pub fn to_dense(&self) -> Vec<f32> {
92        let total = self.total_elements();
93        let mut dense = vec![0.0; total];
94
95        for (&idx, &val) in self.indices.iter().zip(&self.values) {
96            if idx < total {
97                dense[idx] = val;
98            }
99        }
100
101        dense
102    }
103
104    /// Verify shape consistency
105    pub fn verify_shape(&self) -> Result<(), GradientError> {
106        let total = self.total_elements();
107
108        for &idx in &self.indices {
109            if idx >= total {
110                return Err(GradientError::InvalidGradient(format!(
111                    "Index {} out of bounds for shape {:?}",
112                    idx, self.shape
113                )));
114            }
115        }
116
117        if self.indices.len() != self.values.len() {
118            return Err(GradientError::InvalidGradient(format!(
119                "Indices length {} != values length {}",
120                self.indices.len(),
121                self.values.len()
122            )));
123        }
124
125        Ok(())
126    }
127}
128
129/// Quantized gradient (reduced precision)
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct QuantizedGradient {
132    /// Quantized values (e.g., int8)
133    pub quantized_values: Vec<i8>,
134    /// Scale factor for dequantization
135    pub scale: f32,
136    /// Minimum value for dequantization
137    pub min_val: f32,
138    /// Original tensor shape
139    pub shape: Vec<usize>,
140    /// Metadata
141    pub metadata: TensorMetadata,
142}
143
144impl QuantizedGradient {
145    /// Quantize a dense gradient to int8
146    pub fn from_dense(values: &[f32], shape: Vec<usize>) -> Self {
147        let (quantized_values, scale, min_val) = Self::quantize_i8(values);
148
149        let metadata = TensorMetadata {
150            name: "quantized_gradient".to_string(),
151            shape: shape.clone(),
152            dtype: TensorDtype::Int8,
153            strides: None,
154            custom: HashMap::new(),
155        };
156
157        Self {
158            quantized_values,
159            scale,
160            min_val,
161            shape,
162            metadata,
163        }
164    }
165
166    /// Quantize f32 values to i8
167    fn quantize_i8(values: &[f32]) -> (Vec<i8>, f32, f32) {
168        if values.is_empty() {
169            return (Vec::new(), 1.0, 0.0);
170        }
171
172        let min_val = values.iter().copied().fold(f32::INFINITY, f32::min);
173        let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
174
175        // Avoid division by zero
176        let scale = if (max_val - min_val).abs() < 1e-8 {
177            1.0
178        } else {
179            (max_val - min_val) / 255.0
180        };
181
182        let quantized = values
183            .iter()
184            .map(|&v| {
185                // Map [min_val, max_val] to [0, 255], then shift to [-128, 127]
186                let normalized = (v - min_val) / scale;
187                (normalized - 128.0).round().clamp(-128.0, 127.0) as i8
188            })
189            .collect();
190
191        (quantized, scale, min_val)
192    }
193
194    /// Dequantize to f32 values
195    pub fn to_dense(&self) -> Vec<f32> {
196        self.quantized_values
197            .iter()
198            .map(|&q| {
199                // Shift from [-128, 127] to [0, 255], then scale back
200                let normalized = (q as f32) + 128.0;
201                normalized * self.scale + self.min_val
202            })
203            .collect()
204    }
205
206    /// Get compression ratio
207    pub fn compression_ratio(&self) -> f32 {
208        // f32 = 4 bytes, i8 = 1 byte, plus scale and min_val
209        let original_size = self.quantized_values.len() * 4;
210        let compressed_size = self.quantized_values.len() + 8; // 4 bytes scale + 4 bytes min_val
211        original_size as f32 / compressed_size as f32
212    }
213}
214
215/// Gradient delta (difference from base model)
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct GradientDelta {
218    /// Base model CID
219    #[serde(serialize_with = "crate::serialize_cid")]
220    #[serde(deserialize_with = "crate::deserialize_cid")]
221    pub base_model: Cid,
222    /// Layer name to gradient mapping
223    pub layer_gradients: HashMap<String, LayerGradient>,
224    /// Checksum for verification
225    pub checksum: u64,
226    /// Timestamp
227    pub timestamp: i64,
228}
229
230/// Gradient for a single layer
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum LayerGradient {
233    /// Dense gradient
234    Dense { values: Vec<f32>, shape: Vec<usize> },
235    /// Sparse gradient
236    Sparse(SparseGradient),
237    /// Quantized gradient
238    Quantized(QuantizedGradient),
239}
240
241impl LayerGradient {
242    /// Get the shape of the gradient
243    pub fn shape(&self) -> &[usize] {
244        match self {
245            LayerGradient::Dense { shape, .. } => shape,
246            LayerGradient::Sparse(sg) => &sg.shape,
247            LayerGradient::Quantized(qg) => &qg.shape,
248        }
249    }
250
251    /// Convert to dense representation
252    pub fn to_dense(&self) -> Vec<f32> {
253        match self {
254            LayerGradient::Dense { values, .. } => values.clone(),
255            LayerGradient::Sparse(sg) => sg.to_dense(),
256            LayerGradient::Quantized(qg) => qg.to_dense(),
257        }
258    }
259
260    /// Get memory size in bytes
261    pub fn memory_size(&self) -> usize {
262        match self {
263            LayerGradient::Dense { values, .. } => values.len() * 4,
264            LayerGradient::Sparse(sg) => sg.indices.len() * 4 + sg.values.len() * 4,
265            LayerGradient::Quantized(qg) => qg.quantized_values.len() + 8,
266        }
267    }
268}
269
270impl GradientDelta {
271    /// Create a new gradient delta
272    pub fn new(base_model: Cid) -> Self {
273        Self {
274            base_model,
275            layer_gradients: HashMap::new(),
276            checksum: 0,
277            timestamp: chrono::Utc::now().timestamp(),
278        }
279    }
280
281    /// Add a dense gradient for a layer
282    pub fn add_dense_gradient(&mut self, layer_name: String, values: Vec<f32>, shape: Vec<usize>) {
283        self.layer_gradients
284            .insert(layer_name, LayerGradient::Dense { values, shape });
285        self.update_checksum();
286    }
287
288    /// Add a sparse gradient for a layer
289    pub fn add_sparse_gradient(&mut self, layer_name: String, gradient: SparseGradient) {
290        self.layer_gradients
291            .insert(layer_name, LayerGradient::Sparse(gradient));
292        self.update_checksum();
293    }
294
295    /// Add a quantized gradient for a layer
296    pub fn add_quantized_gradient(&mut self, layer_name: String, gradient: QuantizedGradient) {
297        self.layer_gradients
298            .insert(layer_name, LayerGradient::Quantized(gradient));
299        self.update_checksum();
300    }
301
302    /// Compute checksum for verification
303    fn update_checksum(&mut self) {
304        use std::collections::hash_map::DefaultHasher;
305        use std::hash::{Hash, Hasher};
306
307        let mut hasher = DefaultHasher::new();
308
309        // Hash layer count
310        self.layer_gradients.len().hash(&mut hasher);
311
312        // Hash each layer's data
313        let mut sorted_layers: Vec<_> = self.layer_gradients.iter().collect();
314        sorted_layers.sort_by_key(|(name, _)| *name);
315
316        for (name, gradient) in sorted_layers {
317            name.hash(&mut hasher);
318            gradient.shape().hash(&mut hasher);
319
320            // Hash a sample of values for efficiency
321            let dense = gradient.to_dense();
322            let sample_size = dense.len().min(100);
323            for &v in dense.iter().take(sample_size) {
324                v.to_bits().hash(&mut hasher);
325            }
326        }
327
328        self.checksum = hasher.finish();
329    }
330
331    /// Verify checksum
332    pub fn verify_checksum(&self) -> Result<(), GradientError> {
333        let mut temp = self.clone();
334        temp.update_checksum();
335
336        if temp.checksum == self.checksum {
337            Ok(())
338        } else {
339            Err(GradientError::ChecksumFailed)
340        }
341    }
342
343    /// Get total memory size in bytes
344    pub fn total_memory_size(&self) -> usize {
345        self.layer_gradients.values().map(|g| g.memory_size()).sum()
346    }
347}
348
349/// Gradient compression utilities
350pub struct GradientCompressor;
351
352impl GradientCompressor {
353    /// Compress gradient using top-k sparsification
354    pub fn top_k(
355        values: &[f32],
356        shape: Vec<usize>,
357        k: usize,
358    ) -> Result<SparseGradient, GradientError> {
359        if k == 0 || k > values.len() {
360            return Err(GradientError::InvalidCompressionRatio(
361                k as f32 / values.len() as f32,
362            ));
363        }
364
365        // Get indices of top-k absolute values
366        let mut indexed_values: Vec<(usize, f32)> = values
367            .iter()
368            .enumerate()
369            .map(|(i, &v)| (i, v.abs()))
370            .collect();
371
372        indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
373        indexed_values.truncate(k);
374
375        let mut indices = Vec::with_capacity(k);
376        let mut sparse_values = Vec::with_capacity(k);
377
378        for (idx, _) in indexed_values {
379            indices.push(idx);
380            sparse_values.push(values[idx]);
381        }
382
383        Ok(SparseGradient::new(indices, sparse_values, shape))
384    }
385
386    /// Compress gradient using threshold-based sparsification
387    pub fn threshold(values: &[f32], shape: Vec<usize>, threshold: f32) -> SparseGradient {
388        let mut indices = Vec::new();
389        let mut sparse_values = Vec::new();
390
391        for (i, &v) in values.iter().enumerate() {
392            if v.abs() >= threshold {
393                indices.push(i);
394                sparse_values.push(v);
395            }
396        }
397
398        SparseGradient::new(indices, sparse_values, shape)
399    }
400
401    /// Compress gradient using quantization
402    pub fn quantize(values: &[f32], shape: Vec<usize>) -> QuantizedGradient {
403        QuantizedGradient::from_dense(values, shape)
404    }
405
406    /// Compress gradient using random sparsification
407    pub fn random_sparsification(
408        values: &[f32],
409        shape: Vec<usize>,
410        keep_ratio: f32,
411    ) -> Result<SparseGradient, GradientError> {
412        use rand::Rng;
413
414        if keep_ratio <= 0.0 || keep_ratio > 1.0 {
415            return Err(GradientError::InvalidCompressionRatio(keep_ratio));
416        }
417
418        let mut rng = rand::rng();
419        let mut indices = Vec::new();
420        let mut sparse_values = Vec::new();
421
422        for (i, &v) in values.iter().enumerate() {
423            if rng.random::<f32>() < keep_ratio {
424                indices.push(i);
425                sparse_values.push(v / keep_ratio); // Compensate for dropout
426            }
427        }
428
429        Ok(SparseGradient::new(indices, sparse_values, shape))
430    }
431}
432
433/// Gradient aggregation for federated learning
434pub struct GradientAggregator;
435
436impl GradientAggregator {
437    /// Average multiple gradients (unweighted)
438    pub fn average(gradients: &[Vec<f32>]) -> Result<Vec<f32>, GradientError> {
439        if gradients.is_empty() {
440            return Err(GradientError::EmptyGradientSet);
441        }
442
443        let len = gradients[0].len();
444
445        // Verify all gradients have the same length
446        for g in gradients.iter() {
447            if g.len() != len {
448                return Err(GradientError::ShapeMismatch {
449                    expected: vec![len],
450                    actual: vec![g.len()],
451                });
452            }
453        }
454
455        let mut result = vec![0.0; len];
456        let count = gradients.len() as f32;
457
458        for gradient in gradients {
459            for (i, &v) in gradient.iter().enumerate() {
460                result[i] += v / count;
461            }
462        }
463
464        Ok(result)
465    }
466
467    /// Weighted average of gradients
468    pub fn weighted_average(
469        gradients: &[Vec<f32>],
470        weights: &[f32],
471    ) -> Result<Vec<f32>, GradientError> {
472        if gradients.is_empty() {
473            return Err(GradientError::EmptyGradientSet);
474        }
475
476        if gradients.len() != weights.len() {
477            return Err(GradientError::InvalidGradient(format!(
478                "Gradient count {} != weight count {}",
479                gradients.len(),
480                weights.len()
481            )));
482        }
483
484        let len = gradients[0].len();
485
486        // Verify all gradients have the same length
487        for g in gradients.iter() {
488            if g.len() != len {
489                return Err(GradientError::ShapeMismatch {
490                    expected: vec![len],
491                    actual: vec![g.len()],
492                });
493            }
494        }
495
496        let weight_sum: f32 = weights.iter().sum();
497        if weight_sum == 0.0 {
498            return Err(GradientError::InvalidGradient(
499                "Sum of weights is zero".to_string(),
500            ));
501        }
502
503        let mut result = vec![0.0; len];
504
505        for (gradient, &weight) in gradients.iter().zip(weights) {
506            let normalized_weight = weight / weight_sum;
507            for (i, &v) in gradient.iter().enumerate() {
508                result[i] += v * normalized_weight;
509            }
510        }
511
512        Ok(result)
513    }
514
515    /// Apply momentum to gradient
516    pub fn apply_momentum(
517        current_gradient: &[f32],
518        previous_momentum: &[f32],
519        momentum_factor: f32,
520    ) -> Result<Vec<f32>, GradientError> {
521        if current_gradient.len() != previous_momentum.len() {
522            return Err(GradientError::ShapeMismatch {
523                expected: vec![previous_momentum.len()],
524                actual: vec![current_gradient.len()],
525            });
526        }
527
528        let result = current_gradient
529            .iter()
530            .zip(previous_momentum)
531            .map(|(&g, &m)| momentum_factor * m + g)
532            .collect();
533
534        Ok(result)
535    }
536}
537
538/// Gradient verification utilities
539pub struct GradientVerifier;
540
541impl GradientVerifier {
542    /// Verify gradient shape matches expected shape
543    pub fn verify_shape(gradient: &[f32], expected_shape: &[usize]) -> Result<(), GradientError> {
544        let expected_size: usize = expected_shape.iter().product();
545
546        if gradient.len() != expected_size {
547            return Err(GradientError::ShapeMismatch {
548                expected: expected_shape.to_vec(),
549                actual: vec![gradient.len()],
550            });
551        }
552
553        Ok(())
554    }
555
556    /// Detect outliers in gradient (values beyond threshold standard deviations)
557    pub fn detect_outliers(gradient: &[f32], std_threshold: f32) -> Result<(), GradientError> {
558        if gradient.is_empty() {
559            return Ok(());
560        }
561
562        // Calculate mean
563        let mean = gradient.iter().sum::<f32>() / gradient.len() as f32;
564
565        // Calculate standard deviation
566        let variance =
567            gradient.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / gradient.len() as f32;
568        let std_dev = variance.sqrt();
569
570        // Check for outliers
571        for (i, &v) in gradient.iter().enumerate() {
572            let z_score = (v - mean).abs() / std_dev;
573            if z_score > std_threshold {
574                return Err(GradientError::OutlierDetected { index: i, value: v });
575            }
576        }
577
578        Ok(())
579    }
580
581    /// Verify gradient is not NaN or Inf
582    pub fn verify_finite(gradient: &[f32]) -> Result<(), GradientError> {
583        for (i, &v) in gradient.iter().enumerate() {
584            if !v.is_finite() {
585                return Err(GradientError::InvalidGradient(format!(
586                    "Non-finite value at index {}: {}",
587                    i, v
588                )));
589            }
590        }
591
592        Ok(())
593    }
594
595    /// Compute L2 norm of gradient
596    pub fn l2_norm(gradient: &[f32]) -> f32 {
597        gradient.iter().map(|&v| v * v).sum::<f32>().sqrt()
598    }
599
600    /// Clip gradient by norm
601    pub fn clip_by_norm(gradient: &mut [f32], max_norm: f32) {
602        let norm = Self::l2_norm(gradient);
603
604        if norm > max_norm {
605            let scale = max_norm / norm;
606            for v in gradient.iter_mut() {
607                *v *= scale;
608            }
609        }
610    }
611}
612
613/// Privacy budget for differential privacy
614#[derive(Debug, Clone, Copy)]
615pub struct PrivacyBudget {
616    /// Epsilon (privacy loss parameter)
617    pub epsilon: f64,
618    /// Delta (failure probability)
619    pub delta: f64,
620    /// Remaining epsilon
621    pub remaining_epsilon: f64,
622}
623
624impl PrivacyBudget {
625    /// Create a new privacy budget
626    pub fn new(epsilon: f64, delta: f64) -> Self {
627        Self {
628            epsilon,
629            delta,
630            remaining_epsilon: epsilon,
631        }
632    }
633
634    /// Consume some privacy budget
635    pub fn consume(&mut self, epsilon_used: f64) -> Result<(), GradientError> {
636        if epsilon_used > self.remaining_epsilon {
637            return Err(GradientError::InvalidGradient(format!(
638                "Insufficient privacy budget: need {}, have {}",
639                epsilon_used, self.remaining_epsilon
640            )));
641        }
642
643        self.remaining_epsilon -= epsilon_used;
644        Ok(())
645    }
646
647    /// Check if budget is exhausted
648    pub fn is_exhausted(&self) -> bool {
649        self.remaining_epsilon <= 0.0
650    }
651
652    /// Get the fraction of budget remaining
653    pub fn remaining_fraction(&self) -> f64 {
654        self.remaining_epsilon / self.epsilon
655    }
656}
657
658/// Differential privacy mechanism types
659#[derive(Debug, Clone, Copy, PartialEq)]
660pub enum DPMechanism {
661    /// Gaussian mechanism (for bounded sensitivity)
662    Gaussian,
663    /// Laplacian mechanism (for L1 sensitivity)
664    Laplacian,
665}
666
667/// Differential privacy for gradient protection
668pub struct DifferentialPrivacy {
669    /// Privacy budget
670    budget: PrivacyBudget,
671    /// Sensitivity (L2 norm bound for gradients)
672    sensitivity: f64,
673    /// Mechanism type
674    mechanism: DPMechanism,
675}
676
677impl DifferentialPrivacy {
678    /// Create a new differential privacy instance
679    pub fn new(epsilon: f64, delta: f64, sensitivity: f64, mechanism: DPMechanism) -> Self {
680        Self {
681            budget: PrivacyBudget::new(epsilon, delta),
682            sensitivity,
683            mechanism,
684        }
685    }
686
687    /// Add Gaussian noise to gradient (for DP-SGD)
688    /// Calibrated according to sensitivity and privacy parameters
689    pub fn add_gaussian_noise(&mut self, gradient: &mut [f32]) -> Result<(), GradientError> {
690        use rand::Rng;
691
692        if self.budget.is_exhausted() {
693            return Err(GradientError::InvalidGradient(
694                "Privacy budget exhausted".to_string(),
695            ));
696        }
697
698        // Calculate noise scale using Gaussian mechanism
699        // σ = sensitivity * sqrt(2 * ln(1.25/δ)) / ε
700        let ln_term = (1.25 / self.budget.delta).ln();
701        let sigma = self.sensitivity * (2.0 * ln_term).sqrt() / self.budget.epsilon;
702
703        let mut rng = rand::rng();
704
705        // Add Gaussian noise to each element
706        for v in gradient.iter_mut() {
707            let noise: f64 = rng.random_range(-1.0..1.0);
708            let gaussian_noise = sigma * noise;
709            *v += gaussian_noise as f32;
710        }
711
712        // Consume privacy budget (simplified - in practice, this depends on composition)
713        self.budget.consume(self.budget.epsilon / 100.0)?;
714
715        Ok(())
716    }
717
718    /// Add Laplacian noise to gradient
719    /// Calibrated according to L1 sensitivity and privacy parameters
720    pub fn add_laplacian_noise(&mut self, gradient: &mut [f32]) -> Result<(), GradientError> {
721        use rand::Rng;
722
723        if self.budget.is_exhausted() {
724            return Err(GradientError::InvalidGradient(
725                "Privacy budget exhausted".to_string(),
726            ));
727        }
728
729        // Calculate noise scale using Laplacian mechanism
730        // b = sensitivity / ε
731        let scale = self.sensitivity / self.budget.epsilon;
732
733        let mut rng = rand::rng();
734
735        // Add Laplacian noise to each element
736        for v in gradient.iter_mut() {
737            let u: f64 = rng.random_range(-0.5..0.5);
738            let laplacian_noise = -scale * u.signum() * (1.0 - 2.0 * u.abs()).ln();
739            *v += laplacian_noise as f32;
740        }
741
742        // Consume privacy budget
743        self.budget.consume(self.budget.epsilon / 100.0)?;
744
745        Ok(())
746    }
747
748    /// Apply DP-SGD (Differential Private Stochastic Gradient Descent)
749    /// This clips gradients and adds noise
750    pub fn apply_dp_sgd(
751        &mut self,
752        gradient: &mut [f32],
753        clip_norm: f32,
754    ) -> Result<(), GradientError> {
755        // Step 1: Clip gradient to bound sensitivity
756        GradientVerifier::clip_by_norm(gradient, clip_norm);
757
758        // Step 2: Add calibrated noise
759        match self.mechanism {
760            DPMechanism::Gaussian => self.add_gaussian_noise(gradient)?,
761            DPMechanism::Laplacian => self.add_laplacian_noise(gradient)?,
762        }
763
764        Ok(())
765    }
766
767    /// Get remaining privacy budget
768    pub fn remaining_budget(&self) -> f64 {
769        self.budget.remaining_epsilon
770    }
771
772    /// Check if privacy budget is exhausted
773    pub fn is_budget_exhausted(&self) -> bool {
774        self.budget.is_exhausted()
775    }
776
777    /// Get privacy parameters
778    pub fn get_privacy_params(&self) -> (f64, f64) {
779        (self.budget.epsilon, self.budget.delta)
780    }
781
782    /// Calculate noise multiplier for given privacy parameters
783    /// Used in DP-SGD implementations
784    pub fn calculate_noise_multiplier(epsilon: f64, delta: f64, sensitivity: f64) -> f64 {
785        // σ = sensitivity * sqrt(2 * ln(1.25/δ)) / ε
786        let ln_term = (1.25 / delta).ln();
787        sensitivity * (2.0 * ln_term).sqrt() / epsilon
788    }
789}
790
791/// Secure aggregation for federated learning (simplified)
792pub struct SecureAggregation {
793    /// Minimum number of participants required
794    min_participants: usize,
795    /// Current participant count
796    participant_count: usize,
797}
798
799impl SecureAggregation {
800    /// Create a new secure aggregation instance
801    pub fn new(min_participants: usize) -> Self {
802        Self {
803            min_participants,
804            participant_count: 0,
805        }
806    }
807
808    /// Add a participant
809    pub fn add_participant(&mut self) {
810        self.participant_count += 1;
811    }
812
813    /// Check if we have enough participants
814    pub fn can_aggregate(&self) -> bool {
815        self.participant_count >= self.min_participants
816    }
817
818    /// Aggregate gradients securely
819    /// In a real implementation, this would use cryptographic techniques
820    /// like secret sharing, homomorphic encryption, or secure multi-party computation
821    pub fn aggregate_secure(&self, gradients: &[Vec<f32>]) -> Result<Vec<f32>, GradientError> {
822        if !self.can_aggregate() {
823            return Err(GradientError::InvalidGradient(format!(
824                "Not enough participants: need {}, have {}",
825                self.min_participants, self.participant_count
826            )));
827        }
828
829        // For now, use simple averaging
830        // In production, this would:
831        // 1. Use secret sharing to split gradients
832        // 2. Aggregate encrypted shares
833        // 3. Reconstruct only the sum
834        GradientAggregator::average(gradients)
835    }
836
837    /// Reset participant count
838    pub fn reset(&mut self) {
839        self.participant_count = 0;
840    }
841
842    /// Get participant count
843    pub fn participant_count(&self) -> usize {
844        self.participant_count
845    }
846}
847
848/// Client state in federated learning
849#[derive(Debug, Clone, PartialEq, Eq)]
850pub enum ClientState {
851    /// Client is idle and ready for work
852    Idle,
853    /// Client is training
854    Training,
855    /// Client has completed training
856    Completed,
857    /// Client has failed or dropped out
858    Failed,
859}
860
861/// Client information in federated learning
862#[derive(Debug, Clone)]
863pub struct ClientInfo {
864    /// Client ID
865    pub client_id: String,
866    /// Client state
867    pub state: ClientState,
868    /// Number of samples the client has
869    pub sample_count: usize,
870    /// Last update timestamp
871    pub last_update: i64,
872}
873
874impl ClientInfo {
875    /// Create a new client info
876    pub fn new(client_id: String, sample_count: usize) -> Self {
877        Self {
878            client_id,
879            state: ClientState::Idle,
880            sample_count,
881            last_update: chrono::Utc::now().timestamp(),
882        }
883    }
884
885    /// Mark client as training
886    pub fn start_training(&mut self) {
887        self.state = ClientState::Training;
888        self.last_update = chrono::Utc::now().timestamp();
889    }
890
891    /// Mark client as completed
892    pub fn complete_training(&mut self) {
893        self.state = ClientState::Completed;
894        self.last_update = chrono::Utc::now().timestamp();
895    }
896
897    /// Mark client as failed
898    pub fn mark_failed(&mut self) {
899        self.state = ClientState::Failed;
900        self.last_update = chrono::Utc::now().timestamp();
901    }
902}
903
904/// Federated learning round
905#[derive(Debug, Clone, Serialize, Deserialize)]
906pub struct FederatedRound {
907    /// Round number
908    pub round_num: usize,
909    /// Clients participating in this round (stored as count for serialization)
910    pub client_count: usize,
911    /// Global model CID for this round
912    #[serde(serialize_with = "crate::serialize_cid")]
913    #[serde(deserialize_with = "crate::deserialize_cid")]
914    pub global_model: Cid,
915    /// Aggregated gradient for this round (if computed)
916    pub aggregated_gradient: Option<Vec<f32>>,
917    /// Round start timestamp
918    pub start_time: i64,
919    /// Round end timestamp (if completed)
920    pub end_time: Option<i64>,
921    /// Completed client count
922    pub completed_count: usize,
923}
924
925impl FederatedRound {
926    /// Create a new federated round
927    pub fn new(round_num: usize, global_model: Cid, client_count: usize) -> Self {
928        Self {
929            round_num,
930            client_count,
931            global_model,
932            aggregated_gradient: None,
933            start_time: chrono::Utc::now().timestamp(),
934            end_time: None,
935            completed_count: 0,
936        }
937    }
938
939    /// Mark a client as completed
940    pub fn mark_client_completed(&mut self) {
941        self.completed_count += 1;
942    }
943
944    /// Check if round is complete
945    pub fn is_complete(&self) -> bool {
946        self.completed_count >= self.client_count
947    }
948
949    /// Complete the round
950    pub fn complete(&mut self, aggregated_gradient: Vec<f32>) {
951        self.aggregated_gradient = Some(aggregated_gradient);
952        self.end_time = Some(chrono::Utc::now().timestamp());
953    }
954
955    /// Get round duration in seconds
956    pub fn duration(&self) -> Option<i64> {
957        self.end_time.map(|end| end - self.start_time)
958    }
959}
960
961/// Convergence detection for federated learning
962pub struct ConvergenceDetector {
963    /// Window size for convergence detection
964    window_size: usize,
965    /// Recent loss values
966    loss_history: Vec<f64>,
967    /// Convergence threshold (relative change)
968    threshold: f64,
969}
970
971impl ConvergenceDetector {
972    /// Create a new convergence detector
973    pub fn new(window_size: usize, threshold: f64) -> Self {
974        Self {
975            window_size,
976            loss_history: Vec::new(),
977            threshold,
978        }
979    }
980
981    /// Add a loss value
982    pub fn add_loss(&mut self, loss: f64) {
983        self.loss_history.push(loss);
984
985        // Keep only the last window_size values
986        if self.loss_history.len() > self.window_size {
987            self.loss_history.remove(0);
988        }
989    }
990
991    /// Check if training has converged
992    pub fn has_converged(&self) -> bool {
993        if self.loss_history.len() < self.window_size {
994            return false;
995        }
996
997        // Calculate relative change in loss
998        let recent = &self.loss_history[self.loss_history.len() - self.window_size..];
999        let mean = recent.iter().sum::<f64>() / recent.len() as f64;
1000
1001        if mean.abs() < 1e-10 {
1002            // Avoid division by zero
1003            return true;
1004        }
1005
1006        let std_dev =
1007            (recent.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64).sqrt();
1008
1009        // Converged if standard deviation is below threshold
1010        std_dev / mean.abs() < self.threshold
1011    }
1012
1013    /// Get the latest loss
1014    pub fn latest_loss(&self) -> Option<f64> {
1015        self.loss_history.last().copied()
1016    }
1017
1018    /// Clear loss history
1019    pub fn reset(&mut self) {
1020        self.loss_history.clear();
1021    }
1022
1023    /// Get loss history
1024    pub fn history(&self) -> &[f64] {
1025        &self.loss_history
1026    }
1027}
1028
1029/// Model synchronization protocol for federated learning
1030pub struct ModelSyncProtocol {
1031    /// Current round number
1032    current_round: usize,
1033    /// Maximum number of rounds
1034    max_rounds: usize,
1035    /// Minimum number of clients per round
1036    min_clients_per_round: usize,
1037    /// Round history
1038    rounds: Vec<FederatedRound>,
1039    /// Convergence detector
1040    convergence: ConvergenceDetector,
1041}
1042
1043impl ModelSyncProtocol {
1044    /// Create a new model synchronization protocol
1045    pub fn new(
1046        max_rounds: usize,
1047        min_clients_per_round: usize,
1048        convergence_window: usize,
1049        convergence_threshold: f64,
1050    ) -> Self {
1051        Self {
1052            current_round: 0,
1053            max_rounds,
1054            min_clients_per_round,
1055            rounds: Vec::new(),
1056            convergence: ConvergenceDetector::new(convergence_window, convergence_threshold),
1057        }
1058    }
1059
1060    /// Start a new round
1061    pub fn start_round(
1062        &mut self,
1063        global_model: Cid,
1064        client_count: usize,
1065    ) -> Result<usize, GradientError> {
1066        if client_count < self.min_clients_per_round {
1067            return Err(GradientError::InvalidGradient(format!(
1068                "Not enough clients: need {}, got {}",
1069                self.min_clients_per_round, client_count
1070            )));
1071        }
1072
1073        if self.current_round >= self.max_rounds {
1074            return Err(GradientError::InvalidGradient(format!(
1075                "Maximum rounds reached: {}",
1076                self.max_rounds
1077            )));
1078        }
1079
1080        let round = FederatedRound::new(self.current_round, global_model, client_count);
1081        self.rounds.push(round);
1082        self.current_round += 1;
1083
1084        Ok(self.current_round - 1)
1085    }
1086
1087    /// Complete the current round
1088    pub fn complete_round(
1089        &mut self,
1090        round_num: usize,
1091        aggregated_gradient: Vec<f32>,
1092        loss: f64,
1093    ) -> Result<(), GradientError> {
1094        if round_num >= self.rounds.len() {
1095            return Err(GradientError::InvalidGradient(format!(
1096                "Invalid round number: {}",
1097                round_num
1098            )));
1099        }
1100
1101        self.rounds[round_num].complete(aggregated_gradient);
1102        self.convergence.add_loss(loss);
1103
1104        Ok(())
1105    }
1106
1107    /// Check if training should continue
1108    pub fn should_continue(&self) -> bool {
1109        self.current_round < self.max_rounds && !self.convergence.has_converged()
1110    }
1111
1112    /// Check if training has converged
1113    pub fn has_converged(&self) -> bool {
1114        self.convergence.has_converged()
1115    }
1116
1117    /// Get the current round number
1118    pub fn current_round(&self) -> usize {
1119        self.current_round
1120    }
1121
1122    /// Get the total number of rounds
1123    pub fn total_rounds(&self) -> usize {
1124        self.rounds.len()
1125    }
1126
1127    /// Get round information
1128    pub fn get_round(&self, round_num: usize) -> Option<&FederatedRound> {
1129        self.rounds.get(round_num)
1130    }
1131
1132    /// Get the latest loss
1133    pub fn latest_loss(&self) -> Option<f64> {
1134        self.convergence.latest_loss()
1135    }
1136
1137    /// Get max rounds
1138    pub fn max_rounds(&self) -> usize {
1139        self.max_rounds
1140    }
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145    use super::*;
1146
1147    #[test]
1148    fn test_sparse_gradient() {
1149        let indices = vec![0, 5, 10];
1150        let values = vec![1.0, 2.0, 3.0];
1151        let shape = vec![20];
1152
1153        let sparse = SparseGradient::new(indices.clone(), values.clone(), shape);
1154
1155        assert_eq!(sparse.nnz(), 3);
1156        assert_eq!(sparse.total_elements(), 20);
1157        assert!((sparse.sparsity_ratio() - 0.85).abs() < 0.01);
1158
1159        let dense = sparse.to_dense();
1160        assert_eq!(dense.len(), 20);
1161        assert_eq!(dense[0], 1.0);
1162        assert_eq!(dense[5], 2.0);
1163        assert_eq!(dense[10], 3.0);
1164    }
1165
1166    #[test]
1167    fn test_quantized_gradient() {
1168        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1169        let shape = vec![5];
1170
1171        let quantized = QuantizedGradient::from_dense(&values, shape);
1172        let dequantized = quantized.to_dense();
1173
1174        // Check that dequantization is approximately correct
1175        // For a small range like [1,5] with 256 quantization levels,
1176        // we expect good precision
1177        for (i, (orig, deq)) in values.iter().zip(&dequantized).enumerate() {
1178            let error = (orig - deq).abs();
1179            // Allow for quantization error (scale = 4/255 ≈ 0.0157)
1180            assert!(
1181                error < 0.02,
1182                "Value {} mismatch: orig={}, deq={}, error={}",
1183                i,
1184                orig,
1185                deq,
1186                error
1187            );
1188        }
1189    }
1190
1191    #[test]
1192    fn test_gradient_delta() {
1193        let base_cid = Cid::default();
1194        let mut delta = GradientDelta::new(base_cid);
1195
1196        delta.add_dense_gradient("layer1".to_string(), vec![1.0, 2.0, 3.0], vec![3]);
1197        delta.add_dense_gradient("layer2".to_string(), vec![4.0, 5.0], vec![2]);
1198
1199        assert_eq!(delta.layer_gradients.len(), 2);
1200        assert!(delta.verify_checksum().is_ok());
1201    }
1202
1203    #[test]
1204    fn test_top_k_compression() {
1205        let values = vec![1.0, 5.0, 2.0, 8.0, 3.0];
1206        let shape = vec![5];
1207
1208        let sparse = GradientCompressor::top_k(&values, shape, 2).unwrap();
1209
1210        assert_eq!(sparse.nnz(), 2);
1211        assert!(sparse.values.contains(&8.0));
1212        assert!(sparse.values.contains(&5.0));
1213    }
1214
1215    #[test]
1216    fn test_threshold_compression() {
1217        let values = vec![0.1, 5.0, 0.2, 8.0, 0.3];
1218        let shape = vec![5];
1219
1220        let sparse = GradientCompressor::threshold(&values, shape, 1.0);
1221
1222        assert_eq!(sparse.nnz(), 2);
1223        assert!(sparse.values.contains(&5.0));
1224        assert!(sparse.values.contains(&8.0));
1225    }
1226
1227    #[test]
1228    fn test_gradient_averaging() {
1229        let g1 = vec![1.0, 2.0, 3.0];
1230        let g2 = vec![3.0, 4.0, 5.0];
1231        let gradients = vec![g1, g2];
1232
1233        let avg = GradientAggregator::average(&gradients).unwrap();
1234
1235        assert_eq!(avg, vec![2.0, 3.0, 4.0]);
1236    }
1237
1238    #[test]
1239    fn test_weighted_averaging() {
1240        let g1 = vec![1.0, 2.0, 3.0];
1241        let g2 = vec![3.0, 4.0, 5.0];
1242        let gradients = vec![g1, g2];
1243        let weights = vec![0.25, 0.75];
1244
1245        let avg = GradientAggregator::weighted_average(&gradients, &weights).unwrap();
1246
1247        // Expected: 0.25 * [1,2,3] + 0.75 * [3,4,5] = [2.5, 3.5, 4.5]
1248        assert!((avg[0] - 2.5).abs() < 0.01);
1249        assert!((avg[1] - 3.5).abs() < 0.01);
1250        assert!((avg[2] - 4.5).abs() < 0.01);
1251    }
1252
1253    #[test]
1254    fn test_momentum() {
1255        let current = vec![1.0, 2.0, 3.0];
1256        let previous = vec![0.5, 1.0, 1.5];
1257
1258        let result = GradientAggregator::apply_momentum(&current, &previous, 0.9).unwrap();
1259
1260        // Expected: 0.9 * previous + current
1261        assert!((result[0] - 1.45).abs() < 0.01);
1262        assert!((result[1] - 2.9).abs() < 0.01);
1263        assert!((result[2] - 4.35).abs() < 0.01);
1264    }
1265
1266    #[test]
1267    fn test_gradient_verification() {
1268        let gradient = vec![1.0, 2.0, 3.0, 4.0];
1269
1270        // Test shape verification
1271        assert!(GradientVerifier::verify_shape(&gradient, &[4]).is_ok());
1272        assert!(GradientVerifier::verify_shape(&gradient, &[2, 2]).is_ok());
1273        assert!(GradientVerifier::verify_shape(&gradient, &[5]).is_err());
1274
1275        // Test finite verification
1276        assert!(GradientVerifier::verify_finite(&gradient).is_ok());
1277
1278        let invalid = vec![1.0, f32::NAN, 3.0];
1279        assert!(GradientVerifier::verify_finite(&invalid).is_err());
1280    }
1281
1282    #[test]
1283    fn test_gradient_clipping() {
1284        let mut gradient = vec![3.0, 4.0]; // L2 norm = 5.0
1285
1286        GradientVerifier::clip_by_norm(&mut gradient, 2.5);
1287
1288        let norm = GradientVerifier::l2_norm(&gradient);
1289        assert!((norm - 2.5).abs() < 0.01);
1290    }
1291
1292    #[test]
1293    fn test_privacy_budget() {
1294        let mut budget = PrivacyBudget::new(1.0, 1e-5);
1295
1296        assert_eq!(budget.remaining_epsilon, 1.0);
1297        assert!(!budget.is_exhausted());
1298
1299        // Consume some budget
1300        budget.consume(0.5).unwrap();
1301        assert_eq!(budget.remaining_epsilon, 0.5);
1302        assert!((budget.remaining_fraction() - 0.5).abs() < 1e-6);
1303
1304        // Consume remaining budget
1305        budget.consume(0.5).unwrap();
1306        assert!(budget.is_exhausted());
1307
1308        // Should fail when budget is exhausted
1309        assert!(budget.consume(0.1).is_err());
1310    }
1311
1312    #[test]
1313    fn test_differential_privacy_gaussian() {
1314        let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Gaussian);
1315        let mut gradient = vec![1.0, 2.0, 3.0, 4.0];
1316        let original = gradient.clone();
1317
1318        dp.add_gaussian_noise(&mut gradient).unwrap();
1319
1320        // Gradient should be modified (with very high probability)
1321        assert_ne!(gradient, original);
1322
1323        // Values should still be finite
1324        assert!(GradientVerifier::verify_finite(&gradient).is_ok());
1325
1326        // Budget should be consumed
1327        assert!(dp.remaining_budget() < 1.0);
1328    }
1329
1330    #[test]
1331    fn test_differential_privacy_laplacian() {
1332        let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Laplacian);
1333        let mut gradient = vec![1.0, 2.0, 3.0, 4.0];
1334        let original = gradient.clone();
1335
1336        dp.add_laplacian_noise(&mut gradient).unwrap();
1337
1338        // Gradient should be modified (with very high probability)
1339        assert_ne!(gradient, original);
1340
1341        // Values should still be finite
1342        assert!(GradientVerifier::verify_finite(&gradient).is_ok());
1343
1344        // Budget should be consumed
1345        assert!(dp.remaining_budget() < 1.0);
1346    }
1347
1348    #[test]
1349    fn test_dp_sgd() {
1350        let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Gaussian);
1351        let mut gradient = vec![3.0, 4.0, 5.0, 6.0]; // L2 norm > 5.0
1352        let original_norm = GradientVerifier::l2_norm(&gradient);
1353
1354        dp.apply_dp_sgd(&mut gradient, 5.0).unwrap();
1355
1356        // Gradient should be clipped and noised
1357        let new_norm = GradientVerifier::l2_norm(&gradient);
1358
1359        // After clipping and noise, norm might be around 5.0 but not exact due to noise
1360        // Just check it's different from original
1361        assert!(original_norm != new_norm);
1362
1363        // Values should still be finite
1364        assert!(GradientVerifier::verify_finite(&gradient).is_ok());
1365    }
1366
1367    #[test]
1368    fn test_privacy_budget_exhaustion() {
1369        let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Gaussian);
1370        let mut gradient = vec![1.0, 2.0];
1371
1372        // Consume budget multiple times
1373        // Each call consumes epsilon/100 = 0.01, so we need 100 calls to exhaust budget of 1.0
1374        let mut successful_calls = 0;
1375        for _ in 0..200 {
1376            if dp.add_gaussian_noise(&mut gradient).is_ok() {
1377                successful_calls += 1;
1378            } else {
1379                // Budget exhausted, break
1380                break;
1381            }
1382        }
1383
1384        // Should have made ~100 successful calls before budget exhaustion
1385        assert!(
1386            (90..=110).contains(&successful_calls),
1387            "Expected ~100 calls, got {}",
1388            successful_calls
1389        );
1390
1391        // Budget should be very low or exhausted (allow small epsilon for floating point errors)
1392        let remaining = dp.remaining_budget();
1393        assert!(
1394            remaining < 0.02,
1395            "Expected nearly exhausted budget, got {}",
1396            remaining
1397        );
1398
1399        // Should fail when trying to consume more than remaining
1400        let mut new_gradient = vec![1.0, 2.0];
1401        let result = dp.add_gaussian_noise(&mut new_gradient);
1402        // Might succeed if there's a tiny bit of budget left, or fail if exhausted
1403        // Either way is acceptable at this point
1404        let _ = result;
1405    }
1406
1407    #[test]
1408    fn test_noise_multiplier_calculation() {
1409        let epsilon = 1.0;
1410        let delta = 1e-5;
1411        let sensitivity = 1.0;
1412
1413        let multiplier =
1414            DifferentialPrivacy::calculate_noise_multiplier(epsilon, delta, sensitivity);
1415
1416        // Noise multiplier should be positive and reasonable
1417        assert!(multiplier > 0.0);
1418        assert!(multiplier < 10.0); // Sanity check
1419
1420        // For higher epsilon (less privacy), noise should be lower
1421        let multiplier_high_eps =
1422            DifferentialPrivacy::calculate_noise_multiplier(10.0, delta, sensitivity);
1423        assert!(multiplier_high_eps < multiplier);
1424    }
1425
1426    #[test]
1427    fn test_secure_aggregation() {
1428        let mut aggregator = SecureAggregation::new(3);
1429
1430        assert_eq!(aggregator.participant_count(), 0);
1431        assert!(!aggregator.can_aggregate());
1432
1433        // Add participants
1434        aggregator.add_participant();
1435        aggregator.add_participant();
1436        assert!(!aggregator.can_aggregate());
1437
1438        aggregator.add_participant();
1439        assert!(aggregator.can_aggregate());
1440
1441        // Test aggregation
1442        let g1 = vec![1.0, 2.0, 3.0];
1443        let g2 = vec![2.0, 3.0, 4.0];
1444        let g3 = vec![3.0, 4.0, 5.0];
1445        let gradients = vec![g1, g2, g3];
1446
1447        let result = aggregator.aggregate_secure(&gradients).unwrap();
1448
1449        // Should be average of the three gradients
1450        assert!((result[0] - 2.0).abs() < 0.01);
1451        assert!((result[1] - 3.0).abs() < 0.01);
1452        assert!((result[2] - 4.0).abs() < 0.01);
1453
1454        // Reset
1455        aggregator.reset();
1456        assert_eq!(aggregator.participant_count(), 0);
1457    }
1458
1459    #[test]
1460    fn test_secure_aggregation_insufficient_participants() {
1461        let aggregator = SecureAggregation::new(5);
1462
1463        let g1 = vec![1.0, 2.0];
1464        let g2 = vec![3.0, 4.0];
1465        let gradients = vec![g1, g2];
1466
1467        // Should fail because we don't have enough participants
1468        let result = aggregator.aggregate_secure(&gradients);
1469        assert!(result.is_err());
1470    }
1471
1472    #[test]
1473    fn test_dp_mechanism_types() {
1474        let gaussian = DPMechanism::Gaussian;
1475        let laplacian = DPMechanism::Laplacian;
1476
1477        assert_eq!(gaussian, DPMechanism::Gaussian);
1478        assert_eq!(laplacian, DPMechanism::Laplacian);
1479        assert_ne!(gaussian, laplacian);
1480    }
1481
1482    #[test]
1483    fn test_client_info() {
1484        let mut client = ClientInfo::new("client1".to_string(), 1000);
1485
1486        assert_eq!(client.client_id, "client1");
1487        assert_eq!(client.state, ClientState::Idle);
1488        assert_eq!(client.sample_count, 1000);
1489
1490        client.start_training();
1491        assert_eq!(client.state, ClientState::Training);
1492
1493        client.complete_training();
1494        assert_eq!(client.state, ClientState::Completed);
1495
1496        client.mark_failed();
1497        assert_eq!(client.state, ClientState::Failed);
1498    }
1499
1500    #[test]
1501    fn test_federated_round() {
1502        let model_cid = Cid::default();
1503        let mut round = FederatedRound::new(0, model_cid, 5);
1504
1505        assert_eq!(round.round_num, 0);
1506        assert_eq!(round.client_count, 5);
1507        assert_eq!(round.completed_count, 0);
1508        assert!(!round.is_complete());
1509
1510        // Mark clients as completed
1511        for _ in 0..5 {
1512            round.mark_client_completed();
1513        }
1514
1515        assert_eq!(round.completed_count, 5);
1516        assert!(round.is_complete());
1517
1518        // Complete the round
1519        let gradient = vec![1.0, 2.0, 3.0];
1520        round.complete(gradient.clone());
1521
1522        assert_eq!(round.aggregated_gradient, Some(gradient));
1523        assert!(round.end_time.is_some());
1524        assert!(round.duration().is_some());
1525    }
1526
1527    #[test]
1528    fn test_convergence_detector() {
1529        let mut detector = ConvergenceDetector::new(3, 0.01);
1530
1531        // Add loss values that are converging
1532        detector.add_loss(1.0);
1533        detector.add_loss(0.99);
1534        detector.add_loss(0.98);
1535
1536        assert!(detector.has_converged());
1537        assert_eq!(detector.latest_loss(), Some(0.98));
1538        assert_eq!(detector.history().len(), 3);
1539
1540        // Reset
1541        detector.reset();
1542        assert_eq!(detector.history().len(), 0);
1543    }
1544
1545    #[test]
1546    fn test_convergence_detector_not_converged() {
1547        let mut detector = ConvergenceDetector::new(3, 0.01);
1548
1549        // Add loss values that are NOT converging
1550        detector.add_loss(1.0);
1551        detector.add_loss(0.5);
1552        detector.add_loss(1.5);
1553
1554        assert!(!detector.has_converged());
1555    }
1556
1557    #[test]
1558    fn test_model_sync_protocol() {
1559        let mut protocol = ModelSyncProtocol::new(10, 3, 3, 0.01);
1560
1561        assert_eq!(protocol.current_round(), 0);
1562        assert_eq!(protocol.max_rounds(), 10);
1563        assert!(protocol.should_continue());
1564
1565        // Start round 0
1566        let model_cid = Cid::default();
1567        let round_num = protocol.start_round(model_cid, 5).unwrap();
1568
1569        assert_eq!(round_num, 0);
1570        assert_eq!(protocol.current_round(), 1);
1571        assert_eq!(protocol.total_rounds(), 1);
1572
1573        // Complete round 0
1574        let gradient = vec![1.0, 2.0, 3.0];
1575        protocol
1576            .complete_round(round_num, gradient.clone(), 1.0)
1577            .unwrap();
1578
1579        assert_eq!(protocol.latest_loss(), Some(1.0));
1580
1581        // Get round info
1582        let round = protocol.get_round(0).unwrap();
1583        assert_eq!(round.round_num, 0);
1584        assert_eq!(round.aggregated_gradient, Some(gradient));
1585    }
1586
1587    #[test]
1588    fn test_model_sync_protocol_convergence() {
1589        let mut protocol = ModelSyncProtocol::new(10, 2, 3, 0.01);
1590
1591        let model_cid = Cid::default();
1592
1593        // Run multiple rounds with converging loss
1594        for i in 0..3 {
1595            protocol.start_round(model_cid, 3).unwrap();
1596            let gradient = vec![1.0, 2.0];
1597            let loss = 1.0 - (i as f64 * 0.001);
1598            protocol.complete_round(i, gradient, loss).unwrap();
1599        }
1600
1601        // Should have converged
1602        assert!(protocol.has_converged());
1603        assert!(!protocol.should_continue());
1604    }
1605
1606    #[test]
1607    fn test_model_sync_protocol_max_rounds() {
1608        let mut protocol = ModelSyncProtocol::new(2, 1, 3, 0.01);
1609
1610        let model_cid = Cid::default();
1611
1612        // Start 2 rounds (max)
1613        protocol.start_round(model_cid, 2).unwrap();
1614        protocol.start_round(model_cid, 2).unwrap();
1615
1616        // Should fail to start a third round
1617        let result = protocol.start_round(model_cid, 2);
1618        assert!(result.is_err());
1619    }
1620
1621    #[test]
1622    fn test_model_sync_protocol_min_clients() {
1623        let mut protocol = ModelSyncProtocol::new(10, 5, 3, 0.01);
1624
1625        let model_cid = Cid::default();
1626
1627        // Should fail with too few clients
1628        let result = protocol.start_round(model_cid, 3);
1629        assert!(result.is_err());
1630
1631        // Should succeed with enough clients
1632        let result = protocol.start_round(model_cid, 5);
1633        assert!(result.is_ok());
1634    }
1635
1636    #[test]
1637    fn test_client_state_enum() {
1638        let idle = ClientState::Idle;
1639        let training = ClientState::Training;
1640        let completed = ClientState::Completed;
1641        let failed = ClientState::Failed;
1642
1643        assert_ne!(idle, training);
1644        assert_ne!(training, completed);
1645        assert_ne!(completed, failed);
1646        assert_eq!(idle, ClientState::Idle);
1647    }
1648}