quantrs2_ml/pytorch_api/
conv.rs

1//! Convolutional layers for PyTorch-like API (Conv1d, Conv3d)
2
3use super::{Parameter, QuantumModule};
4use crate::error::{MLError, Result};
5use crate::scirs2_integration::SciRS2Array;
6use scirs2_core::ndarray::{ArrayD, IxDyn};
7
8/// 1D Convolution layer
9pub struct QuantumConv1d {
10    weights: Parameter,
11    bias: Option<Parameter>,
12    in_channels: usize,
13    out_channels: usize,
14    kernel_size: usize,
15    stride: usize,
16    padding: usize,
17    dilation: usize,
18    training: bool,
19}
20
21impl QuantumConv1d {
22    /// Create new Conv1d
23    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Result<Self> {
24        let weight_data =
25            ArrayD::from_shape_fn(IxDyn(&[out_channels, in_channels, kernel_size]), |_| {
26                fastrand::f64() * 0.1 - 0.05
27            });
28
29        Ok(Self {
30            weights: Parameter::new(SciRS2Array::with_grad(weight_data), "weight"),
31            bias: None,
32            in_channels,
33            out_channels,
34            kernel_size,
35            stride: 1,
36            padding: 0,
37            dilation: 1,
38            training: true,
39        })
40    }
41
42    /// Set stride
43    pub fn stride(mut self, stride: usize) -> Self {
44        self.stride = stride;
45        self
46    }
47
48    /// Set padding
49    pub fn padding(mut self, padding: usize) -> Self {
50        self.padding = padding;
51        self
52    }
53
54    /// Add bias
55    pub fn with_bias(mut self) -> Self {
56        let bias_data = ArrayD::zeros(IxDyn(&[self.out_channels]));
57        self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
58        self
59    }
60}
61
62impl QuantumModule for QuantumConv1d {
63    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
64        let shape = input.data.shape();
65        if shape.len() != 3 {
66            return Err(MLError::InvalidConfiguration(
67                "Conv1d expects 3D input (batch, channels, length)".to_string(),
68            ));
69        }
70
71        let (batch, _, length) = (shape[0], shape[1], shape[2]);
72        let out_length = (length + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
73            / self.stride
74            + 1;
75
76        let mut output = ArrayD::zeros(IxDyn(&[batch, self.out_channels, out_length]));
77
78        for b in 0..batch {
79            for oc in 0..self.out_channels {
80                for ol in 0..out_length {
81                    let mut sum = 0.0;
82                    for ic in 0..self.in_channels {
83                        for k in 0..self.kernel_size {
84                            let il = ol * self.stride + k * self.dilation;
85                            if il < length + self.padding && il >= self.padding {
86                                let input_idx = il - self.padding;
87                                if input_idx < length {
88                                    sum += input.data[[b, ic, input_idx]]
89                                        * self.weights.data.data[[oc, ic, k]];
90                                }
91                            }
92                        }
93                    }
94                    output[[b, oc, ol]] = sum;
95                }
96            }
97        }
98
99        if let Some(ref bias) = self.bias {
100            for b in 0..batch {
101                for oc in 0..self.out_channels {
102                    for ol in 0..out_length {
103                        output[[b, oc, ol]] += bias.data.data[[oc]];
104                    }
105                }
106            }
107        }
108
109        Ok(SciRS2Array::new(output, input.requires_grad))
110    }
111
112    fn parameters(&self) -> Vec<Parameter> {
113        let mut params = vec![self.weights.clone()];
114        if let Some(ref bias) = self.bias {
115            params.push(bias.clone());
116        }
117        params
118    }
119
120    fn train(&mut self, mode: bool) {
121        self.training = mode;
122    }
123
124    fn training(&self) -> bool {
125        self.training
126    }
127
128    fn zero_grad(&mut self) {
129        self.weights.data.zero_grad();
130        if let Some(ref mut bias) = self.bias {
131            bias.data.zero_grad();
132        }
133    }
134
135    fn name(&self) -> &str {
136        "Conv1d"
137    }
138}
139
140/// 3D Convolution layer
141pub struct QuantumConv3d {
142    weights: Parameter,
143    bias: Option<Parameter>,
144    in_channels: usize,
145    out_channels: usize,
146    kernel_size: (usize, usize, usize),
147    stride: (usize, usize, usize),
148    padding: (usize, usize, usize),
149    training: bool,
150}
151
152impl QuantumConv3d {
153    /// Create new Conv3d
154    pub fn new(
155        in_channels: usize,
156        out_channels: usize,
157        kernel_size: (usize, usize, usize),
158    ) -> Result<Self> {
159        let weight_data = ArrayD::from_shape_fn(
160            IxDyn(&[
161                out_channels,
162                in_channels,
163                kernel_size.0,
164                kernel_size.1,
165                kernel_size.2,
166            ]),
167            |_| fastrand::f64() * 0.1 - 0.05,
168        );
169
170        Ok(Self {
171            weights: Parameter::new(SciRS2Array::with_grad(weight_data), "weight"),
172            bias: None,
173            in_channels,
174            out_channels,
175            kernel_size,
176            stride: (1, 1, 1),
177            padding: (0, 0, 0),
178            training: true,
179        })
180    }
181
182    /// Set stride
183    pub fn stride(mut self, stride: (usize, usize, usize)) -> Self {
184        self.stride = stride;
185        self
186    }
187
188    /// Set padding
189    pub fn padding(mut self, padding: (usize, usize, usize)) -> Self {
190        self.padding = padding;
191        self
192    }
193
194    /// Add bias
195    pub fn with_bias(mut self) -> Self {
196        let bias_data = ArrayD::zeros(IxDyn(&[self.out_channels]));
197        self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
198        self
199    }
200}
201
202impl QuantumModule for QuantumConv3d {
203    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
204        let shape = input.data.shape();
205        if shape.len() != 5 {
206            return Err(MLError::InvalidConfiguration(
207                "Conv3d expects 5D input".to_string(),
208            ));
209        }
210
211        let (batch, _, depth, height, width) = (shape[0], shape[1], shape[2], shape[3], shape[4]);
212        let out_d = (depth + 2 * self.padding.0 - self.kernel_size.0) / self.stride.0 + 1;
213        let out_h = (height + 2 * self.padding.1 - self.kernel_size.1) / self.stride.1 + 1;
214        let out_w = (width + 2 * self.padding.2 - self.kernel_size.2) / self.stride.2 + 1;
215
216        let mut output = ArrayD::zeros(IxDyn(&[batch, self.out_channels, out_d, out_h, out_w]));
217
218        for b in 0..batch {
219            for oc in 0..self.out_channels {
220                for od in 0..out_d {
221                    for oh in 0..out_h {
222                        for ow in 0..out_w {
223                            let mut sum = 0.0;
224                            for ic in 0..self.in_channels {
225                                for kd in 0..self.kernel_size.0 {
226                                    for kh in 0..self.kernel_size.1 {
227                                        for kw in 0..self.kernel_size.2 {
228                                            let id = od * self.stride.0 + kd;
229                                            let ih = oh * self.stride.1 + kh;
230                                            let iw = ow * self.stride.2 + kw;
231                                            if id < depth && ih < height && iw < width {
232                                                sum += input.data[[b, ic, id, ih, iw]]
233                                                    * self.weights.data.data[[oc, ic, kd, kh, kw]];
234                                            }
235                                        }
236                                    }
237                                }
238                            }
239                            output[[b, oc, od, oh, ow]] = sum;
240                        }
241                    }
242                }
243            }
244        }
245
246        if let Some(ref bias) = self.bias {
247            for b in 0..batch {
248                for oc in 0..self.out_channels {
249                    for od in 0..out_d {
250                        for oh in 0..out_h {
251                            for ow in 0..out_w {
252                                output[[b, oc, od, oh, ow]] += bias.data.data[[oc]];
253                            }
254                        }
255                    }
256                }
257            }
258        }
259
260        Ok(SciRS2Array::new(output, input.requires_grad))
261    }
262
263    fn parameters(&self) -> Vec<Parameter> {
264        let mut params = vec![self.weights.clone()];
265        if let Some(ref bias) = self.bias {
266            params.push(bias.clone());
267        }
268        params
269    }
270
271    fn train(&mut self, mode: bool) {
272        self.training = mode;
273    }
274
275    fn training(&self) -> bool {
276        self.training
277    }
278
279    fn zero_grad(&mut self) {
280        self.weights.data.zero_grad();
281        if let Some(ref mut bias) = self.bias {
282            bias.data.zero_grad();
283        }
284    }
285
286    fn name(&self) -> &str {
287        "Conv3d"
288    }
289}