use crate::error::{MLError, Result as MLResult};
#[derive(Debug, Clone)]
pub struct QMaxPool {
n_wires: usize,
pool_size: usize,
stride: usize,
name: String,
}
impl QMaxPool {
pub fn new(n_wires: usize, pool_size: usize, stride: usize) -> MLResult<Self> {
if pool_size > n_wires {
return Err(MLError::InvalidConfiguration(format!(
"Pool size {} exceeds number of wires {}",
pool_size, n_wires
)));
}
if stride == 0 {
return Err(MLError::InvalidConfiguration(
"Stride must be greater than 0".to_string(),
));
}
Ok(Self {
n_wires,
pool_size,
stride,
name: format!("QMaxPool(size={}, stride={})", pool_size, stride),
})
}
pub fn pool_positions(&self) -> Vec<usize> {
let mut positions = Vec::new();
let mut pos = 0;
while pos + self.pool_size <= self.n_wires {
positions.push(pos);
pos += self.stride;
}
positions
}
pub fn pool_qubits(&self, position: usize) -> Vec<usize> {
(position..position + self.pool_size).collect()
}
pub fn output_size(&self) -> usize {
self.pool_positions().len()
}
}
impl QMaxPool {
pub fn n_parameters(&self) -> usize {
0
}
}
#[derive(Debug, Clone)]
pub struct QAvgPool {
n_wires: usize,
pool_size: usize,
stride: usize,
name: String,
}
impl QAvgPool {
pub fn new(n_wires: usize, pool_size: usize, stride: usize) -> MLResult<Self> {
if pool_size > n_wires {
return Err(MLError::InvalidConfiguration(format!(
"Pool size {} exceeds number of wires {}",
pool_size, n_wires
)));
}
if stride == 0 {
return Err(MLError::InvalidConfiguration(
"Stride must be greater than 0".to_string(),
));
}
Ok(Self {
n_wires,
pool_size,
stride,
name: format!("QAvgPool(size={}, stride={})", pool_size, stride),
})
}
pub fn pool_positions(&self) -> Vec<usize> {
let mut positions = Vec::new();
let mut pos = 0;
while pos + self.pool_size <= self.n_wires {
positions.push(pos);
pos += self.stride;
}
positions
}
pub fn pool_qubits(&self, position: usize) -> Vec<usize> {
(position..position + self.pool_size).collect()
}
pub fn output_size(&self) -> usize {
self.pool_positions().len()
}
}
impl QAvgPool {
pub fn n_parameters(&self) -> usize {
0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qmaxpool_creation() {
let pool = QMaxPool::new(8, 2, 2).unwrap();
assert_eq!(pool.n_wires, 8);
assert_eq!(pool.pool_size, 2);
assert_eq!(pool.stride, 2);
assert_eq!(pool.n_parameters(), 0);
}
#[test]
fn test_qmaxpool_positions() {
let pool = QMaxPool::new(8, 2, 2).unwrap();
let positions = pool.pool_positions();
assert_eq!(positions, vec![0, 2, 4, 6]);
}
#[test]
fn test_qmaxpool_qubits() {
let pool = QMaxPool::new(8, 2, 2).unwrap();
let qubits = pool.pool_qubits(4);
assert_eq!(qubits, vec![4, 5]);
}
#[test]
fn test_qmaxpool_output_size() {
let pool = QMaxPool::new(8, 2, 2).unwrap();
assert_eq!(pool.output_size(), 4);
}
#[test]
fn test_qmaxpool_invalid_pool_size() {
let result = QMaxPool::new(4, 6, 2);
assert!(result.is_err());
}
#[test]
fn test_qavgpool_creation() {
let pool = QAvgPool::new(8, 2, 2).unwrap();
assert_eq!(pool.n_wires, 8);
assert_eq!(pool.pool_size, 2);
assert_eq!(pool.stride, 2);
assert_eq!(pool.n_parameters(), 0);
}
#[test]
fn test_qavgpool_positions() {
let pool = QAvgPool::new(8, 2, 2).unwrap();
let positions = pool.pool_positions();
assert_eq!(positions, vec![0, 2, 4, 6]);
}
#[test]
fn test_qavgpool_output_size() {
let pool = QAvgPool::new(8, 2, 2).unwrap();
assert_eq!(pool.output_size(), 4);
}
}