use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::federated_learning_v2_backup::types::*;
use trustformers_core::{Result, CoreError, Tensor};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecureAggregationConfig {
pub protocol: SecureAggregationProtocol,
pub min_participants: u32,
pub max_participants: u32,
pub dropout_tolerance: f64,
pub use_quantization: bool,
pub quantization_bits: u8,
pub secure_shuffling: bool,
pub verification_threshold: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregationWeights {
pub participant_weights: HashMap<String, f64>,
pub total_weight: f64,
pub normalization_strategy: WeightNormalizationStrategy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WeightNormalizationStrategy {
None,
SumToOne,
ByParticipantCount,
ByDataSize,
ByUpdateQuality,
}
#[derive(Debug)]
pub struct SecureAggregator {
config: SecureAggregationConfig,
participant_updates: HashMap<String, Vec<u8>>,
participant_masks: HashMap<String, Vec<u8>>,
aggregation_state: AggregationState,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregationState {
WaitingForUpdates,
Computing,
Completed,
Failed,
}
impl SecureAggregator {
pub fn new(config: &SecureAggregationConfig) -> Result<Self> {
Ok(Self {
config: config.clone(),
participant_updates: HashMap::new(),
participant_masks: HashMap::new(),
aggregation_state: AggregationState::WaitingForUpdates,
})
}
pub fn add_participant_update(&mut self, participant_id: String, update: Vec<u8>) -> Result<()> {
if self.participant_updates.len() >= self.config.max_participants as usize {
return Err(TrustformersError::InvalidConfiguration("Maximum participants exceeded".to_string()).into());
}
self.participant_updates.insert(participant_id, update);
Ok(())
}
pub fn add_participant_mask(&mut self, participant_id: String, mask: Vec<u8>) -> Result<()> {
self.participant_masks.insert(participant_id, mask);
Ok(())
}
pub fn aggregate(&mut self, weights: &AggregationWeights) -> Result<Vec<u8>> {
if self.participant_updates.len() < self.config.min_participants as usize {
return Err(TrustformersError::InvalidConfiguration("Insufficient participants for aggregation".to_string()).into());
}
self.aggregation_state = AggregationState::Computing;
let result = match self.config.protocol {
SecureAggregationProtocol::BasicSecureAggregation => {
self.basic_secure_aggregation(weights)
}
SecureAggregationProtocol::FederatedSecureAggregation => {
self.federated_secure_aggregation(weights)
}
SecureAggregationProtocol::PrivateFederatedLearning => {
self.private_federated_learning_aggregation(weights)
}
SecureAggregationProtocol::SecAggPlus => {
self.secagg_plus_aggregation(weights)
}
SecureAggregationProtocol::Flamingo => {
self.flamingo_aggregation(weights)
}
SecureAggregationProtocol::FATE => {
self.fate_aggregation(weights)
}
};
match result {
Ok(aggregated) => {
self.aggregation_state = AggregationState::Completed;
Ok(aggregated)
}
Err(e) => {
self.aggregation_state = AggregationState::Failed;
Err(e)
}
}
}
fn basic_secure_aggregation(&self, weights: &AggregationWeights) -> Result<Vec<u8>> {
if self.participant_updates.is_empty() {
return Err(TrustformersError::InvalidConfiguration("No participant updates available".to_string()).into());
}
let first_update = self.participant_updates.values().next()
.ok_or_else(|| TrustformersError::other("No participant updates available".to_string()))?;
let update_size = first_update.len();
let mut aggregated = vec![0u8; update_size];
for (participant_id, update) in &self.participant_updates {
let weight = weights.participant_weights.get(participant_id).unwrap_or(&1.0);
if update.len() != update_size {
return Err(TrustformersError::InvalidConfiguration("Update size mismatch".to_string()).into());
}
for (i, &byte) in update.iter().enumerate() {
let weighted_value = (byte as f64 * weight) as u8;
aggregated[i] = aggregated[i].saturating_add(weighted_value);
}
}
self.apply_normalization(&mut aggregated, weights)?;
Ok(aggregated)
}
fn federated_secure_aggregation(&self, weights: &AggregationWeights) -> Result<Vec<u8>> {
let mut aggregated = self.basic_secure_aggregation(weights)?;
if !self.participant_masks.is_empty() {
self.apply_secure_masks(&mut aggregated)?;
}
Ok(aggregated)
}
fn private_federated_learning_aggregation(&self, weights: &AggregationWeights) -> Result<Vec<u8>> {
let mut aggregated = self.federated_secure_aggregation(weights)?;
self.apply_privacy_transformations(&mut aggregated)?;
Ok(aggregated)
}
fn secagg_plus_aggregation(&self, weights: &AggregationWeights) -> Result<Vec<u8>> {
let mut aggregated = self.private_federated_learning_aggregation(weights)?;
self.apply_dropout_resilience(&mut aggregated)?;
Ok(aggregated)
}
fn flamingo_aggregation(&self, weights: &AggregationWeights) -> Result<Vec<u8>> {
let mut aggregated = self.secagg_plus_aggregation(weights)?;
if self.config.secure_shuffling {
self.apply_secure_shuffling(&mut aggregated)?;
}
Ok(aggregated)
}
fn fate_aggregation(&self, weights: &AggregationWeights) -> Result<Vec<u8>> {
let mut aggregated = self.flamingo_aggregation(weights)?;
self.apply_fate_optimizations(&mut aggregated)?;
Ok(aggregated)
}
fn apply_secure_masks(&self, aggregated: &mut [u8]) -> Result<()> {
for mask in self.participant_masks.values() {
if mask.len() != aggregated.len() {
continue; }
for (i, &mask_byte) in mask.iter().enumerate() {
aggregated[i] ^= mask_byte; }
}
Ok(())
}
fn apply_privacy_transformations(&self, aggregated: &mut [u8]) -> Result<()> {
for byte in aggregated.iter_mut() {
let noise = 1u8; *byte = byte.saturating_add(noise);
}
Ok(())
}
fn apply_dropout_resilience(&self, aggregated: &mut [u8]) -> Result<()> {
let dropout_rate = 1.0 - (self.participant_updates.len() as f64 / self.config.max_participants as f64);
if dropout_rate > self.config.dropout_tolerance {
for byte in aggregated.iter_mut() {
let compensated = (*byte as f64 * (1.0 + dropout_rate)) as u8;
*byte = compensated;
}
}
Ok(())
}
fn apply_secure_shuffling(&self, aggregated: &mut [u8]) -> Result<()> {
let len = aggregated.len();
for i in 0..len {
let j = (i + 1) % len; aggregated.swap(i, j);
}
Ok(())
}
fn apply_fate_optimizations(&self, aggregated: &mut [u8]) -> Result<()> {
if self.config.use_quantization {
self.apply_quantization(aggregated)?;
}
Ok(())
}
fn apply_quantization(&self, aggregated: &mut [u8]) -> Result<()> {
let bits = self.config.quantization_bits as u32;
let levels = (1u32 << bits) - 1;
for byte in aggregated.iter_mut() {
let quantized = (*byte as u32 * levels / 255) as u8;
*byte = quantized;
}
Ok(())
}
fn apply_normalization(&self, aggregated: &mut [u8], weights: &AggregationWeights) -> Result<()> {
match weights.normalization_strategy {
WeightNormalizationStrategy::None => {
}
WeightNormalizationStrategy::SumToOne => {
for byte in aggregated.iter_mut() {
let normalized = (*byte as f64 / weights.total_weight) as u8;
*byte = normalized;
}
}
WeightNormalizationStrategy::ByParticipantCount => {
let count = self.participant_updates.len() as f64;
for byte in aggregated.iter_mut() {
let normalized = (*byte as f64 / count) as u8;
*byte = normalized;
}
}
WeightNormalizationStrategy::ByDataSize => {
let avg_data_size = 1000.0; for byte in aggregated.iter_mut() {
let normalized = (*byte as f64 / avg_data_size) as u8;
*byte = normalized;
}
}
WeightNormalizationStrategy::ByUpdateQuality => {
let avg_quality = 0.8; for byte in aggregated.iter_mut() {
let normalized = (*byte as f64 / avg_quality) as u8;
*byte = normalized;
}
}
}
Ok(())
}
pub fn get_state(&self) -> AggregationState {
self.aggregation_state
}
pub fn get_participant_count(&self) -> usize {
self.participant_updates.len()
}
pub fn verify_aggregation(&self, aggregated_result: &[u8]) -> Result<bool> {
if aggregated_result.is_empty() {
return Ok(false);
}
if let Some(first_update) = self.participant_updates.values().next() {
if aggregated_result.len() != first_update.len() {
return Ok(false);
}
}
Ok(true)
}
pub fn reset(&mut self) {
self.participant_updates.clear();
self.participant_masks.clear();
self.aggregation_state = AggregationState::WaitingForUpdates;
}
}
impl AggregationWeights {
pub fn new(normalization_strategy: WeightNormalizationStrategy) -> Self {
Self {
participant_weights: HashMap::new(),
total_weight: 0.0,
normalization_strategy,
}
}
pub fn add_participant(&mut self, participant_id: String, weight: f64) {
self.participant_weights.insert(participant_id, weight);
self.recalculate_total_weight();
}
pub fn update_participant_weight(&mut self, participant_id: &str, weight: f64) -> Result<()> {
if self.participant_weights.contains_key(participant_id) {
self.participant_weights.insert(participant_id.to_string(), weight);
self.recalculate_total_weight();
Ok(())
} else {
Err(TrustformersError::InvalidConfiguration(format!("Participant {} not found", participant_id)))
}
}
pub fn remove_participant(&mut self, participant_id: &str) -> Result<()> {
if self.participant_weights.remove(participant_id).is_some() {
self.recalculate_total_weight();
Ok(())
} else {
Err(TrustformersError::InvalidConfiguration(format!("Participant {} not found", participant_id)))
}
}
fn recalculate_total_weight(&mut self) {
self.total_weight = self.participant_weights.values().sum();
}
pub fn get_participant_weight(&self, participant_id: &str) -> Option<f64> {
self.participant_weights.get(participant_id).copied()
}
pub fn get_normalized_weights(&self) -> HashMap<String, f64> {
if self.total_weight == 0.0 {
return self.participant_weights.clone();
}
self.participant_weights
.iter()
.map(|(id, &weight)| (id.clone(), weight / self.total_weight))
.collect()
}
}
impl Default for SecureAggregationConfig {
fn default() -> Self {
Self {
protocol: SecureAggregationProtocol::default(),
min_participants: 2,
max_participants: 1000,
dropout_tolerance: 0.3,
use_quantization: true,
quantization_bits: 8,
secure_shuffling: true,
verification_threshold: 0.95,
}
}
}
impl Default for WeightNormalizationStrategy {
fn default() -> Self {
Self::SumToOne
}
}