quantrs2_ml/
adversarial.rs

1//! Quantum Adversarial Training
2//!
3//! This module implements adversarial training methods for quantum neural networks,
4//! including adversarial attack generation, robust training procedures, and
5//! defense mechanisms against quantum adversarial examples.
6
7use crate::autodiff::optimizers::Optimizer;
8use crate::error::{MLError, Result};
9use crate::optimization::OptimizationMethod;
10use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
11use ndarray::{s, Array1, Array2, Array3, Axis};
12use quantrs2_circuit::builder::{Circuit, Simulator};
13use quantrs2_core::gate::{
14    single::{RotationX, RotationY, RotationZ},
15    GateOp,
16};
17use quantrs2_sim::statevector::StateVectorSimulator;
18use std::collections::HashMap;
19use std::f64::consts::PI;
20
21/// Types of adversarial attacks for quantum models
22#[derive(Debug, Clone)]
23pub enum QuantumAttackType {
24    /// Fast Gradient Sign Method adapted for quantum circuits
25    FGSM { epsilon: f64 },
26
27    /// Projected Gradient Descent attack
28    PGD {
29        epsilon: f64,
30        alpha: f64,
31        num_steps: usize,
32    },
33
34    /// Quantum Parameter Shift attack
35    ParameterShift {
36        shift_magnitude: f64,
37        target_parameters: Option<Vec<usize>>,
38    },
39
40    /// Quantum State Perturbation attack
41    StatePerturbation {
42        perturbation_strength: f64,
43        basis: String,
44    },
45
46    /// Quantum Circuit Manipulation attack
47    CircuitManipulation {
48        gate_error_rate: f64,
49        coherence_time: f64,
50    },
51
52    /// Universal Adversarial Perturbation for quantum inputs
53    UniversalPerturbation {
54        perturbation_budget: f64,
55        success_rate_threshold: f64,
56    },
57}
58
59/// Defense strategies against quantum adversarial attacks
60#[derive(Debug, Clone)]
61pub enum QuantumDefenseStrategy {
62    /// Adversarial training with generated examples
63    AdversarialTraining {
64        attack_types: Vec<QuantumAttackType>,
65        adversarial_ratio: f64,
66    },
67
68    /// Quantum error correction as defense
69    QuantumErrorCorrection {
70        code_type: String,
71        correction_threshold: f64,
72    },
73
74    /// Input preprocessing and sanitization
75    InputPreprocessing {
76        noise_addition: f64,
77        feature_squeezing: bool,
78    },
79
80    /// Ensemble defense with multiple quantum models
81    EnsembleDefense {
82        num_models: usize,
83        diversity_metric: String,
84    },
85
86    /// Certified defense with provable bounds
87    CertifiedDefense {
88        smoothing_variance: f64,
89        confidence_level: f64,
90    },
91
92    /// Randomized circuit defense
93    RandomizedCircuit {
94        randomization_strength: f64,
95        num_random_layers: usize,
96    },
97}
98
99/// Adversarial example for quantum neural networks
100#[derive(Debug, Clone)]
101pub struct QuantumAdversarialExample {
102    /// Original input
103    pub original_input: Array1<f64>,
104
105    /// Adversarial input
106    pub adversarial_input: Array1<f64>,
107
108    /// Original prediction
109    pub original_prediction: Array1<f64>,
110
111    /// Adversarial prediction
112    pub adversarial_prediction: Array1<f64>,
113
114    /// True label
115    pub true_label: usize,
116
117    /// Perturbation magnitude
118    pub perturbation_norm: f64,
119
120    /// Attack success (caused misclassification)
121    pub attack_success: bool,
122
123    /// Attack metadata
124    pub metadata: HashMap<String, f64>,
125}
126
127/// Quantum adversarial trainer
128pub struct QuantumAdversarialTrainer {
129    /// Base quantum model
130    model: QuantumNeuralNetwork,
131
132    /// Defense strategy
133    defense_strategy: QuantumDefenseStrategy,
134
135    /// Training configuration
136    config: AdversarialTrainingConfig,
137
138    /// Attack history
139    attack_history: Vec<QuantumAdversarialExample>,
140
141    /// Robustness metrics
142    robustness_metrics: RobustnessMetrics,
143
144    /// Ensemble models (for ensemble defense)
145    ensemble_models: Vec<QuantumNeuralNetwork>,
146}
147
148/// Configuration for adversarial training
149#[derive(Debug, Clone)]
150pub struct AdversarialTrainingConfig {
151    /// Number of training epochs
152    pub epochs: usize,
153
154    /// Batch size
155    pub batch_size: usize,
156
157    /// Learning rate
158    pub learning_rate: f64,
159
160    /// Adversarial example generation frequency
161    pub adversarial_frequency: usize,
162
163    /// Maximum perturbation budget
164    pub max_perturbation: f64,
165
166    /// Robustness evaluation interval
167    pub eval_interval: usize,
168
169    /// Early stopping criteria
170    pub early_stopping: Option<EarlyStoppingCriteria>,
171}
172
173/// Early stopping criteria for adversarial training
174#[derive(Debug, Clone)]
175pub struct EarlyStoppingCriteria {
176    /// Minimum clean accuracy
177    pub min_clean_accuracy: f64,
178
179    /// Minimum robust accuracy
180    pub min_robust_accuracy: f64,
181
182    /// Patience (epochs without improvement)
183    pub patience: usize,
184}
185
186/// Robustness metrics
187#[derive(Debug, Clone)]
188pub struct RobustnessMetrics {
189    /// Clean accuracy (on unperturbed data)
190    pub clean_accuracy: f64,
191
192    /// Robust accuracy (on adversarial examples)
193    pub robust_accuracy: f64,
194
195    /// Average perturbation norm for successful attacks
196    pub avg_perturbation_norm: f64,
197
198    /// Attack success rate
199    pub attack_success_rate: f64,
200
201    /// Certified accuracy (for certified defenses)
202    pub certified_accuracy: Option<f64>,
203
204    /// Per-attack-type metrics
205    pub per_attack_metrics: HashMap<String, AttackMetrics>,
206}
207
208/// Metrics for specific attack types
209#[derive(Debug, Clone)]
210pub struct AttackMetrics {
211    /// Success rate
212    pub success_rate: f64,
213
214    /// Average perturbation
215    pub avg_perturbation: f64,
216
217    /// Average confidence drop
218    pub avg_confidence_drop: f64,
219}
220
221impl QuantumAdversarialTrainer {
222    /// Create a new quantum adversarial trainer
223    pub fn new(
224        model: QuantumNeuralNetwork,
225        defense_strategy: QuantumDefenseStrategy,
226        config: AdversarialTrainingConfig,
227    ) -> Self {
228        Self {
229            model,
230            defense_strategy,
231            config,
232            attack_history: Vec::new(),
233            robustness_metrics: RobustnessMetrics {
234                clean_accuracy: 0.0,
235                robust_accuracy: 0.0,
236                avg_perturbation_norm: 0.0,
237                attack_success_rate: 0.0,
238                certified_accuracy: None,
239                per_attack_metrics: HashMap::new(),
240            },
241            ensemble_models: Vec::new(),
242        }
243    }
244
245    /// Train the model with adversarial training
246    pub fn train(
247        &mut self,
248        train_data: &Array2<f64>,
249        train_labels: &Array1<usize>,
250        val_data: &Array2<f64>,
251        val_labels: &Array1<usize>,
252        optimizer: &mut dyn Optimizer,
253    ) -> Result<Vec<f64>> {
254        println!("Starting quantum adversarial training...");
255
256        let mut losses = Vec::new();
257        let mut patience_counter = 0;
258        let mut best_robust_accuracy = 0.0;
259
260        // Initialize ensemble if needed
261        self.initialize_ensemble()?;
262
263        for epoch in 0..self.config.epochs {
264            let mut epoch_loss = 0.0;
265            let num_batches =
266                (train_data.nrows() + self.config.batch_size - 1) / self.config.batch_size;
267
268            for batch_idx in 0..num_batches {
269                let batch_start = batch_idx * self.config.batch_size;
270                let batch_end = (batch_start + self.config.batch_size).min(train_data.nrows());
271
272                let batch_data = train_data.slice(s![batch_start..batch_end, ..]).to_owned();
273                let batch_labels = train_labels.slice(s![batch_start..batch_end]).to_owned();
274
275                // Generate adversarial examples if needed
276                let (final_data, final_labels) = if epoch % self.config.adversarial_frequency == 0 {
277                    self.generate_adversarial_batch(&batch_data, &batch_labels)?
278                } else {
279                    (batch_data, batch_labels)
280                };
281
282                // Compute loss and update model
283                let batch_loss = self.train_batch(&final_data, &final_labels, optimizer)?;
284                epoch_loss += batch_loss;
285            }
286
287            epoch_loss /= num_batches as f64;
288            losses.push(epoch_loss);
289
290            // Evaluate robustness periodically
291            if epoch % self.config.eval_interval == 0 {
292                self.evaluate_robustness(val_data, val_labels)?;
293
294                println!(
295                    "Epoch {}: Loss = {:.4}, Clean Acc = {:.3}, Robust Acc = {:.3}",
296                    epoch,
297                    epoch_loss,
298                    self.robustness_metrics.clean_accuracy,
299                    self.robustness_metrics.robust_accuracy
300                );
301
302                // Early stopping check
303                if let Some(ref criteria) = self.config.early_stopping {
304                    if self.robustness_metrics.robust_accuracy > best_robust_accuracy {
305                        best_robust_accuracy = self.robustness_metrics.robust_accuracy;
306                        patience_counter = 0;
307                    } else {
308                        patience_counter += 1;
309                    }
310
311                    if patience_counter >= criteria.patience {
312                        println!("Early stopping triggered at epoch {}", epoch);
313                        break;
314                    }
315
316                    if self.robustness_metrics.clean_accuracy < criteria.min_clean_accuracy
317                        || self.robustness_metrics.robust_accuracy < criteria.min_robust_accuracy
318                    {
319                        println!("Minimum performance criteria not met, stopping training");
320                        break;
321                    }
322                }
323            }
324        }
325
326        // Final robustness evaluation
327        self.evaluate_robustness(val_data, val_labels)?;
328
329        Ok(losses)
330    }
331
332    /// Generate adversarial examples using specified attack
333    pub fn generate_adversarial_examples(
334        &self,
335        data: &Array2<f64>,
336        labels: &Array1<usize>,
337        attack_type: QuantumAttackType,
338    ) -> Result<Vec<QuantumAdversarialExample>> {
339        let mut adversarial_examples = Vec::new();
340
341        for (i, (input, &label)) in data.outer_iter().zip(labels.iter()).enumerate() {
342            let adversarial_example = self.generate_single_adversarial_example(
343                &input.to_owned(),
344                label,
345                attack_type.clone(),
346            )?;
347
348            adversarial_examples.push(adversarial_example);
349        }
350
351        Ok(adversarial_examples)
352    }
353
354    /// Generate a single adversarial example
355    fn generate_single_adversarial_example(
356        &self,
357        input: &Array1<f64>,
358        true_label: usize,
359        attack_type: QuantumAttackType,
360    ) -> Result<QuantumAdversarialExample> {
361        // Get original prediction
362        let original_prediction = self.model.forward(input)?;
363
364        let adversarial_input = match attack_type {
365            QuantumAttackType::FGSM { epsilon } => self.fgsm_attack(input, true_label, epsilon)?,
366            QuantumAttackType::PGD {
367                epsilon,
368                alpha,
369                num_steps,
370            } => self.pgd_attack(input, true_label, epsilon, alpha, num_steps)?,
371            QuantumAttackType::ParameterShift {
372                shift_magnitude,
373                target_parameters,
374            } => self.parameter_shift_attack(input, shift_magnitude, target_parameters)?,
375            QuantumAttackType::StatePerturbation {
376                perturbation_strength,
377                ref basis,
378            } => self.state_perturbation_attack(input, perturbation_strength, basis)?,
379            QuantumAttackType::CircuitManipulation {
380                gate_error_rate,
381                coherence_time,
382            } => self.circuit_manipulation_attack(input, gate_error_rate, coherence_time)?,
383            QuantumAttackType::UniversalPerturbation {
384                perturbation_budget,
385                success_rate_threshold,
386            } => self.universal_perturbation_attack(input, perturbation_budget)?,
387        };
388
389        // Get adversarial prediction
390        let adversarial_prediction = self.model.forward(&adversarial_input)?;
391
392        // Compute perturbation norm
393        let perturbation = &adversarial_input - input;
394        let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
395
396        // Check attack success
397        let original_class = original_prediction
398            .iter()
399            .enumerate()
400            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
401            .map(|(i, _)| i)
402            .unwrap_or(0);
403
404        let adversarial_class = adversarial_prediction
405            .iter()
406            .enumerate()
407            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
408            .map(|(i, _)| i)
409            .unwrap_or(0);
410
411        let attack_success = original_class != adversarial_class;
412
413        Ok(QuantumAdversarialExample {
414            original_input: input.clone(),
415            adversarial_input,
416            original_prediction,
417            adversarial_prediction,
418            true_label,
419            perturbation_norm,
420            attack_success,
421            metadata: HashMap::new(),
422        })
423    }
424
425    /// Fast Gradient Sign Method (FGSM) for quantum circuits
426    fn fgsm_attack(
427        &self,
428        input: &Array1<f64>,
429        true_label: usize,
430        epsilon: f64,
431    ) -> Result<Array1<f64>> {
432        // Compute gradient of loss w.r.t. input
433        let gradient = self.compute_input_gradient(input, true_label)?;
434
435        // Apply FGSM perturbation
436        let perturbation = gradient.mapv(|g| epsilon * g.signum());
437        let adversarial_input = input + &perturbation;
438
439        // Clip to valid range [0, 1] for quantum inputs
440        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
441    }
442
443    /// Projected Gradient Descent (PGD) attack
444    fn pgd_attack(
445        &self,
446        input: &Array1<f64>,
447        true_label: usize,
448        epsilon: f64,
449        alpha: f64,
450        num_steps: usize,
451    ) -> Result<Array1<f64>> {
452        let mut adversarial_input = input.clone();
453
454        for _ in 0..num_steps {
455            // Compute gradient
456            let gradient = self.compute_input_gradient(&adversarial_input, true_label)?;
457
458            // Take gradient step
459            let perturbation = gradient.mapv(|g| alpha * g.signum());
460            adversarial_input = &adversarial_input + &perturbation;
461
462            // Project back to epsilon ball
463            let total_perturbation = &adversarial_input - input;
464            let perturbation_norm = total_perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
465
466            if perturbation_norm > epsilon {
467                let scaling = epsilon / perturbation_norm;
468                adversarial_input = input + &(total_perturbation * scaling);
469            }
470
471            // Clip to valid range
472            adversarial_input = adversarial_input.mapv(|x| x.max(0.0).min(1.0));
473        }
474
475        Ok(adversarial_input)
476    }
477
478    /// Parameter shift attack targeting quantum circuit parameters
479    fn parameter_shift_attack(
480        &self,
481        input: &Array1<f64>,
482        shift_magnitude: f64,
483        target_parameters: Option<Vec<usize>>,
484    ) -> Result<Array1<f64>> {
485        // This attack modifies the input to exploit parameter shift rules
486        let mut adversarial_input = input.clone();
487
488        // Apply parameter shift-inspired perturbations
489        for i in 0..adversarial_input.len() {
490            if let Some(ref targets) = target_parameters {
491                if !targets.contains(&i) {
492                    continue;
493                }
494            }
495
496            // Use parameter shift rule: f(x + π/2) - f(x - π/2)
497            let shift = shift_magnitude * (PI / 2.0);
498            adversarial_input[i] += shift * (2.0 * rand::random::<f64>() - 1.0);
499        }
500
501        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
502    }
503
504    /// Quantum state perturbation attack
505    fn state_perturbation_attack(
506        &self,
507        input: &Array1<f64>,
508        perturbation_strength: f64,
509        basis: &str,
510    ) -> Result<Array1<f64>> {
511        let mut adversarial_input = input.clone();
512
513        match basis {
514            "pauli_x" => {
515                // Apply X-basis perturbations
516                for i in 0..adversarial_input.len() {
517                    let angle = adversarial_input[i] * PI;
518                    let perturbed_angle =
519                        angle + perturbation_strength * (2.0 * rand::random::<f64>() - 1.0);
520                    adversarial_input[i] = perturbed_angle / PI;
521                }
522            }
523            "pauli_y" => {
524                // Apply Y-basis perturbations
525                for i in 0..adversarial_input.len() {
526                    adversarial_input[i] +=
527                        perturbation_strength * (2.0 * rand::random::<f64>() - 1.0);
528                }
529            }
530            "pauli_z" | _ => {
531                // Apply Z-basis perturbations (default)
532                for i in 0..adversarial_input.len() {
533                    let phase_shift = perturbation_strength * (2.0 * rand::random::<f64>() - 1.0);
534                    adversarial_input[i] =
535                        (adversarial_input[i] + phase_shift / (2.0 * PI)).fract();
536                }
537            }
538        }
539
540        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
541    }
542
543    /// Circuit manipulation attack (simulating hardware errors)
544    fn circuit_manipulation_attack(
545        &self,
546        input: &Array1<f64>,
547        gate_error_rate: f64,
548        coherence_time: f64,
549    ) -> Result<Array1<f64>> {
550        let mut adversarial_input = input.clone();
551
552        // Simulate decoherence effects
553        for i in 0..adversarial_input.len() {
554            // Apply T1 decay
555            let t1_factor = (-1.0 / coherence_time).exp();
556            adversarial_input[i] *= t1_factor;
557
558            // Add gate errors
559            if rand::random::<f64>() < gate_error_rate {
560                adversarial_input[i] += 0.1 * (2.0 * rand::random::<f64>() - 1.0);
561            }
562        }
563
564        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
565    }
566
567    /// Universal adversarial perturbation attack
568    fn universal_perturbation_attack(
569        &self,
570        input: &Array1<f64>,
571        perturbation_budget: f64,
572    ) -> Result<Array1<f64>> {
573        // Apply a learned universal perturbation (simplified)
574        let mut adversarial_input = input.clone();
575
576        // Generate universal perturbation pattern
577        for i in 0..adversarial_input.len() {
578            let universal_component =
579                perturbation_budget * (2.0 * PI * i as f64 / adversarial_input.len() as f64).sin();
580            adversarial_input[i] += universal_component;
581        }
582
583        Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
584    }
585
586    /// Compute gradient of loss with respect to input
587    fn compute_input_gradient(
588        &self,
589        input: &Array1<f64>,
590        true_label: usize,
591    ) -> Result<Array1<f64>> {
592        // Placeholder for gradient computation
593        // In practice, this would use automatic differentiation
594        let mut gradient = Array1::zeros(input.len());
595
596        // Finite difference approximation
597        let h = 1e-5;
598        let original_output = self.model.forward(input)?;
599        let original_loss = self.compute_loss(&original_output, true_label);
600
601        for i in 0..input.len() {
602            let mut perturbed_input = input.clone();
603            perturbed_input[i] += h;
604
605            let perturbed_output = self.model.forward(&perturbed_input)?;
606            let perturbed_loss = self.compute_loss(&perturbed_output, true_label);
607
608            gradient[i] = (perturbed_loss - original_loss) / h;
609        }
610
611        Ok(gradient)
612    }
613
614    /// Compute loss for a given output and true label
615    fn compute_loss(&self, output: &Array1<f64>, true_label: usize) -> f64 {
616        // Cross-entropy loss
617        let predicted_prob = output[true_label].max(1e-10);
618        -predicted_prob.ln()
619    }
620
621    /// Generate adversarial training batch
622    fn generate_adversarial_batch(
623        &self,
624        data: &Array2<f64>,
625        labels: &Array1<usize>,
626    ) -> Result<(Array2<f64>, Array1<usize>)> {
627        match &self.defense_strategy {
628            QuantumDefenseStrategy::AdversarialTraining {
629                attack_types,
630                adversarial_ratio,
631            } => {
632                let num_adversarial = (data.nrows() as f64 * adversarial_ratio) as usize;
633                let mut combined_data = data.clone();
634                let mut combined_labels = labels.clone();
635
636                // Generate adversarial examples
637                for i in 0..num_adversarial {
638                    let idx = i % data.nrows();
639                    let input = data.row(idx).to_owned();
640                    let label = labels[idx];
641
642                    // Randomly select attack type
643                    let attack_type = attack_types[fastrand::usize(0..attack_types.len())].clone();
644                    let adversarial_example =
645                        self.generate_single_adversarial_example(&input, label, attack_type)?;
646
647                    // Add to batch (replace original example)
648                    combined_data
649                        .row_mut(idx)
650                        .assign(&adversarial_example.adversarial_input);
651                }
652
653                Ok((combined_data, combined_labels))
654            }
655            _ => Ok((data.clone(), labels.clone())),
656        }
657    }
658
659    /// Train on a single batch
660    fn train_batch(
661        &mut self,
662        data: &Array2<f64>,
663        labels: &Array1<usize>,
664        optimizer: &mut dyn Optimizer,
665    ) -> Result<f64> {
666        // Simplified training step
667        let mut total_loss = 0.0;
668
669        for (input, &label) in data.outer_iter().zip(labels.iter()) {
670            let output = self.model.forward(&input.to_owned())?;
671            let loss = self.compute_loss(&output, label);
672            total_loss += loss;
673
674            // Compute gradients and update (simplified)
675            // In practice, this would use proper backpropagation
676        }
677
678        Ok(total_loss / data.nrows() as f64)
679    }
680
681    /// Initialize ensemble for ensemble defense
682    fn initialize_ensemble(&mut self) -> Result<()> {
683        if let QuantumDefenseStrategy::EnsembleDefense { num_models, .. } = &self.defense_strategy {
684            for _ in 0..*num_models {
685                // Create model with slight variations
686                let model = self.model.clone();
687                self.ensemble_models.push(model);
688            }
689        }
690        Ok(())
691    }
692
693    /// Evaluate robustness on validation set
694    fn evaluate_robustness(
695        &mut self,
696        val_data: &Array2<f64>,
697        val_labels: &Array1<usize>,
698    ) -> Result<()> {
699        let mut clean_correct = 0;
700        let mut robust_correct = 0;
701        let mut total_perturbation = 0.0;
702        let mut successful_attacks = 0;
703
704        // Test with different attack types
705        let test_attacks = vec![
706            QuantumAttackType::FGSM { epsilon: 0.1 },
707            QuantumAttackType::PGD {
708                epsilon: 0.1,
709                alpha: 0.01,
710                num_steps: 10,
711            },
712        ];
713
714        for (input, &label) in val_data.outer_iter().zip(val_labels.iter()) {
715            let input_owned = input.to_owned();
716
717            // Clean accuracy
718            let clean_output = self.model.forward(&input_owned)?;
719            let clean_pred = clean_output
720                .iter()
721                .enumerate()
722                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
723                .map(|(i, _)| i)
724                .unwrap_or(0);
725
726            if clean_pred == label {
727                clean_correct += 1;
728            }
729
730            // Test robustness against attacks
731            let mut robust_for_this_input = true;
732            for attack_type in &test_attacks {
733                let adversarial_example = self.generate_single_adversarial_example(
734                    &input_owned,
735                    label,
736                    attack_type.clone(),
737                )?;
738
739                total_perturbation += adversarial_example.perturbation_norm;
740
741                if adversarial_example.attack_success {
742                    successful_attacks += 1;
743                    robust_for_this_input = false;
744                }
745            }
746
747            if robust_for_this_input {
748                robust_correct += 1;
749            }
750        }
751
752        let num_samples = val_data.nrows();
753        let num_attack_tests = num_samples * test_attacks.len();
754
755        self.robustness_metrics.clean_accuracy = clean_correct as f64 / num_samples as f64;
756        self.robustness_metrics.robust_accuracy = robust_correct as f64 / num_samples as f64;
757        self.robustness_metrics.avg_perturbation_norm =
758            total_perturbation / num_attack_tests as f64;
759        self.robustness_metrics.attack_success_rate =
760            successful_attacks as f64 / num_attack_tests as f64;
761
762        Ok(())
763    }
764
765    /// Apply defense strategy to input
766    pub fn apply_defense(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
767        match &self.defense_strategy {
768            QuantumDefenseStrategy::InputPreprocessing {
769                noise_addition,
770                feature_squeezing,
771            } => {
772                let mut defended_input = input.clone();
773
774                // Add noise
775                for i in 0..defended_input.len() {
776                    defended_input[i] += noise_addition * (2.0 * rand::random::<f64>() - 1.0);
777                }
778
779                // Feature squeezing
780                if *feature_squeezing {
781                    defended_input = defended_input.mapv(|x| (x * 8.0).round() / 8.0);
782                }
783
784                Ok(defended_input.mapv(|x| x.max(0.0).min(1.0)))
785            }
786            QuantumDefenseStrategy::RandomizedCircuit {
787                randomization_strength,
788                ..
789            } => {
790                let mut defended_input = input.clone();
791
792                // Add random perturbations to simulate circuit randomization
793                for i in 0..defended_input.len() {
794                    let random_shift = randomization_strength * (2.0 * rand::random::<f64>() - 1.0);
795                    defended_input[i] += random_shift;
796                }
797
798                Ok(defended_input.mapv(|x| x.max(0.0).min(1.0)))
799            }
800            _ => Ok(input.clone()),
801        }
802    }
803
804    /// Get robustness metrics
805    pub fn get_robustness_metrics(&self) -> &RobustnessMetrics {
806        &self.robustness_metrics
807    }
808
809    /// Get attack history
810    pub fn get_attack_history(&self) -> &[QuantumAdversarialExample] {
811        &self.attack_history
812    }
813
814    /// Perform certified defense analysis
815    pub fn certified_defense_analysis(
816        &self,
817        data: &Array2<f64>,
818        smoothing_variance: f64,
819        num_samples: usize,
820    ) -> Result<f64> {
821        let mut certified_correct = 0;
822
823        for input in data.outer_iter() {
824            let input_owned = input.to_owned();
825
826            // Sample multiple noisy versions
827            let mut predictions = Vec::new();
828            for _ in 0..num_samples {
829                let mut noisy_input = input_owned.clone();
830                for i in 0..noisy_input.len() {
831                    let noise = fastrand::f64() * smoothing_variance;
832                    noisy_input[i] += noise;
833                }
834
835                let output = self.model.forward(&noisy_input)?;
836                let pred = output
837                    .iter()
838                    .enumerate()
839                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
840                    .map(|(i, _)| i)
841                    .unwrap_or(0);
842
843                predictions.push(pred);
844            }
845
846            // Check if prediction is certified (majority vote is stable)
847            let mut counts = vec![0; 10]; // Assume max 10 classes
848            for &pred in &predictions {
849                if pred < counts.len() {
850                    counts[pred] += 1;
851                }
852            }
853
854            let max_count = counts.iter().max().unwrap_or(&0);
855            let certification_threshold = (num_samples as f64 * 0.6) as usize;
856
857            if *max_count >= certification_threshold {
858                certified_correct += 1;
859            }
860        }
861
862        Ok(certified_correct as f64 / data.nrows() as f64)
863    }
864}
865
866/// Helper function to create default adversarial training config
867pub fn create_default_adversarial_config() -> AdversarialTrainingConfig {
868    AdversarialTrainingConfig {
869        epochs: 100,
870        batch_size: 32,
871        learning_rate: 0.001,
872        adversarial_frequency: 2,
873        max_perturbation: 0.1,
874        eval_interval: 10,
875        early_stopping: Some(EarlyStoppingCriteria {
876            min_clean_accuracy: 0.7,
877            min_robust_accuracy: 0.5,
878            patience: 20,
879        }),
880    }
881}
882
883/// Helper function to create comprehensive defense strategy
884pub fn create_comprehensive_defense() -> QuantumDefenseStrategy {
885    QuantumDefenseStrategy::AdversarialTraining {
886        attack_types: vec![
887            QuantumAttackType::FGSM { epsilon: 0.1 },
888            QuantumAttackType::PGD {
889                epsilon: 0.1,
890                alpha: 0.01,
891                num_steps: 7,
892            },
893            QuantumAttackType::ParameterShift {
894                shift_magnitude: 0.05,
895                target_parameters: None,
896            },
897        ],
898        adversarial_ratio: 0.5,
899    }
900}
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905    use crate::qnn::QNNLayerType;
906
907    #[test]
908    fn test_adversarial_example_creation() {
909        let original_input = Array1::from_vec(vec![0.5, 0.3, 0.8, 0.2]);
910        let adversarial_input = Array1::from_vec(vec![0.6, 0.4, 0.7, 0.3]);
911        let original_prediction = Array1::from_vec(vec![0.8, 0.2]);
912        let adversarial_prediction = Array1::from_vec(vec![0.3, 0.7]);
913
914        let perturbation = &adversarial_input - &original_input;
915        let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
916
917        let example = QuantumAdversarialExample {
918            original_input,
919            adversarial_input,
920            original_prediction,
921            adversarial_prediction,
922            true_label: 0,
923            perturbation_norm,
924            attack_success: true,
925            metadata: HashMap::new(),
926        };
927
928        assert!(example.attack_success);
929        assert!(example.perturbation_norm > 0.0);
930    }
931
932    #[test]
933    fn test_fgsm_attack() {
934        let layers = vec![
935            QNNLayerType::EncodingLayer { num_features: 4 },
936            QNNLayerType::VariationalLayer { num_params: 8 },
937            QNNLayerType::MeasurementLayer {
938                measurement_basis: "computational".to_string(),
939            },
940        ];
941
942        let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
943        let defense = create_comprehensive_defense();
944        let config = create_default_adversarial_config();
945
946        let trainer = QuantumAdversarialTrainer::new(model, defense, config);
947
948        let input = Array1::from_vec(vec![0.5, 0.3, 0.8, 0.2]);
949        let adversarial_input = trainer.fgsm_attack(&input, 0, 0.1).unwrap();
950
951        assert_eq!(adversarial_input.len(), input.len());
952
953        // Check that perturbation exists
954        let perturbation = &adversarial_input - &input;
955        let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
956        assert!(perturbation_norm > 0.0);
957
958        // Check that values are in valid range
959        for &val in adversarial_input.iter() {
960            assert!(val >= 0.0 && val <= 1.0);
961        }
962    }
963
964    #[test]
965    fn test_defense_application() {
966        let layers = vec![
967            QNNLayerType::EncodingLayer { num_features: 4 },
968            QNNLayerType::VariationalLayer { num_params: 8 },
969        ];
970
971        let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
972
973        let defense = QuantumDefenseStrategy::InputPreprocessing {
974            noise_addition: 0.05,
975            feature_squeezing: true,
976        };
977
978        let config = create_default_adversarial_config();
979        let trainer = QuantumAdversarialTrainer::new(model, defense, config);
980
981        let input = Array1::from_vec(vec![0.51, 0.32, 0.83, 0.24]);
982        let defended_input = trainer.apply_defense(&input).unwrap();
983
984        assert_eq!(defended_input.len(), input.len());
985
986        // Check that defense was applied (input changed)
987        let difference = (&defended_input - &input).mapv(|x| x.abs()).sum();
988        assert!(difference > 0.0);
989    }
990
991    #[test]
992    fn test_robustness_metrics() {
993        let metrics = RobustnessMetrics {
994            clean_accuracy: 0.85,
995            robust_accuracy: 0.65,
996            avg_perturbation_norm: 0.12,
997            attack_success_rate: 0.35,
998            certified_accuracy: Some(0.55),
999            per_attack_metrics: HashMap::new(),
1000        };
1001
1002        assert_eq!(metrics.clean_accuracy, 0.85);
1003        assert_eq!(metrics.robust_accuracy, 0.65);
1004        assert!(metrics.robust_accuracy < metrics.clean_accuracy);
1005        assert!(metrics.attack_success_rate < 0.5);
1006    }
1007}