quantum_gan/
quantum_gan.rs1use quantrs2_ml::gan::{DiscriminatorType, GANEvaluationMetrics, GeneratorType, QuantumGAN};
2use quantrs2_ml::prelude::*;
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::prelude::*;
5use std::time::Instant;
6
7fn main() -> Result<()> {
8 println!("Quantum Generative Adversarial Network Example");
9 println!("=============================================");
10
11 let num_qubits_gen = 6;
13 let num_qubits_disc = 6;
14 let latent_dim = 4;
15 let data_dim = 8;
16
17 println!("Creating Quantum GAN...");
18 println!(" Generator: {num_qubits_gen} qubits");
19 println!(" Discriminator: {num_qubits_disc} qubits");
20 println!(" Latent dimension: {latent_dim}");
21 println!(" Data dimension: {data_dim}");
22
23 let mut qgan = QuantumGAN::new(
25 num_qubits_gen,
26 num_qubits_disc,
27 latent_dim,
28 data_dim,
29 GeneratorType::HybridClassicalQuantum,
30 DiscriminatorType::HybridQuantumFeatures,
31 )?;
32
33 println!("Generating synthetic data for training...");
35 let real_data = generate_sine_wave_data(500, data_dim);
36
37 println!("Training quantum GAN...");
39 let training_params = [
40 (50, 32, 0.01, 0.01, 1), ];
42
43 for (epochs, batch_size, lr_gen, lr_disc, disc_steps) in training_params {
44 println!("Training with parameters:");
45 println!(" Epochs: {epochs}");
46 println!(" Batch size: {batch_size}");
47 println!(" Generator learning rate: {lr_gen}");
48 println!(" Discriminator learning rate: {lr_disc}");
49 println!(" Discriminator steps per iteration: {disc_steps}");
50
51 let start = Instant::now();
52 let history = qgan.train(&real_data, epochs, batch_size, lr_gen, lr_disc, disc_steps)?;
53
54 println!("Training completed in {:.2?}", start.elapsed());
55 println!("Final losses:");
56 println!(
57 " Generator: {:.4}",
58 history.gen_losses.last().unwrap_or(&0.0)
59 );
60 println!(
61 " Discriminator: {:.4}",
62 history.disc_losses.last().unwrap_or(&0.0)
63 );
64 }
65
66 println!("\nGenerating samples from trained GAN...");
68 let num_samples = 10;
69 let generated_samples = qgan.generate(num_samples)?;
70
71 println!("Generated {num_samples} samples");
72 println!("First sample:");
73 print_sample(
74 &generated_samples
75 .slice(scirs2_core::ndarray::s![0, ..])
76 .to_owned(),
77 );
78
79 println!("\nEvaluating GAN quality...");
81 let eval_metrics = qgan.evaluate(&real_data, num_samples)?;
82
83 println!("Evaluation metrics:");
84 println!(
85 " Real data accuracy: {:.2}%",
86 eval_metrics.real_accuracy * 100.0
87 );
88 println!(
89 " Fake data accuracy: {:.2}%",
90 eval_metrics.fake_accuracy * 100.0
91 );
92 println!(
93 " Overall discriminator accuracy: {:.2}%",
94 eval_metrics.overall_accuracy * 100.0
95 );
96 println!(" JS Divergence: {:.4}", eval_metrics.js_divergence);
97
98 println!("\nCreating specialized particle physics GAN...");
100 let particle_gan = quantrs2_ml::gan::physics_gan::ParticleGAN::new(
101 num_qubits_gen,
102 num_qubits_disc,
103 latent_dim,
104 data_dim,
105 )?;
106
107 println!("Particle GAN created successfully");
108
109 Ok(())
110}
111
112fn generate_sine_wave_data(num_samples: usize, data_dim: usize) -> Array2<f64> {
114 let mut data = Array2::zeros((num_samples, data_dim));
115
116 for i in 0..num_samples {
117 let x = (i as f64) / (num_samples as f64) * 2.0 * std::f64::consts::PI;
118
119 for j in 0..data_dim {
120 let freq = (j as f64 + 1.0) * 0.5;
121 data[[i, j]] = 0.1f64.mul_add(thread_rng().gen::<f64>(), (x * freq).sin());
122 }
123 }
124
125 data
126}
127
128fn print_sample(sample: &Array1<f64>) {
130 print!(" [");
131 for (i, &val) in sample.iter().enumerate() {
132 if i > 0 {
133 print!(", ");
134 }
135 print!("{val:.4}");
136 }
137 println!("]");
138}