rustyml 0.11.0

A high-performance machine learning & deep learning library in pure Rust, offering ML algorithms and neural network support
Documentation
use crate::error::ModelError;
use ndarray::{Array, ArrayBase, ArrayViewMut1, Axis, Data, Dimension};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use rayon::prelude::IntoParallelRefMutIterator;

/// Threshold for enabling parallel computation in standardization
/// Arrays with fewer elements than this threshold will use sequential computation
/// to avoid the overhead of thread spawning and synchronization
const STANDARDIZE_PARALLEL_THRESHOLD: usize = 10000;

/// Defines the axis along which the standardization is applied
///
/// # Variants
///
/// - `Row` - Standardize across rows (each row is standardized independently)
/// - `Column` - Standardize across columns (each column is standardized independently)
/// - `Global` - Standardize the entire array globally
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StandardizationAxis {
    Row,
    Column,
    Global,
}

/// Standardizes data to have zero mean and unit variance
///
/// This function transforms input data by subtracting the mean and dividing
/// by the standard deviation for each feature, resulting in standardized data
/// where each feature has a mean of 0 and a standard deviation of 1.
///
/// # Parameters
///
/// - `data` - Input array data as `ArrayBase` with arbitrary dimensions and f64 elements
/// - `axis` - The axis along which to perform standardization (Row/Column/Global)
/// - `epsilon` - Small value added to standard deviation to prevent division by zero
///
/// # Returns
///
/// - `Result<Array<f64, D>, ModelError>` - Standardized array with same dimensions as input
///
/// # Examples
/// ```rust
/// use ndarray::array;
/// use rustyml::utility::standardize::{standardize, StandardizationAxis};
///
/// let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
/// let result = standardize(&data, StandardizationAxis::Column, 1e-8).unwrap();
/// ```
///
/// # Errors
///
/// - `ModelError::InputValidationError` - If input array is empty, contains NaN/Infinite values, or if epsilon is non-positive
/// - `ModelError::ProcessingError` - If standardization computation fails (e.g., zero values in global axis)
///
/// # Performance
///
/// - Parallel computation is enabled when the number of elements exceeds `STANDARDIZE_PARALLEL_THRESHOLD` (10,000)
///
/// # Implementation Details
///
/// - For Row axis: Each row is standardized independently
/// - For Column axis: Each column is standardized independently
/// - For Global axis: The entire array is standardized as a single dataset
/// - Epsilon is added to standard deviation to prevent division by zero
/// - Uses parallel computation for improved performance on large datasets
/// - NaN and infinite values in input will result in an error
pub fn standardize<S, D>(
    data: &ArrayBase<S, D>,
    axis: StandardizationAxis,
    epsilon: f64,
) -> Result<Array<f64, D>, ModelError>
where
    S: Data<Elem = f64>,
    D: Dimension,
{
    // Input validation
    if data.is_empty() {
        return Err(ModelError::InputValidationError(
            "Cannot standardize empty array".to_string(),
        ));
    }

    // Check for NaN or infinite values
    if data.iter().any(|&x| !x.is_finite()) {
        return Err(ModelError::InputValidationError(
            "Input contains NaN or infinite values".to_string(),
        ));
    }

    // Validate epsilon parameter
    if epsilon <= 0.0 || !epsilon.is_finite() {
        return Err(ModelError::InputValidationError(
            "Epsilon must be positive and finite".to_string(),
        ));
    }

    let mut result = data.to_owned();

    match axis {
        StandardizationAxis::Global => {
            standardize_global(&mut result, epsilon)?;
        }
        StandardizationAxis::Row => {
            standardize_by_rows(&mut result, epsilon)?;
        }
        StandardizationAxis::Column => {
            standardize_by_columns(&mut result, epsilon)?;
        }
    }

    Ok(result)
}

/// Helper function to standardize the entire array globally
fn standardize_global<D>(data: &mut Array<f64, D>, epsilon: f64) -> Result<(), ModelError>
where
    D: Dimension,
{
    let n = data.len() as f64;

    if n == 0.0 {
        return Err(ModelError::ProcessingError(
            "No values to standardize".to_string(),
        ));
    }

    // Use parallel computation for large datasets
    if n as usize >= STANDARDIZE_PARALLEL_THRESHOLD {
        // Calculate mean
        let mean = data.par_iter().sum::<f64>() / n;

        // Calculate variance
        let variance = data.par_iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;

        // Add epsilon to variance for numerical stability, then take sqrt
        let std_dev = (variance + epsilon * epsilon).sqrt();

        // Apply standardization
        data.par_mapv_inplace(|x| (x - mean) / std_dev);
    } else {
        // Same process as above, but sequential
        let mean = data.iter().sum::<f64>() / n;
        let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
        let std_dev = (variance + epsilon * epsilon).sqrt();
        data.mapv_inplace(|x| (x - mean) / std_dev);
    }

    Ok(())
}

/// Helper function to compute mean and standard deviation for a collection of values
fn compute_mean_and_std(values: &[f64], epsilon: f64) -> (f64, f64) {
    let n = values.len() as f64;

    if n == 0.0 {
        return (0.0, epsilon);
    }

    // Use parallel computation for large datasets
    if values.len() >= STANDARDIZE_PARALLEL_THRESHOLD {
        // Calculate mean
        let mean = values.par_iter().sum::<f64>() / n;

        // Calculate variance
        let variance = values.par_iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;

        // Add epsilon to variance for numerical stability, then take sqrt
        let std_dev = (variance + epsilon * epsilon).sqrt();

        (mean, std_dev)
    } else {
        // Same process as above, but sequential
        let mean = values.iter().sum::<f64>() / n;
        let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
        let std_dev = (variance + epsilon * epsilon).sqrt();

        (mean, std_dev)
    }
}

/// Generic helper function to standardize lanes along a specified axis
fn standardize_lanes<D>(
    data: &mut Array<f64, D>,
    axis: Axis,
    epsilon: f64,
    operation_name: &str,
) -> Result<(), ModelError>
where
    D: Dimension,
{
    let ndim = data.ndim();
    if ndim < 2 {
        return Err(ModelError::InputValidationError(format!(
            "{} requires at least 2 dimensions",
            operation_name
        )));
    }

    // Process each lane along the specified axis using parallel iteration
    let mut lanes: Vec<_> = data.lanes_mut(axis).into_iter().collect();

    // Define a closure that processes a single lane
    let process_lane = |lane: &mut ArrayViewMut1<f64>| -> Result<(), ModelError> {
        let values: Vec<f64> = lane.iter().copied().collect();
        let (mean, std_dev) = compute_mean_and_std(&values, epsilon);
        lane.mapv_inplace(|x| (x - mean) / std_dev);
        Ok(())
    };

    // Choose between parallel and sequential processing based on the number of lanes
    if lanes.len() >= STANDARDIZE_PARALLEL_THRESHOLD / 100 {
        // Use parallel iteration for large number of lanes
        lanes.par_iter_mut().try_for_each(process_lane)?;
    } else {
        // Use sequential iteration for small number of lanes
        lanes.iter_mut().try_for_each(process_lane)?;
    }

    Ok(())
}

/// Helper function to standardize each row independently
fn standardize_by_rows<D>(data: &mut Array<f64, D>, epsilon: f64) -> Result<(), ModelError>
where
    D: Dimension,
{
    let ndim = data.ndim();
    // Get the last axis (assuming it represents features/columns)
    let last_axis = Axis(ndim - 1);

    standardize_lanes(data, last_axis, epsilon, "Row standardization")
}

/// Helper function to standardize each column independently
fn standardize_by_columns<D>(data: &mut Array<f64, D>, epsilon: f64) -> Result<(), ModelError>
where
    D: Dimension,
{
    let ndim = data.ndim();

    // Column standardization requires at least 2 dimensions
    if ndim < 2 {
        return Err(ModelError::InputValidationError(
            "Column standardization requires at least 2 dimensions".to_string(),
        ));
    }

    // Get the second-to-last axis (assuming it represents samples/rows)
    let axis = Axis(ndim - 2);

    standardize_lanes(data, axis, epsilon, "Column standardization")
}