use super::config::AggregationStrategy;
use super::participant::LocalUpdate;
use anyhow::Result;
use scirs2_core::ndarray_ext::Array2;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregationEngine {
pub strategy: AggregationStrategy,
pub parameters: HashMap<String, f64>,
pub weighting_scheme: WeightingScheme,
pub outlier_detection: OutlierDetection,
}
impl AggregationEngine {
pub fn new(strategy: AggregationStrategy) -> Self {
Self {
strategy,
parameters: HashMap::new(),
weighting_scheme: WeightingScheme::SampleSize,
outlier_detection: OutlierDetection::default(),
}
}
pub fn with_weighting_scheme(mut self, scheme: WeightingScheme) -> Self {
self.weighting_scheme = scheme;
self
}
pub fn with_outlier_detection(mut self, detection: OutlierDetection) -> Self {
self.outlier_detection = detection;
self
}
pub fn aggregate_updates(
&self,
updates: &[LocalUpdate],
) -> Result<HashMap<String, Array2<f32>>> {
if updates.is_empty() {
return Ok(HashMap::new());
}
let filtered_updates = if self.outlier_detection.enabled {
self.filter_outliers(updates)?
} else {
updates.to_vec()
};
let weights = self.calculate_weights(&filtered_updates)?;
match self.strategy {
AggregationStrategy::FederatedAveraging => {
self.federated_averaging(&filtered_updates, &weights)
}
AggregationStrategy::WeightedAveraging => {
self.weighted_averaging(&filtered_updates, &weights)
}
AggregationStrategy::SecureAggregation => {
self.secure_aggregation(&filtered_updates, &weights)
}
AggregationStrategy::RobustAggregation => {
self.robust_aggregation(&filtered_updates, &weights)
}
AggregationStrategy::PersonalizedAggregation => {
self.personalized_aggregation(&filtered_updates, &weights)
}
AggregationStrategy::HierarchicalAggregation => {
self.hierarchical_aggregation(&filtered_updates, &weights)
}
}
}
fn federated_averaging(
&self,
updates: &[LocalUpdate],
weights: &HashMap<Uuid, f64>,
) -> Result<HashMap<String, Array2<f32>>> {
self.weighted_averaging(updates, weights)
}
fn weighted_averaging(
&self,
updates: &[LocalUpdate],
weights: &HashMap<Uuid, f64>,
) -> Result<HashMap<String, Array2<f32>>> {
let mut aggregated = HashMap::new();
let total_weight: f64 = weights.values().sum();
if total_weight == 0.0 {
return Err(anyhow::anyhow!("Total weight is zero"));
}
if let Some(first_update) = updates.first() {
for (param_name, param_values) in &first_update.parameter_updates {
aggregated.insert(param_name.clone(), Array2::zeros(param_values.raw_dim()));
}
}
for update in updates {
let weight = weights.get(&update.participant_id).unwrap_or(&0.0) / total_weight;
for (param_name, param_values) in &update.parameter_updates {
if let Some(aggregated_param) = aggregated.get_mut(param_name) {
*aggregated_param = &*aggregated_param + &(param_values * weight as f32);
}
}
}
Ok(aggregated)
}
fn secure_aggregation(
&self,
updates: &[LocalUpdate],
weights: &HashMap<Uuid, f64>,
) -> Result<HashMap<String, Array2<f32>>> {
self.weighted_averaging(updates, weights)
}
fn robust_aggregation(
&self,
updates: &[LocalUpdate],
_weights: &HashMap<Uuid, f64>,
) -> Result<HashMap<String, Array2<f32>>> {
let mut aggregated = HashMap::new();
if let Some(first_update) = updates.first() {
for param_name in first_update.parameter_updates.keys() {
let param_matrices: Vec<&Array2<f32>> = updates
.iter()
.filter_map(|update| update.parameter_updates.get(param_name))
.collect();
if param_matrices.is_empty() {
continue;
}
let aggregated_param = if param_matrices.len() > 2 {
self.krum_aggregation(¶m_matrices)?
} else {
self.median_aggregation(¶m_matrices)?
};
aggregated.insert(param_name.clone(), aggregated_param);
}
}
Ok(aggregated)
}
fn personalized_aggregation(
&self,
updates: &[LocalUpdate],
weights: &HashMap<Uuid, f64>,
) -> Result<HashMap<String, Array2<f32>>> {
self.weighted_averaging(updates, weights)
}
fn hierarchical_aggregation(
&self,
updates: &[LocalUpdate],
weights: &HashMap<Uuid, f64>,
) -> Result<HashMap<String, Array2<f32>>> {
self.weighted_averaging(updates, weights)
}
fn krum_aggregation(&self, matrices: &[&Array2<f32>]) -> Result<Array2<f32>> {
if matrices.is_empty() {
return Err(anyhow::anyhow!("No matrices to aggregate"));
}
let mut best_idx = 0;
let mut min_distance = f64::INFINITY;
for i in 0..matrices.len() {
let mut total_distance = 0.0;
for j in 0..matrices.len() {
if i != j {
total_distance += self.matrix_distance(matrices[i], matrices[j]);
}
}
if total_distance < min_distance {
min_distance = total_distance;
best_idx = i;
}
}
Ok(matrices[best_idx].clone())
}
fn median_aggregation(&self, matrices: &[&Array2<f32>]) -> Result<Array2<f32>> {
if matrices.is_empty() {
return Err(anyhow::anyhow!("No matrices to aggregate"));
}
let shape = matrices[0].raw_dim();
let mut result = Array2::zeros(shape);
for i in 0..shape[0] {
for j in 0..shape[1] {
let mut values: Vec<f32> = matrices.iter().map(|m| m[[i, j]]).collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = if values.len() % 2 == 0 {
(values[values.len() / 2 - 1] + values[values.len() / 2]) / 2.0
} else {
values[values.len() / 2]
};
result[[i, j]] = median;
}
}
Ok(result)
}
fn matrix_distance(&self, a: &Array2<f32>, b: &Array2<f32>) -> f64 {
(a - b)
.iter()
.map(|x| (*x as f64) * (*x as f64))
.sum::<f64>()
.sqrt()
}
fn calculate_weights(&self, updates: &[LocalUpdate]) -> Result<HashMap<Uuid, f64>> {
let mut weights = HashMap::new();
match &self.weighting_scheme {
WeightingScheme::Uniform => {
let uniform_weight = 1.0 / updates.len() as f64;
for update in updates {
weights.insert(update.participant_id, uniform_weight);
}
}
WeightingScheme::SampleSize => {
let total_samples: usize = updates.iter().map(|u| u.num_samples).sum();
if total_samples > 0 {
for update in updates {
let weight = update.num_samples as f64 / total_samples as f64;
weights.insert(update.participant_id, weight);
}
}
}
WeightingScheme::DataQuality => {
let total_accuracy: f64 = updates
.iter()
.map(|u| u.training_stats.local_accuracy)
.sum();
if total_accuracy > 0.0 {
for update in updates {
let weight = update.training_stats.local_accuracy / total_accuracy;
weights.insert(update.participant_id, weight);
}
}
}
WeightingScheme::ComputeContribution => {
let total_compute: f64 = updates
.iter()
.map(|u| 1.0 / (u.training_stats.training_time_seconds + 1.0))
.sum();
if total_compute > 0.0 {
for update in updates {
let weight = (1.0 / (update.training_stats.training_time_seconds + 1.0))
/ total_compute;
weights.insert(update.participant_id, weight);
}
}
}
WeightingScheme::TrustScore => {
let uniform_weight = 1.0 / updates.len() as f64;
for update in updates {
weights.insert(update.participant_id, uniform_weight);
}
}
WeightingScheme::Custom {
weights: custom_weights,
} => {
for update in updates {
let weight = custom_weights.get(&update.participant_id).unwrap_or(&0.0);
weights.insert(update.participant_id, *weight);
}
}
}
Ok(weights)
}
fn filter_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
match self.outlier_detection.method {
OutlierDetectionMethod::StatisticalDistance => {
self.filter_statistical_outliers(updates)
}
OutlierDetectionMethod::Clustering => self.filter_clustering_outliers(updates),
OutlierDetectionMethod::IsolationForest => {
self.filter_isolation_forest_outliers(updates)
}
OutlierDetectionMethod::ByzantineDetection => self.filter_byzantine_outliers(updates),
}
}
fn filter_statistical_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
if updates.len() < 3 {
return Ok(updates.to_vec());
}
let mut distances = Vec::new();
for i in 0..updates.len() {
let mut total_distance = 0.0;
for j in 0..updates.len() {
if i != j {
total_distance += self.calculate_update_distance(&updates[i], &updates[j]);
}
}
distances.push((i, total_distance / (updates.len() - 1) as f64));
}
let mean_distance: f64 =
distances.iter().map(|(_, d)| *d).sum::<f64>() / distances.len() as f64;
let variance: f64 = distances
.iter()
.map(|(_, d)| (d - mean_distance).powi(2))
.sum::<f64>()
/ distances.len() as f64;
let std_dev = variance.sqrt();
let threshold = mean_distance + self.outlier_detection.threshold * std_dev;
let filtered_indices: Vec<usize> = distances
.iter()
.filter(|(_, d)| *d <= threshold)
.map(|(i, _)| *i)
.collect();
Ok(filtered_indices
.iter()
.map(|&i| updates[i].clone())
.collect())
}
fn calculate_update_distance(&self, update1: &LocalUpdate, update2: &LocalUpdate) -> f64 {
let mut total_distance = 0.0;
let mut param_count = 0;
for (param_name, param1) in &update1.parameter_updates {
if let Some(param2) = update2.parameter_updates.get(param_name) {
total_distance += self.matrix_distance(param1, param2);
param_count += 1;
}
}
if param_count > 0 {
total_distance / param_count as f64
} else {
0.0
}
}
fn filter_clustering_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
self.filter_statistical_outliers(updates)
}
fn filter_isolation_forest_outliers(
&self,
updates: &[LocalUpdate],
) -> Result<Vec<LocalUpdate>> {
self.filter_statistical_outliers(updates)
}
fn filter_byzantine_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
self.filter_statistical_outliers(updates)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WeightingScheme {
Uniform,
SampleSize,
DataQuality,
ComputeContribution,
TrustScore,
Custom { weights: HashMap<Uuid, f64> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutlierDetection {
pub enabled: bool,
pub method: OutlierDetectionMethod,
pub threshold: f64,
pub outlier_action: OutlierAction,
}
impl Default for OutlierDetection {
fn default() -> Self {
Self {
enabled: true,
method: OutlierDetectionMethod::StatisticalDistance,
threshold: 2.0,
outlier_action: OutlierAction::ReduceWeight,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OutlierDetectionMethod {
StatisticalDistance,
Clustering,
IsolationForest,
ByzantineDetection,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OutlierAction {
Exclude,
ReduceWeight,
RobustAggregation,
FlagForReview,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregationStats {
pub num_participants: usize,
pub num_outliers: usize,
pub total_parameters: usize,
pub aggregation_time_seconds: f64,
pub consensus_measure: f64,
pub privacy_budget_consumed: f64,
}