1use crate::arrow::{TensorDtype, TensorMetadata};
10use ipfrs_core::Cid;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
17pub enum GradientError {
18 #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
19 ShapeMismatch {
20 expected: Vec<usize>,
21 actual: Vec<usize>,
22 },
23
24 #[error("Checksum verification failed")]
25 ChecksumFailed,
26
27 #[error("Invalid compression ratio: {0}")]
28 InvalidCompressionRatio(f32),
29
30 #[error("Empty gradient set")]
31 EmptyGradientSet,
32
33 #[error("Incompatible dtype: {0:?}")]
34 IncompatibleDtype(TensorDtype),
35
36 #[error("Outlier detected at index {index}: value {value}")]
37 OutlierDetected { index: usize, value: f32 },
38
39 #[error("Invalid gradient: {0}")]
40 InvalidGradient(String),
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SparseGradient {
46 pub indices: Vec<usize>,
48 pub values: Vec<f32>,
50 pub shape: Vec<usize>,
52 pub metadata: TensorMetadata,
54}
55
56impl SparseGradient {
57 pub fn new(indices: Vec<usize>, values: Vec<f32>, shape: Vec<usize>) -> Self {
59 let metadata = TensorMetadata {
60 name: "sparse_gradient".to_string(),
61 shape: shape.clone(),
62 dtype: TensorDtype::Float32,
63 strides: None,
64 custom: HashMap::new(),
65 };
66
67 Self {
68 indices,
69 values,
70 shape,
71 metadata,
72 }
73 }
74
75 pub fn nnz(&self) -> usize {
77 self.indices.len()
78 }
79
80 pub fn total_elements(&self) -> usize {
82 self.shape.iter().product()
83 }
84
85 pub fn sparsity_ratio(&self) -> f32 {
87 1.0 - (self.nnz() as f32 / self.total_elements() as f32)
88 }
89
90 pub fn to_dense(&self) -> Vec<f32> {
92 let total = self.total_elements();
93 let mut dense = vec![0.0; total];
94
95 for (&idx, &val) in self.indices.iter().zip(&self.values) {
96 if idx < total {
97 dense[idx] = val;
98 }
99 }
100
101 dense
102 }
103
104 pub fn verify_shape(&self) -> Result<(), GradientError> {
106 let total = self.total_elements();
107
108 for &idx in &self.indices {
109 if idx >= total {
110 return Err(GradientError::InvalidGradient(format!(
111 "Index {} out of bounds for shape {:?}",
112 idx, self.shape
113 )));
114 }
115 }
116
117 if self.indices.len() != self.values.len() {
118 return Err(GradientError::InvalidGradient(format!(
119 "Indices length {} != values length {}",
120 self.indices.len(),
121 self.values.len()
122 )));
123 }
124
125 Ok(())
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct QuantizedGradient {
132 pub quantized_values: Vec<i8>,
134 pub scale: f32,
136 pub min_val: f32,
138 pub shape: Vec<usize>,
140 pub metadata: TensorMetadata,
142}
143
144impl QuantizedGradient {
145 pub fn from_dense(values: &[f32], shape: Vec<usize>) -> Self {
147 let (quantized_values, scale, min_val) = Self::quantize_i8(values);
148
149 let metadata = TensorMetadata {
150 name: "quantized_gradient".to_string(),
151 shape: shape.clone(),
152 dtype: TensorDtype::Int8,
153 strides: None,
154 custom: HashMap::new(),
155 };
156
157 Self {
158 quantized_values,
159 scale,
160 min_val,
161 shape,
162 metadata,
163 }
164 }
165
166 fn quantize_i8(values: &[f32]) -> (Vec<i8>, f32, f32) {
168 if values.is_empty() {
169 return (Vec::new(), 1.0, 0.0);
170 }
171
172 let min_val = values.iter().copied().fold(f32::INFINITY, f32::min);
173 let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
174
175 let scale = if (max_val - min_val).abs() < 1e-8 {
177 1.0
178 } else {
179 (max_val - min_val) / 255.0
180 };
181
182 let quantized = values
183 .iter()
184 .map(|&v| {
185 let normalized = (v - min_val) / scale;
187 (normalized - 128.0).round().clamp(-128.0, 127.0) as i8
188 })
189 .collect();
190
191 (quantized, scale, min_val)
192 }
193
194 pub fn to_dense(&self) -> Vec<f32> {
196 self.quantized_values
197 .iter()
198 .map(|&q| {
199 let normalized = (q as f32) + 128.0;
201 normalized * self.scale + self.min_val
202 })
203 .collect()
204 }
205
206 pub fn compression_ratio(&self) -> f32 {
208 let original_size = self.quantized_values.len() * 4;
210 let compressed_size = self.quantized_values.len() + 8; original_size as f32 / compressed_size as f32
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct GradientDelta {
218 #[serde(serialize_with = "crate::serialize_cid")]
220 #[serde(deserialize_with = "crate::deserialize_cid")]
221 pub base_model: Cid,
222 pub layer_gradients: HashMap<String, LayerGradient>,
224 pub checksum: u64,
226 pub timestamp: i64,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum LayerGradient {
233 Dense { values: Vec<f32>, shape: Vec<usize> },
235 Sparse(SparseGradient),
237 Quantized(QuantizedGradient),
239}
240
241impl LayerGradient {
242 pub fn shape(&self) -> &[usize] {
244 match self {
245 LayerGradient::Dense { shape, .. } => shape,
246 LayerGradient::Sparse(sg) => &sg.shape,
247 LayerGradient::Quantized(qg) => &qg.shape,
248 }
249 }
250
251 pub fn to_dense(&self) -> Vec<f32> {
253 match self {
254 LayerGradient::Dense { values, .. } => values.clone(),
255 LayerGradient::Sparse(sg) => sg.to_dense(),
256 LayerGradient::Quantized(qg) => qg.to_dense(),
257 }
258 }
259
260 pub fn memory_size(&self) -> usize {
262 match self {
263 LayerGradient::Dense { values, .. } => values.len() * 4,
264 LayerGradient::Sparse(sg) => sg.indices.len() * 4 + sg.values.len() * 4,
265 LayerGradient::Quantized(qg) => qg.quantized_values.len() + 8,
266 }
267 }
268}
269
270impl GradientDelta {
271 pub fn new(base_model: Cid) -> Self {
273 Self {
274 base_model,
275 layer_gradients: HashMap::new(),
276 checksum: 0,
277 timestamp: chrono::Utc::now().timestamp(),
278 }
279 }
280
281 pub fn add_dense_gradient(&mut self, layer_name: String, values: Vec<f32>, shape: Vec<usize>) {
283 self.layer_gradients
284 .insert(layer_name, LayerGradient::Dense { values, shape });
285 self.update_checksum();
286 }
287
288 pub fn add_sparse_gradient(&mut self, layer_name: String, gradient: SparseGradient) {
290 self.layer_gradients
291 .insert(layer_name, LayerGradient::Sparse(gradient));
292 self.update_checksum();
293 }
294
295 pub fn add_quantized_gradient(&mut self, layer_name: String, gradient: QuantizedGradient) {
297 self.layer_gradients
298 .insert(layer_name, LayerGradient::Quantized(gradient));
299 self.update_checksum();
300 }
301
302 fn update_checksum(&mut self) {
304 use std::collections::hash_map::DefaultHasher;
305 use std::hash::{Hash, Hasher};
306
307 let mut hasher = DefaultHasher::new();
308
309 self.layer_gradients.len().hash(&mut hasher);
311
312 let mut sorted_layers: Vec<_> = self.layer_gradients.iter().collect();
314 sorted_layers.sort_by_key(|(name, _)| *name);
315
316 for (name, gradient) in sorted_layers {
317 name.hash(&mut hasher);
318 gradient.shape().hash(&mut hasher);
319
320 let dense = gradient.to_dense();
322 let sample_size = dense.len().min(100);
323 for &v in dense.iter().take(sample_size) {
324 v.to_bits().hash(&mut hasher);
325 }
326 }
327
328 self.checksum = hasher.finish();
329 }
330
331 pub fn verify_checksum(&self) -> Result<(), GradientError> {
333 let mut temp = self.clone();
334 temp.update_checksum();
335
336 if temp.checksum == self.checksum {
337 Ok(())
338 } else {
339 Err(GradientError::ChecksumFailed)
340 }
341 }
342
343 pub fn total_memory_size(&self) -> usize {
345 self.layer_gradients.values().map(|g| g.memory_size()).sum()
346 }
347}
348
349pub struct GradientCompressor;
351
352impl GradientCompressor {
353 pub fn top_k(
355 values: &[f32],
356 shape: Vec<usize>,
357 k: usize,
358 ) -> Result<SparseGradient, GradientError> {
359 if k == 0 || k > values.len() {
360 return Err(GradientError::InvalidCompressionRatio(
361 k as f32 / values.len() as f32,
362 ));
363 }
364
365 let mut indexed_values: Vec<(usize, f32)> = values
367 .iter()
368 .enumerate()
369 .map(|(i, &v)| (i, v.abs()))
370 .collect();
371
372 indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
373 indexed_values.truncate(k);
374
375 let mut indices = Vec::with_capacity(k);
376 let mut sparse_values = Vec::with_capacity(k);
377
378 for (idx, _) in indexed_values {
379 indices.push(idx);
380 sparse_values.push(values[idx]);
381 }
382
383 Ok(SparseGradient::new(indices, sparse_values, shape))
384 }
385
386 pub fn threshold(values: &[f32], shape: Vec<usize>, threshold: f32) -> SparseGradient {
388 let mut indices = Vec::new();
389 let mut sparse_values = Vec::new();
390
391 for (i, &v) in values.iter().enumerate() {
392 if v.abs() >= threshold {
393 indices.push(i);
394 sparse_values.push(v);
395 }
396 }
397
398 SparseGradient::new(indices, sparse_values, shape)
399 }
400
401 pub fn quantize(values: &[f32], shape: Vec<usize>) -> QuantizedGradient {
403 QuantizedGradient::from_dense(values, shape)
404 }
405
406 pub fn random_sparsification(
408 values: &[f32],
409 shape: Vec<usize>,
410 keep_ratio: f32,
411 ) -> Result<SparseGradient, GradientError> {
412 use rand::Rng;
413
414 if keep_ratio <= 0.0 || keep_ratio > 1.0 {
415 return Err(GradientError::InvalidCompressionRatio(keep_ratio));
416 }
417
418 let mut rng = rand::rng();
419 let mut indices = Vec::new();
420 let mut sparse_values = Vec::new();
421
422 for (i, &v) in values.iter().enumerate() {
423 if rng.random::<f32>() < keep_ratio {
424 indices.push(i);
425 sparse_values.push(v / keep_ratio); }
427 }
428
429 Ok(SparseGradient::new(indices, sparse_values, shape))
430 }
431}
432
433pub struct GradientAggregator;
435
436impl GradientAggregator {
437 pub fn average(gradients: &[Vec<f32>]) -> Result<Vec<f32>, GradientError> {
439 if gradients.is_empty() {
440 return Err(GradientError::EmptyGradientSet);
441 }
442
443 let len = gradients[0].len();
444
445 for g in gradients.iter() {
447 if g.len() != len {
448 return Err(GradientError::ShapeMismatch {
449 expected: vec![len],
450 actual: vec![g.len()],
451 });
452 }
453 }
454
455 let mut result = vec![0.0; len];
456 let count = gradients.len() as f32;
457
458 for gradient in gradients {
459 for (i, &v) in gradient.iter().enumerate() {
460 result[i] += v / count;
461 }
462 }
463
464 Ok(result)
465 }
466
467 pub fn weighted_average(
469 gradients: &[Vec<f32>],
470 weights: &[f32],
471 ) -> Result<Vec<f32>, GradientError> {
472 if gradients.is_empty() {
473 return Err(GradientError::EmptyGradientSet);
474 }
475
476 if gradients.len() != weights.len() {
477 return Err(GradientError::InvalidGradient(format!(
478 "Gradient count {} != weight count {}",
479 gradients.len(),
480 weights.len()
481 )));
482 }
483
484 let len = gradients[0].len();
485
486 for g in gradients.iter() {
488 if g.len() != len {
489 return Err(GradientError::ShapeMismatch {
490 expected: vec![len],
491 actual: vec![g.len()],
492 });
493 }
494 }
495
496 let weight_sum: f32 = weights.iter().sum();
497 if weight_sum == 0.0 {
498 return Err(GradientError::InvalidGradient(
499 "Sum of weights is zero".to_string(),
500 ));
501 }
502
503 let mut result = vec![0.0; len];
504
505 for (gradient, &weight) in gradients.iter().zip(weights) {
506 let normalized_weight = weight / weight_sum;
507 for (i, &v) in gradient.iter().enumerate() {
508 result[i] += v * normalized_weight;
509 }
510 }
511
512 Ok(result)
513 }
514
515 pub fn apply_momentum(
517 current_gradient: &[f32],
518 previous_momentum: &[f32],
519 momentum_factor: f32,
520 ) -> Result<Vec<f32>, GradientError> {
521 if current_gradient.len() != previous_momentum.len() {
522 return Err(GradientError::ShapeMismatch {
523 expected: vec![previous_momentum.len()],
524 actual: vec![current_gradient.len()],
525 });
526 }
527
528 let result = current_gradient
529 .iter()
530 .zip(previous_momentum)
531 .map(|(&g, &m)| momentum_factor * m + g)
532 .collect();
533
534 Ok(result)
535 }
536}
537
538pub struct GradientVerifier;
540
541impl GradientVerifier {
542 pub fn verify_shape(gradient: &[f32], expected_shape: &[usize]) -> Result<(), GradientError> {
544 let expected_size: usize = expected_shape.iter().product();
545
546 if gradient.len() != expected_size {
547 return Err(GradientError::ShapeMismatch {
548 expected: expected_shape.to_vec(),
549 actual: vec![gradient.len()],
550 });
551 }
552
553 Ok(())
554 }
555
556 pub fn detect_outliers(gradient: &[f32], std_threshold: f32) -> Result<(), GradientError> {
558 if gradient.is_empty() {
559 return Ok(());
560 }
561
562 let mean = gradient.iter().sum::<f32>() / gradient.len() as f32;
564
565 let variance =
567 gradient.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / gradient.len() as f32;
568 let std_dev = variance.sqrt();
569
570 for (i, &v) in gradient.iter().enumerate() {
572 let z_score = (v - mean).abs() / std_dev;
573 if z_score > std_threshold {
574 return Err(GradientError::OutlierDetected { index: i, value: v });
575 }
576 }
577
578 Ok(())
579 }
580
581 pub fn verify_finite(gradient: &[f32]) -> Result<(), GradientError> {
583 for (i, &v) in gradient.iter().enumerate() {
584 if !v.is_finite() {
585 return Err(GradientError::InvalidGradient(format!(
586 "Non-finite value at index {}: {}",
587 i, v
588 )));
589 }
590 }
591
592 Ok(())
593 }
594
595 pub fn l2_norm(gradient: &[f32]) -> f32 {
597 gradient.iter().map(|&v| v * v).sum::<f32>().sqrt()
598 }
599
600 pub fn clip_by_norm(gradient: &mut [f32], max_norm: f32) {
602 let norm = Self::l2_norm(gradient);
603
604 if norm > max_norm {
605 let scale = max_norm / norm;
606 for v in gradient.iter_mut() {
607 *v *= scale;
608 }
609 }
610 }
611}
612
613#[derive(Debug, Clone, Copy)]
615pub struct PrivacyBudget {
616 pub epsilon: f64,
618 pub delta: f64,
620 pub remaining_epsilon: f64,
622}
623
624impl PrivacyBudget {
625 pub fn new(epsilon: f64, delta: f64) -> Self {
627 Self {
628 epsilon,
629 delta,
630 remaining_epsilon: epsilon,
631 }
632 }
633
634 pub fn consume(&mut self, epsilon_used: f64) -> Result<(), GradientError> {
636 if epsilon_used > self.remaining_epsilon {
637 return Err(GradientError::InvalidGradient(format!(
638 "Insufficient privacy budget: need {}, have {}",
639 epsilon_used, self.remaining_epsilon
640 )));
641 }
642
643 self.remaining_epsilon -= epsilon_used;
644 Ok(())
645 }
646
647 pub fn is_exhausted(&self) -> bool {
649 self.remaining_epsilon <= 0.0
650 }
651
652 pub fn remaining_fraction(&self) -> f64 {
654 self.remaining_epsilon / self.epsilon
655 }
656}
657
658#[derive(Debug, Clone, Copy, PartialEq)]
660pub enum DPMechanism {
661 Gaussian,
663 Laplacian,
665}
666
667pub struct DifferentialPrivacy {
669 budget: PrivacyBudget,
671 sensitivity: f64,
673 mechanism: DPMechanism,
675}
676
677impl DifferentialPrivacy {
678 pub fn new(epsilon: f64, delta: f64, sensitivity: f64, mechanism: DPMechanism) -> Self {
680 Self {
681 budget: PrivacyBudget::new(epsilon, delta),
682 sensitivity,
683 mechanism,
684 }
685 }
686
687 pub fn add_gaussian_noise(&mut self, gradient: &mut [f32]) -> Result<(), GradientError> {
690 use rand::Rng;
691
692 if self.budget.is_exhausted() {
693 return Err(GradientError::InvalidGradient(
694 "Privacy budget exhausted".to_string(),
695 ));
696 }
697
698 let ln_term = (1.25 / self.budget.delta).ln();
701 let sigma = self.sensitivity * (2.0 * ln_term).sqrt() / self.budget.epsilon;
702
703 let mut rng = rand::rng();
704
705 for v in gradient.iter_mut() {
707 let noise: f64 = rng.random_range(-1.0..1.0);
708 let gaussian_noise = sigma * noise;
709 *v += gaussian_noise as f32;
710 }
711
712 self.budget.consume(self.budget.epsilon / 100.0)?;
714
715 Ok(())
716 }
717
718 pub fn add_laplacian_noise(&mut self, gradient: &mut [f32]) -> Result<(), GradientError> {
721 use rand::Rng;
722
723 if self.budget.is_exhausted() {
724 return Err(GradientError::InvalidGradient(
725 "Privacy budget exhausted".to_string(),
726 ));
727 }
728
729 let scale = self.sensitivity / self.budget.epsilon;
732
733 let mut rng = rand::rng();
734
735 for v in gradient.iter_mut() {
737 let u: f64 = rng.random_range(-0.5..0.5);
738 let laplacian_noise = -scale * u.signum() * (1.0 - 2.0 * u.abs()).ln();
739 *v += laplacian_noise as f32;
740 }
741
742 self.budget.consume(self.budget.epsilon / 100.0)?;
744
745 Ok(())
746 }
747
748 pub fn apply_dp_sgd(
751 &mut self,
752 gradient: &mut [f32],
753 clip_norm: f32,
754 ) -> Result<(), GradientError> {
755 GradientVerifier::clip_by_norm(gradient, clip_norm);
757
758 match self.mechanism {
760 DPMechanism::Gaussian => self.add_gaussian_noise(gradient)?,
761 DPMechanism::Laplacian => self.add_laplacian_noise(gradient)?,
762 }
763
764 Ok(())
765 }
766
767 pub fn remaining_budget(&self) -> f64 {
769 self.budget.remaining_epsilon
770 }
771
772 pub fn is_budget_exhausted(&self) -> bool {
774 self.budget.is_exhausted()
775 }
776
777 pub fn get_privacy_params(&self) -> (f64, f64) {
779 (self.budget.epsilon, self.budget.delta)
780 }
781
782 pub fn calculate_noise_multiplier(epsilon: f64, delta: f64, sensitivity: f64) -> f64 {
785 let ln_term = (1.25 / delta).ln();
787 sensitivity * (2.0 * ln_term).sqrt() / epsilon
788 }
789}
790
791pub struct SecureAggregation {
793 min_participants: usize,
795 participant_count: usize,
797}
798
799impl SecureAggregation {
800 pub fn new(min_participants: usize) -> Self {
802 Self {
803 min_participants,
804 participant_count: 0,
805 }
806 }
807
808 pub fn add_participant(&mut self) {
810 self.participant_count += 1;
811 }
812
813 pub fn can_aggregate(&self) -> bool {
815 self.participant_count >= self.min_participants
816 }
817
818 pub fn aggregate_secure(&self, gradients: &[Vec<f32>]) -> Result<Vec<f32>, GradientError> {
822 if !self.can_aggregate() {
823 return Err(GradientError::InvalidGradient(format!(
824 "Not enough participants: need {}, have {}",
825 self.min_participants, self.participant_count
826 )));
827 }
828
829 GradientAggregator::average(gradients)
835 }
836
837 pub fn reset(&mut self) {
839 self.participant_count = 0;
840 }
841
842 pub fn participant_count(&self) -> usize {
844 self.participant_count
845 }
846}
847
848#[derive(Debug, Clone, PartialEq, Eq)]
850pub enum ClientState {
851 Idle,
853 Training,
855 Completed,
857 Failed,
859}
860
861#[derive(Debug, Clone)]
863pub struct ClientInfo {
864 pub client_id: String,
866 pub state: ClientState,
868 pub sample_count: usize,
870 pub last_update: i64,
872}
873
874impl ClientInfo {
875 pub fn new(client_id: String, sample_count: usize) -> Self {
877 Self {
878 client_id,
879 state: ClientState::Idle,
880 sample_count,
881 last_update: chrono::Utc::now().timestamp(),
882 }
883 }
884
885 pub fn start_training(&mut self) {
887 self.state = ClientState::Training;
888 self.last_update = chrono::Utc::now().timestamp();
889 }
890
891 pub fn complete_training(&mut self) {
893 self.state = ClientState::Completed;
894 self.last_update = chrono::Utc::now().timestamp();
895 }
896
897 pub fn mark_failed(&mut self) {
899 self.state = ClientState::Failed;
900 self.last_update = chrono::Utc::now().timestamp();
901 }
902}
903
904#[derive(Debug, Clone, Serialize, Deserialize)]
906pub struct FederatedRound {
907 pub round_num: usize,
909 pub client_count: usize,
911 #[serde(serialize_with = "crate::serialize_cid")]
913 #[serde(deserialize_with = "crate::deserialize_cid")]
914 pub global_model: Cid,
915 pub aggregated_gradient: Option<Vec<f32>>,
917 pub start_time: i64,
919 pub end_time: Option<i64>,
921 pub completed_count: usize,
923}
924
925impl FederatedRound {
926 pub fn new(round_num: usize, global_model: Cid, client_count: usize) -> Self {
928 Self {
929 round_num,
930 client_count,
931 global_model,
932 aggregated_gradient: None,
933 start_time: chrono::Utc::now().timestamp(),
934 end_time: None,
935 completed_count: 0,
936 }
937 }
938
939 pub fn mark_client_completed(&mut self) {
941 self.completed_count += 1;
942 }
943
944 pub fn is_complete(&self) -> bool {
946 self.completed_count >= self.client_count
947 }
948
949 pub fn complete(&mut self, aggregated_gradient: Vec<f32>) {
951 self.aggregated_gradient = Some(aggregated_gradient);
952 self.end_time = Some(chrono::Utc::now().timestamp());
953 }
954
955 pub fn duration(&self) -> Option<i64> {
957 self.end_time.map(|end| end - self.start_time)
958 }
959}
960
961pub struct ConvergenceDetector {
963 window_size: usize,
965 loss_history: Vec<f64>,
967 threshold: f64,
969}
970
971impl ConvergenceDetector {
972 pub fn new(window_size: usize, threshold: f64) -> Self {
974 Self {
975 window_size,
976 loss_history: Vec::new(),
977 threshold,
978 }
979 }
980
981 pub fn add_loss(&mut self, loss: f64) {
983 self.loss_history.push(loss);
984
985 if self.loss_history.len() > self.window_size {
987 self.loss_history.remove(0);
988 }
989 }
990
991 pub fn has_converged(&self) -> bool {
993 if self.loss_history.len() < self.window_size {
994 return false;
995 }
996
997 let recent = &self.loss_history[self.loss_history.len() - self.window_size..];
999 let mean = recent.iter().sum::<f64>() / recent.len() as f64;
1000
1001 if mean.abs() < 1e-10 {
1002 return true;
1004 }
1005
1006 let std_dev =
1007 (recent.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64).sqrt();
1008
1009 std_dev / mean.abs() < self.threshold
1011 }
1012
1013 pub fn latest_loss(&self) -> Option<f64> {
1015 self.loss_history.last().copied()
1016 }
1017
1018 pub fn reset(&mut self) {
1020 self.loss_history.clear();
1021 }
1022
1023 pub fn history(&self) -> &[f64] {
1025 &self.loss_history
1026 }
1027}
1028
1029pub struct ModelSyncProtocol {
1031 current_round: usize,
1033 max_rounds: usize,
1035 min_clients_per_round: usize,
1037 rounds: Vec<FederatedRound>,
1039 convergence: ConvergenceDetector,
1041}
1042
1043impl ModelSyncProtocol {
1044 pub fn new(
1046 max_rounds: usize,
1047 min_clients_per_round: usize,
1048 convergence_window: usize,
1049 convergence_threshold: f64,
1050 ) -> Self {
1051 Self {
1052 current_round: 0,
1053 max_rounds,
1054 min_clients_per_round,
1055 rounds: Vec::new(),
1056 convergence: ConvergenceDetector::new(convergence_window, convergence_threshold),
1057 }
1058 }
1059
1060 pub fn start_round(
1062 &mut self,
1063 global_model: Cid,
1064 client_count: usize,
1065 ) -> Result<usize, GradientError> {
1066 if client_count < self.min_clients_per_round {
1067 return Err(GradientError::InvalidGradient(format!(
1068 "Not enough clients: need {}, got {}",
1069 self.min_clients_per_round, client_count
1070 )));
1071 }
1072
1073 if self.current_round >= self.max_rounds {
1074 return Err(GradientError::InvalidGradient(format!(
1075 "Maximum rounds reached: {}",
1076 self.max_rounds
1077 )));
1078 }
1079
1080 let round = FederatedRound::new(self.current_round, global_model, client_count);
1081 self.rounds.push(round);
1082 self.current_round += 1;
1083
1084 Ok(self.current_round - 1)
1085 }
1086
1087 pub fn complete_round(
1089 &mut self,
1090 round_num: usize,
1091 aggregated_gradient: Vec<f32>,
1092 loss: f64,
1093 ) -> Result<(), GradientError> {
1094 if round_num >= self.rounds.len() {
1095 return Err(GradientError::InvalidGradient(format!(
1096 "Invalid round number: {}",
1097 round_num
1098 )));
1099 }
1100
1101 self.rounds[round_num].complete(aggregated_gradient);
1102 self.convergence.add_loss(loss);
1103
1104 Ok(())
1105 }
1106
1107 pub fn should_continue(&self) -> bool {
1109 self.current_round < self.max_rounds && !self.convergence.has_converged()
1110 }
1111
1112 pub fn has_converged(&self) -> bool {
1114 self.convergence.has_converged()
1115 }
1116
1117 pub fn current_round(&self) -> usize {
1119 self.current_round
1120 }
1121
1122 pub fn total_rounds(&self) -> usize {
1124 self.rounds.len()
1125 }
1126
1127 pub fn get_round(&self, round_num: usize) -> Option<&FederatedRound> {
1129 self.rounds.get(round_num)
1130 }
1131
1132 pub fn latest_loss(&self) -> Option<f64> {
1134 self.convergence.latest_loss()
1135 }
1136
1137 pub fn max_rounds(&self) -> usize {
1139 self.max_rounds
1140 }
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145 use super::*;
1146
1147 #[test]
1148 fn test_sparse_gradient() {
1149 let indices = vec![0, 5, 10];
1150 let values = vec![1.0, 2.0, 3.0];
1151 let shape = vec![20];
1152
1153 let sparse = SparseGradient::new(indices.clone(), values.clone(), shape);
1154
1155 assert_eq!(sparse.nnz(), 3);
1156 assert_eq!(sparse.total_elements(), 20);
1157 assert!((sparse.sparsity_ratio() - 0.85).abs() < 0.01);
1158
1159 let dense = sparse.to_dense();
1160 assert_eq!(dense.len(), 20);
1161 assert_eq!(dense[0], 1.0);
1162 assert_eq!(dense[5], 2.0);
1163 assert_eq!(dense[10], 3.0);
1164 }
1165
1166 #[test]
1167 fn test_quantized_gradient() {
1168 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1169 let shape = vec![5];
1170
1171 let quantized = QuantizedGradient::from_dense(&values, shape);
1172 let dequantized = quantized.to_dense();
1173
1174 for (i, (orig, deq)) in values.iter().zip(&dequantized).enumerate() {
1178 let error = (orig - deq).abs();
1179 assert!(
1181 error < 0.02,
1182 "Value {} mismatch: orig={}, deq={}, error={}",
1183 i,
1184 orig,
1185 deq,
1186 error
1187 );
1188 }
1189 }
1190
1191 #[test]
1192 fn test_gradient_delta() {
1193 let base_cid = Cid::default();
1194 let mut delta = GradientDelta::new(base_cid);
1195
1196 delta.add_dense_gradient("layer1".to_string(), vec![1.0, 2.0, 3.0], vec![3]);
1197 delta.add_dense_gradient("layer2".to_string(), vec![4.0, 5.0], vec![2]);
1198
1199 assert_eq!(delta.layer_gradients.len(), 2);
1200 assert!(delta.verify_checksum().is_ok());
1201 }
1202
1203 #[test]
1204 fn test_top_k_compression() {
1205 let values = vec![1.0, 5.0, 2.0, 8.0, 3.0];
1206 let shape = vec![5];
1207
1208 let sparse = GradientCompressor::top_k(&values, shape, 2).unwrap();
1209
1210 assert_eq!(sparse.nnz(), 2);
1211 assert!(sparse.values.contains(&8.0));
1212 assert!(sparse.values.contains(&5.0));
1213 }
1214
1215 #[test]
1216 fn test_threshold_compression() {
1217 let values = vec![0.1, 5.0, 0.2, 8.0, 0.3];
1218 let shape = vec![5];
1219
1220 let sparse = GradientCompressor::threshold(&values, shape, 1.0);
1221
1222 assert_eq!(sparse.nnz(), 2);
1223 assert!(sparse.values.contains(&5.0));
1224 assert!(sparse.values.contains(&8.0));
1225 }
1226
1227 #[test]
1228 fn test_gradient_averaging() {
1229 let g1 = vec![1.0, 2.0, 3.0];
1230 let g2 = vec![3.0, 4.0, 5.0];
1231 let gradients = vec![g1, g2];
1232
1233 let avg = GradientAggregator::average(&gradients).unwrap();
1234
1235 assert_eq!(avg, vec![2.0, 3.0, 4.0]);
1236 }
1237
1238 #[test]
1239 fn test_weighted_averaging() {
1240 let g1 = vec![1.0, 2.0, 3.0];
1241 let g2 = vec![3.0, 4.0, 5.0];
1242 let gradients = vec![g1, g2];
1243 let weights = vec![0.25, 0.75];
1244
1245 let avg = GradientAggregator::weighted_average(&gradients, &weights).unwrap();
1246
1247 assert!((avg[0] - 2.5).abs() < 0.01);
1249 assert!((avg[1] - 3.5).abs() < 0.01);
1250 assert!((avg[2] - 4.5).abs() < 0.01);
1251 }
1252
1253 #[test]
1254 fn test_momentum() {
1255 let current = vec![1.0, 2.0, 3.0];
1256 let previous = vec![0.5, 1.0, 1.5];
1257
1258 let result = GradientAggregator::apply_momentum(¤t, &previous, 0.9).unwrap();
1259
1260 assert!((result[0] - 1.45).abs() < 0.01);
1262 assert!((result[1] - 2.9).abs() < 0.01);
1263 assert!((result[2] - 4.35).abs() < 0.01);
1264 }
1265
1266 #[test]
1267 fn test_gradient_verification() {
1268 let gradient = vec![1.0, 2.0, 3.0, 4.0];
1269
1270 assert!(GradientVerifier::verify_shape(&gradient, &[4]).is_ok());
1272 assert!(GradientVerifier::verify_shape(&gradient, &[2, 2]).is_ok());
1273 assert!(GradientVerifier::verify_shape(&gradient, &[5]).is_err());
1274
1275 assert!(GradientVerifier::verify_finite(&gradient).is_ok());
1277
1278 let invalid = vec![1.0, f32::NAN, 3.0];
1279 assert!(GradientVerifier::verify_finite(&invalid).is_err());
1280 }
1281
1282 #[test]
1283 fn test_gradient_clipping() {
1284 let mut gradient = vec![3.0, 4.0]; GradientVerifier::clip_by_norm(&mut gradient, 2.5);
1287
1288 let norm = GradientVerifier::l2_norm(&gradient);
1289 assert!((norm - 2.5).abs() < 0.01);
1290 }
1291
1292 #[test]
1293 fn test_privacy_budget() {
1294 let mut budget = PrivacyBudget::new(1.0, 1e-5);
1295
1296 assert_eq!(budget.remaining_epsilon, 1.0);
1297 assert!(!budget.is_exhausted());
1298
1299 budget.consume(0.5).unwrap();
1301 assert_eq!(budget.remaining_epsilon, 0.5);
1302 assert!((budget.remaining_fraction() - 0.5).abs() < 1e-6);
1303
1304 budget.consume(0.5).unwrap();
1306 assert!(budget.is_exhausted());
1307
1308 assert!(budget.consume(0.1).is_err());
1310 }
1311
1312 #[test]
1313 fn test_differential_privacy_gaussian() {
1314 let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Gaussian);
1315 let mut gradient = vec![1.0, 2.0, 3.0, 4.0];
1316 let original = gradient.clone();
1317
1318 dp.add_gaussian_noise(&mut gradient).unwrap();
1319
1320 assert_ne!(gradient, original);
1322
1323 assert!(GradientVerifier::verify_finite(&gradient).is_ok());
1325
1326 assert!(dp.remaining_budget() < 1.0);
1328 }
1329
1330 #[test]
1331 fn test_differential_privacy_laplacian() {
1332 let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Laplacian);
1333 let mut gradient = vec![1.0, 2.0, 3.0, 4.0];
1334 let original = gradient.clone();
1335
1336 dp.add_laplacian_noise(&mut gradient).unwrap();
1337
1338 assert_ne!(gradient, original);
1340
1341 assert!(GradientVerifier::verify_finite(&gradient).is_ok());
1343
1344 assert!(dp.remaining_budget() < 1.0);
1346 }
1347
1348 #[test]
1349 fn test_dp_sgd() {
1350 let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Gaussian);
1351 let mut gradient = vec![3.0, 4.0, 5.0, 6.0]; let original_norm = GradientVerifier::l2_norm(&gradient);
1353
1354 dp.apply_dp_sgd(&mut gradient, 5.0).unwrap();
1355
1356 let new_norm = GradientVerifier::l2_norm(&gradient);
1358
1359 assert!(original_norm != new_norm);
1362
1363 assert!(GradientVerifier::verify_finite(&gradient).is_ok());
1365 }
1366
1367 #[test]
1368 fn test_privacy_budget_exhaustion() {
1369 let mut dp = DifferentialPrivacy::new(1.0, 1e-5, 1.0, DPMechanism::Gaussian);
1370 let mut gradient = vec![1.0, 2.0];
1371
1372 let mut successful_calls = 0;
1375 for _ in 0..200 {
1376 if dp.add_gaussian_noise(&mut gradient).is_ok() {
1377 successful_calls += 1;
1378 } else {
1379 break;
1381 }
1382 }
1383
1384 assert!(
1386 (90..=110).contains(&successful_calls),
1387 "Expected ~100 calls, got {}",
1388 successful_calls
1389 );
1390
1391 let remaining = dp.remaining_budget();
1393 assert!(
1394 remaining < 0.02,
1395 "Expected nearly exhausted budget, got {}",
1396 remaining
1397 );
1398
1399 let mut new_gradient = vec![1.0, 2.0];
1401 let result = dp.add_gaussian_noise(&mut new_gradient);
1402 let _ = result;
1405 }
1406
1407 #[test]
1408 fn test_noise_multiplier_calculation() {
1409 let epsilon = 1.0;
1410 let delta = 1e-5;
1411 let sensitivity = 1.0;
1412
1413 let multiplier =
1414 DifferentialPrivacy::calculate_noise_multiplier(epsilon, delta, sensitivity);
1415
1416 assert!(multiplier > 0.0);
1418 assert!(multiplier < 10.0); let multiplier_high_eps =
1422 DifferentialPrivacy::calculate_noise_multiplier(10.0, delta, sensitivity);
1423 assert!(multiplier_high_eps < multiplier);
1424 }
1425
1426 #[test]
1427 fn test_secure_aggregation() {
1428 let mut aggregator = SecureAggregation::new(3);
1429
1430 assert_eq!(aggregator.participant_count(), 0);
1431 assert!(!aggregator.can_aggregate());
1432
1433 aggregator.add_participant();
1435 aggregator.add_participant();
1436 assert!(!aggregator.can_aggregate());
1437
1438 aggregator.add_participant();
1439 assert!(aggregator.can_aggregate());
1440
1441 let g1 = vec![1.0, 2.0, 3.0];
1443 let g2 = vec![2.0, 3.0, 4.0];
1444 let g3 = vec![3.0, 4.0, 5.0];
1445 let gradients = vec![g1, g2, g3];
1446
1447 let result = aggregator.aggregate_secure(&gradients).unwrap();
1448
1449 assert!((result[0] - 2.0).abs() < 0.01);
1451 assert!((result[1] - 3.0).abs() < 0.01);
1452 assert!((result[2] - 4.0).abs() < 0.01);
1453
1454 aggregator.reset();
1456 assert_eq!(aggregator.participant_count(), 0);
1457 }
1458
1459 #[test]
1460 fn test_secure_aggregation_insufficient_participants() {
1461 let aggregator = SecureAggregation::new(5);
1462
1463 let g1 = vec![1.0, 2.0];
1464 let g2 = vec![3.0, 4.0];
1465 let gradients = vec![g1, g2];
1466
1467 let result = aggregator.aggregate_secure(&gradients);
1469 assert!(result.is_err());
1470 }
1471
1472 #[test]
1473 fn test_dp_mechanism_types() {
1474 let gaussian = DPMechanism::Gaussian;
1475 let laplacian = DPMechanism::Laplacian;
1476
1477 assert_eq!(gaussian, DPMechanism::Gaussian);
1478 assert_eq!(laplacian, DPMechanism::Laplacian);
1479 assert_ne!(gaussian, laplacian);
1480 }
1481
1482 #[test]
1483 fn test_client_info() {
1484 let mut client = ClientInfo::new("client1".to_string(), 1000);
1485
1486 assert_eq!(client.client_id, "client1");
1487 assert_eq!(client.state, ClientState::Idle);
1488 assert_eq!(client.sample_count, 1000);
1489
1490 client.start_training();
1491 assert_eq!(client.state, ClientState::Training);
1492
1493 client.complete_training();
1494 assert_eq!(client.state, ClientState::Completed);
1495
1496 client.mark_failed();
1497 assert_eq!(client.state, ClientState::Failed);
1498 }
1499
1500 #[test]
1501 fn test_federated_round() {
1502 let model_cid = Cid::default();
1503 let mut round = FederatedRound::new(0, model_cid, 5);
1504
1505 assert_eq!(round.round_num, 0);
1506 assert_eq!(round.client_count, 5);
1507 assert_eq!(round.completed_count, 0);
1508 assert!(!round.is_complete());
1509
1510 for _ in 0..5 {
1512 round.mark_client_completed();
1513 }
1514
1515 assert_eq!(round.completed_count, 5);
1516 assert!(round.is_complete());
1517
1518 let gradient = vec![1.0, 2.0, 3.0];
1520 round.complete(gradient.clone());
1521
1522 assert_eq!(round.aggregated_gradient, Some(gradient));
1523 assert!(round.end_time.is_some());
1524 assert!(round.duration().is_some());
1525 }
1526
1527 #[test]
1528 fn test_convergence_detector() {
1529 let mut detector = ConvergenceDetector::new(3, 0.01);
1530
1531 detector.add_loss(1.0);
1533 detector.add_loss(0.99);
1534 detector.add_loss(0.98);
1535
1536 assert!(detector.has_converged());
1537 assert_eq!(detector.latest_loss(), Some(0.98));
1538 assert_eq!(detector.history().len(), 3);
1539
1540 detector.reset();
1542 assert_eq!(detector.history().len(), 0);
1543 }
1544
1545 #[test]
1546 fn test_convergence_detector_not_converged() {
1547 let mut detector = ConvergenceDetector::new(3, 0.01);
1548
1549 detector.add_loss(1.0);
1551 detector.add_loss(0.5);
1552 detector.add_loss(1.5);
1553
1554 assert!(!detector.has_converged());
1555 }
1556
1557 #[test]
1558 fn test_model_sync_protocol() {
1559 let mut protocol = ModelSyncProtocol::new(10, 3, 3, 0.01);
1560
1561 assert_eq!(protocol.current_round(), 0);
1562 assert_eq!(protocol.max_rounds(), 10);
1563 assert!(protocol.should_continue());
1564
1565 let model_cid = Cid::default();
1567 let round_num = protocol.start_round(model_cid, 5).unwrap();
1568
1569 assert_eq!(round_num, 0);
1570 assert_eq!(protocol.current_round(), 1);
1571 assert_eq!(protocol.total_rounds(), 1);
1572
1573 let gradient = vec![1.0, 2.0, 3.0];
1575 protocol
1576 .complete_round(round_num, gradient.clone(), 1.0)
1577 .unwrap();
1578
1579 assert_eq!(protocol.latest_loss(), Some(1.0));
1580
1581 let round = protocol.get_round(0).unwrap();
1583 assert_eq!(round.round_num, 0);
1584 assert_eq!(round.aggregated_gradient, Some(gradient));
1585 }
1586
1587 #[test]
1588 fn test_model_sync_protocol_convergence() {
1589 let mut protocol = ModelSyncProtocol::new(10, 2, 3, 0.01);
1590
1591 let model_cid = Cid::default();
1592
1593 for i in 0..3 {
1595 protocol.start_round(model_cid, 3).unwrap();
1596 let gradient = vec![1.0, 2.0];
1597 let loss = 1.0 - (i as f64 * 0.001);
1598 protocol.complete_round(i, gradient, loss).unwrap();
1599 }
1600
1601 assert!(protocol.has_converged());
1603 assert!(!protocol.should_continue());
1604 }
1605
1606 #[test]
1607 fn test_model_sync_protocol_max_rounds() {
1608 let mut protocol = ModelSyncProtocol::new(2, 1, 3, 0.01);
1609
1610 let model_cid = Cid::default();
1611
1612 protocol.start_round(model_cid, 2).unwrap();
1614 protocol.start_round(model_cid, 2).unwrap();
1615
1616 let result = protocol.start_round(model_cid, 2);
1618 assert!(result.is_err());
1619 }
1620
1621 #[test]
1622 fn test_model_sync_protocol_min_clients() {
1623 let mut protocol = ModelSyncProtocol::new(10, 5, 3, 0.01);
1624
1625 let model_cid = Cid::default();
1626
1627 let result = protocol.start_round(model_cid, 3);
1629 assert!(result.is_err());
1630
1631 let result = protocol.start_round(model_cid, 5);
1633 assert!(result.is_ok());
1634 }
1635
1636 #[test]
1637 fn test_client_state_enum() {
1638 let idle = ClientState::Idle;
1639 let training = ClientState::Training;
1640 let completed = ClientState::Completed;
1641 let failed = ClientState::Failed;
1642
1643 assert_ne!(idle, training);
1644 assert_ne!(training, completed);
1645 assert_ne!(completed, failed);
1646 assert_eq!(idle, ClientState::Idle);
1647 }
1648}