use crate::error::{OptimError, Result};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone)]
pub struct AmplificationConfig {
pub enabled: bool,
pub subsampling_factor: f64,
pub shuffling_enabled: bool,
pub multi_round_amplification: bool,
pub heterogeneous_amplification: bool,
}
pub struct PrivacyAmplificationAnalyzer {
config: AmplificationConfig,
subsampling_history: VecDeque<SubsamplingEvent>,
amplification_factors: HashMap<String, f64>,
}
#[derive(Debug, Clone)]
pub struct SubsamplingEvent {
pub round: usize,
pub sampling_rate: f64,
pub clients_sampled: usize,
pub total_clients: usize,
pub amplificationfactor: f64,
}
#[derive(Debug, Clone)]
pub struct AmplificationStats {
pub rounds_analyzed: usize,
pub avg_amplification_factor: f64,
pub max_amplification_factor: f64,
pub min_amplification_factor: f64,
pub total_privacy_saved: f64,
}
impl PrivacyAmplificationAnalyzer {
pub fn new(config: AmplificationConfig) -> Self {
Self {
config,
subsampling_history: VecDeque::with_capacity(1000),
amplification_factors: HashMap::new(),
}
}
pub fn compute_amplification_factor(
&mut self,
sampling_probability: f64,
round: usize,
) -> Result<f64> {
if !self.config.enabled {
return Ok(1.0);
}
let subsampling_factor = if sampling_probability < 1.0 {
sampling_probability.sqrt() * self.config.subsampling_factor
} else {
1.0
};
let multi_round_factor = if self.config.multi_round_amplification && round > 1 {
1.0 + 0.1 * (round as f64).ln() } else {
1.0
};
let total_amplification = subsampling_factor * multi_round_factor;
self.subsampling_history.push_back(SubsamplingEvent {
round,
sampling_rate: sampling_probability,
clients_sampled: (sampling_probability * 1000.0) as usize, total_clients: 1000,
amplificationfactor: total_amplification,
});
if self.subsampling_history.len() > 1000 {
self.subsampling_history.pop_front();
}
Ok(total_amplification.max(1.0))
}
pub fn get_amplification_stats(&self) -> AmplificationStats {
if self.subsampling_history.is_empty() {
return AmplificationStats::default();
}
let factors: Vec<f64> = self
.subsampling_history
.iter()
.map(|event| event.amplificationfactor)
.collect();
let avg_amplification = factors.iter().sum::<f64>() / factors.len() as f64;
let max_amplification = factors.iter().cloned().fold(0.0f64, f64::max);
let min_amplification = factors.iter().cloned().fold(f64::INFINITY, f64::min);
AmplificationStats {
rounds_analyzed: self.subsampling_history.len(),
avg_amplification_factor: avg_amplification,
max_amplification_factor: max_amplification,
min_amplification_factor: min_amplification,
total_privacy_saved: avg_amplification - 1.0,
}
}
pub fn add_client_amplification(&mut self, client_id: String, factor: f64) {
self.amplification_factors.insert(client_id, factor);
}
pub fn get_client_amplification(&self, client_id: &str) -> Option<f64> {
self.amplification_factors.get(client_id).copied()
}
pub fn compute_shuffling_amplification(&self, num_clients: usize) -> f64 {
if !self.config.shuffling_enabled || num_clients < 2 {
return 1.0;
}
1.0 + 0.1 * (num_clients as f64).sqrt()
}
pub fn compute_heterogeneous_amplification(&self, client_diversities: &[f64]) -> f64 {
if !self.config.heterogeneous_amplification || client_diversities.is_empty() {
return 1.0;
}
let avg_diversity =
client_diversities.iter().sum::<f64>() / client_diversities.len() as f64;
1.0 + 0.05 * avg_diversity.sqrt()
}
pub fn config(&self) -> &AmplificationConfig {
&self.config
}
pub fn rounds_analyzed(&self) -> usize {
self.subsampling_history.len()
}
pub fn clear_history(&mut self) {
self.subsampling_history.clear();
self.amplification_factors.clear();
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn get_subsampling_history(&self) -> &VecDeque<SubsamplingEvent> {
&self.subsampling_history
}
pub fn update_config(&mut self, config: AmplificationConfig) {
self.config = config;
}
pub fn compute_combined_amplification(
&mut self,
sampling_probability: f64,
round: usize,
num_clients: usize,
client_diversities: Option<&[f64]>,
) -> Result<f64> {
let subsampling_amp = self.compute_amplification_factor(sampling_probability, round)?;
let shuffling_amp = self.compute_shuffling_amplification(num_clients);
let heterogeneous_amp = if let Some(diversities) = client_diversities {
self.compute_heterogeneous_amplification(diversities)
} else {
1.0
};
Ok(subsampling_amp * shuffling_amp * heterogeneous_amp)
}
}
impl Default for AmplificationConfig {
fn default() -> Self {
Self {
enabled: true,
subsampling_factor: 1.0,
shuffling_enabled: false,
multi_round_amplification: true,
heterogeneous_amplification: false,
}
}
}
impl Default for AmplificationStats {
fn default() -> Self {
Self {
rounds_analyzed: 0,
avg_amplification_factor: 1.0,
max_amplification_factor: 1.0,
min_amplification_factor: 1.0,
total_privacy_saved: 0.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_amplification_analyzer_creation() {
let config = AmplificationConfig::default();
let analyzer = PrivacyAmplificationAnalyzer::new(config);
assert!(analyzer.is_enabled());
assert_eq!(analyzer.rounds_analyzed(), 0);
}
#[test]
fn test_compute_amplification_factor() {
let config = AmplificationConfig::default();
let mut analyzer = PrivacyAmplificationAnalyzer::new(config);
let factor = analyzer.compute_amplification_factor(0.1, 1);
assert!(factor.is_ok());
let amp_factor = factor.expect("unwrap failed");
assert!(amp_factor >= 1.0); assert_eq!(analyzer.rounds_analyzed(), 1);
}
#[test]
fn test_amplification_with_disabled_config() {
let mut config = AmplificationConfig {
enabled: false,
..Default::default()
};
let mut analyzer = PrivacyAmplificationAnalyzer::new(config);
let factor = analyzer.compute_amplification_factor(0.1, 1);
assert!(factor.is_ok());
assert_eq!(factor.expect("unwrap failed"), 1.0); }
#[test]
fn test_amplification_stats() {
let config = AmplificationConfig::default();
let mut analyzer = PrivacyAmplificationAnalyzer::new(config);
analyzer
.compute_amplification_factor(0.1, 1)
.expect("unwrap failed");
analyzer
.compute_amplification_factor(0.2, 2)
.expect("unwrap failed");
analyzer
.compute_amplification_factor(0.15, 3)
.expect("unwrap failed");
let stats = analyzer.get_amplification_stats();
assert_eq!(stats.rounds_analyzed, 3);
assert!(stats.avg_amplification_factor > 0.0); assert!(stats.max_amplification_factor >= stats.min_amplification_factor);
}
#[test]
fn test_client_specific_amplification() {
let config = AmplificationConfig::default();
let mut analyzer = PrivacyAmplificationAnalyzer::new(config);
analyzer.add_client_amplification("client1".to_string(), 1.5);
analyzer.add_client_amplification("client2".to_string(), 1.3);
assert_eq!(analyzer.get_client_amplification("client1"), Some(1.5));
assert_eq!(analyzer.get_client_amplification("client2"), Some(1.3));
assert_eq!(analyzer.get_client_amplification("client3"), None);
}
#[test]
fn test_shuffling_amplification() {
let mut config = AmplificationConfig {
shuffling_enabled: true,
..Default::default()
};
let analyzer = PrivacyAmplificationAnalyzer::new(config);
let amp_factor = analyzer.compute_shuffling_amplification(100);
assert!(amp_factor > 1.0);
let no_amp_factor = analyzer.compute_shuffling_amplification(1);
assert_eq!(no_amp_factor, 1.0); }
#[test]
fn test_heterogeneous_amplification() {
let mut config = AmplificationConfig {
heterogeneous_amplification: true,
..Default::default()
};
let analyzer = PrivacyAmplificationAnalyzer::new(config);
let diversities = vec![0.1, 0.3, 0.5, 0.7, 0.9];
let amp_factor = analyzer.compute_heterogeneous_amplification(&diversities);
assert!(amp_factor > 1.0);
let no_amp_factor = analyzer.compute_heterogeneous_amplification(&[]);
assert_eq!(no_amp_factor, 1.0); }
#[test]
fn test_combined_amplification() {
let mut config = AmplificationConfig {
shuffling_enabled: true,
heterogeneous_amplification: true,
..Default::default()
};
let mut analyzer = PrivacyAmplificationAnalyzer::new(config);
let diversities = vec![0.2, 0.4, 0.6];
let combined = analyzer.compute_combined_amplification(
0.1, 1, 10, Some(&diversities),
);
assert!(combined.is_ok());
let factor = combined.expect("unwrap failed");
assert!(factor > 1.0); }
#[test]
fn test_clear_history() {
let config = AmplificationConfig::default();
let mut analyzer = PrivacyAmplificationAnalyzer::new(config);
analyzer
.compute_amplification_factor(0.1, 1)
.expect("unwrap failed");
analyzer.add_client_amplification("client1".to_string(), 1.5);
assert_eq!(analyzer.rounds_analyzed(), 1);
assert!(analyzer.get_client_amplification("client1").is_some());
analyzer.clear_history();
assert_eq!(analyzer.rounds_analyzed(), 0);
assert!(analyzer.get_client_amplification("client1").is_none());
}
}