use crate::error::{MLError, Result as MLResult};
#[derive(Debug, Clone)]
pub struct QConv1D {
n_wires: usize,
kernel_size: usize,
stride: usize,
n_params_per_kernel: usize,
n_parameters: usize,
name: String,
}
impl QConv1D {
pub fn new(
n_wires: usize,
kernel_size: usize,
stride: usize,
n_params_per_kernel: usize,
) -> MLResult<Self> {
if kernel_size > n_wires {
return Err(MLError::InvalidConfiguration(format!(
"Kernel size {} exceeds number of wires {}",
kernel_size, n_wires
)));
}
if stride == 0 {
return Err(MLError::InvalidConfiguration(
"Stride must be greater than 0".to_string(),
));
}
let n_kernels = (n_wires - kernel_size) / stride + 1;
let n_parameters = n_kernels * n_params_per_kernel;
Ok(Self {
n_wires,
kernel_size,
stride,
n_params_per_kernel,
n_parameters,
name: format!("QConv1D(kernel={}, stride={})", kernel_size, stride),
})
}
pub fn kernel_positions(&self) -> Vec<usize> {
let mut positions = Vec::new();
let mut pos = 0;
while pos + self.kernel_size <= self.n_wires {
positions.push(pos);
pos += self.stride;
}
positions
}
pub fn kernel_qubits(&self, position: usize) -> Vec<usize> {
(position..position + self.kernel_size).collect()
}
}
impl QConv1D {
pub fn n_parameters(&self) -> usize {
self.n_parameters
}
}
#[derive(Debug, Clone)]
pub struct QConv2D {
width: usize,
height: usize,
kernel_width: usize,
kernel_height: usize,
stride_x: usize,
stride_y: usize,
n_params_per_kernel: usize,
n_parameters: usize,
name: String,
}
impl QConv2D {
pub fn new(
width: usize,
height: usize,
kernel_width: usize,
kernel_height: usize,
stride_x: usize,
stride_y: usize,
n_params_per_kernel: usize,
) -> MLResult<Self> {
if kernel_width > width {
return Err(MLError::InvalidConfiguration(format!(
"Kernel width {} exceeds grid width {}",
kernel_width, width
)));
}
if kernel_height > height {
return Err(MLError::InvalidConfiguration(format!(
"Kernel height {} exceeds grid height {}",
kernel_height, height
)));
}
if stride_x == 0 || stride_y == 0 {
return Err(MLError::InvalidConfiguration(
"Strides must be greater than 0".to_string(),
));
}
let n_kernels_x = (width - kernel_width) / stride_x + 1;
let n_kernels_y = (height - kernel_height) / stride_y + 1;
let n_kernels = n_kernels_x * n_kernels_y;
let n_parameters = n_kernels * n_params_per_kernel;
Ok(Self {
width,
height,
kernel_width,
kernel_height,
stride_x,
stride_y,
n_params_per_kernel,
n_parameters,
name: format!(
"QConv2D(kernel={}×{}, stride=({},{}))",
kernel_width, kernel_height, stride_x, stride_y
),
})
}
pub fn kernel_positions(&self) -> Vec<(usize, usize)> {
let mut positions = Vec::new();
let mut y = 0;
while y + self.kernel_height <= self.height {
let mut x = 0;
while x + self.kernel_width <= self.width {
positions.push((x, y));
x += self.stride_x;
}
y += self.stride_y;
}
positions
}
pub fn kernel_qubits(&self, position: (usize, usize)) -> Vec<(usize, usize)> {
let (x0, y0) = position;
let mut qubits = Vec::new();
for y in y0..y0 + self.kernel_height {
for x in x0..x0 + self.kernel_width {
qubits.push((x, y));
}
}
qubits
}
pub fn coords_to_index(&self, x: usize, y: usize) -> usize {
y * self.width + x
}
pub fn index_to_coords(&self, index: usize) -> (usize, usize) {
(index % self.width, index / self.width)
}
pub fn n_wires(&self) -> usize {
self.width * self.height
}
}
impl QConv2D {
pub fn n_parameters(&self) -> usize {
self.n_parameters
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qconv1d_creation() {
let conv = QConv1D::new(8, 3, 1, 6).unwrap();
assert_eq!(conv.n_wires, 8);
assert_eq!(conv.kernel_size, 3);
assert_eq!(conv.stride, 1);
assert_eq!(conv.n_parameters(), 36); }
#[test]
fn test_qconv1d_kernel_positions() {
let conv = QConv1D::new(8, 3, 2, 4).unwrap();
let positions = conv.kernel_positions();
assert_eq!(positions, vec![0, 2, 4]);
}
#[test]
fn test_qconv1d_kernel_qubits() {
let conv = QConv1D::new(8, 3, 1, 4).unwrap();
let qubits = conv.kernel_qubits(2);
assert_eq!(qubits, vec![2, 3, 4]);
}
#[test]
fn test_qconv1d_invalid_kernel_size() {
let result = QConv1D::new(4, 6, 1, 4);
assert!(result.is_err());
}
#[test]
fn test_qconv1d_zero_stride() {
let result = QConv1D::new(8, 3, 0, 4);
assert!(result.is_err());
}
#[test]
fn test_qconv2d_creation() {
let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
assert_eq!(conv.width, 4);
assert_eq!(conv.height, 4);
assert_eq!(conv.kernel_width, 2);
assert_eq!(conv.kernel_height, 2);
assert_eq!(conv.n_parameters(), 72); }
#[test]
fn test_qconv2d_kernel_positions() {
let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
let positions = conv.kernel_positions();
assert_eq!(positions.len(), 9); assert_eq!(positions[0], (0, 0));
assert_eq!(positions[4], (1, 1));
assert_eq!(positions[8], (2, 2));
}
#[test]
fn test_qconv2d_kernel_qubits() {
let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
let qubits = conv.kernel_qubits((1, 1));
assert_eq!(qubits, vec![(1, 1), (2, 1), (1, 2), (2, 2)]);
}
#[test]
fn test_qconv2d_coords_conversion() {
let conv = QConv2D::new(4, 4, 2, 2, 1, 1, 8).unwrap();
assert_eq!(conv.coords_to_index(0, 0), 0);
assert_eq!(conv.coords_to_index(3, 0), 3);
assert_eq!(conv.coords_to_index(0, 1), 4);
assert_eq!(conv.coords_to_index(3, 3), 15);
assert_eq!(conv.index_to_coords(0), (0, 0));
assert_eq!(conv.index_to_coords(5), (1, 1));
assert_eq!(conv.index_to_coords(15), (3, 3));
}
#[test]
fn test_qconv2d_invalid_kernel() {
let result = QConv2D::new(4, 4, 5, 2, 1, 1, 8);
assert!(result.is_err());
}
#[test]
fn test_qconv2d_zero_stride() {
let result = QConv2D::new(4, 4, 2, 2, 0, 1, 8);
assert!(result.is_err());
}
}