Skip to main content

optirs_core/privacy/federated/
composition_analyzer.rs

1// Federated Composition Analyzer Module
2//
3// This module implements privacy composition analysis for federated learning,
4// tracking privacy budget consumption across multiple rounds and providing
5// various composition methods for differential privacy guarantees.
6
7use crate::error::{OptimError, Result};
8use std::collections::HashMap;
9
10/// Federated composition methods
11#[derive(Debug, Clone, Copy, Default)]
12pub enum FederatedCompositionMethod {
13    /// Basic composition
14    Basic,
15
16    /// Advanced composition with amplification
17    AdvancedComposition,
18
19    /// Moments accountant for federated setting
20    #[default]
21    FederatedMomentsAccountant,
22
23    /// Renyi differential privacy
24    RenyiDP,
25
26    /// Zero-concentrated differential privacy
27    ZCDP,
28}
29
30/// Federated composition analyzer
31pub struct FederatedCompositionAnalyzer {
32    method: FederatedCompositionMethod,
33    round_compositions: Vec<RoundComposition>,
34    client_compositions: HashMap<String, Vec<ClientComposition>>,
35}
36
37/// Round composition for privacy accounting
38#[derive(Debug, Clone)]
39pub struct RoundComposition {
40    pub round: usize,
41    pub participating_clients: usize,
42    pub epsilonconsumed: f64,
43    pub delta_consumed: f64,
44    pub amplification_applied: bool,
45    pub composition_method: FederatedCompositionMethod,
46}
47
48/// Client-specific composition tracking
49#[derive(Debug, Clone)]
50pub struct ClientComposition {
51    pub clientid: String,
52    pub round: usize,
53    pub epsilon_contribution: f64,
54    pub delta_contribution: f64,
55}
56
57/// Composition statistics
58#[derive(Debug, Clone)]
59pub struct CompositionStats {
60    pub total_rounds: usize,
61    pub total_epsilon_consumed: f64,
62    pub total_delta_consumed: f64,
63    pub composition_method: FederatedCompositionMethod,
64    pub amplification_rounds: usize,
65}
66
67impl FederatedCompositionAnalyzer {
68    pub fn new(method: FederatedCompositionMethod) -> Self {
69        Self {
70            method,
71            round_compositions: Vec::new(),
72            client_compositions: HashMap::new(),
73        }
74    }
75
76    pub fn analyze_composition(&self, round: usize, epsilon: f64, delta: f64) -> Result<f64> {
77        match self.method {
78            FederatedCompositionMethod::Basic => Ok(epsilon * round as f64),
79            FederatedCompositionMethod::AdvancedComposition => {
80                // Simplified advanced composition
81                let k = round as f64;
82                let advanced_epsilon = (k * epsilon * epsilon
83                    + k.sqrt() * epsilon * (2.0 * (1.25 / delta).ln()).sqrt())
84                .sqrt();
85                Ok(advanced_epsilon)
86            }
87            FederatedCompositionMethod::FederatedMomentsAccountant => {
88                // Use existing moments accountant logic
89                Ok(epsilon * (round as f64).sqrt())
90            }
91            FederatedCompositionMethod::RenyiDP => {
92                // Simplified Renyi DP composition
93                Ok(epsilon * (round as f64).ln())
94            }
95            FederatedCompositionMethod::ZCDP => {
96                // Zero-concentrated DP composition
97                Ok(epsilon * (round as f64).sqrt())
98            }
99        }
100    }
101
102    pub fn add_round_composition(&mut self, composition: RoundComposition) {
103        self.round_compositions.push(composition);
104    }
105
106    pub fn add_client_composition(&mut self, client_id: String, composition: ClientComposition) {
107        self.client_compositions
108            .entry(client_id)
109            .or_default()
110            .push(composition);
111    }
112
113    pub fn get_composition_stats(&self) -> CompositionStats {
114        if self.round_compositions.is_empty() {
115            return CompositionStats::default();
116        }
117
118        let total_epsilon: f64 = self
119            .round_compositions
120            .iter()
121            .map(|comp| comp.epsilonconsumed)
122            .sum();
123
124        let total_delta: f64 = self
125            .round_compositions
126            .iter()
127            .map(|comp| comp.delta_consumed)
128            .sum();
129
130        CompositionStats {
131            total_rounds: self.round_compositions.len(),
132            total_epsilon_consumed: total_epsilon,
133            total_delta_consumed: total_delta,
134            composition_method: self.method,
135            amplification_rounds: self
136                .round_compositions
137                .iter()
138                .filter(|comp| comp.amplification_applied)
139                .count(),
140        }
141    }
142
143    /// Get current composition method
144    pub fn method(&self) -> FederatedCompositionMethod {
145        self.method
146    }
147
148    /// Get number of rounds tracked
149    pub fn rounds_count(&self) -> usize {
150        self.round_compositions.len()
151    }
152
153    /// Get client composition history for a specific client
154    pub fn get_client_compositions(&self, client_id: &str) -> Option<&Vec<ClientComposition>> {
155        self.client_compositions.get(client_id)
156    }
157
158    /// Get round compositions
159    pub fn get_round_compositions(&self) -> &Vec<RoundComposition> {
160        &self.round_compositions
161    }
162
163    /// Clear all composition history
164    pub fn clear_history(&mut self) {
165        self.round_compositions.clear();
166        self.client_compositions.clear();
167    }
168
169    /// Set composition method
170    pub fn set_method(&mut self, method: FederatedCompositionMethod) {
171        self.method = method;
172    }
173}
174
175impl Default for CompositionStats {
176    fn default() -> Self {
177        Self {
178            total_rounds: 0,
179            total_epsilon_consumed: 0.0,
180            total_delta_consumed: 0.0,
181            composition_method: FederatedCompositionMethod::default(),
182            amplification_rounds: 0,
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_federated_composition_analyzer() {
193        let analyzer =
194            FederatedCompositionAnalyzer::new(FederatedCompositionMethod::AdvancedComposition);
195
196        let epsilon = analyzer
197            .analyze_composition(5, 0.1, 1e-5)
198            .expect("unwrap failed");
199        assert!(epsilon > 0.1); // Should be larger than single round epsilon
200    }
201
202    #[test]
203    fn test_composition_stats() {
204        let mut analyzer = FederatedCompositionAnalyzer::new(
205            FederatedCompositionMethod::FederatedMomentsAccountant,
206        );
207
208        // Add some round compositions
209        analyzer.add_round_composition(RoundComposition {
210            round: 1,
211            participating_clients: 10,
212            epsilonconsumed: 0.1,
213            delta_consumed: 1e-5,
214            amplification_applied: true,
215            composition_method: FederatedCompositionMethod::FederatedMomentsAccountant,
216        });
217
218        analyzer.add_round_composition(RoundComposition {
219            round: 2,
220            participating_clients: 12,
221            epsilonconsumed: 0.15,
222            delta_consumed: 1e-5,
223            amplification_applied: false,
224            composition_method: FederatedCompositionMethod::FederatedMomentsAccountant,
225        });
226
227        let stats = analyzer.get_composition_stats();
228        assert_eq!(stats.total_rounds, 2);
229        assert_eq!(stats.total_epsilon_consumed, 0.25);
230        assert_eq!(stats.total_delta_consumed, 2e-5);
231        assert_eq!(stats.amplification_rounds, 1);
232    }
233
234    #[test]
235    fn test_basic_composition() {
236        let analyzer = FederatedCompositionAnalyzer::new(FederatedCompositionMethod::Basic);
237        let epsilon = analyzer
238            .analyze_composition(3, 0.1, 1e-5)
239            .expect("unwrap failed");
240        assert!((epsilon - 0.3).abs() < 1e-10); // Basic composition: 3 * 0.1, with floating point tolerance
241    }
242
243    #[test]
244    fn test_client_composition_tracking() {
245        let mut analyzer = FederatedCompositionAnalyzer::new(
246            FederatedCompositionMethod::FederatedMomentsAccountant,
247        );
248
249        let client_comp = ClientComposition {
250            clientid: "client1".to_string(),
251            round: 1,
252            epsilon_contribution: 0.05,
253            delta_contribution: 5e-6,
254        };
255
256        analyzer.add_client_composition("client1".to_string(), client_comp);
257
258        let compositions = analyzer.get_client_compositions("client1");
259        assert!(compositions.is_some());
260        assert_eq!(compositions.expect("unwrap failed").len(), 1);
261    }
262
263    #[test]
264    fn test_clear_history() {
265        let mut analyzer = FederatedCompositionAnalyzer::new(
266            FederatedCompositionMethod::FederatedMomentsAccountant,
267        );
268
269        // Add some data
270        analyzer.add_round_composition(RoundComposition {
271            round: 1,
272            participating_clients: 10,
273            epsilonconsumed: 0.1,
274            delta_consumed: 1e-5,
275            amplification_applied: true,
276            composition_method: FederatedCompositionMethod::FederatedMomentsAccountant,
277        });
278
279        assert_eq!(analyzer.rounds_count(), 1);
280
281        analyzer.clear_history();
282        assert_eq!(analyzer.rounds_count(), 0);
283    }
284}