quantrs2_ml/pytorch_api/
layers.rs

1//! Basic layers for PyTorch-like API
2//!
3//! This module contains fundamental layers: Linear, Conv2d, Activations,
4//! Normalization, Dropout, Pooling, and Embedding.
5
6use super::{Parameter, QuantumModule};
7use crate::circuit_integration::QuantumMLExecutor;
8use crate::error::{MLError, Result};
9use crate::scirs2_integration::SciRS2Array;
10use scirs2_core::ndarray::{ArrayD, IxDyn};
11
12// ============================================================================
13// Linear Layer
14// ============================================================================
15
16/// Quantum linear layer
17pub struct QuantumLinear {
18    /// Weight parameters
19    weights: Parameter,
20    /// Bias parameters (optional)
21    bias: Option<Parameter>,
22    /// Input features
23    pub in_features: usize,
24    /// Output features
25    pub out_features: usize,
26    /// Training mode
27    training: bool,
28    /// Circuit executor
29    executor: QuantumMLExecutor<8>,
30}
31
32impl QuantumLinear {
33    /// Create new quantum linear layer
34    pub fn new(in_features: usize, out_features: usize) -> Result<Self> {
35        let weight_data = ArrayD::zeros(IxDyn(&[out_features, in_features]));
36        let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
37
38        Ok(Self {
39            weights,
40            bias: None,
41            in_features,
42            out_features,
43            training: true,
44            executor: QuantumMLExecutor::new(),
45        })
46    }
47
48    /// Create with bias
49    pub fn with_bias(mut self) -> Result<Self> {
50        let bias_data = ArrayD::zeros(IxDyn(&[self.out_features]));
51        self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
52        Ok(self)
53    }
54
55    /// Initialize weights using Xavier/Glorot uniform
56    pub fn init_xavier_uniform(&mut self) -> Result<()> {
57        let fan_in = self.in_features as f64;
58        let fan_out = self.out_features as f64;
59        let bound = (6.0 / (fan_in + fan_out)).sqrt();
60
61        for elem in self.weights.data.data.iter_mut() {
62            *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
63        }
64
65        Ok(())
66    }
67}
68
69impl QuantumModule for QuantumLinear {
70    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
71        let output = input.matmul(&self.weights.data)?;
72
73        if let Some(ref bias) = self.bias {
74            output.add(&bias.data)
75        } else {
76            Ok(output)
77        }
78    }
79
80    fn parameters(&self) -> Vec<Parameter> {
81        let mut params = vec![self.weights.clone()];
82        if let Some(ref bias) = self.bias {
83            params.push(bias.clone());
84        }
85        params
86    }
87
88    fn train(&mut self, mode: bool) {
89        self.training = mode;
90    }
91
92    fn training(&self) -> bool {
93        self.training
94    }
95
96    fn zero_grad(&mut self) {
97        self.weights.data.zero_grad();
98        if let Some(ref mut bias) = self.bias {
99            bias.data.zero_grad();
100        }
101    }
102
103    fn name(&self) -> &str {
104        "QuantumLinear"
105    }
106}
107
108// ============================================================================
109// Conv2d Layer
110// ============================================================================
111
112/// Quantum convolutional layer
113pub struct QuantumConv2d {
114    /// Convolution parameters
115    weights: Parameter,
116    /// Bias parameters
117    bias: Option<Parameter>,
118    /// Input channels
119    pub in_channels: usize,
120    /// Output channels
121    pub out_channels: usize,
122    /// Kernel size
123    pub kernel_size: (usize, usize),
124    /// Stride
125    pub stride: (usize, usize),
126    /// Padding
127    pub padding: (usize, usize),
128    /// Training mode
129    training: bool,
130}
131
132impl QuantumConv2d {
133    /// Create new quantum conv2d layer
134    pub fn new(
135        in_channels: usize,
136        out_channels: usize,
137        kernel_size: (usize, usize),
138    ) -> Result<Self> {
139        let weight_shape = [out_channels, in_channels, kernel_size.0, kernel_size.1];
140        let weight_data = ArrayD::zeros(IxDyn(&weight_shape));
141        let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
142
143        Ok(Self {
144            weights,
145            bias: None,
146            in_channels,
147            out_channels,
148            kernel_size,
149            stride: (1, 1),
150            padding: (0, 0),
151            training: true,
152        })
153    }
154
155    /// Set stride
156    pub fn stride(mut self, stride: (usize, usize)) -> Self {
157        self.stride = stride;
158        self
159    }
160
161    /// Set padding
162    pub fn padding(mut self, padding: (usize, usize)) -> Self {
163        self.padding = padding;
164        self
165    }
166
167    /// Add bias
168    pub fn with_bias(mut self) -> Result<Self> {
169        let bias_data = ArrayD::zeros(IxDyn(&[self.out_channels]));
170        self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
171        Ok(self)
172    }
173}
174
175impl QuantumModule for QuantumConv2d {
176    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
177        let output_data = input.data.clone();
178        let mut output = SciRS2Array::new(output_data, input.requires_grad);
179
180        if let Some(ref bias) = self.bias {
181            output = output.add(&bias.data)?;
182        }
183
184        Ok(output)
185    }
186
187    fn parameters(&self) -> Vec<Parameter> {
188        let mut params = vec![self.weights.clone()];
189        if let Some(ref bias) = self.bias {
190            params.push(bias.clone());
191        }
192        params
193    }
194
195    fn train(&mut self, mode: bool) {
196        self.training = mode;
197    }
198
199    fn training(&self) -> bool {
200        self.training
201    }
202
203    fn zero_grad(&mut self) {
204        self.weights.data.zero_grad();
205        if let Some(ref mut bias) = self.bias {
206            bias.data.zero_grad();
207        }
208    }
209
210    fn name(&self) -> &str {
211        "QuantumConv2d"
212    }
213}
214
215// ============================================================================
216// Activation Functions
217// ============================================================================
218
219/// Activation function types
220#[derive(Debug, Clone)]
221pub enum ActivationType {
222    /// Quantum ReLU (using rotation gates)
223    QReLU,
224    /// Quantum Sigmoid
225    QSigmoid,
226    /// Quantum Tanh
227    QTanh,
228    /// Quantum Softmax
229    QSoftmax,
230    /// Identity (no activation)
231    Identity,
232}
233
234/// Quantum activation functions
235pub struct QuantumActivation {
236    /// Activation function type
237    activation_type: ActivationType,
238    /// Training mode
239    training: bool,
240}
241
242impl QuantumActivation {
243    /// Create new activation layer
244    pub fn new(activation_type: ActivationType) -> Self {
245        Self {
246            activation_type,
247            training: true,
248        }
249    }
250
251    /// Create ReLU activation
252    pub fn relu() -> Self {
253        Self::new(ActivationType::QReLU)
254    }
255
256    /// Create Sigmoid activation
257    pub fn sigmoid() -> Self {
258        Self::new(ActivationType::QSigmoid)
259    }
260
261    /// Create Tanh activation
262    pub fn tanh() -> Self {
263        Self::new(ActivationType::QTanh)
264    }
265
266    /// Create Softmax activation
267    pub fn softmax() -> Self {
268        Self::new(ActivationType::QSoftmax)
269    }
270}
271
272impl QuantumModule for QuantumActivation {
273    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
274        match self.activation_type {
275            ActivationType::QReLU => {
276                let output_data = input.data.mapv(|x| x.max(0.0));
277                Ok(SciRS2Array::new(output_data, input.requires_grad))
278            }
279            ActivationType::QSigmoid => {
280                let output_data = input.data.mapv(|x| 1.0 / (1.0 + (-x).exp()));
281                Ok(SciRS2Array::new(output_data, input.requires_grad))
282            }
283            ActivationType::QTanh => {
284                let output_data = input.data.mapv(|x| x.tanh());
285                Ok(SciRS2Array::new(output_data, input.requires_grad))
286            }
287            ActivationType::QSoftmax => {
288                let max_val = input.data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
289                let exp_data = input.data.mapv(|x| (x - max_val).exp());
290                let sum_exp = exp_data.sum();
291                let output_data = exp_data.mapv(|x| x / sum_exp);
292                Ok(SciRS2Array::new(output_data, input.requires_grad))
293            }
294            ActivationType::Identity => {
295                Ok(SciRS2Array::new(input.data.clone(), input.requires_grad))
296            }
297        }
298    }
299
300    fn parameters(&self) -> Vec<Parameter> {
301        Vec::new()
302    }
303
304    fn train(&mut self, mode: bool) {
305        self.training = mode;
306    }
307
308    fn training(&self) -> bool {
309        self.training
310    }
311
312    fn zero_grad(&mut self) {}
313
314    fn name(&self) -> &str {
315        "QuantumActivation"
316    }
317}
318
319// ============================================================================
320// Extended Activation Functions
321// ============================================================================
322
323/// Parameter initialization types
324#[derive(Debug, Clone, Copy)]
325pub enum InitType {
326    /// Xavier/Glorot initialization
327    Xavier,
328    /// He initialization
329    He,
330    /// Normal distribution
331    Normal(f64, f64),
332    /// Uniform distribution
333    Uniform(f64, f64),
334}
335
336/// Extended activation function types
337#[derive(Debug, Clone)]
338pub enum ExtendedActivation {
339    /// GELU activation
340    GELU,
341    /// ELU activation
342    ELU { alpha: f64 },
343    /// LeakyReLU activation
344    LeakyReLU { negative_slope: f64 },
345    /// SiLU/Swish activation
346    SiLU,
347    /// PReLU activation
348    PReLU { num_parameters: usize },
349    /// Softplus activation
350    Softplus { beta: f64 },
351    /// Mish activation
352    Mish,
353    /// Hardswish activation
354    Hardswish,
355    /// Hardsigmoid activation
356    Hardsigmoid,
357}
358
359/// Extended activation layer
360pub struct QuantumExtendedActivation {
361    activation: ExtendedActivation,
362    prelu_weights: Option<Parameter>,
363    training: bool,
364}
365
366impl QuantumExtendedActivation {
367    /// Create new extended activation
368    pub fn new(activation: ExtendedActivation) -> Self {
369        let prelu_weights = match &activation {
370            ExtendedActivation::PReLU { num_parameters } => {
371                let data = ArrayD::from_elem(IxDyn(&[*num_parameters]), 0.25);
372                Some(Parameter::new(SciRS2Array::with_grad(data), "weight"))
373            }
374            _ => None,
375        };
376
377        Self {
378            activation,
379            prelu_weights,
380            training: true,
381        }
382    }
383
384    /// Create GELU activation
385    pub fn gelu() -> Self {
386        Self::new(ExtendedActivation::GELU)
387    }
388
389    /// Create ELU activation
390    pub fn elu(alpha: f64) -> Self {
391        Self::new(ExtendedActivation::ELU { alpha })
392    }
393
394    /// Create LeakyReLU activation
395    pub fn leaky_relu(negative_slope: f64) -> Self {
396        Self::new(ExtendedActivation::LeakyReLU { negative_slope })
397    }
398
399    /// Create SiLU/Swish activation
400    pub fn silu() -> Self {
401        Self::new(ExtendedActivation::SiLU)
402    }
403
404    /// Create Mish activation
405    pub fn mish() -> Self {
406        Self::new(ExtendedActivation::Mish)
407    }
408}
409
410impl QuantumModule for QuantumExtendedActivation {
411    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
412        let output_data = match &self.activation {
413            ExtendedActivation::GELU => input.data.mapv(|x| {
414                let sqrt_2_pi = (2.0 / std::f64::consts::PI).sqrt();
415                0.5 * x * (1.0 + (sqrt_2_pi * (x + 0.044715 * x.powi(3))).tanh())
416            }),
417            ExtendedActivation::ELU { alpha } => {
418                let a = *alpha;
419                input
420                    .data
421                    .mapv(|x| if x >= 0.0 { x } else { a * (x.exp() - 1.0) })
422            }
423            ExtendedActivation::LeakyReLU { negative_slope } => {
424                let slope = *negative_slope;
425                input.data.mapv(|x| if x >= 0.0 { x } else { slope * x })
426            }
427            ExtendedActivation::SiLU => input.data.mapv(|x| x / (1.0 + (-x).exp())),
428            ExtendedActivation::PReLU { .. } => {
429                if let Some(ref weights) = self.prelu_weights {
430                    let weight = weights.data.data[[0]];
431                    input.data.mapv(|x| if x >= 0.0 { x } else { weight * x })
432                } else {
433                    input.data.mapv(|x| if x >= 0.0 { x } else { 0.25 * x })
434                }
435            }
436            ExtendedActivation::Softplus { beta } => {
437                let b = *beta;
438                input.data.mapv(|x| (1.0 / b) * (1.0 + (b * x).exp()).ln())
439            }
440            ExtendedActivation::Mish => input.data.mapv(|x| x * ((1.0 + x.exp()).ln()).tanh()),
441            ExtendedActivation::Hardswish => input.data.mapv(|x| {
442                if x <= -3.0 {
443                    0.0
444                } else if x >= 3.0 {
445                    x
446                } else {
447                    x * (x + 3.0) / 6.0
448                }
449            }),
450            ExtendedActivation::Hardsigmoid => input.data.mapv(|x| {
451                if x <= -3.0 {
452                    0.0
453                } else if x >= 3.0 {
454                    1.0
455                } else {
456                    (x + 3.0) / 6.0
457                }
458            }),
459        };
460        Ok(SciRS2Array::new(output_data, input.requires_grad))
461    }
462
463    fn parameters(&self) -> Vec<Parameter> {
464        self.prelu_weights.iter().cloned().collect()
465    }
466
467    fn train(&mut self, mode: bool) {
468        self.training = mode;
469    }
470
471    fn training(&self) -> bool {
472        self.training
473    }
474
475    fn zero_grad(&mut self) {
476        if let Some(ref mut weights) = self.prelu_weights {
477            weights.data.zero_grad();
478        }
479    }
480
481    fn name(&self) -> &str {
482        "ExtendedActivation"
483    }
484}
485
486// ============================================================================
487// Normalization Layers
488// ============================================================================
489
490/// Batch normalization layer
491pub struct QuantumBatchNorm1d {
492    num_features: usize,
493    running_mean: Parameter,
494    running_var: Parameter,
495    weight: Parameter,
496    bias: Parameter,
497    eps: f64,
498    momentum: f64,
499    training: bool,
500}
501
502impl QuantumBatchNorm1d {
503    /// Create new batch normalization layer
504    pub fn new(num_features: usize) -> Self {
505        let weight_data = ArrayD::ones(IxDyn(&[num_features]));
506        let bias_data = ArrayD::zeros(IxDyn(&[num_features]));
507        let mean_data = ArrayD::zeros(IxDyn(&[num_features]));
508        let var_data = ArrayD::ones(IxDyn(&[num_features]));
509
510        Self {
511            num_features,
512            running_mean: Parameter::no_grad(SciRS2Array::new(mean_data, false), "running_mean"),
513            running_var: Parameter::no_grad(SciRS2Array::new(var_data, false), "running_var"),
514            weight: Parameter::new(SciRS2Array::with_grad(weight_data), "weight"),
515            bias: Parameter::new(SciRS2Array::with_grad(bias_data), "bias"),
516            eps: 1e-5,
517            momentum: 0.1,
518            training: true,
519        }
520    }
521
522    /// Set epsilon
523    pub fn eps(mut self, eps: f64) -> Self {
524        self.eps = eps;
525        self
526    }
527
528    /// Set momentum
529    pub fn momentum(mut self, momentum: f64) -> Self {
530        self.momentum = momentum;
531        self
532    }
533}
534
535impl QuantumModule for QuantumBatchNorm1d {
536    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
537        let shape = input.data.shape();
538        if shape.len() < 2 || shape[1] != self.num_features {
539            return Err(MLError::InvalidConfiguration(format!(
540                "Expected {} features, got {:?}",
541                self.num_features, shape
542            )));
543        }
544
545        let batch_size = shape[0];
546        let mut output = input.data.clone();
547
548        if self.training {
549            for f in 0..self.num_features {
550                let mut sum = 0.0;
551                for b in 0..batch_size {
552                    sum += input.data[[b, f]];
553                }
554                let mean = sum / batch_size as f64;
555
556                let mut var_sum = 0.0;
557                for b in 0..batch_size {
558                    let diff = input.data[[b, f]] - mean;
559                    var_sum += diff * diff;
560                }
561                let var = var_sum / batch_size as f64;
562
563                self.running_mean.data.data[[f]] =
564                    (1.0 - self.momentum) * self.running_mean.data.data[[f]] + self.momentum * mean;
565                self.running_var.data.data[[f]] =
566                    (1.0 - self.momentum) * self.running_var.data.data[[f]] + self.momentum * var;
567
568                let std = (var + self.eps).sqrt();
569                for b in 0..batch_size {
570                    output[[b, f]] = (input.data[[b, f]] - mean) / std;
571                    output[[b, f]] =
572                        output[[b, f]] * self.weight.data.data[[f]] + self.bias.data.data[[f]];
573                }
574            }
575        } else {
576            for f in 0..self.num_features {
577                let mean = self.running_mean.data.data[[f]];
578                let var = self.running_var.data.data[[f]];
579                let std = (var + self.eps).sqrt();
580
581                for b in 0..batch_size {
582                    output[[b, f]] = (input.data[[b, f]] - mean) / std;
583                    output[[b, f]] =
584                        output[[b, f]] * self.weight.data.data[[f]] + self.bias.data.data[[f]];
585                }
586            }
587        }
588
589        Ok(SciRS2Array::new(output, input.requires_grad))
590    }
591
592    fn parameters(&self) -> Vec<Parameter> {
593        vec![self.weight.clone(), self.bias.clone()]
594    }
595
596    fn train(&mut self, mode: bool) {
597        self.training = mode;
598    }
599
600    fn training(&self) -> bool {
601        self.training
602    }
603
604    fn zero_grad(&mut self) {
605        self.weight.data.zero_grad();
606        self.bias.data.zero_grad();
607    }
608
609    fn name(&self) -> &str {
610        "BatchNorm1d"
611    }
612}
613
614/// Layer normalization
615pub struct QuantumLayerNorm {
616    normalized_shape: Vec<usize>,
617    weight: Parameter,
618    bias: Parameter,
619    eps: f64,
620    training: bool,
621}
622
623impl QuantumLayerNorm {
624    /// Create new layer normalization
625    pub fn new(normalized_shape: Vec<usize>) -> Self {
626        let size: usize = normalized_shape.iter().product();
627        let weight_data = ArrayD::ones(IxDyn(&[size]));
628        let bias_data = ArrayD::zeros(IxDyn(&[size]));
629
630        Self {
631            normalized_shape,
632            weight: Parameter::new(SciRS2Array::with_grad(weight_data), "weight"),
633            bias: Parameter::new(SciRS2Array::with_grad(bias_data), "bias"),
634            eps: 1e-5,
635            training: true,
636        }
637    }
638}
639
640impl QuantumModule for QuantumLayerNorm {
641    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
642        let mean: f64 = input.data.iter().sum::<f64>() / input.data.len() as f64;
643        let var: f64 =
644            input.data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / input.data.len() as f64;
645        let std = (var + self.eps).sqrt();
646
647        let mut output = input.data.mapv(|x| (x - mean) / std);
648
649        for (i, val) in output.iter_mut().enumerate() {
650            let idx = i % self.weight.data.data.len();
651            *val = *val * self.weight.data.data[[idx]] + self.bias.data.data[[idx]];
652        }
653
654        Ok(SciRS2Array::new(output, input.requires_grad))
655    }
656
657    fn parameters(&self) -> Vec<Parameter> {
658        vec![self.weight.clone(), self.bias.clone()]
659    }
660
661    fn train(&mut self, mode: bool) {
662        self.training = mode;
663    }
664
665    fn training(&self) -> bool {
666        self.training
667    }
668
669    fn zero_grad(&mut self) {
670        self.weight.data.zero_grad();
671        self.bias.data.zero_grad();
672    }
673
674    fn name(&self) -> &str {
675        "LayerNorm"
676    }
677}
678
679// ============================================================================
680// Dropout Layers
681// ============================================================================
682
683/// Dropout layer
684pub struct QuantumDropout {
685    p: f64,
686    training: bool,
687}
688
689impl QuantumDropout {
690    /// Create new dropout layer
691    pub fn new(p: f64) -> Self {
692        Self { p, training: true }
693    }
694}
695
696impl QuantumModule for QuantumDropout {
697    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
698        if !self.training || self.p == 0.0 {
699            return Ok(input.clone());
700        }
701
702        let scale = 1.0 / (1.0 - self.p);
703        let output = input.data.mapv(|x| {
704            if fastrand::f64() < self.p {
705                0.0
706            } else {
707                x * scale
708            }
709        });
710
711        Ok(SciRS2Array::new(output, input.requires_grad))
712    }
713
714    fn parameters(&self) -> Vec<Parameter> {
715        Vec::new()
716    }
717
718    fn train(&mut self, mode: bool) {
719        self.training = mode;
720    }
721
722    fn training(&self) -> bool {
723        self.training
724    }
725
726    fn zero_grad(&mut self) {}
727
728    fn name(&self) -> &str {
729        "Dropout"
730    }
731}
732
733/// Dropout2d for convolutional layers
734pub struct QuantumDropout2d {
735    p: f64,
736    training: bool,
737}
738
739impl QuantumDropout2d {
740    /// Create new dropout2d layer
741    pub fn new(p: f64) -> Self {
742        Self { p, training: true }
743    }
744}
745
746impl QuantumModule for QuantumDropout2d {
747    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
748        if !self.training || self.p == 0.0 {
749            return Ok(input.clone());
750        }
751
752        let scale = 1.0 / (1.0 - self.p);
753        let output = input.data.mapv(|x| {
754            if fastrand::f64() < self.p {
755                0.0
756            } else {
757                x * scale
758            }
759        });
760
761        Ok(SciRS2Array::new(output, input.requires_grad))
762    }
763
764    fn parameters(&self) -> Vec<Parameter> {
765        Vec::new()
766    }
767
768    fn train(&mut self, mode: bool) {
769        self.training = mode;
770    }
771
772    fn training(&self) -> bool {
773        self.training
774    }
775
776    fn zero_grad(&mut self) {}
777
778    fn name(&self) -> &str {
779        "Dropout2d"
780    }
781}
782
783// ============================================================================
784// Pooling Layers
785// ============================================================================
786
787/// Max pooling 2D
788pub struct QuantumMaxPool2d {
789    kernel_size: (usize, usize),
790    stride: (usize, usize),
791    padding: (usize, usize),
792    training: bool,
793}
794
795impl QuantumMaxPool2d {
796    /// Create new max pooling layer
797    pub fn new(kernel_size: (usize, usize)) -> Self {
798        Self {
799            kernel_size,
800            stride: kernel_size,
801            padding: (0, 0),
802            training: true,
803        }
804    }
805
806    /// Set stride
807    pub fn stride(mut self, stride: (usize, usize)) -> Self {
808        self.stride = stride;
809        self
810    }
811
812    /// Set padding
813    pub fn padding(mut self, padding: (usize, usize)) -> Self {
814        self.padding = padding;
815        self
816    }
817}
818
819impl QuantumModule for QuantumMaxPool2d {
820    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
821        let shape = input.data.shape();
822        if shape.len() != 4 {
823            return Err(MLError::InvalidConfiguration(
824                "MaxPool2d expects 4D input (batch, channels, height, width)".to_string(),
825            ));
826        }
827
828        let (batch, channels, height, width) = (shape[0], shape[1], shape[2], shape[3]);
829        let out_height = (height + 2 * self.padding.0 - self.kernel_size.0) / self.stride.0 + 1;
830        let out_width = (width + 2 * self.padding.1 - self.kernel_size.1) / self.stride.1 + 1;
831
832        let mut output = ArrayD::zeros(IxDyn(&[batch, channels, out_height, out_width]));
833
834        for b in 0..batch {
835            for c in 0..channels {
836                for oh in 0..out_height {
837                    for ow in 0..out_width {
838                        let h_start = oh * self.stride.0;
839                        let w_start = ow * self.stride.1;
840
841                        let mut max_val = f64::NEG_INFINITY;
842                        for kh in 0..self.kernel_size.0 {
843                            for kw in 0..self.kernel_size.1 {
844                                let h = h_start + kh;
845                                let w = w_start + kw;
846                                if h < height && w < width {
847                                    max_val = max_val.max(input.data[[b, c, h, w]]);
848                                }
849                            }
850                        }
851                        output[[b, c, oh, ow]] = max_val;
852                    }
853                }
854            }
855        }
856
857        Ok(SciRS2Array::new(output, input.requires_grad))
858    }
859
860    fn parameters(&self) -> Vec<Parameter> {
861        Vec::new()
862    }
863
864    fn train(&mut self, mode: bool) {
865        self.training = mode;
866    }
867
868    fn training(&self) -> bool {
869        self.training
870    }
871
872    fn zero_grad(&mut self) {}
873
874    fn name(&self) -> &str {
875        "MaxPool2d"
876    }
877}
878
879/// Average pooling 2D
880pub struct QuantumAvgPool2d {
881    kernel_size: (usize, usize),
882    stride: (usize, usize),
883    padding: (usize, usize),
884    training: bool,
885}
886
887impl QuantumAvgPool2d {
888    /// Create new average pooling layer
889    pub fn new(kernel_size: (usize, usize)) -> Self {
890        Self {
891            kernel_size,
892            stride: kernel_size,
893            padding: (0, 0),
894            training: true,
895        }
896    }
897
898    /// Set stride
899    pub fn stride(mut self, stride: (usize, usize)) -> Self {
900        self.stride = stride;
901        self
902    }
903
904    /// Set padding
905    pub fn padding(mut self, padding: (usize, usize)) -> Self {
906        self.padding = padding;
907        self
908    }
909}
910
911impl QuantumModule for QuantumAvgPool2d {
912    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
913        let shape = input.data.shape();
914        if shape.len() != 4 {
915            return Err(MLError::InvalidConfiguration(
916                "AvgPool2d expects 4D input".to_string(),
917            ));
918        }
919
920        let (batch, channels, height, width) = (shape[0], shape[1], shape[2], shape[3]);
921        let out_height = (height + 2 * self.padding.0 - self.kernel_size.0) / self.stride.0 + 1;
922        let out_width = (width + 2 * self.padding.1 - self.kernel_size.1) / self.stride.1 + 1;
923
924        let mut output = ArrayD::zeros(IxDyn(&[batch, channels, out_height, out_width]));
925
926        for b in 0..batch {
927            for c in 0..channels {
928                for oh in 0..out_height {
929                    for ow in 0..out_width {
930                        let h_start = oh * self.stride.0;
931                        let w_start = ow * self.stride.1;
932
933                        let mut sum = 0.0;
934                        let mut count = 0;
935                        for kh in 0..self.kernel_size.0 {
936                            for kw in 0..self.kernel_size.1 {
937                                let h = h_start + kh;
938                                let w = w_start + kw;
939                                if h < height && w < width {
940                                    sum += input.data[[b, c, h, w]];
941                                    count += 1;
942                                }
943                            }
944                        }
945                        output[[b, c, oh, ow]] = if count > 0 { sum / count as f64 } else { 0.0 };
946                    }
947                }
948            }
949        }
950
951        Ok(SciRS2Array::new(output, input.requires_grad))
952    }
953
954    fn parameters(&self) -> Vec<Parameter> {
955        Vec::new()
956    }
957
958    fn train(&mut self, mode: bool) {
959        self.training = mode;
960    }
961
962    fn training(&self) -> bool {
963        self.training
964    }
965
966    fn zero_grad(&mut self) {}
967
968    fn name(&self) -> &str {
969        "AvgPool2d"
970    }
971}
972
973/// Adaptive average pooling 2D
974pub struct QuantumAdaptiveAvgPool2d {
975    output_size: (usize, usize),
976    training: bool,
977}
978
979impl QuantumAdaptiveAvgPool2d {
980    /// Create new adaptive average pooling layer
981    pub fn new(output_size: (usize, usize)) -> Self {
982        Self {
983            output_size,
984            training: true,
985        }
986    }
987}
988
989impl QuantumModule for QuantumAdaptiveAvgPool2d {
990    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
991        let shape = input.data.shape();
992        if shape.len() != 4 {
993            return Err(MLError::InvalidConfiguration(
994                "AdaptiveAvgPool2d expects 4D input".to_string(),
995            ));
996        }
997
998        let (batch, channels, height, width) = (shape[0], shape[1], shape[2], shape[3]);
999        let (out_h, out_w) = self.output_size;
1000
1001        let mut output = ArrayD::zeros(IxDyn(&[batch, channels, out_h, out_w]));
1002
1003        for b in 0..batch {
1004            for c in 0..channels {
1005                for oh in 0..out_h {
1006                    for ow in 0..out_w {
1007                        let h_start = (oh * height) / out_h;
1008                        let h_end = ((oh + 1) * height) / out_h;
1009                        let w_start = (ow * width) / out_w;
1010                        let w_end = ((ow + 1) * width) / out_w;
1011
1012                        let mut sum = 0.0;
1013                        let mut count = 0;
1014                        for h in h_start..h_end {
1015                            for w in w_start..w_end {
1016                                sum += input.data[[b, c, h, w]];
1017                                count += 1;
1018                            }
1019                        }
1020                        output[[b, c, oh, ow]] = if count > 0 { sum / count as f64 } else { 0.0 };
1021                    }
1022                }
1023            }
1024        }
1025
1026        Ok(SciRS2Array::new(output, input.requires_grad))
1027    }
1028
1029    fn parameters(&self) -> Vec<Parameter> {
1030        Vec::new()
1031    }
1032
1033    fn train(&mut self, mode: bool) {
1034        self.training = mode;
1035    }
1036
1037    fn training(&self) -> bool {
1038        self.training
1039    }
1040
1041    fn zero_grad(&mut self) {}
1042
1043    fn name(&self) -> &str {
1044        "AdaptiveAvgPool2d"
1045    }
1046}
1047
1048// ============================================================================
1049// Embedding Layer
1050// ============================================================================
1051
1052/// Embedding layer for discrete inputs
1053pub struct QuantumEmbedding {
1054    num_embeddings: usize,
1055    embedding_dim: usize,
1056    weight: Parameter,
1057    padding_idx: Option<usize>,
1058    training: bool,
1059}
1060
1061impl QuantumEmbedding {
1062    /// Create new embedding layer
1063    pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
1064        let weight_data = ArrayD::zeros(IxDyn(&[num_embeddings, embedding_dim]));
1065        let mut weight = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
1066
1067        for val in weight.data.data.iter_mut() {
1068            *val = fastrand::f64() * 2.0 - 1.0;
1069        }
1070
1071        Self {
1072            num_embeddings,
1073            embedding_dim,
1074            weight,
1075            padding_idx: None,
1076            training: true,
1077        }
1078    }
1079
1080    /// Set padding index
1081    pub fn padding_idx(mut self, idx: usize) -> Self {
1082        self.padding_idx = Some(idx);
1083        for j in 0..self.embedding_dim {
1084            self.weight.data.data[[idx, j]] = 0.0;
1085        }
1086        self
1087    }
1088
1089    /// Get embedding for indices
1090    pub fn get_embedding(&self, indices: &[usize]) -> Result<ArrayD<f64>> {
1091        let mut output = ArrayD::zeros(IxDyn(&[indices.len(), self.embedding_dim]));
1092
1093        for (i, &idx) in indices.iter().enumerate() {
1094            if idx >= self.num_embeddings {
1095                return Err(MLError::InvalidConfiguration(format!(
1096                    "Index {} out of range for {} embeddings",
1097                    idx, self.num_embeddings
1098                )));
1099            }
1100            for j in 0..self.embedding_dim {
1101                output[[i, j]] = self.weight.data.data[[idx, j]];
1102            }
1103        }
1104
1105        Ok(output)
1106    }
1107}
1108
1109impl QuantumModule for QuantumEmbedding {
1110    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
1111        let indices: Vec<usize> = input.data.iter().map(|&x| x as usize).collect();
1112        let output = self.get_embedding(&indices)?;
1113        Ok(SciRS2Array::new(output, self.training))
1114    }
1115
1116    fn parameters(&self) -> Vec<Parameter> {
1117        vec![self.weight.clone()]
1118    }
1119
1120    fn train(&mut self, mode: bool) {
1121        self.training = mode;
1122    }
1123
1124    fn training(&self) -> bool {
1125        self.training
1126    }
1127
1128    fn zero_grad(&mut self) {
1129        self.weight.data.zero_grad();
1130    }
1131
1132    fn name(&self) -> &str {
1133        "Embedding"
1134    }
1135}
1136
1137// ============================================================================
1138// Parameter Initialization
1139// ============================================================================
1140
1141/// Initialize parameters with specified method
1142pub fn init_weights(param: &mut Parameter, init_type: InitType) -> Result<()> {
1143    let shape = param.data.data.shape().to_vec();
1144    let fan_in = if shape.len() >= 2 { shape[1] } else { shape[0] };
1145    let fan_out = shape[0];
1146
1147    match init_type {
1148        InitType::Xavier => {
1149            let bound = (6.0 / (fan_in + fan_out) as f64).sqrt();
1150            for elem in param.data.data.iter_mut() {
1151                *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
1152            }
1153        }
1154        InitType::He => {
1155            let std = (2.0 / fan_in as f64).sqrt();
1156            for elem in param.data.data.iter_mut() {
1157                let u1: f64 = fastrand::f64();
1158                let u2: f64 = fastrand::f64();
1159                let normal = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
1160                *elem = normal * std;
1161            }
1162        }
1163        InitType::Normal(mean, std) => {
1164            for elem in param.data.data.iter_mut() {
1165                let u1: f64 = fastrand::f64();
1166                let u2: f64 = fastrand::f64();
1167                let normal = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
1168                *elem = mean + normal * std;
1169            }
1170        }
1171        InitType::Uniform(low, high) => {
1172            for elem in param.data.data.iter_mut() {
1173                *elem = low + (high - low) * fastrand::f64();
1174            }
1175        }
1176    }
1177    Ok(())
1178}