1use crate::autodiff::optimizers::Optimizer;
8use crate::error::{MLError, Result};
9use crate::optimization::OptimizationMethod;
10use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
11use ndarray::{s, Array1, Array2, Array3, Axis};
12use quantrs2_circuit::builder::{Circuit, Simulator};
13use quantrs2_core::gate::{
14 single::{RotationX, RotationY, RotationZ},
15 GateOp,
16};
17use quantrs2_sim::statevector::StateVectorSimulator;
18use std::collections::HashMap;
19use std::f64::consts::PI;
20
21#[derive(Debug, Clone)]
23pub enum QuantumAttackType {
24 FGSM { epsilon: f64 },
26
27 PGD {
29 epsilon: f64,
30 alpha: f64,
31 num_steps: usize,
32 },
33
34 ParameterShift {
36 shift_magnitude: f64,
37 target_parameters: Option<Vec<usize>>,
38 },
39
40 StatePerturbation {
42 perturbation_strength: f64,
43 basis: String,
44 },
45
46 CircuitManipulation {
48 gate_error_rate: f64,
49 coherence_time: f64,
50 },
51
52 UniversalPerturbation {
54 perturbation_budget: f64,
55 success_rate_threshold: f64,
56 },
57}
58
59#[derive(Debug, Clone)]
61pub enum QuantumDefenseStrategy {
62 AdversarialTraining {
64 attack_types: Vec<QuantumAttackType>,
65 adversarial_ratio: f64,
66 },
67
68 QuantumErrorCorrection {
70 code_type: String,
71 correction_threshold: f64,
72 },
73
74 InputPreprocessing {
76 noise_addition: f64,
77 feature_squeezing: bool,
78 },
79
80 EnsembleDefense {
82 num_models: usize,
83 diversity_metric: String,
84 },
85
86 CertifiedDefense {
88 smoothing_variance: f64,
89 confidence_level: f64,
90 },
91
92 RandomizedCircuit {
94 randomization_strength: f64,
95 num_random_layers: usize,
96 },
97}
98
99#[derive(Debug, Clone)]
101pub struct QuantumAdversarialExample {
102 pub original_input: Array1<f64>,
104
105 pub adversarial_input: Array1<f64>,
107
108 pub original_prediction: Array1<f64>,
110
111 pub adversarial_prediction: Array1<f64>,
113
114 pub true_label: usize,
116
117 pub perturbation_norm: f64,
119
120 pub attack_success: bool,
122
123 pub metadata: HashMap<String, f64>,
125}
126
127pub struct QuantumAdversarialTrainer {
129 model: QuantumNeuralNetwork,
131
132 defense_strategy: QuantumDefenseStrategy,
134
135 config: AdversarialTrainingConfig,
137
138 attack_history: Vec<QuantumAdversarialExample>,
140
141 robustness_metrics: RobustnessMetrics,
143
144 ensemble_models: Vec<QuantumNeuralNetwork>,
146}
147
148#[derive(Debug, Clone)]
150pub struct AdversarialTrainingConfig {
151 pub epochs: usize,
153
154 pub batch_size: usize,
156
157 pub learning_rate: f64,
159
160 pub adversarial_frequency: usize,
162
163 pub max_perturbation: f64,
165
166 pub eval_interval: usize,
168
169 pub early_stopping: Option<EarlyStoppingCriteria>,
171}
172
173#[derive(Debug, Clone)]
175pub struct EarlyStoppingCriteria {
176 pub min_clean_accuracy: f64,
178
179 pub min_robust_accuracy: f64,
181
182 pub patience: usize,
184}
185
186#[derive(Debug, Clone)]
188pub struct RobustnessMetrics {
189 pub clean_accuracy: f64,
191
192 pub robust_accuracy: f64,
194
195 pub avg_perturbation_norm: f64,
197
198 pub attack_success_rate: f64,
200
201 pub certified_accuracy: Option<f64>,
203
204 pub per_attack_metrics: HashMap<String, AttackMetrics>,
206}
207
208#[derive(Debug, Clone)]
210pub struct AttackMetrics {
211 pub success_rate: f64,
213
214 pub avg_perturbation: f64,
216
217 pub avg_confidence_drop: f64,
219}
220
221impl QuantumAdversarialTrainer {
222 pub fn new(
224 model: QuantumNeuralNetwork,
225 defense_strategy: QuantumDefenseStrategy,
226 config: AdversarialTrainingConfig,
227 ) -> Self {
228 Self {
229 model,
230 defense_strategy,
231 config,
232 attack_history: Vec::new(),
233 robustness_metrics: RobustnessMetrics {
234 clean_accuracy: 0.0,
235 robust_accuracy: 0.0,
236 avg_perturbation_norm: 0.0,
237 attack_success_rate: 0.0,
238 certified_accuracy: None,
239 per_attack_metrics: HashMap::new(),
240 },
241 ensemble_models: Vec::new(),
242 }
243 }
244
245 pub fn train(
247 &mut self,
248 train_data: &Array2<f64>,
249 train_labels: &Array1<usize>,
250 val_data: &Array2<f64>,
251 val_labels: &Array1<usize>,
252 optimizer: &mut dyn Optimizer,
253 ) -> Result<Vec<f64>> {
254 println!("Starting quantum adversarial training...");
255
256 let mut losses = Vec::new();
257 let mut patience_counter = 0;
258 let mut best_robust_accuracy = 0.0;
259
260 self.initialize_ensemble()?;
262
263 for epoch in 0..self.config.epochs {
264 let mut epoch_loss = 0.0;
265 let num_batches =
266 (train_data.nrows() + self.config.batch_size - 1) / self.config.batch_size;
267
268 for batch_idx in 0..num_batches {
269 let batch_start = batch_idx * self.config.batch_size;
270 let batch_end = (batch_start + self.config.batch_size).min(train_data.nrows());
271
272 let batch_data = train_data.slice(s![batch_start..batch_end, ..]).to_owned();
273 let batch_labels = train_labels.slice(s![batch_start..batch_end]).to_owned();
274
275 let (final_data, final_labels) = if epoch % self.config.adversarial_frequency == 0 {
277 self.generate_adversarial_batch(&batch_data, &batch_labels)?
278 } else {
279 (batch_data, batch_labels)
280 };
281
282 let batch_loss = self.train_batch(&final_data, &final_labels, optimizer)?;
284 epoch_loss += batch_loss;
285 }
286
287 epoch_loss /= num_batches as f64;
288 losses.push(epoch_loss);
289
290 if epoch % self.config.eval_interval == 0 {
292 self.evaluate_robustness(val_data, val_labels)?;
293
294 println!(
295 "Epoch {}: Loss = {:.4}, Clean Acc = {:.3}, Robust Acc = {:.3}",
296 epoch,
297 epoch_loss,
298 self.robustness_metrics.clean_accuracy,
299 self.robustness_metrics.robust_accuracy
300 );
301
302 if let Some(ref criteria) = self.config.early_stopping {
304 if self.robustness_metrics.robust_accuracy > best_robust_accuracy {
305 best_robust_accuracy = self.robustness_metrics.robust_accuracy;
306 patience_counter = 0;
307 } else {
308 patience_counter += 1;
309 }
310
311 if patience_counter >= criteria.patience {
312 println!("Early stopping triggered at epoch {}", epoch);
313 break;
314 }
315
316 if self.robustness_metrics.clean_accuracy < criteria.min_clean_accuracy
317 || self.robustness_metrics.robust_accuracy < criteria.min_robust_accuracy
318 {
319 println!("Minimum performance criteria not met, stopping training");
320 break;
321 }
322 }
323 }
324 }
325
326 self.evaluate_robustness(val_data, val_labels)?;
328
329 Ok(losses)
330 }
331
332 pub fn generate_adversarial_examples(
334 &self,
335 data: &Array2<f64>,
336 labels: &Array1<usize>,
337 attack_type: QuantumAttackType,
338 ) -> Result<Vec<QuantumAdversarialExample>> {
339 let mut adversarial_examples = Vec::new();
340
341 for (i, (input, &label)) in data.outer_iter().zip(labels.iter()).enumerate() {
342 let adversarial_example = self.generate_single_adversarial_example(
343 &input.to_owned(),
344 label,
345 attack_type.clone(),
346 )?;
347
348 adversarial_examples.push(adversarial_example);
349 }
350
351 Ok(adversarial_examples)
352 }
353
354 fn generate_single_adversarial_example(
356 &self,
357 input: &Array1<f64>,
358 true_label: usize,
359 attack_type: QuantumAttackType,
360 ) -> Result<QuantumAdversarialExample> {
361 let original_prediction = self.model.forward(input)?;
363
364 let adversarial_input = match attack_type {
365 QuantumAttackType::FGSM { epsilon } => self.fgsm_attack(input, true_label, epsilon)?,
366 QuantumAttackType::PGD {
367 epsilon,
368 alpha,
369 num_steps,
370 } => self.pgd_attack(input, true_label, epsilon, alpha, num_steps)?,
371 QuantumAttackType::ParameterShift {
372 shift_magnitude,
373 target_parameters,
374 } => self.parameter_shift_attack(input, shift_magnitude, target_parameters)?,
375 QuantumAttackType::StatePerturbation {
376 perturbation_strength,
377 ref basis,
378 } => self.state_perturbation_attack(input, perturbation_strength, basis)?,
379 QuantumAttackType::CircuitManipulation {
380 gate_error_rate,
381 coherence_time,
382 } => self.circuit_manipulation_attack(input, gate_error_rate, coherence_time)?,
383 QuantumAttackType::UniversalPerturbation {
384 perturbation_budget,
385 success_rate_threshold,
386 } => self.universal_perturbation_attack(input, perturbation_budget)?,
387 };
388
389 let adversarial_prediction = self.model.forward(&adversarial_input)?;
391
392 let perturbation = &adversarial_input - input;
394 let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
395
396 let original_class = original_prediction
398 .iter()
399 .enumerate()
400 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
401 .map(|(i, _)| i)
402 .unwrap_or(0);
403
404 let adversarial_class = adversarial_prediction
405 .iter()
406 .enumerate()
407 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
408 .map(|(i, _)| i)
409 .unwrap_or(0);
410
411 let attack_success = original_class != adversarial_class;
412
413 Ok(QuantumAdversarialExample {
414 original_input: input.clone(),
415 adversarial_input,
416 original_prediction,
417 adversarial_prediction,
418 true_label,
419 perturbation_norm,
420 attack_success,
421 metadata: HashMap::new(),
422 })
423 }
424
425 fn fgsm_attack(
427 &self,
428 input: &Array1<f64>,
429 true_label: usize,
430 epsilon: f64,
431 ) -> Result<Array1<f64>> {
432 let gradient = self.compute_input_gradient(input, true_label)?;
434
435 let perturbation = gradient.mapv(|g| epsilon * g.signum());
437 let adversarial_input = input + &perturbation;
438
439 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
441 }
442
443 fn pgd_attack(
445 &self,
446 input: &Array1<f64>,
447 true_label: usize,
448 epsilon: f64,
449 alpha: f64,
450 num_steps: usize,
451 ) -> Result<Array1<f64>> {
452 let mut adversarial_input = input.clone();
453
454 for _ in 0..num_steps {
455 let gradient = self.compute_input_gradient(&adversarial_input, true_label)?;
457
458 let perturbation = gradient.mapv(|g| alpha * g.signum());
460 adversarial_input = &adversarial_input + &perturbation;
461
462 let total_perturbation = &adversarial_input - input;
464 let perturbation_norm = total_perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
465
466 if perturbation_norm > epsilon {
467 let scaling = epsilon / perturbation_norm;
468 adversarial_input = input + &(total_perturbation * scaling);
469 }
470
471 adversarial_input = adversarial_input.mapv(|x| x.max(0.0).min(1.0));
473 }
474
475 Ok(adversarial_input)
476 }
477
478 fn parameter_shift_attack(
480 &self,
481 input: &Array1<f64>,
482 shift_magnitude: f64,
483 target_parameters: Option<Vec<usize>>,
484 ) -> Result<Array1<f64>> {
485 let mut adversarial_input = input.clone();
487
488 for i in 0..adversarial_input.len() {
490 if let Some(ref targets) = target_parameters {
491 if !targets.contains(&i) {
492 continue;
493 }
494 }
495
496 let shift = shift_magnitude * (PI / 2.0);
498 adversarial_input[i] += shift * (2.0 * rand::random::<f64>() - 1.0);
499 }
500
501 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
502 }
503
504 fn state_perturbation_attack(
506 &self,
507 input: &Array1<f64>,
508 perturbation_strength: f64,
509 basis: &str,
510 ) -> Result<Array1<f64>> {
511 let mut adversarial_input = input.clone();
512
513 match basis {
514 "pauli_x" => {
515 for i in 0..adversarial_input.len() {
517 let angle = adversarial_input[i] * PI;
518 let perturbed_angle =
519 angle + perturbation_strength * (2.0 * rand::random::<f64>() - 1.0);
520 adversarial_input[i] = perturbed_angle / PI;
521 }
522 }
523 "pauli_y" => {
524 for i in 0..adversarial_input.len() {
526 adversarial_input[i] +=
527 perturbation_strength * (2.0 * rand::random::<f64>() - 1.0);
528 }
529 }
530 "pauli_z" | _ => {
531 for i in 0..adversarial_input.len() {
533 let phase_shift = perturbation_strength * (2.0 * rand::random::<f64>() - 1.0);
534 adversarial_input[i] =
535 (adversarial_input[i] + phase_shift / (2.0 * PI)).fract();
536 }
537 }
538 }
539
540 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
541 }
542
543 fn circuit_manipulation_attack(
545 &self,
546 input: &Array1<f64>,
547 gate_error_rate: f64,
548 coherence_time: f64,
549 ) -> Result<Array1<f64>> {
550 let mut adversarial_input = input.clone();
551
552 for i in 0..adversarial_input.len() {
554 let t1_factor = (-1.0 / coherence_time).exp();
556 adversarial_input[i] *= t1_factor;
557
558 if rand::random::<f64>() < gate_error_rate {
560 adversarial_input[i] += 0.1 * (2.0 * rand::random::<f64>() - 1.0);
561 }
562 }
563
564 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
565 }
566
567 fn universal_perturbation_attack(
569 &self,
570 input: &Array1<f64>,
571 perturbation_budget: f64,
572 ) -> Result<Array1<f64>> {
573 let mut adversarial_input = input.clone();
575
576 for i in 0..adversarial_input.len() {
578 let universal_component =
579 perturbation_budget * (2.0 * PI * i as f64 / adversarial_input.len() as f64).sin();
580 adversarial_input[i] += universal_component;
581 }
582
583 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
584 }
585
586 fn compute_input_gradient(
588 &self,
589 input: &Array1<f64>,
590 true_label: usize,
591 ) -> Result<Array1<f64>> {
592 let mut gradient = Array1::zeros(input.len());
595
596 let h = 1e-5;
598 let original_output = self.model.forward(input)?;
599 let original_loss = self.compute_loss(&original_output, true_label);
600
601 for i in 0..input.len() {
602 let mut perturbed_input = input.clone();
603 perturbed_input[i] += h;
604
605 let perturbed_output = self.model.forward(&perturbed_input)?;
606 let perturbed_loss = self.compute_loss(&perturbed_output, true_label);
607
608 gradient[i] = (perturbed_loss - original_loss) / h;
609 }
610
611 Ok(gradient)
612 }
613
614 fn compute_loss(&self, output: &Array1<f64>, true_label: usize) -> f64 {
616 let predicted_prob = output[true_label].max(1e-10);
618 -predicted_prob.ln()
619 }
620
621 fn generate_adversarial_batch(
623 &self,
624 data: &Array2<f64>,
625 labels: &Array1<usize>,
626 ) -> Result<(Array2<f64>, Array1<usize>)> {
627 match &self.defense_strategy {
628 QuantumDefenseStrategy::AdversarialTraining {
629 attack_types,
630 adversarial_ratio,
631 } => {
632 let num_adversarial = (data.nrows() as f64 * adversarial_ratio) as usize;
633 let mut combined_data = data.clone();
634 let mut combined_labels = labels.clone();
635
636 for i in 0..num_adversarial {
638 let idx = i % data.nrows();
639 let input = data.row(idx).to_owned();
640 let label = labels[idx];
641
642 let attack_type = attack_types[fastrand::usize(0..attack_types.len())].clone();
644 let adversarial_example =
645 self.generate_single_adversarial_example(&input, label, attack_type)?;
646
647 combined_data
649 .row_mut(idx)
650 .assign(&adversarial_example.adversarial_input);
651 }
652
653 Ok((combined_data, combined_labels))
654 }
655 _ => Ok((data.clone(), labels.clone())),
656 }
657 }
658
659 fn train_batch(
661 &mut self,
662 data: &Array2<f64>,
663 labels: &Array1<usize>,
664 optimizer: &mut dyn Optimizer,
665 ) -> Result<f64> {
666 let mut total_loss = 0.0;
668
669 for (input, &label) in data.outer_iter().zip(labels.iter()) {
670 let output = self.model.forward(&input.to_owned())?;
671 let loss = self.compute_loss(&output, label);
672 total_loss += loss;
673
674 }
677
678 Ok(total_loss / data.nrows() as f64)
679 }
680
681 fn initialize_ensemble(&mut self) -> Result<()> {
683 if let QuantumDefenseStrategy::EnsembleDefense { num_models, .. } = &self.defense_strategy {
684 for _ in 0..*num_models {
685 let model = self.model.clone();
687 self.ensemble_models.push(model);
688 }
689 }
690 Ok(())
691 }
692
693 fn evaluate_robustness(
695 &mut self,
696 val_data: &Array2<f64>,
697 val_labels: &Array1<usize>,
698 ) -> Result<()> {
699 let mut clean_correct = 0;
700 let mut robust_correct = 0;
701 let mut total_perturbation = 0.0;
702 let mut successful_attacks = 0;
703
704 let test_attacks = vec![
706 QuantumAttackType::FGSM { epsilon: 0.1 },
707 QuantumAttackType::PGD {
708 epsilon: 0.1,
709 alpha: 0.01,
710 num_steps: 10,
711 },
712 ];
713
714 for (input, &label) in val_data.outer_iter().zip(val_labels.iter()) {
715 let input_owned = input.to_owned();
716
717 let clean_output = self.model.forward(&input_owned)?;
719 let clean_pred = clean_output
720 .iter()
721 .enumerate()
722 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
723 .map(|(i, _)| i)
724 .unwrap_or(0);
725
726 if clean_pred == label {
727 clean_correct += 1;
728 }
729
730 let mut robust_for_this_input = true;
732 for attack_type in &test_attacks {
733 let adversarial_example = self.generate_single_adversarial_example(
734 &input_owned,
735 label,
736 attack_type.clone(),
737 )?;
738
739 total_perturbation += adversarial_example.perturbation_norm;
740
741 if adversarial_example.attack_success {
742 successful_attacks += 1;
743 robust_for_this_input = false;
744 }
745 }
746
747 if robust_for_this_input {
748 robust_correct += 1;
749 }
750 }
751
752 let num_samples = val_data.nrows();
753 let num_attack_tests = num_samples * test_attacks.len();
754
755 self.robustness_metrics.clean_accuracy = clean_correct as f64 / num_samples as f64;
756 self.robustness_metrics.robust_accuracy = robust_correct as f64 / num_samples as f64;
757 self.robustness_metrics.avg_perturbation_norm =
758 total_perturbation / num_attack_tests as f64;
759 self.robustness_metrics.attack_success_rate =
760 successful_attacks as f64 / num_attack_tests as f64;
761
762 Ok(())
763 }
764
765 pub fn apply_defense(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
767 match &self.defense_strategy {
768 QuantumDefenseStrategy::InputPreprocessing {
769 noise_addition,
770 feature_squeezing,
771 } => {
772 let mut defended_input = input.clone();
773
774 for i in 0..defended_input.len() {
776 defended_input[i] += noise_addition * (2.0 * rand::random::<f64>() - 1.0);
777 }
778
779 if *feature_squeezing {
781 defended_input = defended_input.mapv(|x| (x * 8.0).round() / 8.0);
782 }
783
784 Ok(defended_input.mapv(|x| x.max(0.0).min(1.0)))
785 }
786 QuantumDefenseStrategy::RandomizedCircuit {
787 randomization_strength,
788 ..
789 } => {
790 let mut defended_input = input.clone();
791
792 for i in 0..defended_input.len() {
794 let random_shift = randomization_strength * (2.0 * rand::random::<f64>() - 1.0);
795 defended_input[i] += random_shift;
796 }
797
798 Ok(defended_input.mapv(|x| x.max(0.0).min(1.0)))
799 }
800 _ => Ok(input.clone()),
801 }
802 }
803
804 pub fn get_robustness_metrics(&self) -> &RobustnessMetrics {
806 &self.robustness_metrics
807 }
808
809 pub fn get_attack_history(&self) -> &[QuantumAdversarialExample] {
811 &self.attack_history
812 }
813
814 pub fn certified_defense_analysis(
816 &self,
817 data: &Array2<f64>,
818 smoothing_variance: f64,
819 num_samples: usize,
820 ) -> Result<f64> {
821 let mut certified_correct = 0;
822
823 for input in data.outer_iter() {
824 let input_owned = input.to_owned();
825
826 let mut predictions = Vec::new();
828 for _ in 0..num_samples {
829 let mut noisy_input = input_owned.clone();
830 for i in 0..noisy_input.len() {
831 let noise = fastrand::f64() * smoothing_variance;
832 noisy_input[i] += noise;
833 }
834
835 let output = self.model.forward(&noisy_input)?;
836 let pred = output
837 .iter()
838 .enumerate()
839 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
840 .map(|(i, _)| i)
841 .unwrap_or(0);
842
843 predictions.push(pred);
844 }
845
846 let mut counts = vec![0; 10]; for &pred in &predictions {
849 if pred < counts.len() {
850 counts[pred] += 1;
851 }
852 }
853
854 let max_count = counts.iter().max().unwrap_or(&0);
855 let certification_threshold = (num_samples as f64 * 0.6) as usize;
856
857 if *max_count >= certification_threshold {
858 certified_correct += 1;
859 }
860 }
861
862 Ok(certified_correct as f64 / data.nrows() as f64)
863 }
864}
865
866pub fn create_default_adversarial_config() -> AdversarialTrainingConfig {
868 AdversarialTrainingConfig {
869 epochs: 100,
870 batch_size: 32,
871 learning_rate: 0.001,
872 adversarial_frequency: 2,
873 max_perturbation: 0.1,
874 eval_interval: 10,
875 early_stopping: Some(EarlyStoppingCriteria {
876 min_clean_accuracy: 0.7,
877 min_robust_accuracy: 0.5,
878 patience: 20,
879 }),
880 }
881}
882
883pub fn create_comprehensive_defense() -> QuantumDefenseStrategy {
885 QuantumDefenseStrategy::AdversarialTraining {
886 attack_types: vec![
887 QuantumAttackType::FGSM { epsilon: 0.1 },
888 QuantumAttackType::PGD {
889 epsilon: 0.1,
890 alpha: 0.01,
891 num_steps: 7,
892 },
893 QuantumAttackType::ParameterShift {
894 shift_magnitude: 0.05,
895 target_parameters: None,
896 },
897 ],
898 adversarial_ratio: 0.5,
899 }
900}
901
902#[cfg(test)]
903mod tests {
904 use super::*;
905 use crate::qnn::QNNLayerType;
906
907 #[test]
908 fn test_adversarial_example_creation() {
909 let original_input = Array1::from_vec(vec![0.5, 0.3, 0.8, 0.2]);
910 let adversarial_input = Array1::from_vec(vec![0.6, 0.4, 0.7, 0.3]);
911 let original_prediction = Array1::from_vec(vec![0.8, 0.2]);
912 let adversarial_prediction = Array1::from_vec(vec![0.3, 0.7]);
913
914 let perturbation = &adversarial_input - &original_input;
915 let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
916
917 let example = QuantumAdversarialExample {
918 original_input,
919 adversarial_input,
920 original_prediction,
921 adversarial_prediction,
922 true_label: 0,
923 perturbation_norm,
924 attack_success: true,
925 metadata: HashMap::new(),
926 };
927
928 assert!(example.attack_success);
929 assert!(example.perturbation_norm > 0.0);
930 }
931
932 #[test]
933 fn test_fgsm_attack() {
934 let layers = vec![
935 QNNLayerType::EncodingLayer { num_features: 4 },
936 QNNLayerType::VariationalLayer { num_params: 8 },
937 QNNLayerType::MeasurementLayer {
938 measurement_basis: "computational".to_string(),
939 },
940 ];
941
942 let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
943 let defense = create_comprehensive_defense();
944 let config = create_default_adversarial_config();
945
946 let trainer = QuantumAdversarialTrainer::new(model, defense, config);
947
948 let input = Array1::from_vec(vec![0.5, 0.3, 0.8, 0.2]);
949 let adversarial_input = trainer.fgsm_attack(&input, 0, 0.1).unwrap();
950
951 assert_eq!(adversarial_input.len(), input.len());
952
953 let perturbation = &adversarial_input - &input;
955 let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
956 assert!(perturbation_norm > 0.0);
957
958 for &val in adversarial_input.iter() {
960 assert!(val >= 0.0 && val <= 1.0);
961 }
962 }
963
964 #[test]
965 fn test_defense_application() {
966 let layers = vec![
967 QNNLayerType::EncodingLayer { num_features: 4 },
968 QNNLayerType::VariationalLayer { num_params: 8 },
969 ];
970
971 let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
972
973 let defense = QuantumDefenseStrategy::InputPreprocessing {
974 noise_addition: 0.05,
975 feature_squeezing: true,
976 };
977
978 let config = create_default_adversarial_config();
979 let trainer = QuantumAdversarialTrainer::new(model, defense, config);
980
981 let input = Array1::from_vec(vec![0.51, 0.32, 0.83, 0.24]);
982 let defended_input = trainer.apply_defense(&input).unwrap();
983
984 assert_eq!(defended_input.len(), input.len());
985
986 let difference = (&defended_input - &input).mapv(|x| x.abs()).sum();
988 assert!(difference > 0.0);
989 }
990
991 #[test]
992 fn test_robustness_metrics() {
993 let metrics = RobustnessMetrics {
994 clean_accuracy: 0.85,
995 robust_accuracy: 0.65,
996 avg_perturbation_norm: 0.12,
997 attack_success_rate: 0.35,
998 certified_accuracy: Some(0.55),
999 per_attack_metrics: HashMap::new(),
1000 };
1001
1002 assert_eq!(metrics.clean_accuracy, 0.85);
1003 assert_eq!(metrics.robust_accuracy, 0.65);
1004 assert!(metrics.robust_accuracy < metrics.clean_accuracy);
1005 assert!(metrics.attack_success_rate < 0.5);
1006 }
1007}