use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::helper_function::calculate_output_shape_2d_pooling;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::layer::pooling_layer::input_validation_function::{
validate_input_shape_dims, validate_pool_size_2d, validate_strides_2d,
};
use crate::neural_network::neural_network_trait::Layer;
use ndarray::ArrayD;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
const AVERAGE_POOLING_2D_PARALLEL_THRESHOLD: usize = 32;
pub struct AveragePooling2D {
pool_size: (usize, usize),
strides: (usize, usize),
input_shape: Vec<usize>,
input_cache: Option<Tensor>,
}
impl AveragePooling2D {
pub fn new(
pool_size: (usize, usize),
input_shape: Vec<usize>,
strides: Option<(usize, usize)>,
) -> Result<Self, ModelError> {
let strides = strides.unwrap_or(pool_size);
validate_input_shape_dims(&input_shape, 4, "AveragePooling2D")?;
validate_pool_size_2d(pool_size)?;
validate_strides_2d(strides)?;
Ok(AveragePooling2D {
pool_size,
strides,
input_shape,
input_cache: None,
})
}
fn avg_pool(&self, input: &Tensor) -> Tensor {
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let output_shape =
calculate_output_shape_2d_pooling(input_shape, self.pool_size, self.strides);
let mut output = ArrayD::zeros(output_shape.clone());
let pool_size = self.pool_size;
let strides = self.strides;
let compute_pooling = |b: usize, c: usize| {
let mut batch_channel_output = Vec::new();
for i in 0..output_shape[2] {
let i_start = i * strides.0;
for j in 0..output_shape[3] {
let j_start = j * strides.1;
let mut sum = 0.0;
let mut count = 0;
for di in 0..pool_size.0 {
let i_pos = i_start + di;
if i_pos >= input_shape[2] {
continue;
}
for dj in 0..pool_size.1 {
let j_pos = j_start + dj;
if j_pos >= input_shape[3] {
continue;
}
sum += input[[b, c, i_pos, j_pos]];
count += 1;
}
}
let avg_val = if count > 0 { sum / count as f32 } else { 0.0 };
batch_channel_output.push((i, j, avg_val));
}
}
((b, c), batch_channel_output)
};
let results: Vec<_> = execute_parallel_or_sequential!(
batch_size,
channels,
AVERAGE_POOLING_2D_PARALLEL_THRESHOLD,
compute_pooling
);
for ((b, c), outputs) in results {
for (i, j, val) in outputs {
output[[b, c, i, j]] = val;
}
}
output
}
}
impl Layer for AveragePooling2D {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
if input.ndim() != 4 {
return Err(ModelError::InputValidationError(
"input tensor is not 4D".to_string(),
));
}
self.input_cache = Some(input.clone());
Ok(self.avg_pool(input))
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
if let Some(input) = &self.input_cache {
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let height = input_shape[2];
let width = input_shape[3];
let mut input_grad = ArrayD::zeros(input_shape.to_vec());
let output_shape = grad_output.shape();
let pool_size = self.pool_size;
let strides = self.strides;
let compute_gradient = |b: usize, c: usize| {
let mut spatial_grad = vec![0.0f32; height * width];
for i in 0..output_shape[2] {
let i_start = i * strides.0;
for j in 0..output_shape[3] {
let j_start = j * strides.1;
let grad = grad_output[[b, c, i, j]];
let mut count = 0;
for di in 0..pool_size.0 {
let i_pos = i_start + di;
if i_pos >= height {
break; }
for dj in 0..pool_size.1 {
let j_pos = j_start + dj;
if j_pos >= width {
break; }
count += 1;
}
}
if count > 0 {
let grad_per_element = grad / count as f32;
for di in 0..pool_size.0 {
let i_pos = i_start + di;
if i_pos >= height {
break;
}
for dj in 0..pool_size.1 {
let j_pos = j_start + dj;
if j_pos >= width {
break;
}
spatial_grad[i_pos * width + j_pos] += grad_per_element;
}
}
}
}
}
((b, c), spatial_grad)
};
let results: Vec<_> = execute_parallel_or_sequential!(
batch_size,
channels,
AVERAGE_POOLING_2D_PARALLEL_THRESHOLD,
compute_gradient
);
merge_gradients_2d!(input_grad, results, height, width);
Ok(input_grad)
} else {
Err(ModelError::ProcessingError(
"Forward pass has not been run yet".to_string(),
))
}
}
fn layer_type(&self) -> &str {
"AveragePooling2D"
}
layer_functions_2d_pooling!();
}