use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::helper_function::calculate_output_shape_1d_pooling;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::layer::pooling_layer::input_validation_function::{
validate_input_shape_dims, validate_pool_size_1d, validate_stride_1d,
};
use crate::neural_network::neural_network_trait::Layer;
use ndarray::Array3;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
const MAX_POOLING_1D_PARALLEL_THRESHOLD: usize = 32;
pub struct MaxPooling1D {
pool_size: usize,
stride: usize,
input_shape: Vec<usize>,
input_cache: Option<Tensor>,
max_positions: Option<Array3<usize>>,
}
impl MaxPooling1D {
pub fn new(
pool_size: usize,
stride: usize,
input_shape: Vec<usize>,
) -> Result<Self, ModelError> {
validate_input_shape_dims(&input_shape, 3, "MaxPooling1D")?;
validate_pool_size_1d(pool_size, input_shape[2])?;
validate_stride_1d(stride)?;
Ok(MaxPooling1D {
pool_size,
stride,
input_shape,
input_cache: None,
max_positions: None,
})
}
}
impl Layer for MaxPooling1D {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
if input.ndim() != 3 {
return Err(ModelError::InputValidationError(
"input tensor is not 3D".to_string(),
));
}
self.input_cache = Some(input.clone());
let batch_size = input.shape()[0];
let channels = input.shape()[1];
let length = input.shape()[2];
let output_length = (length - self.pool_size) / self.stride + 1;
let mut output = Array3::<f32>::zeros((batch_size, channels, output_length));
let mut max_positions = Array3::<usize>::zeros((batch_size, channels, output_length));
let pool_size = self.pool_size;
let stride = self.stride;
let compute_pooling = |b: usize, c: usize| {
let mut channel_output = Vec::new();
let mut channel_max_pos = Vec::new();
for i in 0..output_length {
let start_idx = i * stride;
let end_idx = start_idx + pool_size;
let mut max_val = input[[b, c, start_idx]];
let mut max_idx = start_idx;
for j in (start_idx + 1)..end_idx {
if input[[b, c, j]] > max_val {
max_val = input[[b, c, j]];
max_idx = j;
}
}
channel_output.push((i, max_val));
channel_max_pos.push((i, max_idx));
}
((b, c), (channel_output, channel_max_pos))
};
let results: Vec<_> = execute_parallel_or_sequential!(
batch_size,
channels,
MAX_POOLING_1D_PARALLEL_THRESHOLD,
compute_pooling
);
for ((b, c), (channel_output, channel_max_pos)) in results {
for (i, val) in channel_output {
output[[b, c, i]] = val;
}
for (i, pos) in channel_max_pos {
max_positions[[b, c, i]] = pos;
}
}
self.max_positions = Some(max_positions);
Ok(output.into_dyn())
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
let input = match &self.input_cache {
Some(input) => input,
None => {
return Err(ModelError::ProcessingError(
"No cached input for MaxPooling1D".to_string(),
));
}
};
let max_positions = match &self.max_positions {
Some(positions) => positions,
None => {
return Err(ModelError::ProcessingError(
"No cached max positions for MaxPooling1D".to_string(),
));
}
};
let batch_size = input.shape()[0];
let channels = input.shape()[1];
let length = input.shape()[2];
let output_length = grad_output.shape()[2];
let mut grad_input = Array3::<f32>::zeros((batch_size, channels, length));
let compute_gradient = |b: usize, c: usize| {
let mut channel_grad = vec![0.0; length];
for i in 0..output_length {
let max_idx = max_positions[[b, c, i]];
channel_grad[max_idx] += grad_output[[b, c, i]];
}
((b, c), channel_grad)
};
let results: Vec<_> = execute_parallel_or_sequential!(
batch_size,
channels,
MAX_POOLING_1D_PARALLEL_THRESHOLD,
compute_gradient
);
merge_gradients_1d!(grad_input, results, length);
Ok(grad_input.into_dyn())
}
fn layer_type(&self) -> &str {
"MaxPooling1D"
}
layer_functions_1d_pooling!();
}