optirs_core/privacy/federated/
composition_analyzer.rs1use crate::error::{OptimError, Result};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, Default)]
12pub enum FederatedCompositionMethod {
13 Basic,
15
16 AdvancedComposition,
18
19 #[default]
21 FederatedMomentsAccountant,
22
23 RenyiDP,
25
26 ZCDP,
28}
29
30pub struct FederatedCompositionAnalyzer {
32 method: FederatedCompositionMethod,
33 round_compositions: Vec<RoundComposition>,
34 client_compositions: HashMap<String, Vec<ClientComposition>>,
35}
36
37#[derive(Debug, Clone)]
39pub struct RoundComposition {
40 pub round: usize,
41 pub participating_clients: usize,
42 pub epsilonconsumed: f64,
43 pub delta_consumed: f64,
44 pub amplification_applied: bool,
45 pub composition_method: FederatedCompositionMethod,
46}
47
48#[derive(Debug, Clone)]
50pub struct ClientComposition {
51 pub clientid: String,
52 pub round: usize,
53 pub epsilon_contribution: f64,
54 pub delta_contribution: f64,
55}
56
57#[derive(Debug, Clone)]
59pub struct CompositionStats {
60 pub total_rounds: usize,
61 pub total_epsilon_consumed: f64,
62 pub total_delta_consumed: f64,
63 pub composition_method: FederatedCompositionMethod,
64 pub amplification_rounds: usize,
65}
66
67impl FederatedCompositionAnalyzer {
68 pub fn new(method: FederatedCompositionMethod) -> Self {
69 Self {
70 method,
71 round_compositions: Vec::new(),
72 client_compositions: HashMap::new(),
73 }
74 }
75
76 pub fn analyze_composition(&self, round: usize, epsilon: f64, delta: f64) -> Result<f64> {
77 match self.method {
78 FederatedCompositionMethod::Basic => Ok(epsilon * round as f64),
79 FederatedCompositionMethod::AdvancedComposition => {
80 let k = round as f64;
82 let advanced_epsilon = (k * epsilon * epsilon
83 + k.sqrt() * epsilon * (2.0 * (1.25 / delta).ln()).sqrt())
84 .sqrt();
85 Ok(advanced_epsilon)
86 }
87 FederatedCompositionMethod::FederatedMomentsAccountant => {
88 Ok(epsilon * (round as f64).sqrt())
90 }
91 FederatedCompositionMethod::RenyiDP => {
92 Ok(epsilon * (round as f64).ln())
94 }
95 FederatedCompositionMethod::ZCDP => {
96 Ok(epsilon * (round as f64).sqrt())
98 }
99 }
100 }
101
102 pub fn add_round_composition(&mut self, composition: RoundComposition) {
103 self.round_compositions.push(composition);
104 }
105
106 pub fn add_client_composition(&mut self, client_id: String, composition: ClientComposition) {
107 self.client_compositions
108 .entry(client_id)
109 .or_default()
110 .push(composition);
111 }
112
113 pub fn get_composition_stats(&self) -> CompositionStats {
114 if self.round_compositions.is_empty() {
115 return CompositionStats::default();
116 }
117
118 let total_epsilon: f64 = self
119 .round_compositions
120 .iter()
121 .map(|comp| comp.epsilonconsumed)
122 .sum();
123
124 let total_delta: f64 = self
125 .round_compositions
126 .iter()
127 .map(|comp| comp.delta_consumed)
128 .sum();
129
130 CompositionStats {
131 total_rounds: self.round_compositions.len(),
132 total_epsilon_consumed: total_epsilon,
133 total_delta_consumed: total_delta,
134 composition_method: self.method,
135 amplification_rounds: self
136 .round_compositions
137 .iter()
138 .filter(|comp| comp.amplification_applied)
139 .count(),
140 }
141 }
142
143 pub fn method(&self) -> FederatedCompositionMethod {
145 self.method
146 }
147
148 pub fn rounds_count(&self) -> usize {
150 self.round_compositions.len()
151 }
152
153 pub fn get_client_compositions(&self, client_id: &str) -> Option<&Vec<ClientComposition>> {
155 self.client_compositions.get(client_id)
156 }
157
158 pub fn get_round_compositions(&self) -> &Vec<RoundComposition> {
160 &self.round_compositions
161 }
162
163 pub fn clear_history(&mut self) {
165 self.round_compositions.clear();
166 self.client_compositions.clear();
167 }
168
169 pub fn set_method(&mut self, method: FederatedCompositionMethod) {
171 self.method = method;
172 }
173}
174
175impl Default for CompositionStats {
176 fn default() -> Self {
177 Self {
178 total_rounds: 0,
179 total_epsilon_consumed: 0.0,
180 total_delta_consumed: 0.0,
181 composition_method: FederatedCompositionMethod::default(),
182 amplification_rounds: 0,
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_federated_composition_analyzer() {
193 let analyzer =
194 FederatedCompositionAnalyzer::new(FederatedCompositionMethod::AdvancedComposition);
195
196 let epsilon = analyzer
197 .analyze_composition(5, 0.1, 1e-5)
198 .expect("unwrap failed");
199 assert!(epsilon > 0.1); }
201
202 #[test]
203 fn test_composition_stats() {
204 let mut analyzer = FederatedCompositionAnalyzer::new(
205 FederatedCompositionMethod::FederatedMomentsAccountant,
206 );
207
208 analyzer.add_round_composition(RoundComposition {
210 round: 1,
211 participating_clients: 10,
212 epsilonconsumed: 0.1,
213 delta_consumed: 1e-5,
214 amplification_applied: true,
215 composition_method: FederatedCompositionMethod::FederatedMomentsAccountant,
216 });
217
218 analyzer.add_round_composition(RoundComposition {
219 round: 2,
220 participating_clients: 12,
221 epsilonconsumed: 0.15,
222 delta_consumed: 1e-5,
223 amplification_applied: false,
224 composition_method: FederatedCompositionMethod::FederatedMomentsAccountant,
225 });
226
227 let stats = analyzer.get_composition_stats();
228 assert_eq!(stats.total_rounds, 2);
229 assert_eq!(stats.total_epsilon_consumed, 0.25);
230 assert_eq!(stats.total_delta_consumed, 2e-5);
231 assert_eq!(stats.amplification_rounds, 1);
232 }
233
234 #[test]
235 fn test_basic_composition() {
236 let analyzer = FederatedCompositionAnalyzer::new(FederatedCompositionMethod::Basic);
237 let epsilon = analyzer
238 .analyze_composition(3, 0.1, 1e-5)
239 .expect("unwrap failed");
240 assert!((epsilon - 0.3).abs() < 1e-10); }
242
243 #[test]
244 fn test_client_composition_tracking() {
245 let mut analyzer = FederatedCompositionAnalyzer::new(
246 FederatedCompositionMethod::FederatedMomentsAccountant,
247 );
248
249 let client_comp = ClientComposition {
250 clientid: "client1".to_string(),
251 round: 1,
252 epsilon_contribution: 0.05,
253 delta_contribution: 5e-6,
254 };
255
256 analyzer.add_client_composition("client1".to_string(), client_comp);
257
258 let compositions = analyzer.get_client_compositions("client1");
259 assert!(compositions.is_some());
260 assert_eq!(compositions.expect("unwrap failed").len(), 1);
261 }
262
263 #[test]
264 fn test_clear_history() {
265 let mut analyzer = FederatedCompositionAnalyzer::new(
266 FederatedCompositionMethod::FederatedMomentsAccountant,
267 );
268
269 analyzer.add_round_composition(RoundComposition {
271 round: 1,
272 participating_clients: 10,
273 epsilonconsumed: 0.1,
274 delta_consumed: 1e-5,
275 amplification_applied: true,
276 composition_method: FederatedCompositionMethod::FederatedMomentsAccountant,
277 });
278
279 assert_eq!(analyzer.rounds_count(), 1);
280
281 analyzer.clear_history();
282 assert_eq!(analyzer.rounds_count(), 0);
283 }
284}