use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, NumCast, Zero};
use scirs2_core::parallel_ops::*;
use super::basic::PARALLEL_THRESHOLD;
#[inline]
fn indices_to_flat_idx(indices: &[usize], strides: &[usize]) -> usize {
indices
.iter()
.enumerate()
.map(|(i, &idx)| idx * strides[i])
.sum()
}
fn compute_strides(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
pub fn nanmean<T: Float + Clone + Zero + NumCast + std::fmt::Display + Send + Sync>(
array: &Array<T>,
axis: Option<usize>,
keepdims: bool,
) -> Result<Array<T>> {
match axis {
None => {
let data = array.to_vec();
if data.len() >= PARALLEL_THRESHOLD {
let (sum, count) = data
.par_iter()
.filter(|x| !x.is_nan())
.fold(
|| (T::zero(), 0usize),
|(sum, count), &x| (sum + x, count + 1),
)
.reduce(
|| (T::zero(), 0usize),
|(sum1, count1), (sum2, count2)| (sum1 + sum2, count1 + count2),
);
if count == 0 {
Ok(Array::from_vec(vec![T::nan()]))
} else {
let mean = sum / T::from(count).expect("count should be representable");
Ok(Array::from_vec(vec![mean]))
}
} else {
let filtered: Vec<T> = data.into_iter().filter(|x| !x.is_nan()).collect();
if filtered.is_empty() {
Ok(Array::from_vec(vec![T::nan()]))
} else {
let sum = filtered.iter().fold(T::zero(), |acc, &x| acc + x);
let mean = sum
/ T::from(filtered.len()).expect("filtered length should be representable");
Ok(Array::from_vec(vec![mean]))
}
}
}
Some(ax) => {
let shape = array.shape();
let ndim = array.ndim();
if ax >= ndim {
return Err(NumRs2Error::InvalidOperation(format!(
"axis {} is out of bounds for array of dimension {}",
ax, ndim
)));
}
let axis_size = shape[ax];
let data = array.to_vec();
let strides = compute_strides(&shape);
let mut out_shape: Vec<usize> = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != ax)
.map(|(_, &s)| s)
.collect();
if out_shape.is_empty() {
out_shape.push(1);
}
let out_size: usize = out_shape.iter().product();
let result_data: Vec<T> = if out_size >= PARALLEL_THRESHOLD {
(0..out_size)
.into_par_iter()
.map(|out_idx| {
let mut indices = vec![0usize; ndim];
let mut temp = out_idx;
for i in 0..ndim {
if i != ax {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
let mut sum = T::zero();
let mut count = 0usize;
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
sum = sum + val;
count += 1;
}
}
if count == 0 {
T::nan()
} else {
sum / T::from(count).expect("count should be representable")
}
})
.collect()
} else {
(0..out_size)
.map(|out_idx| {
let mut indices = vec![0usize; ndim];
let mut temp = out_idx;
for i in 0..ndim {
if i != ax {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
let mut sum = T::zero();
let mut count = 0usize;
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
sum = sum + val;
count += 1;
}
}
if count == 0 {
T::nan()
} else {
sum / T::from(count).expect("count should be representable")
}
})
.collect()
};
let result = Array::from_vec(result_data).reshape(&out_shape);
if keepdims {
let mut keepdim_shape = out_shape;
keepdim_shape.insert(ax, 1);
Ok(result.reshape(&keepdim_shape))
} else {
Ok(result)
}
}
}
}
pub fn nanstd<T: Float + Clone + Zero + NumCast + std::fmt::Display + Send + Sync>(
array: &Array<T>,
axis: Option<usize>,
ddof: Option<usize>,
keepdims: bool,
) -> Result<Array<T>> {
let variance = nanvar(array, axis, ddof, keepdims)?;
Ok(variance.map(|x| x.sqrt()))
}
pub fn nanvar<T: Float + Clone + Zero + NumCast + std::fmt::Display + Send + Sync>(
array: &Array<T>,
axis: Option<usize>,
ddof: Option<usize>,
keepdims: bool,
) -> Result<Array<T>> {
let ddof_val = ddof.unwrap_or(0);
match axis {
None => {
let data = array.to_vec();
if data.len() >= PARALLEL_THRESHOLD {
let (sum, count) = data
.par_iter()
.filter(|x| !x.is_nan())
.fold(
|| (T::zero(), 0usize),
|(sum, count), &x| (sum + x, count + 1),
)
.reduce(
|| (T::zero(), 0usize),
|(sum1, count1), (sum2, count2)| (sum1 + sum2, count1 + count2),
);
if count <= ddof_val {
Ok(Array::from_vec(vec![T::nan()]))
} else {
let mean = sum / T::from(count).expect("count should be representable");
let sum_squared_diff = data
.par_iter()
.filter(|x| !x.is_nan())
.map(|&x| (x - mean) * (x - mean))
.reduce(|| T::zero(), |acc, x| acc + x);
let variance = sum_squared_diff
/ T::from(count - ddof_val).expect("count-ddof should be representable");
Ok(Array::from_vec(vec![variance]))
}
} else {
let filtered: Vec<T> = data.into_iter().filter(|x| !x.is_nan()).collect();
if filtered.len() <= ddof_val {
Ok(Array::from_vec(vec![T::nan()]))
} else {
let mean = filtered.iter().fold(T::zero(), |acc, &x| acc + x)
/ T::from(filtered.len()).expect("filtered length should be representable");
let sum_squared_diff = filtered
.iter()
.fold(T::zero(), |acc, &x| acc + (x - mean) * (x - mean));
let variance = sum_squared_diff
/ T::from(filtered.len() - ddof_val)
.expect("filtered len-ddof should be representable");
Ok(Array::from_vec(vec![variance]))
}
}
}
Some(ax) => {
let shape = array.shape();
let ndim = array.ndim();
if ax >= ndim {
return Err(NumRs2Error::InvalidOperation(format!(
"axis {} is out of bounds for array of dimension {}",
ax, ndim
)));
}
let axis_size = shape[ax];
let data = array.to_vec();
let strides = compute_strides(&shape);
let mut out_shape: Vec<usize> = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != ax)
.map(|(_, &s)| s)
.collect();
if out_shape.is_empty() {
out_shape.push(1);
}
let out_size: usize = out_shape.iter().product();
let mut result_data = Vec::with_capacity(out_size);
for out_idx in 0..out_size {
let mut indices = vec![0usize; ndim];
let mut temp = out_idx;
for i in 0..ndim {
if i != ax {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
let mut sum = T::zero();
let mut count = 0usize;
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
sum = sum + val;
count += 1;
}
}
if count <= ddof_val {
result_data.push(T::nan());
} else {
let mean = sum / T::from(count).expect("count should be representable");
let mut sum_sq_diff = T::zero();
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
let diff = val - mean;
sum_sq_diff = sum_sq_diff + diff * diff;
}
}
let variance = sum_sq_diff
/ T::from(count - ddof_val).expect("count-ddof should be representable");
result_data.push(variance);
}
}
let result = Array::from_vec(result_data).reshape(&out_shape);
if keepdims {
let mut keepdim_shape = out_shape;
keepdim_shape.insert(ax, 1);
Ok(result.reshape(&keepdim_shape))
} else {
Ok(result)
}
}
}
}
pub fn nanmin<T: Float + Clone + Zero + NumCast + std::fmt::Display>(
array: &Array<T>,
axis: Option<usize>,
keepdims: bool,
) -> Result<Array<T>> {
match axis {
None => {
let data = array.to_vec();
let filtered: Vec<T> = data.into_iter().filter(|x| !x.is_nan()).collect();
if filtered.is_empty() {
Ok(Array::from_vec(vec![T::nan()]))
} else {
let min_val = filtered.iter().fold(filtered[0], |acc, &x| acc.min(x));
Ok(Array::from_vec(vec![min_val]))
}
}
Some(ax) => {
let shape = array.shape();
let ndim = array.ndim();
if ax >= ndim {
return Err(NumRs2Error::InvalidOperation(format!(
"axis {} is out of bounds for array of dimension {}",
ax, ndim
)));
}
let axis_size = shape[ax];
let data = array.to_vec();
let strides = compute_strides(&shape);
let mut out_shape: Vec<usize> = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != ax)
.map(|(_, &s)| s)
.collect();
if out_shape.is_empty() {
out_shape.push(1);
}
let out_size: usize = out_shape.iter().product();
let mut result_data = Vec::with_capacity(out_size);
for out_idx in 0..out_size {
let mut indices = vec![0usize; ndim];
let mut temp = out_idx;
for i in 0..ndim {
if i != ax {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
let mut min_val: Option<T> = None;
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
min_val = Some(match min_val {
Some(current) => current.min(val),
None => val,
});
}
}
result_data.push(min_val.unwrap_or(T::nan()));
}
let result = Array::from_vec(result_data).reshape(&out_shape);
if keepdims {
let mut keepdim_shape = out_shape;
keepdim_shape.insert(ax, 1);
Ok(result.reshape(&keepdim_shape))
} else {
Ok(result)
}
}
}
}
pub fn nanmax<T: Float + Clone + Zero + NumCast + std::fmt::Display>(
array: &Array<T>,
axis: Option<usize>,
keepdims: bool,
) -> Result<Array<T>> {
match axis {
None => {
let data = array.to_vec();
let filtered: Vec<T> = data.into_iter().filter(|x| !x.is_nan()).collect();
if filtered.is_empty() {
Ok(Array::from_vec(vec![T::nan()]))
} else {
let max_val = filtered.iter().fold(filtered[0], |acc, &x| acc.max(x));
Ok(Array::from_vec(vec![max_val]))
}
}
Some(ax) => {
let shape = array.shape();
let ndim = array.ndim();
if ax >= ndim {
return Err(NumRs2Error::InvalidOperation(format!(
"axis {} is out of bounds for array of dimension {}",
ax, ndim
)));
}
let axis_size = shape[ax];
let data = array.to_vec();
let strides = compute_strides(&shape);
let mut out_shape: Vec<usize> = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != ax)
.map(|(_, &s)| s)
.collect();
if out_shape.is_empty() {
out_shape.push(1);
}
let out_size: usize = out_shape.iter().product();
let mut result_data = Vec::with_capacity(out_size);
for out_idx in 0..out_size {
let mut indices = vec![0usize; ndim];
let mut temp = out_idx;
for i in 0..ndim {
if i != ax {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
let mut max_val: Option<T> = None;
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
max_val = Some(match max_val {
Some(current) => current.max(val),
None => val,
});
}
}
result_data.push(max_val.unwrap_or(T::nan()));
}
let result = Array::from_vec(result_data).reshape(&out_shape);
if keepdims {
let mut keepdim_shape = out_shape;
keepdim_shape.insert(ax, 1);
Ok(result.reshape(&keepdim_shape))
} else {
Ok(result)
}
}
}
}
pub fn nansum<T: Float + Clone + Zero + NumCast + std::fmt::Display + Send + Sync>(
array: &Array<T>,
axis: Option<usize>,
keepdims: bool,
) -> Result<Array<T>> {
match axis {
None => {
let data = array.to_vec();
let sum = if data.len() >= PARALLEL_THRESHOLD {
data.par_iter()
.filter(|x| !x.is_nan())
.cloned()
.reduce(|| T::zero(), |acc, x| acc + x)
} else {
data.iter()
.fold(T::zero(), |acc, &x| if x.is_nan() { acc } else { acc + x })
};
Ok(Array::from_vec(vec![sum]))
}
Some(ax) => {
let shape = array.shape();
let ndim = array.ndim();
if ax >= ndim {
return Err(NumRs2Error::InvalidOperation(format!(
"axis {} is out of bounds for array of dimension {}",
ax, ndim
)));
}
let axis_size = shape[ax];
let data = array.to_vec();
let strides = compute_strides(&shape);
let mut out_shape: Vec<usize> = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != ax)
.map(|(_, &s)| s)
.collect();
if out_shape.is_empty() {
out_shape.push(1);
}
let out_size: usize = out_shape.iter().product();
let result_data: Vec<T> = if out_size >= PARALLEL_THRESHOLD {
(0..out_size)
.into_par_iter()
.map(|out_idx| {
let mut indices = vec![0usize; ndim];
let mut temp = out_idx;
for i in 0..ndim {
if i != ax {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
let mut sum = T::zero();
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
sum = sum + val;
}
}
sum
})
.collect()
} else {
(0..out_size)
.map(|out_idx| {
let mut indices = vec![0usize; ndim];
let mut temp = out_idx;
for i in 0..ndim {
if i != ax {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
let mut sum = T::zero();
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
sum = sum + val;
}
}
sum
})
.collect()
};
let result = Array::from_vec(result_data).reshape(&out_shape);
if keepdims {
let mut keepdim_shape = out_shape;
keepdim_shape.insert(ax, 1);
Ok(result.reshape(&keepdim_shape))
} else {
Ok(result)
}
}
}
}
pub fn nanprod<T: Float + Clone + Zero + NumCast + std::fmt::Display + Send + Sync>(
array: &Array<T>,
axis: Option<usize>,
keepdims: bool,
) -> Result<Array<T>> {
match axis {
None => {
let data = array.to_vec();
let product = if data.len() >= PARALLEL_THRESHOLD {
data.par_iter()
.filter(|x| !x.is_nan())
.cloned()
.reduce(|| T::one(), |acc, x| acc * x)
} else {
data.iter()
.fold(T::one(), |acc, &x| if x.is_nan() { acc } else { acc * x })
};
Ok(Array::from_vec(vec![product]))
}
Some(ax) => {
let shape = array.shape();
let ndim = array.ndim();
if ax >= ndim {
return Err(NumRs2Error::InvalidOperation(format!(
"axis {} is out of bounds for array of dimension {}",
ax, ndim
)));
}
let axis_size = shape[ax];
let data = array.to_vec();
let strides = compute_strides(&shape);
let mut out_shape: Vec<usize> = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != ax)
.map(|(_, &s)| s)
.collect();
if out_shape.is_empty() {
out_shape.push(1);
}
let out_size: usize = out_shape.iter().product();
let result_data: Vec<T> = if out_size >= PARALLEL_THRESHOLD {
(0..out_size)
.into_par_iter()
.map(|out_idx| {
let mut indices = vec![0usize; ndim];
let mut temp = out_idx;
for i in 0..ndim {
if i != ax {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
let mut product = T::one();
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
product = product * val;
}
}
product
})
.collect()
} else {
(0..out_size)
.map(|out_idx| {
let mut indices = vec![0usize; ndim];
let mut temp = out_idx;
for i in 0..ndim {
if i != ax {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
let mut product = T::one();
for j in 0..axis_size {
indices[ax] = j;
let flat_idx = indices_to_flat_idx(&indices, &strides);
let val = data[flat_idx];
if !val.is_nan() {
product = product * val;
}
}
product
})
.collect()
};
let result = Array::from_vec(result_data).reshape(&out_shape);
if keepdims {
let mut keepdim_shape = out_shape;
keepdim_shape.insert(ax, 1);
Ok(result.reshape(&keepdim_shape))
} else {
Ok(result)
}
}
}
}