oxirs_embed/federated_learning/
aggregation.rs1use super::config::AggregationStrategy;
8use super::participant::LocalUpdate;
9use anyhow::Result;
10use scirs2_core::ndarray_ext::Array2;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AggregationEngine {
18 pub strategy: AggregationStrategy,
20 pub parameters: HashMap<String, f64>,
22 pub weighting_scheme: WeightingScheme,
24 pub outlier_detection: OutlierDetection,
26}
27
28impl AggregationEngine {
29 pub fn new(strategy: AggregationStrategy) -> Self {
31 Self {
32 strategy,
33 parameters: HashMap::new(),
34 weighting_scheme: WeightingScheme::SampleSize,
35 outlier_detection: OutlierDetection::default(),
36 }
37 }
38
39 pub fn with_weighting_scheme(mut self, scheme: WeightingScheme) -> Self {
41 self.weighting_scheme = scheme;
42 self
43 }
44
45 pub fn with_outlier_detection(mut self, detection: OutlierDetection) -> Self {
47 self.outlier_detection = detection;
48 self
49 }
50
51 pub fn aggregate_updates(
53 &self,
54 updates: &[LocalUpdate],
55 ) -> Result<HashMap<String, Array2<f32>>> {
56 if updates.is_empty() {
57 return Ok(HashMap::new());
58 }
59
60 let filtered_updates = if self.outlier_detection.enabled {
62 self.filter_outliers(updates)?
63 } else {
64 updates.to_vec()
65 };
66
67 let weights = self.calculate_weights(&filtered_updates)?;
69
70 match self.strategy {
72 AggregationStrategy::FederatedAveraging => {
73 self.federated_averaging(&filtered_updates, &weights)
74 }
75 AggregationStrategy::WeightedAveraging => {
76 self.weighted_averaging(&filtered_updates, &weights)
77 }
78 AggregationStrategy::SecureAggregation => {
79 self.secure_aggregation(&filtered_updates, &weights)
80 }
81 AggregationStrategy::RobustAggregation => {
82 self.robust_aggregation(&filtered_updates, &weights)
83 }
84 AggregationStrategy::PersonalizedAggregation => {
85 self.personalized_aggregation(&filtered_updates, &weights)
86 }
87 AggregationStrategy::HierarchicalAggregation => {
88 self.hierarchical_aggregation(&filtered_updates, &weights)
89 }
90 }
91 }
92
93 fn federated_averaging(
95 &self,
96 updates: &[LocalUpdate],
97 weights: &HashMap<Uuid, f64>,
98 ) -> Result<HashMap<String, Array2<f32>>> {
99 self.weighted_averaging(updates, weights)
100 }
101
102 fn weighted_averaging(
104 &self,
105 updates: &[LocalUpdate],
106 weights: &HashMap<Uuid, f64>,
107 ) -> Result<HashMap<String, Array2<f32>>> {
108 let mut aggregated = HashMap::new();
109 let total_weight: f64 = weights.values().sum();
110
111 if total_weight == 0.0 {
112 return Err(anyhow::anyhow!("Total weight is zero"));
113 }
114
115 if let Some(first_update) = updates.first() {
117 for (param_name, param_values) in &first_update.parameter_updates {
118 aggregated.insert(param_name.clone(), Array2::zeros(param_values.raw_dim()));
119 }
120 }
121
122 for update in updates {
124 let weight = weights.get(&update.participant_id).unwrap_or(&0.0) / total_weight;
125
126 for (param_name, param_values) in &update.parameter_updates {
127 if let Some(aggregated_param) = aggregated.get_mut(param_name) {
128 *aggregated_param = &*aggregated_param + &(param_values * weight as f32);
129 }
130 }
131 }
132
133 Ok(aggregated)
134 }
135
136 fn secure_aggregation(
138 &self,
139 updates: &[LocalUpdate],
140 weights: &HashMap<Uuid, f64>,
141 ) -> Result<HashMap<String, Array2<f32>>> {
142 self.weighted_averaging(updates, weights)
145 }
146
147 fn robust_aggregation(
149 &self,
150 updates: &[LocalUpdate],
151 _weights: &HashMap<Uuid, f64>,
152 ) -> Result<HashMap<String, Array2<f32>>> {
153 let mut aggregated = HashMap::new();
154
155 if let Some(first_update) = updates.first() {
156 for param_name in first_update.parameter_updates.keys() {
157 let param_matrices: Vec<&Array2<f32>> = updates
159 .iter()
160 .filter_map(|update| update.parameter_updates.get(param_name))
161 .collect();
162
163 if param_matrices.is_empty() {
164 continue;
165 }
166
167 let aggregated_param = if param_matrices.len() > 2 {
169 self.krum_aggregation(¶m_matrices)?
170 } else {
171 self.median_aggregation(¶m_matrices)?
173 };
174
175 aggregated.insert(param_name.clone(), aggregated_param);
176 }
177 }
178
179 Ok(aggregated)
180 }
181
182 fn personalized_aggregation(
184 &self,
185 updates: &[LocalUpdate],
186 weights: &HashMap<Uuid, f64>,
187 ) -> Result<HashMap<String, Array2<f32>>> {
188 self.weighted_averaging(updates, weights)
191 }
192
193 fn hierarchical_aggregation(
195 &self,
196 updates: &[LocalUpdate],
197 weights: &HashMap<Uuid, f64>,
198 ) -> Result<HashMap<String, Array2<f32>>> {
199 self.weighted_averaging(updates, weights)
202 }
203
204 fn krum_aggregation(&self, matrices: &[&Array2<f32>]) -> Result<Array2<f32>> {
206 if matrices.is_empty() {
207 return Err(anyhow::anyhow!("No matrices to aggregate"));
208 }
209
210 let mut best_idx = 0;
212 let mut min_distance = f64::INFINITY;
213
214 for i in 0..matrices.len() {
215 let mut total_distance = 0.0;
216 for j in 0..matrices.len() {
217 if i != j {
218 total_distance += self.matrix_distance(matrices[i], matrices[j]);
219 }
220 }
221 if total_distance < min_distance {
222 min_distance = total_distance;
223 best_idx = i;
224 }
225 }
226
227 Ok(matrices[best_idx].clone())
228 }
229
230 fn median_aggregation(&self, matrices: &[&Array2<f32>]) -> Result<Array2<f32>> {
232 if matrices.is_empty() {
233 return Err(anyhow::anyhow!("No matrices to aggregate"));
234 }
235
236 let shape = matrices[0].raw_dim();
237 let mut result = Array2::zeros(shape);
238
239 for i in 0..shape[0] {
241 for j in 0..shape[1] {
242 let mut values: Vec<f32> = matrices.iter().map(|m| m[[i, j]]).collect();
243 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
244
245 let median = if values.len() % 2 == 0 {
246 (values[values.len() / 2 - 1] + values[values.len() / 2]) / 2.0
247 } else {
248 values[values.len() / 2]
249 };
250
251 result[[i, j]] = median;
252 }
253 }
254
255 Ok(result)
256 }
257
258 fn matrix_distance(&self, a: &Array2<f32>, b: &Array2<f32>) -> f64 {
260 (a - b)
261 .iter()
262 .map(|x| (*x as f64) * (*x as f64))
263 .sum::<f64>()
264 .sqrt()
265 }
266
267 fn calculate_weights(&self, updates: &[LocalUpdate]) -> Result<HashMap<Uuid, f64>> {
269 let mut weights = HashMap::new();
270
271 match &self.weighting_scheme {
272 WeightingScheme::Uniform => {
273 let uniform_weight = 1.0 / updates.len() as f64;
274 for update in updates {
275 weights.insert(update.participant_id, uniform_weight);
276 }
277 }
278 WeightingScheme::SampleSize => {
279 let total_samples: usize = updates.iter().map(|u| u.num_samples).sum();
280 if total_samples > 0 {
281 for update in updates {
282 let weight = update.num_samples as f64 / total_samples as f64;
283 weights.insert(update.participant_id, weight);
284 }
285 }
286 }
287 WeightingScheme::DataQuality => {
288 let total_accuracy: f64 = updates
290 .iter()
291 .map(|u| u.training_stats.local_accuracy)
292 .sum();
293 if total_accuracy > 0.0 {
294 for update in updates {
295 let weight = update.training_stats.local_accuracy / total_accuracy;
296 weights.insert(update.participant_id, weight);
297 }
298 }
299 }
300 WeightingScheme::ComputeContribution => {
301 let total_compute: f64 = updates
303 .iter()
304 .map(|u| 1.0 / (u.training_stats.training_time_seconds + 1.0))
305 .sum();
306 if total_compute > 0.0 {
307 for update in updates {
308 let weight = (1.0 / (update.training_stats.training_time_seconds + 1.0))
309 / total_compute;
310 weights.insert(update.participant_id, weight);
311 }
312 }
313 }
314 WeightingScheme::TrustScore => {
315 let uniform_weight = 1.0 / updates.len() as f64;
318 for update in updates {
319 weights.insert(update.participant_id, uniform_weight);
320 }
321 }
322 WeightingScheme::Custom {
323 weights: custom_weights,
324 } => {
325 for update in updates {
326 let weight = custom_weights.get(&update.participant_id).unwrap_or(&0.0);
327 weights.insert(update.participant_id, *weight);
328 }
329 }
330 }
331
332 Ok(weights)
333 }
334
335 fn filter_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
337 match self.outlier_detection.method {
338 OutlierDetectionMethod::StatisticalDistance => {
339 self.filter_statistical_outliers(updates)
340 }
341 OutlierDetectionMethod::Clustering => self.filter_clustering_outliers(updates),
342 OutlierDetectionMethod::IsolationForest => {
343 self.filter_isolation_forest_outliers(updates)
344 }
345 OutlierDetectionMethod::ByzantineDetection => self.filter_byzantine_outliers(updates),
346 }
347 }
348
349 fn filter_statistical_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
351 if updates.len() < 3 {
352 return Ok(updates.to_vec());
353 }
354
355 let mut distances = Vec::new();
357 for i in 0..updates.len() {
358 let mut total_distance = 0.0;
359 for j in 0..updates.len() {
360 if i != j {
361 total_distance += self.calculate_update_distance(&updates[i], &updates[j]);
362 }
363 }
364 distances.push((i, total_distance / (updates.len() - 1) as f64));
365 }
366
367 let mean_distance: f64 =
369 distances.iter().map(|(_, d)| *d).sum::<f64>() / distances.len() as f64;
370 let variance: f64 = distances
371 .iter()
372 .map(|(_, d)| (d - mean_distance).powi(2))
373 .sum::<f64>()
374 / distances.len() as f64;
375 let std_dev = variance.sqrt();
376
377 let threshold = mean_distance + self.outlier_detection.threshold * std_dev;
379 let filtered_indices: Vec<usize> = distances
380 .iter()
381 .filter(|(_, d)| *d <= threshold)
382 .map(|(i, _)| *i)
383 .collect();
384
385 Ok(filtered_indices
386 .iter()
387 .map(|&i| updates[i].clone())
388 .collect())
389 }
390
391 fn calculate_update_distance(&self, update1: &LocalUpdate, update2: &LocalUpdate) -> f64 {
393 let mut total_distance = 0.0;
394 let mut param_count = 0;
395
396 for (param_name, param1) in &update1.parameter_updates {
397 if let Some(param2) = update2.parameter_updates.get(param_name) {
398 total_distance += self.matrix_distance(param1, param2);
399 param_count += 1;
400 }
401 }
402
403 if param_count > 0 {
404 total_distance / param_count as f64
405 } else {
406 0.0
407 }
408 }
409
410 fn filter_clustering_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
412 self.filter_statistical_outliers(updates)
415 }
416
417 fn filter_isolation_forest_outliers(
419 &self,
420 updates: &[LocalUpdate],
421 ) -> Result<Vec<LocalUpdate>> {
422 self.filter_statistical_outliers(updates)
425 }
426
427 fn filter_byzantine_outliers(&self, updates: &[LocalUpdate]) -> Result<Vec<LocalUpdate>> {
429 self.filter_statistical_outliers(updates)
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
437pub enum WeightingScheme {
438 Uniform,
440 SampleSize,
442 DataQuality,
444 ComputeContribution,
446 TrustScore,
448 Custom { weights: HashMap<Uuid, f64> },
450}
451
452#[derive(Debug, Clone, Serialize, Deserialize)]
454pub struct OutlierDetection {
455 pub enabled: bool,
457 pub method: OutlierDetectionMethod,
459 pub threshold: f64,
461 pub outlier_action: OutlierAction,
463}
464
465impl Default for OutlierDetection {
466 fn default() -> Self {
467 Self {
468 enabled: true,
469 method: OutlierDetectionMethod::StatisticalDistance,
470 threshold: 2.0,
471 outlier_action: OutlierAction::ReduceWeight,
472 }
473 }
474}
475
476#[derive(Debug, Clone, Serialize, Deserialize)]
478pub enum OutlierDetectionMethod {
479 StatisticalDistance,
481 Clustering,
483 IsolationForest,
485 ByzantineDetection,
487}
488
489#[derive(Debug, Clone, Serialize, Deserialize)]
491pub enum OutlierAction {
492 Exclude,
494 ReduceWeight,
496 RobustAggregation,
498 FlagForReview,
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize)]
504pub struct AggregationStats {
505 pub num_participants: usize,
507 pub num_outliers: usize,
509 pub total_parameters: usize,
511 pub aggregation_time_seconds: f64,
513 pub consensus_measure: f64,
515 pub privacy_budget_consumed: f64,
517}