Skip to main content

optirs_core/privacy/federated/
secure_aggregation.rs

1use std::fmt::Debug;
2// Secure Aggregation Module
3//
4// This module implements secure aggregation protocols for federated learning,
5// including secure multi-party computation techniques to aggregate client updates
6// while preserving individual client privacy.
7
8use super::super::moment_accountant::MomentsAccountant;
9use super::super::{AccountingMethod, DifferentialPrivacyConfig, NoiseMechanism, PrivacyBudget};
10use crate::error::{OptimError, Result};
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::numeric::Float;
13use scirs2_core::random::Rng;
14use scirs2_core::random::{Random, Rng as SCRRng};
15use std::collections::{HashMap, VecDeque};
16use std::sync::Arc;
17
18/// Secure aggregation configuration
19#[derive(Debug, Clone)]
20pub struct SecureAggregationConfig {
21    /// Enable secure aggregation
22    pub enabled: bool,
23
24    /// Minimum number of clients for aggregation
25    pub min_clients: usize,
26
27    /// Maximum number of dropouts tolerated
28    pub max_dropouts: usize,
29
30    /// Masking vector dimension
31    pub masking_dimension: usize,
32
33    /// Random seed sharing method
34    pub seed_sharing: SeedSharingMethod,
35
36    /// Quantization bits for compressed aggregation
37    pub quantization_bits: Option<u8>,
38
39    /// Enable differential privacy on aggregated result
40    pub aggregate_dp: bool,
41}
42
43/// Seed sharing methods for secure aggregation
44#[derive(Debug, Clone, Copy)]
45pub enum SeedSharingMethod {
46    /// Shamir secret sharing
47    ShamirSecretSharing,
48
49    /// Threshold encryption
50    ThresholdEncryption,
51
52    /// Distributed key generation
53    DistributedKeyGeneration,
54}
55
56/// Secure aggregation protocol implementation
57pub struct SecureAggregator<T: Float + Debug + Send + Sync + 'static> {
58    config: SecureAggregationConfig,
59    client_masks: HashMap<String, Array1<T>>,
60    shared_randomness: Arc<std::sync::Mutex<u64>>,
61    aggregation_threshold: usize,
62    round_keys: Vec<u64>,
63}
64
65/// Secure aggregation plan
66#[derive(Debug, Clone)]
67pub struct SecureAggregationPlan {
68    pub round_seed: u64,
69    pub participating_clients: Vec<String>,
70    pub min_threshold: usize,
71    pub masking_enabled: bool,
72}
73
74impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
75    SecureAggregator<T>
76{
77    pub fn new(config: SecureAggregationConfig) -> Result<Self> {
78        let min_clients = config.min_clients;
79        Ok(Self {
80            config,
81            client_masks: HashMap::new(),
82            shared_randomness: Arc::new(std::sync::Mutex::new(0u64)),
83            aggregation_threshold: min_clients,
84            round_keys: Vec::new(),
85        })
86    }
87
88    pub fn prepare_round(&mut self, selectedclients: &[String]) -> Result<SecureAggregationPlan> {
89        // Generate round-specific keys
90        let mut seed = self.shared_randomness.lock().expect("lock poisoned");
91        *seed = seed.wrapping_add(1);
92        let round_seed = *seed;
93        self.round_keys.push(round_seed);
94
95        // Generate client masks (simplified)
96        self.client_masks.clear();
97        for clientid in selectedclients.iter() {
98            let mut client_rng = Random::default();
99            let mask_size = self.config.masking_dimension;
100
101            let mask = Array1::from_iter(
102                (0..mask_size)
103                    .map(|_| T::from(client_rng.gen_range(-1.0..1.0)).expect("unwrap failed")),
104            );
105
106            self.client_masks.insert(clientid.clone(), mask);
107        }
108
109        Ok(SecureAggregationPlan {
110            round_seed,
111            participating_clients: selectedclients.to_vec(),
112            min_threshold: self.config.min_clients,
113            masking_enabled: true,
114        })
115    }
116
117    pub fn aggregate_with_masks(
118        &self,
119        clientupdates: &HashMap<String, Array1<T>>,
120        _aggregation_plan: &SecureAggregationPlan,
121    ) -> Result<Array1<T>> {
122        if clientupdates.len() < self.aggregation_threshold {
123            return Err(OptimError::InvalidConfig(
124                "Insufficient clients for secure aggregation".to_string(),
125            ));
126        }
127
128        // Simplified secure aggregation (in practice, would use more sophisticated protocols)
129        let first_update = clientupdates.values().next().expect("unwrap failed");
130        let mut aggregated = Array1::zeros(first_update.len());
131
132        for (clientid, update) in clientupdates {
133            if let Some(mask) = self.client_masks.get(clientid) {
134                // Apply mask (simplified - real implementation would be more complex)
135                let masked_update = if update.len() == mask.len() {
136                    update + mask
137                } else {
138                    update.clone() // Fallback if dimensions don't match
139                };
140                aggregated = aggregated + masked_update;
141            } else {
142                aggregated = aggregated + update;
143            }
144        }
145
146        // Remove aggregated masks (simplified)
147        let num_clients = T::from(clientupdates.len()).expect("unwrap failed");
148        aggregated = aggregated / num_clients;
149
150        Ok(aggregated)
151    }
152
153    /// Get current configuration
154    pub fn config(&self) -> &SecureAggregationConfig {
155        &self.config
156    }
157
158    /// Get aggregation threshold
159    pub fn aggregation_threshold(&self) -> usize {
160        self.aggregation_threshold
161    }
162
163    /// Check if secure aggregation is enabled
164    pub fn is_enabled(&self) -> bool {
165        self.config.enabled
166    }
167}
168
169impl Default for SecureAggregationConfig {
170    fn default() -> Self {
171        Self {
172            enabled: true,
173            min_clients: 10,
174            max_dropouts: 5,
175            masking_dimension: 1000,
176            seed_sharing: SeedSharingMethod::ShamirSecretSharing,
177            quantization_bits: None,
178            aggregate_dp: true,
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use scirs2_core::ndarray::Array1;
187
188    #[test]
189    fn test_secure_aggregation_config() {
190        let config = SecureAggregationConfig {
191            enabled: true,
192            min_clients: 5,
193            max_dropouts: 2,
194            masking_dimension: 100,
195            seed_sharing: SeedSharingMethod::ShamirSecretSharing,
196            quantization_bits: Some(8),
197            aggregate_dp: true,
198        };
199
200        assert!(config.enabled);
201        assert_eq!(config.min_clients, 5);
202        assert_eq!(config.max_dropouts, 2);
203    }
204
205    #[test]
206    fn test_secure_aggregator_creation() {
207        let config = SecureAggregationConfig::default();
208        let aggregator = SecureAggregator::<f64>::new(config.clone());
209
210        assert!(aggregator.is_ok());
211        let agg = aggregator.expect("unwrap failed");
212        assert_eq!(agg.aggregation_threshold(), config.min_clients);
213        assert!(agg.is_enabled());
214    }
215
216    #[test]
217    fn test_secure_aggregation_plan() {
218        let config = SecureAggregationConfig::default();
219        let mut aggregator = SecureAggregator::<f64>::new(config).expect("unwrap failed");
220
221        let clients = vec!["client1".to_string(), "client2".to_string()];
222        let plan = aggregator.prepare_round(&clients);
223
224        assert!(plan.is_ok());
225        let plan = plan.expect("unwrap failed");
226        assert_eq!(plan.participating_clients.len(), 2);
227        assert!(plan.masking_enabled);
228    }
229}