optirs_core/privacy/federated/
secure_aggregation.rs1use std::fmt::Debug;
2use super::super::moment_accountant::MomentsAccountant;
9use super::super::{AccountingMethod, DifferentialPrivacyConfig, NoiseMechanism, PrivacyBudget};
10use crate::error::{OptimError, Result};
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::numeric::Float;
13use scirs2_core::random::Rng;
14use scirs2_core::random::{Random, Rng as SCRRng};
15use std::collections::{HashMap, VecDeque};
16use std::sync::Arc;
17
18#[derive(Debug, Clone)]
20pub struct SecureAggregationConfig {
21 pub enabled: bool,
23
24 pub min_clients: usize,
26
27 pub max_dropouts: usize,
29
30 pub masking_dimension: usize,
32
33 pub seed_sharing: SeedSharingMethod,
35
36 pub quantization_bits: Option<u8>,
38
39 pub aggregate_dp: bool,
41}
42
43#[derive(Debug, Clone, Copy)]
45pub enum SeedSharingMethod {
46 ShamirSecretSharing,
48
49 ThresholdEncryption,
51
52 DistributedKeyGeneration,
54}
55
56pub struct SecureAggregator<T: Float + Debug + Send + Sync + 'static> {
58 config: SecureAggregationConfig,
59 client_masks: HashMap<String, Array1<T>>,
60 shared_randomness: Arc<std::sync::Mutex<u64>>,
61 aggregation_threshold: usize,
62 round_keys: Vec<u64>,
63}
64
65#[derive(Debug, Clone)]
67pub struct SecureAggregationPlan {
68 pub round_seed: u64,
69 pub participating_clients: Vec<String>,
70 pub min_threshold: usize,
71 pub masking_enabled: bool,
72}
73
74impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
75 SecureAggregator<T>
76{
77 pub fn new(config: SecureAggregationConfig) -> Result<Self> {
78 let min_clients = config.min_clients;
79 Ok(Self {
80 config,
81 client_masks: HashMap::new(),
82 shared_randomness: Arc::new(std::sync::Mutex::new(0u64)),
83 aggregation_threshold: min_clients,
84 round_keys: Vec::new(),
85 })
86 }
87
88 pub fn prepare_round(&mut self, selectedclients: &[String]) -> Result<SecureAggregationPlan> {
89 let mut seed = self.shared_randomness.lock().unwrap();
91 *seed = seed.wrapping_add(1);
92 let round_seed = *seed;
93 self.round_keys.push(round_seed);
94
95 self.client_masks.clear();
97 for clientid in selectedclients.iter() {
98 let mut client_rng = Random::default();
99 let mask_size = self.config.masking_dimension;
100
101 let mask = Array1::from_iter(
102 (0..mask_size).map(|_| T::from(client_rng.gen_range(-1.0..1.0)).unwrap()),
103 );
104
105 self.client_masks.insert(clientid.clone(), mask);
106 }
107
108 Ok(SecureAggregationPlan {
109 round_seed,
110 participating_clients: selectedclients.to_vec(),
111 min_threshold: self.config.min_clients,
112 masking_enabled: true,
113 })
114 }
115
116 pub fn aggregate_with_masks(
117 &self,
118 clientupdates: &HashMap<String, Array1<T>>,
119 _aggregation_plan: &SecureAggregationPlan,
120 ) -> Result<Array1<T>> {
121 if clientupdates.len() < self.aggregation_threshold {
122 return Err(OptimError::InvalidConfig(
123 "Insufficient clients for secure aggregation".to_string(),
124 ));
125 }
126
127 let first_update = clientupdates.values().next().unwrap();
129 let mut aggregated = Array1::zeros(first_update.len());
130
131 for (clientid, update) in clientupdates {
132 if let Some(mask) = self.client_masks.get(clientid) {
133 let masked_update = if update.len() == mask.len() {
135 update + mask
136 } else {
137 update.clone() };
139 aggregated = aggregated + masked_update;
140 } else {
141 aggregated = aggregated + update;
142 }
143 }
144
145 let num_clients = T::from(clientupdates.len()).unwrap();
147 aggregated = aggregated / num_clients;
148
149 Ok(aggregated)
150 }
151
152 pub fn config(&self) -> &SecureAggregationConfig {
154 &self.config
155 }
156
157 pub fn aggregation_threshold(&self) -> usize {
159 self.aggregation_threshold
160 }
161
162 pub fn is_enabled(&self) -> bool {
164 self.config.enabled
165 }
166}
167
168impl Default for SecureAggregationConfig {
169 fn default() -> Self {
170 Self {
171 enabled: true,
172 min_clients: 10,
173 max_dropouts: 5,
174 masking_dimension: 1000,
175 seed_sharing: SeedSharingMethod::ShamirSecretSharing,
176 quantization_bits: None,
177 aggregate_dp: true,
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use scirs2_core::ndarray::Array1;
186
187 #[test]
188 fn test_secure_aggregation_config() {
189 let config = SecureAggregationConfig {
190 enabled: true,
191 min_clients: 5,
192 max_dropouts: 2,
193 masking_dimension: 100,
194 seed_sharing: SeedSharingMethod::ShamirSecretSharing,
195 quantization_bits: Some(8),
196 aggregate_dp: true,
197 };
198
199 assert!(config.enabled);
200 assert_eq!(config.min_clients, 5);
201 assert_eq!(config.max_dropouts, 2);
202 }
203
204 #[test]
205 fn test_secure_aggregator_creation() {
206 let config = SecureAggregationConfig::default();
207 let aggregator = SecureAggregator::<f64>::new(config.clone());
208
209 assert!(aggregator.is_ok());
210 let agg = aggregator.unwrap();
211 assert_eq!(agg.aggregation_threshold(), config.min_clients);
212 assert!(agg.is_enabled());
213 }
214
215 #[test]
216 fn test_secure_aggregation_plan() {
217 let config = SecureAggregationConfig::default();
218 let mut aggregator = SecureAggregator::<f64>::new(config).unwrap();
219
220 let clients = vec!["client1".to_string(), "client2".to_string()];
221 let plan = aggregator.prepare_round(&clients);
222
223 assert!(plan.is_ok());
224 let plan = plan.unwrap();
225 assert_eq!(plan.participating_clients.len(), 2);
226 assert!(plan.masking_enabled);
227 }
228}