oxirs_embed/federated_learning/
privacy.rs

1//! Privacy-preserving mechanisms for federated learning
2//!
3//! This module implements differential privacy, gradient clipping, noise mechanisms,
4//! and privacy budget accounting for secure federated learning.
5
6use super::config::{NoiseMechanism, PrivacyConfig};
7use chrono::{DateTime, Utc};
8use scirs2_core::ndarray_ext::Array2;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13/// Privacy engine for differential privacy
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PrivacyEngine {
16    /// Privacy configuration
17    pub config: PrivacyConfig,
18    /// Privacy accountant
19    pub privacy_accountant: PrivacyAccountant,
20    /// Noise generator
21    pub noise_generator: NoiseGenerator,
22    /// Clipping mechanisms
23    pub clipping_mechanisms: ClippingMechanisms,
24}
25
26impl PrivacyEngine {
27    /// Create new privacy engine
28    pub fn new(config: PrivacyConfig) -> Self {
29        Self {
30            privacy_accountant: PrivacyAccountant::new(config.epsilon, config.delta),
31            noise_generator: NoiseGenerator::new(config.noise_mechanism.clone()),
32            clipping_mechanisms: ClippingMechanisms::new(config.clipping_threshold),
33            config,
34        }
35    }
36
37    /// Process gradients with privacy mechanisms
38    pub fn process_gradients(
39        &mut self,
40        gradients: &Array2<f32>,
41        participant_id: Uuid,
42    ) -> anyhow::Result<Array2<f32>> {
43        // Clip gradients first
44        let clipped_gradients = self.clipping_mechanisms.clip_gradients(gradients);
45
46        // Add noise for differential privacy
47        let noisy_gradients = if self.config.enable_differential_privacy {
48            self.noise_generator.add_noise(&clipped_gradients)
49        } else {
50            clipped_gradients
51        };
52
53        // Update privacy budget
54        let privacy_cost = self.calculate_privacy_cost(&noisy_gradients);
55        self.privacy_accountant
56            .consume_budget(participant_id, privacy_cost)?;
57
58        Ok(noisy_gradients)
59    }
60
61    /// Calculate privacy cost for an operation
62    fn calculate_privacy_cost(&self, _gradients: &Array2<f32>) -> f64 {
63        // Simplified privacy cost calculation
64        // In practice, this would be more sophisticated
65        self.config.local_epsilon / 100.0
66    }
67}
68
69/// Privacy budget accounting
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct PrivacyAccountant {
72    /// Total epsilon budget
73    pub total_epsilon: f64,
74    /// Used epsilon budget
75    pub used_epsilon: f64,
76    /// Delta parameter
77    pub delta: f64,
78    /// Privacy budget per participant
79    pub participant_budgets: HashMap<Uuid, f64>,
80    /// Budget tracking per round
81    pub round_budgets: Vec<f64>,
82    /// Budget history
83    pub budget_history: Vec<BudgetEntry>,
84}
85
86/// Budget entry for tracking
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct BudgetEntry {
89    /// Timestamp
90    pub timestamp: DateTime<Utc>,
91    /// Participant ID
92    pub participant_id: Option<Uuid>,
93    /// Privacy cost
94    pub privacy_cost: f64,
95    /// Operation type
96    pub operation: String,
97    /// Remaining budget
98    pub remaining_budget: f64,
99}
100
101impl PrivacyAccountant {
102    /// Create new privacy accountant
103    pub fn new(total_epsilon: f64, delta: f64) -> Self {
104        Self {
105            total_epsilon,
106            used_epsilon: 0.0,
107            delta,
108            participant_budgets: HashMap::new(),
109            round_budgets: Vec::new(),
110            budget_history: Vec::new(),
111        }
112    }
113
114    /// Consume privacy budget
115    pub fn consume_budget(&mut self, participant_id: Uuid, cost: f64) -> anyhow::Result<()> {
116        if self.used_epsilon + cost > self.total_epsilon {
117            return Err(anyhow::anyhow!("Privacy budget exceeded"));
118        }
119
120        self.used_epsilon += cost;
121        *self
122            .participant_budgets
123            .entry(participant_id)
124            .or_insert(0.0) += cost;
125
126        // Record budget entry
127        self.budget_history.push(BudgetEntry {
128            timestamp: Utc::now(),
129            participant_id: Some(participant_id),
130            privacy_cost: cost,
131            operation: "gradient_update".to_string(),
132            remaining_budget: self.total_epsilon - self.used_epsilon,
133        });
134
135        Ok(())
136    }
137
138    /// Get remaining budget
139    pub fn remaining_budget(&self) -> f64 {
140        self.total_epsilon - self.used_epsilon
141    }
142
143    /// Check if budget is available
144    pub fn is_budget_available(&self, required_budget: f64) -> bool {
145        self.remaining_budget() >= required_budget
146    }
147}
148
149/// Noise generation for differential privacy
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct NoiseGenerator {
152    /// Noise mechanism
153    pub mechanism: NoiseMechanism,
154    /// Noise scale
155    pub scale: f64,
156    /// Random seed
157    pub seed: Option<u64>,
158}
159
160impl NoiseGenerator {
161    /// Create new noise generator
162    pub fn new(mechanism: NoiseMechanism) -> Self {
163        Self {
164            mechanism,
165            scale: 1.0,
166            seed: None,
167        }
168    }
169
170    /// Set noise scale
171    pub fn with_scale(mut self, scale: f64) -> Self {
172        self.scale = scale;
173        self
174    }
175
176    /// Add noise to parameters for differential privacy
177    pub fn add_noise(&self, parameters: &Array2<f32>) -> Array2<f32> {
178        match self.mechanism {
179            NoiseMechanism::Gaussian => self.add_gaussian_noise(parameters),
180            NoiseMechanism::Laplace => self.add_laplace_noise(parameters),
181            NoiseMechanism::Exponential => self.add_exponential_noise(parameters),
182            NoiseMechanism::SparseVector => self.add_sparse_vector_noise(parameters),
183        }
184    }
185
186    /// Add Gaussian noise
187    fn add_gaussian_noise(&self, parameters: &Array2<f32>) -> Array2<f32> {
188        let noise = Array2::from_shape_fn(parameters.raw_dim(), |_| {
189            // Box-Muller transform for Gaussian noise
190            let u1: f32 = {
191                use scirs2_core::random::{Random, Rng};
192                let mut random = Random::default();
193                random.random::<f32>()
194            };
195            let u2: f32 = {
196                use scirs2_core::random::{Random, Rng};
197                let mut random = Random::default();
198                random.random::<f32>()
199            };
200            let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
201            z * self.scale as f32
202        });
203        parameters + &noise
204    }
205
206    /// Add Laplace noise
207    fn add_laplace_noise(&self, parameters: &Array2<f32>) -> Array2<f32> {
208        let noise = Array2::from_shape_fn(parameters.raw_dim(), |_| {
209            let u: f32 = {
210                use scirs2_core::random::{Random, Rng};
211                let mut random = Random::default();
212                random.random::<f32>() - 0.5
213            };
214            let sign = if u > 0.0 { 1.0 } else { -1.0 };
215            -sign * (1.0 - 2.0 * u.abs()).ln() * self.scale as f32
216        });
217        parameters + &noise
218    }
219
220    /// Add exponential mechanism noise (simplified)
221    fn add_exponential_noise(&self, parameters: &Array2<f32>) -> Array2<f32> {
222        // Simplified implementation - would need proper exponential mechanism
223        self.add_laplace_noise(parameters)
224    }
225
226    /// Add sparse vector technique noise
227    fn add_sparse_vector_noise(&self, parameters: &Array2<f32>) -> Array2<f32> {
228        // Simplified implementation - would need proper sparse vector technique
229        let mut result = self.add_gaussian_noise(parameters);
230
231        // Apply sparsity by zeroing out small values
232        result.mapv_inplace(|x| {
233            if x.abs() < self.scale as f32 * 0.1 {
234                0.0
235            } else {
236                x
237            }
238        });
239
240        result
241    }
242}
243
244/// Gradient clipping mechanisms
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ClippingMechanisms {
247    /// Clipping threshold
248    pub threshold: f64,
249    /// Clipping method
250    pub method: ClippingMethod,
251    /// Adaptive clipping
252    pub adaptive_clipping: bool,
253    /// Adaptive threshold history
254    pub threshold_history: Vec<f64>,
255}
256
257/// Gradient clipping methods
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub enum ClippingMethod {
260    /// L2 norm clipping
261    L2Norm,
262    /// L1 norm clipping
263    L1Norm,
264    /// Element-wise clipping
265    ElementWise,
266    /// Adaptive clipping
267    Adaptive,
268}
269
270impl ClippingMechanisms {
271    /// Create new clipping mechanisms
272    pub fn new(threshold: f64) -> Self {
273        Self {
274            threshold,
275            method: ClippingMethod::L2Norm,
276            adaptive_clipping: false,
277            threshold_history: Vec::new(),
278        }
279    }
280
281    /// Set clipping method
282    pub fn with_method(mut self, method: ClippingMethod) -> Self {
283        self.method = method;
284        self
285    }
286
287    /// Enable adaptive clipping
288    pub fn with_adaptive_clipping(mut self, adaptive: bool) -> Self {
289        self.adaptive_clipping = adaptive;
290        self
291    }
292
293    /// Clip gradients based on configured method
294    pub fn clip_gradients(&mut self, gradients: &Array2<f32>) -> Array2<f32> {
295        let result = match self.method {
296            ClippingMethod::L2Norm => self.clip_l2_norm(gradients),
297            ClippingMethod::L1Norm => self.clip_l1_norm(gradients),
298            ClippingMethod::ElementWise => self.clip_element_wise(gradients),
299            ClippingMethod::Adaptive => self.clip_adaptive(gradients),
300        };
301
302        // Update threshold history for adaptive clipping
303        if self.adaptive_clipping {
304            let current_norm = self.calculate_norm(gradients);
305            self.threshold_history.push(current_norm);
306
307            // Adapt threshold based on history (simplified)
308            if self.threshold_history.len() > 10 {
309                let avg_norm: f64 = self.threshold_history.iter().sum::<f64>()
310                    / self.threshold_history.len() as f64;
311                self.threshold = avg_norm * 1.2; // Allow 20% above average
312                self.threshold_history.remove(0); // Keep only recent history
313            }
314        }
315
316        result
317    }
318
319    /// Calculate gradient norm
320    fn calculate_norm(&self, gradients: &Array2<f32>) -> f64 {
321        match self.method {
322            ClippingMethod::L2Norm | ClippingMethod::Adaptive => gradients
323                .iter()
324                .map(|x| (*x as f64) * (*x as f64))
325                .sum::<f64>()
326                .sqrt(),
327            ClippingMethod::L1Norm => gradients.iter().map(|x| (*x as f64).abs()).sum::<f64>(),
328            ClippingMethod::ElementWise => gradients
329                .iter()
330                .map(|x| (*x as f64).abs())
331                .fold(0.0, f64::max),
332        }
333    }
334
335    /// L2 norm clipping
336    fn clip_l2_norm(&self, gradients: &Array2<f32>) -> Array2<f32> {
337        let norm = gradients.iter().map(|x| x * x).sum::<f32>().sqrt();
338        if norm > self.threshold as f32 {
339            gradients * (self.threshold as f32 / norm)
340        } else {
341            gradients.clone()
342        }
343    }
344
345    /// L1 norm clipping
346    fn clip_l1_norm(&self, gradients: &Array2<f32>) -> Array2<f32> {
347        let norm = gradients.iter().map(|x| x.abs()).sum::<f32>();
348        if norm > self.threshold as f32 {
349            gradients * (self.threshold as f32 / norm)
350        } else {
351            gradients.clone()
352        }
353    }
354
355    /// Element-wise clipping
356    fn clip_element_wise(&self, gradients: &Array2<f32>) -> Array2<f32> {
357        gradients.mapv(|x| x.max(-self.threshold as f32).min(self.threshold as f32))
358    }
359
360    /// Adaptive clipping
361    fn clip_adaptive(&self, gradients: &Array2<f32>) -> Array2<f32> {
362        // Use L2 norm clipping with adaptive threshold
363        self.clip_l2_norm(gradients)
364    }
365}
366
367/// Privacy parameters for different mechanisms
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct PrivacyParams {
370    /// Epsilon value for differential privacy
371    pub epsilon: f64,
372    /// Delta value for differential privacy
373    pub delta: f64,
374    /// Sensitivity of the query/function
375    pub sensitivity: f64,
376    /// Composition method for privacy accounting
377    pub composition_method: CompositionMethod,
378}
379
380/// Methods for privacy composition
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub enum CompositionMethod {
383    /// Basic composition
384    Basic,
385    /// Advanced composition
386    Advanced,
387    /// Renyi differential privacy
388    RenyiDP { alpha: f64 },
389    /// Privacy loss distribution
390    PLD,
391    /// Gaussian differential privacy
392    GDP,
393}
394
395/// Privacy accountant with advanced composition
396#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct AdvancedPrivacyAccountant {
398    /// Privacy parameters
399    pub privacy_params: PrivacyParams,
400    /// Composition tracking
401    pub compositions: Vec<CompositionEntry>,
402    /// Current privacy guarantees
403    pub current_guarantees: PrivacyGuarantees,
404}
405
406/// Privacy composition entry
407#[derive(Debug, Clone, Serialize, Deserialize)]
408pub struct CompositionEntry {
409    /// Operation timestamp
410    pub timestamp: DateTime<Utc>,
411    /// Privacy cost
412    pub privacy_cost: (f64, f64), // (epsilon, delta)
413    /// Mechanism used
414    pub mechanism: String,
415    /// Participant involved
416    pub participant_id: Option<Uuid>,
417}
418
419/// Current privacy guarantees
420#[derive(Debug, Clone, Serialize, Deserialize)]
421pub struct PrivacyGuarantees {
422    /// Total epsilon consumed
423    pub total_epsilon: f64,
424    /// Total delta consumed
425    pub total_delta: f64,
426    /// Worst-case privacy loss
427    pub worst_case_loss: f64,
428    /// Expected privacy loss
429    pub expected_loss: f64,
430    /// Confidence intervals
431    pub confidence_intervals: HashMap<String, (f64, f64)>,
432}