quantum_gan/
quantum_gan.rs1#![allow(clippy::pedantic, clippy::unnecessary_wraps)]
2use quantrs2_ml::gan::{DiscriminatorType, GANEvaluationMetrics, GeneratorType, QuantumGAN};
3use quantrs2_ml::prelude::*;
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::prelude::*;
6use std::time::Instant;
7
8fn main() -> Result<()> {
9 println!("Quantum Generative Adversarial Network Example");
10 println!("=============================================");
11
12 let num_qubits_gen = 6;
14 let num_qubits_disc = 6;
15 let latent_dim = 4;
16 let data_dim = 8;
17
18 println!("Creating Quantum GAN...");
19 println!(" Generator: {num_qubits_gen} qubits");
20 println!(" Discriminator: {num_qubits_disc} qubits");
21 println!(" Latent dimension: {latent_dim}");
22 println!(" Data dimension: {data_dim}");
23
24 let mut qgan = QuantumGAN::new(
26 num_qubits_gen,
27 num_qubits_disc,
28 latent_dim,
29 data_dim,
30 GeneratorType::HybridClassicalQuantum,
31 DiscriminatorType::HybridQuantumFeatures,
32 )?;
33
34 println!("Generating synthetic data for training...");
36 let real_data = generate_sine_wave_data(500, data_dim);
37
38 println!("Training quantum GAN...");
40 let training_params = [
41 (50, 32, 0.01, 0.01, 1), ];
43
44 for (epochs, batch_size, lr_gen, lr_disc, disc_steps) in training_params {
45 println!("Training with parameters:");
46 println!(" Epochs: {epochs}");
47 println!(" Batch size: {batch_size}");
48 println!(" Generator learning rate: {lr_gen}");
49 println!(" Discriminator learning rate: {lr_disc}");
50 println!(" Discriminator steps per iteration: {disc_steps}");
51
52 let start = Instant::now();
53 let history = qgan.train(&real_data, epochs, batch_size, lr_gen, lr_disc, disc_steps)?;
54
55 println!("Training completed in {:.2?}", start.elapsed());
56 println!("Final losses:");
57 println!(
58 " Generator: {:.4}",
59 history.gen_losses.last().unwrap_or(&0.0)
60 );
61 println!(
62 " Discriminator: {:.4}",
63 history.disc_losses.last().unwrap_or(&0.0)
64 );
65 }
66
67 println!("\nGenerating samples from trained GAN...");
69 let num_samples = 10;
70 let generated_samples = qgan.generate(num_samples)?;
71
72 println!("Generated {num_samples} samples");
73 println!("First sample:");
74 print_sample(
75 &generated_samples
76 .slice(scirs2_core::ndarray::s![0, ..])
77 .to_owned(),
78 );
79
80 println!("\nEvaluating GAN quality...");
82 let eval_metrics = qgan.evaluate(&real_data, num_samples)?;
83
84 println!("Evaluation metrics:");
85 println!(
86 " Real data accuracy: {:.2}%",
87 eval_metrics.real_accuracy * 100.0
88 );
89 println!(
90 " Fake data accuracy: {:.2}%",
91 eval_metrics.fake_accuracy * 100.0
92 );
93 println!(
94 " Overall discriminator accuracy: {:.2}%",
95 eval_metrics.overall_accuracy * 100.0
96 );
97 println!(" JS Divergence: {:.4}", eval_metrics.js_divergence);
98
99 println!("\nCreating specialized particle physics GAN...");
101 let particle_gan = quantrs2_ml::gan::physics_gan::ParticleGAN::new(
102 num_qubits_gen,
103 num_qubits_disc,
104 latent_dim,
105 data_dim,
106 )?;
107
108 println!("Particle GAN created successfully");
109
110 Ok(())
111}
112
113fn generate_sine_wave_data(num_samples: usize, data_dim: usize) -> Array2<f64> {
115 let mut data = Array2::zeros((num_samples, data_dim));
116
117 for i in 0..num_samples {
118 let x = (i as f64) / (num_samples as f64) * 2.0 * std::f64::consts::PI;
119
120 for j in 0..data_dim {
121 let freq = (j as f64 + 1.0) * 0.5;
122 data[[i, j]] = 0.1f64.mul_add(thread_rng().gen::<f64>(), (x * freq).sin());
123 }
124 }
125
126 data
127}
128
129fn print_sample(sample: &Array1<f64>) {
131 print!(" [");
132 for (i, &val) in sample.iter().enumerate() {
133 if i > 0 {
134 print!(", ");
135 }
136 print!("{val:.4}");
137 }
138 println!("]");
139}