1use super::{Parameter, QuantumModule};
4use crate::error::{MLError, Result};
5use crate::scirs2_integration::SciRS2Array;
6use scirs2_core::ndarray::{ArrayD, IxDyn};
7
8pub struct QuantumConv1d {
10 weights: Parameter,
11 bias: Option<Parameter>,
12 in_channels: usize,
13 out_channels: usize,
14 kernel_size: usize,
15 stride: usize,
16 padding: usize,
17 dilation: usize,
18 training: bool,
19}
20
21impl QuantumConv1d {
22 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Result<Self> {
24 let weight_data =
25 ArrayD::from_shape_fn(IxDyn(&[out_channels, in_channels, kernel_size]), |_| {
26 fastrand::f64() * 0.1 - 0.05
27 });
28
29 Ok(Self {
30 weights: Parameter::new(SciRS2Array::with_grad(weight_data), "weight"),
31 bias: None,
32 in_channels,
33 out_channels,
34 kernel_size,
35 stride: 1,
36 padding: 0,
37 dilation: 1,
38 training: true,
39 })
40 }
41
42 pub fn stride(mut self, stride: usize) -> Self {
44 self.stride = stride;
45 self
46 }
47
48 pub fn padding(mut self, padding: usize) -> Self {
50 self.padding = padding;
51 self
52 }
53
54 pub fn with_bias(mut self) -> Self {
56 let bias_data = ArrayD::zeros(IxDyn(&[self.out_channels]));
57 self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
58 self
59 }
60}
61
62impl QuantumModule for QuantumConv1d {
63 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
64 let shape = input.data.shape();
65 if shape.len() != 3 {
66 return Err(MLError::InvalidConfiguration(
67 "Conv1d expects 3D input (batch, channels, length)".to_string(),
68 ));
69 }
70
71 let (batch, _, length) = (shape[0], shape[1], shape[2]);
72 let out_length = (length + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
73 / self.stride
74 + 1;
75
76 let mut output = ArrayD::zeros(IxDyn(&[batch, self.out_channels, out_length]));
77
78 for b in 0..batch {
79 for oc in 0..self.out_channels {
80 for ol in 0..out_length {
81 let mut sum = 0.0;
82 for ic in 0..self.in_channels {
83 for k in 0..self.kernel_size {
84 let il = ol * self.stride + k * self.dilation;
85 if il < length + self.padding && il >= self.padding {
86 let input_idx = il - self.padding;
87 if input_idx < length {
88 sum += input.data[[b, ic, input_idx]]
89 * self.weights.data.data[[oc, ic, k]];
90 }
91 }
92 }
93 }
94 output[[b, oc, ol]] = sum;
95 }
96 }
97 }
98
99 if let Some(ref bias) = self.bias {
100 for b in 0..batch {
101 for oc in 0..self.out_channels {
102 for ol in 0..out_length {
103 output[[b, oc, ol]] += bias.data.data[[oc]];
104 }
105 }
106 }
107 }
108
109 Ok(SciRS2Array::new(output, input.requires_grad))
110 }
111
112 fn parameters(&self) -> Vec<Parameter> {
113 let mut params = vec![self.weights.clone()];
114 if let Some(ref bias) = self.bias {
115 params.push(bias.clone());
116 }
117 params
118 }
119
120 fn train(&mut self, mode: bool) {
121 self.training = mode;
122 }
123
124 fn training(&self) -> bool {
125 self.training
126 }
127
128 fn zero_grad(&mut self) {
129 self.weights.data.zero_grad();
130 if let Some(ref mut bias) = self.bias {
131 bias.data.zero_grad();
132 }
133 }
134
135 fn name(&self) -> &str {
136 "Conv1d"
137 }
138}
139
140pub struct QuantumConv3d {
142 weights: Parameter,
143 bias: Option<Parameter>,
144 in_channels: usize,
145 out_channels: usize,
146 kernel_size: (usize, usize, usize),
147 stride: (usize, usize, usize),
148 padding: (usize, usize, usize),
149 training: bool,
150}
151
152impl QuantumConv3d {
153 pub fn new(
155 in_channels: usize,
156 out_channels: usize,
157 kernel_size: (usize, usize, usize),
158 ) -> Result<Self> {
159 let weight_data = ArrayD::from_shape_fn(
160 IxDyn(&[
161 out_channels,
162 in_channels,
163 kernel_size.0,
164 kernel_size.1,
165 kernel_size.2,
166 ]),
167 |_| fastrand::f64() * 0.1 - 0.05,
168 );
169
170 Ok(Self {
171 weights: Parameter::new(SciRS2Array::with_grad(weight_data), "weight"),
172 bias: None,
173 in_channels,
174 out_channels,
175 kernel_size,
176 stride: (1, 1, 1),
177 padding: (0, 0, 0),
178 training: true,
179 })
180 }
181
182 pub fn stride(mut self, stride: (usize, usize, usize)) -> Self {
184 self.stride = stride;
185 self
186 }
187
188 pub fn padding(mut self, padding: (usize, usize, usize)) -> Self {
190 self.padding = padding;
191 self
192 }
193
194 pub fn with_bias(mut self) -> Self {
196 let bias_data = ArrayD::zeros(IxDyn(&[self.out_channels]));
197 self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
198 self
199 }
200}
201
202impl QuantumModule for QuantumConv3d {
203 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
204 let shape = input.data.shape();
205 if shape.len() != 5 {
206 return Err(MLError::InvalidConfiguration(
207 "Conv3d expects 5D input".to_string(),
208 ));
209 }
210
211 let (batch, _, depth, height, width) = (shape[0], shape[1], shape[2], shape[3], shape[4]);
212 let out_d = (depth + 2 * self.padding.0 - self.kernel_size.0) / self.stride.0 + 1;
213 let out_h = (height + 2 * self.padding.1 - self.kernel_size.1) / self.stride.1 + 1;
214 let out_w = (width + 2 * self.padding.2 - self.kernel_size.2) / self.stride.2 + 1;
215
216 let mut output = ArrayD::zeros(IxDyn(&[batch, self.out_channels, out_d, out_h, out_w]));
217
218 for b in 0..batch {
219 for oc in 0..self.out_channels {
220 for od in 0..out_d {
221 for oh in 0..out_h {
222 for ow in 0..out_w {
223 let mut sum = 0.0;
224 for ic in 0..self.in_channels {
225 for kd in 0..self.kernel_size.0 {
226 for kh in 0..self.kernel_size.1 {
227 for kw in 0..self.kernel_size.2 {
228 let id = od * self.stride.0 + kd;
229 let ih = oh * self.stride.1 + kh;
230 let iw = ow * self.stride.2 + kw;
231 if id < depth && ih < height && iw < width {
232 sum += input.data[[b, ic, id, ih, iw]]
233 * self.weights.data.data[[oc, ic, kd, kh, kw]];
234 }
235 }
236 }
237 }
238 }
239 output[[b, oc, od, oh, ow]] = sum;
240 }
241 }
242 }
243 }
244 }
245
246 if let Some(ref bias) = self.bias {
247 for b in 0..batch {
248 for oc in 0..self.out_channels {
249 for od in 0..out_d {
250 for oh in 0..out_h {
251 for ow in 0..out_w {
252 output[[b, oc, od, oh, ow]] += bias.data.data[[oc]];
253 }
254 }
255 }
256 }
257 }
258 }
259
260 Ok(SciRS2Array::new(output, input.requires_grad))
261 }
262
263 fn parameters(&self) -> Vec<Parameter> {
264 let mut params = vec![self.weights.clone()];
265 if let Some(ref bias) = self.bias {
266 params.push(bias.clone());
267 }
268 params
269 }
270
271 fn train(&mut self, mode: bool) {
272 self.training = mode;
273 }
274
275 fn training(&self) -> bool {
276 self.training
277 }
278
279 fn zero_grad(&mut self) {
280 self.weights.data.zero_grad();
281 if let Some(ref mut bias) = self.bias {
282 bias.data.zero_grad();
283 }
284 }
285
286 fn name(&self) -> &str {
287 "Conv3d"
288 }
289}