quantrs2_ml/torchquantum/
pooling.rs1use crate::error::{MLError, Result as MLResult};
28
29#[derive(Debug, Clone)]
34pub struct QMaxPool {
35 n_wires: usize,
37 pool_size: usize,
39 stride: usize,
41 name: String,
43}
44
45impl QMaxPool {
46 pub fn new(n_wires: usize, pool_size: usize, stride: usize) -> MLResult<Self> {
60 if pool_size > n_wires {
61 return Err(MLError::InvalidConfiguration(format!(
62 "Pool size {} exceeds number of wires {}",
63 pool_size, n_wires
64 )));
65 }
66
67 if stride == 0 {
68 return Err(MLError::InvalidConfiguration(
69 "Stride must be greater than 0".to_string(),
70 ));
71 }
72
73 Ok(Self {
74 n_wires,
75 pool_size,
76 stride,
77 name: format!("QMaxPool(size={}, stride={})", pool_size, stride),
78 })
79 }
80
81 pub fn pool_positions(&self) -> Vec<usize> {
83 let mut positions = Vec::new();
84 let mut pos = 0;
85
86 while pos + self.pool_size <= self.n_wires {
87 positions.push(pos);
88 pos += self.stride;
89 }
90
91 positions
92 }
93
94 pub fn pool_qubits(&self, position: usize) -> Vec<usize> {
96 (position..position + self.pool_size).collect()
97 }
98
99 pub fn output_size(&self) -> usize {
101 self.pool_positions().len()
102 }
103}
104
105impl QMaxPool {
106 pub fn n_parameters(&self) -> usize {
108 0
109 }
110}
111
112#[derive(Debug, Clone)]
117pub struct QAvgPool {
118 n_wires: usize,
120 pool_size: usize,
122 stride: usize,
124 name: String,
126}
127
128impl QAvgPool {
129 pub fn new(n_wires: usize, pool_size: usize, stride: usize) -> MLResult<Self> {
143 if pool_size > n_wires {
144 return Err(MLError::InvalidConfiguration(format!(
145 "Pool size {} exceeds number of wires {}",
146 pool_size, n_wires
147 )));
148 }
149
150 if stride == 0 {
151 return Err(MLError::InvalidConfiguration(
152 "Stride must be greater than 0".to_string(),
153 ));
154 }
155
156 Ok(Self {
157 n_wires,
158 pool_size,
159 stride,
160 name: format!("QAvgPool(size={}, stride={})", pool_size, stride),
161 })
162 }
163
164 pub fn pool_positions(&self) -> Vec<usize> {
166 let mut positions = Vec::new();
167 let mut pos = 0;
168
169 while pos + self.pool_size <= self.n_wires {
170 positions.push(pos);
171 pos += self.stride;
172 }
173
174 positions
175 }
176
177 pub fn pool_qubits(&self, position: usize) -> Vec<usize> {
179 (position..position + self.pool_size).collect()
180 }
181
182 pub fn output_size(&self) -> usize {
184 self.pool_positions().len()
185 }
186}
187
188impl QAvgPool {
189 pub fn n_parameters(&self) -> usize {
191 0
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_qmaxpool_creation() {
201 let pool = QMaxPool::new(8, 2, 2).unwrap();
202 assert_eq!(pool.n_wires, 8);
203 assert_eq!(pool.pool_size, 2);
204 assert_eq!(pool.stride, 2);
205 assert_eq!(pool.n_parameters(), 0);
206 }
207
208 #[test]
209 fn test_qmaxpool_positions() {
210 let pool = QMaxPool::new(8, 2, 2).unwrap();
211 let positions = pool.pool_positions();
212 assert_eq!(positions, vec![0, 2, 4, 6]);
213 }
214
215 #[test]
216 fn test_qmaxpool_qubits() {
217 let pool = QMaxPool::new(8, 2, 2).unwrap();
218 let qubits = pool.pool_qubits(4);
219 assert_eq!(qubits, vec![4, 5]);
220 }
221
222 #[test]
223 fn test_qmaxpool_output_size() {
224 let pool = QMaxPool::new(8, 2, 2).unwrap();
225 assert_eq!(pool.output_size(), 4);
226 }
227
228 #[test]
229 fn test_qmaxpool_invalid_pool_size() {
230 let result = QMaxPool::new(4, 6, 2);
231 assert!(result.is_err());
232 }
233
234 #[test]
235 fn test_qavgpool_creation() {
236 let pool = QAvgPool::new(8, 2, 2).unwrap();
237 assert_eq!(pool.n_wires, 8);
238 assert_eq!(pool.pool_size, 2);
239 assert_eq!(pool.stride, 2);
240 assert_eq!(pool.n_parameters(), 0);
241 }
242
243 #[test]
244 fn test_qavgpool_positions() {
245 let pool = QAvgPool::new(8, 2, 2).unwrap();
246 let positions = pool.pool_positions();
247 assert_eq!(positions, vec![0, 2, 4, 6]);
248 }
249
250 #[test]
251 fn test_qavgpool_output_size() {
252 let pool = QAvgPool::new(8, 2, 2).unwrap();
253 assert_eq!(pool.output_size(), 4);
254 }
255}