1use crate::error::{MLError, Result};
8use crate::qnn::QuantumNeuralNetwork;
9use quantrs2_circuit::prelude::Circuit;
10use quantrs2_sim::statevector::StateVectorSimulator;
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::prelude::*;
13use std::fmt;
14
15#[derive(Debug, Clone, Copy)]
17pub enum GeneratorType {
18 Classical,
20
21 QuantumOnly,
23
24 HybridClassicalQuantum,
26}
27
28#[derive(Debug, Clone, Copy)]
30pub enum DiscriminatorType {
31 Classical,
33
34 QuantumOnly,
36
37 HybridQuantumFeatures,
39
40 HybridQuantumDecision,
42}
43
44#[derive(Debug, Clone)]
46pub struct GANTrainingHistory {
47 pub gen_losses: Vec<f64>,
49
50 pub disc_losses: Vec<f64>,
52}
53
54#[derive(Debug, Clone)]
56pub struct GANEvaluationMetrics {
57 pub real_accuracy: f64,
59
60 pub fake_accuracy: f64,
62
63 pub overall_accuracy: f64,
65
66 pub js_divergence: f64,
68}
69
70pub trait Generator {
72 fn generate(&self, num_samples: usize) -> Result<Array2<f64>>;
74
75 fn generate_conditional(
77 &self,
78 num_samples: usize,
79 conditions: &[(usize, f64)],
80 ) -> Result<Array2<f64>>;
81
82 fn update(
84 &mut self,
85 latent_vectors: &Array2<f64>,
86 discriminator_outputs: &Array1<f64>,
87 learning_rate: f64,
88 ) -> Result<f64>;
89}
90
91pub trait Discriminator {
93 fn discriminate(&self, samples: &Array2<f64>) -> Result<Array1<f64>>;
95
96 fn predict_batch(&self, samples: &Array2<f64>) -> Result<Array1<f64>> {
98 self.discriminate(samples)
99 }
100
101 fn update(
103 &mut self,
104 real_samples: &Array2<f64>,
105 generated_samples: &Array2<f64>,
106 learning_rate: f64,
107 ) -> Result<f64>;
108}
109
110pub mod physics_gan {
112 use super::*;
113
114 pub struct ParticleGAN {
116 pub gan: QuantumGAN,
118
119 pub physics_params: PhysicsParameters,
121 }
122
123 #[derive(Debug, Clone)]
125 pub struct PhysicsParameters {
126 pub energy_scale: f64,
128
129 pub momentum_conservation: f64,
131
132 pub quantum_effects: bool,
134 }
135
136 impl ParticleGAN {
137 pub fn new(
139 num_qubits_gen: usize,
140 num_qubits_disc: usize,
141 latent_dim: usize,
142 data_dim: usize,
143 ) -> Result<Self> {
144 let gan = QuantumGAN::new(
146 num_qubits_gen,
147 num_qubits_disc,
148 latent_dim,
149 data_dim,
150 GeneratorType::HybridClassicalQuantum,
151 DiscriminatorType::HybridQuantumFeatures,
152 )?;
153
154 let physics_params = PhysicsParameters {
156 energy_scale: 100.0, momentum_conservation: 0.99,
158 quantum_effects: true,
159 };
160
161 Ok(ParticleGAN {
162 gan,
163 physics_params,
164 })
165 }
166
167 pub fn train(
169 &mut self,
170 particle_data: &Array2<f64>,
171 epochs: usize,
172 ) -> Result<&GANTrainingHistory> {
173 self.gan.train(
175 particle_data,
176 epochs,
177 32, 0.01, 0.01, 1, )
182 }
183
184 pub fn generate_particles(&self, num_particles: usize) -> Result<Array2<f64>> {
186 let raw_data = self.gan.generate(num_particles)?;
188
189 Ok(raw_data)
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct QuantumGenerator {
200 num_qubits: usize,
202
203 latent_dim: usize,
205
206 data_dim: usize,
208
209 generator_type: GeneratorType,
211
212 qnn: QuantumNeuralNetwork,
214}
215
216impl QuantumGenerator {
217 pub fn new(
219 num_qubits: usize,
220 latent_dim: usize,
221 data_dim: usize,
222 generator_type: GeneratorType,
223 ) -> Result<Self> {
224 let layers = vec![
226 crate::qnn::QNNLayerType::EncodingLayer {
227 num_features: latent_dim,
228 },
229 crate::qnn::QNNLayerType::VariationalLayer {
230 num_params: 2 * num_qubits,
231 },
232 crate::qnn::QNNLayerType::EntanglementLayer {
233 connectivity: "full".to_string(),
234 },
235 crate::qnn::QNNLayerType::VariationalLayer {
236 num_params: 2 * num_qubits,
237 },
238 crate::qnn::QNNLayerType::MeasurementLayer {
239 measurement_basis: "computational".to_string(),
240 },
241 ];
242
243 let qnn = QuantumNeuralNetwork::new(layers, num_qubits, latent_dim, data_dim)?;
244
245 Ok(QuantumGenerator {
246 num_qubits,
247 latent_dim,
248 data_dim,
249 generator_type,
250 qnn,
251 })
252 }
253}
254
255impl Generator for QuantumGenerator {
256 fn generate(&self, num_samples: usize) -> Result<Array2<f64>> {
257 let mut latent_vectors = Array2::zeros((num_samples, self.latent_dim));
259 for i in 0..num_samples {
260 for j in 0..self.latent_dim {
261 latent_vectors[[i, j]] = thread_rng().random::<f64>() * 2.0 - 1.0;
262 }
263 }
264
265 let mut samples = Array2::zeros((num_samples, self.data_dim));
268 for i in 0..num_samples {
269 for j in 0..self.data_dim {
270 let latent_sum = latent_vectors.row(i).sum();
272 samples[[i, j]] = (latent_sum + (j as f64) * 0.1).sin() * 0.5 + 0.5;
273 }
274 }
275
276 Ok(samples)
277 }
278
279 fn generate_conditional(
280 &self,
281 num_samples: usize,
282 conditions: &[(usize, f64)],
283 ) -> Result<Array2<f64>> {
284 let mut samples = self.generate(num_samples)?;
286
287 for &(feature_idx, value) in conditions {
289 if feature_idx < self.data_dim {
290 for i in 0..num_samples {
291 samples[[i, feature_idx]] = value;
292 }
293 }
294 }
295
296 Ok(samples)
297 }
298
299 fn update(
300 &mut self,
301 _latent_vectors: &Array2<f64>,
302 _discriminator_outputs: &Array1<f64>,
303 _learning_rate: f64,
304 ) -> Result<f64> {
305 Ok(0.5)
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct QuantumDiscriminator {
313 num_qubits: usize,
315
316 data_dim: usize,
318
319 discriminator_type: DiscriminatorType,
321
322 qnn: QuantumNeuralNetwork,
324}
325
326impl QuantumDiscriminator {
327 pub fn new(
329 num_qubits: usize,
330 data_dim: usize,
331 discriminator_type: DiscriminatorType,
332 ) -> Result<Self> {
333 let layers = vec![
335 crate::qnn::QNNLayerType::EncodingLayer {
336 num_features: data_dim,
337 },
338 crate::qnn::QNNLayerType::VariationalLayer {
339 num_params: 2 * num_qubits,
340 },
341 crate::qnn::QNNLayerType::EntanglementLayer {
342 connectivity: "full".to_string(),
343 },
344 crate::qnn::QNNLayerType::VariationalLayer {
345 num_params: 2 * num_qubits,
346 },
347 crate::qnn::QNNLayerType::MeasurementLayer {
348 measurement_basis: "computational".to_string(),
349 },
350 ];
351
352 let qnn = QuantumNeuralNetwork::new(
353 layers, num_qubits, data_dim, 1, )?;
355
356 Ok(QuantumDiscriminator {
357 num_qubits,
358 data_dim,
359 discriminator_type,
360 qnn,
361 })
362 }
363}
364
365impl Discriminator for QuantumDiscriminator {
366 fn discriminate(&self, samples: &Array2<f64>) -> Result<Array1<f64>> {
367 let num_samples = samples.nrows();
371 let mut outputs = Array1::zeros(num_samples);
372
373 for i in 0..num_samples {
374 let sum = samples.row(i).sum();
376 outputs[i] = (sum * 0.1).sin() * 0.5 + 0.5;
377 }
378
379 Ok(outputs)
380 }
381
382 fn update(
383 &mut self,
384 _real_samples: &Array2<f64>,
385 _generated_samples: &Array2<f64>,
386 _learning_rate: f64,
387 ) -> Result<f64> {
388 Ok(0.5)
390 }
391}
392
393#[derive(Debug, Clone)]
395pub struct QuantumGAN {
396 pub generator: QuantumGenerator,
398
399 pub discriminator: QuantumDiscriminator,
401
402 pub training_history: GANTrainingHistory,
404}
405
406impl QuantumGAN {
407 pub fn new(
409 num_qubits_gen: usize,
410 num_qubits_disc: usize,
411 latent_dim: usize,
412 data_dim: usize,
413 generator_type: GeneratorType,
414 discriminator_type: DiscriminatorType,
415 ) -> Result<Self> {
416 let generator =
417 QuantumGenerator::new(num_qubits_gen, latent_dim, data_dim, generator_type)?;
418
419 let discriminator =
420 QuantumDiscriminator::new(num_qubits_disc, data_dim, discriminator_type)?;
421
422 let training_history = GANTrainingHistory {
423 gen_losses: Vec::new(),
424 disc_losses: Vec::new(),
425 };
426
427 Ok(QuantumGAN {
428 generator,
429 discriminator,
430 training_history,
431 })
432 }
433
434 pub fn train(
436 &mut self,
437 real_data: &Array2<f64>,
438 epochs: usize,
439 batch_size: usize,
440 gen_learning_rate: f64,
441 disc_learning_rate: f64,
442 disc_steps: usize,
443 ) -> Result<&GANTrainingHistory> {
444 let mut gen_losses = Vec::with_capacity(epochs);
445 let mut disc_losses = Vec::with_capacity(epochs);
446
447 for _epoch in 0..epochs {
448 let mut disc_loss_sum = 0.0;
450 for _step in 0..disc_steps {
451 let fake_samples = self.generator.generate(batch_size)?;
453
454 let real_batch = sample_batch(real_data, batch_size)?;
456
457 let disc_loss =
459 self.discriminator
460 .update(&real_batch, &fake_samples, disc_learning_rate)?;
461 disc_loss_sum += disc_loss;
462 }
463 let avg_disc_loss = disc_loss_sum / disc_steps as f64;
464
465 let latent_vectors = Array2::zeros((batch_size, self.generator.latent_dim));
467 let fake_outputs = Array1::zeros(batch_size);
468 let gen_loss =
469 self.generator
470 .update(&latent_vectors, &fake_outputs, gen_learning_rate)?;
471
472 gen_losses.push(gen_loss);
474 disc_losses.push(avg_disc_loss);
475 }
476
477 self.training_history = GANTrainingHistory {
478 gen_losses,
479 disc_losses,
480 };
481
482 Ok(&self.training_history)
483 }
484
485 pub fn generate(&self, num_samples: usize) -> Result<Array2<f64>> {
487 self.generator.generate(num_samples)
488 }
489
490 pub fn generate_conditional(
492 &self,
493 num_samples: usize,
494 conditions: &[(usize, f64)],
495 ) -> Result<Array2<f64>> {
496 self.generator.generate_conditional(num_samples, conditions)
497 }
498
499 pub fn evaluate(
501 &self,
502 real_data: &Array2<f64>,
503 num_samples: usize,
504 ) -> Result<GANEvaluationMetrics> {
505 let fake_samples = self.generate(num_samples)?;
507
508 let real_preds = self.discriminator.predict_batch(real_data)?;
510 let real_correct = real_preds.iter().filter(|&&p| p > 0.5).count();
511 let real_accuracy = real_correct as f64 / real_preds.len() as f64;
512
513 let fake_preds = self.discriminator.predict_batch(&fake_samples)?;
515 let fake_correct = fake_preds.iter().filter(|&&p| p < 0.5).count();
516 let fake_accuracy = fake_correct as f64 / fake_preds.len() as f64;
517
518 let overall_correct = real_correct + fake_correct;
520 let overall_total = real_preds.len() + fake_preds.len();
521 let overall_accuracy = overall_correct as f64 / overall_total as f64;
522
523 let js_divergence = calculate_js_divergence(real_data, &fake_samples)?;
526
527 Ok(GANEvaluationMetrics {
528 real_accuracy,
529 fake_accuracy,
530 overall_accuracy,
531 js_divergence,
532 })
533 }
534}
535
536fn calculate_js_divergence(data1: &Array2<f64>, data2: &Array2<f64>) -> Result<f64> {
542 if data1.ncols() == 0 || data1.nrows() == 0 || data2.nrows() == 0 {
543 return Ok(0.0);
544 }
545
546 let n_bins: usize = 20;
547 let n_cols = data1.ncols().min(data2.ncols());
548 let mut total_js = 0.0;
549
550 for col in 0..n_cols {
551 let col1: Vec<f64> = data1.column(col).to_vec();
552 let col2: Vec<f64> = data2.column(col).to_vec();
553
554 let min_val = col1
555 .iter()
556 .chain(col2.iter())
557 .cloned()
558 .fold(f64::INFINITY, f64::min);
559 let max_val = col1
560 .iter()
561 .chain(col2.iter())
562 .cloned()
563 .fold(f64::NEG_INFINITY, f64::max);
564
565 if (max_val - min_val).abs() < 1e-14 {
566 continue;
568 }
569
570 let bin_width = (max_val - min_val) / n_bins as f64;
571 let mut hist1 = vec![0.0f64; n_bins];
572 let mut hist2 = vec![0.0f64; n_bins];
573
574 for &v in &col1 {
575 let bin = ((v - min_val) / bin_width) as usize;
576 let bin = bin.min(n_bins - 1);
577 hist1[bin] += 1.0;
578 }
579 for &v in &col2 {
580 let bin = ((v - min_val) / bin_width) as usize;
581 let bin = bin.min(n_bins - 1);
582 hist2[bin] += 1.0;
583 }
584
585 let n1 = col1.len() as f64;
586 let n2 = col2.len() as f64;
587 for i in 0..n_bins {
588 hist1[i] /= n1;
589 hist2[i] /= n2;
590 }
591
592 let mut js = 0.0f64;
594 for i in 0..n_bins {
595 let p = hist1[i];
596 let q = hist2[i];
597 let m = (p + q) * 0.5;
598 if m > 1e-14 {
599 if p > 1e-14 {
600 js += 0.5 * p * (p / m).ln();
601 }
602 if q > 1e-14 {
603 js += 0.5 * q * (q / m).ln();
604 }
605 }
606 }
607 total_js += js;
608 }
609
610 Ok(if n_cols > 0 {
611 total_js / n_cols as f64
612 } else {
613 0.0
614 })
615}
616
617fn sample_batch(data: &Array2<f64>, batch_size: usize) -> Result<Array2<f64>> {
619 let num_samples = data.nrows();
620 let mut batch = Array2::zeros((batch_size.min(num_samples), data.ncols()));
621
622 for i in 0..batch_size.min(num_samples) {
623 let idx = fastrand::usize(0..num_samples);
624 batch.row_mut(i).assign(&data.row(idx));
625 }
626
627 Ok(batch)
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use scirs2_core::ndarray::Array2;
634
635 #[test]
636 fn test_js_divergence_identical() {
637 let data = Array2::from_shape_vec((4, 2), vec![0.0, 1.0, 0.5, 0.5, 0.2, 0.8, 0.7, 0.3])
638 .expect("array creation failed");
639 let js = calculate_js_divergence(&data, &data).expect("divergence failed");
640 assert!(js < 0.01, "JS(p,p) should be ≈0, got {js}");
641 }
642
643 #[test]
644 fn test_js_divergence_bounded() {
645 let data1 =
646 Array2::from_shape_vec((4, 1), vec![0.0, 0.0, 0.0, 0.0]).expect("array creation");
647 let data2 =
648 Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).expect("array creation");
649 let js = calculate_js_divergence(&data1, &data2).expect("divergence failed");
650 assert!(js >= 0.0 && js <= 1.0, "JS should be in [0, 1], got {js}");
651 }
652}
653
654impl fmt::Display for GeneratorType {
655 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
656 match self {
657 GeneratorType::Classical => write!(f, "Classical"),
658 GeneratorType::QuantumOnly => write!(f, "Quantum Only"),
659 GeneratorType::HybridClassicalQuantum => write!(f, "Hybrid Classical-Quantum"),
660 }
661 }
662}
663
664impl fmt::Display for DiscriminatorType {
665 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
666 match self {
667 DiscriminatorType::Classical => write!(f, "Classical"),
668 DiscriminatorType::QuantumOnly => write!(f, "Quantum Only"),
669 DiscriminatorType::HybridQuantumFeatures => write!(f, "Hybrid with Quantum Features"),
670 DiscriminatorType::HybridQuantumDecision => write!(f, "Hybrid with Quantum Decision"),
671 }
672 }
673}