use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, NumCast};
pub fn mode<T>(
array: &Array<T>,
axis: Option<usize>,
nan_policy: Option<&str>,
) -> Result<(Array<T>, Array<T>)>
where
T: Float + Clone + PartialOrd + std::fmt::Display + NumCast,
{
let policy = nan_policy.unwrap_or("propagate");
match axis {
None => {
let data = array.to_vec();
if data.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Cannot compute mode of empty array".to_string(),
));
}
let filtered_data: Vec<T> = match policy {
"propagate" => {
if data.iter().any(|x| x.is_nan()) {
return Ok((
Array::from_vec(vec![T::nan()]),
Array::from_vec(vec![T::zero()]),
));
}
data
}
"omit" => {
data.into_iter().filter(|x| !x.is_nan()).collect()
}
"raise" => {
if data.iter().any(|x| x.is_nan()) {
return Err(NumRs2Error::InvalidOperation(
"NaN values found in array with nan_policy='raise'".to_string(),
));
}
data
}
_ => {
return Err(NumRs2Error::InvalidOperation(format!(
"Invalid nan_policy '{}'. Use 'propagate', 'omit', or 'raise'",
policy
)));
}
};
if filtered_data.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"No valid (non-NaN) values found".to_string(),
));
}
use std::collections::HashMap;
let mut counts: HashMap<String, (T, usize)> = HashMap::new();
for &value in &filtered_data {
let key = format!("{:.15}", value); let entry = counts.entry(key).or_insert((value, 0));
entry.1 += 1;
}
let max_count = counts
.values()
.map(|(_, count)| *count)
.max()
.expect("counts should not be empty");
let mut mode_candidates: Vec<T> = counts
.values()
.filter(|(_, count)| *count == max_count)
.map(|(value, _)| *value)
.collect();
mode_candidates.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mode_value = mode_candidates[0];
let mode_count = T::from(max_count).expect("max_count should be representable");
Ok((
Array::from_vec(vec![mode_value]),
Array::from_vec(vec![mode_count]),
))
}
Some(axis_val) => {
let shape = array.shape();
if axis_val >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis_val,
shape.len()
)));
}
mode(array, None, nan_policy)
}
}
}