quantrs2_ml/
pytorch_api.rs

1//! PyTorch-like API for quantum machine learning models
2//!
3//! This module provides a familiar PyTorch-style interface for building,
4//! training, and deploying quantum ML models, making it easier for classical
5//! ML practitioners to adopt quantum algorithms.
6
7use crate::circuit_integration::QuantumMLExecutor;
8use crate::error::{MLError, Result};
9use crate::scirs2_integration::{SciRS2Array, SciRS2Optimizer};
10use crate::simulator_backends::{Observable, SimulatorBackend};
11use scirs2_core::ndarray::{Array1, Array2, ArrayD, Axis, Dimension, IxDyn};
12use quantrs2_circuit::prelude::*;
13use std::cell::RefCell;
14use std::collections::HashMap;
15use std::rc::Rc;
16
17/// Base trait for all quantum ML modules
18pub trait QuantumModule: Send + Sync {
19    /// Forward pass
20    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array>;
21
22    /// Get all parameters
23    fn parameters(&self) -> Vec<Parameter>;
24
25    /// Set training mode
26    fn train(&mut self, mode: bool);
27
28    /// Check if module is in training mode
29    fn training(&self) -> bool;
30
31    /// Zero gradients of all parameters
32    fn zero_grad(&mut self);
33
34    /// Module name for debugging
35    fn name(&self) -> &str;
36}
37
38/// Quantum parameter wrapper
39#[derive(Debug, Clone)]
40pub struct Parameter {
41    /// Parameter data
42    pub data: SciRS2Array,
43    /// Parameter name
44    pub name: String,
45    /// Whether parameter requires gradient
46    pub requires_grad: bool,
47}
48
49impl Parameter {
50    /// Create new parameter
51    pub fn new(data: SciRS2Array, name: impl Into<String>) -> Self {
52        Self {
53            data,
54            name: name.into(),
55            requires_grad: true,
56        }
57    }
58
59    /// Create parameter without gradients
60    pub fn no_grad(data: SciRS2Array, name: impl Into<String>) -> Self {
61        Self {
62            data,
63            name: name.into(),
64            requires_grad: false,
65        }
66    }
67
68    /// Get parameter shape
69    pub fn shape(&self) -> &[usize] {
70        self.data.data.shape()
71    }
72
73    /// Get parameter size
74    pub fn numel(&self) -> usize {
75        self.data.data.len()
76    }
77}
78
79/// Quantum linear layer
80pub struct QuantumLinear {
81    /// Weight parameters
82    weights: Parameter,
83    /// Bias parameters (optional)
84    bias: Option<Parameter>,
85    /// Input features
86    in_features: usize,
87    /// Output features
88    out_features: usize,
89    /// Training mode
90    training: bool,
91    /// Circuit executor
92    executor: QuantumMLExecutor<8>, // Fixed const size for now
93}
94
95impl QuantumLinear {
96    /// Create new quantum linear layer
97    pub fn new(in_features: usize, out_features: usize) -> Result<Self> {
98        let weight_data = ArrayD::zeros(IxDyn(&[out_features, in_features]));
99        let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
100
101        Ok(Self {
102            weights,
103            bias: None,
104            in_features,
105            out_features,
106            training: true,
107            executor: QuantumMLExecutor::new(),
108        })
109    }
110
111    /// Create with bias
112    pub fn with_bias(mut self) -> Result<Self> {
113        let bias_data = ArrayD::zeros(IxDyn(&[self.out_features]));
114        self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
115        Ok(self)
116    }
117
118    /// Initialize weights using Xavier/Glorot uniform
119    pub fn init_xavier_uniform(&mut self) -> Result<()> {
120        let fan_in = self.in_features as f64;
121        let fan_out = self.out_features as f64;
122        let bound = (6.0 / (fan_in + fan_out)).sqrt();
123
124        for elem in self.weights.data.data.iter_mut() {
125            *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
126        }
127
128        Ok(())
129    }
130}
131
132impl QuantumModule for QuantumLinear {
133    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
134        // Quantum linear transformation: output = input @ weights.T + bias
135        let output = input.matmul(&self.weights.data)?;
136
137        if let Some(ref bias) = self.bias {
138            output.add(&bias.data)
139        } else {
140            Ok(output)
141        }
142    }
143
144    fn parameters(&self) -> Vec<Parameter> {
145        let mut params = vec![self.weights.clone()];
146        if let Some(ref bias) = self.bias {
147            params.push(bias.clone());
148        }
149        params
150    }
151
152    fn train(&mut self, mode: bool) {
153        self.training = mode;
154    }
155
156    fn training(&self) -> bool {
157        self.training
158    }
159
160    fn zero_grad(&mut self) {
161        self.weights.data.zero_grad();
162        if let Some(ref mut bias) = self.bias {
163            bias.data.zero_grad();
164        }
165    }
166
167    fn name(&self) -> &str {
168        "QuantumLinear"
169    }
170}
171
172/// Quantum convolutional layer
173pub struct QuantumConv2d {
174    /// Convolution parameters
175    weights: Parameter,
176    /// Bias parameters
177    bias: Option<Parameter>,
178    /// Input channels
179    in_channels: usize,
180    /// Output channels
181    out_channels: usize,
182    /// Kernel size
183    kernel_size: (usize, usize),
184    /// Stride
185    stride: (usize, usize),
186    /// Padding
187    padding: (usize, usize),
188    /// Training mode
189    training: bool,
190}
191
192impl QuantumConv2d {
193    /// Create new quantum conv2d layer
194    pub fn new(
195        in_channels: usize,
196        out_channels: usize,
197        kernel_size: (usize, usize),
198    ) -> Result<Self> {
199        let weight_shape = [out_channels, in_channels, kernel_size.0, kernel_size.1];
200        let weight_data = ArrayD::zeros(IxDyn(&weight_shape));
201        let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
202
203        Ok(Self {
204            weights,
205            bias: None,
206            in_channels,
207            out_channels,
208            kernel_size,
209            stride: (1, 1),
210            padding: (0, 0),
211            training: true,
212        })
213    }
214
215    /// Set stride
216    pub fn stride(mut self, stride: (usize, usize)) -> Self {
217        self.stride = stride;
218        self
219    }
220
221    /// Set padding
222    pub fn padding(mut self, padding: (usize, usize)) -> Self {
223        self.padding = padding;
224        self
225    }
226
227    /// Add bias
228    pub fn with_bias(mut self) -> Result<Self> {
229        let bias_data = ArrayD::zeros(IxDyn(&[self.out_channels]));
230        self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
231        Ok(self)
232    }
233}
234
235impl QuantumModule for QuantumConv2d {
236    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
237        // Quantum convolution implementation (simplified)
238        // In practice, this would implement quantum convolution operations
239        let output_data = input.data.clone(); // Placeholder
240        let mut output = SciRS2Array::new(output_data, input.requires_grad);
241
242        if let Some(ref bias) = self.bias {
243            output = output.add(&bias.data)?;
244        }
245
246        Ok(output)
247    }
248
249    fn parameters(&self) -> Vec<Parameter> {
250        let mut params = vec![self.weights.clone()];
251        if let Some(ref bias) = self.bias {
252            params.push(bias.clone());
253        }
254        params
255    }
256
257    fn train(&mut self, mode: bool) {
258        self.training = mode;
259    }
260
261    fn training(&self) -> bool {
262        self.training
263    }
264
265    fn zero_grad(&mut self) {
266        self.weights.data.zero_grad();
267        if let Some(ref mut bias) = self.bias {
268            bias.data.zero_grad();
269        }
270    }
271
272    fn name(&self) -> &str {
273        "QuantumConv2d"
274    }
275}
276
277/// Quantum activation functions
278pub struct QuantumActivation {
279    /// Activation function type
280    activation_type: ActivationType,
281    /// Training mode
282    training: bool,
283}
284
285/// Activation function types
286#[derive(Debug, Clone)]
287pub enum ActivationType {
288    /// Quantum ReLU (using rotation gates)
289    QReLU,
290    /// Quantum Sigmoid
291    QSigmoid,
292    /// Quantum Tanh
293    QTanh,
294    /// Quantum Softmax
295    QSoftmax,
296    /// Identity (no activation)
297    Identity,
298}
299
300impl QuantumActivation {
301    /// Create new activation layer
302    pub fn new(activation_type: ActivationType) -> Self {
303        Self {
304            activation_type,
305            training: true,
306        }
307    }
308
309    /// Create ReLU activation
310    pub fn relu() -> Self {
311        Self::new(ActivationType::QReLU)
312    }
313
314    /// Create Sigmoid activation
315    pub fn sigmoid() -> Self {
316        Self::new(ActivationType::QSigmoid)
317    }
318
319    /// Create Tanh activation
320    pub fn tanh() -> Self {
321        Self::new(ActivationType::QTanh)
322    }
323
324    /// Create Softmax activation
325    pub fn softmax() -> Self {
326        Self::new(ActivationType::QSoftmax)
327    }
328}
329
330impl QuantumModule for QuantumActivation {
331    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
332        match self.activation_type {
333            ActivationType::QReLU => {
334                // Quantum ReLU: max(0, x) approximation using quantum gates
335                let output_data = input.data.mapv(|x| x.max(0.0));
336                Ok(SciRS2Array::new(output_data, input.requires_grad))
337            }
338            ActivationType::QSigmoid => {
339                // Quantum sigmoid approximation
340                let output_data = input.data.mapv(|x| 1.0 / (1.0 + (-x).exp()));
341                Ok(SciRS2Array::new(output_data, input.requires_grad))
342            }
343            ActivationType::QTanh => {
344                // Quantum tanh
345                let output_data = input.data.mapv(|x| x.tanh());
346                Ok(SciRS2Array::new(output_data, input.requires_grad))
347            }
348            ActivationType::QSoftmax => {
349                // Quantum softmax
350                let max_val = input.data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
351                let exp_data = input.data.mapv(|x| (x - max_val).exp());
352                let sum_exp = exp_data.sum();
353                let output_data = exp_data.mapv(|x| x / sum_exp);
354                Ok(SciRS2Array::new(output_data, input.requires_grad))
355            }
356            ActivationType::Identity => {
357                Ok(SciRS2Array::new(input.data.clone(), input.requires_grad))
358            }
359        }
360    }
361
362    fn parameters(&self) -> Vec<Parameter> {
363        Vec::new() // Activation functions typically don't have parameters
364    }
365
366    fn train(&mut self, mode: bool) {
367        self.training = mode;
368    }
369
370    fn training(&self) -> bool {
371        self.training
372    }
373
374    fn zero_grad(&mut self) {
375        // No parameters to zero
376    }
377
378    fn name(&self) -> &str {
379        "QuantumActivation"
380    }
381}
382
383/// Sequential container for quantum modules
384pub struct QuantumSequential {
385    /// Ordered modules
386    modules: Vec<Box<dyn QuantumModule>>,
387    /// Training mode
388    training: bool,
389}
390
391impl QuantumSequential {
392    /// Create new sequential container
393    pub fn new() -> Self {
394        Self {
395            modules: Vec::new(),
396            training: true,
397        }
398    }
399
400    /// Add module to sequence
401    pub fn add(mut self, module: Box<dyn QuantumModule>) -> Self {
402        self.modules.push(module);
403        self
404    }
405
406    /// Get number of modules
407    pub fn len(&self) -> usize {
408        self.modules.len()
409    }
410
411    /// Check if empty
412    pub fn is_empty(&self) -> bool {
413        self.modules.is_empty()
414    }
415}
416
417impl QuantumModule for QuantumSequential {
418    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
419        let mut output = input.clone();
420
421        for module in &mut self.modules {
422            output = module.forward(&output)?;
423        }
424
425        Ok(output)
426    }
427
428    fn parameters(&self) -> Vec<Parameter> {
429        let mut all_params = Vec::new();
430
431        for module in &self.modules {
432            all_params.extend(module.parameters());
433        }
434
435        all_params
436    }
437
438    fn train(&mut self, mode: bool) {
439        self.training = mode;
440        for module in &mut self.modules {
441            module.train(mode);
442        }
443    }
444
445    fn training(&self) -> bool {
446        self.training
447    }
448
449    fn zero_grad(&mut self) {
450        for module in &mut self.modules {
451            module.zero_grad();
452        }
453    }
454
455    fn name(&self) -> &str {
456        "QuantumSequential"
457    }
458}
459
460/// Loss functions for quantum ML
461pub trait QuantumLoss {
462    /// Compute loss
463    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array>;
464
465    /// Loss function name
466    fn name(&self) -> &str;
467}
468
469/// Mean Squared Error loss
470pub struct QuantumMSELoss;
471
472impl QuantumLoss for QuantumMSELoss {
473    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
474        let diff = predictions.data.clone() - &targets.data;
475        let squared_diff = &diff * &diff;
476        let mse = squared_diff.mean().unwrap();
477
478        let loss_data = ArrayD::from_elem(IxDyn(&[]), mse);
479        Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
480    }
481
482    fn name(&self) -> &str {
483        "MSELoss"
484    }
485}
486
487/// Cross Entropy loss
488pub struct QuantumCrossEntropyLoss;
489
490impl QuantumLoss for QuantumCrossEntropyLoss {
491    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
492        // Compute softmax of predictions
493        let max_val = predictions
494            .data
495            .iter()
496            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
497        let exp_preds = predictions.data.mapv(|x| (x - max_val).exp());
498        let sum_exp = exp_preds.sum();
499        let softmax = exp_preds.mapv(|x| x / sum_exp);
500
501        // Compute cross entropy
502        let log_softmax = softmax.mapv(|x| x.ln());
503        let cross_entropy = -(&targets.data * &log_softmax).sum();
504
505        let loss_data = ArrayD::from_elem(IxDyn(&[]), cross_entropy);
506        Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
507    }
508
509    fn name(&self) -> &str {
510        "CrossEntropyLoss"
511    }
512}
513
514/// Training utilities
515pub struct QuantumTrainer {
516    /// Model to train
517    model: Box<dyn QuantumModule>,
518    /// Optimizer
519    optimizer: SciRS2Optimizer,
520    /// Loss function
521    loss_fn: Box<dyn QuantumLoss>,
522    /// Training history
523    history: TrainingHistory,
524}
525
526/// Training history
527#[derive(Debug, Clone)]
528pub struct TrainingHistory {
529    /// Loss values per epoch
530    pub losses: Vec<f64>,
531    /// Accuracy values per epoch (if applicable)
532    pub accuracies: Vec<f64>,
533    /// Validation losses
534    pub val_losses: Vec<f64>,
535    /// Validation accuracies
536    pub val_accuracies: Vec<f64>,
537}
538
539impl TrainingHistory {
540    /// Create new training history
541    pub fn new() -> Self {
542        Self {
543            losses: Vec::new(),
544            accuracies: Vec::new(),
545            val_losses: Vec::new(),
546            val_accuracies: Vec::new(),
547        }
548    }
549
550    /// Add training metrics
551    pub fn add_training(&mut self, loss: f64, accuracy: Option<f64>) {
552        self.losses.push(loss);
553        if let Some(acc) = accuracy {
554            self.accuracies.push(acc);
555        }
556    }
557
558    /// Add validation metrics
559    pub fn add_validation(&mut self, loss: f64, accuracy: Option<f64>) {
560        self.val_losses.push(loss);
561        if let Some(acc) = accuracy {
562            self.val_accuracies.push(acc);
563        }
564    }
565}
566
567impl QuantumTrainer {
568    /// Create new trainer
569    pub fn new(
570        model: Box<dyn QuantumModule>,
571        optimizer: SciRS2Optimizer,
572        loss_fn: Box<dyn QuantumLoss>,
573    ) -> Self {
574        Self {
575            model,
576            optimizer,
577            loss_fn,
578            history: TrainingHistory::new(),
579        }
580    }
581
582    /// Train for one epoch
583    pub fn train_epoch(&mut self, dataloader: &mut dyn DataLoader) -> Result<f64> {
584        self.model.train(true);
585        let mut total_loss = 0.0;
586        let mut num_batches = 0;
587
588        while let Some((inputs, targets)) = dataloader.next_batch()? {
589            // Zero gradients
590            self.model.zero_grad();
591
592            // Forward pass
593            let predictions = self.model.forward(&inputs)?;
594
595            // Compute loss
596            let loss = self.loss_fn.forward(&predictions, &targets)?;
597            total_loss += loss.data[[0]];
598
599            // Backward pass
600            // loss.backward()?; // Would implement backpropagation
601
602            // Optimizer step
603            let mut params = HashMap::new();
604            for (i, param) in self.model.parameters().iter().enumerate() {
605                params.insert(format!("param_{}", i), param.data.clone());
606            }
607            self.optimizer.step(&mut params)?;
608
609            num_batches += 1;
610        }
611
612        let avg_loss = total_loss / num_batches as f64;
613        self.history.add_training(avg_loss, None);
614        Ok(avg_loss)
615    }
616
617    /// Evaluate on validation set
618    pub fn evaluate(&mut self, dataloader: &mut dyn DataLoader) -> Result<f64> {
619        self.model.train(false);
620        let mut total_loss = 0.0;
621        let mut num_batches = 0;
622
623        while let Some((inputs, targets)) = dataloader.next_batch()? {
624            // Forward pass (no gradients)
625            let predictions = self.model.forward(&inputs)?;
626
627            // Compute loss
628            let loss = self.loss_fn.forward(&predictions, &targets)?;
629            total_loss += loss.data[[0]];
630
631            num_batches += 1;
632        }
633
634        let avg_loss = total_loss / num_batches as f64;
635        self.history.add_validation(avg_loss, None);
636        Ok(avg_loss)
637    }
638
639    /// Get training history
640    pub fn history(&self) -> &TrainingHistory {
641        &self.history
642    }
643}
644
645/// Data loader trait
646pub trait DataLoader {
647    /// Get next batch
648    fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>>;
649
650    /// Reset to beginning
651    fn reset(&mut self);
652
653    /// Get batch size
654    fn batch_size(&self) -> usize;
655}
656
657/// Simple in-memory data loader
658pub struct MemoryDataLoader {
659    /// Input data
660    inputs: SciRS2Array,
661    /// Target data
662    targets: SciRS2Array,
663    /// Batch size
664    batch_size: usize,
665    /// Current position
666    current_pos: usize,
667    /// Shuffle data
668    shuffle: bool,
669    /// Indices for shuffling
670    indices: Vec<usize>,
671}
672
673impl MemoryDataLoader {
674    /// Create new memory data loader
675    pub fn new(
676        inputs: SciRS2Array,
677        targets: SciRS2Array,
678        batch_size: usize,
679        shuffle: bool,
680    ) -> Result<Self> {
681        let num_samples = inputs.data.shape()[0];
682        if targets.data.shape()[0] != num_samples {
683            return Err(MLError::InvalidConfiguration(
684                "Input and target batch sizes don't match".to_string(),
685            ));
686        }
687
688        let indices: Vec<usize> = (0..num_samples).collect();
689
690        Ok(Self {
691            inputs,
692            targets,
693            batch_size,
694            current_pos: 0,
695            shuffle,
696            indices,
697        })
698    }
699
700    /// Shuffle indices
701    fn shuffle_indices(&mut self) {
702        if self.shuffle {
703            // Simple shuffle using Fisher-Yates
704            for i in (1..self.indices.len()).rev() {
705                let j = fastrand::usize(0..=i);
706                self.indices.swap(i, j);
707            }
708        }
709    }
710}
711
712impl DataLoader for MemoryDataLoader {
713    fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>> {
714        if self.current_pos >= self.indices.len() {
715            return Ok(None);
716        }
717
718        let end_pos = (self.current_pos + self.batch_size).min(self.indices.len());
719        let batch_indices = &self.indices[self.current_pos..end_pos];
720
721        // Extract batch data (simplified - would use proper indexing)
722        let batch_inputs = self.inputs.clone(); // Placeholder
723        let batch_targets = self.targets.clone(); // Placeholder
724
725        self.current_pos = end_pos;
726
727        Ok(Some((batch_inputs, batch_targets)))
728    }
729
730    fn reset(&mut self) {
731        self.current_pos = 0;
732        self.shuffle_indices();
733    }
734
735    fn batch_size(&self) -> usize {
736        self.batch_size
737    }
738}
739
740/// Utility functions for building quantum models
741pub mod quantum_nn {
742    use super::*;
743
744    /// Create a simple quantum feedforward network
745    pub fn create_feedforward(
746        input_size: usize,
747        hidden_sizes: &[usize],
748        output_size: usize,
749        activation: ActivationType,
750    ) -> Result<QuantumSequential> {
751        let mut model = QuantumSequential::new();
752
753        let mut prev_size = input_size;
754
755        // Hidden layers
756        for &hidden_size in hidden_sizes {
757            model = model.add(Box::new(
758                QuantumLinear::new(prev_size, hidden_size)?.with_bias()?,
759            ));
760            model = model.add(Box::new(QuantumActivation::new(activation.clone())));
761            prev_size = hidden_size;
762        }
763
764        // Output layer
765        model = model.add(Box::new(
766            QuantumLinear::new(prev_size, output_size)?.with_bias()?,
767        ));
768
769        Ok(model)
770    }
771
772    /// Create quantum CNN
773    pub fn create_cnn(input_channels: usize, num_classes: usize) -> Result<QuantumSequential> {
774        let model = QuantumSequential::new()
775            .add(Box::new(
776                QuantumConv2d::new(input_channels, 32, (3, 3))?.with_bias()?,
777            ))
778            .add(Box::new(QuantumActivation::relu()))
779            .add(Box::new(QuantumConv2d::new(32, 64, (3, 3))?.with_bias()?))
780            .add(Box::new(QuantumActivation::relu()))
781            .add(Box::new(QuantumLinear::new(64, num_classes)?.with_bias()?));
782
783        Ok(model)
784    }
785
786    /// Initialize model parameters
787    pub fn init_parameters(model: &mut dyn QuantumModule, init_type: InitType) -> Result<()> {
788        for mut param in model.parameters() {
789            match init_type {
790                InitType::Xavier => {
791                    // Xavier/Glorot initialization
792                    let fan_in = param.shape().iter().rev().skip(1).product::<usize>() as f64;
793                    let fan_out = param.shape()[0] as f64;
794                    let bound = (6.0 / (fan_in + fan_out)).sqrt();
795
796                    for elem in param.data.data.iter_mut() {
797                        *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
798                    }
799                }
800                InitType::He => {
801                    // He initialization
802                    let fan_in = param.shape().iter().rev().skip(1).product::<usize>() as f64;
803                    let std = (2.0 / fan_in).sqrt();
804
805                    for elem in param.data.data.iter_mut() {
806                        *elem = fastrand::f64() * std;
807                    }
808                }
809                InitType::Normal(mean, std) => {
810                    // Normal initialization
811                    for elem in param.data.data.iter_mut() {
812                        *elem = mean + std * fastrand::f64();
813                    }
814                }
815                InitType::Uniform(low, high) => {
816                    // Uniform initialization
817                    for elem in param.data.data.iter_mut() {
818                        *elem = low + (high - low) * fastrand::f64();
819                    }
820                }
821            }
822        }
823        Ok(())
824    }
825}
826
827/// Parameter initialization types
828#[derive(Debug, Clone, Copy)]
829pub enum InitType {
830    /// Xavier/Glorot initialization
831    Xavier,
832    /// He initialization
833    He,
834    /// Normal distribution
835    Normal(f64, f64), // mean, std
836    /// Uniform distribution
837    Uniform(f64, f64), // low, high
838}
839
840#[cfg(test)]
841mod tests {
842    use super::*;
843
844    #[test]
845    fn test_quantum_linear() {
846        let mut linear = QuantumLinear::new(4, 2).unwrap();
847        assert_eq!(linear.in_features, 4);
848        assert_eq!(linear.out_features, 2);
849        assert_eq!(linear.parameters().len(), 1); // weights only
850
851        let linear_with_bias = linear.with_bias().unwrap();
852        // Would have 2 parameters: weights and bias
853    }
854
855    #[test]
856    fn test_quantum_sequential() {
857        let model = QuantumSequential::new()
858            .add(Box::new(QuantumLinear::new(4, 8).unwrap()))
859            .add(Box::new(QuantumActivation::relu()))
860            .add(Box::new(QuantumLinear::new(8, 2).unwrap()));
861
862        assert_eq!(model.len(), 3);
863        assert!(!model.is_empty());
864    }
865
866    #[test]
867    fn test_quantum_activation() {
868        let mut relu = QuantumActivation::relu();
869        let input_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![-1.0, 1.0]).unwrap();
870        let input = SciRS2Array::new(input_data, false);
871
872        let output = relu.forward(&input).unwrap();
873        assert_eq!(output.data[[0]], 0.0); // ReLU(-1) = 0
874        assert_eq!(output.data[[1]], 1.0); // ReLU(1) = 1
875    }
876
877    #[test]
878    #[ignore]
879    fn test_quantum_loss() {
880        let mse_loss = QuantumMSELoss;
881
882        let pred_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap();
883        let target_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.5, 1.8]).unwrap();
884
885        let predictions = SciRS2Array::new(pred_data, false);
886        let targets = SciRS2Array::new(target_data, false);
887
888        let loss = mse_loss.forward(&predictions, &targets).unwrap();
889        assert!(loss.data[[0]] > 0.0); // Should have positive loss
890    }
891
892    #[test]
893    fn test_parameter() {
894        let data = ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0; 6]).unwrap();
895        let param = Parameter::new(SciRS2Array::new(data, true), "test_param");
896
897        assert_eq!(param.name, "test_param");
898        assert!(param.requires_grad);
899        assert_eq!(param.shape(), &[2, 3]);
900        assert_eq!(param.numel(), 6);
901    }
902
903    #[test]
904    fn test_training_history() {
905        let mut history = TrainingHistory::new();
906        history.add_training(0.5, Some(0.8));
907        history.add_validation(0.6, Some(0.7));
908
909        assert_eq!(history.losses.len(), 1);
910        assert_eq!(history.accuracies.len(), 1);
911        assert_eq!(history.val_losses.len(), 1);
912        assert_eq!(history.val_accuracies.len(), 1);
913    }
914}