use super::{NnResult, PaddingMode};
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{
Array, Array1, Array2, Array3, Array4, ArrayView, ArrayView2, Axis, ScalarOperand,
};
use scirs2_core::numeric::Float;
use scirs2_core::simd_ops::SimdUnifiedOps;
pub fn max_pool2d<T>(
x: &ArrayView2<T>,
pool_size: (usize, usize),
stride: (usize, usize),
) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let (h, w) = (x.nrows(), x.ncols());
let (ph, pw) = pool_size;
let (sh, sw) = stride;
if ph == 0 || pw == 0 {
return Err(NumRs2Error::InvalidOperation(
"Pool size must be positive".to_string(),
));
}
if sh == 0 || sw == 0 {
return Err(NumRs2Error::InvalidOperation(
"Stride must be positive".to_string(),
));
}
let out_h = (h - ph) / sh + 1;
let out_w = (w - pw) / sw + 1;
let mut result = Array2::zeros((out_h, out_w));
for i in 0..out_h {
for j in 0..out_w {
let start_h = i * sh;
let start_w = j * sw;
let mut max_val = T::neg_infinity();
for dh in 0..ph {
for dw in 0..pw {
let h_idx = start_h + dh;
let w_idx = start_w + dw;
if h_idx < h && w_idx < w {
let val = x[[h_idx, w_idx]];
if val > max_val {
max_val = val;
}
}
}
}
result[[i, j]] = max_val;
}
}
Ok(result)
}
pub fn avg_pool2d<T>(
x: &ArrayView2<T>,
pool_size: (usize, usize),
stride: (usize, usize),
) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let (h, w) = (x.nrows(), x.ncols());
let (ph, pw) = pool_size;
let (sh, sw) = stride;
if ph == 0 || pw == 0 {
return Err(NumRs2Error::InvalidOperation(
"Pool size must be positive".to_string(),
));
}
if sh == 0 || sw == 0 {
return Err(NumRs2Error::InvalidOperation(
"Stride must be positive".to_string(),
));
}
let out_h = (h - ph) / sh + 1;
let out_w = (w - pw) / sw + 1;
let mut result = Array2::zeros((out_h, out_w));
for i in 0..out_h {
for j in 0..out_w {
let start_h = i * sh;
let start_w = j * sw;
let mut sum = T::zero();
let mut count = 0;
for dh in 0..ph {
for dw in 0..pw {
let h_idx = start_h + dh;
let w_idx = start_w + dw;
if h_idx < h && w_idx < w {
sum = sum + x[[h_idx, w_idx]];
count += 1;
}
}
}
if count > 0 {
let count_t = T::from(count).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert count".to_string())
})?;
result[[i, j]] = sum / count_t;
}
}
}
Ok(result)
}
pub fn adaptive_avg_pool2d<T>(x: &ArrayView2<T>, output_size: (usize, usize)) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let (in_h, in_w) = (x.nrows(), x.ncols());
let (out_h, out_w) = output_size;
if out_h == 0 || out_w == 0 {
return Err(NumRs2Error::InvalidOperation(
"Output size must be positive".to_string(),
));
}
let mut result = Array2::zeros((out_h, out_w));
for i in 0..out_h {
for j in 0..out_w {
let start_h = (i * in_h) / out_h;
let end_h = ((i + 1) * in_h) / out_h;
let start_w = (j * in_w) / out_w;
let end_w = ((j + 1) * in_w) / out_w;
let mut sum = T::zero();
let mut count = 0;
for h_idx in start_h..end_h {
for w_idx in start_w..end_w {
if h_idx < in_h && w_idx < in_w {
sum = sum + x[[h_idx, w_idx]];
count += 1;
}
}
}
if count > 0 {
let count_t = T::from(count).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert count".to_string())
})?;
result[[i, j]] = sum / count_t;
}
}
}
Ok(result)
}
pub fn global_avg_pool<T>(x: &ArrayView2<T>) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
let sum = x.sum();
let count = T::from(x.len())
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert size".to_string()))?;
Ok(sum / count)
}
pub fn global_max_pool<T>(x: &ArrayView2<T>) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
let max_val = x.fold(T::neg_infinity(), |acc, &v| if v > acc { v } else { acc });
if !max_val.is_finite() {
return Err(NumRs2Error::InvalidOperation(
"No valid maximum found".to_string(),
));
}
Ok(max_val)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array2;
#[test]
fn test_max_pool2d() {
let x = Array2::from_shape_fn((4, 4), |(i, j)| (i * 4 + j) as f64);
let result = max_pool2d(&x.view(), (2, 2), (2, 2)).unwrap();
assert_eq!(result.dim(), (2, 2));
assert_abs_diff_eq!(result[[0, 0]], 5.0, epsilon = 1e-6); assert_abs_diff_eq!(result[[0, 1]], 7.0, epsilon = 1e-6); assert_abs_diff_eq!(result[[1, 0]], 13.0, epsilon = 1e-6); assert_abs_diff_eq!(result[[1, 1]], 15.0, epsilon = 1e-6); }
#[test]
fn test_avg_pool2d() {
let x = Array2::from_shape_fn((4, 4), |(_, _)| 1.0);
let result = avg_pool2d(&x.view(), (2, 2), (2, 2)).unwrap();
assert_eq!(result.dim(), (2, 2));
for &val in result.iter() {
assert_abs_diff_eq!(val, 1.0, epsilon = 1e-6);
}
}
#[test]
fn test_global_avg_pool() {
let x = Array2::from_shape_fn((3, 3), |(_, _)| 2.0);
let result = global_avg_pool(&x.view()).unwrap();
assert_abs_diff_eq!(result, 2.0, epsilon = 1e-6);
}
}