1use std::fmt::Debug;
2use crate::error::{OptimError, Result};
8use scirs2_core::ndarray::Array1;
9use scirs2_core::numeric::Float;
10use std::collections::{HashMap, VecDeque};
11
12#[derive(Debug, Clone, Copy)]
14pub enum ByzantineRobustMethod {
15 TrimmedMean { trim_ratio: f64 },
17
18 CoordinateWiseMedian,
20
21 Krum { f: usize },
23
24 MultiKrum { f: usize, m: usize },
26
27 Bulyan { f: usize },
29
30 CenteredClipping { tau: f64 },
32
33 FedAvgOutlierDetection { threshold: f64 },
35
36 ReputationWeighted { reputation_decay: f64 },
38}
39
40#[derive(Debug, Clone)]
42pub struct ByzantineRobustConfig {
43 pub method: ByzantineRobustMethod,
45
46 pub expected_byzantine_ratio: f64,
48
49 pub dynamic_detection: bool,
51
52 pub reputation_system: ReputationSystemConfig,
54
55 pub statistical_tests: StatisticalTestConfig,
57}
58
59#[derive(Debug, Clone)]
61pub struct ReputationSystemConfig {
62 pub enabled: bool,
63 pub initial_reputation: f64,
64 pub reputation_decay: f64,
65 pub min_reputation: f64,
66 pub outlier_penalty: f64,
67 pub contribution_bonus: f64,
68}
69
70#[derive(Debug, Clone)]
72pub struct StatisticalTestConfig {
73 pub enabled: bool,
74 pub test_type: StatisticalTestType,
75 pub significancelevel: f64,
76 pub window_size: usize,
77 pub adaptive_threshold: bool,
78}
79
80#[derive(Debug, Clone, Copy)]
81pub enum StatisticalTestType {
82 ZScore,
83 ModifiedZScore,
84 IQRTest,
85 GrubbsTest,
86 ChauventCriterion,
87}
88
89pub struct ByzantineRobustAggregator<
91 T: Float + Debug + Default + Clone + Send + Sync + std::iter::Sum + 'static,
92> {
93 config: ByzantineRobustConfig,
94 client_reputations: HashMap<String, f64>,
95 outlier_history: VecDeque<OutlierDetectionResult>,
96 statistical_analyzer: StatisticalAnalyzer<T>,
97 robust_estimators: RobustEstimators<T>,
98}
99
100pub struct StatisticalAnalyzer<
102 T: Float + Debug + Default + Clone + Send + Sync + std::iter::Sum + 'static,
103> {
104 window_size: usize,
105 significancelevel: f64,
106 test_statistics: VecDeque<TestStatistic<T>>,
107}
108
109pub struct RobustEstimators<
111 T: Float + Debug + Default + Clone + Send + Sync + std::iter::Sum + 'static,
112> {
113 trimmed_mean_cache: HashMap<String, T>,
114 median_cache: HashMap<String, T>,
115 krum_scores: HashMap<String, f64>,
116}
117
118#[derive(Debug, Clone)]
120pub struct OutlierDetectionResult {
121 pub clientid: String,
122 pub round: usize,
123 pub is_outlier: bool,
124 pub outlier_score: f64,
125 pub detection_method: String,
126}
127
128#[derive(Debug, Clone)]
130pub struct TestStatistic<T: Float + Debug + Send + Sync + 'static> {
131 pub statistic_value: T,
132 pub p_value: f64,
133 pub test_type: StatisticalTestType,
134 pub clientid: String,
135}
136
137#[derive(Debug, Clone)]
139pub struct AdaptivePrivacyAllocation {
140 pub epsilon: f64,
141 pub delta: f64,
142 pub utility_weight: f64,
143}
144
145impl<
146 T: Float
147 + Debug
148 + Default
149 + Clone
150 + Send
151 + Sync
152 + 'static
153 + std::iter::Sum
154 + scirs2_core::ndarray::ScalarOperand,
155 > ByzantineRobustAggregator<T>
156{
157 #[allow(dead_code)]
158 pub fn new() -> Result<Self> {
159 Ok(Self {
160 config: ByzantineRobustConfig::default(),
161 client_reputations: HashMap::new(),
162 outlier_history: VecDeque::with_capacity(1000),
163 statistical_analyzer: StatisticalAnalyzer::new(100, 0.05), robust_estimators: RobustEstimators::new(),
165 })
166 }
167
168 #[allow(dead_code)]
169 pub fn detect_byzantine_clients(
170 &mut self,
171 client_updates: &HashMap<String, Array1<T>>,
172 round: usize,
173 ) -> Result<Vec<OutlierDetectionResult>> {
174 self.statistical_analyzer
175 .detect_outliers(client_updates, round)
176 }
177
178 #[allow(dead_code)]
179 pub fn get_client_reputations(&self, clients: &[String]) -> HashMap<String, f64> {
180 let mut reputations = HashMap::new();
181 for client_id in clients {
182 let reputation = self
183 .client_reputations
184 .get(client_id)
185 .copied()
186 .unwrap_or(self.config.reputation_system.initial_reputation);
187 reputations.insert(client_id.clone(), reputation);
188 }
189 reputations
190 }
191
192 #[allow(dead_code)]
193 pub fn robust_aggregate(
194 &self,
195 clientupdates: &HashMap<String, Array1<T>>,
196 _allocations: &HashMap<String, AdaptivePrivacyAllocation>,
197 ) -> Result<Array1<T>> {
198 match self.config.method {
199 ByzantineRobustMethod::TrimmedMean { trim_ratio } => {
200 let mut estimators = RobustEstimators::new();
202 estimators.trimmed_mean(clientupdates, trim_ratio)
203 }
204 ByzantineRobustMethod::CoordinateWiseMedian => {
205 self.coordinate_wise_median(clientupdates)
206 }
207 _ => {
208 if let Some(first_update) = clientupdates.values().next() {
210 let mut result = Array1::zeros(first_update.len());
211 let count = T::from(clientupdates.len()).expect("unwrap failed");
212
213 for update in clientupdates.values() {
214 result = result + update;
215 }
216
217 Ok(result / count)
218 } else {
219 Err(OptimError::InvalidConfig("No client _updates".to_string()))
220 }
221 }
222 }
223 }
224
225 #[allow(dead_code)]
226 pub fn compute_robustness_factor(&self) -> Result<f64> {
227 let detected_byzantine = self
228 .outlier_history
229 .iter()
230 .filter(|result| result.is_outlier)
231 .count() as f64;
232
233 let total_evaluations = self.outlier_history.len() as f64;
234
235 if total_evaluations > 0.0 {
236 Ok(1.0 - (detected_byzantine / total_evaluations))
237 } else {
238 Ok(1.0)
239 }
240 }
241
242 fn coordinate_wise_median(
243 &self,
244 clientupdates: &HashMap<String, Array1<T>>,
245 ) -> Result<Array1<T>> {
246 if clientupdates.is_empty() {
247 return Err(OptimError::InvalidConfig(
248 "No client updates provided".to_string(),
249 ));
250 }
251
252 let first_update = clientupdates.values().next().expect("unwrap failed");
253 let dim = first_update.len();
254 let mut result = Array1::zeros(dim);
255
256 for coord_idx in 0..dim {
258 let mut coord_values: Vec<T> = clientupdates
259 .values()
260 .map(|update| update[coord_idx])
261 .collect();
262
263 coord_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
264
265 let median = if coord_values.len().is_multiple_of(2) {
266 let mid = coord_values.len() / 2;
267 (coord_values[mid - 1] + coord_values[mid])
268 / T::from(2.0).unwrap_or_else(|| T::zero())
269 } else {
270 coord_values[coord_values.len() / 2]
271 };
272
273 result[coord_idx] = median;
274 }
275
276 Ok(result)
277 }
278
279 pub fn config(&self) -> &ByzantineRobustConfig {
281 &self.config
282 }
283
284 pub fn update_client_reputation(&mut self, client_id: String, is_outlier: bool) {
286 let current_reputation = self
287 .client_reputations
288 .get(&client_id)
289 .copied()
290 .unwrap_or(self.config.reputation_system.initial_reputation);
291
292 let new_reputation = if is_outlier {
293 (current_reputation - self.config.reputation_system.outlier_penalty)
294 .max(self.config.reputation_system.min_reputation)
295 } else {
296 (current_reputation + self.config.reputation_system.contribution_bonus).min(1.0)
297 };
298
299 self.client_reputations.insert(client_id, new_reputation);
300 }
301}
302
303impl<T: Float + Debug + Default + Clone + Send + Sync + 'static + std::iter::Sum>
304 StatisticalAnalyzer<T>
305{
306 pub fn new(window_size: usize, significancelevel: f64) -> Self {
308 Self {
309 window_size,
310 significancelevel,
311 test_statistics: VecDeque::with_capacity(window_size),
312 }
313 }
314
315 pub fn detect_outliers(
317 &mut self,
318 clientupdates: &HashMap<String, Array1<T>>,
319 round: usize,
320 ) -> Result<Vec<OutlierDetectionResult>> {
321 let mut results = Vec::new();
322
323 if clientupdates.len() < 3 {
324 return Ok(results); }
326
327 let clientids: Vec<_> = clientupdates.keys().collect();
329 let mut distances = HashMap::new();
330
331 for &client_a in clientids.iter() {
332 let mut total_distance = T::zero();
333 let mut count = 0;
334
335 for &client_b in clientids.iter() {
336 if client_a != client_b {
337 let update_a = &clientupdates[client_a];
339 let update_b = &clientupdates[client_b];
340
341 let mut sum_sq_diff = T::zero();
343 for (a, b) in update_a.iter().zip(update_b.iter()) {
344 let diff = *a - *b;
345 sum_sq_diff = sum_sq_diff + diff * diff;
346 }
347
348 let distance = sum_sq_diff.sqrt();
349 total_distance = total_distance + distance;
350 count += 1;
351 }
352 }
353
354 if count > 0 {
355 let avg_distance = total_distance / T::from(count).unwrap_or_else(|| T::zero());
356 distances.insert(client_a, avg_distance);
357 }
358 }
359
360 if !distances.is_empty() {
362 let distances_vec: Vec<T> = distances.values().cloned().collect();
363 let mean_distance = distances_vec.iter().fold(T::zero(), |acc, &x| acc + x)
364 / T::from(distances_vec.len()).expect("unwrap failed");
365
366 let variance = distances_vec.iter().fold(T::zero(), |acc, &x| {
367 let diff = x - mean_distance;
368 acc + diff * diff
369 }) / T::from(distances_vec.len()).expect("unwrap failed");
370
371 let std_dev = variance.sqrt();
372 let threshold = mean_distance + T::from(1.0).unwrap_or_else(|| T::zero()) * std_dev; for (client_id, &distance) in &distances {
375 let is_outlier = distance > threshold;
376 results.push(OutlierDetectionResult {
377 clientid: client_id.to_string(),
378 round,
379 is_outlier,
380 outlier_score: distance.to_f64().unwrap_or(0.0),
381 detection_method: "statistical_distance".to_string(),
382 });
383 }
384 }
385
386 Ok(results)
387 }
388}
389
390impl<T: Float + Debug + Default + Clone + Send + Sync + 'static + std::iter::Sum>
391 RobustEstimators<T>
392{
393 pub fn new() -> Self {
395 Self {
396 trimmed_mean_cache: HashMap::new(),
397 median_cache: HashMap::new(),
398 krum_scores: HashMap::new(),
399 }
400 }
401
402 pub fn trimmed_mean(
404 &mut self,
405 clientupdates: &HashMap<String, Array1<T>>,
406 trim_ratio: f64,
407 ) -> Result<Array1<T>> {
408 if clientupdates.is_empty() {
409 return Err(OptimError::InvalidConfig(
410 "No client _updates provided".to_string(),
411 ));
412 }
413
414 let first_update = clientupdates.values().next().expect("unwrap failed");
415 let dim = first_update.len();
416
417 for update in clientupdates.values() {
419 if update.len() != dim {
420 return Err(OptimError::InvalidConfig(
421 "Client _updates have different dimensions".to_string(),
422 ));
423 }
424 }
425
426 let mut result = Array1::zeros(dim);
427 let num_clients = clientupdates.len();
428 let trim_count = ((num_clients as f64 * trim_ratio) / 2.0) as usize;
429
430 for coord_idx in 0..dim {
432 let mut coord_values: Vec<T> = clientupdates
433 .values()
434 .map(|update| update[coord_idx])
435 .collect();
436
437 coord_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
438
439 let trimmed_values = &coord_values[trim_count..coord_values.len() - trim_count];
441
442 if !trimmed_values.is_empty() {
443 let sum = trimmed_values.iter().fold(T::zero(), |acc, &x| acc + x);
444 result[coord_idx] = sum / T::from(trimmed_values.len()).expect("unwrap failed");
445 } else {
446 result[coord_idx] = T::zero();
447 }
448 }
449
450 Ok(result)
451 }
452}
453
454impl<T: Float + Debug + Default + Clone + Send + Sync + 'static + std::iter::Sum> Default
455 for RobustEstimators<T>
456{
457 fn default() -> Self {
458 Self::new()
459 }
460}
461
462impl Default for ByzantineRobustConfig {
463 fn default() -> Self {
464 Self {
465 method: ByzantineRobustMethod::TrimmedMean { trim_ratio: 0.1 },
466 expected_byzantine_ratio: 0.1,
467 dynamic_detection: true,
468 reputation_system: ReputationSystemConfig::default(),
469 statistical_tests: StatisticalTestConfig::default(),
470 }
471 }
472}
473
474impl Default for ReputationSystemConfig {
475 fn default() -> Self {
476 Self {
477 enabled: true,
478 initial_reputation: 1.0,
479 reputation_decay: 0.01,
480 min_reputation: 0.1,
481 outlier_penalty: 0.5,
482 contribution_bonus: 0.1,
483 }
484 }
485}
486
487impl Default for StatisticalTestConfig {
488 fn default() -> Self {
489 Self {
490 enabled: true,
491 test_type: StatisticalTestType::ZScore,
492 significancelevel: 0.05,
493 window_size: 100,
494 adaptive_threshold: true,
495 }
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use scirs2_core::ndarray::Array1;
503
504 #[test]
505 fn test_byzantine_robust_aggregator_creation() {
506 let aggregator = ByzantineRobustAggregator::<f64>::new();
507 assert!(aggregator.is_ok());
508 }
509
510 #[test]
511 fn test_trimmed_mean_aggregation() {
512 let mut estimators = RobustEstimators::<f64>::new();
513
514 let mut client_updates = HashMap::new();
515 client_updates.insert("client1".to_string(), Array1::from(vec![1.0, 2.0, 3.0]));
516 client_updates.insert("client2".to_string(), Array1::from(vec![1.1, 2.1, 3.1]));
517 client_updates.insert("client3".to_string(), Array1::from(vec![10.0, 20.0, 30.0])); client_updates.insert("client4".to_string(), Array1::from(vec![0.9, 1.9, 2.9]));
519
520 let result = estimators.trimmed_mean(&client_updates, 0.25);
521 assert!(result.is_ok());
522
523 let trimmed = result.expect("unwrap failed");
524 assert!(trimmed[0] < 5.0); }
527
528 #[test]
529 fn test_outlier_detection() {
530 let mut analyzer = StatisticalAnalyzer::<f64>::new(100, 0.05);
531
532 let mut client_updates = HashMap::new();
533 client_updates.insert("client1".to_string(), Array1::from(vec![1.0, 2.0, 3.0]));
534 client_updates.insert("client2".to_string(), Array1::from(vec![1.1, 2.1, 3.1]));
535 client_updates.insert(
536 "client3".to_string(),
537 Array1::from(vec![1000.0, 2000.0, 3000.0]),
538 ); let results = analyzer.detect_outliers(&client_updates, 1);
541 assert!(results.is_ok());
542
543 let detections = results.expect("unwrap failed");
544 assert!(!detections.is_empty());
545
546 let outlier_detected = detections
548 .iter()
549 .any(|r| r.clientid == "client3" && r.is_outlier);
550 assert!(outlier_detected);
551 }
552
553 #[test]
554 fn test_coordinate_wise_median() {
555 let aggregator = ByzantineRobustAggregator::<f64>::new().expect("unwrap failed");
556
557 let mut client_updates = HashMap::new();
558 client_updates.insert("client1".to_string(), Array1::from(vec![1.0, 4.0, 7.0]));
559 client_updates.insert("client2".to_string(), Array1::from(vec![2.0, 5.0, 8.0]));
560 client_updates.insert("client3".to_string(), Array1::from(vec![3.0, 6.0, 9.0]));
561
562 let result = aggregator.coordinate_wise_median(&client_updates);
563 assert!(result.is_ok());
564
565 let median = result.expect("unwrap failed");
566 assert_eq!(median[0], 2.0); assert_eq!(median[1], 5.0); assert_eq!(median[2], 8.0); }
570
571 #[test]
572 fn test_reputation_system() {
573 let mut aggregator = ByzantineRobustAggregator::<f64>::new().expect("unwrap failed");
574
575 let reputations = aggregator.get_client_reputations(&["client1".to_string()]);
577 assert_eq!(reputations.get("client1"), Some(&1.0));
578
579 aggregator.update_client_reputation("client1".to_string(), true);
581 let updated_reputations = aggregator.get_client_reputations(&["client1".to_string()]);
582 assert!(updated_reputations.get("client1").expect("unwrap failed") < &1.0);
583
584 aggregator.update_client_reputation("client2".to_string(), false);
586 let good_reputations = aggregator.get_client_reputations(&["client2".to_string()]);
587 assert_eq!(good_reputations.get("client2"), Some(&1.0)); }
589}