1use crate::{
8 error::QuantRS2Result, gate::multi::*, gate::single::*, gate::GateOp, qubit::QubitId,
9 variational::VariationalOptimizer,
10};
11use ndarray::{Array1, Array2, Axis};
12use rand::{rngs::StdRng, Rng, SeedableRng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct QGANConfig {
19 pub generator_qubits: usize,
21 pub discriminator_qubits: usize,
23 pub latent_qubits: usize,
25 pub data_qubits: usize,
27 pub generator_lr: f64,
29 pub discriminator_lr: f64,
31 pub generator_depth: usize,
33 pub discriminator_depth: usize,
35 pub batch_size: usize,
37 pub max_iterations: usize,
39 pub generator_frequency: usize,
41 pub use_quantum_advantage: bool,
43 pub noise_type: NoiseType,
45 pub regularization: f64,
47 pub random_seed: Option<u64>,
49}
50
51impl Default for QGANConfig {
52 fn default() -> Self {
53 Self {
54 generator_qubits: 6,
55 discriminator_qubits: 6,
56 latent_qubits: 4,
57 data_qubits: 4,
58 generator_lr: 0.01,
59 discriminator_lr: 0.01,
60 generator_depth: 8,
61 discriminator_depth: 6,
62 batch_size: 16,
63 max_iterations: 1000,
64 generator_frequency: 1,
65 use_quantum_advantage: true,
66 noise_type: NoiseType::Gaussian,
67 regularization: 0.001,
68 random_seed: None,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum NoiseType {
76 Gaussian,
78 Uniform,
80 QuantumSuperposition,
82 BasisStates,
84}
85
86pub struct QGAN {
88 config: QGANConfig,
90 generator: QuantumGenerator,
92 discriminator: QuantumDiscriminator,
94 training_stats: QGANTrainingStats,
96 rng: StdRng,
98 iteration: usize,
100}
101
102pub struct QuantumGenerator {
104 circuit: QuantumGeneratorCircuit,
106 parameters: Array1<f64>,
108 #[allow(dead_code)]
110 optimizer: VariationalOptimizer,
111 gradient_history: VecDeque<Array1<f64>>,
113}
114
115pub struct QuantumDiscriminator {
117 circuit: QuantumDiscriminatorCircuit,
119 parameters: Array1<f64>,
121 #[allow(dead_code)]
123 optimizer: VariationalOptimizer,
124 gradient_history: VecDeque<Array1<f64>>,
126}
127
128#[derive(Debug, Clone)]
130pub struct QuantumGeneratorCircuit {
131 latent_qubits: usize,
133 data_qubits: usize,
135 depth: usize,
137 total_qubits: usize,
139 noise_type: NoiseType,
141}
142
143#[derive(Debug, Clone)]
145pub struct QuantumDiscriminatorCircuit {
146 data_qubits: usize,
148 aux_qubits: usize,
150 depth: usize,
152 total_qubits: usize,
154}
155
156#[derive(Debug, Clone, Default)]
158pub struct QGANTrainingStats {
159 pub generator_losses: Vec<f64>,
161 pub discriminator_losses: Vec<f64>,
163 pub fidelities: Vec<f64>,
165 pub iteration_times: Vec<f64>,
167 pub convergence_metrics: Vec<f64>,
169}
170
171#[derive(Debug, Clone)]
173pub struct QGANIterationMetrics {
174 pub generator_loss: f64,
176 pub discriminator_loss: f64,
178 pub real_accuracy: f64,
180 pub fake_accuracy: f64,
182 pub fidelity: f64,
184 pub iteration: usize,
186}
187
188impl QGAN {
189 pub fn new(config: QGANConfig) -> QuantRS2Result<Self> {
191 let rng = match config.random_seed {
192 Some(seed) => StdRng::seed_from_u64(seed),
193 None => StdRng::from_seed([0; 32]), };
195
196 let generator = QuantumGenerator::new(&config)?;
197 let discriminator = QuantumDiscriminator::new(&config)?;
198
199 Ok(Self {
200 config,
201 generator,
202 discriminator,
203 training_stats: QGANTrainingStats::default(),
204 rng,
205 iteration: 0,
206 })
207 }
208
209 pub fn train(&mut self, real_data: &Array2<f64>) -> QuantRS2Result<QGANIterationMetrics> {
211 let batch_size = self.config.batch_size.min(real_data.nrows());
212
213 let real_batch = self.sample_real_data_batch(real_data, batch_size)?;
215
216 let fake_batch = self.generate_fake_data_batch(batch_size)?;
218
219 let (d_loss_real, d_loss_fake) = self.train_discriminator(&real_batch, &fake_batch)?;
221 let discriminator_loss = d_loss_real + d_loss_fake;
222
223 let generator_loss = if self.iteration % self.config.generator_frequency == 0 {
225 self.train_generator(batch_size)?
226 } else {
227 0.0
228 };
229
230 let real_accuracy = self.compute_discriminator_accuracy(&real_batch, true)?;
232 let fake_accuracy = self.compute_discriminator_accuracy(&fake_batch, false)?;
233 let fidelity = self.compute_fidelity(&real_batch, &fake_batch)?;
234
235 self.training_stats.generator_losses.push(generator_loss);
237 self.training_stats
238 .discriminator_losses
239 .push(discriminator_loss);
240 self.training_stats.fidelities.push(fidelity);
241
242 let metrics = QGANIterationMetrics {
243 generator_loss,
244 discriminator_loss,
245 real_accuracy,
246 fake_accuracy,
247 fidelity,
248 iteration: self.iteration,
249 };
250
251 self.iteration += 1;
252
253 Ok(metrics)
254 }
255
256 pub fn generate_data(&mut self, num_samples: usize) -> QuantRS2Result<Array2<f64>> {
258 self.generate_fake_data_batch(num_samples)
259 }
260
261 pub fn discriminate(&self, data: &Array2<f64>) -> QuantRS2Result<Array1<f64>> {
263 let num_samples = data.nrows();
264 let mut scores = Array1::zeros(num_samples);
265
266 for i in 0..num_samples {
267 let sample = data.row(i).to_owned();
268 scores[i] = self.discriminator.discriminate(&sample)?;
269 }
270
271 Ok(scores)
272 }
273
274 fn sample_real_data_batch(
276 &mut self,
277 real_data: &Array2<f64>,
278 batch_size: usize,
279 ) -> QuantRS2Result<Array2<f64>> {
280 let num_samples = real_data.nrows();
281 let mut batch = Array2::zeros((batch_size, real_data.ncols()));
282
283 for i in 0..batch_size {
284 let idx = self.rng.random_range(0..num_samples);
285 batch.row_mut(i).assign(&real_data.row(idx));
286 }
287
288 Ok(batch)
289 }
290
291 fn generate_fake_data_batch(&mut self, batch_size: usize) -> QuantRS2Result<Array2<f64>> {
293 let mut fake_batch = Array2::zeros((batch_size, self.config.data_qubits));
294
295 for i in 0..batch_size {
296 let noise = self.sample_noise()?;
297 let generated_sample = self.generator.generate(&noise)?;
298 fake_batch.row_mut(i).assign(&generated_sample);
299 }
300
301 Ok(fake_batch)
302 }
303
304 fn sample_noise(&mut self) -> QuantRS2Result<Array1<f64>> {
306 let mut noise = Array1::zeros(self.config.latent_qubits);
307
308 match self.config.noise_type {
309 NoiseType::Gaussian => {
310 for i in 0..self.config.latent_qubits {
311 noise[i] = self.rng.random::<f64>() * 2.0 - 1.0; }
313 }
314 NoiseType::Uniform => {
315 for i in 0..self.config.latent_qubits {
316 noise[i] = self.rng.random::<f64>() * 2.0 * std::f64::consts::PI
317 - std::f64::consts::PI;
318 }
319 }
320 NoiseType::QuantumSuperposition => {
321 for i in 0..self.config.latent_qubits {
323 noise[i] = std::f64::consts::PI / 2.0; }
325 }
326 NoiseType::BasisStates => {
327 let state = self.rng.random_range(0..(1 << self.config.latent_qubits));
329 for i in 0..self.config.latent_qubits {
330 noise[i] = if (state >> i) & 1 == 1 {
331 std::f64::consts::PI
332 } else {
333 0.0
334 };
335 }
336 }
337 }
338
339 Ok(noise)
340 }
341
342 fn train_discriminator(
344 &mut self,
345 real_batch: &Array2<f64>,
346 fake_batch: &Array2<f64>,
347 ) -> QuantRS2Result<(f64, f64)> {
348 let d_loss_real = self
350 .discriminator
351 .train_batch(real_batch, &Array1::ones(real_batch.nrows()))?;
352
353 let d_loss_fake = self
355 .discriminator
356 .train_batch(fake_batch, &Array1::zeros(fake_batch.nrows()))?;
357
358 Ok((d_loss_real, d_loss_fake))
359 }
360
361 fn train_generator(&mut self, batch_size: usize) -> QuantRS2Result<f64> {
363 let fake_batch = self.generate_fake_data_batch(batch_size)?;
365
366 let discriminator_scores = self.discriminate(&fake_batch)?;
368
369 let targets = Array1::ones(batch_size);
371 let generator_loss =
372 self.generator
373 .train_adversarial(&fake_batch, &targets, &discriminator_scores)?;
374
375 Ok(generator_loss)
376 }
377
378 fn compute_discriminator_accuracy(
380 &self,
381 data: &Array2<f64>,
382 is_real: bool,
383 ) -> QuantRS2Result<f64> {
384 let scores = self.discriminate(data)?;
385 let threshold = 0.5;
386 let _target = if is_real { 1.0 } else { 0.0 };
387
388 let correct = scores
389 .iter()
390 .map(|&score| {
391 if (score > threshold) == is_real {
392 1.0
393 } else {
394 0.0
395 }
396 })
397 .sum::<f64>();
398
399 Ok(correct / data.nrows() as f64)
400 }
401
402 fn compute_fidelity(
404 &self,
405 real_batch: &Array2<f64>,
406 fake_batch: &Array2<f64>,
407 ) -> QuantRS2Result<f64> {
408 let real_mean = real_batch.mean_axis(Axis(0)).unwrap();
410 let fake_mean = fake_batch.mean_axis(Axis(0)).unwrap();
411
412 let real_var = real_batch.var_axis(Axis(0), 0.0);
413 let fake_var = fake_batch.var_axis(Axis(0), 0.0);
414
415 let mean_diff = (&real_mean - &fake_mean).mapv(|x| x.powi(2)).sum().sqrt();
417 let var_diff = (&real_var - &fake_var).mapv(|x| x.powi(2)).sum().sqrt();
418
419 let fidelity = (-0.5 * (mean_diff + var_diff)).exp();
420
421 Ok(fidelity)
422 }
423
424 pub fn get_training_stats(&self) -> &QGANTrainingStats {
426 &self.training_stats
427 }
428
429 pub fn has_converged(&self, tolerance: f64, window: usize) -> bool {
431 if self.training_stats.fidelities.len() < window {
432 return false;
433 }
434
435 let recent_fidelities =
436 &self.training_stats.fidelities[self.training_stats.fidelities.len() - window..];
437 let mean_fidelity = recent_fidelities.iter().sum::<f64>() / window as f64;
438
439 mean_fidelity > 1.0 - tolerance
440 }
441}
442
443impl QuantumGenerator {
444 fn new(config: &QGANConfig) -> QuantRS2Result<Self> {
446 let circuit = QuantumGeneratorCircuit::new(
447 config.latent_qubits,
448 config.data_qubits,
449 config.generator_depth,
450 config.noise_type,
451 )?;
452
453 let num_parameters = circuit.get_parameter_count();
454 let mut parameters = Array1::zeros(num_parameters);
455
456 let mut rng = match config.random_seed {
458 Some(seed) => StdRng::seed_from_u64(seed),
459 None => StdRng::from_seed([0; 32]),
460 };
461
462 for param in parameters.iter_mut() {
463 *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
464 }
465
466 let optimizer = VariationalOptimizer::new(0.01, 0.9);
467 let gradient_history = VecDeque::with_capacity(10);
468
469 Ok(Self {
470 circuit,
471 parameters,
472 optimizer,
473 gradient_history,
474 })
475 }
476
477 fn generate(&self, noise: &Array1<f64>) -> QuantRS2Result<Array1<f64>> {
479 self.circuit.generate_data(noise, &self.parameters)
480 }
481
482 fn train_adversarial(
484 &mut self,
485 generated_data: &Array2<f64>,
486 targets: &Array1<f64>,
487 discriminator_scores: &Array1<f64>,
488 ) -> QuantRS2Result<f64> {
489 let batch_size = generated_data.nrows();
490 let mut total_loss = 0.0;
491
492 for i in 0..batch_size {
493 let data_sample = generated_data.row(i).to_owned();
494 let target = targets[i];
495 let score = discriminator_scores[i];
496
497 let loss = (score - target).powi(2);
499 total_loss += loss;
500
501 let gradients = self.circuit.compute_adversarial_gradients(
503 &data_sample,
504 target,
505 score,
506 &self.parameters,
507 )?;
508 self.update_parameters(&gradients, 0.01)?; }
510
511 Ok(total_loss / batch_size as f64)
512 }
513
514 fn update_parameters(
516 &mut self,
517 gradients: &Array1<f64>,
518 learning_rate: f64,
519 ) -> QuantRS2Result<()> {
520 let mut effective_gradients = gradients.clone();
522
523 if let Some(prev_gradients) = self.gradient_history.back() {
524 let momentum = 0.9;
525 effective_gradients = &effective_gradients + &(prev_gradients * momentum);
526 }
527
528 for (param, &grad) in self.parameters.iter_mut().zip(effective_gradients.iter()) {
530 *param -= learning_rate * grad;
531 }
532
533 self.gradient_history.push_back(effective_gradients);
535 if self.gradient_history.len() > 10 {
536 self.gradient_history.pop_front();
537 }
538
539 Ok(())
540 }
541}
542
543impl QuantumDiscriminator {
544 fn new(config: &QGANConfig) -> QuantRS2Result<Self> {
546 let circuit = QuantumDiscriminatorCircuit::new(
547 config.data_qubits,
548 config.discriminator_qubits - config.data_qubits,
549 config.discriminator_depth,
550 )?;
551
552 let num_parameters = circuit.get_parameter_count();
553 let mut parameters = Array1::zeros(num_parameters);
554
555 let mut rng = match config.random_seed {
557 Some(seed) => StdRng::seed_from_u64(seed),
558 None => StdRng::from_seed([0; 32]),
559 };
560
561 for param in parameters.iter_mut() {
562 *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
563 }
564
565 let optimizer = VariationalOptimizer::new(0.01, 0.9);
566 let gradient_history = VecDeque::with_capacity(10);
567
568 Ok(Self {
569 circuit,
570 parameters,
571 optimizer,
572 gradient_history,
573 })
574 }
575
576 fn discriminate(&self, data: &Array1<f64>) -> QuantRS2Result<f64> {
578 self.circuit.discriminate_data(data, &self.parameters)
579 }
580
581 fn train_batch(
583 &mut self,
584 data_batch: &Array2<f64>,
585 targets: &Array1<f64>,
586 ) -> QuantRS2Result<f64> {
587 let batch_size = data_batch.nrows();
588 let mut total_loss = 0.0;
589
590 for i in 0..batch_size {
591 let data_sample = data_batch.row(i).to_owned();
592 let target = targets[i];
593
594 let prediction = self.discriminate(&data_sample)?;
596
597 let loss = -(target * prediction.ln() + (1.0 - target) * (1.0 - prediction).ln());
599 total_loss += loss;
600
601 let gradients = self.circuit.compute_discriminator_gradients(
603 &data_sample,
604 target,
605 prediction,
606 &self.parameters,
607 )?;
608 self.update_parameters(&gradients, 0.01)?; }
610
611 Ok(total_loss / batch_size as f64)
612 }
613
614 fn update_parameters(
616 &mut self,
617 gradients: &Array1<f64>,
618 learning_rate: f64,
619 ) -> QuantRS2Result<()> {
620 let mut effective_gradients = gradients.clone();
622
623 if let Some(prev_gradients) = self.gradient_history.back() {
624 let momentum = 0.9;
625 effective_gradients = &effective_gradients + &(prev_gradients * momentum);
626 }
627
628 for (param, &grad) in self.parameters.iter_mut().zip(effective_gradients.iter()) {
630 *param -= learning_rate * grad;
631 }
632
633 self.gradient_history.push_back(effective_gradients);
635 if self.gradient_history.len() > 10 {
636 self.gradient_history.pop_front();
637 }
638
639 Ok(())
640 }
641}
642
643impl QuantumGeneratorCircuit {
644 fn new(
646 latent_qubits: usize,
647 data_qubits: usize,
648 depth: usize,
649 noise_type: NoiseType,
650 ) -> QuantRS2Result<Self> {
651 let total_qubits = latent_qubits + data_qubits;
652
653 Ok(Self {
654 latent_qubits,
655 data_qubits,
656 depth,
657 total_qubits,
658 noise_type,
659 })
660 }
661
662 fn get_parameter_count(&self) -> usize {
664 let total_qubits = self.latent_qubits + self.data_qubits;
665 let rotations_per_layer = total_qubits * 3;
667 let entangling_per_layer = total_qubits; self.depth * (rotations_per_layer + entangling_per_layer)
669 }
670
671 fn generate_data(
673 &self,
674 noise: &Array1<f64>,
675 parameters: &Array1<f64>,
676 ) -> QuantRS2Result<Array1<f64>> {
677 let mut gates = Vec::new();
679
680 for i in 0..self.latent_qubits {
682 let noise_value = if i < noise.len() { noise[i] } else { 0.0 };
683 gates.push(Box::new(RotationY {
684 target: QubitId(i as u32),
685 theta: noise_value,
686 }) as Box<dyn GateOp>);
687 }
688
689 let mut param_idx = 0;
691 for _layer in 0..self.depth {
692 for qubit in 0..self.latent_qubits + self.data_qubits {
694 if param_idx + 2 < parameters.len() {
695 gates.push(Box::new(RotationX {
696 target: QubitId(qubit as u32),
697 theta: parameters[param_idx],
698 }) as Box<dyn GateOp>);
699 param_idx += 1;
700
701 gates.push(Box::new(RotationY {
702 target: QubitId(qubit as u32),
703 theta: parameters[param_idx],
704 }) as Box<dyn GateOp>);
705 param_idx += 1;
706
707 gates.push(Box::new(RotationZ {
708 target: QubitId(qubit as u32),
709 theta: parameters[param_idx],
710 }) as Box<dyn GateOp>);
711 param_idx += 1;
712 }
713 }
714
715 for qubit in 0..self.latent_qubits + self.data_qubits - 1 {
717 if param_idx < parameters.len() {
718 gates.push(Box::new(CRZ {
719 control: QubitId(qubit as u32),
720 target: QubitId((qubit + 1) as u32),
721 theta: parameters[param_idx],
722 }) as Box<dyn GateOp>);
723 param_idx += 1;
724 }
725 }
726 }
727
728 let generated_data = self.simulate_generation_circuit(&gates)?;
730
731 Ok(generated_data)
732 }
733
734 fn simulate_generation_circuit(
736 &self,
737 gates: &[Box<dyn GateOp>],
738 ) -> QuantRS2Result<Array1<f64>> {
739 let mut data = Array1::zeros(self.data_qubits);
741
742 let mut hash_value = 0u64;
743 for gate in gates {
744 if let Ok(matrix) = gate.matrix() {
745 for complex in &matrix {
746 hash_value = hash_value.wrapping_add((complex.re * 1000.0) as u64);
747 }
748 }
749 }
750
751 for i in 0..self.data_qubits {
753 let qubit_hash = hash_value.wrapping_add(i as u64);
754 data[i] = ((qubit_hash % 1000) as f64 / 1000.0) * 2.0 - 1.0; }
756
757 Ok(data)
758 }
759
760 fn compute_adversarial_gradients(
762 &self,
763 _data_sample: &Array1<f64>,
764 target: f64,
765 score: f64,
766 parameters: &Array1<f64>,
767 ) -> QuantRS2Result<Array1<f64>> {
768 let mut gradients = Array1::zeros(parameters.len());
769 let shift = std::f64::consts::PI / 2.0;
770
771 for i in 0..parameters.len() {
772 let mut params_plus = parameters.clone();
774 params_plus[i] += shift;
775 let data_plus = self.generate_data(&Array1::zeros(self.latent_qubits), ¶ms_plus)?;
776
777 let mut params_minus = parameters.clone();
778 params_minus[i] -= shift;
779 let data_minus =
780 self.generate_data(&Array1::zeros(self.latent_qubits), ¶ms_minus)?;
781
782 let loss_gradient = 2.0 * (score - target);
784
785 let data_diff = (&data_plus - &data_minus).sum() / 2.0;
787
788 gradients[i] = loss_gradient * data_diff;
789 }
790
791 Ok(gradients)
792 }
793}
794
795impl QuantumDiscriminatorCircuit {
796 fn new(data_qubits: usize, aux_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
798 let total_qubits = data_qubits + aux_qubits;
799
800 Ok(Self {
801 data_qubits,
802 aux_qubits,
803 depth,
804 total_qubits,
805 })
806 }
807
808 fn get_parameter_count(&self) -> usize {
810 let total_qubits = self.data_qubits + self.aux_qubits;
811 let rotations_per_layer = total_qubits * 3;
812 let entangling_per_layer = total_qubits;
813 self.depth * (rotations_per_layer + entangling_per_layer)
814 }
815
816 fn discriminate_data(
818 &self,
819 data: &Array1<f64>,
820 parameters: &Array1<f64>,
821 ) -> QuantRS2Result<f64> {
822 let mut gates = Vec::new();
824
825 for i in 0..self.data_qubits {
827 let data_value = if i < data.len() { data[i] } else { 0.0 };
828 gates.push(Box::new(RotationY {
829 target: QubitId(i as u32),
830 theta: data_value * std::f64::consts::PI,
831 }) as Box<dyn GateOp>);
832 }
833
834 let mut param_idx = 0;
836 for _layer in 0..self.depth {
837 for qubit in 0..self.data_qubits + self.aux_qubits {
839 if param_idx + 2 < parameters.len() {
840 gates.push(Box::new(RotationX {
841 target: QubitId(qubit as u32),
842 theta: parameters[param_idx],
843 }) as Box<dyn GateOp>);
844 param_idx += 1;
845
846 gates.push(Box::new(RotationY {
847 target: QubitId(qubit as u32),
848 theta: parameters[param_idx],
849 }) as Box<dyn GateOp>);
850 param_idx += 1;
851
852 gates.push(Box::new(RotationZ {
853 target: QubitId(qubit as u32),
854 theta: parameters[param_idx],
855 }) as Box<dyn GateOp>);
856 param_idx += 1;
857 }
858 }
859
860 for qubit in 0..self.data_qubits + self.aux_qubits - 1 {
862 if param_idx < parameters.len() {
863 gates.push(Box::new(CRZ {
864 control: QubitId(qubit as u32),
865 target: QubitId((qubit + 1) as u32),
866 theta: parameters[param_idx],
867 }) as Box<dyn GateOp>);
868 param_idx += 1;
869 }
870 }
871 }
872
873 let probability = self.simulate_discrimination_circuit(&gates)?;
875
876 Ok(probability)
877 }
878
879 fn simulate_discrimination_circuit(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<f64> {
881 let mut hash_value = 0u64;
883
884 for gate in gates {
885 if let Ok(matrix) = gate.matrix() {
886 for complex in &matrix {
887 hash_value = hash_value.wrapping_add((complex.re * 1000.0) as u64);
888 hash_value = hash_value.wrapping_add((complex.im * 1000.0) as u64);
889 }
890 }
891 }
892
893 let probability = ((hash_value % 1000) as f64) / 1000.0;
895
896 Ok(probability)
897 }
898
899 fn compute_discriminator_gradients(
901 &self,
902 data_sample: &Array1<f64>,
903 target: f64,
904 prediction: f64,
905 parameters: &Array1<f64>,
906 ) -> QuantRS2Result<Array1<f64>> {
907 let mut gradients = Array1::zeros(parameters.len());
908 let shift = std::f64::consts::PI / 2.0;
909
910 for i in 0..parameters.len() {
911 let mut params_plus = parameters.clone();
913 params_plus[i] += shift;
914 let pred_plus = self.discriminate_data(data_sample, ¶ms_plus)?;
915
916 let mut params_minus = parameters.clone();
917 params_minus[i] -= shift;
918 let pred_minus = self.discriminate_data(data_sample, ¶ms_minus)?;
919
920 let pred_gradient = if prediction > 0.0 && prediction < 1.0 {
922 -target / prediction + (1.0 - target) / (1.0 - prediction)
923 } else {
924 0.0 };
926
927 gradients[i] = pred_gradient * (pred_plus - pred_minus) / 2.0;
928 }
929
930 Ok(gradients)
931 }
932}
933
934#[cfg(test)]
935mod tests {
936 use super::*;
937
938 #[test]
939 fn test_qgan_creation() {
940 let config = QGANConfig::default();
941 let qgan = QGAN::new(config).unwrap();
942
943 assert_eq!(qgan.iteration, 0);
944 assert_eq!(qgan.training_stats.generator_losses.len(), 0);
945 }
946
947 #[test]
948 fn test_noise_generation() {
949 let config = QGANConfig::default();
950 let mut qgan = QGAN::new(config).unwrap();
951
952 let noise = qgan.sample_noise().unwrap();
953 assert_eq!(noise.len(), qgan.config.latent_qubits);
954
955 qgan.config.noise_type = NoiseType::Uniform;
957 let uniform_noise = qgan.sample_noise().unwrap();
958 assert_eq!(uniform_noise.len(), qgan.config.latent_qubits);
959
960 qgan.config.noise_type = NoiseType::QuantumSuperposition;
961 let quantum_noise = qgan.sample_noise().unwrap();
962 assert_eq!(quantum_noise.len(), qgan.config.latent_qubits);
963 }
964
965 #[test]
966 fn test_data_generation() {
967 let config = QGANConfig::default();
968 let mut qgan = QGAN::new(config).unwrap();
969
970 let generated_data = qgan.generate_data(5).unwrap();
971 assert_eq!(generated_data.nrows(), 5);
972 assert_eq!(generated_data.ncols(), qgan.config.data_qubits);
973 }
974
975 #[test]
976 fn test_discrimination() {
977 let config = QGANConfig::default();
978 let qgan = QGAN::new(config).unwrap();
979
980 let data = Array2::from_shape_fn((3, qgan.config.data_qubits), |(i, j)| {
982 (i as f64 + j as f64) / 10.0
983 });
984
985 let scores = qgan.discriminate(&data).unwrap();
986 assert_eq!(scores.len(), 3);
987
988 for &score in scores.iter() {
990 assert!(score >= 0.0 && score <= 1.0);
991 }
992 }
993
994 #[test]
995 fn test_qgan_training_step() {
996 let config = QGANConfig {
997 batch_size: 4,
998 ..Default::default()
999 };
1000 let mut qgan = QGAN::new(config).unwrap();
1001
1002 let real_data = Array2::from_shape_fn((10, qgan.config.data_qubits), |(i, j)| {
1004 ((i + j) as f64).sin()
1005 });
1006
1007 let metrics = qgan.train(&real_data).unwrap();
1008
1009 assert_eq!(metrics.iteration, 0);
1010 assert!(metrics.fidelity >= 0.0 && metrics.fidelity <= 1.0);
1011 assert_eq!(qgan.iteration, 1);
1012 assert_eq!(qgan.training_stats.generator_losses.len(), 1);
1013 assert_eq!(qgan.training_stats.discriminator_losses.len(), 1);
1014 }
1015
1016 #[test]
1017 fn test_convergence_check() {
1018 let config = QGANConfig::default();
1019 let mut qgan = QGAN::new(config).unwrap();
1020
1021 for _ in 0..10 {
1023 qgan.training_stats.fidelities.push(0.95);
1024 }
1025
1026 assert!(qgan.has_converged(0.1, 5)); assert!(!qgan.has_converged(0.01, 5)); }
1029
1030 #[test]
1031 fn test_quantum_generator_circuit() {
1032 let circuit = QuantumGeneratorCircuit::new(3, 2, 4, NoiseType::Gaussian).unwrap();
1033 let param_count = circuit.get_parameter_count();
1034 assert!(param_count > 0);
1035
1036 let noise = Array1::from_vec(vec![0.5, -0.5, 0.0]);
1037 let parameters = Array1::zeros(param_count);
1038
1039 let generated_data = circuit.generate_data(&noise, ¶meters).unwrap();
1040 assert_eq!(generated_data.len(), 2);
1041 }
1042
1043 #[test]
1044 fn test_quantum_discriminator_circuit() {
1045 let circuit = QuantumDiscriminatorCircuit::new(3, 2, 4).unwrap();
1046 let param_count = circuit.get_parameter_count();
1047 assert!(param_count > 0);
1048
1049 let data = Array1::from_vec(vec![0.5, -0.5, 0.0]);
1050 let parameters = Array1::zeros(param_count);
1051
1052 let score = circuit.discriminate_data(&data, ¶meters).unwrap();
1053 assert!(score >= 0.0 && score <= 1.0);
1054 }
1055
1056 #[test]
1057 fn test_fidelity_computation() {
1058 let config = QGANConfig::default();
1059 let qgan = QGAN::new(config).unwrap();
1060
1061 let data1 = Array2::ones((5, 3));
1063 let data2 = Array2::ones((5, 3));
1064 let fidelity = qgan.compute_fidelity(&data1, &data2).unwrap();
1065 assert!(fidelity > 0.9);
1066
1067 let data3 = Array2::zeros((5, 3));
1069 let fidelity2 = qgan.compute_fidelity(&data1, &data3).unwrap();
1070 assert!(fidelity2 < fidelity);
1071 }
1072}