oxirs_embed/federated_learning/
aggregation.rs

1//! Aggregation strategies for federated learning
2//!
3//! This module implements various aggregation methods for combining local model
4//! updates from multiple participants in federated learning, including Byzantine-
5//! resilient methods and robust aggregation techniques.
6
7use super::config::AggregationStrategy;
8use super::participant::LocalUpdate;
9use anyhow::Result;
10use scirs2_core::ndarray_ext::Array2;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use uuid::Uuid;
14
15/// Aggregation engine for combining local updates
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AggregationEngine {
18    /// Aggregation strategy
19    pub strategy: AggregationStrategy,
20    /// Aggregation parameters
21    pub parameters: HashMap<String, f64>,
22    /// Weighting scheme
23    pub weighting_scheme: WeightingScheme,
24    /// Outlier detection
25    pub outlier_detection: OutlierDetection,
26}
27
28impl AggregationEngine {
29    /// Create new aggregation engine
30    pub fn new(strategy: AggregationStrategy) -> Self {
31        Self {
32            strategy,
33            parameters: HashMap::new(),
34            weighting_scheme: WeightingScheme::SampleSize,
35            outlier_detection: OutlierDetection::default(),
36        }
37    }
38
39    /// Configure weighting scheme
40    pub fn with_weighting_scheme(mut self, scheme: WeightingScheme) -> Self {
41        self.weighting_scheme = scheme;
42        self
43    }
44
45    /// Configure outlier detection
46    pub fn with_outlier_detection(mut self, detection: OutlierDetection) -> Self {
47        self.outlier_detection = detection;
48        self
49    }
50
51    /// Aggregate local updates from participants
52    pub fn aggregate_updates(
53        &self,
54        updates: &[LocalUpdate],
55    ) -> Result<HashMap<String, Array2<f32>>> {
56        if updates.is_empty() {
57            return Ok(HashMap::new());
58        }
59
60        // Detect and handle outliers
61        let filtered_updates = if self.outlier_detection.enabled {
62            self.filter_outliers(updates)?
63        } else {
64            updates.to_vec()
65        };
66
67        // Calculate weights for each participant
68        let weights = self.calculate_weights(&filtered_updates)?;
69
70        // Perform aggregation based on strategy
71        match self.strategy {
72            AggregationStrategy::FederatedAveraging => {
73                self.federated_averaging(&filtered_updates, &weights)
74            }
75            AggregationStrategy::WeightedAveraging => {
76                self.weighted_averaging(&filtered_updates, &weights)
77            }
78            AggregationStrategy::SecureAggregation => {
79                self.secure_aggregation(&filtered_updates, &weights)
80            }
81            AggregationStrategy::RobustAggregation => {
82                self.robust_aggregation(&filtered_updates, &weights)
83            }
84            AggregationStrategy::PersonalizedAggregation => {
85                self.personalized_aggregation(&filtered_updates, &weights)
86            }
87            AggregationStrategy::HierarchicalAggregation => {
88                self.hierarchical_aggregation(&filtered_updates, &weights)
89            }
90        }
91    }
92
93    /// Standard federated averaging
94    fn federated_averaging(
95        &self,
96        updates: &[LocalUpdate],
97        weights: &HashMap<Uuid, f64>,
98    ) -> Result<HashMap<String, Array2<f32>>> {
99        self.weighted_averaging(updates, weights)
100    }
101
102    /// Weighted averaging of updates
103    fn weighted_averaging(
104        &self,
105        updates: &[LocalUpdate],
106        weights: &HashMap<Uuid, f64>,
107    ) -> Result<HashMap<String, Array2<f32>>> {
108        let mut aggregated = HashMap::new();
109        let total_weight: f64 = weights.values().sum();
110
111        if total_weight == 0.0 {
112            return Err(anyhow::anyhow!("Total weight is zero"));
113        }
114
115        // Initialize aggregated parameters with zeros
116        if let Some(first_update) = updates.first() {
117            for (param_name, param_values) in &first_update.parameter_updates {
118                aggregated.insert(param_name.clone(), Array2::zeros(param_values.raw_dim()));
119            }
120        }
121
122        // Weighted sum of all updates
123        for update in updates {
124            let weight = weights.get(&update.participant_id).unwrap_or(&0.0) / total_weight;
125
126            for (param_name, param_values) in &update.parameter_updates {
127                if let Some(aggregated_param) = aggregated.get_mut(param_name) {
128                    *aggregated_param = &*aggregated_param + &(param_values * weight as f32);
129                }
130            }
131        }
132
133        Ok(aggregated)
134    }
135
136    /// Secure aggregation with privacy preservation
137    fn secure_aggregation(
138        &self,
139        updates: &[LocalUpdate],
140        weights: &HashMap<Uuid, f64>,
141    ) -> Result<HashMap<String, Array2<f32>>> {
142        // For now, use weighted averaging
143        // In a full implementation, this would use secure multi-party computation
144        self.weighted_averaging(updates, weights)
145    }
146
147    /// Robust aggregation resistant to Byzantine failures
148    fn robust_aggregation(
149        &self,
150        updates: &[LocalUpdate],
151        _weights: &HashMap<Uuid, f64>,
152    ) -> Result<HashMap<String, Array2<f32>>> {
153        let mut aggregated = HashMap::new();
154
155        if let Some(first_update) = updates.first() {
156            for param_name in first_update.parameter_updates.keys() {
157                // Collect all parameter values for this parameter
158                let param_matrices: Vec<&Array2<f32>> = updates
159                    .iter()
160                    .filter_map(|update| update.parameter_updates.get(param_name))
161                    .collect();
162
163                if param_matrices.is_empty() {
164                    continue;
165                }
166
167                // Apply robust aggregation (Krum algorithm approximation)
168                let aggregated_param = if param_matrices.len() > 2 {
169                    self.krum_aggregation(&param_matrices)?
170                } else {
171                    // Fallback to averaging for small number of participants
172                    self.median_aggregation(&param_matrices)?
173                };
174
175                aggregated.insert(param_name.clone(), aggregated_param);
176            }
177        }
178
179        Ok(aggregated)
180    }
181
182    /// Personalized aggregation for participant-specific models
183    fn personalized_aggregation(
184        &self,
185        updates: &[LocalUpdate],
186        weights: &HashMap<Uuid, f64>,
187    ) -> Result<HashMap<String, Array2<f32>>> {
188        // For global model, use weighted averaging
189        // Individual personalized models would be handled separately
190        self.weighted_averaging(updates, weights)
191    }
192
193    /// Hierarchical aggregation for multi-level federation
194    fn hierarchical_aggregation(
195        &self,
196        updates: &[LocalUpdate],
197        weights: &HashMap<Uuid, f64>,
198    ) -> Result<HashMap<String, Array2<f32>>> {
199        // Simplified hierarchical aggregation
200        // In practice, this would involve multiple levels of aggregation
201        self.weighted_averaging(updates, weights)
202    }
203
204    /// Krum aggregation for Byzantine resilience
205    fn krum_aggregation(&self, matrices: &[&Array2<f32>]) -> Result<Array2<f32>> {
206        if matrices.is_empty() {
207            return Err(anyhow::anyhow!("No matrices to aggregate"));
208        }
209
210        // Simplified Krum: find the matrix closest to others
211        let mut best_idx = 0;
212        let mut min_distance = f64::INFINITY;
213
214        for i in 0..matrices.len() {
215            let mut total_distance = 0.0;
216            for j in 0..matrices.len() {
217                if i != j {
218                    total_distance += self.matrix_distance(matrices[i], matrices[j]);
219                }
220            }
221            if total_distance < min_distance {
222                min_distance = total_distance;
223                best_idx = i;
224            }
225        }
226
227        Ok(matrices[best_idx].clone())
228    }
229
230    /// Median aggregation for robustness
231    fn median_aggregation(&self, matrices: &[&Array2<f32>]) -> Result<Array2<f32>> {
232        if matrices.is_empty() {
233            return Err(anyhow::anyhow!("No matrices to aggregate"));
234        }
235
236        let shape = matrices[0].raw_dim();
237        let mut result = Array2::zeros(shape);
238
239        // Element-wise median
240        for i in 0..shape[0] {
241            for j in 0..shape[1] {
242                let mut values: Vec<f32> = matrices.iter().map(|m| m[[i, j]]).collect();
243                values.sort_by(|a, b| a.partial_cmp(b).unwrap());
244
245                let median = if values.len() % 2 == 0 {
246                    (values[values.len() / 2 - 1] + values[values.len() / 2]) / 2.0
247                } else {
248                    values[values.len() / 2]
249                };
250
251                result[[i, j]] = median;
252            }
253        }
254
255        Ok(result)
256    }
257
258    /// Calculate distance between two matrices
259    fn matrix_distance(&self, a: &Array2<f32>, b: &Array2<f32>) -> f64 {
260        (a - b)
261            .iter()
262            .map(|x| (*x as f64) * (*x as f64))
263            .sum::<f64>()
264            .sqrt()
265    }
266
267    /// Calculate participant weights based on weighting scheme
268    fn calculate_weights(&self, updates: &[LocalUpdate]) -> Result<HashMap<Uuid, f64>> {
269        let mut weights = HashMap::new();
270
271        match &self.weighting_scheme {
272            WeightingScheme::Uniform => {
273                let uniform_weight = 1.0 / updates.len() as f64;
274                for update in updates {
275                    weights.insert(update.participant_id, uniform_weight);
276                }
277            }
278            WeightingScheme::SampleSize => {
279                let total_samples: usize = updates.iter().map(|u| u.num_samples).sum();
280                if total_samples > 0 {
281                    for update in updates {
282                        let weight = update.num_samples as f64 / total_samples as f64;
283                        weights.insert(update.participant_id, weight);
284                    }
285                }
286            }
287            WeightingScheme::DataQuality => {
288                // Use training accuracy as a proxy for data quality
289                let total_accuracy: f64 = updates
290                    .iter()
291                    .map(|u| u.training_stats.local_accuracy)
292                    .sum();
293                if total_accuracy > 0.0 {
294                    for update in updates {
295                        let weight = update.training_stats.local_accuracy / total_accuracy;
296                        weights.insert(update.participant_id, weight);
297                    }
298                }
299            }
300            WeightingScheme::ComputeContribution => {
301                // Use inverse of training time as compute contribution
302                let total_compute: f64 = updates
303                    .iter()
304                    .map(|u| 1.0 / (u.training_stats.training_time_seconds + 1.0))
305                    .sum();
306                if total_compute > 0.0 {
307                    for update in updates {
308                        let weight = (1.0 / (update.training_stats.training_time_seconds + 1.0))
309                            / total_compute;
310                        weights.insert(update.participant_id, weight);
311                    }
312                }
313            }
314            WeightingScheme::TrustScore => {
315                // Would require trust scores from participant management
316                // For now, fallback to uniform weighting
317                let uniform_weight = 1.0 / updates.len() as f64;
318                for update in updates {
319                    weights.insert(update.participant_id, uniform_weight);
320                }
321            }
322            WeightingScheme::Custom {
323                weights: custom_weights,
324            } => {
325                for update in updates {
326                    let weight = custom_weights.get(&update.participant_id).unwrap_or(&0.0);
327                    weights.insert(update.participant_id, *weight);
328                }
329            }
330        }
331
332        Ok(weights)
333    }
334
335    /// Filter outliers from updates
336    fn filter_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
337        match self.outlier_detection.method {
338            OutlierDetectionMethod::StatisticalDistance => {
339                self.filter_statistical_outliers(updates)
340            }
341            OutlierDetectionMethod::Clustering => self.filter_clustering_outliers(updates),
342            OutlierDetectionMethod::IsolationForest => {
343                self.filter_isolation_forest_outliers(updates)
344            }
345            OutlierDetectionMethod::ByzantineDetection => self.filter_byzantine_outliers(updates),
346        }
347    }
348
349    /// Filter outliers using statistical distance
350    fn filter_statistical_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
351        if updates.len() < 3 {
352            return Ok(updates.to_vec());
353        }
354
355        // Calculate pairwise distances between updates
356        let mut distances = Vec::new();
357        for i in 0..updates.len() {
358            let mut total_distance = 0.0;
359            for j in 0..updates.len() {
360                if i != j {
361                    total_distance += self.calculate_update_distance(&updates[i], &updates[j]);
362                }
363            }
364            distances.push((i, total_distance / (updates.len() - 1) as f64));
365        }
366
367        // Calculate mean and std of distances
368        let mean_distance: f64 =
369            distances.iter().map(|(_, d)| *d).sum::<f64>() / distances.len() as f64;
370        let variance: f64 = distances
371            .iter()
372            .map(|(_, d)| (d - mean_distance).powi(2))
373            .sum::<f64>()
374            / distances.len() as f64;
375        let std_dev = variance.sqrt();
376
377        // Filter outliers
378        let threshold = mean_distance + self.outlier_detection.threshold * std_dev;
379        let filtered_indices: Vec<usize> = distances
380            .iter()
381            .filter(|(_, d)| *d <= threshold)
382            .map(|(i, _)| *i)
383            .collect();
384
385        Ok(filtered_indices
386            .iter()
387            .map(|&i| updates[i].clone())
388            .collect())
389    }
390
391    /// Calculate distance between two updates
392    fn calculate_update_distance(&self, update1: &LocalUpdate, update2: &LocalUpdate) -> f64 {
393        let mut total_distance = 0.0;
394        let mut param_count = 0;
395
396        for (param_name, param1) in &update1.parameter_updates {
397            if let Some(param2) = update2.parameter_updates.get(param_name) {
398                total_distance += self.matrix_distance(param1, param2);
399                param_count += 1;
400            }
401        }
402
403        if param_count > 0 {
404            total_distance / param_count as f64
405        } else {
406            0.0
407        }
408    }
409
410    /// Filter outliers using clustering (simplified)
411    fn filter_clustering_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
412        // Simplified clustering-based outlier detection
413        // In practice, this would use proper clustering algorithms
414        self.filter_statistical_outliers(updates)
415    }
416
417    /// Filter outliers using isolation forest (simplified)
418    fn filter_isolation_forest_outliers(
419        &self,
420        updates: &[LocalUpdate],
421    ) -> Result<Vec<LocalUpdate>> {
422        // Simplified isolation forest
423        // In practice, this would implement the full isolation forest algorithm
424        self.filter_statistical_outliers(updates)
425    }
426
427    /// Filter Byzantine failures
428    fn filter_byzantine_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
429        // Simplified Byzantine detection
430        // In practice, this would implement sophisticated Byzantine detection
431        self.filter_statistical_outliers(updates)
432    }
433}
434
435/// Weighting schemes for aggregation
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub enum WeightingScheme {
438    /// Equal weights for all participants
439    Uniform,
440    /// Weight by number of samples
441    SampleSize,
442    /// Weight by data quality
443    DataQuality,
444    /// Weight by compute contribution
445    ComputeContribution,
446    /// Weight by trust score
447    TrustScore,
448    /// Custom weighting function
449    Custom { weights: HashMap<Uuid, f64> },
450}
451
452/// Outlier detection for robust aggregation
453#[derive(Debug, Clone, Serialize, Deserialize)]
454pub struct OutlierDetection {
455    /// Enable outlier detection
456    pub enabled: bool,
457    /// Detection method
458    pub method: OutlierDetectionMethod,
459    /// Outlier threshold
460    pub threshold: f64,
461    /// Action on outliers
462    pub outlier_action: OutlierAction,
463}
464
465impl Default for OutlierDetection {
466    fn default() -> Self {
467        Self {
468            enabled: true,
469            method: OutlierDetectionMethod::StatisticalDistance,
470            threshold: 2.0,
471            outlier_action: OutlierAction::ReduceWeight,
472        }
473    }
474}
475
476/// Outlier detection methods
477#[derive(Debug, Clone, Serialize, Deserialize)]
478pub enum OutlierDetectionMethod {
479    /// Statistical distance-based
480    StatisticalDistance,
481    /// Clustering-based
482    Clustering,
483    /// Isolation forest
484    IsolationForest,
485    /// Byzantine detection
486    ByzantineDetection,
487}
488
489/// Actions to take on detected outliers
490#[derive(Debug, Clone, Serialize, Deserialize)]
491pub enum OutlierAction {
492    /// Exclude from aggregation
493    Exclude,
494    /// Reduce weight
495    ReduceWeight,
496    /// Apply robust aggregation
497    RobustAggregation,
498    /// Flag for manual review
499    FlagForReview,
500}
501
502/// Aggregation statistics
503#[derive(Debug, Clone, Serialize, Deserialize)]
504pub struct AggregationStats {
505    /// Number of participants
506    pub num_participants: usize,
507    /// Number of outliers detected
508    pub num_outliers: usize,
509    /// Total parameters aggregated
510    pub total_parameters: usize,
511    /// Aggregation time (seconds)
512    pub aggregation_time_seconds: f64,
513    /// Consensus measure
514    pub consensus_measure: f64,
515    /// Privacy budget consumed
516    pub privacy_budget_consumed: f64,
517}