optirs_core/privacy/federated/
secure_aggregation.rs

1use std::fmt::Debug;
2// Secure Aggregation Module
3//
4// This module implements secure aggregation protocols for federated learning,
5// including secure multi-party computation techniques to aggregate client updates
6// while preserving individual client privacy.
7
8use super::super::moment_accountant::MomentsAccountant;
9use super::super::{AccountingMethod, DifferentialPrivacyConfig, NoiseMechanism, PrivacyBudget};
10use crate::error::{OptimError, Result};
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::numeric::Float;
13use scirs2_core::random::Rng;
14use scirs2_core::random::{Random, Rng as SCRRng};
15use std::collections::{HashMap, VecDeque};
16use std::sync::Arc;
17
18/// Secure aggregation configuration
19#[derive(Debug, Clone)]
20pub struct SecureAggregationConfig {
21    /// Enable secure aggregation
22    pub enabled: bool,
23
24    /// Minimum number of clients for aggregation
25    pub min_clients: usize,
26
27    /// Maximum number of dropouts tolerated
28    pub max_dropouts: usize,
29
30    /// Masking vector dimension
31    pub masking_dimension: usize,
32
33    /// Random seed sharing method
34    pub seed_sharing: SeedSharingMethod,
35
36    /// Quantization bits for compressed aggregation
37    pub quantization_bits: Option<u8>,
38
39    /// Enable differential privacy on aggregated result
40    pub aggregate_dp: bool,
41}
42
43/// Seed sharing methods for secure aggregation
44#[derive(Debug, Clone, Copy)]
45pub enum SeedSharingMethod {
46    /// Shamir secret sharing
47    ShamirSecretSharing,
48
49    /// Threshold encryption
50    ThresholdEncryption,
51
52    /// Distributed key generation
53    DistributedKeyGeneration,
54}
55
56/// Secure aggregation protocol implementation
57pub struct SecureAggregator<T: Float + Debug + Send + Sync + 'static> {
58    config: SecureAggregationConfig,
59    client_masks: HashMap<String, Array1<T>>,
60    shared_randomness: Arc<std::sync::Mutex<u64>>,
61    aggregation_threshold: usize,
62    round_keys: Vec<u64>,
63}
64
65/// Secure aggregation plan
66#[derive(Debug, Clone)]
67pub struct SecureAggregationPlan {
68    pub round_seed: u64,
69    pub participating_clients: Vec<String>,
70    pub min_threshold: usize,
71    pub masking_enabled: bool,
72}
73
74impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
75    SecureAggregator<T>
76{
77    pub fn new(config: SecureAggregationConfig) -> Result<Self> {
78        let min_clients = config.min_clients;
79        Ok(Self {
80            config,
81            client_masks: HashMap::new(),
82            shared_randomness: Arc::new(std::sync::Mutex::new(0u64)),
83            aggregation_threshold: min_clients,
84            round_keys: Vec::new(),
85        })
86    }
87
88    pub fn prepare_round(&mut self, selectedclients: &[String]) -> Result<SecureAggregationPlan> {
89        // Generate round-specific keys
90        let mut seed = self.shared_randomness.lock().unwrap();
91        *seed = seed.wrapping_add(1);
92        let round_seed = *seed;
93        self.round_keys.push(round_seed);
94
95        // Generate client masks (simplified)
96        self.client_masks.clear();
97        for clientid in selectedclients.iter() {
98            let mut client_rng = Random::default();
99            let mask_size = self.config.masking_dimension;
100
101            let mask = Array1::from_iter(
102                (0..mask_size).map(|_| T::from(client_rng.gen_range(-1.0..1.0)).unwrap()),
103            );
104
105            self.client_masks.insert(clientid.clone(), mask);
106        }
107
108        Ok(SecureAggregationPlan {
109            round_seed,
110            participating_clients: selectedclients.to_vec(),
111            min_threshold: self.config.min_clients,
112            masking_enabled: true,
113        })
114    }
115
116    pub fn aggregate_with_masks(
117        &self,
118        clientupdates: &HashMap<String, Array1<T>>,
119        _aggregation_plan: &SecureAggregationPlan,
120    ) -> Result<Array1<T>> {
121        if clientupdates.len() < self.aggregation_threshold {
122            return Err(OptimError::InvalidConfig(
123                "Insufficient clients for secure aggregation".to_string(),
124            ));
125        }
126
127        // Simplified secure aggregation (in practice, would use more sophisticated protocols)
128        let first_update = clientupdates.values().next().unwrap();
129        let mut aggregated = Array1::zeros(first_update.len());
130
131        for (clientid, update) in clientupdates {
132            if let Some(mask) = self.client_masks.get(clientid) {
133                // Apply mask (simplified - real implementation would be more complex)
134                let masked_update = if update.len() == mask.len() {
135                    update + mask
136                } else {
137                    update.clone() // Fallback if dimensions don't match
138                };
139                aggregated = aggregated + masked_update;
140            } else {
141                aggregated = aggregated + update;
142            }
143        }
144
145        // Remove aggregated masks (simplified)
146        let num_clients = T::from(clientupdates.len()).unwrap();
147        aggregated = aggregated / num_clients;
148
149        Ok(aggregated)
150    }
151
152    /// Get current configuration
153    pub fn config(&self) -> &SecureAggregationConfig {
154        &self.config
155    }
156
157    /// Get aggregation threshold
158    pub fn aggregation_threshold(&self) -> usize {
159        self.aggregation_threshold
160    }
161
162    /// Check if secure aggregation is enabled
163    pub fn is_enabled(&self) -> bool {
164        self.config.enabled
165    }
166}
167
168impl Default for SecureAggregationConfig {
169    fn default() -> Self {
170        Self {
171            enabled: true,
172            min_clients: 10,
173            max_dropouts: 5,
174            masking_dimension: 1000,
175            seed_sharing: SeedSharingMethod::ShamirSecretSharing,
176            quantization_bits: None,
177            aggregate_dp: true,
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use scirs2_core::ndarray::Array1;
186
187    #[test]
188    fn test_secure_aggregation_config() {
189        let config = SecureAggregationConfig {
190            enabled: true,
191            min_clients: 5,
192            max_dropouts: 2,
193            masking_dimension: 100,
194            seed_sharing: SeedSharingMethod::ShamirSecretSharing,
195            quantization_bits: Some(8),
196            aggregate_dp: true,
197        };
198
199        assert!(config.enabled);
200        assert_eq!(config.min_clients, 5);
201        assert_eq!(config.max_dropouts, 2);
202    }
203
204    #[test]
205    fn test_secure_aggregator_creation() {
206        let config = SecureAggregationConfig::default();
207        let aggregator = SecureAggregator::<f64>::new(config.clone());
208
209        assert!(aggregator.is_ok());
210        let agg = aggregator.unwrap();
211        assert_eq!(agg.aggregation_threshold(), config.min_clients);
212        assert!(agg.is_enabled());
213    }
214
215    #[test]
216    fn test_secure_aggregation_plan() {
217        let config = SecureAggregationConfig::default();
218        let mut aggregator = SecureAggregator::<f64>::new(config).unwrap();
219
220        let clients = vec!["client1".to_string(), "client2".to_string()];
221        let plan = aggregator.prepare_round(&clients);
222
223        assert!(plan.is_ok());
224        let plan = plan.unwrap();
225        assert_eq!(plan.participating_clients.len(), 2);
226        assert!(plan.masking_enabled);
227    }
228}