1use crate::{
8 error::QuantRS2Result, gate::multi::*, gate::single::*, gate::GateOp, qubit::QubitId,
9 variational::VariationalOptimizer,
10};
11use scirs2_core::ndarray::{Array1, Array2, Axis};
12use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct QGANConfig {
19 pub generator_qubits: usize,
21 pub discriminator_qubits: usize,
23 pub latent_qubits: usize,
25 pub data_qubits: usize,
27 pub generator_lr: f64,
29 pub discriminator_lr: f64,
31 pub generator_depth: usize,
33 pub discriminator_depth: usize,
35 pub batch_size: usize,
37 pub max_iterations: usize,
39 pub generator_frequency: usize,
41 pub use_quantum_advantage: bool,
43 pub noise_type: NoiseType,
45 pub regularization: f64,
47 pub random_seed: Option<u64>,
49}
50
51impl Default for QGANConfig {
52 fn default() -> Self {
53 Self {
54 generator_qubits: 6,
55 discriminator_qubits: 6,
56 latent_qubits: 4,
57 data_qubits: 4,
58 generator_lr: 0.01,
59 discriminator_lr: 0.01,
60 generator_depth: 8,
61 discriminator_depth: 6,
62 batch_size: 16,
63 max_iterations: 1000,
64 generator_frequency: 1,
65 use_quantum_advantage: true,
66 noise_type: NoiseType::Gaussian,
67 regularization: 0.001,
68 random_seed: None,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum NoiseType {
76 Gaussian,
78 Uniform,
80 QuantumSuperposition,
82 BasisStates,
84}
85
86pub struct QGAN {
88 config: QGANConfig,
90 generator: QuantumGenerator,
92 discriminator: QuantumDiscriminator,
94 training_stats: QGANTrainingStats,
96 rng: StdRng,
98 iteration: usize,
100}
101
102pub struct QuantumGenerator {
104 circuit: QuantumGeneratorCircuit,
106 parameters: Array1<f64>,
108 #[allow(dead_code)]
110 optimizer: VariationalOptimizer,
111 gradient_history: VecDeque<Array1<f64>>,
113}
114
115pub struct QuantumDiscriminator {
117 circuit: QuantumDiscriminatorCircuit,
119 parameters: Array1<f64>,
121 #[allow(dead_code)]
123 optimizer: VariationalOptimizer,
124 gradient_history: VecDeque<Array1<f64>>,
126}
127
128#[derive(Debug, Clone)]
130pub struct QuantumGeneratorCircuit {
131 latent_qubits: usize,
133 data_qubits: usize,
135 depth: usize,
137 total_qubits: usize,
139 noise_type: NoiseType,
141}
142
143#[derive(Debug, Clone)]
145pub struct QuantumDiscriminatorCircuit {
146 data_qubits: usize,
148 aux_qubits: usize,
150 depth: usize,
152 total_qubits: usize,
154}
155
156#[derive(Debug, Clone, Default)]
158pub struct QGANTrainingStats {
159 pub generator_losses: Vec<f64>,
161 pub discriminator_losses: Vec<f64>,
163 pub fidelities: Vec<f64>,
165 pub iteration_times: Vec<f64>,
167 pub convergence_metrics: Vec<f64>,
169}
170
171#[derive(Debug, Clone)]
173pub struct QGANIterationMetrics {
174 pub generator_loss: f64,
176 pub discriminator_loss: f64,
178 pub real_accuracy: f64,
180 pub fake_accuracy: f64,
182 pub fidelity: f64,
184 pub iteration: usize,
186}
187
188impl QGAN {
189 pub fn new(config: QGANConfig) -> QuantRS2Result<Self> {
191 let rng = match config.random_seed {
192 Some(seed) => StdRng::seed_from_u64(seed),
193 None => StdRng::from_seed([0; 32]), };
195
196 let generator = QuantumGenerator::new(&config)?;
197 let discriminator = QuantumDiscriminator::new(&config)?;
198
199 Ok(Self {
200 config,
201 generator,
202 discriminator,
203 training_stats: QGANTrainingStats::default(),
204 rng,
205 iteration: 0,
206 })
207 }
208
209 pub fn train(&mut self, real_data: &Array2<f64>) -> QuantRS2Result<QGANIterationMetrics> {
211 let batch_size = self.config.batch_size.min(real_data.nrows());
212
213 let real_batch = self.sample_real_data_batch(real_data, batch_size)?;
215
216 let fake_batch = self.generate_fake_data_batch(batch_size)?;
218
219 let (d_loss_real, d_loss_fake) = self.train_discriminator(&real_batch, &fake_batch)?;
221 let discriminator_loss = d_loss_real + d_loss_fake;
222
223 let generator_loss = if self.iteration % self.config.generator_frequency == 0 {
225 self.train_generator(batch_size)?
226 } else {
227 0.0
228 };
229
230 let real_accuracy = self.compute_discriminator_accuracy(&real_batch, true)?;
232 let fake_accuracy = self.compute_discriminator_accuracy(&fake_batch, false)?;
233 let fidelity = self.compute_fidelity(&real_batch, &fake_batch)?;
234
235 self.training_stats.generator_losses.push(generator_loss);
237 self.training_stats
238 .discriminator_losses
239 .push(discriminator_loss);
240 self.training_stats.fidelities.push(fidelity);
241
242 let metrics = QGANIterationMetrics {
243 generator_loss,
244 discriminator_loss,
245 real_accuracy,
246 fake_accuracy,
247 fidelity,
248 iteration: self.iteration,
249 };
250
251 self.iteration += 1;
252
253 Ok(metrics)
254 }
255
256 pub fn generate_data(&mut self, num_samples: usize) -> QuantRS2Result<Array2<f64>> {
258 self.generate_fake_data_batch(num_samples)
259 }
260
261 pub fn discriminate(&self, data: &Array2<f64>) -> QuantRS2Result<Array1<f64>> {
263 let num_samples = data.nrows();
264 let mut scores = Array1::zeros(num_samples);
265
266 for i in 0..num_samples {
267 let sample = data.row(i).to_owned();
268 scores[i] = self.discriminator.discriminate(&sample)?;
269 }
270
271 Ok(scores)
272 }
273
274 fn sample_real_data_batch(
276 &mut self,
277 real_data: &Array2<f64>,
278 batch_size: usize,
279 ) -> QuantRS2Result<Array2<f64>> {
280 let num_samples = real_data.nrows();
281 let mut batch = Array2::zeros((batch_size, real_data.ncols()));
282
283 for i in 0..batch_size {
284 let idx = self.rng.random_range(0..num_samples);
285 batch.row_mut(i).assign(&real_data.row(idx));
286 }
287
288 Ok(batch)
289 }
290
291 fn generate_fake_data_batch(&mut self, batch_size: usize) -> QuantRS2Result<Array2<f64>> {
293 let mut fake_batch = Array2::zeros((batch_size, self.config.data_qubits));
294
295 for i in 0..batch_size {
296 let noise = self.sample_noise()?;
297 let generated_sample = self.generator.generate(&noise)?;
298 fake_batch.row_mut(i).assign(&generated_sample);
299 }
300
301 Ok(fake_batch)
302 }
303
304 fn sample_noise(&mut self) -> QuantRS2Result<Array1<f64>> {
306 let mut noise = Array1::zeros(self.config.latent_qubits);
307
308 match self.config.noise_type {
309 NoiseType::Gaussian => {
310 for i in 0..self.config.latent_qubits {
311 noise[i] = self.rng.random::<f64>().mul_add(2.0, -1.0); }
313 }
314 NoiseType::Uniform => {
315 for i in 0..self.config.latent_qubits {
316 noise[i] = (self.rng.random::<f64>() * 2.0)
317 .mul_add(std::f64::consts::PI, -std::f64::consts::PI);
318 }
319 }
320 NoiseType::QuantumSuperposition => {
321 for i in 0..self.config.latent_qubits {
323 noise[i] = std::f64::consts::PI / 2.0; }
325 }
326 NoiseType::BasisStates => {
327 let state = self.rng.random_range(0..(1 << self.config.latent_qubits));
329 for i in 0..self.config.latent_qubits {
330 noise[i] = if (state >> i) & 1 == 1 {
331 std::f64::consts::PI
332 } else {
333 0.0
334 };
335 }
336 }
337 }
338
339 Ok(noise)
340 }
341
342 fn train_discriminator(
344 &mut self,
345 real_batch: &Array2<f64>,
346 fake_batch: &Array2<f64>,
347 ) -> QuantRS2Result<(f64, f64)> {
348 let d_loss_real = self
350 .discriminator
351 .train_batch(real_batch, &Array1::ones(real_batch.nrows()))?;
352
353 let d_loss_fake = self
355 .discriminator
356 .train_batch(fake_batch, &Array1::zeros(fake_batch.nrows()))?;
357
358 Ok((d_loss_real, d_loss_fake))
359 }
360
361 fn train_generator(&mut self, batch_size: usize) -> QuantRS2Result<f64> {
363 let fake_batch = self.generate_fake_data_batch(batch_size)?;
365
366 let discriminator_scores = self.discriminate(&fake_batch)?;
368
369 let targets = Array1::ones(batch_size);
371 let generator_loss =
372 self.generator
373 .train_adversarial(&fake_batch, &targets, &discriminator_scores)?;
374
375 Ok(generator_loss)
376 }
377
378 fn compute_discriminator_accuracy(
380 &self,
381 data: &Array2<f64>,
382 is_real: bool,
383 ) -> QuantRS2Result<f64> {
384 let scores = self.discriminate(data)?;
385 let threshold = 0.5;
386 let _target = if is_real { 1.0 } else { 0.0 };
387
388 let correct = scores
389 .iter()
390 .map(|&score| {
391 if (score > threshold) == is_real {
392 1.0
393 } else {
394 0.0
395 }
396 })
397 .sum::<f64>();
398
399 Ok(correct / data.nrows() as f64)
400 }
401
402 fn compute_fidelity(
404 &self,
405 real_batch: &Array2<f64>,
406 fake_batch: &Array2<f64>,
407 ) -> QuantRS2Result<f64> {
408 use crate::error::QuantRS2Error;
409 let real_mean = real_batch
411 .mean_axis(Axis(0))
412 .ok_or_else(|| QuantRS2Error::InvalidInput("Empty real batch".to_string()))?;
413 let fake_mean = fake_batch
414 .mean_axis(Axis(0))
415 .ok_or_else(|| QuantRS2Error::InvalidInput("Empty fake batch".to_string()))?;
416
417 let real_var = real_batch.var_axis(Axis(0), 0.0);
418 let fake_var = fake_batch.var_axis(Axis(0), 0.0);
419
420 let mean_diff = (&real_mean - &fake_mean).mapv(|x| x.powi(2)).sum().sqrt();
422 let var_diff = (&real_var - &fake_var).mapv(|x| x.powi(2)).sum().sqrt();
423
424 let fidelity = (-0.5 * (mean_diff + var_diff)).exp();
425
426 Ok(fidelity)
427 }
428
429 pub const fn get_training_stats(&self) -> &QGANTrainingStats {
431 &self.training_stats
432 }
433
434 pub fn has_converged(&self, tolerance: f64, window: usize) -> bool {
436 if self.training_stats.fidelities.len() < window {
437 return false;
438 }
439
440 let recent_fidelities =
441 &self.training_stats.fidelities[self.training_stats.fidelities.len() - window..];
442 let mean_fidelity = recent_fidelities.iter().sum::<f64>() / window as f64;
443
444 mean_fidelity > 1.0 - tolerance
445 }
446}
447
448impl QuantumGenerator {
449 fn new(config: &QGANConfig) -> QuantRS2Result<Self> {
451 let circuit = QuantumGeneratorCircuit::new(
452 config.latent_qubits,
453 config.data_qubits,
454 config.generator_depth,
455 config.noise_type,
456 )?;
457
458 let num_parameters = circuit.get_parameter_count();
459 let mut parameters = Array1::zeros(num_parameters);
460
461 let mut rng = match config.random_seed {
463 Some(seed) => StdRng::seed_from_u64(seed),
464 None => StdRng::from_seed([0; 32]),
465 };
466
467 for param in &mut parameters {
468 *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
469 }
470
471 let optimizer = VariationalOptimizer::new(0.01, 0.9);
472 let gradient_history = VecDeque::with_capacity(10);
473
474 Ok(Self {
475 circuit,
476 parameters,
477 optimizer,
478 gradient_history,
479 })
480 }
481
482 fn generate(&self, noise: &Array1<f64>) -> QuantRS2Result<Array1<f64>> {
484 self.circuit.generate_data(noise, &self.parameters)
485 }
486
487 fn train_adversarial(
489 &mut self,
490 generated_data: &Array2<f64>,
491 targets: &Array1<f64>,
492 discriminator_scores: &Array1<f64>,
493 ) -> QuantRS2Result<f64> {
494 let batch_size = generated_data.nrows();
495 let mut total_loss = 0.0;
496
497 for i in 0..batch_size {
498 let data_sample = generated_data.row(i).to_owned();
499 let target = targets[i];
500 let score = discriminator_scores[i];
501
502 let loss = (score - target).powi(2);
504 total_loss += loss;
505
506 let gradients = self.circuit.compute_adversarial_gradients(
508 &data_sample,
509 target,
510 score,
511 &self.parameters,
512 )?;
513 self.update_parameters(&gradients, 0.01)?; }
515
516 Ok(total_loss / batch_size as f64)
517 }
518
519 fn update_parameters(
521 &mut self,
522 gradients: &Array1<f64>,
523 learning_rate: f64,
524 ) -> QuantRS2Result<()> {
525 let mut effective_gradients = gradients.clone();
527
528 if let Some(prev_gradients) = self.gradient_history.back() {
529 let momentum = 0.9;
530 effective_gradients = &effective_gradients + &(prev_gradients * momentum);
531 }
532
533 for (param, &grad) in self.parameters.iter_mut().zip(effective_gradients.iter()) {
535 *param -= learning_rate * grad;
536 }
537
538 self.gradient_history.push_back(effective_gradients);
540 if self.gradient_history.len() > 10 {
541 self.gradient_history.pop_front();
542 }
543
544 Ok(())
545 }
546}
547
548impl QuantumDiscriminator {
549 fn new(config: &QGANConfig) -> QuantRS2Result<Self> {
551 let circuit = QuantumDiscriminatorCircuit::new(
552 config.data_qubits,
553 config.discriminator_qubits - config.data_qubits,
554 config.discriminator_depth,
555 )?;
556
557 let num_parameters = circuit.get_parameter_count();
558 let mut parameters = Array1::zeros(num_parameters);
559
560 let mut rng = match config.random_seed {
562 Some(seed) => StdRng::seed_from_u64(seed),
563 None => StdRng::from_seed([0; 32]),
564 };
565
566 for param in &mut parameters {
567 *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
568 }
569
570 let optimizer = VariationalOptimizer::new(0.01, 0.9);
571 let gradient_history = VecDeque::with_capacity(10);
572
573 Ok(Self {
574 circuit,
575 parameters,
576 optimizer,
577 gradient_history,
578 })
579 }
580
581 fn discriminate(&self, data: &Array1<f64>) -> QuantRS2Result<f64> {
583 self.circuit.discriminate_data(data, &self.parameters)
584 }
585
586 fn train_batch(
588 &mut self,
589 data_batch: &Array2<f64>,
590 targets: &Array1<f64>,
591 ) -> QuantRS2Result<f64> {
592 let batch_size = data_batch.nrows();
593 let mut total_loss = 0.0;
594
595 for i in 0..batch_size {
596 let data_sample = data_batch.row(i).to_owned();
597 let target = targets[i];
598
599 let prediction = self.discriminate(&data_sample)?;
601
602 let loss = -target.mul_add(prediction.ln(), (1.0 - target) * (1.0 - prediction).ln());
604 total_loss += loss;
605
606 let gradients = self.circuit.compute_discriminator_gradients(
608 &data_sample,
609 target,
610 prediction,
611 &self.parameters,
612 )?;
613 self.update_parameters(&gradients, 0.01)?; }
615
616 Ok(total_loss / batch_size as f64)
617 }
618
619 fn update_parameters(
621 &mut self,
622 gradients: &Array1<f64>,
623 learning_rate: f64,
624 ) -> QuantRS2Result<()> {
625 let mut effective_gradients = gradients.clone();
627
628 if let Some(prev_gradients) = self.gradient_history.back() {
629 let momentum = 0.9;
630 effective_gradients = &effective_gradients + &(prev_gradients * momentum);
631 }
632
633 for (param, &grad) in self.parameters.iter_mut().zip(effective_gradients.iter()) {
635 *param -= learning_rate * grad;
636 }
637
638 self.gradient_history.push_back(effective_gradients);
640 if self.gradient_history.len() > 10 {
641 self.gradient_history.pop_front();
642 }
643
644 Ok(())
645 }
646}
647
648impl QuantumGeneratorCircuit {
649 const fn new(
651 latent_qubits: usize,
652 data_qubits: usize,
653 depth: usize,
654 noise_type: NoiseType,
655 ) -> QuantRS2Result<Self> {
656 let total_qubits = latent_qubits + data_qubits;
657
658 Ok(Self {
659 latent_qubits,
660 data_qubits,
661 depth,
662 total_qubits,
663 noise_type,
664 })
665 }
666
667 const fn get_parameter_count(&self) -> usize {
669 let total_qubits = self.latent_qubits + self.data_qubits;
670 let rotations_per_layer = total_qubits * 3;
672 let entangling_per_layer = total_qubits; self.depth * (rotations_per_layer + entangling_per_layer)
674 }
675
676 fn generate_data(
678 &self,
679 noise: &Array1<f64>,
680 parameters: &Array1<f64>,
681 ) -> QuantRS2Result<Array1<f64>> {
682 let mut gates = Vec::new();
684
685 for i in 0..self.latent_qubits {
687 let noise_value = if i < noise.len() { noise[i] } else { 0.0 };
688 gates.push(Box::new(RotationY {
689 target: QubitId(i as u32),
690 theta: noise_value,
691 }) as Box<dyn GateOp>);
692 }
693
694 let mut param_idx = 0;
696 for _layer in 0..self.depth {
697 for qubit in 0..self.latent_qubits + self.data_qubits {
699 if param_idx + 2 < parameters.len() {
700 gates.push(Box::new(RotationX {
701 target: QubitId(qubit as u32),
702 theta: parameters[param_idx],
703 }) as Box<dyn GateOp>);
704 param_idx += 1;
705
706 gates.push(Box::new(RotationY {
707 target: QubitId(qubit as u32),
708 theta: parameters[param_idx],
709 }) as Box<dyn GateOp>);
710 param_idx += 1;
711
712 gates.push(Box::new(RotationZ {
713 target: QubitId(qubit as u32),
714 theta: parameters[param_idx],
715 }) as Box<dyn GateOp>);
716 param_idx += 1;
717 }
718 }
719
720 for qubit in 0..self.latent_qubits + self.data_qubits - 1 {
722 if param_idx < parameters.len() {
723 gates.push(Box::new(CRZ {
724 control: QubitId(qubit as u32),
725 target: QubitId((qubit + 1) as u32),
726 theta: parameters[param_idx],
727 }) as Box<dyn GateOp>);
728 param_idx += 1;
729 }
730 }
731 }
732
733 let generated_data = self.simulate_generation_circuit(&gates)?;
735
736 Ok(generated_data)
737 }
738
739 fn simulate_generation_circuit(
741 &self,
742 gates: &[Box<dyn GateOp>],
743 ) -> QuantRS2Result<Array1<f64>> {
744 let mut data = Array1::zeros(self.data_qubits);
746
747 let mut hash_value = 0u64;
748 for gate in gates {
749 if let Ok(matrix) = gate.matrix() {
750 for complex in &matrix {
751 hash_value = hash_value.wrapping_add((complex.re * 1000.0) as u64);
752 }
753 }
754 }
755
756 for i in 0..self.data_qubits {
758 let qubit_hash = hash_value.wrapping_add(i as u64);
759 data[i] = ((qubit_hash % 1000) as f64 / 1000.0).mul_add(2.0, -1.0); }
761
762 Ok(data)
763 }
764
765 fn compute_adversarial_gradients(
767 &self,
768 _data_sample: &Array1<f64>,
769 target: f64,
770 score: f64,
771 parameters: &Array1<f64>,
772 ) -> QuantRS2Result<Array1<f64>> {
773 let mut gradients = Array1::zeros(parameters.len());
774 let shift = std::f64::consts::PI / 2.0;
775
776 for i in 0..parameters.len() {
777 let mut params_plus = parameters.clone();
779 params_plus[i] += shift;
780 let data_plus = self.generate_data(&Array1::zeros(self.latent_qubits), ¶ms_plus)?;
781
782 let mut params_minus = parameters.clone();
783 params_minus[i] -= shift;
784 let data_minus =
785 self.generate_data(&Array1::zeros(self.latent_qubits), ¶ms_minus)?;
786
787 let loss_gradient = 2.0 * (score - target);
789
790 let data_diff = (&data_plus - &data_minus).sum() / 2.0;
792
793 gradients[i] = loss_gradient * data_diff;
794 }
795
796 Ok(gradients)
797 }
798}
799
800impl QuantumDiscriminatorCircuit {
801 const fn new(data_qubits: usize, aux_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
803 let total_qubits = data_qubits + aux_qubits;
804
805 Ok(Self {
806 data_qubits,
807 aux_qubits,
808 depth,
809 total_qubits,
810 })
811 }
812
813 const fn get_parameter_count(&self) -> usize {
815 let total_qubits = self.data_qubits + self.aux_qubits;
816 let rotations_per_layer = total_qubits * 3;
817 let entangling_per_layer = total_qubits;
818 self.depth * (rotations_per_layer + entangling_per_layer)
819 }
820
821 fn discriminate_data(
823 &self,
824 data: &Array1<f64>,
825 parameters: &Array1<f64>,
826 ) -> QuantRS2Result<f64> {
827 let mut gates = Vec::new();
829
830 for i in 0..self.data_qubits {
832 let data_value = if i < data.len() { data[i] } else { 0.0 };
833 gates.push(Box::new(RotationY {
834 target: QubitId(i as u32),
835 theta: data_value * std::f64::consts::PI,
836 }) as Box<dyn GateOp>);
837 }
838
839 let mut param_idx = 0;
841 for _layer in 0..self.depth {
842 for qubit in 0..self.data_qubits + self.aux_qubits {
844 if param_idx + 2 < parameters.len() {
845 gates.push(Box::new(RotationX {
846 target: QubitId(qubit as u32),
847 theta: parameters[param_idx],
848 }) as Box<dyn GateOp>);
849 param_idx += 1;
850
851 gates.push(Box::new(RotationY {
852 target: QubitId(qubit as u32),
853 theta: parameters[param_idx],
854 }) as Box<dyn GateOp>);
855 param_idx += 1;
856
857 gates.push(Box::new(RotationZ {
858 target: QubitId(qubit as u32),
859 theta: parameters[param_idx],
860 }) as Box<dyn GateOp>);
861 param_idx += 1;
862 }
863 }
864
865 for qubit in 0..self.data_qubits + self.aux_qubits - 1 {
867 if param_idx < parameters.len() {
868 gates.push(Box::new(CRZ {
869 control: QubitId(qubit as u32),
870 target: QubitId((qubit + 1) as u32),
871 theta: parameters[param_idx],
872 }) as Box<dyn GateOp>);
873 param_idx += 1;
874 }
875 }
876 }
877
878 let probability = self.simulate_discrimination_circuit(&gates)?;
880
881 Ok(probability)
882 }
883
884 fn simulate_discrimination_circuit(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<f64> {
886 let mut hash_value = 0u64;
888
889 for gate in gates {
890 if let Ok(matrix) = gate.matrix() {
891 for complex in &matrix {
892 hash_value = hash_value.wrapping_add((complex.re * 1000.0) as u64);
893 hash_value = hash_value.wrapping_add((complex.im * 1000.0) as u64);
894 }
895 }
896 }
897
898 let probability = ((hash_value % 1000) as f64) / 1000.0;
900
901 Ok(probability)
902 }
903
904 fn compute_discriminator_gradients(
906 &self,
907 data_sample: &Array1<f64>,
908 target: f64,
909 prediction: f64,
910 parameters: &Array1<f64>,
911 ) -> QuantRS2Result<Array1<f64>> {
912 let mut gradients = Array1::zeros(parameters.len());
913 let shift = std::f64::consts::PI / 2.0;
914
915 for i in 0..parameters.len() {
916 let mut params_plus = parameters.clone();
918 params_plus[i] += shift;
919 let pred_plus = self.discriminate_data(data_sample, ¶ms_plus)?;
920
921 let mut params_minus = parameters.clone();
922 params_minus[i] -= shift;
923 let pred_minus = self.discriminate_data(data_sample, ¶ms_minus)?;
924
925 let pred_gradient = if prediction > 0.0 && prediction < 1.0 {
927 -target / prediction + (1.0 - target) / (1.0 - prediction)
928 } else {
929 0.0 };
931
932 gradients[i] = pred_gradient * (pred_plus - pred_minus) / 2.0;
933 }
934
935 Ok(gradients)
936 }
937}
938
939#[cfg(test)]
940mod tests {
941 use super::*;
942
943 #[test]
944 fn test_qgan_creation() {
945 let config = QGANConfig::default();
946 let qgan = QGAN::new(config).expect("failed to create QGAN");
947
948 assert_eq!(qgan.iteration, 0);
949 assert_eq!(qgan.training_stats.generator_losses.len(), 0);
950 }
951
952 #[test]
953 fn test_noise_generation() {
954 let config = QGANConfig::default();
955 let mut qgan = QGAN::new(config).expect("failed to create QGAN");
956
957 let noise = qgan.sample_noise().expect("failed to sample noise");
958 assert_eq!(noise.len(), qgan.config.latent_qubits);
959
960 qgan.config.noise_type = NoiseType::Uniform;
962 let uniform_noise = qgan.sample_noise().expect("failed to sample uniform noise");
963 assert_eq!(uniform_noise.len(), qgan.config.latent_qubits);
964
965 qgan.config.noise_type = NoiseType::QuantumSuperposition;
966 let quantum_noise = qgan.sample_noise().expect("failed to sample quantum noise");
967 assert_eq!(quantum_noise.len(), qgan.config.latent_qubits);
968 }
969
970 #[test]
971 fn test_data_generation() {
972 let config = QGANConfig::default();
973 let mut qgan = QGAN::new(config).expect("failed to create QGAN");
974
975 let generated_data = qgan.generate_data(5).expect("failed to generate data");
976 assert_eq!(generated_data.nrows(), 5);
977 assert_eq!(generated_data.ncols(), qgan.config.data_qubits);
978 }
979
980 #[test]
981 fn test_discrimination() {
982 let config = QGANConfig::default();
983 let qgan = QGAN::new(config).expect("failed to create QGAN");
984
985 let data = Array2::from_shape_fn((3, qgan.config.data_qubits), |(i, j)| {
987 (i as f64 + j as f64) / 10.0
988 });
989
990 let scores = qgan.discriminate(&data).expect("failed to discriminate");
991 assert_eq!(scores.len(), 3);
992
993 for &score in scores.iter() {
995 assert!(score >= 0.0 && score <= 1.0);
996 }
997 }
998
999 #[test]
1000 fn test_qgan_training_step() {
1001 let config = QGANConfig {
1002 batch_size: 4,
1003 ..Default::default()
1004 };
1005 let mut qgan = QGAN::new(config).expect("failed to create QGAN");
1006
1007 let real_data = Array2::from_shape_fn((10, qgan.config.data_qubits), |(i, j)| {
1009 ((i + j) as f64).sin()
1010 });
1011
1012 let metrics = qgan.train(&real_data).expect("failed to train QGAN");
1013
1014 assert_eq!(metrics.iteration, 0);
1015 assert!(metrics.fidelity >= 0.0 && metrics.fidelity <= 1.0);
1016 assert_eq!(qgan.iteration, 1);
1017 assert_eq!(qgan.training_stats.generator_losses.len(), 1);
1018 assert_eq!(qgan.training_stats.discriminator_losses.len(), 1);
1019 }
1020
1021 #[test]
1022 fn test_convergence_check() {
1023 let config = QGANConfig::default();
1024 let mut qgan = QGAN::new(config).expect("failed to create QGAN");
1025
1026 for _ in 0..10 {
1028 qgan.training_stats.fidelities.push(0.95);
1029 }
1030
1031 assert!(qgan.has_converged(0.1, 5)); assert!(!qgan.has_converged(0.01, 5)); }
1034
1035 #[test]
1036 fn test_quantum_generator_circuit() {
1037 let circuit = QuantumGeneratorCircuit::new(3, 2, 4, NoiseType::Gaussian)
1038 .expect("failed to create generator circuit");
1039 let param_count = circuit.get_parameter_count();
1040 assert!(param_count > 0);
1041
1042 let noise = Array1::from_vec(vec![0.5, -0.5, 0.0]);
1043 let parameters = Array1::zeros(param_count);
1044
1045 let generated_data = circuit
1046 .generate_data(&noise, ¶meters)
1047 .expect("failed to generate data");
1048 assert_eq!(generated_data.len(), 2);
1049 }
1050
1051 #[test]
1052 fn test_quantum_discriminator_circuit() {
1053 let circuit = QuantumDiscriminatorCircuit::new(3, 2, 4)
1054 .expect("failed to create discriminator circuit");
1055 let param_count = circuit.get_parameter_count();
1056 assert!(param_count > 0);
1057
1058 let data = Array1::from_vec(vec![0.5, -0.5, 0.0]);
1059 let parameters = Array1::zeros(param_count);
1060
1061 let score = circuit
1062 .discriminate_data(&data, ¶meters)
1063 .expect("failed to discriminate data");
1064 assert!(score >= 0.0 && score <= 1.0);
1065 }
1066
1067 #[test]
1068 fn test_fidelity_computation() {
1069 let config = QGANConfig::default();
1070 let qgan = QGAN::new(config).expect("failed to create QGAN");
1071
1072 let data1 = Array2::ones((5, 3));
1074 let data2 = Array2::ones((5, 3));
1075 let fidelity = qgan
1076 .compute_fidelity(&data1, &data2)
1077 .expect("failed to compute fidelity");
1078 assert!(fidelity > 0.9);
1079
1080 let data3 = Array2::zeros((5, 3));
1082 let fidelity2 = qgan
1083 .compute_fidelity(&data1, &data3)
1084 .expect("failed to compute fidelity");
1085 assert!(fidelity2 < fidelity);
1086 }
1087}