quantum_gan/
quantum_gan.rs

1#![allow(clippy::pedantic, clippy::unnecessary_wraps)]
2use quantrs2_ml::gan::{DiscriminatorType, GANEvaluationMetrics, GeneratorType, QuantumGAN};
3use quantrs2_ml::prelude::*;
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::prelude::*;
6use std::time::Instant;
7
8fn main() -> Result<()> {
9    println!("Quantum Generative Adversarial Network Example");
10    println!("=============================================");
11
12    // GAN parameters
13    let num_qubits_gen = 6;
14    let num_qubits_disc = 6;
15    let latent_dim = 4;
16    let data_dim = 8;
17
18    println!("Creating Quantum GAN...");
19    println!("  Generator: {num_qubits_gen} qubits");
20    println!("  Discriminator: {num_qubits_disc} qubits");
21    println!("  Latent dimension: {latent_dim}");
22    println!("  Data dimension: {data_dim}");
23
24    // Create quantum GAN
25    let mut qgan = QuantumGAN::new(
26        num_qubits_gen,
27        num_qubits_disc,
28        latent_dim,
29        data_dim,
30        GeneratorType::HybridClassicalQuantum,
31        DiscriminatorType::HybridQuantumFeatures,
32    )?;
33
34    // Generate synthetic data for training
35    println!("Generating synthetic data for training...");
36    let real_data = generate_sine_wave_data(500, data_dim);
37
38    // Train GAN
39    println!("Training quantum GAN...");
40    let training_params = [
41        (50, 32, 0.01, 0.01, 1), // (epochs, batch_size, lr_gen, lr_disc, disc_steps)
42    ];
43
44    for (epochs, batch_size, lr_gen, lr_disc, disc_steps) in training_params {
45        println!("Training with parameters:");
46        println!("  Epochs: {epochs}");
47        println!("  Batch size: {batch_size}");
48        println!("  Generator learning rate: {lr_gen}");
49        println!("  Discriminator learning rate: {lr_disc}");
50        println!("  Discriminator steps per iteration: {disc_steps}");
51
52        let start = Instant::now();
53        let history = qgan.train(&real_data, epochs, batch_size, lr_gen, lr_disc, disc_steps)?;
54
55        println!("Training completed in {:.2?}", start.elapsed());
56        println!("Final losses:");
57        println!(
58            "  Generator: {:.4}",
59            history.gen_losses.last().unwrap_or(&0.0)
60        );
61        println!(
62            "  Discriminator: {:.4}",
63            history.disc_losses.last().unwrap_or(&0.0)
64        );
65    }
66
67    // Generate samples
68    println!("\nGenerating samples from trained GAN...");
69    let num_samples = 10;
70    let generated_samples = qgan.generate(num_samples)?;
71
72    println!("Generated {num_samples} samples");
73    println!("First sample:");
74    print_sample(
75        &generated_samples
76            .slice(scirs2_core::ndarray::s![0, ..])
77            .to_owned(),
78    );
79
80    // Evaluate GAN
81    println!("\nEvaluating GAN quality...");
82    let eval_metrics = qgan.evaluate(&real_data, num_samples)?;
83
84    println!("Evaluation metrics:");
85    println!(
86        "  Real data accuracy: {:.2}%",
87        eval_metrics.real_accuracy * 100.0
88    );
89    println!(
90        "  Fake data accuracy: {:.2}%",
91        eval_metrics.fake_accuracy * 100.0
92    );
93    println!(
94        "  Overall discriminator accuracy: {:.2}%",
95        eval_metrics.overall_accuracy * 100.0
96    );
97    println!("  JS Divergence: {:.4}", eval_metrics.js_divergence);
98
99    // Use physics-specific GAN
100    println!("\nCreating specialized particle physics GAN...");
101    let particle_gan = quantrs2_ml::gan::physics_gan::ParticleGAN::new(
102        num_qubits_gen,
103        num_qubits_disc,
104        latent_dim,
105        data_dim,
106    )?;
107
108    println!("Particle GAN created successfully");
109
110    Ok(())
111}
112
113// Generate synthetic sine wave data
114fn generate_sine_wave_data(num_samples: usize, data_dim: usize) -> Array2<f64> {
115    let mut data = Array2::zeros((num_samples, data_dim));
116
117    for i in 0..num_samples {
118        let x = (i as f64) / (num_samples as f64) * 2.0 * std::f64::consts::PI;
119
120        for j in 0..data_dim {
121            let freq = (j as f64 + 1.0) * 0.5;
122            data[[i, j]] = 0.1f64.mul_add(thread_rng().gen::<f64>(), (x * freq).sin());
123        }
124    }
125
126    data
127}
128
129// Print a sample vector
130fn print_sample(sample: &Array1<f64>) {
131    print!("  [");
132    for (i, &val) in sample.iter().enumerate() {
133        if i > 0 {
134            print!(", ");
135        }
136        print!("{val:.4}");
137    }
138    println!("]");
139}