use scirs2_core::ndarray::{ArrayD, IxDyn};
#[derive(Debug, Clone)]
pub enum PoolingError {
InvalidKernelSize { size: usize },
InvalidStride { stride: usize },
InvalidPadding { padding: usize, kernel_size: usize },
InsufficientDimensions { ndim: usize, required: usize },
EmptyInput,
ShapeMismatch(String),
}
impl std::fmt::Display for PoolingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidKernelSize { size } => {
write!(f, "Invalid kernel size: {size} (must be > 0)")
}
Self::InvalidStride { stride } => {
write!(f, "Invalid stride: {stride} (must be > 0)")
}
Self::InvalidPadding {
padding,
kernel_size,
} => write!(
f,
"Invalid padding: {padding} (must be < kernel_size {kernel_size})"
),
Self::InsufficientDimensions { ndim, required } => {
write!(
f,
"Insufficient dimensions: got {ndim}, need at least {required}"
)
}
Self::EmptyInput => write!(f, "Empty input tensor"),
Self::ShapeMismatch(msg) => write!(f, "Shape mismatch: {msg}"),
}
}
}
impl std::error::Error for PoolingError {}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub kernel_size: Vec<usize>,
pub stride: Vec<usize>,
pub padding: Vec<usize>,
pub ceil_mode: bool,
}
impl PoolConfig {
pub fn new(kernel_size: Vec<usize>) -> Self {
Self {
stride: kernel_size.clone(),
padding: vec![0; kernel_size.len()],
kernel_size,
ceil_mode: false,
}
}
pub fn with_stride(mut self, stride: Vec<usize>) -> Self {
self.stride = stride;
self
}
pub fn with_padding(mut self, padding: Vec<usize>) -> Self {
self.padding = padding;
self
}
pub fn with_ceil_mode(mut self, ceil: bool) -> Self {
self.ceil_mode = ceil;
self
}
pub fn output_size(&self, input_size: usize, dim: usize) -> usize {
let k = self.kernel_size.get(dim).copied().unwrap_or(1);
let s = self.effective_stride(dim);
let p = self.padding.get(dim).copied().unwrap_or(0);
let numerator = input_size + 2 * p;
if numerator < k {
return 0;
}
let diff = numerator - k;
if self.ceil_mode {
diff.div_ceil(s) + 1
} else {
diff / s + 1
}
}
pub fn validate(&self) -> Result<(), PoolingError> {
for &k in &self.kernel_size {
if k == 0 {
return Err(PoolingError::InvalidKernelSize { size: k });
}
}
for &s in &self.stride {
if s == 0 {
return Err(PoolingError::InvalidStride { stride: s });
}
}
for (i, &p) in self.padding.iter().enumerate() {
let k = self.kernel_size.get(i).copied().unwrap_or(1);
if p >= k {
return Err(PoolingError::InvalidPadding {
padding: p,
kernel_size: k,
});
}
}
Ok(())
}
pub fn num_spatial_dims(&self) -> usize {
self.kernel_size.len()
}
fn effective_stride(&self, dim: usize) -> usize {
self.stride
.get(dim)
.copied()
.unwrap_or_else(|| self.kernel_size.get(dim).copied().unwrap_or(1))
}
fn effective_padding(&self, dim: usize) -> usize {
self.padding.get(dim).copied().unwrap_or(0)
}
}
fn validate_input(input: &ArrayD<f64>, num_spatial: usize) -> Result<(), PoolingError> {
if input.is_empty() {
return Err(PoolingError::EmptyInput);
}
let required = num_spatial + 2;
if input.ndim() < required {
return Err(PoolingError::InsufficientDimensions {
ndim: input.ndim(),
required,
});
}
Ok(())
}
fn compute_output_shape(
input_shape: &[usize],
config: &PoolConfig,
) -> Result<Vec<usize>, PoolingError> {
let num_spatial = config.num_spatial_dims();
let mut out_shape = Vec::with_capacity(input_shape.len());
for &d in &input_shape[..input_shape.len() - num_spatial] {
out_shape.push(d);
}
for i in 0..num_spatial {
let spatial_idx = input_shape.len() - num_spatial + i;
let out = config.output_size(input_shape[spatial_idx], i);
out_shape.push(out);
}
Ok(out_shape)
}
fn num_outer_slices(shape: &[usize], num_spatial: usize) -> usize {
shape[..shape.len() - num_spatial].iter().product()
}
fn flat_to_outer_indices(mut flat: usize, shape: &[usize], num_spatial: usize) -> Vec<usize> {
let outer_dims = shape.len() - num_spatial;
let mut indices = vec![0usize; outer_dims];
for d in (0..outer_dims).rev() {
indices[d] = flat % shape[d];
flat /= shape[d];
}
indices
}
fn get_spatial_value(
input: &ArrayD<f64>,
outer_indices: &[usize],
spatial_indices: &[usize],
num_spatial: usize,
) -> f64 {
let ndim = input.ndim();
let mut idx = vec![0usize; ndim];
for (i, &oi) in outer_indices.iter().enumerate() {
idx[i] = oi;
}
let offset = ndim - num_spatial;
for (i, &si) in spatial_indices.iter().enumerate() {
idx[offset + i] = si;
}
input[IxDyn(&idx)]
}
fn for_each_window<F>(
input_spatial_shape: &[usize],
config: &PoolConfig,
output_spatial_shape: &[usize],
mut callback: F,
) where
F: FnMut(&[usize], Vec<(f64, Vec<usize>)>),
{
let num_spatial = config.num_spatial_dims();
let mut out_pos = vec![0usize; num_spatial];
loop {
let mut window_values: Vec<(f64, Vec<usize>)> = Vec::new();
collect_window_values(
input_spatial_shape,
config,
&out_pos,
num_spatial,
0,
&mut vec![0usize; num_spatial],
&mut window_values,
);
callback(&out_pos, window_values);
if !advance_indices(&mut out_pos, output_spatial_shape) {
break;
}
}
}
fn collect_window_values(
input_spatial_shape: &[usize],
config: &PoolConfig,
out_pos: &[usize],
num_spatial: usize,
dim: usize,
current_input_pos: &mut Vec<usize>,
results: &mut Vec<(f64, Vec<usize>)>,
) {
if dim == num_spatial {
let mut valid = true;
let mut actual_pos = Vec::with_capacity(num_spatial);
for d in 0..num_spatial {
let p = config.effective_padding(d);
let pos_with_pad = current_input_pos[d];
if pos_with_pad < p || pos_with_pad >= input_spatial_shape[d] + p {
valid = false;
break;
}
actual_pos.push(pos_with_pad - p);
}
if valid {
results.push((0.0, actual_pos));
}
return;
}
let stride = config.effective_stride(dim);
let k = config.kernel_size.get(dim).copied().unwrap_or(1);
let start = out_pos[dim] * stride;
for ki in 0..k {
current_input_pos[dim] = start + ki;
collect_window_values(
input_spatial_shape,
config,
out_pos,
num_spatial,
dim + 1,
current_input_pos,
results,
);
}
}
fn advance_indices(indices: &mut [usize], shape: &[usize]) -> bool {
for d in (0..indices.len()).rev() {
indices[d] += 1;
if indices[d] < shape[d] {
return true;
}
indices[d] = 0;
}
false
}
fn spatial_flat_index(spatial_indices: &[usize], spatial_shape: &[usize]) -> i64 {
let mut flat: i64 = 0;
let mut stride: i64 = 1;
for d in (0..spatial_indices.len()).rev() {
flat += spatial_indices[d] as i64 * stride;
stride *= spatial_shape[d] as i64;
}
flat
}
pub fn max_pool(input: &ArrayD<f64>, config: &PoolConfig) -> Result<ArrayD<f64>, PoolingError> {
config.validate()?;
let num_spatial = config.num_spatial_dims();
validate_input(input, num_spatial)?;
let input_shape = input.shape();
let out_shape = compute_output_shape(input_shape, config)?;
let spatial_offset = input_shape.len() - num_spatial;
let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
let mut output = ArrayD::zeros(IxDyn(&out_shape));
let n_outer = num_outer_slices(input_shape, num_spatial);
for outer_flat in 0..n_outer {
let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
for_each_window(
&input_spatial,
config,
&output_spatial,
|out_pos, positions| {
let mut max_val = f64::NEG_INFINITY;
for (_, actual_pos) in &positions {
let val = get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
if val > max_val {
max_val = val;
}
}
if max_val == f64::NEG_INFINITY {
max_val = 0.0;
}
let mut full_idx: Vec<usize> = outer_idx.clone();
full_idx.extend_from_slice(out_pos);
output[IxDyn(&full_idx)] = max_val;
},
);
}
Ok(output)
}
pub fn max_pool_with_indices(
input: &ArrayD<f64>,
config: &PoolConfig,
) -> Result<(ArrayD<f64>, ArrayD<i64>), PoolingError> {
config.validate()?;
let num_spatial = config.num_spatial_dims();
validate_input(input, num_spatial)?;
let input_shape = input.shape();
let out_shape = compute_output_shape(input_shape, config)?;
let spatial_offset = input_shape.len() - num_spatial;
let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
let mut output = ArrayD::zeros(IxDyn(&out_shape));
let mut indices = ArrayD::zeros(IxDyn(&out_shape));
let n_outer = num_outer_slices(input_shape, num_spatial);
for outer_flat in 0..n_outer {
let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
for_each_window(
&input_spatial,
config,
&output_spatial,
|out_pos, positions| {
let mut max_val = f64::NEG_INFINITY;
let mut max_idx: i64 = -1;
for (_, actual_pos) in &positions {
let val = get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
if val > max_val {
max_val = val;
max_idx = spatial_flat_index(actual_pos, &input_spatial);
}
}
if max_val == f64::NEG_INFINITY {
max_val = 0.0;
max_idx = 0;
}
let mut full_idx: Vec<usize> = outer_idx.clone();
full_idx.extend_from_slice(out_pos);
output[IxDyn(&full_idx)] = max_val;
indices[IxDyn(&full_idx)] = max_idx;
},
);
}
Ok((output, indices))
}
pub fn avg_pool(input: &ArrayD<f64>, config: &PoolConfig) -> Result<ArrayD<f64>, PoolingError> {
config.validate()?;
let num_spatial = config.num_spatial_dims();
validate_input(input, num_spatial)?;
let input_shape = input.shape();
let out_shape = compute_output_shape(input_shape, config)?;
let spatial_offset = input_shape.len() - num_spatial;
let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
let mut output = ArrayD::zeros(IxDyn(&out_shape));
let n_outer = num_outer_slices(input_shape, num_spatial);
for outer_flat in 0..n_outer {
let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
for_each_window(
&input_spatial,
config,
&output_spatial,
|out_pos, positions| {
let mut sum = 0.0;
let count = positions.len();
for (_, actual_pos) in &positions {
sum += get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
}
let avg = if count > 0 { sum / count as f64 } else { 0.0 };
let mut full_idx: Vec<usize> = outer_idx.clone();
full_idx.extend_from_slice(out_pos);
output[IxDyn(&full_idx)] = avg;
},
);
}
Ok(output)
}
pub fn lp_pool(
input: &ArrayD<f64>,
config: &PoolConfig,
p: f64,
) -> Result<ArrayD<f64>, PoolingError> {
config.validate()?;
let num_spatial = config.num_spatial_dims();
validate_input(input, num_spatial)?;
let input_shape = input.shape();
let out_shape = compute_output_shape(input_shape, config)?;
let spatial_offset = input_shape.len() - num_spatial;
let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
let mut output = ArrayD::zeros(IxDyn(&out_shape));
let n_outer = num_outer_slices(input_shape, num_spatial);
for outer_flat in 0..n_outer {
let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
for_each_window(
&input_spatial,
config,
&output_spatial,
|out_pos, positions| {
let count = positions.len();
let mut sum_pow = 0.0;
for (_, actual_pos) in &positions {
let val = get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
sum_pow += val.abs().powf(p);
}
let result = if count > 0 {
(sum_pow / count as f64).powf(1.0 / p)
} else {
0.0
};
let mut full_idx: Vec<usize> = outer_idx.clone();
full_idx.extend_from_slice(out_pos);
output[IxDyn(&full_idx)] = result;
},
);
}
Ok(output)
}
pub fn global_max_pool(input: &ArrayD<f64>) -> Result<ArrayD<f64>, PoolingError> {
if input.is_empty() {
return Err(PoolingError::EmptyInput);
}
if input.ndim() < 3 {
return Err(PoolingError::InsufficientDimensions {
ndim: input.ndim(),
required: 3,
});
}
let shape = input.shape();
let batch = shape[0];
let channels = shape[1];
let num_spatial = input.ndim() - 2;
let spatial_size: usize = shape[2..].iter().product();
let mut output = ArrayD::zeros(IxDyn(&[batch, channels]));
for b in 0..batch {
for c in 0..channels {
let mut max_val = f64::NEG_INFINITY;
for s in 0..spatial_size {
let spatial_idx = flat_to_spatial_indices(s, &shape[2..]);
let mut full_idx = vec![b, c];
full_idx.extend_from_slice(&spatial_idx);
let val = input[IxDyn(&full_idx)];
if val > max_val {
max_val = val;
}
}
if max_val == f64::NEG_INFINITY {
max_val = 0.0;
}
output[IxDyn(&[b, c])] = max_val;
}
}
let _ = num_spatial;
Ok(output)
}
pub fn global_avg_pool(input: &ArrayD<f64>) -> Result<ArrayD<f64>, PoolingError> {
if input.is_empty() {
return Err(PoolingError::EmptyInput);
}
if input.ndim() < 3 {
return Err(PoolingError::InsufficientDimensions {
ndim: input.ndim(),
required: 3,
});
}
let shape = input.shape();
let batch = shape[0];
let channels = shape[1];
let spatial_size: usize = shape[2..].iter().product();
let mut output = ArrayD::zeros(IxDyn(&[batch, channels]));
for b in 0..batch {
for c in 0..channels {
let mut sum = 0.0;
for s in 0..spatial_size {
let spatial_idx = flat_to_spatial_indices(s, &shape[2..]);
let mut full_idx = vec![b, c];
full_idx.extend_from_slice(&spatial_idx);
sum += input[IxDyn(&full_idx)];
}
output[IxDyn(&[b, c])] = sum / spatial_size as f64;
}
}
Ok(output)
}
fn flat_to_spatial_indices(mut flat: usize, spatial_shape: &[usize]) -> Vec<usize> {
let mut indices = vec![0usize; spatial_shape.len()];
for d in (0..spatial_shape.len()).rev() {
indices[d] = flat % spatial_shape[d];
flat /= spatial_shape[d];
}
indices
}
pub fn adaptive_avg_pool(
input: &ArrayD<f64>,
output_size: &[usize],
) -> Result<ArrayD<f64>, PoolingError> {
if input.is_empty() {
return Err(PoolingError::EmptyInput);
}
let num_spatial = output_size.len();
if input.ndim() < num_spatial + 2 {
return Err(PoolingError::InsufficientDimensions {
ndim: input.ndim(),
required: num_spatial + 2,
});
}
let shape = input.shape();
let spatial_offset = shape.len() - num_spatial;
let input_spatial: Vec<usize> = shape[spatial_offset..].to_vec();
let mut out_shape: Vec<usize> = shape[..spatial_offset].to_vec();
out_shape.extend_from_slice(output_size);
let mut output = ArrayD::zeros(IxDyn(&out_shape));
let n_outer = num_outer_slices(shape, num_spatial);
for outer_flat in 0..n_outer {
let outer_idx = flat_to_outer_indices(outer_flat, shape, num_spatial);
let mut out_pos = vec![0usize; num_spatial];
loop {
let mut ranges: Vec<(usize, usize)> = Vec::with_capacity(num_spatial);
for d in 0..num_spatial {
let in_size = input_spatial[d];
let out_sz = output_size[d];
let start = (out_pos[d] * in_size) / out_sz;
let end = ((out_pos[d] + 1) * in_size) / out_sz;
ranges.push((start, end));
}
let mut sum = 0.0;
let mut count = 0usize;
let mut win_pos = vec![0usize; num_spatial];
for d in 0..num_spatial {
win_pos[d] = ranges[d].0;
}
loop {
let val = get_spatial_value(input, &outer_idx, &win_pos, num_spatial);
sum += val;
count += 1;
if !advance_within_ranges(&mut win_pos, &ranges) {
break;
}
}
let avg = if count > 0 { sum / count as f64 } else { 0.0 };
let mut full_idx: Vec<usize> = outer_idx.clone();
full_idx.extend_from_slice(&out_pos);
output[IxDyn(&full_idx)] = avg;
if !advance_indices(&mut out_pos, output_size) {
break;
}
}
}
Ok(output)
}
fn advance_within_ranges(indices: &mut [usize], ranges: &[(usize, usize)]) -> bool {
for d in (0..indices.len()).rev() {
indices[d] += 1;
if indices[d] < ranges[d].1 {
return true;
}
indices[d] = ranges[d].0;
}
false
}
pub fn max_unpool(
pooled: &ArrayD<f64>,
indices: &ArrayD<i64>,
output_size: &[usize],
) -> Result<ArrayD<f64>, PoolingError> {
if pooled.shape() != indices.shape() {
return Err(PoolingError::ShapeMismatch(format!(
"pooled shape {:?} != indices shape {:?}",
pooled.shape(),
indices.shape()
)));
}
if pooled.is_empty() {
return Err(PoolingError::EmptyInput);
}
let pooled_shape = pooled.shape();
if output_size.len() != pooled_shape.len() {
return Err(PoolingError::ShapeMismatch(format!(
"output_size len {} != pooled ndim {}",
output_size.len(),
pooled_shape.len()
)));
}
let num_spatial = pooled_shape.len().saturating_sub(2);
let spatial_offset = pooled_shape.len() - num_spatial;
let output_spatial: Vec<usize> = output_size[spatial_offset..].to_vec();
let mut output = ArrayD::zeros(IxDyn(output_size));
let n_outer = num_outer_slices(pooled_shape, num_spatial);
let output_spatial_total: usize = output_spatial.iter().product();
for outer_flat in 0..n_outer {
let outer_idx = flat_to_outer_indices(outer_flat, pooled_shape, num_spatial);
let pooled_spatial: Vec<usize> = pooled_shape[spatial_offset..].to_vec();
let mut pos = vec![0usize; num_spatial];
loop {
let mut pooled_full: Vec<usize> = outer_idx.clone();
pooled_full.extend_from_slice(&pos);
let val = pooled[IxDyn(&pooled_full)];
let idx = indices[IxDyn(&pooled_full)];
if idx >= 0 && (idx as usize) < output_spatial_total {
let spatial_pos = flat_to_spatial_indices(idx as usize, &output_spatial);
let mut out_full: Vec<usize> = outer_idx.clone();
out_full.extend_from_slice(&spatial_pos);
output[IxDyn(&out_full)] = val;
}
if !advance_indices(&mut pos, &pooled_spatial) {
break;
}
}
}
Ok(output)
}
#[derive(Debug, Clone)]
pub struct PoolingStats {
pub input_shape: Vec<usize>,
pub output_shape: Vec<usize>,
pub kernel_size: Vec<usize>,
pub stride: Vec<usize>,
pub receptive_field_size: usize,
pub compression_ratio: f64,
pub overlap_ratio: f64,
}
impl PoolingStats {
pub fn compute(input_shape: &[usize], config: &PoolConfig) -> Result<Self, PoolingError> {
config.validate()?;
let num_spatial = config.num_spatial_dims();
if input_shape.len() < num_spatial + 2 {
return Err(PoolingError::InsufficientDimensions {
ndim: input_shape.len(),
required: num_spatial + 2,
});
}
let output_shape = compute_output_shape(input_shape, config)?;
let spatial_offset = input_shape.len() - num_spatial;
let input_spatial_size: usize = input_shape[spatial_offset..].iter().product();
let output_spatial_size: usize = output_shape[spatial_offset..].iter().product();
let receptive_field_size: usize = config.kernel_size.iter().product();
let compression_ratio = if output_spatial_size > 0 {
input_spatial_size as f64 / output_spatial_size as f64
} else {
f64::INFINITY
};
let mut overlap_sum = 0.0;
for d in 0..num_spatial {
let k = config.kernel_size.get(d).copied().unwrap_or(1) as f64;
let s = config.effective_stride(d) as f64;
let overlap = ((k - s) / k).max(0.0);
overlap_sum += overlap;
}
let overlap_ratio = if num_spatial > 0 {
overlap_sum / num_spatial as f64
} else {
0.0
};
let effective_stride: Vec<usize> = (0..num_spatial)
.map(|d| config.effective_stride(d))
.collect();
Ok(Self {
input_shape: input_shape.to_vec(),
output_shape,
kernel_size: config.kernel_size.clone(),
stride: effective_stride,
receptive_field_size,
compression_ratio,
overlap_ratio,
})
}
pub fn summary(&self) -> String {
format!(
"Pooling: {:?} -> {:?}, kernel={:?}, stride={:?}, \
receptive_field={}, compression={:.2}x, overlap={:.2}",
self.input_shape,
self.output_shape,
self.kernel_size,
self.stride,
self.receptive_field_size,
self.compression_ratio,
self.overlap_ratio,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::ArrayD;
fn make_4d(data: Vec<f64>, h: usize, w: usize) -> ArrayD<f64> {
ArrayD::from_shape_vec(IxDyn(&[1, 1, h, w]), data)
.expect("test tensor creation should succeed")
}
#[test]
fn test_pool_config_output_size() {
let config = PoolConfig::new(vec![2, 2]);
assert_eq!(config.output_size(4, 0), 2);
assert_eq!(config.output_size(4, 1), 2);
}
#[test]
fn test_pool_config_output_size_with_padding() {
let config = PoolConfig::new(vec![2, 2]).with_padding(vec![1, 1]);
assert_eq!(config.output_size(4, 0), 3);
}
#[test]
fn test_pool_config_validate_valid() {
let config = PoolConfig::new(vec![2, 2]);
assert!(config.validate().is_ok());
}
#[test]
fn test_pool_config_validate_zero_kernel() {
let config = PoolConfig::new(vec![0, 2]);
let err = config.validate();
assert!(err.is_err());
match err {
Err(PoolingError::InvalidKernelSize { size: 0 }) => {}
other => panic!("Expected InvalidKernelSize, got {:?}", other),
}
}
#[test]
fn test_max_pool_basic() {
#[rustfmt::skip]
let data = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let input = make_4d(data, 4, 4);
let config = PoolConfig::new(vec![2, 2]);
let output = max_pool(&input, &config).expect("max_pool should succeed");
assert_eq!(output.shape(), &[1, 1, 2, 2]);
assert_eq!(output[IxDyn(&[0, 0, 0, 0])], 6.0);
assert_eq!(output[IxDyn(&[0, 0, 0, 1])], 8.0);
assert_eq!(output[IxDyn(&[0, 0, 1, 0])], 14.0);
assert_eq!(output[IxDyn(&[0, 0, 1, 1])], 16.0);
}
#[test]
fn test_max_pool_with_indices_correct() {
#[rustfmt::skip]
let data = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let input = make_4d(data, 4, 4);
let config = PoolConfig::new(vec![2, 2]);
let (output, indices) =
max_pool_with_indices(&input, &config).expect("max_pool_with_indices should succeed");
assert_eq!(output.shape(), &[1, 1, 2, 2]);
assert_eq!(output[IxDyn(&[0, 0, 0, 0])], 6.0);
assert_eq!(indices[IxDyn(&[0, 0, 0, 0])], 5);
assert_eq!(output[IxDyn(&[0, 0, 0, 1])], 8.0);
assert_eq!(indices[IxDyn(&[0, 0, 0, 1])], 7);
assert_eq!(output[IxDyn(&[0, 0, 1, 0])], 14.0);
assert_eq!(indices[IxDyn(&[0, 0, 1, 0])], 13);
assert_eq!(output[IxDyn(&[0, 0, 1, 1])], 16.0);
assert_eq!(indices[IxDyn(&[0, 0, 1, 1])], 15);
}
#[test]
fn test_avg_pool_basic() {
#[rustfmt::skip]
let data = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let input = make_4d(data, 4, 4);
let config = PoolConfig::new(vec![2, 2]);
let output = avg_pool(&input, &config).expect("avg_pool should succeed");
assert_eq!(output.shape(), &[1, 1, 2, 2]);
assert!((output[IxDyn(&[0, 0, 0, 0])] - 3.5).abs() < 1e-10);
assert!((output[IxDyn(&[0, 0, 0, 1])] - 5.5).abs() < 1e-10);
assert!((output[IxDyn(&[0, 0, 1, 0])] - 11.5).abs() < 1e-10);
assert!((output[IxDyn(&[0, 0, 1, 1])] - 13.5).abs() < 1e-10);
}
#[test]
fn test_avg_pool_padding() {
let data = vec![1.0; 16];
let input = make_4d(data, 4, 4);
let config = PoolConfig::new(vec![2, 2]).with_padding(vec![1, 1]);
let output = avg_pool(&input, &config).expect("avg_pool with padding should succeed");
assert_eq!(output.shape(), &[1, 1, 3, 3]);
}
#[test]
fn test_lp_pool_p2() {
#[rustfmt::skip]
let data = vec![
1.0, 2.0,
3.0, 4.0,
];
let input = make_4d(data, 2, 2);
let config = PoolConfig::new(vec![2, 2]);
let output = lp_pool(&input, &config, 2.0).expect("lp_pool p=2 should succeed");
assert_eq!(output.shape(), &[1, 1, 1, 1]);
let expected = (7.5_f64).sqrt();
assert!((output[IxDyn(&[0, 0, 0, 0])] - expected).abs() < 1e-10);
}
#[test]
fn test_lp_pool_p1() {
#[rustfmt::skip]
let data = vec![
1.0, -2.0,
3.0, -4.0,
];
let input = make_4d(data, 2, 2);
let config = PoolConfig::new(vec![2, 2]);
let output = lp_pool(&input, &config, 1.0).expect("lp_pool p=1 should succeed");
assert_eq!(output.shape(), &[1, 1, 1, 1]);
assert!((output[IxDyn(&[0, 0, 0, 0])] - 2.5).abs() < 1e-10);
}
#[test]
fn test_global_max_pool_shape() {
let input = ArrayD::zeros(IxDyn(&[1, 3, 4, 4]));
let output = global_max_pool(&input).expect("global_max_pool should succeed");
assert_eq!(output.shape(), &[1, 3]);
}
#[test]
fn test_global_max_pool_values() {
let mut input = ArrayD::zeros(IxDyn(&[1, 3, 4, 4]));
input[IxDyn(&[0, 0, 2, 3])] = 42.0;
input[IxDyn(&[0, 1, 0, 0])] = 99.0;
input[IxDyn(&[0, 2, 3, 3])] = -1.0;
let output = global_max_pool(&input).expect("global_max_pool should succeed");
assert_eq!(output[IxDyn(&[0, 0])], 42.0);
assert_eq!(output[IxDyn(&[0, 1])], 99.0);
assert_eq!(output[IxDyn(&[0, 2])], 0.0); }
#[test]
fn test_global_avg_pool_shape() {
let input = ArrayD::zeros(IxDyn(&[1, 3, 4, 4]));
let output = global_avg_pool(&input).expect("global_avg_pool should succeed");
assert_eq!(output.shape(), &[1, 3]);
}
#[test]
fn test_global_avg_pool_values() {
let mut input = ArrayD::ones(IxDyn(&[1, 2, 2, 2]));
input[IxDyn(&[0, 1, 0, 0])] = 2.0;
input[IxDyn(&[0, 1, 0, 1])] = 2.0;
input[IxDyn(&[0, 1, 1, 0])] = 2.0;
input[IxDyn(&[0, 1, 1, 1])] = 2.0;
let output = global_avg_pool(&input).expect("global_avg_pool should succeed");
assert!((output[IxDyn(&[0, 0])] - 1.0).abs() < 1e-10);
assert!((output[IxDyn(&[0, 1])] - 2.0).abs() < 1e-10);
}
#[test]
fn test_adaptive_avg_pool_output_size() {
let input = ArrayD::ones(IxDyn(&[1, 1, 4, 4]));
let output = adaptive_avg_pool(&input, &[2, 2]).expect("adaptive_avg_pool should succeed");
assert_eq!(output.shape(), &[1, 1, 2, 2]);
}
#[test]
fn test_adaptive_avg_pool_identity() {
#[rustfmt::skip]
let data = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let input = make_4d(data.clone(), 4, 4);
let output =
adaptive_avg_pool(&input, &[4, 4]).expect("adaptive_avg_pool identity should succeed");
assert_eq!(output.shape(), &[1, 1, 4, 4]);
for (i, &v) in data.iter().enumerate() {
let h = i / 4;
let w = i % 4;
assert!(
(output[IxDyn(&[0, 0, h, w])] - v).abs() < 1e-10,
"mismatch at ({}, {})",
h,
w
);
}
}
#[test]
fn test_max_unpool_basic() {
#[rustfmt::skip]
let data = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let input = make_4d(data, 4, 4);
let config = PoolConfig::new(vec![2, 2]);
let (pooled, indices) =
max_pool_with_indices(&input, &config).expect("max_pool_with_indices should succeed");
let unpooled =
max_unpool(&pooled, &indices, &[1, 1, 4, 4]).expect("max_unpool should succeed");
assert_eq!(unpooled.shape(), &[1, 1, 4, 4]);
assert_eq!(unpooled[IxDyn(&[0, 0, 1, 1])], 6.0); assert_eq!(unpooled[IxDyn(&[0, 0, 1, 3])], 8.0); assert_eq!(unpooled[IxDyn(&[0, 0, 3, 1])], 14.0); assert_eq!(unpooled[IxDyn(&[0, 0, 3, 3])], 16.0); assert_eq!(unpooled[IxDyn(&[0, 0, 0, 0])], 0.0);
assert_eq!(unpooled[IxDyn(&[0, 0, 2, 2])], 0.0);
}
#[test]
fn test_pooling_stats_compression() {
let config = PoolConfig::new(vec![2, 2]);
let stats =
PoolingStats::compute(&[1, 1, 4, 4], &config).expect("stats compute should succeed");
assert_eq!(stats.output_shape, vec![1, 1, 2, 2]);
assert!((stats.compression_ratio - 4.0).abs() < 1e-10);
assert_eq!(stats.receptive_field_size, 4);
assert!((stats.overlap_ratio - 0.0).abs() < 1e-10);
}
#[test]
fn test_pooling_stats_summary() {
let config = PoolConfig::new(vec![2, 2]);
let stats =
PoolingStats::compute(&[1, 1, 4, 4], &config).expect("stats compute should succeed");
let summary = stats.summary();
assert!(!summary.is_empty());
assert!(summary.contains("Pooling"));
}
#[test]
fn test_pooling_error_display() {
let errors = vec![
PoolingError::InvalidKernelSize { size: 0 },
PoolingError::InvalidStride { stride: 0 },
PoolingError::InvalidPadding {
padding: 3,
kernel_size: 2,
},
PoolingError::InsufficientDimensions {
ndim: 2,
required: 4,
},
PoolingError::EmptyInput,
PoolingError::ShapeMismatch("test".to_string()),
];
for err in &errors {
let msg = format!("{err}");
assert!(!msg.is_empty(), "Error display should not be empty");
}
}
}