quantrs2_ml/torchquantum/
conv.rs

1//! Quantum convolutional layers for quantum machine learning
2//!
3//! This module provides quantum analogues of classical convolutional neural network layers,
4//! enabling spatial feature extraction on quantum states through parameterized local unitaries.
5//!
6//! # Layers
7//!
8//! - **QConv1D**: 1D quantum convolution with sliding window kernels
9//! - **QConv2D**: 2D quantum convolution for grid-arranged qubits
10//!
11//! # Features
12//!
13//! - Parameterized unitaries applied to local qubit neighborhoods
14//! - Configurable kernel size and stride
15//! - Compatible with TorchQuantum training framework
16//!
17//! # Example
18//!
19//! ```ignore
20//! use quantrs2_ml::torchquantum::conv::QConv1D;
21//!
22//! // Create 1D conv layer: 8 qubits, kernel size 3, stride 1
23//! let conv = QConv1D::new(8, 3, 1)?;
24//! println!("Parameters: {}", conv.n_parameters());
25//! ```
26
27use crate::error::{MLError, Result as MLResult};
28
29/// Quantum convolutional layer operating on 1D wire sequences
30///
31/// Applies a parameterized unitary to sliding windows of qubits,
32/// similar to classical convolutional neural networks but operating
33/// on quantum states.
34#[derive(Debug, Clone)]
35pub struct QConv1D {
36    /// Number of input wires
37    n_wires: usize,
38    /// Kernel size (number of qubits per convolution window)
39    kernel_size: usize,
40    /// Stride (step size for sliding window)
41    stride: usize,
42    /// Number of parameters per kernel application
43    n_params_per_kernel: usize,
44    /// Total number of trainable parameters
45    n_parameters: usize,
46    /// Layer name for debugging
47    name: String,
48}
49
50impl QConv1D {
51    /// Create a new 1D quantum convolutional layer
52    ///
53    /// # Arguments
54    /// * `n_wires` - Number of input qubits
55    /// * `kernel_size` - Size of the convolutional kernel (number of qubits)
56    /// * `stride` - Step size for the sliding window
57    /// * `n_params_per_kernel` - Number of rotation parameters per kernel
58    ///
59    /// # Example
60    /// ```ignore
61    /// // 8-qubit input, 3-qubit kernel, stride 1, 6 parameters per kernel
62    /// let conv = QConv1D::new(8, 3, 1, 6)?;
63    /// // Will apply kernel at positions: (0,1,2), (1,2,3), ..., (5,6,7)
64    /// // Total: 6 positions × 6 params = 36 parameters
65    /// ```
66    pub fn new(
67        n_wires: usize,
68        kernel_size: usize,
69        stride: usize,
70        n_params_per_kernel: usize,
71    ) -> MLResult<Self> {
72        if kernel_size > n_wires {
73            return Err(MLError::InvalidConfiguration(format!(
74                "Kernel size {} exceeds number of wires {}",
75                kernel_size, n_wires
76            )));
77        }
78
79        if stride == 0 {
80            return Err(MLError::InvalidConfiguration(
81                "Stride must be greater than 0".to_string(),
82            ));
83        }
84
85        // Calculate number of kernel applications
86        let n_kernels = (n_wires - kernel_size) / stride + 1;
87        let n_parameters = n_kernels * n_params_per_kernel;
88
89        Ok(Self {
90            n_wires,
91            kernel_size,
92            stride,
93            n_params_per_kernel,
94            n_parameters,
95            name: format!("QConv1D(kernel={}, stride={})", kernel_size, stride),
96        })
97    }
98
99    /// Get the positions where kernels will be applied
100    pub fn kernel_positions(&self) -> Vec<usize> {
101        let mut positions = Vec::new();
102        let mut pos = 0;
103
104        while pos + self.kernel_size <= self.n_wires {
105            positions.push(pos);
106            pos += self.stride;
107        }
108
109        positions
110    }
111
112    /// Get the qubit indices for a specific kernel position
113    pub fn kernel_qubits(&self, position: usize) -> Vec<usize> {
114        (position..position + self.kernel_size).collect()
115    }
116}
117
118impl QConv1D {
119    /// Get the total number of trainable parameters
120    pub fn n_parameters(&self) -> usize {
121        self.n_parameters
122    }
123}
124
125/// Quantum convolutional layer operating on 2D qubit grids
126///
127/// Extends QConv1D to operate on 2D arrangements of qubits,
128/// applying kernels to rectangular patches of the qubit lattice.
129#[derive(Debug, Clone)]
130pub struct QConv2D {
131    /// Grid width (number of qubits in x direction)
132    width: usize,
133    /// Grid height (number of qubits in y direction)
134    height: usize,
135    /// Kernel width
136    kernel_width: usize,
137    /// Kernel height
138    kernel_height: usize,
139    /// Stride in x direction
140    stride_x: usize,
141    /// Stride in y direction
142    stride_y: usize,
143    /// Number of parameters per kernel application
144    n_params_per_kernel: usize,
145    /// Total number of trainable parameters
146    n_parameters: usize,
147    /// Layer name for debugging
148    name: String,
149}
150
151impl QConv2D {
152    /// Create a new 2D quantum convolutional layer
153    ///
154    /// # Arguments
155    /// * `width` - Grid width (qubits in x direction)
156    /// * `height` - Grid height (qubits in y direction)
157    /// * `kernel_width` - Kernel width
158    /// * `kernel_height` - Kernel height
159    /// * `stride_x` - Stride in x direction
160    /// * `stride_y` - Stride in y direction
161    /// * `n_params_per_kernel` - Number of rotation parameters per kernel
162    ///
163    /// # Example
164    /// ```ignore
165    /// // 4×4 qubit grid, 2×2 kernel, stride (1,1), 8 parameters per kernel
166    /// let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8)?;
167    /// // Will apply kernel at 9 positions (3×3 grid of positions)
168    /// // Total: 9 × 8 = 72 parameters
169    /// ```
170    pub fn new(
171        width: usize,
172        height: usize,
173        kernel_width: usize,
174        kernel_height: usize,
175        stride_x: usize,
176        stride_y: usize,
177        n_params_per_kernel: usize,
178    ) -> MLResult<Self> {
179        if kernel_width > width {
180            return Err(MLError::InvalidConfiguration(format!(
181                "Kernel width {} exceeds grid width {}",
182                kernel_width, width
183            )));
184        }
185
186        if kernel_height > height {
187            return Err(MLError::InvalidConfiguration(format!(
188                "Kernel height {} exceeds grid height {}",
189                kernel_height, height
190            )));
191        }
192
193        if stride_x == 0 || stride_y == 0 {
194            return Err(MLError::InvalidConfiguration(
195                "Strides must be greater than 0".to_string(),
196            ));
197        }
198
199        // Calculate number of kernel applications
200        let n_kernels_x = (width - kernel_width) / stride_x + 1;
201        let n_kernels_y = (height - kernel_height) / stride_y + 1;
202        let n_kernels = n_kernels_x * n_kernels_y;
203        let n_parameters = n_kernels * n_params_per_kernel;
204
205        Ok(Self {
206            width,
207            height,
208            kernel_width,
209            kernel_height,
210            stride_x,
211            stride_y,
212            n_params_per_kernel,
213            n_parameters,
214            name: format!(
215                "QConv2D(kernel={}×{}, stride=({},{}))",
216                kernel_width, kernel_height, stride_x, stride_y
217            ),
218        })
219    }
220
221    /// Get the 2D positions where kernels will be applied
222    pub fn kernel_positions(&self) -> Vec<(usize, usize)> {
223        let mut positions = Vec::new();
224        let mut y = 0;
225
226        while y + self.kernel_height <= self.height {
227            let mut x = 0;
228            while x + self.kernel_width <= self.width {
229                positions.push((x, y));
230                x += self.stride_x;
231            }
232            y += self.stride_y;
233        }
234
235        positions
236    }
237
238    /// Get the qubit coordinates for a specific kernel position
239    /// Returns (x, y) coordinates in the 2D grid
240    pub fn kernel_qubits(&self, position: (usize, usize)) -> Vec<(usize, usize)> {
241        let (x0, y0) = position;
242        let mut qubits = Vec::new();
243
244        for y in y0..y0 + self.kernel_height {
245            for x in x0..x0 + self.kernel_width {
246                qubits.push((x, y));
247            }
248        }
249
250        qubits
251    }
252
253    /// Convert 2D coordinates to 1D qubit index
254    pub fn coords_to_index(&self, x: usize, y: usize) -> usize {
255        y * self.width + x
256    }
257
258    /// Convert 1D qubit index to 2D coordinates
259    pub fn index_to_coords(&self, index: usize) -> (usize, usize) {
260        (index % self.width, index / self.width)
261    }
262
263    /// Total number of qubits in the grid
264    pub fn n_wires(&self) -> usize {
265        self.width * self.height
266    }
267}
268
269impl QConv2D {
270    /// Get the total number of trainable parameters
271    pub fn n_parameters(&self) -> usize {
272        self.n_parameters
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_qconv1d_creation() {
282        let conv = QConv1D::new(8, 3, 1, 6).unwrap();
283        assert_eq!(conv.n_wires, 8);
284        assert_eq!(conv.kernel_size, 3);
285        assert_eq!(conv.stride, 1);
286        assert_eq!(conv.n_parameters(), 36); // 6 kernels × 6 params
287    }
288
289    #[test]
290    fn test_qconv1d_kernel_positions() {
291        let conv = QConv1D::new(8, 3, 2, 4).unwrap();
292        let positions = conv.kernel_positions();
293        assert_eq!(positions, vec![0, 2, 4]);
294    }
295
296    #[test]
297    fn test_qconv1d_kernel_qubits() {
298        let conv = QConv1D::new(8, 3, 1, 4).unwrap();
299        let qubits = conv.kernel_qubits(2);
300        assert_eq!(qubits, vec![2, 3, 4]);
301    }
302
303    #[test]
304    fn test_qconv1d_invalid_kernel_size() {
305        let result = QConv1D::new(4, 6, 1, 4);
306        assert!(result.is_err());
307    }
308
309    #[test]
310    fn test_qconv1d_zero_stride() {
311        let result = QConv1D::new(8, 3, 0, 4);
312        assert!(result.is_err());
313    }
314
315    #[test]
316    fn test_qconv2d_creation() {
317        let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
318        assert_eq!(conv.width, 4);
319        assert_eq!(conv.height, 4);
320        assert_eq!(conv.kernel_width, 2);
321        assert_eq!(conv.kernel_height, 2);
322        assert_eq!(conv.n_parameters(), 72); // 9 kernels × 8 params
323    }
324
325    #[test]
326    fn test_qconv2d_kernel_positions() {
327        let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
328        let positions = conv.kernel_positions();
329        assert_eq!(positions.len(), 9); // 3×3 grid
330        assert_eq!(positions[0], (0, 0));
331        assert_eq!(positions[4], (1, 1));
332        assert_eq!(positions[8], (2, 2));
333    }
334
335    #[test]
336    fn test_qconv2d_kernel_qubits() {
337        let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
338        let qubits = conv.kernel_qubits((1, 1));
339        assert_eq!(qubits, vec![(1, 1), (2, 1), (1, 2), (2, 2)]);
340    }
341
342    #[test]
343    fn test_qconv2d_coords_conversion() {
344        let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
345
346        // Test forward conversion
347        assert_eq!(conv.coords_to_index(0, 0), 0);
348        assert_eq!(conv.coords_to_index(3, 0), 3);
349        assert_eq!(conv.coords_to_index(0, 1), 4);
350        assert_eq!(conv.coords_to_index(3, 3), 15);
351
352        // Test reverse conversion
353        assert_eq!(conv.index_to_coords(0), (0, 0));
354        assert_eq!(conv.index_to_coords(5), (1, 1));
355        assert_eq!(conv.index_to_coords(15), (3, 3));
356    }
357
358    #[test]
359    fn test_qconv2d_invalid_kernel() {
360        let result = QConv2D::new(4, 4, 5, 2, 1, 1, 8);
361        assert!(result.is_err());
362    }
363
364    #[test]
365    fn test_qconv2d_zero_stride() {
366        let result = QConv2D::new(4, 4, 2, 2, 0, 1, 8);
367        assert!(result.is_err());
368    }
369}