use std::fmt::Debug;
use super::super::moment_accountant::MomentsAccountant;
use super::super::{AccountingMethod, DifferentialPrivacyConfig, NoiseMechanism, PrivacyBudget};
use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::Float;
use scirs2_core::random::Rng;
use scirs2_core::random::{Random, Rng as SCRRng};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct SecureAggregationConfig {
pub enabled: bool,
pub min_clients: usize,
pub max_dropouts: usize,
pub masking_dimension: usize,
pub seed_sharing: SeedSharingMethod,
pub quantization_bits: Option<u8>,
pub aggregate_dp: bool,
}
#[derive(Debug, Clone, Copy)]
pub enum SeedSharingMethod {
ShamirSecretSharing,
ThresholdEncryption,
DistributedKeyGeneration,
}
pub struct SecureAggregator<T: Float + Debug + Send + Sync + 'static> {
config: SecureAggregationConfig,
client_masks: HashMap<String, Array1<T>>,
shared_randomness: Arc<std::sync::Mutex<u64>>,
aggregation_threshold: usize,
round_keys: Vec<u64>,
}
#[derive(Debug, Clone)]
pub struct SecureAggregationPlan {
pub round_seed: u64,
pub participating_clients: Vec<String>,
pub min_threshold: usize,
pub masking_enabled: bool,
}
impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
SecureAggregator<T>
{
pub fn new(config: SecureAggregationConfig) -> Result<Self> {
let min_clients = config.min_clients;
Ok(Self {
config,
client_masks: HashMap::new(),
shared_randomness: Arc::new(std::sync::Mutex::new(0u64)),
aggregation_threshold: min_clients,
round_keys: Vec::new(),
})
}
pub fn prepare_round(&mut self, selectedclients: &[String]) -> Result<SecureAggregationPlan> {
let mut seed = self.shared_randomness.lock().expect("lock poisoned");
*seed = seed.wrapping_add(1);
let round_seed = *seed;
self.round_keys.push(round_seed);
self.client_masks.clear();
for clientid in selectedclients.iter() {
let mut client_rng = Random::default();
let mask_size = self.config.masking_dimension;
let mask = Array1::from_iter(
(0..mask_size)
.map(|_| T::from(client_rng.gen_range(-1.0..1.0)).expect("unwrap failed")),
);
self.client_masks.insert(clientid.clone(), mask);
}
Ok(SecureAggregationPlan {
round_seed,
participating_clients: selectedclients.to_vec(),
min_threshold: self.config.min_clients,
masking_enabled: true,
})
}
pub fn aggregate_with_masks(
&self,
clientupdates: &HashMap<String, Array1<T>>,
_aggregation_plan: &SecureAggregationPlan,
) -> Result<Array1<T>> {
if clientupdates.len() < self.aggregation_threshold {
return Err(OptimError::InvalidConfig(
"Insufficient clients for secure aggregation".to_string(),
));
}
let first_update = clientupdates.values().next().expect("unwrap failed");
let mut aggregated = Array1::zeros(first_update.len());
for (clientid, update) in clientupdates {
if let Some(mask) = self.client_masks.get(clientid) {
let masked_update = if update.len() == mask.len() {
update + mask
} else {
update.clone() };
aggregated = aggregated + masked_update;
} else {
aggregated = aggregated + update;
}
}
let num_clients = T::from(clientupdates.len()).expect("unwrap failed");
aggregated = aggregated / num_clients;
Ok(aggregated)
}
pub fn config(&self) -> &SecureAggregationConfig {
&self.config
}
pub fn aggregation_threshold(&self) -> usize {
self.aggregation_threshold
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
}
impl Default for SecureAggregationConfig {
fn default() -> Self {
Self {
enabled: true,
min_clients: 10,
max_dropouts: 5,
masking_dimension: 1000,
seed_sharing: SeedSharingMethod::ShamirSecretSharing,
quantization_bits: None,
aggregate_dp: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_secure_aggregation_config() {
let config = SecureAggregationConfig {
enabled: true,
min_clients: 5,
max_dropouts: 2,
masking_dimension: 100,
seed_sharing: SeedSharingMethod::ShamirSecretSharing,
quantization_bits: Some(8),
aggregate_dp: true,
};
assert!(config.enabled);
assert_eq!(config.min_clients, 5);
assert_eq!(config.max_dropouts, 2);
}
#[test]
fn test_secure_aggregator_creation() {
let config = SecureAggregationConfig::default();
let aggregator = SecureAggregator::<f64>::new(config.clone());
assert!(aggregator.is_ok());
let agg = aggregator.expect("unwrap failed");
assert_eq!(agg.aggregation_threshold(), config.min_clients);
assert!(agg.is_enabled());
}
#[test]
fn test_secure_aggregation_plan() {
let config = SecureAggregationConfig::default();
let mut aggregator = SecureAggregator::<f64>::new(config).expect("unwrap failed");
let clients = vec!["client1".to_string(), "client2".to_string()];
let plan = aggregator.prepare_round(&clients);
assert!(plan.is_ok());
let plan = plan.expect("unwrap failed");
assert_eq!(plan.participating_clients.len(), 2);
assert!(plan.masking_enabled);
}
}