use crate::TruenoError;
use super::super::super::Matrix;
impl Matrix<f32> {
pub fn max_pool2d(
&self,
kernel: (usize, usize),
stride: (usize, usize),
) -> Result<Matrix<f32>, TruenoError> {
let (kh, kw) = kernel;
let (sh, sw) = stride;
if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
return Err(TruenoError::InvalidInput(
"Kernel and stride dimensions must be positive".into(),
));
}
if kh > self.rows || kw > self.cols {
return Err(TruenoError::InvalidInput(format!(
"Kernel size ({}, {}) larger than input ({}, {})",
kh, kw, self.rows, self.cols
)));
}
let out_h = (self.rows - kh) / sh + 1;
let out_w = (self.cols - kw) / sw + 1;
let mut result = Matrix::new(out_h, out_w);
for i in 0..out_h {
for j in 0..out_w {
let mut max_val = f32::NEG_INFINITY;
for ki in 0..kh {
for kj in 0..kw {
let val = self.data[(i * sh + ki) * self.cols + (j * sw + kj)];
max_val = max_val.max(val);
}
}
result.data[i * out_w + j] = max_val;
}
}
Ok(result)
}
pub fn avg_pool2d(
&self,
kernel: (usize, usize),
stride: (usize, usize),
) -> Result<Matrix<f32>, TruenoError> {
let (kh, kw) = kernel;
let (sh, sw) = stride;
if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
return Err(TruenoError::InvalidInput(
"Kernel and stride dimensions must be positive".into(),
));
}
if kh > self.rows || kw > self.cols {
return Err(TruenoError::InvalidInput(format!(
"Kernel size ({}, {}) larger than input ({}, {})",
kh, kw, self.rows, self.cols
)));
}
let out_h = (self.rows - kh) / sh + 1;
let out_w = (self.cols - kw) / sw + 1;
let kernel_size = (kh * kw) as f32;
let mut result = Matrix::new(out_h, out_w);
for i in 0..out_h {
for j in 0..out_w {
let mut sum = 0.0;
for ki in 0..kh {
for kj in 0..kw {
sum += self.data[(i * sh + ki) * self.cols + (j * sw + kj)];
}
}
result.data[i * out_w + j] = sum / kernel_size;
}
}
Ok(result)
}
}