use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::helper_function::calculate_output_shape_3d_pooling;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::layer::pooling_layer::input_validation_function::{
validate_input_shape_dims, validate_pool_size_3d, validate_strides_3d,
};
use crate::neural_network::neural_network_trait::Layer;
use ndarray::{Array, IxDyn};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
const AVERAGE_POOLING_3D_PARALLEL_THRESHOLD: usize = 32;
pub struct AveragePooling3D {
pool_size: (usize, usize, usize),
strides: (usize, usize, usize),
input_shape: Vec<usize>,
input_cache: Option<Tensor>,
}
impl AveragePooling3D {
pub fn new(
pool_size: (usize, usize, usize),
input_shape: Vec<usize>,
strides: Option<(usize, usize, usize)>,
) -> Result<Self, ModelError> {
let strides = strides.unwrap_or(pool_size);
validate_input_shape_dims(&input_shape, 5, "AveragePooling3D")?;
validate_pool_size_3d(pool_size)?;
validate_strides_3d(strides)?;
Ok(Self {
pool_size,
strides,
input_shape,
input_cache: None,
})
}
}
impl Layer for AveragePooling3D {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
if input.ndim() != 5 {
return Err(ModelError::InputValidationError(
"input tensor is not 5D".to_string(),
));
}
let input_shape = input.shape();
self.input_cache = Some(input.clone());
let output_shape =
calculate_output_shape_3d_pooling(input_shape, self.pool_size, self.strides);
let mut output = Array::zeros(IxDyn(&output_shape));
let batch_size = input_shape[0];
let channels = input_shape[1];
let input_depth = input_shape[2];
let input_height = input_shape[3];
let input_width = input_shape[4];
let output_depth = output_shape[2];
let output_height = output_shape[3];
let output_width = output_shape[4];
let pool_size = self.pool_size;
let strides = self.strides;
let compute_pooling = |b: usize, c: usize| {
let mut local_results = Vec::new();
for od in 0..output_depth {
for oh in 0..output_height {
for ow in 0..output_width {
let start_d = od * strides.0;
let start_h = oh * strides.1;
let start_w = ow * strides.2;
let end_d = (start_d + pool_size.0).min(input_depth);
let end_h = (start_h + pool_size.1).min(input_height);
let end_w = (start_w + pool_size.2).min(input_width);
let mut sum = 0.0;
let mut count = 0;
for d in start_d..end_d {
for h in start_h..end_h {
for w in start_w..end_w {
sum += input[[b, c, d, h, w]];
count += 1;
}
}
}
let pooled_value = if count > 0 { sum / count as f32 } else { 0.0 };
local_results.push(((od, oh, ow), pooled_value));
}
}
}
((b, c), local_results)
};
let results: Vec<_> = execute_parallel_or_sequential!(
batch_size,
channels,
AVERAGE_POOLING_3D_PARALLEL_THRESHOLD,
compute_pooling
);
for ((b, c), local_results) in results {
for ((od, oh, ow), value) in local_results {
output[[b, c, od, oh, ow]] = value;
}
}
Ok(output)
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
let input = self.input_cache.as_ref().ok_or_else(|| {
ModelError::ProcessingError("Forward pass has not been run yet".to_string())
})?;
let input_shape = input.shape();
let mut grad_input = Array::zeros(IxDyn(input_shape));
let batch_size = input_shape[0];
let channels = input_shape[1];
let input_depth = input_shape[2];
let input_height = input_shape[3];
let input_width = input_shape[4];
let output_depth = grad_output.shape()[2];
let output_height = grad_output.shape()[3];
let output_width = grad_output.shape()[4];
let pool_size = self.pool_size;
let strides = self.strides;
let compute_gradient = |b: usize, c: usize| {
let spatial_volume = input_depth * input_height * input_width;
let mut spatial_grad = vec![0.0f32; spatial_volume];
for od in 0..output_depth {
for oh in 0..output_height {
for ow in 0..output_width {
let start_d = od * strides.0;
let start_h = oh * strides.1;
let start_w = ow * strides.2;
let end_d = (start_d + pool_size.0).min(input_depth);
let end_h = (start_h + pool_size.1).min(input_height);
let end_w = (start_w + pool_size.2).min(input_width);
let actual_count =
((end_d - start_d) * (end_h - start_h) * (end_w - start_w)) as f32;
let grad_value = if actual_count > 0.0 {
grad_output[[b, c, od, oh, ow]] / actual_count
} else {
0.0
};
for d in start_d..end_d {
for h in start_h..end_h {
for w in start_w..end_w {
let idx = d * input_height * input_width + h * input_width + w;
spatial_grad[idx] += grad_value;
}
}
}
}
}
}
((b, c), spatial_grad)
};
let results: Vec<_> = execute_parallel_or_sequential!(
batch_size,
channels,
AVERAGE_POOLING_3D_PARALLEL_THRESHOLD,
compute_gradient
);
merge_gradients_3d!(grad_input, results, input_depth, input_height, input_width);
Ok(grad_input)
}
fn layer_type(&self) -> &str {
"AveragePooling3D"
}
layer_functions_3d_pooling!();
}