quantrs2_ml/
adversarial.rs

1//! Quantum Adversarial Training
2//!
3//! This module implements adversarial training methods for quantum neural networks,
4//! including adversarial attack generation, robust training procedures, and
5//! defense mechanisms against quantum adversarial examples.
6
7use crate::autodiff::optimizers::Optimizer;
8use crate::error::{MLError, Result};
9use crate::optimization::OptimizationMethod;
10use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
11use quantrs2_circuit::builder::{Circuit, Simulator};
12use quantrs2_core::gate::{
13    single::{RotationX, RotationY, RotationZ},
14    GateOp,
15};
16use quantrs2_sim::statevector::StateVectorSimulator;
17use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
18use scirs2_core::random::prelude::*;
19use std::collections::HashMap;
20use std::f64::consts::PI;
21
22/// Types of adversarial attacks for quantum models
23#[derive(Debug, Clone)]
24pub enum QuantumAttackType {
25    /// Fast Gradient Sign Method adapted for quantum circuits
26    FGSM { epsilon: f64 },
27
28    /// Projected Gradient Descent attack
29    PGD {
30        epsilon: f64,
31        alpha: f64,
32        num_steps: usize,
33    },
34
35    /// Quantum Parameter Shift attack
36    ParameterShift {
37        shift_magnitude: f64,
38        target_parameters: Option<Vec<usize>>,
39    },
40
41    /// Quantum State Perturbation attack
42    StatePerturbation {
43        perturbation_strength: f64,
44        basis: String,
45    },
46
47    /// Quantum Circuit Manipulation attack
48    CircuitManipulation {
49        gate_error_rate: f64,
50        coherence_time: f64,
51    },
52
53    /// Universal Adversarial Perturbation for quantum inputs
54    UniversalPerturbation {
55        perturbation_budget: f64,
56        success_rate_threshold: f64,
57    },
58}
59
60/// Defense strategies against quantum adversarial attacks
61#[derive(Debug, Clone)]
62pub enum QuantumDefenseStrategy {
63    /// Adversarial training with generated examples
64    AdversarialTraining {
65        attack_types: Vec<QuantumAttackType>,
66        adversarial_ratio: f64,
67    },
68
69    /// Quantum error correction as defense
70    QuantumErrorCorrection {
71        code_type: String,
72        correction_threshold: f64,
73    },
74
75    /// Input preprocessing and sanitization
76    InputPreprocessing {
77        noise_addition: f64,
78        feature_squeezing: bool,
79    },
80
81    /// Ensemble defense with multiple quantum models
82    EnsembleDefense {
83        num_models: usize,
84        diversity_metric: String,
85    },
86
87    /// Certified defense with provable bounds
88    CertifiedDefense {
89        smoothing_variance: f64,
90        confidence_level: f64,
91    },
92
93    /// Randomized circuit defense
94    RandomizedCircuit {
95        randomization_strength: f64,
96        num_random_layers: usize,
97    },
98}
99
100/// Adversarial example for quantum neural networks
101#[derive(Debug, Clone)]
102pub struct QuantumAdversarialExample {
103    /// Original input
104    pub original_input: Array1<f64>,
105
106    /// Adversarial input
107    pub adversarial_input: Array1<f64>,
108
109    /// Original prediction
110    pub original_prediction: Array1<f64>,
111
112    /// Adversarial prediction
113    pub adversarial_prediction: Array1<f64>,
114
115    /// True label
116    pub true_label: usize,
117
118    /// Perturbation magnitude
119    pub perturbation_norm: f64,
120
121    /// Attack success (caused misclassification)
122    pub attack_success: bool,
123
124    /// Attack metadata
125    pub metadata: HashMap<String, f64>,
126}
127
128/// Quantum adversarial trainer
129pub struct QuantumAdversarialTrainer {
130    /// Base quantum model
131    model: QuantumNeuralNetwork,
132
133    /// Defense strategy
134    defense_strategy: QuantumDefenseStrategy,
135
136    /// Training configuration
137    config: AdversarialTrainingConfig,
138
139    /// Attack history
140    attack_history: Vec<QuantumAdversarialExample>,
141
142    /// Robustness metrics
143    robustness_metrics: RobustnessMetrics,
144
145    /// Ensemble models (for ensemble defense)
146    ensemble_models: Vec<QuantumNeuralNetwork>,
147}
148
149/// Configuration for adversarial training
150#[derive(Debug, Clone)]
151pub struct AdversarialTrainingConfig {
152    /// Number of training epochs
153    pub epochs: usize,
154
155    /// Batch size
156    pub batch_size: usize,
157
158    /// Learning rate
159    pub learning_rate: f64,
160
161    /// Adversarial example generation frequency
162    pub adversarial_frequency: usize,
163
164    /// Maximum perturbation budget
165    pub max_perturbation: f64,
166
167    /// Robustness evaluation interval
168    pub eval_interval: usize,
169
170    /// Early stopping criteria
171    pub early_stopping: Option<EarlyStoppingCriteria>,
172}
173
174/// Early stopping criteria for adversarial training
175#[derive(Debug, Clone)]
176pub struct EarlyStoppingCriteria {
177    /// Minimum clean accuracy
178    pub min_clean_accuracy: f64,
179
180    /// Minimum robust accuracy
181    pub min_robust_accuracy: f64,
182
183    /// Patience (epochs without improvement)
184    pub patience: usize,
185}
186
187/// Robustness metrics
188#[derive(Debug, Clone)]
189pub struct RobustnessMetrics {
190    /// Clean accuracy (on unperturbed data)
191    pub clean_accuracy: f64,
192
193    /// Robust accuracy (on adversarial examples)
194    pub robust_accuracy: f64,
195
196    /// Average perturbation norm for successful attacks
197    pub avg_perturbation_norm: f64,
198
199    /// Attack success rate
200    pub attack_success_rate: f64,
201
202    /// Certified accuracy (for certified defenses)
203    pub certified_accuracy: Option<f64>,
204
205    /// Per-attack-type metrics
206    pub per_attack_metrics: HashMap<String, AttackMetrics>,
207}
208
209/// Metrics for specific attack types
210#[derive(Debug, Clone)]
211pub struct AttackMetrics {
212    /// Success rate
213    pub success_rate: f64,
214
215    /// Average perturbation
216    pub avg_perturbation: f64,
217
218    /// Average confidence drop
219    pub avg_confidence_drop: f64,
220}
221
222impl QuantumAdversarialTrainer {
223    /// Create a new quantum adversarial trainer
224    pub fn new(
225        model: QuantumNeuralNetwork,
226        defense_strategy: QuantumDefenseStrategy,
227        config: AdversarialTrainingConfig,
228    ) -> Self {
229        Self {
230            model,
231            defense_strategy,
232            config,
233            attack_history: Vec::new(),
234            robustness_metrics: RobustnessMetrics {
235                clean_accuracy: 0.0,
236                robust_accuracy: 0.0,
237                avg_perturbation_norm: 0.0,
238                attack_success_rate: 0.0,
239                certified_accuracy: None,
240                per_attack_metrics: HashMap::new(),
241            },
242            ensemble_models: Vec::new(),
243        }
244    }
245
246    /// Train the model with adversarial training
247    pub fn train(
248        &mut self,
249        train_data: &Array2<f64>,
250        train_labels: &Array1<usize>,
251        val_data: &Array2<f64>,
252        val_labels: &Array1<usize>,
253        optimizer: &mut dyn Optimizer,
254    ) -> Result<Vec<f64>> {
255        println!("Starting quantum adversarial training...");
256
257        let mut losses = Vec::new();
258        let mut patience_counter = 0;
259        let mut best_robust_accuracy = 0.0;
260
261        // Initialize ensemble if needed
262        self.initialize_ensemble()?;
263
264        for epoch in 0..self.config.epochs {
265            let mut epoch_loss = 0.0;
266            let num_batches =
267                (train_data.nrows() + self.config.batch_size - 1) / self.config.batch_size;
268
269            for batch_idx in 0..num_batches {
270                let batch_start = batch_idx * self.config.batch_size;
271                let batch_end = (batch_start + self.config.batch_size).min(train_data.nrows());
272
273                let batch_data = train_data.slice(s![batch_start..batch_end, ..]).to_owned();
274                let batch_labels = train_labels.slice(s![batch_start..batch_end]).to_owned();
275
276                // Generate adversarial examples if needed
277                let (final_data, final_labels) = if epoch % self.config.adversarial_frequency == 0 {
278                    self.generate_adversarial_batch(&batch_data, &batch_labels)?
279                } else {
280                    (batch_data, batch_labels)
281                };
282
283                // Compute loss and update model
284                let batch_loss = self.train_batch(&final_data, &final_labels, optimizer)?;
285                epoch_loss += batch_loss;
286            }
287
288            epoch_loss /= num_batches as f64;
289            losses.push(epoch_loss);
290
291            // Evaluate robustness periodically
292            if epoch % self.config.eval_interval == 0 {
293                self.evaluate_robustness(val_data, val_labels)?;
294
295                println!(
296                    "Epoch {}: Loss = {:.4}, Clean Acc = {:.3}, Robust Acc = {:.3}",
297                    epoch,
298                    epoch_loss,
299                    self.robustness_metrics.clean_accuracy,
300                    self.robustness_metrics.robust_accuracy
301                );
302
303                // Early stopping check
304                if let Some(ref criteria) = self.config.early_stopping {
305                    if self.robustness_metrics.robust_accuracy > best_robust_accuracy {
306                        best_robust_accuracy = self.robustness_metrics.robust_accuracy;
307                        patience_counter = 0;
308                    } else {
309                        patience_counter += 1;
310                    }
311
312                    if patience_counter >= criteria.patience {
313                        println!("Early stopping triggered at epoch {}", epoch);
314                        break;
315                    }
316
317                    if self.robustness_metrics.clean_accuracy < criteria.min_clean_accuracy
318                        || self.robustness_metrics.robust_accuracy < criteria.min_robust_accuracy
319                    {
320                        println!("Minimum performance criteria not met, stopping training");
321                        break;
322                    }
323                }
324            }
325        }
326
327        // Final robustness evaluation
328        self.evaluate_robustness(val_data, val_labels)?;
329
330        Ok(losses)
331    }
332
333    /// Generate adversarial examples using specified attack
334    pub fn generate_adversarial_examples(
335        &self,
336        data: &Array2<f64>,
337        labels: &Array1<usize>,
338        attack_type: QuantumAttackType,
339    ) -> Result<Vec<QuantumAdversarialExample>> {
340        let mut adversarial_examples = Vec::new();
341
342        for (i, (input, &label)) in data.outer_iter().zip(labels.iter()).enumerate() {
343            let adversarial_example = self.generate_single_adversarial_example(
344                &input.to_owned(),
345                label,
346                attack_type.clone(),
347            )?;
348
349            adversarial_examples.push(adversarial_example);
350        }
351
352        Ok(adversarial_examples)
353    }
354
355    /// Generate a single adversarial example
356    fn generate_single_adversarial_example(
357        &self,
358        input: &Array1<f64>,
359        true_label: usize,
360        attack_type: QuantumAttackType,
361    ) -> Result<QuantumAdversarialExample> {
362        // Get original prediction
363        let original_prediction = self.model.forward(input)?;
364
365        let adversarial_input = match attack_type {
366            QuantumAttackType::FGSM { epsilon } => self.fgsm_attack(input, true_label, epsilon)?,
367            QuantumAttackType::PGD {
368                epsilon,
369                alpha,
370                num_steps,
371            } => self.pgd_attack(input, true_label, epsilon, alpha, num_steps)?,
372            QuantumAttackType::ParameterShift {
373                shift_magnitude,
374                target_parameters,
375            } => self.parameter_shift_attack(input, shift_magnitude, target_parameters)?,
376            QuantumAttackType::StatePerturbation {
377                perturbation_strength,
378                ref basis,
379            } => self.state_perturbation_attack(input, perturbation_strength, basis)?,
380            QuantumAttackType::CircuitManipulation {
381                gate_error_rate,
382                coherence_time,
383            } => self.circuit_manipulation_attack(input, gate_error_rate, coherence_time)?,
384            QuantumAttackType::UniversalPerturbation {
385                perturbation_budget,
386                success_rate_threshold,
387            } => self.universal_perturbation_attack(input, perturbation_budget)?,
388        };
389
390        // Get adversarial prediction
391        let adversarial_prediction = self.model.forward(&adversarial_input)?;
392
393        // Compute perturbation norm
394        let perturbation = &adversarial_input - input;
395        let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
396
397        // Check attack success
398        let original_class = original_prediction
399            .iter()
400            .enumerate()
401            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
402            .map(|(i, _)| i)
403            .unwrap_or(0);
404
405        let adversarial_class = adversarial_prediction
406            .iter()
407            .enumerate()
408            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
409            .map(|(i, _)| i)
410            .unwrap_or(0);
411
412        let attack_success = original_class != adversarial_class;
413
414        Ok(QuantumAdversarialExample {
415            original_input: input.clone(),
416            adversarial_input,
417            original_prediction,
418            adversarial_prediction,
419            true_label,
420            perturbation_norm,
421            attack_success,
422            metadata: HashMap::new(),
423        })
424    }
425
426    /// Fast Gradient Sign Method (FGSM) for quantum circuits
427    fn fgsm_attack(
428        &self,
429        input: &Array1<f64>,
430        true_label: usize,
431        epsilon: f64,
432    ) -> Result<Array1<f64>> {
433        // Compute gradient of loss w.r.t. input
434        let gradient = self.compute_input_gradient(input, true_label)?;
435
436        // Apply FGSM perturbation
437        let perturbation = gradient.mapv(|g| epsilon * g.signum());
438        let adversarial_input = input + &perturbation;
439
440        // Clip to valid range [0, 1] for quantum inputs
441        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
442    }
443
444    /// Projected Gradient Descent (PGD) attack
445    fn pgd_attack(
446        &self,
447        input: &Array1<f64>,
448        true_label: usize,
449        epsilon: f64,
450        alpha: f64,
451        num_steps: usize,
452    ) -> Result<Array1<f64>> {
453        let mut adversarial_input = input.clone();
454
455        for _ in 0..num_steps {
456            // Compute gradient
457            let gradient = self.compute_input_gradient(&adversarial_input, true_label)?;
458
459            // Take gradient step
460            let perturbation = gradient.mapv(|g| alpha * g.signum());
461            adversarial_input = &adversarial_input + &perturbation;
462
463            // Project back to epsilon ball
464            let total_perturbation = &adversarial_input - input;
465            let perturbation_norm = total_perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
466
467            if perturbation_norm > epsilon {
468                let scaling = epsilon / perturbation_norm;
469                adversarial_input = input + &(total_perturbation * scaling);
470            }
471
472            // Clip to valid range
473            adversarial_input = adversarial_input.mapv(|x| x.max(0.0).min(1.0));
474        }
475
476        Ok(adversarial_input)
477    }
478
479    /// Parameter shift attack targeting quantum circuit parameters
480    fn parameter_shift_attack(
481        &self,
482        input: &Array1<f64>,
483        shift_magnitude: f64,
484        target_parameters: Option<Vec<usize>>,
485    ) -> Result<Array1<f64>> {
486        // This attack modifies the input to exploit parameter shift rules
487        let mut adversarial_input = input.clone();
488
489        // Apply parameter shift-inspired perturbations
490        for i in 0..adversarial_input.len() {
491            if let Some(ref targets) = target_parameters {
492                if !targets.contains(&i) {
493                    continue;
494                }
495            }
496
497            // Use parameter shift rule: f(x + π/2) - f(x - π/2)
498            let shift = shift_magnitude * (PI / 2.0);
499            adversarial_input[i] += shift * (2.0 * thread_rng().gen::<f64>() - 1.0);
500        }
501
502        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
503    }
504
505    /// Quantum state perturbation attack
506    fn state_perturbation_attack(
507        &self,
508        input: &Array1<f64>,
509        perturbation_strength: f64,
510        basis: &str,
511    ) -> Result<Array1<f64>> {
512        let mut adversarial_input = input.clone();
513
514        match basis {
515            "pauli_x" => {
516                // Apply X-basis perturbations
517                for i in 0..adversarial_input.len() {
518                    let angle = adversarial_input[i] * PI;
519                    let perturbed_angle =
520                        angle + perturbation_strength * (2.0 * thread_rng().gen::<f64>() - 1.0);
521                    adversarial_input[i] = perturbed_angle / PI;
522                }
523            }
524            "pauli_y" => {
525                // Apply Y-basis perturbations
526                for i in 0..adversarial_input.len() {
527                    adversarial_input[i] +=
528                        perturbation_strength * (2.0 * thread_rng().gen::<f64>() - 1.0);
529                }
530            }
531            "pauli_z" | _ => {
532                // Apply Z-basis perturbations (default)
533                for i in 0..adversarial_input.len() {
534                    let phase_shift =
535                        perturbation_strength * (2.0 * thread_rng().gen::<f64>() - 1.0);
536                    adversarial_input[i] =
537                        (adversarial_input[i] + phase_shift / (2.0 * PI)).fract();
538                }
539            }
540        }
541
542        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
543    }
544
545    /// Circuit manipulation attack (simulating hardware errors)
546    fn circuit_manipulation_attack(
547        &self,
548        input: &Array1<f64>,
549        gate_error_rate: f64,
550        coherence_time: f64,
551    ) -> Result<Array1<f64>> {
552        let mut adversarial_input = input.clone();
553
554        // Simulate decoherence effects
555        for i in 0..adversarial_input.len() {
556            // Apply T1 decay
557            let t1_factor = (-1.0 / coherence_time).exp();
558            adversarial_input[i] *= t1_factor;
559
560            // Add gate errors
561            if thread_rng().gen::<f64>() < gate_error_rate {
562                adversarial_input[i] += 0.1 * (2.0 * thread_rng().gen::<f64>() - 1.0);
563            }
564        }
565
566        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
567    }
568
569    /// Universal adversarial perturbation attack
570    fn universal_perturbation_attack(
571        &self,
572        input: &Array1<f64>,
573        perturbation_budget: f64,
574    ) -> Result<Array1<f64>> {
575        // Apply a learned universal perturbation (simplified)
576        let mut adversarial_input = input.clone();
577
578        // Generate universal perturbation pattern
579        for i in 0..adversarial_input.len() {
580            let universal_component =
581                perturbation_budget * (2.0 * PI * i as f64 / adversarial_input.len() as f64).sin();
582            adversarial_input[i] += universal_component;
583        }
584
585        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
586    }
587
588    /// Compute gradient of loss with respect to input
589    fn compute_input_gradient(
590        &self,
591        input: &Array1<f64>,
592        true_label: usize,
593    ) -> Result<Array1<f64>> {
594        // Placeholder for gradient computation
595        // In practice, this would use automatic differentiation
596        let mut gradient = Array1::zeros(input.len());
597
598        // Finite difference approximation
599        let h = 1e-5;
600        let original_output = self.model.forward(input)?;
601        let original_loss = self.compute_loss(&original_output, true_label);
602
603        for i in 0..input.len() {
604            let mut perturbed_input = input.clone();
605            perturbed_input[i] += h;
606
607            let perturbed_output = self.model.forward(&perturbed_input)?;
608            let perturbed_loss = self.compute_loss(&perturbed_output, true_label);
609
610            gradient[i] = (perturbed_loss - original_loss) / h;
611        }
612
613        Ok(gradient)
614    }
615
616    /// Compute loss for a given output and true label
617    fn compute_loss(&self, output: &Array1<f64>, true_label: usize) -> f64 {
618        // Cross-entropy loss
619        let predicted_prob = output[true_label].max(1e-10);
620        -predicted_prob.ln()
621    }
622
623    /// Generate adversarial training batch
624    fn generate_adversarial_batch(
625        &self,
626        data: &Array2<f64>,
627        labels: &Array1<usize>,
628    ) -> Result<(Array2<f64>, Array1<usize>)> {
629        match &self.defense_strategy {
630            QuantumDefenseStrategy::AdversarialTraining {
631                attack_types,
632                adversarial_ratio,
633            } => {
634                let num_adversarial = (data.nrows() as f64 * adversarial_ratio) as usize;
635                let mut combined_data = data.clone();
636                let mut combined_labels = labels.clone();
637
638                // Generate adversarial examples
639                for i in 0..num_adversarial {
640                    let idx = i % data.nrows();
641                    let input = data.row(idx).to_owned();
642                    let label = labels[idx];
643
644                    // Randomly select attack type
645                    let attack_type = attack_types[fastrand::usize(0..attack_types.len())].clone();
646                    let adversarial_example =
647                        self.generate_single_adversarial_example(&input, label, attack_type)?;
648
649                    // Add to batch (replace original example)
650                    combined_data
651                        .row_mut(idx)
652                        .assign(&adversarial_example.adversarial_input);
653                }
654
655                Ok((combined_data, combined_labels))
656            }
657            _ => Ok((data.clone(), labels.clone())),
658        }
659    }
660
661    /// Train on a single batch
662    fn train_batch(
663        &mut self,
664        data: &Array2<f64>,
665        labels: &Array1<usize>,
666        optimizer: &mut dyn Optimizer,
667    ) -> Result<f64> {
668        // Simplified training step
669        let mut total_loss = 0.0;
670
671        for (input, &label) in data.outer_iter().zip(labels.iter()) {
672            let output = self.model.forward(&input.to_owned())?;
673            let loss = self.compute_loss(&output, label);
674            total_loss += loss;
675
676            // Compute gradients and update (simplified)
677            // In practice, this would use proper backpropagation
678        }
679
680        Ok(total_loss / data.nrows() as f64)
681    }
682
683    /// Initialize ensemble for ensemble defense
684    fn initialize_ensemble(&mut self) -> Result<()> {
685        if let QuantumDefenseStrategy::EnsembleDefense { num_models, .. } = &self.defense_strategy {
686            for _ in 0..*num_models {
687                // Create model with slight variations
688                let model = self.model.clone();
689                self.ensemble_models.push(model);
690            }
691        }
692        Ok(())
693    }
694
695    /// Evaluate robustness on validation set
696    fn evaluate_robustness(
697        &mut self,
698        val_data: &Array2<f64>,
699        val_labels: &Array1<usize>,
700    ) -> Result<()> {
701        let mut clean_correct = 0;
702        let mut robust_correct = 0;
703        let mut total_perturbation = 0.0;
704        let mut successful_attacks = 0;
705
706        // Test with different attack types
707        let test_attacks = vec![
708            QuantumAttackType::FGSM { epsilon: 0.1 },
709            QuantumAttackType::PGD {
710                epsilon: 0.1,
711                alpha: 0.01,
712                num_steps: 10,
713            },
714        ];
715
716        for (input, &label) in val_data.outer_iter().zip(val_labels.iter()) {
717            let input_owned = input.to_owned();
718
719            // Clean accuracy
720            let clean_output = self.model.forward(&input_owned)?;
721            let clean_pred = clean_output
722                .iter()
723                .enumerate()
724                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
725                .map(|(i, _)| i)
726                .unwrap_or(0);
727
728            if clean_pred == label {
729                clean_correct += 1;
730            }
731
732            // Test robustness against attacks
733            let mut robust_for_this_input = true;
734            for attack_type in &test_attacks {
735                let adversarial_example = self.generate_single_adversarial_example(
736                    &input_owned,
737                    label,
738                    attack_type.clone(),
739                )?;
740
741                total_perturbation += adversarial_example.perturbation_norm;
742
743                if adversarial_example.attack_success {
744                    successful_attacks += 1;
745                    robust_for_this_input = false;
746                }
747            }
748
749            if robust_for_this_input {
750                robust_correct += 1;
751            }
752        }
753
754        let num_samples = val_data.nrows();
755        let num_attack_tests = num_samples * test_attacks.len();
756
757        self.robustness_metrics.clean_accuracy = clean_correct as f64 / num_samples as f64;
758        self.robustness_metrics.robust_accuracy = robust_correct as f64 / num_samples as f64;
759        self.robustness_metrics.avg_perturbation_norm =
760            total_perturbation / num_attack_tests as f64;
761        self.robustness_metrics.attack_success_rate =
762            successful_attacks as f64 / num_attack_tests as f64;
763
764        Ok(())
765    }
766
767    /// Apply defense strategy to input
768    pub fn apply_defense(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
769        match &self.defense_strategy {
770            QuantumDefenseStrategy::InputPreprocessing {
771                noise_addition,
772                feature_squeezing,
773            } => {
774                let mut defended_input = input.clone();
775
776                // Add noise
777                for i in 0..defended_input.len() {
778                    defended_input[i] += noise_addition * (2.0 * thread_rng().gen::<f64>() - 1.0);
779                }
780
781                // Feature squeezing
782                if *feature_squeezing {
783                    defended_input = defended_input.mapv(|x| (x * 8.0).round() / 8.0);
784                }
785
786                Ok(defended_input.mapv(|x| x.max(0.0).min(1.0)))
787            }
788            QuantumDefenseStrategy::RandomizedCircuit {
789                randomization_strength,
790                ..
791            } => {
792                let mut defended_input = input.clone();
793
794                // Add random perturbations to simulate circuit randomization
795                for i in 0..defended_input.len() {
796                    let random_shift =
797                        randomization_strength * (2.0 * thread_rng().gen::<f64>() - 1.0);
798                    defended_input[i] += random_shift;
799                }
800
801                Ok(defended_input.mapv(|x| x.max(0.0).min(1.0)))
802            }
803            _ => Ok(input.clone()),
804        }
805    }
806
807    /// Get robustness metrics
808    pub fn get_robustness_metrics(&self) -> &RobustnessMetrics {
809        &self.robustness_metrics
810    }
811
812    /// Get attack history
813    pub fn get_attack_history(&self) -> &[QuantumAdversarialExample] {
814        &self.attack_history
815    }
816
817    /// Perform certified defense analysis
818    pub fn certified_defense_analysis(
819        &self,
820        data: &Array2<f64>,
821        smoothing_variance: f64,
822        num_samples: usize,
823    ) -> Result<f64> {
824        let mut certified_correct = 0;
825
826        for input in data.outer_iter() {
827            let input_owned = input.to_owned();
828
829            // Sample multiple noisy versions
830            let mut predictions = Vec::new();
831            for _ in 0..num_samples {
832                let mut noisy_input = input_owned.clone();
833                for i in 0..noisy_input.len() {
834                    let noise = fastrand::f64() * smoothing_variance;
835                    noisy_input[i] += noise;
836                }
837
838                let output = self.model.forward(&noisy_input)?;
839                let pred = output
840                    .iter()
841                    .enumerate()
842                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
843                    .map(|(i, _)| i)
844                    .unwrap_or(0);
845
846                predictions.push(pred);
847            }
848
849            // Check if prediction is certified (majority vote is stable)
850            let mut counts = vec![0; 10]; // Assume max 10 classes
851            for &pred in &predictions {
852                if pred < counts.len() {
853                    counts[pred] += 1;
854                }
855            }
856
857            let max_count = counts.iter().max().unwrap_or(&0);
858            let certification_threshold = (num_samples as f64 * 0.6) as usize;
859
860            if *max_count >= certification_threshold {
861                certified_correct += 1;
862            }
863        }
864
865        Ok(certified_correct as f64 / data.nrows() as f64)
866    }
867}
868
869/// Helper function to create default adversarial training config
870pub fn create_default_adversarial_config() -> AdversarialTrainingConfig {
871    AdversarialTrainingConfig {
872        epochs: 100,
873        batch_size: 32,
874        learning_rate: 0.001,
875        adversarial_frequency: 2,
876        max_perturbation: 0.1,
877        eval_interval: 10,
878        early_stopping: Some(EarlyStoppingCriteria {
879            min_clean_accuracy: 0.7,
880            min_robust_accuracy: 0.5,
881            patience: 20,
882        }),
883    }
884}
885
886/// Helper function to create comprehensive defense strategy
887pub fn create_comprehensive_defense() -> QuantumDefenseStrategy {
888    QuantumDefenseStrategy::AdversarialTraining {
889        attack_types: vec![
890            QuantumAttackType::FGSM { epsilon: 0.1 },
891            QuantumAttackType::PGD {
892                epsilon: 0.1,
893                alpha: 0.01,
894                num_steps: 7,
895            },
896            QuantumAttackType::ParameterShift {
897                shift_magnitude: 0.05,
898                target_parameters: None,
899            },
900        ],
901        adversarial_ratio: 0.5,
902    }
903}
904
905#[cfg(test)]
906mod tests {
907    use super::*;
908    use crate::qnn::QNNLayerType;
909
910    #[test]
911    fn test_adversarial_example_creation() {
912        let original_input = Array1::from_vec(vec![0.5, 0.3, 0.8, 0.2]);
913        let adversarial_input = Array1::from_vec(vec![0.6, 0.4, 0.7, 0.3]);
914        let original_prediction = Array1::from_vec(vec![0.8, 0.2]);
915        let adversarial_prediction = Array1::from_vec(vec![0.3, 0.7]);
916
917        let perturbation = &adversarial_input - &original_input;
918        let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
919
920        let example = QuantumAdversarialExample {
921            original_input,
922            adversarial_input,
923            original_prediction,
924            adversarial_prediction,
925            true_label: 0,
926            perturbation_norm,
927            attack_success: true,
928            metadata: HashMap::new(),
929        };
930
931        assert!(example.attack_success);
932        assert!(example.perturbation_norm > 0.0);
933    }
934
935    #[test]
936    fn test_fgsm_attack() {
937        let layers = vec![
938            QNNLayerType::EncodingLayer { num_features: 4 },
939            QNNLayerType::VariationalLayer { num_params: 8 },
940            QNNLayerType::MeasurementLayer {
941                measurement_basis: "computational".to_string(),
942            },
943        ];
944
945        let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
946        let defense = create_comprehensive_defense();
947        let config = create_default_adversarial_config();
948
949        let trainer = QuantumAdversarialTrainer::new(model, defense, config);
950
951        let input = Array1::from_vec(vec![0.5, 0.3, 0.8, 0.2]);
952        let adversarial_input = trainer.fgsm_attack(&input, 0, 0.1).unwrap();
953
954        assert_eq!(adversarial_input.len(), input.len());
955
956        // Check that perturbation exists
957        let perturbation = &adversarial_input - &input;
958        let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
959        assert!(perturbation_norm > 0.0);
960
961        // Check that values are in valid range
962        for &val in adversarial_input.iter() {
963            assert!(val >= 0.0 && val <= 1.0);
964        }
965    }
966
967    #[test]
968    fn test_defense_application() {
969        let layers = vec![
970            QNNLayerType::EncodingLayer { num_features: 4 },
971            QNNLayerType::VariationalLayer { num_params: 8 },
972        ];
973
974        let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
975
976        let defense = QuantumDefenseStrategy::InputPreprocessing {
977            noise_addition: 0.05,
978            feature_squeezing: true,
979        };
980
981        let config = create_default_adversarial_config();
982        let trainer = QuantumAdversarialTrainer::new(model, defense, config);
983
984        let input = Array1::from_vec(vec![0.51, 0.32, 0.83, 0.24]);
985        let defended_input = trainer.apply_defense(&input).unwrap();
986
987        assert_eq!(defended_input.len(), input.len());
988
989        // Check that defense was applied (input changed)
990        let difference = (&defended_input - &input).mapv(|x| x.abs()).sum();
991        assert!(difference > 0.0);
992    }
993
994    #[test]
995    fn test_robustness_metrics() {
996        let metrics = RobustnessMetrics {
997            clean_accuracy: 0.85,
998            robust_accuracy: 0.65,
999            avg_perturbation_norm: 0.12,
1000            attack_success_rate: 0.35,
1001            certified_accuracy: Some(0.55),
1002            per_attack_metrics: HashMap::new(),
1003        };
1004
1005        assert_eq!(metrics.clean_accuracy, 0.85);
1006        assert_eq!(metrics.robust_accuracy, 0.65);
1007        assert!(metrics.robust_accuracy < metrics.clean_accuracy);
1008        assert!(metrics.attack_success_rate < 0.5);
1009    }
1010}