Skip to main content

optirs_core/privacy/federated/
byzantine_aggregation.rs

1use std::fmt::Debug;
2// Byzantine Robust Aggregation Module
3//
4// This module implements Byzantine-robust aggregation algorithms for federated learning,
5// providing protection against malicious clients and outlier detection mechanisms.
6
7use crate::error::{OptimError, Result};
8use scirs2_core::ndarray::Array1;
9use scirs2_core::numeric::Float;
10use std::collections::{HashMap, VecDeque};
11
12/// Byzantine-robust aggregation algorithms
13#[derive(Debug, Clone, Copy)]
14pub enum ByzantineRobustMethod {
15    /// Trimmed mean aggregation
16    TrimmedMean { trim_ratio: f64 },
17
18    /// Coordinate-wise median
19    CoordinateWiseMedian,
20
21    /// Krum aggregation
22    Krum { f: usize },
23
24    /// Multi-Krum aggregation
25    MultiKrum { f: usize, m: usize },
26
27    /// Bulyan aggregation
28    Bulyan { f: usize },
29
30    /// Centered clipping
31    CenteredClipping { tau: f64 },
32
33    /// FedAvg with outlier detection
34    FedAvgOutlierDetection { threshold: f64 },
35
36    /// Robust aggregation with reputation
37    ReputationWeighted { reputation_decay: f64 },
38}
39
40/// Byzantine robustness configuration
41#[derive(Debug, Clone)]
42pub struct ByzantineRobustConfig {
43    /// Aggregation method
44    pub method: ByzantineRobustMethod,
45
46    /// Expected number of Byzantine clients
47    pub expected_byzantine_ratio: f64,
48
49    /// Enable dynamic Byzantine detection
50    pub dynamic_detection: bool,
51
52    /// Reputation system settings
53    pub reputation_system: ReputationSystemConfig,
54
55    /// Statistical tests for outlier detection
56    pub statistical_tests: StatisticalTestConfig,
57}
58
59/// Reputation system configuration
60#[derive(Debug, Clone)]
61pub struct ReputationSystemConfig {
62    pub enabled: bool,
63    pub initial_reputation: f64,
64    pub reputation_decay: f64,
65    pub min_reputation: f64,
66    pub outlier_penalty: f64,
67    pub contribution_bonus: f64,
68}
69
70/// Statistical test configuration for outlier detection
71#[derive(Debug, Clone)]
72pub struct StatisticalTestConfig {
73    pub enabled: bool,
74    pub test_type: StatisticalTestType,
75    pub significancelevel: f64,
76    pub window_size: usize,
77    pub adaptive_threshold: bool,
78}
79
80#[derive(Debug, Clone, Copy)]
81pub enum StatisticalTestType {
82    ZScore,
83    ModifiedZScore,
84    IQRTest,
85    GrubbsTest,
86    ChauventCriterion,
87}
88
89/// Byzantine-robust aggregation engine
90pub struct ByzantineRobustAggregator<
91    T: Float + Debug + Default + Clone + Send + Sync + std::iter::Sum + 'static,
92> {
93    config: ByzantineRobustConfig,
94    client_reputations: HashMap<String, f64>,
95    outlier_history: VecDeque<OutlierDetectionResult>,
96    statistical_analyzer: StatisticalAnalyzer<T>,
97    robust_estimators: RobustEstimators<T>,
98}
99
100/// Statistical analyzer for outlier detection
101pub struct StatisticalAnalyzer<
102    T: Float + Debug + Default + Clone + Send + Sync + std::iter::Sum + 'static,
103> {
104    window_size: usize,
105    significancelevel: f64,
106    test_statistics: VecDeque<TestStatistic<T>>,
107}
108
109/// Robust estimators for aggregation
110pub struct RobustEstimators<
111    T: Float + Debug + Default + Clone + Send + Sync + std::iter::Sum + 'static,
112> {
113    trimmed_mean_cache: HashMap<String, T>,
114    median_cache: HashMap<String, T>,
115    krum_scores: HashMap<String, f64>,
116}
117
118/// Outlier detection result
119#[derive(Debug, Clone)]
120pub struct OutlierDetectionResult {
121    pub clientid: String,
122    pub round: usize,
123    pub is_outlier: bool,
124    pub outlier_score: f64,
125    pub detection_method: String,
126}
127
128/// Test statistic for outlier detection
129#[derive(Debug, Clone)]
130pub struct TestStatistic<T: Float + Debug + Send + Sync + 'static> {
131    pub statistic_value: T,
132    pub p_value: f64,
133    pub test_type: StatisticalTestType,
134    pub clientid: String,
135}
136
137/// Placeholder for adaptive privacy allocation
138#[derive(Debug, Clone)]
139pub struct AdaptivePrivacyAllocation {
140    pub epsilon: f64,
141    pub delta: f64,
142    pub utility_weight: f64,
143}
144
145impl<
146        T: Float
147            + Debug
148            + Default
149            + Clone
150            + Send
151            + Sync
152            + 'static
153            + std::iter::Sum
154            + scirs2_core::ndarray::ScalarOperand,
155    > ByzantineRobustAggregator<T>
156{
157    #[allow(dead_code)]
158    pub fn new() -> Result<Self> {
159        Ok(Self {
160            config: ByzantineRobustConfig::default(),
161            client_reputations: HashMap::new(),
162            outlier_history: VecDeque::with_capacity(1000),
163            statistical_analyzer: StatisticalAnalyzer::new(100, 0.05), // window_size=100, significancelevel=0.05
164            robust_estimators: RobustEstimators::new(),
165        })
166    }
167
168    #[allow(dead_code)]
169    pub fn detect_byzantine_clients(
170        &mut self,
171        client_updates: &HashMap<String, Array1<T>>,
172        round: usize,
173    ) -> Result<Vec<OutlierDetectionResult>> {
174        self.statistical_analyzer
175            .detect_outliers(client_updates, round)
176    }
177
178    #[allow(dead_code)]
179    pub fn get_client_reputations(&self, clients: &[String]) -> HashMap<String, f64> {
180        let mut reputations = HashMap::new();
181        for client_id in clients {
182            let reputation = self
183                .client_reputations
184                .get(client_id)
185                .copied()
186                .unwrap_or(self.config.reputation_system.initial_reputation);
187            reputations.insert(client_id.clone(), reputation);
188        }
189        reputations
190    }
191
192    #[allow(dead_code)]
193    pub fn robust_aggregate(
194        &self,
195        clientupdates: &HashMap<String, Array1<T>>,
196        _allocations: &HashMap<String, AdaptivePrivacyAllocation>,
197    ) -> Result<Array1<T>> {
198        match self.config.method {
199            ByzantineRobustMethod::TrimmedMean { trim_ratio } => {
200                // Use robust estimators for trimmed mean
201                let mut estimators = RobustEstimators::new();
202                estimators.trimmed_mean(clientupdates, trim_ratio)
203            }
204            ByzantineRobustMethod::CoordinateWiseMedian => {
205                self.coordinate_wise_median(clientupdates)
206            }
207            _ => {
208                // Default to simple averaging for other methods
209                if let Some(first_update) = clientupdates.values().next() {
210                    let mut result = Array1::zeros(first_update.len());
211                    let count = T::from(clientupdates.len()).expect("unwrap failed");
212
213                    for update in clientupdates.values() {
214                        result = result + update;
215                    }
216
217                    Ok(result / count)
218                } else {
219                    Err(OptimError::InvalidConfig("No client _updates".to_string()))
220                }
221            }
222        }
223    }
224
225    #[allow(dead_code)]
226    pub fn compute_robustness_factor(&self) -> Result<f64> {
227        let detected_byzantine = self
228            .outlier_history
229            .iter()
230            .filter(|result| result.is_outlier)
231            .count() as f64;
232
233        let total_evaluations = self.outlier_history.len() as f64;
234
235        if total_evaluations > 0.0 {
236            Ok(1.0 - (detected_byzantine / total_evaluations))
237        } else {
238            Ok(1.0)
239        }
240    }
241
242    fn coordinate_wise_median(
243        &self,
244        clientupdates: &HashMap<String, Array1<T>>,
245    ) -> Result<Array1<T>> {
246        if clientupdates.is_empty() {
247            return Err(OptimError::InvalidConfig(
248                "No client updates provided".to_string(),
249            ));
250        }
251
252        let first_update = clientupdates.values().next().expect("unwrap failed");
253        let dim = first_update.len();
254        let mut result = Array1::zeros(dim);
255
256        // For each coordinate, compute median across all clients
257        for coord_idx in 0..dim {
258            let mut coord_values: Vec<T> = clientupdates
259                .values()
260                .map(|update| update[coord_idx])
261                .collect();
262
263            coord_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
264
265            let median = if coord_values.len().is_multiple_of(2) {
266                let mid = coord_values.len() / 2;
267                (coord_values[mid - 1] + coord_values[mid])
268                    / T::from(2.0).unwrap_or_else(|| T::zero())
269            } else {
270                coord_values[coord_values.len() / 2]
271            };
272
273            result[coord_idx] = median;
274        }
275
276        Ok(result)
277    }
278
279    /// Get current configuration
280    pub fn config(&self) -> &ByzantineRobustConfig {
281        &self.config
282    }
283
284    /// Update client reputation
285    pub fn update_client_reputation(&mut self, client_id: String, is_outlier: bool) {
286        let current_reputation = self
287            .client_reputations
288            .get(&client_id)
289            .copied()
290            .unwrap_or(self.config.reputation_system.initial_reputation);
291
292        let new_reputation = if is_outlier {
293            (current_reputation - self.config.reputation_system.outlier_penalty)
294                .max(self.config.reputation_system.min_reputation)
295        } else {
296            (current_reputation + self.config.reputation_system.contribution_bonus).min(1.0)
297        };
298
299        self.client_reputations.insert(client_id, new_reputation);
300    }
301}
302
303impl<T: Float + Debug + Default + Clone + Send + Sync + 'static + std::iter::Sum>
304    StatisticalAnalyzer<T>
305{
306    /// Create new statistical analyzer
307    pub fn new(window_size: usize, significancelevel: f64) -> Self {
308        Self {
309            window_size,
310            significancelevel,
311            test_statistics: VecDeque::with_capacity(window_size),
312        }
313    }
314
315    /// Detect outliers using statistical tests
316    pub fn detect_outliers(
317        &mut self,
318        clientupdates: &HashMap<String, Array1<T>>,
319        round: usize,
320    ) -> Result<Vec<OutlierDetectionResult>> {
321        let mut results = Vec::new();
322
323        if clientupdates.len() < 3 {
324            return Ok(results); // Need at least 3 clients for meaningful analysis
325        }
326
327        // Simple outlier detection based on pairwise distances
328        let clientids: Vec<_> = clientupdates.keys().collect();
329        let mut distances = HashMap::new();
330
331        for &client_a in clientids.iter() {
332            let mut total_distance = T::zero();
333            let mut count = 0;
334
335            for &client_b in clientids.iter() {
336                if client_a != client_b {
337                    // Skip self comparison
338                    let update_a = &clientupdates[client_a];
339                    let update_b = &clientupdates[client_b];
340
341                    // Compute Euclidean distance
342                    let mut sum_sq_diff = T::zero();
343                    for (a, b) in update_a.iter().zip(update_b.iter()) {
344                        let diff = *a - *b;
345                        sum_sq_diff = sum_sq_diff + diff * diff;
346                    }
347
348                    let distance = sum_sq_diff.sqrt();
349                    total_distance = total_distance + distance;
350                    count += 1;
351                }
352            }
353
354            if count > 0 {
355                let avg_distance = total_distance / T::from(count).unwrap_or_else(|| T::zero());
356                distances.insert(client_a, avg_distance);
357            }
358        }
359
360        // Detect outliers based on distance threshold
361        if !distances.is_empty() {
362            let distances_vec: Vec<T> = distances.values().cloned().collect();
363            let mean_distance = distances_vec.iter().fold(T::zero(), |acc, &x| acc + x)
364                / T::from(distances_vec.len()).expect("unwrap failed");
365
366            let variance = distances_vec.iter().fold(T::zero(), |acc, &x| {
367                let diff = x - mean_distance;
368                acc + diff * diff
369            }) / T::from(distances_vec.len()).expect("unwrap failed");
370
371            let std_dev = variance.sqrt();
372            let threshold = mean_distance + T::from(1.0).unwrap_or_else(|| T::zero()) * std_dev; // 1-sigma threshold (more sensitive)
373
374            for (client_id, &distance) in &distances {
375                let is_outlier = distance > threshold;
376                results.push(OutlierDetectionResult {
377                    clientid: client_id.to_string(),
378                    round,
379                    is_outlier,
380                    outlier_score: distance.to_f64().unwrap_or(0.0),
381                    detection_method: "statistical_distance".to_string(),
382                });
383            }
384        }
385
386        Ok(results)
387    }
388}
389
390impl<T: Float + Debug + Default + Clone + Send + Sync + 'static + std::iter::Sum>
391    RobustEstimators<T>
392{
393    /// Create new robust estimators
394    pub fn new() -> Self {
395        Self {
396            trimmed_mean_cache: HashMap::new(),
397            median_cache: HashMap::new(),
398            krum_scores: HashMap::new(),
399        }
400    }
401
402    /// Compute trimmed mean of client updates
403    pub fn trimmed_mean(
404        &mut self,
405        clientupdates: &HashMap<String, Array1<T>>,
406        trim_ratio: f64,
407    ) -> Result<Array1<T>> {
408        if clientupdates.is_empty() {
409            return Err(OptimError::InvalidConfig(
410                "No client _updates provided".to_string(),
411            ));
412        }
413
414        let first_update = clientupdates.values().next().expect("unwrap failed");
415        let dim = first_update.len();
416
417        // Verify all _updates have same dimension
418        for update in clientupdates.values() {
419            if update.len() != dim {
420                return Err(OptimError::InvalidConfig(
421                    "Client _updates have different dimensions".to_string(),
422                ));
423            }
424        }
425
426        let mut result = Array1::zeros(dim);
427        let num_clients = clientupdates.len();
428        let trim_count = ((num_clients as f64 * trim_ratio) / 2.0) as usize;
429
430        // For each coordinate, compute trimmed mean
431        for coord_idx in 0..dim {
432            let mut coord_values: Vec<T> = clientupdates
433                .values()
434                .map(|update| update[coord_idx])
435                .collect();
436
437            coord_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
438
439            // Remove extreme values
440            let trimmed_values = &coord_values[trim_count..coord_values.len() - trim_count];
441
442            if !trimmed_values.is_empty() {
443                let sum = trimmed_values.iter().fold(T::zero(), |acc, &x| acc + x);
444                result[coord_idx] = sum / T::from(trimmed_values.len()).expect("unwrap failed");
445            } else {
446                result[coord_idx] = T::zero();
447            }
448        }
449
450        Ok(result)
451    }
452}
453
454impl<T: Float + Debug + Default + Clone + Send + Sync + 'static + std::iter::Sum> Default
455    for RobustEstimators<T>
456{
457    fn default() -> Self {
458        Self::new()
459    }
460}
461
462impl Default for ByzantineRobustConfig {
463    fn default() -> Self {
464        Self {
465            method: ByzantineRobustMethod::TrimmedMean { trim_ratio: 0.1 },
466            expected_byzantine_ratio: 0.1,
467            dynamic_detection: true,
468            reputation_system: ReputationSystemConfig::default(),
469            statistical_tests: StatisticalTestConfig::default(),
470        }
471    }
472}
473
474impl Default for ReputationSystemConfig {
475    fn default() -> Self {
476        Self {
477            enabled: true,
478            initial_reputation: 1.0,
479            reputation_decay: 0.01,
480            min_reputation: 0.1,
481            outlier_penalty: 0.5,
482            contribution_bonus: 0.1,
483        }
484    }
485}
486
487impl Default for StatisticalTestConfig {
488    fn default() -> Self {
489        Self {
490            enabled: true,
491            test_type: StatisticalTestType::ZScore,
492            significancelevel: 0.05,
493            window_size: 100,
494            adaptive_threshold: true,
495        }
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use scirs2_core::ndarray::Array1;
503
504    #[test]
505    fn test_byzantine_robust_aggregator_creation() {
506        let aggregator = ByzantineRobustAggregator::<f64>::new();
507        assert!(aggregator.is_ok());
508    }
509
510    #[test]
511    fn test_trimmed_mean_aggregation() {
512        let mut estimators = RobustEstimators::<f64>::new();
513
514        let mut client_updates = HashMap::new();
515        client_updates.insert("client1".to_string(), Array1::from(vec![1.0, 2.0, 3.0]));
516        client_updates.insert("client2".to_string(), Array1::from(vec![1.1, 2.1, 3.1]));
517        client_updates.insert("client3".to_string(), Array1::from(vec![10.0, 20.0, 30.0])); // Outlier
518        client_updates.insert("client4".to_string(), Array1::from(vec![0.9, 1.9, 2.9]));
519
520        let result = estimators.trimmed_mean(&client_updates, 0.25);
521        assert!(result.is_ok());
522
523        let trimmed = result.expect("unwrap failed");
524        // Should exclude the outlier client3
525        assert!(trimmed[0] < 5.0); // Should be around 1.0, not influenced by 10.0
526    }
527
528    #[test]
529    fn test_outlier_detection() {
530        let mut analyzer = StatisticalAnalyzer::<f64>::new(100, 0.05);
531
532        let mut client_updates = HashMap::new();
533        client_updates.insert("client1".to_string(), Array1::from(vec![1.0, 2.0, 3.0]));
534        client_updates.insert("client2".to_string(), Array1::from(vec![1.1, 2.1, 3.1]));
535        client_updates.insert(
536            "client3".to_string(),
537            Array1::from(vec![1000.0, 2000.0, 3000.0]),
538        ); // Very clear outlier
539
540        let results = analyzer.detect_outliers(&client_updates, 1);
541        assert!(results.is_ok());
542
543        let detections = results.expect("unwrap failed");
544        assert!(!detections.is_empty());
545
546        // Check if the outlier was detected
547        let outlier_detected = detections
548            .iter()
549            .any(|r| r.clientid == "client3" && r.is_outlier);
550        assert!(outlier_detected);
551    }
552
553    #[test]
554    fn test_coordinate_wise_median() {
555        let aggregator = ByzantineRobustAggregator::<f64>::new().expect("unwrap failed");
556
557        let mut client_updates = HashMap::new();
558        client_updates.insert("client1".to_string(), Array1::from(vec![1.0, 4.0, 7.0]));
559        client_updates.insert("client2".to_string(), Array1::from(vec![2.0, 5.0, 8.0]));
560        client_updates.insert("client3".to_string(), Array1::from(vec![3.0, 6.0, 9.0]));
561
562        let result = aggregator.coordinate_wise_median(&client_updates);
563        assert!(result.is_ok());
564
565        let median = result.expect("unwrap failed");
566        assert_eq!(median[0], 2.0); // Median of [1, 2, 3]
567        assert_eq!(median[1], 5.0); // Median of [4, 5, 6]
568        assert_eq!(median[2], 8.0); // Median of [7, 8, 9]
569    }
570
571    #[test]
572    fn test_reputation_system() {
573        let mut aggregator = ByzantineRobustAggregator::<f64>::new().expect("unwrap failed");
574
575        // Test initial reputation
576        let reputations = aggregator.get_client_reputations(&["client1".to_string()]);
577        assert_eq!(reputations.get("client1"), Some(&1.0));
578
579        // Test reputation penalty for outlier
580        aggregator.update_client_reputation("client1".to_string(), true);
581        let updated_reputations = aggregator.get_client_reputations(&["client1".to_string()]);
582        assert!(updated_reputations.get("client1").expect("unwrap failed") < &1.0);
583
584        // Test reputation bonus for good behavior
585        aggregator.update_client_reputation("client2".to_string(), false);
586        let good_reputations = aggregator.get_client_reputations(&["client2".to_string()]);
587        assert_eq!(good_reputations.get("client2"), Some(&1.0)); // Should stay at max (1.0)
588    }
589}