1use crate::error::{MLError, Result as MLResult};
28
29#[derive(Debug, Clone)]
35pub struct QConv1D {
36 n_wires: usize,
38 kernel_size: usize,
40 stride: usize,
42 n_params_per_kernel: usize,
44 n_parameters: usize,
46 name: String,
48}
49
50impl QConv1D {
51 pub fn new(
67 n_wires: usize,
68 kernel_size: usize,
69 stride: usize,
70 n_params_per_kernel: usize,
71 ) -> MLResult<Self> {
72 if kernel_size > n_wires {
73 return Err(MLError::InvalidConfiguration(format!(
74 "Kernel size {} exceeds number of wires {}",
75 kernel_size, n_wires
76 )));
77 }
78
79 if stride == 0 {
80 return Err(MLError::InvalidConfiguration(
81 "Stride must be greater than 0".to_string(),
82 ));
83 }
84
85 let n_kernels = (n_wires - kernel_size) / stride + 1;
87 let n_parameters = n_kernels * n_params_per_kernel;
88
89 Ok(Self {
90 n_wires,
91 kernel_size,
92 stride,
93 n_params_per_kernel,
94 n_parameters,
95 name: format!("QConv1D(kernel={}, stride={})", kernel_size, stride),
96 })
97 }
98
99 pub fn kernel_positions(&self) -> Vec<usize> {
101 let mut positions = Vec::new();
102 let mut pos = 0;
103
104 while pos + self.kernel_size <= self.n_wires {
105 positions.push(pos);
106 pos += self.stride;
107 }
108
109 positions
110 }
111
112 pub fn kernel_qubits(&self, position: usize) -> Vec<usize> {
114 (position..position + self.kernel_size).collect()
115 }
116}
117
118impl QConv1D {
119 pub fn n_parameters(&self) -> usize {
121 self.n_parameters
122 }
123}
124
125#[derive(Debug, Clone)]
130pub struct QConv2D {
131 width: usize,
133 height: usize,
135 kernel_width: usize,
137 kernel_height: usize,
139 stride_x: usize,
141 stride_y: usize,
143 n_params_per_kernel: usize,
145 n_parameters: usize,
147 name: String,
149}
150
151impl QConv2D {
152 pub fn new(
171 width: usize,
172 height: usize,
173 kernel_width: usize,
174 kernel_height: usize,
175 stride_x: usize,
176 stride_y: usize,
177 n_params_per_kernel: usize,
178 ) -> MLResult<Self> {
179 if kernel_width > width {
180 return Err(MLError::InvalidConfiguration(format!(
181 "Kernel width {} exceeds grid width {}",
182 kernel_width, width
183 )));
184 }
185
186 if kernel_height > height {
187 return Err(MLError::InvalidConfiguration(format!(
188 "Kernel height {} exceeds grid height {}",
189 kernel_height, height
190 )));
191 }
192
193 if stride_x == 0 || stride_y == 0 {
194 return Err(MLError::InvalidConfiguration(
195 "Strides must be greater than 0".to_string(),
196 ));
197 }
198
199 let n_kernels_x = (width - kernel_width) / stride_x + 1;
201 let n_kernels_y = (height - kernel_height) / stride_y + 1;
202 let n_kernels = n_kernels_x * n_kernels_y;
203 let n_parameters = n_kernels * n_params_per_kernel;
204
205 Ok(Self {
206 width,
207 height,
208 kernel_width,
209 kernel_height,
210 stride_x,
211 stride_y,
212 n_params_per_kernel,
213 n_parameters,
214 name: format!(
215 "QConv2D(kernel={}×{}, stride=({},{}))",
216 kernel_width, kernel_height, stride_x, stride_y
217 ),
218 })
219 }
220
221 pub fn kernel_positions(&self) -> Vec<(usize, usize)> {
223 let mut positions = Vec::new();
224 let mut y = 0;
225
226 while y + self.kernel_height <= self.height {
227 let mut x = 0;
228 while x + self.kernel_width <= self.width {
229 positions.push((x, y));
230 x += self.stride_x;
231 }
232 y += self.stride_y;
233 }
234
235 positions
236 }
237
238 pub fn kernel_qubits(&self, position: (usize, usize)) -> Vec<(usize, usize)> {
241 let (x0, y0) = position;
242 let mut qubits = Vec::new();
243
244 for y in y0..y0 + self.kernel_height {
245 for x in x0..x0 + self.kernel_width {
246 qubits.push((x, y));
247 }
248 }
249
250 qubits
251 }
252
253 pub fn coords_to_index(&self, x: usize, y: usize) -> usize {
255 y * self.width + x
256 }
257
258 pub fn index_to_coords(&self, index: usize) -> (usize, usize) {
260 (index % self.width, index / self.width)
261 }
262
263 pub fn n_wires(&self) -> usize {
265 self.width * self.height
266 }
267}
268
269impl QConv2D {
270 pub fn n_parameters(&self) -> usize {
272 self.n_parameters
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_qconv1d_creation() {
282 let conv = QConv1D::new(8, 3, 1, 6).unwrap();
283 assert_eq!(conv.n_wires, 8);
284 assert_eq!(conv.kernel_size, 3);
285 assert_eq!(conv.stride, 1);
286 assert_eq!(conv.n_parameters(), 36); }
288
289 #[test]
290 fn test_qconv1d_kernel_positions() {
291 let conv = QConv1D::new(8, 3, 2, 4).unwrap();
292 let positions = conv.kernel_positions();
293 assert_eq!(positions, vec![0, 2, 4]);
294 }
295
296 #[test]
297 fn test_qconv1d_kernel_qubits() {
298 let conv = QConv1D::new(8, 3, 1, 4).unwrap();
299 let qubits = conv.kernel_qubits(2);
300 assert_eq!(qubits, vec![2, 3, 4]);
301 }
302
303 #[test]
304 fn test_qconv1d_invalid_kernel_size() {
305 let result = QConv1D::new(4, 6, 1, 4);
306 assert!(result.is_err());
307 }
308
309 #[test]
310 fn test_qconv1d_zero_stride() {
311 let result = QConv1D::new(8, 3, 0, 4);
312 assert!(result.is_err());
313 }
314
315 #[test]
316 fn test_qconv2d_creation() {
317 let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
318 assert_eq!(conv.width, 4);
319 assert_eq!(conv.height, 4);
320 assert_eq!(conv.kernel_width, 2);
321 assert_eq!(conv.kernel_height, 2);
322 assert_eq!(conv.n_parameters(), 72); }
324
325 #[test]
326 fn test_qconv2d_kernel_positions() {
327 let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
328 let positions = conv.kernel_positions();
329 assert_eq!(positions.len(), 9); assert_eq!(positions[0], (0, 0));
331 assert_eq!(positions[4], (1, 1));
332 assert_eq!(positions[8], (2, 2));
333 }
334
335 #[test]
336 fn test_qconv2d_kernel_qubits() {
337 let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
338 let qubits = conv.kernel_qubits((1, 1));
339 assert_eq!(qubits, vec![(1, 1), (2, 1), (1, 2), (2, 2)]);
340 }
341
342 #[test]
343 fn test_qconv2d_coords_conversion() {
344 let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
345
346 assert_eq!(conv.coords_to_index(0, 0), 0);
348 assert_eq!(conv.coords_to_index(3, 0), 3);
349 assert_eq!(conv.coords_to_index(0, 1), 4);
350 assert_eq!(conv.coords_to_index(3, 3), 15);
351
352 assert_eq!(conv.index_to_coords(0), (0, 0));
354 assert_eq!(conv.index_to_coords(5), (1, 1));
355 assert_eq!(conv.index_to_coords(15), (3, 3));
356 }
357
358 #[test]
359 fn test_qconv2d_invalid_kernel() {
360 let result = QConv2D::new(4, 4, 5, 2, 1, 1, 8);
361 assert!(result.is_err());
362 }
363
364 #[test]
365 fn test_qconv2d_zero_stride() {
366 let result = QConv2D::new(4, 4, 2, 2, 0, 1, 8);
367 assert!(result.is_err());
368 }
369}