quantum_gan/
quantum_gan.rs

1#![allow(
2    clippy::pedantic,
3    clippy::unnecessary_wraps,
4    clippy::needless_range_loop,
5    clippy::useless_vec,
6    clippy::needless_collect,
7    clippy::too_many_arguments
8)]
9use quantrs2_ml::gan::{DiscriminatorType, GANEvaluationMetrics, GeneratorType, QuantumGAN};
10use quantrs2_ml::prelude::*;
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::prelude::*;
13use std::time::Instant;
14
15fn main() -> Result<()> {
16    println!("Quantum Generative Adversarial Network Example");
17    println!("=============================================");
18
19    // GAN parameters
20    let num_qubits_gen = 6;
21    let num_qubits_disc = 6;
22    let latent_dim = 4;
23    let data_dim = 8;
24
25    println!("Creating Quantum GAN...");
26    println!("  Generator: {num_qubits_gen} qubits");
27    println!("  Discriminator: {num_qubits_disc} qubits");
28    println!("  Latent dimension: {latent_dim}");
29    println!("  Data dimension: {data_dim}");
30
31    // Create quantum GAN
32    let mut qgan = QuantumGAN::new(
33        num_qubits_gen,
34        num_qubits_disc,
35        latent_dim,
36        data_dim,
37        GeneratorType::HybridClassicalQuantum,
38        DiscriminatorType::HybridQuantumFeatures,
39    )?;
40
41    // Generate synthetic data for training
42    println!("Generating synthetic data for training...");
43    let real_data = generate_sine_wave_data(500, data_dim);
44
45    // Train GAN
46    println!("Training quantum GAN...");
47    let training_params = [
48        (50, 32, 0.01, 0.01, 1), // (epochs, batch_size, lr_gen, lr_disc, disc_steps)
49    ];
50
51    for (epochs, batch_size, lr_gen, lr_disc, disc_steps) in training_params {
52        println!("Training with parameters:");
53        println!("  Epochs: {epochs}");
54        println!("  Batch size: {batch_size}");
55        println!("  Generator learning rate: {lr_gen}");
56        println!("  Discriminator learning rate: {lr_disc}");
57        println!("  Discriminator steps per iteration: {disc_steps}");
58
59        let start = Instant::now();
60        let history = qgan.train(&real_data, epochs, batch_size, lr_gen, lr_disc, disc_steps)?;
61
62        println!("Training completed in {:.2?}", start.elapsed());
63        println!("Final losses:");
64        println!(
65            "  Generator: {:.4}",
66            history.gen_losses.last().unwrap_or(&0.0)
67        );
68        println!(
69            "  Discriminator: {:.4}",
70            history.disc_losses.last().unwrap_or(&0.0)
71        );
72    }
73
74    // Generate samples
75    println!("\nGenerating samples from trained GAN...");
76    let num_samples = 10;
77    let generated_samples = qgan.generate(num_samples)?;
78
79    println!("Generated {num_samples} samples");
80    println!("First sample:");
81    print_sample(
82        &generated_samples
83            .slice(scirs2_core::ndarray::s![0, ..])
84            .to_owned(),
85    );
86
87    // Evaluate GAN
88    println!("\nEvaluating GAN quality...");
89    let eval_metrics = qgan.evaluate(&real_data, num_samples)?;
90
91    println!("Evaluation metrics:");
92    println!(
93        "  Real data accuracy: {:.2}%",
94        eval_metrics.real_accuracy * 100.0
95    );
96    println!(
97        "  Fake data accuracy: {:.2}%",
98        eval_metrics.fake_accuracy * 100.0
99    );
100    println!(
101        "  Overall discriminator accuracy: {:.2}%",
102        eval_metrics.overall_accuracy * 100.0
103    );
104    println!("  JS Divergence: {:.4}", eval_metrics.js_divergence);
105
106    // Use physics-specific GAN
107    println!("\nCreating specialized particle physics GAN...");
108    let particle_gan = quantrs2_ml::gan::physics_gan::ParticleGAN::new(
109        num_qubits_gen,
110        num_qubits_disc,
111        latent_dim,
112        data_dim,
113    )?;
114
115    println!("Particle GAN created successfully");
116
117    Ok(())
118}
119
120// Generate synthetic sine wave data
121fn generate_sine_wave_data(num_samples: usize, data_dim: usize) -> Array2<f64> {
122    let mut data = Array2::zeros((num_samples, data_dim));
123
124    for i in 0..num_samples {
125        let x = (i as f64) / (num_samples as f64) * 2.0 * std::f64::consts::PI;
126
127        for j in 0..data_dim {
128            let freq = (j as f64 + 1.0) * 0.5;
129            data[[i, j]] = 0.1f64.mul_add(thread_rng().gen::<f64>(), (x * freq).sin());
130        }
131    }
132
133    data
134}
135
136// Print a sample vector
137fn print_sample(sample: &Array1<f64>) {
138    print!("  [");
139    for (i, &val) in sample.iter().enumerate() {
140        if i > 0 {
141            print!(", ");
142        }
143        print!("{val:.4}");
144    }
145    println!("]");
146}