use anyhow::{anyhow, Result};
use scirs2_core::random::StdRng; use scirs2_core::random::*; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FedAvgConfig {
pub local_epochs: usize,
pub local_learning_rate: f32,
pub client_fraction: f32,
pub min_clients: usize,
pub max_clients: usize,
pub weight_decay: f32,
}
impl Default for FedAvgConfig {
fn default() -> Self {
Self {
local_epochs: 5,
local_learning_rate: 1e-3,
client_fraction: 0.1,
min_clients: 2,
max_clients: 100,
weight_decay: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FedProxConfig {
pub fedavg_config: FedAvgConfig,
pub mu: f32,
}
impl Default for FedProxConfig {
fn default() -> Self {
Self {
fedavg_config: FedAvgConfig::default(),
mu: 0.01,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DifferentialPrivacyConfig {
pub epsilon: f32,
pub delta: f32,
pub sensitivity: f32,
pub noise_mechanism: NoiseMechanism,
}
impl Default for DifferentialPrivacyConfig {
fn default() -> Self {
Self {
epsilon: 1.0,
delta: 1e-5,
sensitivity: 1.0,
noise_mechanism: NoiseMechanism::Gaussian,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NoiseMechanism {
Gaussian,
Laplace,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ClientSelectionStrategy {
Random,
DataSize,
ComputeCapacity,
CommunicationQuality,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientInfo {
pub client_id: String,
pub data_size: usize,
pub compute_capacity: f32,
pub communication_quality: f32,
pub available: bool,
}
#[derive(Debug)]
pub struct FedAvg {
config: FedAvgConfig,
global_parameters: Vec<Tensor>,
client_weights: HashMap<String, f32>,
current_round: usize,
selected_clients: Vec<String>,
rng: StdRng,
}
impl FedAvg {
pub fn new(config: FedAvgConfig) -> Self {
Self {
config,
global_parameters: Vec::new(),
client_weights: HashMap::new(),
current_round: 0,
selected_clients: Vec::new(),
rng: StdRng::seed_from_u64(42),
}
}
pub fn initialize_global_parameters(&mut self, parameters: Vec<Tensor>) {
self.global_parameters = parameters;
}
pub fn select_clients(
&mut self,
available_clients: &[ClientInfo],
strategy: ClientSelectionStrategy,
) -> Result<Vec<String>> {
let available: Vec<&ClientInfo> =
available_clients.iter().filter(|c| c.available).collect();
if available.is_empty() {
return Err(anyhow!("No available clients"));
}
let num_clients = (available.len() as f32 * self.config.client_fraction).round() as usize;
let num_clients = num_clients
.max(self.config.min_clients)
.min(self.config.max_clients)
.min(available.len());
let selected = match strategy {
ClientSelectionStrategy::Random => {
let mut indices: Vec<usize> = (0..available.len()).collect();
for i in 0..num_clients {
let j = self.rng.random_range(i..indices.len());
indices.swap(i, j);
}
indices[..num_clients].iter().map(|&i| available[i].client_id.clone()).collect()
},
ClientSelectionStrategy::DataSize => {
let mut clients_with_size: Vec<_> =
available.iter().map(|c| (c.client_id.clone(), c.data_size)).collect();
clients_with_size.sort_by_key(|(_, size)| std::cmp::Reverse(*size));
clients_with_size[..num_clients].iter().map(|(id, _)| id.clone()).collect()
},
ClientSelectionStrategy::ComputeCapacity => {
let mut clients_with_capacity: Vec<_> =
available.iter().map(|c| (c.client_id.clone(), c.compute_capacity)).collect();
clients_with_capacity.sort_by(|(_, a), (_, b)| {
b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
});
clients_with_capacity[..num_clients].iter().map(|(id, _)| id.clone()).collect()
},
ClientSelectionStrategy::CommunicationQuality => {
let mut clients_with_quality: Vec<_> = available
.iter()
.map(|c| (c.client_id.clone(), c.communication_quality))
.collect();
clients_with_quality.sort_by(|(_, a), (_, b)| {
b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
});
clients_with_quality[..num_clients].iter().map(|(id, _)| id.clone()).collect()
},
};
self.selected_clients = selected;
Ok(self.selected_clients.clone())
}
pub fn aggregate_updates(
&mut self,
client_updates: HashMap<String, Vec<Tensor>>,
) -> Result<Vec<Tensor>> {
if client_updates.is_empty() {
return Err(anyhow!("No client updates to aggregate"));
}
let total_weight: f32 = client_updates
.keys()
.map(|client_id| self.client_weights.get(client_id).unwrap_or(&1.0))
.sum();
if total_weight == 0.0 {
return Err(anyhow!("Total client weight is zero"));
}
let param_count = client_updates
.values()
.next()
.expect("client_updates must have at least one entry")
.len();
let mut aggregated = Vec::with_capacity(param_count);
for i in 0..param_count {
let first_param = &client_updates
.values()
.next()
.expect("client_updates must have at least one entry")[i];
aggregated.push(Tensor::zeros_like(first_param)?);
}
for (client_id, updates) in &client_updates {
let weight = self.client_weights.get(client_id).unwrap_or(&1.0) / total_weight;
for (i, update) in updates.iter().enumerate() {
let weighted_update = update.mul_scalar(weight)?;
aggregated[i] = aggregated[i].add(&weighted_update)?;
}
}
self.global_parameters = aggregated.clone();
self.current_round += 1;
Ok(aggregated)
}
pub fn set_client_weights(&mut self, weights: HashMap<String, f32>) {
self.client_weights = weights;
}
pub fn get_global_parameters(&self) -> &[Tensor] {
&self.global_parameters
}
pub fn get_current_round(&self) -> usize {
self.current_round
}
}
#[derive(Debug)]
pub struct FedProx {
fedavg: FedAvg,
config: FedProxConfig,
}
impl FedProx {
pub fn new(config: FedProxConfig) -> Self {
Self {
fedavg: FedAvg::new(config.fedavg_config.clone()),
config,
}
}
pub fn compute_proximal_term(
&self,
client_params: &[Tensor],
global_params: &[Tensor],
) -> Result<f32> {
if client_params.len() != global_params.len() {
return Err(anyhow!("Parameter count mismatch"));
}
let mut proximal_loss = 0.0;
for (client_param, global_param) in client_params.iter().zip(global_params.iter()) {
let diff = client_param.sub(global_param)?;
let norm_sq = diff.norm_squared()?.to_scalar()?;
proximal_loss += norm_sq;
}
Ok(self.config.mu * proximal_loss / 2.0)
}
pub fn apply_proximal_update(
&self,
client_params: &mut [Tensor],
global_params: &[Tensor],
learning_rate: f32,
) -> Result<()> {
for (client_param, global_param) in client_params.iter_mut().zip(global_params.iter()) {
let diff = client_param.sub(global_param)?;
let proximal_grad = diff.mul_scalar(self.config.mu)?;
let update = proximal_grad.mul_scalar(learning_rate)?;
*client_param = client_param.sub(&update)?;
}
Ok(())
}
pub fn select_clients(
&mut self,
available_clients: &[ClientInfo],
strategy: ClientSelectionStrategy,
) -> Result<Vec<String>> {
self.fedavg.select_clients(available_clients, strategy)
}
pub fn aggregate_updates(
&mut self,
client_updates: HashMap<String, Vec<Tensor>>,
) -> Result<Vec<Tensor>> {
self.fedavg.aggregate_updates(client_updates)
}
pub fn get_global_parameters(&self) -> &[Tensor] {
self.fedavg.get_global_parameters()
}
pub fn get_current_round(&self) -> usize {
self.fedavg.get_current_round()
}
}
pub struct DifferentialPrivacy {
config: DifferentialPrivacyConfig,
rng: StdRng,
}
impl DifferentialPrivacy {
pub fn new(config: DifferentialPrivacyConfig) -> Self {
Self {
config,
rng: StdRng::seed_from_u64(42),
}
}
pub fn add_noise(&mut self, parameters: &mut [Tensor]) -> Result<()> {
let noise_scale = self.compute_noise_scale()?;
for param in parameters.iter_mut() {
let noise = self.generate_noise_tensor(param, noise_scale)?;
*param = param.add(&noise)?;
}
Ok(())
}
fn compute_noise_scale(&self) -> Result<f32> {
match self.config.noise_mechanism {
NoiseMechanism::Gaussian => {
let ln_term = (1.25 / self.config.delta).ln();
let sigma = (2.0 * ln_term).sqrt() * self.config.sensitivity / self.config.epsilon;
Ok(sigma)
},
NoiseMechanism::Laplace => {
Ok(self.config.sensitivity / self.config.epsilon)
},
}
}
fn generate_noise_tensor(&mut self, reference: &Tensor, scale: f32) -> Result<Tensor> {
let shape = reference.shape();
let mut noise_data = Vec::new();
match self.config.noise_mechanism {
NoiseMechanism::Gaussian => {
use scirs2_core::random::{Distribution, Normal}; let normal = Normal::new(0.0, scale)
.map_err(|e| anyhow!("Normal distribution error: {}", e))?;
for _ in 0..shape.iter().product::<usize>() {
noise_data.push(normal.sample(&mut self.rng));
}
},
NoiseMechanism::Laplace => {
use scirs2_core::random::{Distribution, Exp}; let exp_dist = Exp::new(1.0 / scale)
.map_err(|e| anyhow!("Exponential distribution error: {}", e))?;
for _ in 0..shape.iter().product::<usize>() {
let sign = if self.rng.random::<bool>() { 1.0 } else { -1.0 };
let exp_sample = exp_dist.sample(&mut self.rng);
noise_data.push(sign * exp_sample);
}
},
}
Ok(Tensor::from_data(noise_data, &shape.to_vec())?)
}
}
pub struct SecureAggregation {
threshold: usize,
#[allow(dead_code)]
total_clients: usize,
}
impl SecureAggregation {
pub fn new(threshold: usize, total_clients: usize) -> Result<Self> {
if threshold > total_clients {
return Err(anyhow!("Threshold cannot exceed total clients"));
}
Ok(Self {
threshold,
total_clients,
})
}
pub fn generate_masks(&self, client_id: &str, round: usize) -> Result<Vec<Tensor>> {
let mut rng = StdRng::from_seed({
let mut seed = [0u8; 32];
let client_hash = format!("{}-{}", client_id, round);
let bytes = client_hash.as_bytes();
for (i, &byte) in bytes.iter().enumerate().take(32) {
seed[i] = byte;
}
seed
});
let mut masks = Vec::new();
let parameter_shapes = vec![
vec![100, 50], vec![50], vec![50, 20], vec![20], ];
for shape in parameter_shapes {
let mask_size = shape.iter().product::<usize>();
let mut mask_data: Vec<f32> = Vec::with_capacity(mask_size);
for _ in 0..mask_size {
mask_data.push(rng.random_range(-1.0..1.0));
}
let mask = Tensor::from_data(mask_data, &shape)?;
masks.push(mask);
}
Ok(masks)
}
pub fn secure_aggregate(
&self,
masked_updates: HashMap<String, Vec<Tensor>>,
) -> Result<Vec<Tensor>> {
if masked_updates.len() < self.threshold {
return Err(anyhow!("Not enough clients for secure aggregation"));
}
let mut result = Vec::new();
let client_count = masked_updates.len() as f32;
let parameter_count =
masked_updates.values().next().map(|update| update.len()).unwrap_or(0);
for (client_id, update) in &masked_updates {
if update.len() != parameter_count {
return Err(anyhow!(
"Client {} has {} parameters, expected {}",
client_id,
update.len(),
parameter_count
));
}
}
for param_idx in 0..parameter_count {
let mut parameter_updates = Vec::new();
let mut expected_shape: Option<Vec<usize>> = None;
for (client_id, update) in &masked_updates {
let param_update = &update[param_idx];
if let Some(ref shape) = expected_shape {
if param_update.shape() != *shape {
return Err(anyhow!(
"Client {} parameter {} has shape {:?}, expected {:?}",
client_id,
param_idx,
param_update.shape(),
shape
));
}
} else {
expected_shape = Some(param_update.shape());
}
parameter_updates.push(param_update);
}
let shape = expected_shape
.ok_or_else(|| anyhow!("No client updates found for parameter {}", param_idx))?;
let mut aggregated_param = Tensor::zeros(&shape)?;
for param_update in parameter_updates {
aggregated_param = aggregated_param.add(param_update)?;
}
result.push(aggregated_param.div_scalar(client_count)?);
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fedavg_config_default() {
let config = FedAvgConfig::default();
assert_eq!(config.local_epochs, 5);
assert_eq!(config.client_fraction, 0.1);
assert_eq!(config.min_clients, 2);
}
#[test]
fn test_fedprox_config_default() {
let config = FedProxConfig::default();
assert_eq!(config.mu, 0.01);
assert_eq!(config.fedavg_config.local_epochs, 5);
}
#[test]
fn test_differential_privacy_config() {
let config = DifferentialPrivacyConfig::default();
assert_eq!(config.epsilon, 1.0);
assert_eq!(config.delta, 1e-5);
assert!(matches!(config.noise_mechanism, NoiseMechanism::Gaussian));
}
#[test]
fn test_client_selection_strategies() {
let clients = vec![
ClientInfo {
client_id: "client1".to_string(),
data_size: 100,
compute_capacity: 0.8,
communication_quality: 0.9,
available: true,
},
ClientInfo {
client_id: "client2".to_string(),
data_size: 200,
compute_capacity: 0.6,
communication_quality: 0.7,
available: true,
},
];
let mut fedavg = FedAvg::new(FedAvgConfig::default());
let selected = fedavg
.select_clients(&clients, ClientSelectionStrategy::Random)
.expect("Operation failed in test");
assert!(!selected.is_empty());
let selected = fedavg
.select_clients(&clients, ClientSelectionStrategy::DataSize)
.expect("Operation failed in test");
assert!(!selected.is_empty());
}
#[test]
fn test_secure_aggregation_creation() {
let secure_agg = SecureAggregation::new(3, 5).expect("Construction failed");
assert_eq!(secure_agg.threshold, 3);
assert_eq!(secure_agg.total_clients, 5);
assert!(SecureAggregation::new(6, 5).is_err());
}
}