1use crate::autodiff::optimizers::Optimizer;
8use crate::error::{MLError, Result};
9use crate::optimization::OptimizationMethod;
10use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
11use quantrs2_circuit::builder::{Circuit, Simulator};
12use quantrs2_core::gate::{
13 single::{RotationX, RotationY, RotationZ},
14 GateOp,
15};
16use quantrs2_sim::statevector::StateVectorSimulator;
17use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
18use scirs2_core::random::prelude::*;
19use std::collections::HashMap;
20use std::f64::consts::PI;
21
22#[derive(Debug, Clone)]
24pub enum QuantumAttackType {
25 FGSM { epsilon: f64 },
27
28 PGD {
30 epsilon: f64,
31 alpha: f64,
32 num_steps: usize,
33 },
34
35 ParameterShift {
37 shift_magnitude: f64,
38 target_parameters: Option<Vec<usize>>,
39 },
40
41 StatePerturbation {
43 perturbation_strength: f64,
44 basis: String,
45 },
46
47 CircuitManipulation {
49 gate_error_rate: f64,
50 coherence_time: f64,
51 },
52
53 UniversalPerturbation {
55 perturbation_budget: f64,
56 success_rate_threshold: f64,
57 },
58}
59
60#[derive(Debug, Clone)]
62pub enum QuantumDefenseStrategy {
63 AdversarialTraining {
65 attack_types: Vec<QuantumAttackType>,
66 adversarial_ratio: f64,
67 },
68
69 QuantumErrorCorrection {
71 code_type: String,
72 correction_threshold: f64,
73 },
74
75 InputPreprocessing {
77 noise_addition: f64,
78 feature_squeezing: bool,
79 },
80
81 EnsembleDefense {
83 num_models: usize,
84 diversity_metric: String,
85 },
86
87 CertifiedDefense {
89 smoothing_variance: f64,
90 confidence_level: f64,
91 },
92
93 RandomizedCircuit {
95 randomization_strength: f64,
96 num_random_layers: usize,
97 },
98}
99
100#[derive(Debug, Clone)]
102pub struct QuantumAdversarialExample {
103 pub original_input: Array1<f64>,
105
106 pub adversarial_input: Array1<f64>,
108
109 pub original_prediction: Array1<f64>,
111
112 pub adversarial_prediction: Array1<f64>,
114
115 pub true_label: usize,
117
118 pub perturbation_norm: f64,
120
121 pub attack_success: bool,
123
124 pub metadata: HashMap<String, f64>,
126}
127
128pub struct QuantumAdversarialTrainer {
130 model: QuantumNeuralNetwork,
132
133 defense_strategy: QuantumDefenseStrategy,
135
136 config: AdversarialTrainingConfig,
138
139 attack_history: Vec<QuantumAdversarialExample>,
141
142 robustness_metrics: RobustnessMetrics,
144
145 ensemble_models: Vec<QuantumNeuralNetwork>,
147}
148
149#[derive(Debug, Clone)]
151pub struct AdversarialTrainingConfig {
152 pub epochs: usize,
154
155 pub batch_size: usize,
157
158 pub learning_rate: f64,
160
161 pub adversarial_frequency: usize,
163
164 pub max_perturbation: f64,
166
167 pub eval_interval: usize,
169
170 pub early_stopping: Option<EarlyStoppingCriteria>,
172}
173
174#[derive(Debug, Clone)]
176pub struct EarlyStoppingCriteria {
177 pub min_clean_accuracy: f64,
179
180 pub min_robust_accuracy: f64,
182
183 pub patience: usize,
185}
186
187#[derive(Debug, Clone)]
189pub struct RobustnessMetrics {
190 pub clean_accuracy: f64,
192
193 pub robust_accuracy: f64,
195
196 pub avg_perturbation_norm: f64,
198
199 pub attack_success_rate: f64,
201
202 pub certified_accuracy: Option<f64>,
204
205 pub per_attack_metrics: HashMap<String, AttackMetrics>,
207}
208
209#[derive(Debug, Clone)]
211pub struct AttackMetrics {
212 pub success_rate: f64,
214
215 pub avg_perturbation: f64,
217
218 pub avg_confidence_drop: f64,
220}
221
222impl QuantumAdversarialTrainer {
223 pub fn new(
225 model: QuantumNeuralNetwork,
226 defense_strategy: QuantumDefenseStrategy,
227 config: AdversarialTrainingConfig,
228 ) -> Self {
229 Self {
230 model,
231 defense_strategy,
232 config,
233 attack_history: Vec::new(),
234 robustness_metrics: RobustnessMetrics {
235 clean_accuracy: 0.0,
236 robust_accuracy: 0.0,
237 avg_perturbation_norm: 0.0,
238 attack_success_rate: 0.0,
239 certified_accuracy: None,
240 per_attack_metrics: HashMap::new(),
241 },
242 ensemble_models: Vec::new(),
243 }
244 }
245
246 pub fn train(
248 &mut self,
249 train_data: &Array2<f64>,
250 train_labels: &Array1<usize>,
251 val_data: &Array2<f64>,
252 val_labels: &Array1<usize>,
253 optimizer: &mut dyn Optimizer,
254 ) -> Result<Vec<f64>> {
255 println!("Starting quantum adversarial training...");
256
257 let mut losses = Vec::new();
258 let mut patience_counter = 0;
259 let mut best_robust_accuracy = 0.0;
260
261 self.initialize_ensemble()?;
263
264 for epoch in 0..self.config.epochs {
265 let mut epoch_loss = 0.0;
266 let num_batches =
267 (train_data.nrows() + self.config.batch_size - 1) / self.config.batch_size;
268
269 for batch_idx in 0..num_batches {
270 let batch_start = batch_idx * self.config.batch_size;
271 let batch_end = (batch_start + self.config.batch_size).min(train_data.nrows());
272
273 let batch_data = train_data.slice(s![batch_start..batch_end, ..]).to_owned();
274 let batch_labels = train_labels.slice(s![batch_start..batch_end]).to_owned();
275
276 let (final_data, final_labels) = if epoch % self.config.adversarial_frequency == 0 {
278 self.generate_adversarial_batch(&batch_data, &batch_labels)?
279 } else {
280 (batch_data, batch_labels)
281 };
282
283 let batch_loss = self.train_batch(&final_data, &final_labels, optimizer)?;
285 epoch_loss += batch_loss;
286 }
287
288 epoch_loss /= num_batches as f64;
289 losses.push(epoch_loss);
290
291 if epoch % self.config.eval_interval == 0 {
293 self.evaluate_robustness(val_data, val_labels)?;
294
295 println!(
296 "Epoch {}: Loss = {:.4}, Clean Acc = {:.3}, Robust Acc = {:.3}",
297 epoch,
298 epoch_loss,
299 self.robustness_metrics.clean_accuracy,
300 self.robustness_metrics.robust_accuracy
301 );
302
303 if let Some(ref criteria) = self.config.early_stopping {
305 if self.robustness_metrics.robust_accuracy > best_robust_accuracy {
306 best_robust_accuracy = self.robustness_metrics.robust_accuracy;
307 patience_counter = 0;
308 } else {
309 patience_counter += 1;
310 }
311
312 if patience_counter >= criteria.patience {
313 println!("Early stopping triggered at epoch {}", epoch);
314 break;
315 }
316
317 if self.robustness_metrics.clean_accuracy < criteria.min_clean_accuracy
318 || self.robustness_metrics.robust_accuracy < criteria.min_robust_accuracy
319 {
320 println!("Minimum performance criteria not met, stopping training");
321 break;
322 }
323 }
324 }
325 }
326
327 self.evaluate_robustness(val_data, val_labels)?;
329
330 Ok(losses)
331 }
332
333 pub fn generate_adversarial_examples(
335 &self,
336 data: &Array2<f64>,
337 labels: &Array1<usize>,
338 attack_type: QuantumAttackType,
339 ) -> Result<Vec<QuantumAdversarialExample>> {
340 let mut adversarial_examples = Vec::new();
341
342 for (i, (input, &label)) in data.outer_iter().zip(labels.iter()).enumerate() {
343 let adversarial_example = self.generate_single_adversarial_example(
344 &input.to_owned(),
345 label,
346 attack_type.clone(),
347 )?;
348
349 adversarial_examples.push(adversarial_example);
350 }
351
352 Ok(adversarial_examples)
353 }
354
355 fn generate_single_adversarial_example(
357 &self,
358 input: &Array1<f64>,
359 true_label: usize,
360 attack_type: QuantumAttackType,
361 ) -> Result<QuantumAdversarialExample> {
362 let original_prediction = self.model.forward(input)?;
364
365 let adversarial_input = match attack_type {
366 QuantumAttackType::FGSM { epsilon } => self.fgsm_attack(input, true_label, epsilon)?,
367 QuantumAttackType::PGD {
368 epsilon,
369 alpha,
370 num_steps,
371 } => self.pgd_attack(input, true_label, epsilon, alpha, num_steps)?,
372 QuantumAttackType::ParameterShift {
373 shift_magnitude,
374 target_parameters,
375 } => self.parameter_shift_attack(input, shift_magnitude, target_parameters)?,
376 QuantumAttackType::StatePerturbation {
377 perturbation_strength,
378 ref basis,
379 } => self.state_perturbation_attack(input, perturbation_strength, basis)?,
380 QuantumAttackType::CircuitManipulation {
381 gate_error_rate,
382 coherence_time,
383 } => self.circuit_manipulation_attack(input, gate_error_rate, coherence_time)?,
384 QuantumAttackType::UniversalPerturbation {
385 perturbation_budget,
386 success_rate_threshold,
387 } => self.universal_perturbation_attack(input, perturbation_budget)?,
388 };
389
390 let adversarial_prediction = self.model.forward(&adversarial_input)?;
392
393 let perturbation = &adversarial_input - input;
395 let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
396
397 let original_class = original_prediction
399 .iter()
400 .enumerate()
401 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
402 .map(|(i, _)| i)
403 .unwrap_or(0);
404
405 let adversarial_class = adversarial_prediction
406 .iter()
407 .enumerate()
408 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
409 .map(|(i, _)| i)
410 .unwrap_or(0);
411
412 let attack_success = original_class != adversarial_class;
413
414 Ok(QuantumAdversarialExample {
415 original_input: input.clone(),
416 adversarial_input,
417 original_prediction,
418 adversarial_prediction,
419 true_label,
420 perturbation_norm,
421 attack_success,
422 metadata: HashMap::new(),
423 })
424 }
425
426 fn fgsm_attack(
428 &self,
429 input: &Array1<f64>,
430 true_label: usize,
431 epsilon: f64,
432 ) -> Result<Array1<f64>> {
433 let gradient = self.compute_input_gradient(input, true_label)?;
435
436 let perturbation = gradient.mapv(|g| epsilon * g.signum());
438 let adversarial_input = input + &perturbation;
439
440 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
442 }
443
444 fn pgd_attack(
446 &self,
447 input: &Array1<f64>,
448 true_label: usize,
449 epsilon: f64,
450 alpha: f64,
451 num_steps: usize,
452 ) -> Result<Array1<f64>> {
453 let mut adversarial_input = input.clone();
454
455 for _ in 0..num_steps {
456 let gradient = self.compute_input_gradient(&adversarial_input, true_label)?;
458
459 let perturbation = gradient.mapv(|g| alpha * g.signum());
461 adversarial_input = &adversarial_input + &perturbation;
462
463 let total_perturbation = &adversarial_input - input;
465 let perturbation_norm = total_perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
466
467 if perturbation_norm > epsilon {
468 let scaling = epsilon / perturbation_norm;
469 adversarial_input = input + &(total_perturbation * scaling);
470 }
471
472 adversarial_input = adversarial_input.mapv(|x| x.max(0.0).min(1.0));
474 }
475
476 Ok(adversarial_input)
477 }
478
479 fn parameter_shift_attack(
481 &self,
482 input: &Array1<f64>,
483 shift_magnitude: f64,
484 target_parameters: Option<Vec<usize>>,
485 ) -> Result<Array1<f64>> {
486 let mut adversarial_input = input.clone();
488
489 for i in 0..adversarial_input.len() {
491 if let Some(ref targets) = target_parameters {
492 if !targets.contains(&i) {
493 continue;
494 }
495 }
496
497 let shift = shift_magnitude * (PI / 2.0);
499 adversarial_input[i] += shift * (2.0 * thread_rng().gen::<f64>() - 1.0);
500 }
501
502 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
503 }
504
505 fn state_perturbation_attack(
507 &self,
508 input: &Array1<f64>,
509 perturbation_strength: f64,
510 basis: &str,
511 ) -> Result<Array1<f64>> {
512 let mut adversarial_input = input.clone();
513
514 match basis {
515 "pauli_x" => {
516 for i in 0..adversarial_input.len() {
518 let angle = adversarial_input[i] * PI;
519 let perturbed_angle =
520 angle + perturbation_strength * (2.0 * thread_rng().gen::<f64>() - 1.0);
521 adversarial_input[i] = perturbed_angle / PI;
522 }
523 }
524 "pauli_y" => {
525 for i in 0..adversarial_input.len() {
527 adversarial_input[i] +=
528 perturbation_strength * (2.0 * thread_rng().gen::<f64>() - 1.0);
529 }
530 }
531 "pauli_z" | _ => {
532 for i in 0..adversarial_input.len() {
534 let phase_shift =
535 perturbation_strength * (2.0 * thread_rng().gen::<f64>() - 1.0);
536 adversarial_input[i] =
537 (adversarial_input[i] + phase_shift / (2.0 * PI)).fract();
538 }
539 }
540 }
541
542 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
543 }
544
545 fn circuit_manipulation_attack(
547 &self,
548 input: &Array1<f64>,
549 gate_error_rate: f64,
550 coherence_time: f64,
551 ) -> Result<Array1<f64>> {
552 let mut adversarial_input = input.clone();
553
554 for i in 0..adversarial_input.len() {
556 let t1_factor = (-1.0 / coherence_time).exp();
558 adversarial_input[i] *= t1_factor;
559
560 if thread_rng().gen::<f64>() < gate_error_rate {
562 adversarial_input[i] += 0.1 * (2.0 * thread_rng().gen::<f64>() - 1.0);
563 }
564 }
565
566 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
567 }
568
569 fn universal_perturbation_attack(
571 &self,
572 input: &Array1<f64>,
573 perturbation_budget: f64,
574 ) -> Result<Array1<f64>> {
575 let mut adversarial_input = input.clone();
577
578 for i in 0..adversarial_input.len() {
580 let universal_component =
581 perturbation_budget * (2.0 * PI * i as f64 / adversarial_input.len() as f64).sin();
582 adversarial_input[i] += universal_component;
583 }
584
585 Ok(adversarial_input.mapv(|x| x.max(0.0).min(1.0)))
586 }
587
588 fn compute_input_gradient(
590 &self,
591 input: &Array1<f64>,
592 true_label: usize,
593 ) -> Result<Array1<f64>> {
594 let mut gradient = Array1::zeros(input.len());
597
598 let h = 1e-5;
600 let original_output = self.model.forward(input)?;
601 let original_loss = self.compute_loss(&original_output, true_label);
602
603 for i in 0..input.len() {
604 let mut perturbed_input = input.clone();
605 perturbed_input[i] += h;
606
607 let perturbed_output = self.model.forward(&perturbed_input)?;
608 let perturbed_loss = self.compute_loss(&perturbed_output, true_label);
609
610 gradient[i] = (perturbed_loss - original_loss) / h;
611 }
612
613 Ok(gradient)
614 }
615
616 fn compute_loss(&self, output: &Array1<f64>, true_label: usize) -> f64 {
618 let predicted_prob = output[true_label].max(1e-10);
620 -predicted_prob.ln()
621 }
622
623 fn generate_adversarial_batch(
625 &self,
626 data: &Array2<f64>,
627 labels: &Array1<usize>,
628 ) -> Result<(Array2<f64>, Array1<usize>)> {
629 match &self.defense_strategy {
630 QuantumDefenseStrategy::AdversarialTraining {
631 attack_types,
632 adversarial_ratio,
633 } => {
634 let num_adversarial = (data.nrows() as f64 * adversarial_ratio) as usize;
635 let mut combined_data = data.clone();
636 let mut combined_labels = labels.clone();
637
638 for i in 0..num_adversarial {
640 let idx = i % data.nrows();
641 let input = data.row(idx).to_owned();
642 let label = labels[idx];
643
644 let attack_type = attack_types[fastrand::usize(0..attack_types.len())].clone();
646 let adversarial_example =
647 self.generate_single_adversarial_example(&input, label, attack_type)?;
648
649 combined_data
651 .row_mut(idx)
652 .assign(&adversarial_example.adversarial_input);
653 }
654
655 Ok((combined_data, combined_labels))
656 }
657 _ => Ok((data.clone(), labels.clone())),
658 }
659 }
660
661 fn train_batch(
663 &mut self,
664 data: &Array2<f64>,
665 labels: &Array1<usize>,
666 optimizer: &mut dyn Optimizer,
667 ) -> Result<f64> {
668 let mut total_loss = 0.0;
670
671 for (input, &label) in data.outer_iter().zip(labels.iter()) {
672 let output = self.model.forward(&input.to_owned())?;
673 let loss = self.compute_loss(&output, label);
674 total_loss += loss;
675
676 }
679
680 Ok(total_loss / data.nrows() as f64)
681 }
682
683 fn initialize_ensemble(&mut self) -> Result<()> {
685 if let QuantumDefenseStrategy::EnsembleDefense { num_models, .. } = &self.defense_strategy {
686 for _ in 0..*num_models {
687 let model = self.model.clone();
689 self.ensemble_models.push(model);
690 }
691 }
692 Ok(())
693 }
694
695 fn evaluate_robustness(
697 &mut self,
698 val_data: &Array2<f64>,
699 val_labels: &Array1<usize>,
700 ) -> Result<()> {
701 let mut clean_correct = 0;
702 let mut robust_correct = 0;
703 let mut total_perturbation = 0.0;
704 let mut successful_attacks = 0;
705
706 let test_attacks = vec![
708 QuantumAttackType::FGSM { epsilon: 0.1 },
709 QuantumAttackType::PGD {
710 epsilon: 0.1,
711 alpha: 0.01,
712 num_steps: 10,
713 },
714 ];
715
716 for (input, &label) in val_data.outer_iter().zip(val_labels.iter()) {
717 let input_owned = input.to_owned();
718
719 let clean_output = self.model.forward(&input_owned)?;
721 let clean_pred = clean_output
722 .iter()
723 .enumerate()
724 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
725 .map(|(i, _)| i)
726 .unwrap_or(0);
727
728 if clean_pred == label {
729 clean_correct += 1;
730 }
731
732 let mut robust_for_this_input = true;
734 for attack_type in &test_attacks {
735 let adversarial_example = self.generate_single_adversarial_example(
736 &input_owned,
737 label,
738 attack_type.clone(),
739 )?;
740
741 total_perturbation += adversarial_example.perturbation_norm;
742
743 if adversarial_example.attack_success {
744 successful_attacks += 1;
745 robust_for_this_input = false;
746 }
747 }
748
749 if robust_for_this_input {
750 robust_correct += 1;
751 }
752 }
753
754 let num_samples = val_data.nrows();
755 let num_attack_tests = num_samples * test_attacks.len();
756
757 self.robustness_metrics.clean_accuracy = clean_correct as f64 / num_samples as f64;
758 self.robustness_metrics.robust_accuracy = robust_correct as f64 / num_samples as f64;
759 self.robustness_metrics.avg_perturbation_norm =
760 total_perturbation / num_attack_tests as f64;
761 self.robustness_metrics.attack_success_rate =
762 successful_attacks as f64 / num_attack_tests as f64;
763
764 Ok(())
765 }
766
767 pub fn apply_defense(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
769 match &self.defense_strategy {
770 QuantumDefenseStrategy::InputPreprocessing {
771 noise_addition,
772 feature_squeezing,
773 } => {
774 let mut defended_input = input.clone();
775
776 for i in 0..defended_input.len() {
778 defended_input[i] += noise_addition * (2.0 * thread_rng().gen::<f64>() - 1.0);
779 }
780
781 if *feature_squeezing {
783 defended_input = defended_input.mapv(|x| (x * 8.0).round() / 8.0);
784 }
785
786 Ok(defended_input.mapv(|x| x.max(0.0).min(1.0)))
787 }
788 QuantumDefenseStrategy::RandomizedCircuit {
789 randomization_strength,
790 ..
791 } => {
792 let mut defended_input = input.clone();
793
794 for i in 0..defended_input.len() {
796 let random_shift =
797 randomization_strength * (2.0 * thread_rng().gen::<f64>() - 1.0);
798 defended_input[i] += random_shift;
799 }
800
801 Ok(defended_input.mapv(|x| x.max(0.0).min(1.0)))
802 }
803 _ => Ok(input.clone()),
804 }
805 }
806
807 pub fn get_robustness_metrics(&self) -> &RobustnessMetrics {
809 &self.robustness_metrics
810 }
811
812 pub fn get_attack_history(&self) -> &[QuantumAdversarialExample] {
814 &self.attack_history
815 }
816
817 pub fn certified_defense_analysis(
819 &self,
820 data: &Array2<f64>,
821 smoothing_variance: f64,
822 num_samples: usize,
823 ) -> Result<f64> {
824 let mut certified_correct = 0;
825
826 for input in data.outer_iter() {
827 let input_owned = input.to_owned();
828
829 let mut predictions = Vec::new();
831 for _ in 0..num_samples {
832 let mut noisy_input = input_owned.clone();
833 for i in 0..noisy_input.len() {
834 let noise = fastrand::f64() * smoothing_variance;
835 noisy_input[i] += noise;
836 }
837
838 let output = self.model.forward(&noisy_input)?;
839 let pred = output
840 .iter()
841 .enumerate()
842 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
843 .map(|(i, _)| i)
844 .unwrap_or(0);
845
846 predictions.push(pred);
847 }
848
849 let mut counts = vec![0; 10]; for &pred in &predictions {
852 if pred < counts.len() {
853 counts[pred] += 1;
854 }
855 }
856
857 let max_count = counts.iter().max().unwrap_or(&0);
858 let certification_threshold = (num_samples as f64 * 0.6) as usize;
859
860 if *max_count >= certification_threshold {
861 certified_correct += 1;
862 }
863 }
864
865 Ok(certified_correct as f64 / data.nrows() as f64)
866 }
867}
868
869pub fn create_default_adversarial_config() -> AdversarialTrainingConfig {
871 AdversarialTrainingConfig {
872 epochs: 100,
873 batch_size: 32,
874 learning_rate: 0.001,
875 adversarial_frequency: 2,
876 max_perturbation: 0.1,
877 eval_interval: 10,
878 early_stopping: Some(EarlyStoppingCriteria {
879 min_clean_accuracy: 0.7,
880 min_robust_accuracy: 0.5,
881 patience: 20,
882 }),
883 }
884}
885
886pub fn create_comprehensive_defense() -> QuantumDefenseStrategy {
888 QuantumDefenseStrategy::AdversarialTraining {
889 attack_types: vec![
890 QuantumAttackType::FGSM { epsilon: 0.1 },
891 QuantumAttackType::PGD {
892 epsilon: 0.1,
893 alpha: 0.01,
894 num_steps: 7,
895 },
896 QuantumAttackType::ParameterShift {
897 shift_magnitude: 0.05,
898 target_parameters: None,
899 },
900 ],
901 adversarial_ratio: 0.5,
902 }
903}
904
905#[cfg(test)]
906mod tests {
907 use super::*;
908 use crate::qnn::QNNLayerType;
909
910 #[test]
911 fn test_adversarial_example_creation() {
912 let original_input = Array1::from_vec(vec![0.5, 0.3, 0.8, 0.2]);
913 let adversarial_input = Array1::from_vec(vec![0.6, 0.4, 0.7, 0.3]);
914 let original_prediction = Array1::from_vec(vec![0.8, 0.2]);
915 let adversarial_prediction = Array1::from_vec(vec![0.3, 0.7]);
916
917 let perturbation = &adversarial_input - &original_input;
918 let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
919
920 let example = QuantumAdversarialExample {
921 original_input,
922 adversarial_input,
923 original_prediction,
924 adversarial_prediction,
925 true_label: 0,
926 perturbation_norm,
927 attack_success: true,
928 metadata: HashMap::new(),
929 };
930
931 assert!(example.attack_success);
932 assert!(example.perturbation_norm > 0.0);
933 }
934
935 #[test]
936 fn test_fgsm_attack() {
937 let layers = vec![
938 QNNLayerType::EncodingLayer { num_features: 4 },
939 QNNLayerType::VariationalLayer { num_params: 8 },
940 QNNLayerType::MeasurementLayer {
941 measurement_basis: "computational".to_string(),
942 },
943 ];
944
945 let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
946 let defense = create_comprehensive_defense();
947 let config = create_default_adversarial_config();
948
949 let trainer = QuantumAdversarialTrainer::new(model, defense, config);
950
951 let input = Array1::from_vec(vec![0.5, 0.3, 0.8, 0.2]);
952 let adversarial_input = trainer.fgsm_attack(&input, 0, 0.1).unwrap();
953
954 assert_eq!(adversarial_input.len(), input.len());
955
956 let perturbation = &adversarial_input - &input;
958 let perturbation_norm = perturbation.mapv(|x| (x as f64).powi(2)).sum().sqrt();
959 assert!(perturbation_norm > 0.0);
960
961 for &val in adversarial_input.iter() {
963 assert!(val >= 0.0 && val <= 1.0);
964 }
965 }
966
967 #[test]
968 fn test_defense_application() {
969 let layers = vec![
970 QNNLayerType::EncodingLayer { num_features: 4 },
971 QNNLayerType::VariationalLayer { num_params: 8 },
972 ];
973
974 let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
975
976 let defense = QuantumDefenseStrategy::InputPreprocessing {
977 noise_addition: 0.05,
978 feature_squeezing: true,
979 };
980
981 let config = create_default_adversarial_config();
982 let trainer = QuantumAdversarialTrainer::new(model, defense, config);
983
984 let input = Array1::from_vec(vec![0.51, 0.32, 0.83, 0.24]);
985 let defended_input = trainer.apply_defense(&input).unwrap();
986
987 assert_eq!(defended_input.len(), input.len());
988
989 let difference = (&defended_input - &input).mapv(|x| x.abs()).sum();
991 assert!(difference > 0.0);
992 }
993
994 #[test]
995 fn test_robustness_metrics() {
996 let metrics = RobustnessMetrics {
997 clean_accuracy: 0.85,
998 robust_accuracy: 0.65,
999 avg_perturbation_norm: 0.12,
1000 attack_success_rate: 0.35,
1001 certified_accuracy: Some(0.55),
1002 per_attack_metrics: HashMap::new(),
1003 };
1004
1005 assert_eq!(metrics.clean_accuracy, 0.85);
1006 assert_eq!(metrics.robust_accuracy, 0.65);
1007 assert!(metrics.robust_accuracy < metrics.clean_accuracy);
1008 assert!(metrics.attack_success_rate < 0.5);
1009 }
1010}