quantrs2_ml/torchquantum/
pooling.rs

1//! Quantum pooling layers for dimensionality reduction
2//!
3//! This module provides quantum pooling operations that reduce the number of qubits
4//! while preserving important quantum features, analogous to pooling in classical CNNs.
5//!
6//! # Layers
7//!
8//! - **QMaxPool**: Select qubits with highest measurement probability
9//! - **QAvgPool**: Average measurements over pooling regions
10//!
11//! # Features
12//!
13//! - Configurable pool size and stride
14//! - Non-trainable dimensionality reduction
15//! - Compatible with TorchQuantum training framework
16//!
17//! # Example
18//!
19//! ```ignore
20//! use quantrs2_ml::torchquantum::pooling::QMaxPool;
21//!
22//! // Create pooling layer: 8 qubits, pool size 2, stride 2
23//! let pool = QMaxPool::new(8, 2, 2)?;
24//! println!("Output qubits: {}", pool.output_wires());
25//! ```
26
27use crate::error::{MLError, Result as MLResult};
28
29/// Quantum pooling layer using maximum measurement probability
30///
31/// Reduces the number of qubits by measuring subsets and keeping
32/// the qubit with the highest measurement probability in each pool.
33#[derive(Debug, Clone)]
34pub struct QMaxPool {
35    /// Number of input wires
36    n_wires: usize,
37    /// Pool size (number of qubits per pool)
38    pool_size: usize,
39    /// Stride (step size for pooling windows)
40    stride: usize,
41    /// Layer name for debugging
42    name: String,
43}
44
45impl QMaxPool {
46    /// Create a new quantum max pooling layer
47    ///
48    /// # Arguments
49    /// * `n_wires` - Number of input qubits
50    /// * `pool_size` - Size of each pooling window
51    /// * `stride` - Step size for pooling windows
52    ///
53    /// # Example
54    /// ```ignore
55    /// // 8-qubit input, pool size 2, stride 2
56    /// let pool = QMaxPool::new(8, 2, 2)?;
57    /// // Will pool: (0,1), (2,3), (4,5), (6,7) -> 4 output qubits
58    /// ```
59    pub fn new(n_wires: usize, pool_size: usize, stride: usize) -> MLResult<Self> {
60        if pool_size > n_wires {
61            return Err(MLError::InvalidConfiguration(format!(
62                "Pool size {} exceeds number of wires {}",
63                pool_size, n_wires
64            )));
65        }
66
67        if stride == 0 {
68            return Err(MLError::InvalidConfiguration(
69                "Stride must be greater than 0".to_string(),
70            ));
71        }
72
73        Ok(Self {
74            n_wires,
75            pool_size,
76            stride,
77            name: format!("QMaxPool(size={}, stride={})", pool_size, stride),
78        })
79    }
80
81    /// Get the positions where pooling windows start
82    pub fn pool_positions(&self) -> Vec<usize> {
83        let mut positions = Vec::new();
84        let mut pos = 0;
85
86        while pos + self.pool_size <= self.n_wires {
87            positions.push(pos);
88            pos += self.stride;
89        }
90
91        positions
92    }
93
94    /// Get the qubit indices for a specific pool
95    pub fn pool_qubits(&self, position: usize) -> Vec<usize> {
96        (position..position + self.pool_size).collect()
97    }
98
99    /// Calculate the number of output qubits after pooling
100    pub fn output_size(&self) -> usize {
101        self.pool_positions().len()
102    }
103}
104
105impl QMaxPool {
106    /// Get the total number of trainable parameters (always 0 for pooling)
107    pub fn n_parameters(&self) -> usize {
108        0
109    }
110}
111
112/// Quantum average pooling layer
113///
114/// Reduces the number of qubits by applying averaging operations
115/// over pools of qubits, typically using measurement statistics.
116#[derive(Debug, Clone)]
117pub struct QAvgPool {
118    /// Number of input wires
119    n_wires: usize,
120    /// Pool size (number of qubits per pool)
121    pool_size: usize,
122    /// Stride (step size for pooling windows)
123    stride: usize,
124    /// Layer name for debugging
125    name: String,
126}
127
128impl QAvgPool {
129    /// Create a new quantum average pooling layer
130    ///
131    /// # Arguments
132    /// * `n_wires` - Number of input qubits
133    /// * `pool_size` - Size of each pooling window
134    /// * `stride` - Step size for pooling windows
135    ///
136    /// # Example
137    /// ```ignore
138    /// // 8-qubit input, pool size 2, stride 2
139    /// let pool = QAvgPool::new(8, 2, 2)?;
140    /// // Will average: (0,1), (2,3), (4,5), (6,7) -> 4 output qubits
141    /// ```
142    pub fn new(n_wires: usize, pool_size: usize, stride: usize) -> MLResult<Self> {
143        if pool_size > n_wires {
144            return Err(MLError::InvalidConfiguration(format!(
145                "Pool size {} exceeds number of wires {}",
146                pool_size, n_wires
147            )));
148        }
149
150        if stride == 0 {
151            return Err(MLError::InvalidConfiguration(
152                "Stride must be greater than 0".to_string(),
153            ));
154        }
155
156        Ok(Self {
157            n_wires,
158            pool_size,
159            stride,
160            name: format!("QAvgPool(size={}, stride={})", pool_size, stride),
161        })
162    }
163
164    /// Get the positions where pooling windows start
165    pub fn pool_positions(&self) -> Vec<usize> {
166        let mut positions = Vec::new();
167        let mut pos = 0;
168
169        while pos + self.pool_size <= self.n_wires {
170            positions.push(pos);
171            pos += self.stride;
172        }
173
174        positions
175    }
176
177    /// Get the qubit indices for a specific pool
178    pub fn pool_qubits(&self, position: usize) -> Vec<usize> {
179        (position..position + self.pool_size).collect()
180    }
181
182    /// Calculate the number of output qubits after pooling
183    pub fn output_size(&self) -> usize {
184        self.pool_positions().len()
185    }
186}
187
188impl QAvgPool {
189    /// Get the total number of trainable parameters (always 0 for pooling)
190    pub fn n_parameters(&self) -> usize {
191        0
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_qmaxpool_creation() {
201        let pool = QMaxPool::new(8, 2, 2).unwrap();
202        assert_eq!(pool.n_wires, 8);
203        assert_eq!(pool.pool_size, 2);
204        assert_eq!(pool.stride, 2);
205        assert_eq!(pool.n_parameters(), 0);
206    }
207
208    #[test]
209    fn test_qmaxpool_positions() {
210        let pool = QMaxPool::new(8, 2, 2).unwrap();
211        let positions = pool.pool_positions();
212        assert_eq!(positions, vec![0, 2, 4, 6]);
213    }
214
215    #[test]
216    fn test_qmaxpool_qubits() {
217        let pool = QMaxPool::new(8, 2, 2).unwrap();
218        let qubits = pool.pool_qubits(4);
219        assert_eq!(qubits, vec![4, 5]);
220    }
221
222    #[test]
223    fn test_qmaxpool_output_size() {
224        let pool = QMaxPool::new(8, 2, 2).unwrap();
225        assert_eq!(pool.output_size(), 4);
226    }
227
228    #[test]
229    fn test_qmaxpool_invalid_pool_size() {
230        let result = QMaxPool::new(4, 6, 2);
231        assert!(result.is_err());
232    }
233
234    #[test]
235    fn test_qavgpool_creation() {
236        let pool = QAvgPool::new(8, 2, 2).unwrap();
237        assert_eq!(pool.n_wires, 8);
238        assert_eq!(pool.pool_size, 2);
239        assert_eq!(pool.stride, 2);
240        assert_eq!(pool.n_parameters(), 0);
241    }
242
243    #[test]
244    fn test_qavgpool_positions() {
245        let pool = QAvgPool::new(8, 2, 2).unwrap();
246        let positions = pool.pool_positions();
247        assert_eq!(positions, vec![0, 2, 4, 6]);
248    }
249
250    #[test]
251    fn test_qavgpool_output_size() {
252        let pool = QAvgPool::new(8, 2, 2).unwrap();
253        assert_eq!(pool.output_size(), 4);
254    }
255}