quantum_gan/
quantum_gan.rs

1use quantrs2_ml::gan::{DiscriminatorType, GANEvaluationMetrics, GeneratorType, QuantumGAN};
2use quantrs2_ml::prelude::*;
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::prelude::*;
5use std::time::Instant;
6
7fn main() -> Result<()> {
8    println!("Quantum Generative Adversarial Network Example");
9    println!("=============================================");
10
11    // GAN parameters
12    let num_qubits_gen = 6;
13    let num_qubits_disc = 6;
14    let latent_dim = 4;
15    let data_dim = 8;
16
17    println!("Creating Quantum GAN...");
18    println!("  Generator: {num_qubits_gen} qubits");
19    println!("  Discriminator: {num_qubits_disc} qubits");
20    println!("  Latent dimension: {latent_dim}");
21    println!("  Data dimension: {data_dim}");
22
23    // Create quantum GAN
24    let mut qgan = QuantumGAN::new(
25        num_qubits_gen,
26        num_qubits_disc,
27        latent_dim,
28        data_dim,
29        GeneratorType::HybridClassicalQuantum,
30        DiscriminatorType::HybridQuantumFeatures,
31    )?;
32
33    // Generate synthetic data for training
34    println!("Generating synthetic data for training...");
35    let real_data = generate_sine_wave_data(500, data_dim);
36
37    // Train GAN
38    println!("Training quantum GAN...");
39    let training_params = [
40        (50, 32, 0.01, 0.01, 1), // (epochs, batch_size, lr_gen, lr_disc, disc_steps)
41    ];
42
43    for (epochs, batch_size, lr_gen, lr_disc, disc_steps) in training_params {
44        println!("Training with parameters:");
45        println!("  Epochs: {epochs}");
46        println!("  Batch size: {batch_size}");
47        println!("  Generator learning rate: {lr_gen}");
48        println!("  Discriminator learning rate: {lr_disc}");
49        println!("  Discriminator steps per iteration: {disc_steps}");
50
51        let start = Instant::now();
52        let history = qgan.train(&real_data, epochs, batch_size, lr_gen, lr_disc, disc_steps)?;
53
54        println!("Training completed in {:.2?}", start.elapsed());
55        println!("Final losses:");
56        println!(
57            "  Generator: {:.4}",
58            history.gen_losses.last().unwrap_or(&0.0)
59        );
60        println!(
61            "  Discriminator: {:.4}",
62            history.disc_losses.last().unwrap_or(&0.0)
63        );
64    }
65
66    // Generate samples
67    println!("\nGenerating samples from trained GAN...");
68    let num_samples = 10;
69    let generated_samples = qgan.generate(num_samples)?;
70
71    println!("Generated {num_samples} samples");
72    println!("First sample:");
73    print_sample(
74        &generated_samples
75            .slice(scirs2_core::ndarray::s![0, ..])
76            .to_owned(),
77    );
78
79    // Evaluate GAN
80    println!("\nEvaluating GAN quality...");
81    let eval_metrics = qgan.evaluate(&real_data, num_samples)?;
82
83    println!("Evaluation metrics:");
84    println!(
85        "  Real data accuracy: {:.2}%",
86        eval_metrics.real_accuracy * 100.0
87    );
88    println!(
89        "  Fake data accuracy: {:.2}%",
90        eval_metrics.fake_accuracy * 100.0
91    );
92    println!(
93        "  Overall discriminator accuracy: {:.2}%",
94        eval_metrics.overall_accuracy * 100.0
95    );
96    println!("  JS Divergence: {:.4}", eval_metrics.js_divergence);
97
98    // Use physics-specific GAN
99    println!("\nCreating specialized particle physics GAN...");
100    let particle_gan = quantrs2_ml::gan::physics_gan::ParticleGAN::new(
101        num_qubits_gen,
102        num_qubits_disc,
103        latent_dim,
104        data_dim,
105    )?;
106
107    println!("Particle GAN created successfully");
108
109    Ok(())
110}
111
112// Generate synthetic sine wave data
113fn generate_sine_wave_data(num_samples: usize, data_dim: usize) -> Array2<f64> {
114    let mut data = Array2::zeros((num_samples, data_dim));
115
116    for i in 0..num_samples {
117        let x = (i as f64) / (num_samples as f64) * 2.0 * std::f64::consts::PI;
118
119        for j in 0..data_dim {
120            let freq = (j as f64 + 1.0) * 0.5;
121            data[[i, j]] = 0.1f64.mul_add(thread_rng().gen::<f64>(), (x * freq).sin());
122        }
123    }
124
125    data
126}
127
128// Print a sample vector
129fn print_sample(sample: &Array1<f64>) {
130    print!("  [");
131    for (i, &val) in sample.iter().enumerate() {
132        if i > 0 {
133            print!(", ");
134        }
135        print!("{val:.4}");
136    }
137    println!("]");
138}