rustyml 0.11.0

A high-performance machine learning & deep learning library in pure Rust, offering ML algorithms and neural network support
Documentation
use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::helper_function::calculate_output_shape_3d_pooling;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::layer::pooling_layer::input_validation_function::{
    validate_input_shape_dims, validate_pool_size_3d, validate_strides_3d,
};
use crate::neural_network::neural_network_trait::Layer;
use ndarray::{Array, IxDyn};
use rayon::iter::{IntoParallelIterator, ParallelIterator};

/// Threshold for determining when to use parallel vs sequential execution.
/// When batch_size * channels >= this threshold, parallel execution is used.
/// Otherwise, sequential execution is used to avoid parallel overhead.
const AVERAGE_POOLING_3D_PARALLEL_THRESHOLD: usize = 32;

/// 3D average pooling layer.
///
/// Computes the mean value over each pooling window across depth, height, and width.
/// Input tensor shape: `[batch_size, channels, depth, height, width]`. Output tensor shape:
/// `[batch_size, channels, pooled_depth, pooled_height, pooled_width]` where
/// `pooled_depth = (depth - pool_size_d) / stride_d + 1`,
/// `pooled_height = (height - pool_size_h) / stride_h + 1`, and
/// `pooled_width = (width - pool_size_w) / stride_w + 1`.
///
/// # Fields
///
/// - `pool_size` - Size of the pooling window as (depth, height, width)
/// - `strides` - Step size of the pooling operation as (depth_stride, height_stride, width_stride)
/// - `input_shape` - Shape of the input tensor
/// - `input_cache` - Cached input tensor from the forward pass
///
/// # Examples
/// ```rust
/// use rustyml::neural_network::sequential::Sequential;
/// use rustyml::neural_network::layer::*;
/// use rustyml::neural_network::optimizer::*;
/// use rustyml::neural_network::loss_function::*;
/// use ndarray::{Array5, ArrayD};
///
/// // Create a Sequential model for 3D data processing
/// let mut model = Sequential::new();
///
/// // Add an AveragePooling3D layer to the model
/// model.add(AveragePooling3D::new(
///     (2, 2, 2),                    // Pooling window size: 2×2×2
///     vec![1, 16, 32, 32, 32],      // Input shape: [batch, channels, depth, height, width]
///     Some((2, 2, 2)),              // Strides: move by 2 in each dimension
/// ).unwrap());
///
/// // Compile the model with optimizer and loss function
/// model.compile(
///     RMSprop::new(0.001, 0.9, 1e-8).unwrap(),    // RMSprop optimizer
///     MeanSquaredError::new()            // Mean squared error loss
/// );
///
/// // Create example 3D input data (e.g., 3D medical imaging or volume data)
/// // Input: [1 batch, 16 channels, 32×32×32 3D volume]
/// let input_data = Array5::from_shape_fn((1, 16, 32, 32, 32), |(b, c, d, h, w)| {
///     // Generate example data with spatial patterns
///     ((d + h + w) as f32 * 0.1) + (c as f32 * 0.01)
/// }).into_dyn();
///
/// // Create target data for training (output shape: [1, 16, 16, 16, 16])
/// let target_data = Array5::ones((1, 16, 16, 16, 16)).into_dyn();
///
/// // Display the model architecture
/// model.summary();
///
/// // Train the model
/// model.fit(&input_data, &target_data, 5).unwrap();
///
/// // Make predictions on new data
/// let predictions = model.predict(&input_data).unwrap();
/// println!("Output shape after average pooling: {:?}", predictions.shape());
/// // Expected output: [1, 16, 16, 16, 16] (spatial dimensions are halved)
/// ```
///
/// # Performance
///
/// Parallel execution is used when `batch_size * channels >= AVERAGE_POOLING_3D_PARALLEL_THRESHOLD` (32).
pub struct AveragePooling3D {
    pool_size: (usize, usize, usize),
    strides: (usize, usize, usize),
    input_shape: Vec<usize>,
    input_cache: Option<Tensor>,
}

impl AveragePooling3D {
    /// Creates a new 3D average pooling layer.
    ///
    /// If `strides` is None, it defaults to `pool_size`.
    ///
    /// # Parameters
    ///
    /// - `pool_size` - Size of the pooling window as (depth, height, width)
    /// - `input_shape` - Input tensor shape `[batch_size, channels, depth, height, width]`
    /// - `strides` - Optional strides of the pooling operation
    ///
    /// # Returns
    ///
    /// - `Result<AveragePooling3D, ModelError>` - New layer instance on success
    ///
    /// # Errors
    ///
    /// - `ModelError::InputValidationError` - If `input_shape` is not 5D, `pool_size` has a zero
    ///   dimension, or any stride is zero
    pub fn new(
        pool_size: (usize, usize, usize),
        input_shape: Vec<usize>,
        strides: Option<(usize, usize, usize)>,
    ) -> Result<Self, ModelError> {
        let strides = strides.unwrap_or(pool_size);

        // input validation
        validate_input_shape_dims(&input_shape, 5, "AveragePooling3D")?;
        validate_pool_size_3d(pool_size)?;
        validate_strides_3d(strides)?;

        Ok(Self {
            pool_size,
            strides,
            input_shape,
            input_cache: None,
        })
    }
}

impl Layer for AveragePooling3D {
    fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
        // Validate input is 5D
        if input.ndim() != 5 {
            return Err(ModelError::InputValidationError(
                "input tensor is not 5D".to_string(),
            ));
        }

        let input_shape = input.shape();

        // Cache input for backpropagation
        self.input_cache = Some(input.clone());

        let output_shape =
            calculate_output_shape_3d_pooling(input_shape, self.pool_size, self.strides);
        let mut output = Array::zeros(IxDyn(&output_shape));

        let batch_size = input_shape[0];
        let channels = input_shape[1];
        let input_depth = input_shape[2];
        let input_height = input_shape[3];
        let input_width = input_shape[4];

        let output_depth = output_shape[2];
        let output_height = output_shape[3];
        let output_width = output_shape[4];

        // Copy needed values to avoid capturing self in closure
        let pool_size = self.pool_size;
        let strides = self.strides;

        // Helper closure to compute pooling for a single (batch, channel) pair
        let compute_pooling = |b: usize, c: usize| {
            let mut local_results = Vec::new();

            for od in 0..output_depth {
                for oh in 0..output_height {
                    for ow in 0..output_width {
                        let start_d = od * strides.0;
                        let start_h = oh * strides.1;
                        let start_w = ow * strides.2;

                        let end_d = (start_d + pool_size.0).min(input_depth);
                        let end_h = (start_h + pool_size.1).min(input_height);
                        let end_w = (start_w + pool_size.2).min(input_width);

                        // Calculate average value within the pooling window
                        let mut sum = 0.0;
                        let mut count = 0;
                        for d in start_d..end_d {
                            for h in start_h..end_h {
                                for w in start_w..end_w {
                                    sum += input[[b, c, d, h, w]];
                                    count += 1;
                                }
                            }
                        }

                        // Divide by actual number of elements, not theoretical pool size
                        let pooled_value = if count > 0 { sum / count as f32 } else { 0.0 };
                        local_results.push(((od, oh, ow), pooled_value));
                    }
                }
            }

            ((b, c), local_results)
        };

        // Choose parallel or sequential execution based on workload size
        let results: Vec<_> = execute_parallel_or_sequential!(
            batch_size,
            channels,
            AVERAGE_POOLING_3D_PARALLEL_THRESHOLD,
            compute_pooling
        );

        // Merge results into output tensor
        for ((b, c), local_results) in results {
            for ((od, oh, ow), value) in local_results {
                output[[b, c, od, oh, ow]] = value;
            }
        }

        Ok(output)
    }

    fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
        let input = self.input_cache.as_ref().ok_or_else(|| {
            ModelError::ProcessingError("Forward pass has not been run yet".to_string())
        })?;

        let input_shape = input.shape();
        let mut grad_input = Array::zeros(IxDyn(input_shape));

        let batch_size = input_shape[0];
        let channels = input_shape[1];
        let input_depth = input_shape[2];
        let input_height = input_shape[3];
        let input_width = input_shape[4];

        let output_depth = grad_output.shape()[2];
        let output_height = grad_output.shape()[3];
        let output_width = grad_output.shape()[4];

        // Copy needed values to avoid capturing self in closure
        let pool_size = self.pool_size;
        let strides = self.strides;

        // Helper closure to compute gradient for a single (batch, channel) pair
        let compute_gradient = |b: usize, c: usize| {
            // Allocate only the spatial volume for this channel
            let spatial_volume = input_depth * input_height * input_width;
            let mut spatial_grad = vec![0.0f32; spatial_volume];

            for od in 0..output_depth {
                for oh in 0..output_height {
                    for ow in 0..output_width {
                        let start_d = od * strides.0;
                        let start_h = oh * strides.1;
                        let start_w = ow * strides.2;

                        let end_d = (start_d + pool_size.0).min(input_depth);
                        let end_h = (start_h + pool_size.1).min(input_height);
                        let end_w = (start_w + pool_size.2).min(input_width);

                        // Calculate actual number of elements in this pooling window
                        let actual_count =
                            ((end_d - start_d) * (end_h - start_h) * (end_w - start_w)) as f32;

                        // Distribute gradient evenly to all elements in the pooling window
                        let grad_value = if actual_count > 0.0 {
                            grad_output[[b, c, od, oh, ow]] / actual_count
                        } else {
                            0.0
                        };

                        for d in start_d..end_d {
                            for h in start_h..end_h {
                                for w in start_w..end_w {
                                    let idx = d * input_height * input_width + h * input_width + w;
                                    spatial_grad[idx] += grad_value;
                                }
                            }
                        }
                    }
                }
            }

            ((b, c), spatial_grad)
        };

        // Choose parallel or sequential execution based on workload size
        let results: Vec<_> = execute_parallel_or_sequential!(
            batch_size,
            channels,
            AVERAGE_POOLING_3D_PARALLEL_THRESHOLD,
            compute_gradient
        );

        // Merge parallel computation results into output gradient tensor
        merge_gradients_3d!(grad_input, results, input_depth, input_height, input_width);

        Ok(grad_input)
    }

    fn layer_type(&self) -> &str {
        "AveragePooling3D"
    }

    layer_functions_3d_pooling!();
}