optirs_core/privacy/
byzantine_tolerance.rs

1// Byzantine Fault Tolerance for Federated Learning
2//
3// This module implements Byzantine-robust aggregation algorithms that can
4// tolerate malicious participants in federated learning scenarios.
5
6#[allow(dead_code)]
7use crate::error::{OptimError, Result};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10use std::cmp::Ordering;
11use std::collections::HashMap;
12use std::fmt::Debug;
13
14/// Type alias for validation rule function
15type RuleFn<T> = Box<dyn Fn(&Array1<T>) -> bool + Send + Sync>;
16
17/// Byzantine fault tolerant aggregator
18pub struct ByzantineTolerantAggregator<T: Float + Debug + Send + Sync + 'static> {
19    /// Configuration for Byzantine tolerance
20    config: ByzantineConfig,
21
22    /// Participant reputation scores
23    reputation_scores: HashMap<String, ReputationScore>,
24
25    /// History of participant behavior
26    behavior_history: HashMap<String, BehaviorHistory>,
27
28    /// Anomaly detection engine
29    anomaly_detector: AnomalyDetector<T>,
30
31    /// Statistical analysis engine
32    statistics_engine: StatisticalAnalysis<T>,
33
34    /// Gradient verification system
35    gradient_verifier: GradientVerifier<T>,
36}
37
38#[derive(Debug, Clone)]
39pub struct ByzantineConfig {
40    /// Maximum number of Byzantine participants to tolerate
41    pub max_byzantine: usize,
42
43    /// Minimum number of participants required
44    pub min_participants: usize,
45
46    /// Aggregation method for Byzantine tolerance
47    pub aggregation_method: ByzantineAggregationMethod,
48
49    /// Anomaly detection threshold
50    pub anomaly_threshold: f64,
51
52    /// Reputation decay factor
53    pub reputation_decay: f64,
54
55    /// Enable gradient verification
56    pub gradient_verification: bool,
57
58    /// Statistical outlier detection
59    pub outlier_detection: OutlierDetectionMethod,
60
61    /// Consensus threshold for decision making
62    pub consensus_threshold: f64,
63}
64
65/// Byzantine-robust aggregation methods
66#[derive(Debug, Clone, Copy)]
67pub enum ByzantineAggregationMethod {
68    /// Trimmed mean (remove extreme values)
69    TrimmedMean,
70
71    /// Coordinate-wise median
72    CoordinateMedian,
73
74    /// Krum algorithm (select most representative gradient)
75    Krum,
76
77    /// Multi-Krum (select multiple representative gradients)
78    MultiKrum,
79
80    /// Bulyan algorithm (robust mean estimation)
81    Bulyan,
82
83    /// FoolsGold (defend against Sybil attacks)
84    FoolsGold,
85
86    /// FLAME (Federated Learning with Approximate Model Enhancement)
87    FLAME,
88
89    /// Median-based aggregation
90    Median,
91
92    /// Geometric median
93    GeometricMedian,
94}
95
96/// Outlier detection methods
97#[derive(Debug, Clone, Copy)]
98pub enum OutlierDetectionMethod {
99    /// Z-score based detection
100    ZScore,
101
102    /// Interquartile range method
103    IQR,
104
105    /// Isolation forest
106    IsolationForest,
107
108    /// Local outlier factor
109    LocalOutlierFactor,
110
111    /// Mahalanobis distance
112    MahalanobisDistance,
113}
114
115/// Reputation score for participants
116#[derive(Debug, Clone)]
117pub struct ReputationScore {
118    /// Current reputation score (0.0 to 1.0)
119    pub score: f64,
120
121    /// Number of successful aggregations
122    pub successful_aggregations: usize,
123
124    /// Number of detected anomalies
125    pub detected_anomalies: usize,
126
127    /// Average gradient quality score
128    pub gradient_quality: f64,
129
130    /// Consistency score across rounds
131    pub consistency_score: f64,
132
133    /// Trust level
134    pub trust_level: TrustLevel,
135}
136
137/// Trust levels for participants
138#[derive(Debug, Clone, Copy)]
139pub enum TrustLevel {
140    /// Highly trusted participant
141    High,
142
143    /// Moderately trusted participant
144    Medium,
145
146    /// Low trust participant
147    Low,
148
149    /// Blacklisted participant
150    Blacklisted,
151}
152
153/// Behavior history for participants
154#[derive(Debug, Clone)]
155pub struct BehaviorHistory {
156    /// History of gradient norms
157    pub gradient_norms: Vec<f64>,
158
159    /// History of gradient directions (cosine similarities)
160    pub gradient_similarities: Vec<f64>,
161
162    /// History of participation patterns
163    pub participation_pattern: Vec<bool>,
164
165    /// History of anomaly scores
166    pub anomaly_scores: Vec<f64>,
167
168    /// Number of rounds participated
169    pub rounds_participated: usize,
170}
171
172/// Anomaly detection engine
173pub struct AnomalyDetector<T: Float + Debug + Send + Sync + 'static> {
174    /// Detection threshold
175    threshold: f64,
176
177    /// Historical gradient statistics
178    gradient_stats: GradientStatistics<T>,
179
180    /// Pattern recognition model
181    pattern_model: PatternModel<T>,
182}
183
184/// Gradient statistics for anomaly detection
185#[derive(Debug, Clone)]
186pub struct GradientStatistics<T: Float + Debug + Send + Sync + 'static> {
187    /// Mean gradient
188    pub mean: Array1<T>,
189
190    /// Gradient covariance matrix
191    pub covariance: Array2<T>,
192
193    /// Historical gradient norms
194    pub norm_history: Vec<T>,
195
196    /// Gradient direction patterns
197    pub direction_patterns: Array2<T>,
198}
199
200/// Pattern recognition model for detecting malicious behavior
201pub struct PatternModel<T: Float + Debug + Send + Sync + 'static> {
202    /// Reference patterns for normal behavior
203    normal_patterns: Vec<Array1<T>>,
204
205    /// Reference patterns for attack behaviors
206    attack_patterns: Vec<Array1<T>>,
207
208    /// Pattern matching threshold
209    matching_threshold: f64,
210}
211
212/// Statistical analysis engine
213pub struct StatisticalAnalysis<T: Float + Debug + Send + Sync + 'static> {
214    /// Window size for statistical analysis
215    window_size: usize,
216
217    /// Statistical measures
218    measures: StatisticalMeasures<T>,
219}
220
221/// Statistical measures for gradient analysis
222#[derive(Debug, Clone)]
223pub struct StatisticalMeasures<T: Float + Debug + Send + Sync + 'static> {
224    /// Mean of gradients
225    pub mean: Array1<T>,
226
227    /// Standard deviation
228    pub std_dev: Array1<T>,
229
230    /// Median
231    pub median: Array1<T>,
232
233    /// Interquartile range
234    pub iqr: Array1<T>,
235
236    /// Skewness
237    pub skewness: Array1<T>,
238
239    /// Kurtosis
240    pub kurtosis: Array1<T>,
241}
242
243/// Gradient verification system
244pub struct GradientVerifier<T: Float + Debug + Send + Sync + 'static> {
245    /// Expected gradient properties
246    expected_properties: GradientProperties<T>,
247
248    /// Verification rules
249    verification_rules: Vec<VerificationRule<T>>,
250}
251
252/// Expected gradient properties
253#[derive(Debug, Clone)]
254pub struct GradientProperties<T: Float + Debug + Send + Sync + 'static> {
255    /// Expected norm range
256    pub norm_range: (T, T),
257
258    /// Expected sparsity
259    pub sparsity_threshold: f64,
260
261    /// Expected direction consistency
262    pub direction_consistency: f64,
263}
264
265/// Verification rule for gradients
266pub struct VerificationRule<T: Float + Debug + Send + Sync + 'static> {
267    /// Rule name
268    pub name: String,
269
270    /// Rule function
271    pub rule_fn: RuleFn<T>,
272
273    /// Rule weight in verification
274    pub weight: f64,
275}
276
277impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
278    ByzantineTolerantAggregator<T>
279{
280    /// Create new Byzantine tolerant aggregator
281    pub fn new(config: ByzantineConfig) -> Self {
282        let anomaly_threshold = config.anomaly_threshold;
283        Self {
284            config,
285            reputation_scores: HashMap::new(),
286            behavior_history: HashMap::new(),
287            anomaly_detector: AnomalyDetector::new(anomaly_threshold),
288            statistics_engine: StatisticalAnalysis::new(100), // 100-round window
289            gradient_verifier: GradientVerifier::new(),
290        }
291    }
292
293    /// Perform Byzantine-robust aggregation
294    pub fn byzantine_robust_aggregate(
295        &mut self,
296        participant_gradients: &HashMap<String, Array1<T>>,
297    ) -> Result<ByzantineAggregationResult<T>> {
298        // Step 1: Pre-filtering based on reputation
299        let filtered_participants = self.filter_by_reputation(participant_gradients)?;
300
301        // Step 2: Anomaly detection
302        let anomaly_results = self.detect_anomalies(&filtered_participants)?;
303
304        // Step 3: Statistical outlier detection
305        let outlier_results = self.detect_statistical_outliers(&filtered_participants)?;
306
307        // Step 4: Gradient verification
308        let verification_results = if self.config.gradient_verification {
309            self.verify_gradients(&filtered_participants)?
310        } else {
311            HashMap::new()
312        };
313
314        // Step 5: Identify Byzantine participants
315        let byzantine_participants = self.identify_byzantine_participants(
316            &anomaly_results,
317            &outlier_results,
318            &verification_results,
319        )?;
320
321        // Step 6: Select honest participants
322        let honest_participants =
323            self.select_honest_participants(&filtered_participants, &byzantine_participants)?;
324
325        // Step 7: Perform robust aggregation
326        let aggregate = self.perform_robust_aggregation(&honest_participants)?;
327
328        // Step 8: Update participant reputations
329        self.update_reputations(&honest_participants, &byzantine_participants)?;
330
331        Ok(ByzantineAggregationResult {
332            aggregate,
333            honest_participants: honest_participants.keys().cloned().collect(),
334            byzantine_participants,
335            reputation_updates: self.get_reputation_updates(),
336            aggregation_method: self.config.aggregation_method,
337            confidence_score: self.calculate_confidence_score(&honest_participants),
338        })
339    }
340
341    /// Filter participants based on reputation scores
342    fn filter_by_reputation(
343        &self,
344        gradients: &HashMap<String, Array1<T>>,
345    ) -> Result<HashMap<String, Array1<T>>> {
346        let mut filtered = HashMap::new();
347
348        for (participant_id, gradient) in gradients {
349            if let Some(reputation) = self.reputation_scores.get(participant_id) {
350                if !matches!(reputation.trust_level, TrustLevel::Blacklisted) {
351                    filtered.insert(participant_id.clone(), gradient.clone());
352                }
353            } else {
354                // New participant - allow with medium trust
355                filtered.insert(participant_id.clone(), gradient.clone());
356            }
357        }
358
359        Ok(filtered)
360    }
361
362    /// Detect anomalies in gradients
363    fn detect_anomalies(
364        &mut self,
365        gradients: &HashMap<String, Array1<T>>,
366    ) -> Result<HashMap<String, AnomalyScore>> {
367        let mut anomaly_results = HashMap::new();
368
369        for (participant_id, gradient) in gradients {
370            let anomaly_score = self.anomaly_detector.detect_anomaly(gradient)?;
371            anomaly_results.insert(participant_id.clone(), anomaly_score);
372        }
373
374        Ok(anomaly_results)
375    }
376
377    /// Detect statistical outliers
378    fn detect_statistical_outliers(
379        &mut self,
380        gradients: &HashMap<String, Array1<T>>,
381    ) -> Result<HashMap<String, OutlierScore>> {
382        let mut outlier_results = HashMap::new();
383
384        // Collect all gradients for statistical analysis
385        let gradient_values: Vec<&Array1<T>> = gradients.values().collect();
386        let stats = self
387            .statistics_engine
388            .compute_statistics(&gradient_values)?;
389
390        for (participant_id, gradient) in gradients {
391            let outlier_score = self.compute_outlier_score(gradient, &stats)?;
392            outlier_results.insert(participant_id.clone(), outlier_score);
393        }
394
395        Ok(outlier_results)
396    }
397
398    /// Verify gradients using verification rules
399    fn verify_gradients(
400        &self,
401        gradients: &HashMap<String, Array1<T>>,
402    ) -> Result<HashMap<String, VerificationScore>> {
403        let mut verification_results = HashMap::new();
404
405        for (participant_id, gradient) in gradients {
406            let verification_score = self.gradient_verifier.verify_gradient(gradient)?;
407            verification_results.insert(participant_id.clone(), verification_score);
408        }
409
410        Ok(verification_results)
411    }
412
413    /// Identify Byzantine participants based on multiple criteria
414    fn identify_byzantine_participants(
415        &self,
416        anomaly_results: &HashMap<String, AnomalyScore>,
417        outlier_results: &HashMap<String, OutlierScore>,
418        verification_results: &HashMap<String, VerificationScore>,
419    ) -> Result<Vec<String>> {
420        let mut byzantine_participants = Vec::new();
421
422        for participant_id in anomaly_results.keys() {
423            let anomaly_score = anomaly_results.get(participant_id).unwrap();
424            let outlier_score = outlier_results.get(participant_id).unwrap();
425            let verification_score = verification_results.get(participant_id);
426
427            // Combine scores to determine if participant is Byzantine
428            let combined_score =
429                self.compute_byzantine_score(anomaly_score, outlier_score, verification_score);
430
431            if combined_score > self.config.anomaly_threshold {
432                byzantine_participants.push(participant_id.clone());
433            }
434        }
435
436        Ok(byzantine_participants)
437    }
438
439    /// Select honest participants for aggregation
440    fn select_honest_participants(
441        &self,
442        all_participants: &HashMap<String, Array1<T>>,
443        byzantine_participants: &[String],
444    ) -> Result<HashMap<String, Array1<T>>> {
445        let mut honest_participants = HashMap::new();
446
447        for (participant_id, gradient) in all_participants {
448            if !byzantine_participants.contains(participant_id) {
449                honest_participants.insert(participant_id.clone(), gradient.clone());
450            }
451        }
452
453        // Ensure we have enough honest _participants
454        if honest_participants.len() < self.config.min_participants {
455            return Err(OptimError::InvalidConfig(
456                "Insufficient honest _participants for aggregation".to_string(),
457            ));
458        }
459
460        Ok(honest_participants)
461    }
462
463    /// Perform robust aggregation using the configured method
464    fn perform_robust_aggregation(
465        &self,
466        honest_gradients: &HashMap<String, Array1<T>>,
467    ) -> Result<Array1<T>> {
468        match self.config.aggregation_method {
469            ByzantineAggregationMethod::TrimmedMean => {
470                self.trimmed_mean_aggregation(honest_gradients)
471            }
472            ByzantineAggregationMethod::CoordinateMedian => {
473                self.coordinate_median_aggregation(honest_gradients)
474            }
475            ByzantineAggregationMethod::Krum => self.krum_aggregation(honest_gradients),
476            ByzantineAggregationMethod::MultiKrum => self.multi_krum_aggregation(honest_gradients),
477            ByzantineAggregationMethod::Bulyan => self.bulyan_aggregation(honest_gradients),
478            ByzantineAggregationMethod::FoolsGold => self.fools_gold_aggregation(honest_gradients),
479            ByzantineAggregationMethod::FLAME => self.flame_aggregation(honest_gradients),
480            ByzantineAggregationMethod::Median => self.median_aggregation(honest_gradients),
481            ByzantineAggregationMethod::GeometricMedian => {
482                self.geometric_median_aggregation(honest_gradients)
483            }
484        }
485    }
486
487    /// Trimmed mean aggregation
488    fn trimmed_mean_aggregation(
489        &self,
490        gradients: &HashMap<String, Array1<T>>,
491    ) -> Result<Array1<T>> {
492        if gradients.is_empty() {
493            return Err(OptimError::InvalidConfig(
494                "No gradients to aggregate".to_string(),
495            ));
496        }
497
498        let values: Vec<&Array1<T>> = gradients.values().collect();
499        let first_gradient = values[0];
500        let dim = first_gradient.len();
501        let mut result = Array1::zeros(dim);
502
503        // For each coordinate, compute trimmed mean
504        for i in 0..dim {
505            let mut coord_values: Vec<T> = values.iter().map(|g| g[i]).collect();
506            coord_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
507
508            // Remove top and bottom 10% (trim parameter), but ensure at least 1 element is trimmed if we have outliers
509            let trim_count = std::cmp::max(1, (coord_values.len() as f64 * 0.1) as usize);
510            let start_idx = std::cmp::min(trim_count, coord_values.len() / 2);
511            let end_idx = std::cmp::max(coord_values.len() - trim_count, coord_values.len() / 2);
512            let trimmed_values = &coord_values[start_idx..end_idx];
513
514            if !trimmed_values.is_empty() {
515                let sum: T = trimmed_values
516                    .iter()
517                    .copied()
518                    .fold(T::zero(), |acc, x| acc + x);
519                result[i] = sum / T::from(trimmed_values.len()).unwrap();
520            }
521        }
522
523        Ok(result)
524    }
525
526    /// Coordinate-wise median aggregation
527    fn coordinate_median_aggregation(
528        &self,
529        gradients: &HashMap<String, Array1<T>>,
530    ) -> Result<Array1<T>> {
531        if gradients.is_empty() {
532            return Err(OptimError::InvalidConfig(
533                "No gradients to aggregate".to_string(),
534            ));
535        }
536
537        let values: Vec<&Array1<T>> = gradients.values().collect();
538        let first_gradient = values[0];
539        let dim = first_gradient.len();
540        let mut result = Array1::zeros(dim);
541
542        // For each coordinate, compute median
543        for i in 0..dim {
544            let mut coord_values: Vec<T> = values.iter().map(|g| g[i]).collect();
545            coord_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
546
547            let median = if coord_values.len().is_multiple_of(2) {
548                let mid = coord_values.len() / 2;
549                (coord_values[mid - 1] + coord_values[mid])
550                    / T::from(2.0).unwrap_or_else(|| T::zero())
551            } else {
552                coord_values[coord_values.len() / 2]
553            };
554
555            result[i] = median;
556        }
557
558        Ok(result)
559    }
560
561    /// Krum aggregation (select single most representative gradient)
562    fn krum_aggregation(&self, gradients: &HashMap<String, Array1<T>>) -> Result<Array1<T>> {
563        if gradients.is_empty() {
564            return Err(OptimError::InvalidConfig(
565                "No gradients to aggregate".to_string(),
566            ));
567        }
568
569        let participants: Vec<&String> = gradients.keys().collect();
570        let mut min_score = T::infinity();
571        let mut selected_gradient = None;
572
573        // For each gradient, compute Krum score
574        for (i, participant) in participants.iter().enumerate() {
575            let gradient = &gradients[*participant];
576            let mut score = T::zero();
577            let mut distances = Vec::new();
578
579            // Compute distances to all other gradients
580            for (j, other_participant) in participants.iter().enumerate() {
581                if i != j {
582                    let other_gradient = &gradients[*other_participant];
583                    let distance = self.compute_euclidean_distance(gradient, other_gradient)?;
584                    distances.push(distance);
585                }
586            }
587
588            // Sort distances and take sum of smallest (n - f - 2) distances
589            distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
590            let take_count = (participants.len() - self.config.max_byzantine - 2).max(1);
591
592            for &distance in distances.iter().take(take_count) {
593                score = score + distance;
594            }
595
596            if score < min_score {
597                min_score = score;
598                selected_gradient = Some(gradient.clone());
599            }
600        }
601
602        selected_gradient.ok_or_else(|| {
603            OptimError::InvalidConfig("Failed to select gradient with Krum".to_string())
604        })
605    }
606
607    /// Multi-Krum aggregation (select multiple representative gradients)
608    fn multi_krum_aggregation(&self, gradients: &HashMap<String, Array1<T>>) -> Result<Array1<T>> {
609        if gradients.is_empty() {
610            return Err(OptimError::InvalidConfig(
611                "No gradients to aggregate".to_string(),
612            ));
613        }
614
615        // Select top-k gradients using Krum scoring
616        let k = (gradients.len() - self.config.max_byzantine).max(1);
617        let selected_gradients = self.select_top_k_krum(gradients, k)?;
618
619        // Average the selected gradients
620        let first_gradient = selected_gradients.values().next().unwrap();
621        let mut result = Array1::zeros(first_gradient.len());
622
623        for gradient in selected_gradients.values() {
624            result = result + gradient;
625        }
626
627        result = result / T::from(selected_gradients.len()).unwrap();
628        Ok(result)
629    }
630
631    /// Bulyan aggregation
632    fn bulyan_aggregation(&self, gradients: &HashMap<String, Array1<T>>) -> Result<Array1<T>> {
633        // Bulyan combines Multi-Krum with trimmed mean
634        let selected_gradients =
635            self.select_top_k_krum(gradients, gradients.len() - self.config.max_byzantine)?;
636        self.trimmed_mean_aggregation(&selected_gradients)
637    }
638
639    /// FoolsGold aggregation (defend against Sybil attacks)
640    fn fools_gold_aggregation(&self, gradients: &HashMap<String, Array1<T>>) -> Result<Array1<T>> {
641        if gradients.is_empty() {
642            return Err(OptimError::InvalidConfig(
643                "No gradients to aggregate".to_string(),
644            ));
645        }
646
647        // Compute learning rates based on historical cosine similarities
648        let learning_rates = self.compute_fools_gold_weights(gradients)?;
649
650        // Weighted aggregation
651        let first_gradient = gradients.values().next().unwrap();
652        let mut result = Array1::zeros(first_gradient.len());
653        let mut total_weight = T::zero();
654
655        for (participant_id, gradient) in gradients {
656            let weight = learning_rates
657                .get(participant_id)
658                .copied()
659                .unwrap_or(T::one());
660            result = result + gradient * weight;
661            total_weight = total_weight + weight;
662        }
663
664        if total_weight > T::zero() {
665            result = result / total_weight;
666        }
667
668        Ok(result)
669    }
670
671    /// FLAME aggregation
672    fn flame_aggregation(&self, gradients: &HashMap<String, Array1<T>>) -> Result<Array1<T>> {
673        // FLAME uses clustering to identify and filter out Byzantine gradients
674        let clusters = self.cluster_gradients(gradients)?;
675        let largest_cluster = self.find_largest_cluster(&clusters)?;
676
677        // Aggregate gradients from the largest cluster
678        self.average_aggregation(&largest_cluster)
679    }
680
681    /// Simple median aggregation
682    fn median_aggregation(&self, gradients: &HashMap<String, Array1<T>>) -> Result<Array1<T>> {
683        self.coordinate_median_aggregation(gradients)
684    }
685
686    /// Geometric median aggregation
687    fn geometric_median_aggregation(
688        &self,
689        gradients: &HashMap<String, Array1<T>>,
690    ) -> Result<Array1<T>> {
691        if gradients.is_empty() {
692            return Err(OptimError::InvalidConfig(
693                "No gradients to aggregate".to_string(),
694            ));
695        }
696
697        // Geometric median using Weiszfeld's algorithm
698        let values: Vec<&Array1<T>> = gradients.values().collect();
699        let first_gradient = values[0];
700        let mut current = first_gradient.clone();
701
702        // Iterative algorithm
703        for _ in 0..100 {
704            // Maximum iterations
705            let mut numerator = Array1::zeros(current.len());
706            let mut denominator = T::zero();
707
708            for &gradient in &values {
709                let distance = self
710                    .compute_euclidean_distance(&current, gradient)
711                    .unwrap_or(T::one());
712                if distance > T::zero() {
713                    let weight = T::one() / distance;
714                    numerator = numerator + gradient * weight;
715                    denominator = denominator + weight;
716                }
717            }
718
719            if denominator > T::zero() {
720                let new_estimate = numerator / denominator;
721
722                // Check convergence
723                let change = self
724                    .compute_euclidean_distance(&current, &new_estimate)
725                    .unwrap_or(T::zero());
726                if change < T::from(1e-6).unwrap_or_else(|| T::zero()) {
727                    break;
728                }
729
730                current = new_estimate;
731            }
732        }
733
734        Ok(current)
735    }
736
737    /// Simple average aggregation
738    fn average_aggregation(&self, gradients: &HashMap<String, Array1<T>>) -> Result<Array1<T>> {
739        if gradients.is_empty() {
740            return Err(OptimError::InvalidConfig(
741                "No gradients to aggregate".to_string(),
742            ));
743        }
744
745        let first_gradient = gradients.values().next().unwrap();
746        let mut result = Array1::zeros(first_gradient.len());
747
748        for gradient in gradients.values() {
749            result = result + gradient;
750        }
751
752        result = result / T::from(gradients.len()).unwrap();
753        Ok(result)
754    }
755
756    /// Update participant reputations based on aggregation results
757    fn update_reputations(
758        &mut self,
759        honest_participants: &HashMap<String, Array1<T>>,
760        byzantine_participants: &[String],
761    ) -> Result<()> {
762        // Update honest _participants (increase reputation)
763        for participant_id in honest_participants.keys() {
764            let reputation = self
765                .reputation_scores
766                .entry(participant_id.clone())
767                .or_default();
768
769            reputation.successful_aggregations += 1;
770            reputation.score = (reputation.score * 0.9 + 0.1).min(1.0);
771
772            // Update trust level based on score
773            reputation.trust_level = match reputation.score {
774                s if s >= 0.8 => TrustLevel::High,
775                s if s >= 0.5 => TrustLevel::Medium,
776                _ => TrustLevel::Low,
777            };
778        }
779
780        // Update Byzantine _participants (decrease reputation)
781        for participant_id in byzantine_participants {
782            let reputation = self
783                .reputation_scores
784                .entry(participant_id.clone())
785                .or_default();
786
787            reputation.detected_anomalies += 1;
788            reputation.score = (reputation.score * 0.5).max(0.0);
789
790            // Update trust level
791            reputation.trust_level = if reputation.score < 0.1 {
792                TrustLevel::Blacklisted
793            } else if reputation.score < 0.3 {
794                TrustLevel::Low
795            } else {
796                TrustLevel::Medium
797            };
798        }
799
800        Ok(())
801    }
802
803    /// Compute Byzantine score combining multiple detection methods
804    fn compute_byzantine_score(
805        &self,
806        anomaly_score: &AnomalyScore,
807        outlier_score: &OutlierScore,
808        verification_score: Option<&VerificationScore>,
809    ) -> f64 {
810        let mut combined_score = 0.0;
811
812        // Weight anomaly score (40%)
813        combined_score += anomaly_score.score * 0.4;
814
815        // Weight outlier score (30%)
816        combined_score += outlier_score.score * 0.3;
817
818        // Weight verification score (30%)
819        if let Some(verification) = verification_score {
820            combined_score += (1.0 - verification.score) * 0.3;
821        }
822
823        combined_score
824    }
825
826    /// Calculate confidence score for aggregation
827    fn calculate_confidence_score(&self, honest_participants: &HashMap<String, Array1<T>>) -> f64 {
828        let honest_count = honest_participants.len() as f64;
829        let total_expected = (self.config.min_participants + self.config.max_byzantine) as f64;
830
831        (honest_count / total_expected).min(1.0)
832    }
833
834    /// Compute Euclidean distance between two gradients
835    fn compute_euclidean_distance(&self, a: &Array1<T>, b: &Array1<T>) -> Result<T> {
836        if a.len() != b.len() {
837            return Err(OptimError::InvalidConfig(
838                "Gradient dimensions don't match".to_string(),
839            ));
840        }
841
842        let mut sum = T::zero();
843        for (x, y) in a.iter().zip(b.iter()) {
844            let diff = *x - *y;
845            sum = sum + diff * diff;
846        }
847
848        Ok(sum.sqrt())
849    }
850
851    /// Select top-k gradients using Krum scoring
852    fn select_top_k_krum(
853        &self,
854        gradients: &HashMap<String, Array1<T>>,
855        k: usize,
856    ) -> Result<HashMap<String, Array1<T>>> {
857        let mut scores = Vec::new();
858        let participants: Vec<&String> = gradients.keys().collect();
859
860        // Compute Krum scores for all participants
861        for (i, participant) in participants.iter().enumerate() {
862            let gradient = &gradients[*participant];
863            let mut score = T::zero();
864            let mut distances = Vec::new();
865
866            for (j, other_participant) in participants.iter().enumerate() {
867                if i != j {
868                    let other_gradient = &gradients[*other_participant];
869                    let distance = self.compute_euclidean_distance(gradient, other_gradient)?;
870                    distances.push(distance);
871                }
872            }
873
874            distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
875            let take_count = (participants.len() - self.config.max_byzantine - 2).max(1);
876
877            for &distance in distances.iter().take(take_count) {
878                score = score + distance;
879            }
880
881            scores.push(((*participant).clone(), score));
882        }
883
884        // Sort by score and select top-k
885        scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
886
887        let mut selected = HashMap::new();
888        for (participant_id, _) in scores.into_iter().take(k) {
889            if let Some(gradient) = gradients.get(&participant_id) {
890                selected.insert(participant_id, gradient.clone());
891            }
892        }
893
894        Ok(selected)
895    }
896
897    /// Compute FoolsGold weights based on historical similarities
898    fn compute_fools_gold_weights(
899        &self,
900        gradients: &HashMap<String, Array1<T>>,
901    ) -> Result<HashMap<String, T>> {
902        let mut weights = HashMap::new();
903
904        for participant_id in gradients.keys() {
905            // Use reputation score as initial weight
906            let base_weight = if let Some(reputation) = self.reputation_scores.get(participant_id) {
907                T::from(reputation.score).unwrap_or_else(|| T::zero())
908            } else {
909                T::one()
910            };
911
912            weights.insert(participant_id.clone(), base_weight);
913        }
914
915        Ok(weights)
916    }
917
918    /// Cluster gradients for FLAME algorithm
919    fn cluster_gradients(
920        &self,
921        gradients: &HashMap<String, Array1<T>>,
922    ) -> Result<Vec<HashMap<String, Array1<T>>>> {
923        // Simple clustering based on cosine similarity
924        let mut clusters = Vec::new();
925        let mut unassigned: HashMap<String, Array1<T>> = gradients.clone();
926
927        while !unassigned.is_empty() {
928            let mut current_cluster = HashMap::new();
929
930            // Start new cluster with first unassigned gradient
931            let (first_id, first_gradient) = unassigned.iter().next().unwrap();
932            let first_id = first_id.clone();
933            let first_gradient = first_gradient.clone();
934
935            current_cluster.insert(first_id.clone(), first_gradient.clone());
936            unassigned.remove(&first_id);
937
938            // Add similar gradients to cluster
939            let mut to_remove = Vec::new();
940            for (participant_id, gradient) in &unassigned {
941                let similarity = self.compute_cosine_similarity(&first_gradient, gradient)?;
942                if similarity > T::from(0.8).unwrap_or_else(|| T::zero()) {
943                    // Similarity threshold
944                    current_cluster.insert(participant_id.clone(), gradient.clone());
945                    to_remove.push(participant_id.clone());
946                }
947            }
948
949            for id in to_remove {
950                unassigned.remove(&id);
951            }
952
953            clusters.push(current_cluster);
954        }
955
956        Ok(clusters)
957    }
958
959    /// Find the largest cluster
960    fn find_largest_cluster(
961        &self,
962        clusters: &[HashMap<String, Array1<T>>],
963    ) -> Result<HashMap<String, Array1<T>>> {
964        clusters
965            .iter()
966            .max_by_key(|cluster| cluster.len())
967            .cloned()
968            .ok_or_else(|| OptimError::InvalidConfig("No clusters found".to_string()))
969    }
970
971    /// Compute cosine similarity between two gradients
972    fn compute_cosine_similarity(&self, a: &Array1<T>, b: &Array1<T>) -> Result<T> {
973        if a.len() != b.len() {
974            return Err(OptimError::InvalidConfig(
975                "Gradient dimensions don't match".to_string(),
976            ));
977        }
978
979        let mut dot_product = T::zero();
980        let mut norm_a = T::zero();
981        let mut norm_b = T::zero();
982
983        for (x, y) in a.iter().zip(b.iter()) {
984            dot_product = dot_product + *x * *y;
985            norm_a = norm_a + *x * *x;
986            norm_b = norm_b + *y * *y;
987        }
988
989        norm_a = norm_a.sqrt();
990        norm_b = norm_b.sqrt();
991
992        if norm_a > T::zero() && norm_b > T::zero() {
993            Ok(dot_product / (norm_a * norm_b))
994        } else {
995            Ok(T::zero())
996        }
997    }
998
999    /// Compute outlier score for a gradient
1000    fn compute_outlier_score(
1001        &self,
1002        gradient: &Array1<T>,
1003        stats: &StatisticalMeasures<T>,
1004    ) -> Result<OutlierScore> {
1005        match self.config.outlier_detection {
1006            OutlierDetectionMethod::ZScore => {
1007                let mut max_z_score = T::zero();
1008
1009                for i in 0..gradient.len() {
1010                    if stats.std_dev[i] > T::zero() {
1011                        let z_score = ((gradient[i] - stats.mean[i]) / stats.std_dev[i]).abs();
1012                        if z_score > max_z_score {
1013                            max_z_score = z_score;
1014                        }
1015                    }
1016                }
1017
1018                Ok(OutlierScore {
1019                    score: max_z_score.to_f64().unwrap_or(0.0),
1020                    method: OutlierDetectionMethod::ZScore,
1021                    details: format!("Max Z-score: {:.4}", max_z_score.to_f64().unwrap_or(0.0)),
1022                })
1023            }
1024
1025            OutlierDetectionMethod::IQR => {
1026                let mut max_iqr_score = 0.0;
1027
1028                for i in 0..gradient.len() {
1029                    let q1 =
1030                        stats.median[i] - stats.iqr[i] / T::from(2.0).unwrap_or_else(|| T::zero());
1031                    let q3 =
1032                        stats.median[i] + stats.iqr[i] / T::from(2.0).unwrap_or_else(|| T::zero());
1033
1034                    if gradient[i] < q1 || gradient[i] > q3 {
1035                        let iqr_score = if gradient[i] < q1 {
1036                            (q1 - gradient[i]) / stats.iqr[i]
1037                        } else {
1038                            (gradient[i] - q3) / stats.iqr[i]
1039                        };
1040
1041                        let score = iqr_score.to_f64().unwrap_or(0.0);
1042                        if score > max_iqr_score {
1043                            max_iqr_score = score;
1044                        }
1045                    }
1046                }
1047
1048                Ok(OutlierScore {
1049                    score: max_iqr_score,
1050                    method: OutlierDetectionMethod::IQR,
1051                    details: format!("Max IQR score: {:.4}", max_iqr_score),
1052                })
1053            }
1054
1055            _ => {
1056                // Fallback to Z-score for other methods
1057                self.compute_outlier_score(gradient, stats)
1058            }
1059        }
1060    }
1061
1062    /// Get reputation updates
1063    fn get_reputation_updates(&self) -> HashMap<String, ReputationScore> {
1064        self.reputation_scores.clone()
1065    }
1066}
1067
1068/// Anomaly score for participant
1069#[derive(Debug, Clone)]
1070pub struct AnomalyScore {
1071    /// Anomaly score (0.0 = normal, 1.0 = highly anomalous)
1072    pub score: f64,
1073
1074    /// Detection method used
1075    pub method: String,
1076
1077    /// Additional details
1078    pub details: String,
1079}
1080
1081/// Outlier score for participant
1082#[derive(Debug, Clone)]
1083pub struct OutlierScore {
1084    /// Outlier score
1085    pub score: f64,
1086
1087    /// Detection method used
1088    pub method: OutlierDetectionMethod,
1089
1090    /// Additional details
1091    pub details: String,
1092}
1093
1094/// Verification score for gradient
1095#[derive(Debug, Clone)]
1096pub struct VerificationScore {
1097    /// Verification score (0.0 = failed, 1.0 = passed)
1098    pub score: f64,
1099
1100    /// Individual rule scores
1101    pub rule_scores: HashMap<String, f64>,
1102
1103    /// Overall verification status
1104    pub passed: bool,
1105}
1106
1107/// Byzantine aggregation result
1108#[derive(Debug, Clone)]
1109pub struct ByzantineAggregationResult<T: Float + Debug + Send + Sync + 'static> {
1110    /// Aggregated gradient
1111    pub aggregate: Array1<T>,
1112
1113    /// List of honest participants
1114    pub honest_participants: Vec<String>,
1115
1116    /// List of detected Byzantine participants
1117    pub byzantine_participants: Vec<String>,
1118
1119    /// Updated reputation scores
1120    pub reputation_updates: HashMap<String, ReputationScore>,
1121
1122    /// Aggregation method used
1123    pub aggregation_method: ByzantineAggregationMethod,
1124
1125    /// Confidence score of the aggregation
1126    pub confidence_score: f64,
1127}
1128
1129impl Default for ReputationScore {
1130    fn default() -> Self {
1131        Self::new()
1132    }
1133}
1134
1135impl ReputationScore {
1136    /// Create new reputation score with default values
1137    pub fn new() -> Self {
1138        Self {
1139            score: 0.7, // Start with medium trust
1140            successful_aggregations: 0,
1141            detected_anomalies: 0,
1142            gradient_quality: 0.5,
1143            consistency_score: 0.5,
1144            trust_level: TrustLevel::Medium,
1145        }
1146    }
1147}
1148
1149impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
1150    AnomalyDetector<T>
1151{
1152    /// Create new anomaly detector
1153    pub fn new(threshold: f64) -> Self {
1154        Self {
1155            threshold,
1156            gradient_stats: GradientStatistics::new(),
1157            pattern_model: PatternModel::new(),
1158        }
1159    }
1160
1161    /// Detect anomaly in gradient
1162    pub fn detect_anomaly(&mut self, gradient: &Array1<T>) -> Result<AnomalyScore> {
1163        // Update gradient statistics
1164        self.gradient_stats.update(gradient)?;
1165
1166        // Compute anomaly score based on deviation from normal patterns
1167        let norm_deviation = self.compute_norm_deviation(gradient)?;
1168        let pattern_deviation = self.pattern_model.compute_pattern_deviation(gradient)?;
1169
1170        let combined_score = (norm_deviation + pattern_deviation) / 2.0;
1171
1172        Ok(AnomalyScore {
1173            score: combined_score,
1174            method: "Combined norm and pattern analysis".to_string(),
1175            details: format!(
1176                "Norm dev: {:.4}, Pattern dev: {:.4}",
1177                norm_deviation, pattern_deviation
1178            ),
1179        })
1180    }
1181
1182    /// Compute norm deviation score
1183    fn compute_norm_deviation(&self, gradient: &Array1<T>) -> Result<f64> {
1184        let gradient_norm = self.compute_l2_norm(gradient);
1185
1186        if self.gradient_stats.norm_history.is_empty() {
1187            return Ok(0.0);
1188        }
1189
1190        // Compute mean and std of historical norms
1191        let mean_norm = self
1192            .gradient_stats
1193            .norm_history
1194            .iter()
1195            .fold(T::zero(), |acc, &x| acc + x)
1196            / T::from(self.gradient_stats.norm_history.len()).unwrap();
1197
1198        let variance = self
1199            .gradient_stats
1200            .norm_history
1201            .iter()
1202            .map(|&x| {
1203                let diff = x - mean_norm;
1204                diff * diff
1205            })
1206            .fold(T::zero(), |acc, x| acc + x)
1207            / T::from(self.gradient_stats.norm_history.len()).unwrap();
1208
1209        let std_norm = variance.sqrt();
1210
1211        if std_norm > T::zero() {
1212            let z_score = ((gradient_norm - mean_norm) / std_norm).abs();
1213            Ok(z_score.to_f64().unwrap_or(0.0) / 3.0) // Normalize to [0,1] approximately
1214        } else {
1215            Ok(0.0)
1216        }
1217    }
1218
1219    /// Compute L2 norm of gradient
1220    fn compute_l2_norm(&self, gradient: &Array1<T>) -> T {
1221        gradient
1222            .iter()
1223            .map(|&x| x * x)
1224            .fold(T::zero(), |acc, x| acc + x)
1225            .sqrt()
1226    }
1227}
1228
1229impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand> Default
1230    for GradientStatistics<T>
1231{
1232    fn default() -> Self {
1233        Self::new()
1234    }
1235}
1236
1237impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
1238    GradientStatistics<T>
1239{
1240    /// Create new gradient statistics
1241    pub fn new() -> Self {
1242        Self {
1243            mean: Array1::zeros(0),
1244            covariance: Array2::zeros((0, 0)),
1245            norm_history: Vec::new(),
1246            direction_patterns: Array2::zeros((0, 0)),
1247        }
1248    }
1249
1250    /// Update statistics with new gradient
1251    pub fn update(&mut self, gradient: &Array1<T>) -> Result<()> {
1252        // Update norm history
1253        let norm = gradient
1254            .iter()
1255            .map(|&x| x * x)
1256            .fold(T::zero(), |acc, x| acc + x)
1257            .sqrt();
1258
1259        self.norm_history.push(norm);
1260
1261        // Keep only recent history (last 1000 gradients)
1262        if self.norm_history.len() > 1000 {
1263            self.norm_history.remove(0);
1264        }
1265
1266        // Initialize mean if this is the first gradient
1267        if self.mean.is_empty() {
1268            self.mean = gradient.clone();
1269        } else if self.mean.len() == gradient.len() {
1270            // Update running mean
1271            let alpha = T::from(0.01).unwrap_or_else(|| T::zero()); // Learning rate for running average
1272            self.mean = &self.mean * (T::one() - alpha) + gradient * alpha;
1273        }
1274
1275        Ok(())
1276    }
1277}
1278
1279impl<T: Float + Debug + Send + Sync + 'static> Default for PatternModel<T> {
1280    fn default() -> Self {
1281        Self::new()
1282    }
1283}
1284
1285impl<T: Float + Debug + Send + Sync + 'static> PatternModel<T> {
1286    /// Create new pattern model
1287    pub fn new() -> Self {
1288        Self {
1289            normal_patterns: Vec::new(),
1290            attack_patterns: Vec::new(),
1291            matching_threshold: 0.8,
1292        }
1293    }
1294
1295    /// Compute pattern deviation score
1296    pub fn compute_pattern_deviation(&self, gradient: &Array1<T>) -> Result<f64> {
1297        // If no patterns learned yet, return neutral score
1298        if self.normal_patterns.is_empty() {
1299            return Ok(0.5);
1300        }
1301
1302        // Find closest normal pattern
1303        let mut min_distance = T::infinity();
1304        for pattern in &self.normal_patterns {
1305            if pattern.len() == gradient.len() {
1306                let distance = self.compute_pattern_distance(gradient, pattern)?;
1307                if distance < min_distance {
1308                    min_distance = distance;
1309                }
1310            }
1311        }
1312
1313        // Normalize distance to [0,1] score
1314        let max_expected_distance = T::from(10.0).unwrap_or_else(|| T::zero()); // Tunable parameter
1315        let deviation_score = (min_distance / max_expected_distance).min(T::one());
1316
1317        Ok(deviation_score.to_f64().unwrap_or(0.5))
1318    }
1319
1320    /// Compute distance between gradient and pattern
1321    fn compute_pattern_distance(&self, gradient: &Array1<T>, pattern: &Array1<T>) -> Result<T> {
1322        if gradient.len() != pattern.len() {
1323            return Err(OptimError::InvalidConfig("Dimension mismatch".to_string()));
1324        }
1325
1326        let mut sum = T::zero();
1327        for (g, p) in gradient.iter().zip(pattern.iter()) {
1328            let diff = *g - *p;
1329            sum = sum + diff * diff;
1330        }
1331
1332        Ok(sum.sqrt())
1333    }
1334}
1335
1336impl<T: Float + Debug + Send + Sync + 'static> StatisticalAnalysis<T> {
1337    /// Create new statistical analysis engine
1338    pub fn new(_windowsize: usize) -> Self {
1339        Self {
1340            window_size: _windowsize,
1341            measures: StatisticalMeasures::new(),
1342        }
1343    }
1344
1345    /// Compute statistical measures for gradients
1346    pub fn compute_statistics(
1347        &mut self,
1348        gradients: &[&Array1<T>],
1349    ) -> Result<StatisticalMeasures<T>> {
1350        if gradients.is_empty() {
1351            return Err(OptimError::InvalidConfig(
1352                "No gradients provided".to_string(),
1353            ));
1354        }
1355
1356        let first_gradient = gradients[0];
1357        let dim = first_gradient.len();
1358
1359        let mut mean = Array1::zeros(dim);
1360        let mut median = Array1::zeros(dim);
1361        let mut std_dev = Array1::zeros(dim);
1362        let mut iqr = Array1::zeros(dim);
1363        let mut skewness = Array1::zeros(dim);
1364        let mut kurtosis = Array1::zeros(dim);
1365
1366        // Compute statistics for each dimension
1367        for i in 0..dim {
1368            let mut values: Vec<T> = gradients.iter().map(|g| g[i]).collect();
1369
1370            // Mean
1371            let sum: T = values.iter().copied().fold(T::zero(), |acc, x| acc + x);
1372            mean[i] = sum / T::from(values.len()).unwrap();
1373
1374            // Sort for median and IQR
1375            values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
1376
1377            // Median
1378            median[i] = if values.len().is_multiple_of(2) {
1379                let mid = values.len() / 2;
1380                (values[mid - 1] + values[mid]) / T::from(2.0).unwrap_or_else(|| T::zero())
1381            } else {
1382                values[values.len() / 2]
1383            };
1384
1385            // Standard deviation
1386            let variance: T = gradients
1387                .iter()
1388                .map(|g| {
1389                    let diff = g[i] - mean[i];
1390                    diff * diff
1391                })
1392                .fold(T::zero(), |acc, x| acc + x)
1393                / T::from(gradients.len()).unwrap();
1394            std_dev[i] = variance.sqrt();
1395
1396            // Interquartile range
1397            let q1_idx = values.len() / 4;
1398            let q3_idx = 3 * values.len() / 4;
1399            iqr[i] = values[q3_idx] - values[q1_idx];
1400
1401            // Skewness and kurtosis (simplified calculations)
1402            skewness[i] = T::zero(); // Placeholder
1403            kurtosis[i] = T::zero(); // Placeholder
1404        }
1405
1406        self.measures = StatisticalMeasures {
1407            mean,
1408            std_dev,
1409            median,
1410            iqr,
1411            skewness,
1412            kurtosis,
1413        };
1414
1415        Ok(self.measures.clone())
1416    }
1417}
1418
1419impl<T: Float + Debug + Send + Sync + 'static> Default for StatisticalMeasures<T> {
1420    fn default() -> Self {
1421        Self::new()
1422    }
1423}
1424
1425impl<T: Float + Debug + Send + Sync + 'static> StatisticalMeasures<T> {
1426    /// Create new statistical measures
1427    pub fn new() -> Self {
1428        Self {
1429            mean: Array1::zeros(0),
1430            std_dev: Array1::zeros(0),
1431            median: Array1::zeros(0),
1432            iqr: Array1::zeros(0),
1433            skewness: Array1::zeros(0),
1434            kurtosis: Array1::zeros(0),
1435        }
1436    }
1437}
1438
1439impl<T: Float + Debug + Send + Sync + 'static> Default for GradientVerifier<T> {
1440    fn default() -> Self {
1441        Self::new()
1442    }
1443}
1444
1445impl<T: Float + Debug + Send + Sync + 'static> GradientVerifier<T> {
1446    /// Create new gradient verifier
1447    pub fn new() -> Self {
1448        let verification_rules = vec![VerificationRule {
1449            name: "Finite values".to_string(),
1450            rule_fn: Box::new(|gradient: &Array1<T>| gradient.iter().all(|&x| x.is_finite())),
1451            weight: 1.0,
1452        }];
1453
1454        Self {
1455            expected_properties: GradientProperties::new(),
1456            verification_rules,
1457        }
1458    }
1459
1460    /// Verify gradient using all rules
1461    pub fn verify_gradient(&self, gradient: &Array1<T>) -> Result<VerificationScore> {
1462        let mut rule_scores = HashMap::new();
1463        let mut total_weight = 0.0;
1464        let mut weighted_score = 0.0;
1465
1466        for rule in &self.verification_rules {
1467            let passed = (rule.rule_fn)(gradient);
1468            let score = if passed { 1.0 } else { 0.0 };
1469
1470            rule_scores.insert(rule.name.clone(), score);
1471            weighted_score += score * rule.weight;
1472            total_weight += rule.weight;
1473        }
1474
1475        let overall_score = if total_weight > 0.0 {
1476            weighted_score / total_weight
1477        } else {
1478            1.0
1479        };
1480
1481        Ok(VerificationScore {
1482            score: overall_score,
1483            rule_scores,
1484            passed: overall_score >= 0.8,
1485        })
1486    }
1487}
1488
1489impl<T: Float + Debug + Send + Sync + 'static> Default for GradientProperties<T> {
1490    fn default() -> Self {
1491        Self::new()
1492    }
1493}
1494
1495impl<T: Float + Debug + Send + Sync + 'static> GradientProperties<T> {
1496    /// Create new gradient properties
1497    pub fn new() -> Self {
1498        Self {
1499            norm_range: (T::zero(), T::from(100.0).unwrap_or_else(|| T::zero())),
1500            sparsity_threshold: 0.1,
1501            direction_consistency: 0.8,
1502        }
1503    }
1504}
1505
1506#[cfg(test)]
1507mod tests {
1508    use super::*;
1509    use scirs2_core::ndarray::Array1;
1510    use std::collections::HashMap;
1511
1512    #[test]
1513    fn test_byzantine_config() {
1514        let config = ByzantineConfig {
1515            max_byzantine: 2,
1516            min_participants: 5,
1517            aggregation_method: ByzantineAggregationMethod::Krum,
1518            anomaly_threshold: 0.5,
1519            reputation_decay: 0.9,
1520            gradient_verification: true,
1521            outlier_detection: OutlierDetectionMethod::ZScore,
1522            consensus_threshold: 0.7,
1523        };
1524
1525        assert_eq!(config.max_byzantine, 2);
1526        assert_eq!(config.min_participants, 5);
1527    }
1528
1529    #[test]
1530    fn test_reputation_score() {
1531        let mut reputation = ReputationScore::new();
1532        assert_eq!(reputation.score, 0.7);
1533        assert_eq!(reputation.successful_aggregations, 0);
1534        assert!(matches!(reputation.trust_level, TrustLevel::Medium));
1535
1536        reputation.successful_aggregations += 1;
1537        reputation.score = 0.9;
1538        assert_eq!(reputation.successful_aggregations, 1);
1539    }
1540
1541    #[test]
1542    fn test_trimmed_mean_aggregation() {
1543        let config = ByzantineConfig {
1544            max_byzantine: 1,
1545            min_participants: 3,
1546            aggregation_method: ByzantineAggregationMethod::TrimmedMean,
1547            anomaly_threshold: 0.5,
1548            reputation_decay: 0.9,
1549            gradient_verification: false,
1550            outlier_detection: OutlierDetectionMethod::ZScore,
1551            consensus_threshold: 0.7,
1552        };
1553
1554        let aggregator = ByzantineTolerantAggregator::new(config);
1555
1556        let mut gradients = HashMap::new();
1557        gradients.insert("client1".to_string(), Array1::from(vec![1.0, 2.0, 3.0]));
1558        gradients.insert("client2".to_string(), Array1::from(vec![1.1, 2.1, 3.1]));
1559        gradients.insert("client3".to_string(), Array1::from(vec![0.9, 1.9, 2.9]));
1560        gradients.insert("client4".to_string(), Array1::from(vec![1.0, 2.0, 3.0]));
1561        gradients.insert("client5".to_string(), Array1::from(vec![10.0, 20.0, 30.0])); // Outlier
1562
1563        let result = aggregator.trimmed_mean_aggregation(&gradients).unwrap();
1564
1565        // Should be close to [1.0, 2.0, 3.0] after trimming outliers
1566        assert!((result[0] - 1.0).abs() < 0.2);
1567        assert!((result[1] - 2.0).abs() < 0.2);
1568        assert!((result[2] - 3.0).abs() < 0.2);
1569    }
1570
1571    #[test]
1572    fn test_coordinate_median_aggregation() {
1573        let config = ByzantineConfig {
1574            max_byzantine: 1,
1575            min_participants: 3,
1576            aggregation_method: ByzantineAggregationMethod::CoordinateMedian,
1577            anomaly_threshold: 0.5,
1578            reputation_decay: 0.9,
1579            gradient_verification: false,
1580            outlier_detection: OutlierDetectionMethod::ZScore,
1581            consensus_threshold: 0.7,
1582        };
1583
1584        let aggregator = ByzantineTolerantAggregator::new(config);
1585
1586        let mut gradients = HashMap::new();
1587        gradients.insert("client1".to_string(), Array1::from(vec![1.0, 2.0, 3.0]));
1588        gradients.insert("client2".to_string(), Array1::from(vec![2.0, 3.0, 4.0]));
1589        gradients.insert("client3".to_string(), Array1::from(vec![3.0, 4.0, 5.0]));
1590
1591        let result = aggregator
1592            .coordinate_median_aggregation(&gradients)
1593            .unwrap();
1594
1595        // Median should be [2.0, 3.0, 4.0]
1596        assert_eq!(result[0], 2.0);
1597        assert_eq!(result[1], 3.0);
1598        assert_eq!(result[2], 4.0);
1599    }
1600
1601    #[test]
1602    fn test_euclidean_distance() {
1603        let config = ByzantineConfig {
1604            max_byzantine: 1,
1605            min_participants: 3,
1606            aggregation_method: ByzantineAggregationMethod::Krum,
1607            anomaly_threshold: 0.5,
1608            reputation_decay: 0.9,
1609            gradient_verification: false,
1610            outlier_detection: OutlierDetectionMethod::ZScore,
1611            consensus_threshold: 0.7,
1612        };
1613
1614        let aggregator = ByzantineTolerantAggregator::new(config);
1615
1616        let a = Array1::from(vec![1.0, 2.0, 3.0]);
1617        let b = Array1::from(vec![4.0, 5.0, 6.0]);
1618
1619        let distance = aggregator.compute_euclidean_distance(&a, &b).unwrap();
1620        let expected = (3.0_f64.powi(2) + 3.0_f64.powi(2) + 3.0_f64.powi(2)).sqrt();
1621
1622        assert!((distance - expected).abs() < 1e-10);
1623    }
1624}