use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::helper_function::calculate_output_shape_2d_pooling;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::layer::pooling_layer::input_validation_function::{
validate_input_shape_dims, validate_pool_size_2d, validate_strides_2d,
};
use crate::neural_network::neural_network_trait::Layer;
use ndarray::ArrayD;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
const MAX_POOLING_2D_PARALLEL_THRESHOLD: usize = 32;
pub struct MaxPooling2D {
pool_size: (usize, usize),
strides: (usize, usize),
input_shape: Vec<usize>,
input_cache: Option<Tensor>,
max_positions: Option<Vec<(usize, usize, usize, usize, usize, usize)>>,
}
impl MaxPooling2D {
pub fn new(
pool_size: (usize, usize),
input_shape: Vec<usize>,
strides: Option<(usize, usize)>,
) -> Result<Self, ModelError> {
let strides = strides.unwrap_or(pool_size);
validate_input_shape_dims(&input_shape, 4, "MaxPooling2D")?;
validate_pool_size_2d(pool_size)?;
validate_strides_2d(strides)?;
Ok(MaxPooling2D {
pool_size,
strides,
input_shape,
input_cache: None,
max_positions: None,
})
}
fn calculate_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
calculate_output_shape_2d_pooling(input_shape, self.pool_size, self.strides)
}
fn max_pool(
&self,
input: &Tensor,
) -> (Tensor, Vec<(usize, usize, usize, usize, usize, usize)>) {
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let output_shape = self.calculate_output_shape(input_shape);
let mut output = ArrayD::zeros(output_shape.clone());
let mut max_positions = Vec::new();
let compute_pooling = |b: usize, c: usize| {
let mut batch_channel_output = Vec::new();
let mut batch_channel_positions = Vec::new();
for out_i in 0..output_shape[2] {
let i_start = out_i * self.strides.0;
for out_j in 0..output_shape[3] {
let j_start = out_j * self.strides.1;
let mut max_val = f32::NEG_INFINITY;
let mut max_pos = (0, 0);
for di in 0..self.pool_size.0 {
let i_pos = i_start + di;
if i_pos >= input_shape[2] {
continue;
}
for dj in 0..self.pool_size.1 {
let j_pos = j_start + dj;
if j_pos >= input_shape[3] {
continue;
}
let val = input[[b, c, i_pos, j_pos]];
if val > max_val {
max_val = val;
max_pos = (i_pos, j_pos);
}
}
}
batch_channel_output.push((out_i, out_j, max_val));
batch_channel_positions.push((b, c, out_i, out_j, max_pos.0, max_pos.1));
}
}
((b, c), (batch_channel_output, batch_channel_positions))
};
let results: Vec<_> = execute_parallel_or_sequential!(
batch_size,
channels,
MAX_POOLING_2D_PARALLEL_THRESHOLD,
compute_pooling
);
for ((b, c), (outputs, positions)) in results {
for (i, j, val) in outputs {
output[[b, c, i, j]] = val;
}
max_positions.extend(positions);
}
(output, max_positions)
}
}
impl Layer for MaxPooling2D {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
if input.ndim() != 4 {
return Err(ModelError::InputValidationError(
"input tensor is not 4D".to_string(),
));
}
self.input_cache = Some(input.clone());
let (output, max_positions) = self.max_pool(input);
self.max_positions = Some(max_positions);
Ok(output)
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
if let (Some(input), Some(max_positions)) = (&self.input_cache, &self.max_positions) {
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let height = input_shape[2];
let width = input_shape[3];
let mut input_gradients = ArrayD::zeros(input_shape.to_vec());
let mut positions_by_bc: std::collections::HashMap<
(usize, usize),
Vec<(usize, usize, usize, usize)>,
> = std::collections::HashMap::new();
for &(b, c, out_i, out_j, in_i, in_j) in max_positions.iter() {
positions_by_bc
.entry((b, c))
.or_insert_with(Vec::new)
.push((out_i, out_j, in_i, in_j));
}
let compute_gradient = |b: usize, c: usize| {
let mut spatial_grad = vec![0.0; height * width];
if let Some(positions) = positions_by_bc.get(&(b, c)) {
for &(out_i, out_j, in_i, in_j) in positions {
let flat_idx = in_i * width + in_j;
spatial_grad[flat_idx] += grad_output[[b, c, out_i, out_j]];
}
}
((b, c), spatial_grad)
};
let results: Vec<_> = execute_parallel_or_sequential!(
batch_size,
channels,
MAX_POOLING_2D_PARALLEL_THRESHOLD,
compute_gradient
);
merge_gradients_2d!(input_gradients, results, height, width);
Ok(input_gradients)
} else {
Err(ModelError::ProcessingError(
"Forward pass has not been run".to_string(),
))
}
}
fn layer_type(&self) -> &str {
"MaxPooling2D"
}
layer_functions_2d_pooling!();
}