quantrs2_ml/
pytorch_api.rs

1//! PyTorch-like API for quantum machine learning models
2//!
3//! This module provides a familiar PyTorch-style interface for building,
4//! training, and deploying quantum ML models, making it easier for classical
5//! ML practitioners to adopt quantum algorithms.
6
7use crate::circuit_integration::QuantumMLExecutor;
8use crate::error::{MLError, Result};
9use crate::scirs2_integration::{SciRS2Array, SciRS2Optimizer};
10use crate::simulator_backends::{Observable, SimulatorBackend};
11use quantrs2_circuit::prelude::*;
12use scirs2_core::ndarray::{Array1, Array2, ArrayD, Axis, Dimension, IxDyn};
13use std::cell::RefCell;
14use std::collections::HashMap;
15use std::rc::Rc;
16
17/// Base trait for all quantum ML modules
18pub trait QuantumModule: Send + Sync {
19    /// Forward pass
20    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array>;
21
22    /// Get all parameters
23    fn parameters(&self) -> Vec<Parameter>;
24
25    /// Set training mode
26    fn train(&mut self, mode: bool);
27
28    /// Check if module is in training mode
29    fn training(&self) -> bool;
30
31    /// Zero gradients of all parameters
32    fn zero_grad(&mut self);
33
34    /// Module name for debugging
35    fn name(&self) -> &str;
36}
37
38/// Quantum parameter wrapper
39#[derive(Debug, Clone)]
40pub struct Parameter {
41    /// Parameter data
42    pub data: SciRS2Array,
43    /// Parameter name
44    pub name: String,
45    /// Whether parameter requires gradient
46    pub requires_grad: bool,
47}
48
49impl Parameter {
50    /// Create new parameter
51    pub fn new(data: SciRS2Array, name: impl Into<String>) -> Self {
52        Self {
53            data,
54            name: name.into(),
55            requires_grad: true,
56        }
57    }
58
59    /// Create parameter without gradients
60    pub fn no_grad(data: SciRS2Array, name: impl Into<String>) -> Self {
61        Self {
62            data,
63            name: name.into(),
64            requires_grad: false,
65        }
66    }
67
68    /// Get parameter shape
69    pub fn shape(&self) -> &[usize] {
70        self.data.data.shape()
71    }
72
73    /// Get parameter size
74    pub fn numel(&self) -> usize {
75        self.data.data.len()
76    }
77}
78
79/// Quantum linear layer
80pub struct QuantumLinear {
81    /// Weight parameters
82    weights: Parameter,
83    /// Bias parameters (optional)
84    bias: Option<Parameter>,
85    /// Input features
86    in_features: usize,
87    /// Output features
88    out_features: usize,
89    /// Training mode
90    training: bool,
91    /// Circuit executor
92    executor: QuantumMLExecutor<8>, // Fixed const size for now
93}
94
95impl QuantumLinear {
96    /// Create new quantum linear layer
97    pub fn new(in_features: usize, out_features: usize) -> Result<Self> {
98        let weight_data = ArrayD::zeros(IxDyn(&[out_features, in_features]));
99        let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
100
101        Ok(Self {
102            weights,
103            bias: None,
104            in_features,
105            out_features,
106            training: true,
107            executor: QuantumMLExecutor::new(),
108        })
109    }
110
111    /// Create with bias
112    pub fn with_bias(mut self) -> Result<Self> {
113        let bias_data = ArrayD::zeros(IxDyn(&[self.out_features]));
114        self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
115        Ok(self)
116    }
117
118    /// Initialize weights using Xavier/Glorot uniform
119    pub fn init_xavier_uniform(&mut self) -> Result<()> {
120        let fan_in = self.in_features as f64;
121        let fan_out = self.out_features as f64;
122        let bound = (6.0 / (fan_in + fan_out)).sqrt();
123
124        for elem in self.weights.data.data.iter_mut() {
125            *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
126        }
127
128        Ok(())
129    }
130}
131
132impl QuantumModule for QuantumLinear {
133    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
134        // Quantum linear transformation: output = input @ weights.T + bias
135        let output = input.matmul(&self.weights.data)?;
136
137        if let Some(ref bias) = self.bias {
138            output.add(&bias.data)
139        } else {
140            Ok(output)
141        }
142    }
143
144    fn parameters(&self) -> Vec<Parameter> {
145        let mut params = vec![self.weights.clone()];
146        if let Some(ref bias) = self.bias {
147            params.push(bias.clone());
148        }
149        params
150    }
151
152    fn train(&mut self, mode: bool) {
153        self.training = mode;
154    }
155
156    fn training(&self) -> bool {
157        self.training
158    }
159
160    fn zero_grad(&mut self) {
161        self.weights.data.zero_grad();
162        if let Some(ref mut bias) = self.bias {
163            bias.data.zero_grad();
164        }
165    }
166
167    fn name(&self) -> &str {
168        "QuantumLinear"
169    }
170}
171
172/// Quantum convolutional layer
173pub struct QuantumConv2d {
174    /// Convolution parameters
175    weights: Parameter,
176    /// Bias parameters
177    bias: Option<Parameter>,
178    /// Input channels
179    in_channels: usize,
180    /// Output channels
181    out_channels: usize,
182    /// Kernel size
183    kernel_size: (usize, usize),
184    /// Stride
185    stride: (usize, usize),
186    /// Padding
187    padding: (usize, usize),
188    /// Training mode
189    training: bool,
190}
191
192impl QuantumConv2d {
193    /// Create new quantum conv2d layer
194    pub fn new(
195        in_channels: usize,
196        out_channels: usize,
197        kernel_size: (usize, usize),
198    ) -> Result<Self> {
199        let weight_shape = [out_channels, in_channels, kernel_size.0, kernel_size.1];
200        let weight_data = ArrayD::zeros(IxDyn(&weight_shape));
201        let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
202
203        Ok(Self {
204            weights,
205            bias: None,
206            in_channels,
207            out_channels,
208            kernel_size,
209            stride: (1, 1),
210            padding: (0, 0),
211            training: true,
212        })
213    }
214
215    /// Set stride
216    pub fn stride(mut self, stride: (usize, usize)) -> Self {
217        self.stride = stride;
218        self
219    }
220
221    /// Set padding
222    pub fn padding(mut self, padding: (usize, usize)) -> Self {
223        self.padding = padding;
224        self
225    }
226
227    /// Add bias
228    pub fn with_bias(mut self) -> Result<Self> {
229        let bias_data = ArrayD::zeros(IxDyn(&[self.out_channels]));
230        self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
231        Ok(self)
232    }
233}
234
235impl QuantumModule for QuantumConv2d {
236    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
237        // Quantum convolution implementation (simplified)
238        // In practice, this would implement quantum convolution operations
239        let output_data = input.data.clone(); // Placeholder
240        let mut output = SciRS2Array::new(output_data, input.requires_grad);
241
242        if let Some(ref bias) = self.bias {
243            output = output.add(&bias.data)?;
244        }
245
246        Ok(output)
247    }
248
249    fn parameters(&self) -> Vec<Parameter> {
250        let mut params = vec![self.weights.clone()];
251        if let Some(ref bias) = self.bias {
252            params.push(bias.clone());
253        }
254        params
255    }
256
257    fn train(&mut self, mode: bool) {
258        self.training = mode;
259    }
260
261    fn training(&self) -> bool {
262        self.training
263    }
264
265    fn zero_grad(&mut self) {
266        self.weights.data.zero_grad();
267        if let Some(ref mut bias) = self.bias {
268            bias.data.zero_grad();
269        }
270    }
271
272    fn name(&self) -> &str {
273        "QuantumConv2d"
274    }
275}
276
277/// Quantum activation functions
278pub struct QuantumActivation {
279    /// Activation function type
280    activation_type: ActivationType,
281    /// Training mode
282    training: bool,
283}
284
285/// Activation function types
286#[derive(Debug, Clone)]
287pub enum ActivationType {
288    /// Quantum ReLU (using rotation gates)
289    QReLU,
290    /// Quantum Sigmoid
291    QSigmoid,
292    /// Quantum Tanh
293    QTanh,
294    /// Quantum Softmax
295    QSoftmax,
296    /// Identity (no activation)
297    Identity,
298}
299
300impl QuantumActivation {
301    /// Create new activation layer
302    pub fn new(activation_type: ActivationType) -> Self {
303        Self {
304            activation_type,
305            training: true,
306        }
307    }
308
309    /// Create ReLU activation
310    pub fn relu() -> Self {
311        Self::new(ActivationType::QReLU)
312    }
313
314    /// Create Sigmoid activation
315    pub fn sigmoid() -> Self {
316        Self::new(ActivationType::QSigmoid)
317    }
318
319    /// Create Tanh activation
320    pub fn tanh() -> Self {
321        Self::new(ActivationType::QTanh)
322    }
323
324    /// Create Softmax activation
325    pub fn softmax() -> Self {
326        Self::new(ActivationType::QSoftmax)
327    }
328}
329
330impl QuantumModule for QuantumActivation {
331    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
332        match self.activation_type {
333            ActivationType::QReLU => {
334                // Quantum ReLU: max(0, x) approximation using quantum gates
335                let output_data = input.data.mapv(|x| x.max(0.0));
336                Ok(SciRS2Array::new(output_data, input.requires_grad))
337            }
338            ActivationType::QSigmoid => {
339                // Quantum sigmoid approximation
340                let output_data = input.data.mapv(|x| 1.0 / (1.0 + (-x).exp()));
341                Ok(SciRS2Array::new(output_data, input.requires_grad))
342            }
343            ActivationType::QTanh => {
344                // Quantum tanh
345                let output_data = input.data.mapv(|x| x.tanh());
346                Ok(SciRS2Array::new(output_data, input.requires_grad))
347            }
348            ActivationType::QSoftmax => {
349                // Quantum softmax
350                let max_val = input.data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
351                let exp_data = input.data.mapv(|x| (x - max_val).exp());
352                let sum_exp = exp_data.sum();
353                let output_data = exp_data.mapv(|x| x / sum_exp);
354                Ok(SciRS2Array::new(output_data, input.requires_grad))
355            }
356            ActivationType::Identity => {
357                Ok(SciRS2Array::new(input.data.clone(), input.requires_grad))
358            }
359        }
360    }
361
362    fn parameters(&self) -> Vec<Parameter> {
363        Vec::new() // Activation functions typically don't have parameters
364    }
365
366    fn train(&mut self, mode: bool) {
367        self.training = mode;
368    }
369
370    fn training(&self) -> bool {
371        self.training
372    }
373
374    fn zero_grad(&mut self) {
375        // No parameters to zero
376    }
377
378    fn name(&self) -> &str {
379        "QuantumActivation"
380    }
381}
382
383/// Sequential container for quantum modules
384pub struct QuantumSequential {
385    /// Ordered modules
386    modules: Vec<Box<dyn QuantumModule>>,
387    /// Training mode
388    training: bool,
389}
390
391impl QuantumSequential {
392    /// Create new sequential container
393    pub fn new() -> Self {
394        Self {
395            modules: Vec::new(),
396            training: true,
397        }
398    }
399
400    /// Add module to sequence
401    pub fn add(mut self, module: Box<dyn QuantumModule>) -> Self {
402        self.modules.push(module);
403        self
404    }
405
406    /// Get number of modules
407    pub fn len(&self) -> usize {
408        self.modules.len()
409    }
410
411    /// Check if empty
412    pub fn is_empty(&self) -> bool {
413        self.modules.is_empty()
414    }
415}
416
417impl QuantumModule for QuantumSequential {
418    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
419        let mut output = input.clone();
420
421        for module in &mut self.modules {
422            output = module.forward(&output)?;
423        }
424
425        Ok(output)
426    }
427
428    fn parameters(&self) -> Vec<Parameter> {
429        let mut all_params = Vec::new();
430
431        for module in &self.modules {
432            all_params.extend(module.parameters());
433        }
434
435        all_params
436    }
437
438    fn train(&mut self, mode: bool) {
439        self.training = mode;
440        for module in &mut self.modules {
441            module.train(mode);
442        }
443    }
444
445    fn training(&self) -> bool {
446        self.training
447    }
448
449    fn zero_grad(&mut self) {
450        for module in &mut self.modules {
451            module.zero_grad();
452        }
453    }
454
455    fn name(&self) -> &str {
456        "QuantumSequential"
457    }
458}
459
460/// Loss functions for quantum ML
461pub trait QuantumLoss {
462    /// Compute loss
463    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array>;
464
465    /// Loss function name
466    fn name(&self) -> &str;
467}
468
469/// Mean Squared Error loss
470pub struct QuantumMSELoss;
471
472impl QuantumLoss for QuantumMSELoss {
473    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
474        let diff = predictions.data.clone() - &targets.data;
475        let squared_diff = &diff * &diff;
476        let mse = squared_diff.mean().ok_or_else(|| {
477            MLError::InvalidConfiguration("Cannot compute mean of empty array".to_string())
478        })?;
479
480        let loss_data = ArrayD::from_elem(IxDyn(&[]), mse);
481        Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
482    }
483
484    fn name(&self) -> &str {
485        "MSELoss"
486    }
487}
488
489/// Cross Entropy loss
490pub struct QuantumCrossEntropyLoss;
491
492impl QuantumLoss for QuantumCrossEntropyLoss {
493    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
494        // Compute softmax of predictions
495        let max_val = predictions
496            .data
497            .iter()
498            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
499        let exp_preds = predictions.data.mapv(|x| (x - max_val).exp());
500        let sum_exp = exp_preds.sum();
501        let softmax = exp_preds.mapv(|x| x / sum_exp);
502
503        // Compute cross entropy
504        let log_softmax = softmax.mapv(|x| x.ln());
505        let cross_entropy = -(&targets.data * &log_softmax).sum();
506
507        let loss_data = ArrayD::from_elem(IxDyn(&[]), cross_entropy);
508        Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
509    }
510
511    fn name(&self) -> &str {
512        "CrossEntropyLoss"
513    }
514}
515
516/// Training utilities
517pub struct QuantumTrainer {
518    /// Model to train
519    model: Box<dyn QuantumModule>,
520    /// Optimizer
521    optimizer: SciRS2Optimizer,
522    /// Loss function
523    loss_fn: Box<dyn QuantumLoss>,
524    /// Training history
525    history: TrainingHistory,
526}
527
528/// Training history
529#[derive(Debug, Clone)]
530pub struct TrainingHistory {
531    /// Loss values per epoch
532    pub losses: Vec<f64>,
533    /// Accuracy values per epoch (if applicable)
534    pub accuracies: Vec<f64>,
535    /// Validation losses
536    pub val_losses: Vec<f64>,
537    /// Validation accuracies
538    pub val_accuracies: Vec<f64>,
539}
540
541impl TrainingHistory {
542    /// Create new training history
543    pub fn new() -> Self {
544        Self {
545            losses: Vec::new(),
546            accuracies: Vec::new(),
547            val_losses: Vec::new(),
548            val_accuracies: Vec::new(),
549        }
550    }
551
552    /// Add training metrics
553    pub fn add_training(&mut self, loss: f64, accuracy: Option<f64>) {
554        self.losses.push(loss);
555        if let Some(acc) = accuracy {
556            self.accuracies.push(acc);
557        }
558    }
559
560    /// Add validation metrics
561    pub fn add_validation(&mut self, loss: f64, accuracy: Option<f64>) {
562        self.val_losses.push(loss);
563        if let Some(acc) = accuracy {
564            self.val_accuracies.push(acc);
565        }
566    }
567}
568
569impl QuantumTrainer {
570    /// Create new trainer
571    pub fn new(
572        model: Box<dyn QuantumModule>,
573        optimizer: SciRS2Optimizer,
574        loss_fn: Box<dyn QuantumLoss>,
575    ) -> Self {
576        Self {
577            model,
578            optimizer,
579            loss_fn,
580            history: TrainingHistory::new(),
581        }
582    }
583
584    /// Train for one epoch
585    pub fn train_epoch(&mut self, dataloader: &mut dyn DataLoader) -> Result<f64> {
586        self.model.train(true);
587        let mut total_loss = 0.0;
588        let mut num_batches = 0;
589
590        while let Some((inputs, targets)) = dataloader.next_batch()? {
591            // Zero gradients
592            self.model.zero_grad();
593
594            // Forward pass
595            let predictions = self.model.forward(&inputs)?;
596
597            // Compute loss
598            let loss = self.loss_fn.forward(&predictions, &targets)?;
599            total_loss += loss.data[[0]];
600
601            // Backward pass
602            // loss.backward()?; // Would implement backpropagation
603
604            // Optimizer step
605            let mut params = HashMap::new();
606            for (i, param) in self.model.parameters().iter().enumerate() {
607                params.insert(format!("param_{}", i), param.data.clone());
608            }
609            self.optimizer.step(&mut params)?;
610
611            num_batches += 1;
612        }
613
614        let avg_loss = total_loss / num_batches as f64;
615        self.history.add_training(avg_loss, None);
616        Ok(avg_loss)
617    }
618
619    /// Evaluate on validation set
620    pub fn evaluate(&mut self, dataloader: &mut dyn DataLoader) -> Result<f64> {
621        self.model.train(false);
622        let mut total_loss = 0.0;
623        let mut num_batches = 0;
624
625        while let Some((inputs, targets)) = dataloader.next_batch()? {
626            // Forward pass (no gradients)
627            let predictions = self.model.forward(&inputs)?;
628
629            // Compute loss
630            let loss = self.loss_fn.forward(&predictions, &targets)?;
631            total_loss += loss.data[[0]];
632
633            num_batches += 1;
634        }
635
636        let avg_loss = total_loss / num_batches as f64;
637        self.history.add_validation(avg_loss, None);
638        Ok(avg_loss)
639    }
640
641    /// Get training history
642    pub fn history(&self) -> &TrainingHistory {
643        &self.history
644    }
645}
646
647/// Data loader trait
648pub trait DataLoader {
649    /// Get next batch
650    fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>>;
651
652    /// Reset to beginning
653    fn reset(&mut self);
654
655    /// Get batch size
656    fn batch_size(&self) -> usize;
657}
658
659/// Simple in-memory data loader
660pub struct MemoryDataLoader {
661    /// Input data
662    inputs: SciRS2Array,
663    /// Target data
664    targets: SciRS2Array,
665    /// Batch size
666    batch_size: usize,
667    /// Current position
668    current_pos: usize,
669    /// Shuffle data
670    shuffle: bool,
671    /// Indices for shuffling
672    indices: Vec<usize>,
673}
674
675impl MemoryDataLoader {
676    /// Create new memory data loader
677    pub fn new(
678        inputs: SciRS2Array,
679        targets: SciRS2Array,
680        batch_size: usize,
681        shuffle: bool,
682    ) -> Result<Self> {
683        let num_samples = inputs.data.shape()[0];
684        if targets.data.shape()[0] != num_samples {
685            return Err(MLError::InvalidConfiguration(
686                "Input and target batch sizes don't match".to_string(),
687            ));
688        }
689
690        let indices: Vec<usize> = (0..num_samples).collect();
691
692        Ok(Self {
693            inputs,
694            targets,
695            batch_size,
696            current_pos: 0,
697            shuffle,
698            indices,
699        })
700    }
701
702    /// Shuffle indices
703    fn shuffle_indices(&mut self) {
704        if self.shuffle {
705            // Simple shuffle using Fisher-Yates
706            for i in (1..self.indices.len()).rev() {
707                let j = fastrand::usize(0..=i);
708                self.indices.swap(i, j);
709            }
710        }
711    }
712}
713
714impl DataLoader for MemoryDataLoader {
715    fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>> {
716        if self.current_pos >= self.indices.len() {
717            return Ok(None);
718        }
719
720        let end_pos = (self.current_pos + self.batch_size).min(self.indices.len());
721        let batch_indices = &self.indices[self.current_pos..end_pos];
722
723        // Extract batch data (simplified - would use proper indexing)
724        let batch_inputs = self.inputs.clone(); // Placeholder
725        let batch_targets = self.targets.clone(); // Placeholder
726
727        self.current_pos = end_pos;
728
729        Ok(Some((batch_inputs, batch_targets)))
730    }
731
732    fn reset(&mut self) {
733        self.current_pos = 0;
734        self.shuffle_indices();
735    }
736
737    fn batch_size(&self) -> usize {
738        self.batch_size
739    }
740}
741
742/// Utility functions for building quantum models
743pub mod quantum_nn {
744    use super::*;
745
746    /// Create a simple quantum feedforward network
747    pub fn create_feedforward(
748        input_size: usize,
749        hidden_sizes: &[usize],
750        output_size: usize,
751        activation: ActivationType,
752    ) -> Result<QuantumSequential> {
753        let mut model = QuantumSequential::new();
754
755        let mut prev_size = input_size;
756
757        // Hidden layers
758        for &hidden_size in hidden_sizes {
759            model = model.add(Box::new(
760                QuantumLinear::new(prev_size, hidden_size)?.with_bias()?,
761            ));
762            model = model.add(Box::new(QuantumActivation::new(activation.clone())));
763            prev_size = hidden_size;
764        }
765
766        // Output layer
767        model = model.add(Box::new(
768            QuantumLinear::new(prev_size, output_size)?.with_bias()?,
769        ));
770
771        Ok(model)
772    }
773
774    /// Create quantum CNN
775    pub fn create_cnn(input_channels: usize, num_classes: usize) -> Result<QuantumSequential> {
776        let model = QuantumSequential::new()
777            .add(Box::new(
778                QuantumConv2d::new(input_channels, 32, (3, 3))?.with_bias()?,
779            ))
780            .add(Box::new(QuantumActivation::relu()))
781            .add(Box::new(QuantumConv2d::new(32, 64, (3, 3))?.with_bias()?))
782            .add(Box::new(QuantumActivation::relu()))
783            .add(Box::new(QuantumLinear::new(64, num_classes)?.with_bias()?));
784
785        Ok(model)
786    }
787
788    /// Initialize model parameters
789    pub fn init_parameters(model: &mut dyn QuantumModule, init_type: InitType) -> Result<()> {
790        for mut param in model.parameters() {
791            match init_type {
792                InitType::Xavier => {
793                    // Xavier/Glorot initialization
794                    let fan_in = param.shape().iter().rev().skip(1).product::<usize>() as f64;
795                    let fan_out = param.shape()[0] as f64;
796                    let bound = (6.0 / (fan_in + fan_out)).sqrt();
797
798                    for elem in param.data.data.iter_mut() {
799                        *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
800                    }
801                }
802                InitType::He => {
803                    // He initialization
804                    let fan_in = param.shape().iter().rev().skip(1).product::<usize>() as f64;
805                    let std = (2.0 / fan_in).sqrt();
806
807                    for elem in param.data.data.iter_mut() {
808                        *elem = fastrand::f64() * std;
809                    }
810                }
811                InitType::Normal(mean, std) => {
812                    // Normal initialization
813                    for elem in param.data.data.iter_mut() {
814                        *elem = mean + std * fastrand::f64();
815                    }
816                }
817                InitType::Uniform(low, high) => {
818                    // Uniform initialization
819                    for elem in param.data.data.iter_mut() {
820                        *elem = low + (high - low) * fastrand::f64();
821                    }
822                }
823            }
824        }
825        Ok(())
826    }
827}
828
829/// Parameter initialization types
830#[derive(Debug, Clone, Copy)]
831pub enum InitType {
832    /// Xavier/Glorot initialization
833    Xavier,
834    /// He initialization
835    He,
836    /// Normal distribution
837    Normal(f64, f64), // mean, std
838    /// Uniform distribution
839    Uniform(f64, f64), // low, high
840}
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845
846    #[test]
847    fn test_quantum_linear() {
848        let mut linear = QuantumLinear::new(4, 2).expect("QuantumLinear creation should succeed");
849        assert_eq!(linear.in_features, 4);
850        assert_eq!(linear.out_features, 2);
851        assert_eq!(linear.parameters().len(), 1); // weights only
852
853        let linear_with_bias = linear.with_bias().expect("Adding bias should succeed");
854        // Would have 2 parameters: weights and bias
855    }
856
857    #[test]
858    fn test_quantum_sequential() {
859        let model = QuantumSequential::new()
860            .add(Box::new(
861                QuantumLinear::new(4, 8).expect("QuantumLinear creation should succeed"),
862            ))
863            .add(Box::new(QuantumActivation::relu()))
864            .add(Box::new(
865                QuantumLinear::new(8, 2).expect("QuantumLinear creation should succeed"),
866            ));
867
868        assert_eq!(model.len(), 3);
869        assert!(!model.is_empty());
870    }
871
872    #[test]
873    fn test_quantum_activation() {
874        let mut relu = QuantumActivation::relu();
875        let input_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![-1.0, 1.0])
876            .expect("Valid shape for input data");
877        let input = SciRS2Array::new(input_data, false);
878
879        let output = relu.forward(&input).expect("Forward pass should succeed");
880        assert_eq!(output.data[[0]], 0.0); // ReLU(-1) = 0
881        assert_eq!(output.data[[1]], 1.0); // ReLU(1) = 1
882    }
883
884    #[test]
885    #[ignore]
886    fn test_quantum_loss() {
887        let mse_loss = QuantumMSELoss;
888
889        let pred_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0])
890            .expect("Valid shape for predictions");
891        let target_data =
892            ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.5, 1.8]).expect("Valid shape for targets");
893
894        let predictions = SciRS2Array::new(pred_data, false);
895        let targets = SciRS2Array::new(target_data, false);
896
897        let loss = mse_loss
898            .forward(&predictions, &targets)
899            .expect("Loss computation should succeed");
900        assert!(loss.data[[0]] > 0.0); // Should have positive loss
901    }
902
903    #[test]
904    fn test_parameter() {
905        let data = ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0; 6])
906            .expect("Valid shape for parameter data");
907        let param = Parameter::new(SciRS2Array::new(data, true), "test_param");
908
909        assert_eq!(param.name, "test_param");
910        assert!(param.requires_grad);
911        assert_eq!(param.shape(), &[2, 3]);
912        assert_eq!(param.numel(), 6);
913    }
914
915    #[test]
916    fn test_training_history() {
917        let mut history = TrainingHistory::new();
918        history.add_training(0.5, Some(0.8));
919        history.add_validation(0.6, Some(0.7));
920
921        assert_eq!(history.losses.len(), 1);
922        assert_eq!(history.accuracies.len(), 1);
923        assert_eq!(history.val_losses.len(), 1);
924        assert_eq!(history.val_accuracies.len(), 1);
925    }
926}