quantum_gan/
quantum_gan.rs1#![allow(
2 clippy::pedantic,
3 clippy::unnecessary_wraps,
4 clippy::needless_range_loop,
5 clippy::useless_vec,
6 clippy::needless_collect,
7 clippy::too_many_arguments
8)]
9use quantrs2_ml::gan::{DiscriminatorType, GANEvaluationMetrics, GeneratorType, QuantumGAN};
10use quantrs2_ml::prelude::*;
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::prelude::*;
13use std::time::Instant;
14
15fn main() -> Result<()> {
16 println!("Quantum Generative Adversarial Network Example");
17 println!("=============================================");
18
19 let num_qubits_gen = 6;
21 let num_qubits_disc = 6;
22 let latent_dim = 4;
23 let data_dim = 8;
24
25 println!("Creating Quantum GAN...");
26 println!(" Generator: {num_qubits_gen} qubits");
27 println!(" Discriminator: {num_qubits_disc} qubits");
28 println!(" Latent dimension: {latent_dim}");
29 println!(" Data dimension: {data_dim}");
30
31 let mut qgan = QuantumGAN::new(
33 num_qubits_gen,
34 num_qubits_disc,
35 latent_dim,
36 data_dim,
37 GeneratorType::HybridClassicalQuantum,
38 DiscriminatorType::HybridQuantumFeatures,
39 )?;
40
41 println!("Generating synthetic data for training...");
43 let real_data = generate_sine_wave_data(500, data_dim);
44
45 println!("Training quantum GAN...");
47 let training_params = [
48 (50, 32, 0.01, 0.01, 1), ];
50
51 for (epochs, batch_size, lr_gen, lr_disc, disc_steps) in training_params {
52 println!("Training with parameters:");
53 println!(" Epochs: {epochs}");
54 println!(" Batch size: {batch_size}");
55 println!(" Generator learning rate: {lr_gen}");
56 println!(" Discriminator learning rate: {lr_disc}");
57 println!(" Discriminator steps per iteration: {disc_steps}");
58
59 let start = Instant::now();
60 let history = qgan.train(&real_data, epochs, batch_size, lr_gen, lr_disc, disc_steps)?;
61
62 println!("Training completed in {:.2?}", start.elapsed());
63 println!("Final losses:");
64 println!(
65 " Generator: {:.4}",
66 history.gen_losses.last().unwrap_or(&0.0)
67 );
68 println!(
69 " Discriminator: {:.4}",
70 history.disc_losses.last().unwrap_or(&0.0)
71 );
72 }
73
74 println!("\nGenerating samples from trained GAN...");
76 let num_samples = 10;
77 let generated_samples = qgan.generate(num_samples)?;
78
79 println!("Generated {num_samples} samples");
80 println!("First sample:");
81 print_sample(
82 &generated_samples
83 .slice(scirs2_core::ndarray::s![0, ..])
84 .to_owned(),
85 );
86
87 println!("\nEvaluating GAN quality...");
89 let eval_metrics = qgan.evaluate(&real_data, num_samples)?;
90
91 println!("Evaluation metrics:");
92 println!(
93 " Real data accuracy: {:.2}%",
94 eval_metrics.real_accuracy * 100.0
95 );
96 println!(
97 " Fake data accuracy: {:.2}%",
98 eval_metrics.fake_accuracy * 100.0
99 );
100 println!(
101 " Overall discriminator accuracy: {:.2}%",
102 eval_metrics.overall_accuracy * 100.0
103 );
104 println!(" JS Divergence: {:.4}", eval_metrics.js_divergence);
105
106 println!("\nCreating specialized particle physics GAN...");
108 let particle_gan = quantrs2_ml::gan::physics_gan::ParticleGAN::new(
109 num_qubits_gen,
110 num_qubits_disc,
111 latent_dim,
112 data_dim,
113 )?;
114
115 println!("Particle GAN created successfully");
116
117 Ok(())
118}
119
120fn generate_sine_wave_data(num_samples: usize, data_dim: usize) -> Array2<f64> {
122 let mut data = Array2::zeros((num_samples, data_dim));
123
124 for i in 0..num_samples {
125 let x = (i as f64) / (num_samples as f64) * 2.0 * std::f64::consts::PI;
126
127 for j in 0..data_dim {
128 let freq = (j as f64 + 1.0) * 0.5;
129 data[[i, j]] = 0.1f64.mul_add(thread_rng().gen::<f64>(), (x * freq).sin());
130 }
131 }
132
133 data
134}
135
136fn print_sample(sample: &Array1<f64>) {
138 print!(" [");
139 for (i, &val) in sample.iter().enumerate() {
140 if i > 0 {
141 print!(", ");
142 }
143 print!("{val:.4}");
144 }
145 println!("]");
146}