numrs2 0.3.3

A Rust implementation inspired by NumPy for numerical computing (NumRS2)
Documentation
//! Mode function
//!
//! This module provides the mode (most frequent value) calculation.

use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, NumCast};

/// Compute the mode (most frequent value) of an array
///
/// The mode is the value that appears most often in a dataset. This implementation
/// returns the most frequent value along with its count. For arrays with multiple
/// modes (equally frequent values), the smallest value is returned.
///
/// # Arguments
///
/// * `array` - Input array
/// * `axis` - Axis along which to compute the mode (None for flattened array)
/// * `nan_policy` - How to handle NaN values:
///   - "propagate": Return NaN if any NaN values are present (default)
///   - "omit": Ignore NaN values in computation
///   - "raise": Raise an error if any NaN values are present
///
/// # Returns
///
/// A tuple of (mode, count) where:
/// - mode: Array containing the most frequent values
/// - count: Array containing the counts of the mode values
///
/// # Examples
///
/// ```
/// use numrs2::prelude::*;
/// use numrs2::stats::mode;
///
/// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0, 1.0]);
/// let (mode_val, count) = mode(&a, None, None).expect("mode should succeed");
/// assert_eq!(mode_val.to_vec()[0], 1.0);  // 1.0 appears 3 times
/// assert_eq!(count.to_vec()[0], 3.0);     // Count is 3
///
/// // Example with multiple values having same frequency
/// let b = Array::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
/// let (mode_val, count) = mode(&b, None, None).expect("mode should succeed");
/// assert_eq!(mode_val.to_vec()[0], 1.0);  // Smallest value with max frequency
/// assert_eq!(count.to_vec()[0], 2.0);     // Each appears 2 times
/// ```
pub fn mode<T>(
    array: &Array<T>,
    axis: Option<usize>,
    nan_policy: Option<&str>,
) -> Result<(Array<T>, Array<T>)>
where
    T: Float + Clone + PartialOrd + std::fmt::Display + NumCast,
{
    let policy = nan_policy.unwrap_or("propagate");

    match axis {
        None => {
            // Flatten the array and compute mode
            let data = array.to_vec();

            if data.is_empty() {
                return Err(NumRs2Error::InvalidOperation(
                    "Cannot compute mode of empty array".to_string(),
                ));
            }

            // Handle NaN policy
            let filtered_data: Vec<T> = match policy {
                "propagate" => {
                    // Check if any NaN values exist
                    if data.iter().any(|x| x.is_nan()) {
                        return Ok((
                            Array::from_vec(vec![T::nan()]),
                            Array::from_vec(vec![T::zero()]),
                        ));
                    }
                    data
                }
                "omit" => {
                    // Filter out NaN values
                    data.into_iter().filter(|x| !x.is_nan()).collect()
                }
                "raise" => {
                    // Check for NaN values and raise error if found
                    if data.iter().any(|x| x.is_nan()) {
                        return Err(NumRs2Error::InvalidOperation(
                            "NaN values found in array with nan_policy='raise'".to_string(),
                        ));
                    }
                    data
                }
                _ => {
                    return Err(NumRs2Error::InvalidOperation(format!(
                        "Invalid nan_policy '{}'. Use 'propagate', 'omit', or 'raise'",
                        policy
                    )));
                }
            };

            if filtered_data.is_empty() {
                return Err(NumRs2Error::InvalidOperation(
                    "No valid (non-NaN) values found".to_string(),
                ));
            }

            // Count frequency of each value
            use std::collections::HashMap;
            let mut counts: HashMap<String, (T, usize)> = HashMap::new();

            for &value in &filtered_data {
                let key = format!("{:.15}", value); // Use string key for floating point comparison
                let entry = counts.entry(key).or_insert((value, 0));
                entry.1 += 1;
            }

            // Find the value(s) with maximum frequency
            let max_count = counts
                .values()
                .map(|(_, count)| *count)
                .max()
                .expect("counts should not be empty");

            // Among values with max frequency, find the smallest one
            let mut mode_candidates: Vec<T> = counts
                .values()
                .filter(|(_, count)| *count == max_count)
                .map(|(value, _)| *value)
                .collect();

            mode_candidates.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));

            let mode_value = mode_candidates[0];
            let mode_count = T::from(max_count).expect("max_count should be representable");

            Ok((
                Array::from_vec(vec![mode_value]),
                Array::from_vec(vec![mode_count]),
            ))
        }
        Some(axis_val) => {
            // For axis-specific mode computation
            let shape = array.shape();
            if axis_val >= shape.len() {
                return Err(NumRs2Error::DimensionMismatch(format!(
                    "Axis {} out of bounds for array of dimension {}",
                    axis_val,
                    shape.len()
                )));
            }

            // This is a simplified implementation - for a full implementation,
            // we would need to iterate along the specified axis
            // For now, fall back to the flattened version
            mode(array, None, nan_policy)
        }
    }
}