optirs_core/privacy/federated/
secure_aggregation.rs1use std::fmt::Debug;
2use super::super::moment_accountant::MomentsAccountant;
9use super::super::{AccountingMethod, DifferentialPrivacyConfig, NoiseMechanism, PrivacyBudget};
10use crate::error::{OptimError, Result};
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::numeric::Float;
13use scirs2_core::random::Rng;
14use scirs2_core::random::{Random, Rng as SCRRng};
15use std::collections::{HashMap, VecDeque};
16use std::sync::Arc;
17
18#[derive(Debug, Clone)]
20pub struct SecureAggregationConfig {
21 pub enabled: bool,
23
24 pub min_clients: usize,
26
27 pub max_dropouts: usize,
29
30 pub masking_dimension: usize,
32
33 pub seed_sharing: SeedSharingMethod,
35
36 pub quantization_bits: Option<u8>,
38
39 pub aggregate_dp: bool,
41}
42
43#[derive(Debug, Clone, Copy)]
45pub enum SeedSharingMethod {
46 ShamirSecretSharing,
48
49 ThresholdEncryption,
51
52 DistributedKeyGeneration,
54}
55
56pub struct SecureAggregator<T: Float + Debug + Send + Sync + 'static> {
58 config: SecureAggregationConfig,
59 client_masks: HashMap<String, Array1<T>>,
60 shared_randomness: Arc<std::sync::Mutex<u64>>,
61 aggregation_threshold: usize,
62 round_keys: Vec<u64>,
63}
64
65#[derive(Debug, Clone)]
67pub struct SecureAggregationPlan {
68 pub round_seed: u64,
69 pub participating_clients: Vec<String>,
70 pub min_threshold: usize,
71 pub masking_enabled: bool,
72}
73
74impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
75 SecureAggregator<T>
76{
77 pub fn new(config: SecureAggregationConfig) -> Result<Self> {
78 let min_clients = config.min_clients;
79 Ok(Self {
80 config,
81 client_masks: HashMap::new(),
82 shared_randomness: Arc::new(std::sync::Mutex::new(0u64)),
83 aggregation_threshold: min_clients,
84 round_keys: Vec::new(),
85 })
86 }
87
88 pub fn prepare_round(&mut self, selectedclients: &[String]) -> Result<SecureAggregationPlan> {
89 let mut seed = self.shared_randomness.lock().expect("lock poisoned");
91 *seed = seed.wrapping_add(1);
92 let round_seed = *seed;
93 self.round_keys.push(round_seed);
94
95 self.client_masks.clear();
97 for clientid in selectedclients.iter() {
98 let mut client_rng = Random::default();
99 let mask_size = self.config.masking_dimension;
100
101 let mask = Array1::from_iter(
102 (0..mask_size)
103 .map(|_| T::from(client_rng.gen_range(-1.0..1.0)).expect("unwrap failed")),
104 );
105
106 self.client_masks.insert(clientid.clone(), mask);
107 }
108
109 Ok(SecureAggregationPlan {
110 round_seed,
111 participating_clients: selectedclients.to_vec(),
112 min_threshold: self.config.min_clients,
113 masking_enabled: true,
114 })
115 }
116
117 pub fn aggregate_with_masks(
118 &self,
119 clientupdates: &HashMap<String, Array1<T>>,
120 _aggregation_plan: &SecureAggregationPlan,
121 ) -> Result<Array1<T>> {
122 if clientupdates.len() < self.aggregation_threshold {
123 return Err(OptimError::InvalidConfig(
124 "Insufficient clients for secure aggregation".to_string(),
125 ));
126 }
127
128 let first_update = clientupdates.values().next().expect("unwrap failed");
130 let mut aggregated = Array1::zeros(first_update.len());
131
132 for (clientid, update) in clientupdates {
133 if let Some(mask) = self.client_masks.get(clientid) {
134 let masked_update = if update.len() == mask.len() {
136 update + mask
137 } else {
138 update.clone() };
140 aggregated = aggregated + masked_update;
141 } else {
142 aggregated = aggregated + update;
143 }
144 }
145
146 let num_clients = T::from(clientupdates.len()).expect("unwrap failed");
148 aggregated = aggregated / num_clients;
149
150 Ok(aggregated)
151 }
152
153 pub fn config(&self) -> &SecureAggregationConfig {
155 &self.config
156 }
157
158 pub fn aggregation_threshold(&self) -> usize {
160 self.aggregation_threshold
161 }
162
163 pub fn is_enabled(&self) -> bool {
165 self.config.enabled
166 }
167}
168
169impl Default for SecureAggregationConfig {
170 fn default() -> Self {
171 Self {
172 enabled: true,
173 min_clients: 10,
174 max_dropouts: 5,
175 masking_dimension: 1000,
176 seed_sharing: SeedSharingMethod::ShamirSecretSharing,
177 quantization_bits: None,
178 aggregate_dp: true,
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use scirs2_core::ndarray::Array1;
187
188 #[test]
189 fn test_secure_aggregation_config() {
190 let config = SecureAggregationConfig {
191 enabled: true,
192 min_clients: 5,
193 max_dropouts: 2,
194 masking_dimension: 100,
195 seed_sharing: SeedSharingMethod::ShamirSecretSharing,
196 quantization_bits: Some(8),
197 aggregate_dp: true,
198 };
199
200 assert!(config.enabled);
201 assert_eq!(config.min_clients, 5);
202 assert_eq!(config.max_dropouts, 2);
203 }
204
205 #[test]
206 fn test_secure_aggregator_creation() {
207 let config = SecureAggregationConfig::default();
208 let aggregator = SecureAggregator::<f64>::new(config.clone());
209
210 assert!(aggregator.is_ok());
211 let agg = aggregator.expect("unwrap failed");
212 assert_eq!(agg.aggregation_threshold(), config.min_clients);
213 assert!(agg.is_enabled());
214 }
215
216 #[test]
217 fn test_secure_aggregation_plan() {
218 let config = SecureAggregationConfig::default();
219 let mut aggregator = SecureAggregator::<f64>::new(config).expect("unwrap failed");
220
221 let clients = vec!["client1".to_string(), "client2".to_string()];
222 let plan = aggregator.prepare_round(&clients);
223
224 assert!(plan.is_ok());
225 let plan = plan.expect("unwrap failed");
226 assert_eq!(plan.participating_clients.len(), 2);
227 assert!(plan.masking_enabled);
228 }
229}