use crate::ndarray::compat::ArrayStatCompat;
use crate::ufuncs::core::{apply_reduction, register_ufunc, UFunc, UFuncKind};
use ::ndarray::{
Array, Array1, ArrayView, ArrayViewMut, Axis, Dimension, Ix1, IxDyn, ShapeBuilder,
};
use std::sync::Once;
static INIT: Once = Once::new();
#[allow(dead_code)]
fn init_reduction_ufuncs() {
INIT.call_once(|| {
let _ = register_ufunc(Box::new(SumUFunc));
let _ = register_ufunc(Box::new(ProductUFunc));
let _ = register_ufunc(Box::new(MeanUFunc));
let _ = register_ufunc(Box::new(StdUFunc));
let _ = register_ufunc(Box::new(VarUFunc));
let _ = register_ufunc(Box::new(MinUFunc));
let _ = register_ufunc(Box::new(MaxUFunc));
});
}
pub struct SumUFunc;
impl UFunc for SumUFunc {
fn name(&self) -> &str {
"sum"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Reduction
}
fn apply(
&self,
inputs: &[ArrayView<f64, IxDyn>],
output: &mut ArrayViewMut<f64, IxDyn>,
) -> Result<(), &'static str> {
if inputs.len() != 1 {
return Err("Sum requires exactly one input array");
}
if let Some(output1d) = output.as_slice_mut() {
let input_view = &inputs[0];
let mut sum = 0.0;
for &val in input_view.iter() {
sum += val;
}
output1d[0] = sum;
Ok(())
} else {
Err("Output array is not contiguous")
}
}
}
pub struct ProductUFunc;
impl UFunc for ProductUFunc {
fn name(&self) -> &str {
"product"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Reduction
}
fn apply(
&self,
inputs: &[ArrayView<f64, IxDyn>],
output: &mut ArrayViewMut<f64, IxDyn>,
) -> Result<(), &'static str> {
if inputs.len() != 1 {
return Err("Product requires exactly one input array");
}
if let Some(output1d) = output.as_slice_mut() {
let input_view = &inputs[0];
let mut product = 1.0;
for &val in input_view.iter() {
product *= val;
}
output1d[0] = product;
Ok(())
} else {
Err("Output array is not contiguous")
}
}
}
pub struct MeanUFunc;
impl UFunc for MeanUFunc {
fn name(&self) -> &str {
"mean"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Reduction
}
fn apply(
&self,
inputs: &[ArrayView<f64, IxDyn>],
output: &mut ArrayViewMut<f64, IxDyn>,
) -> Result<(), &'static str> {
if inputs.len() != 1 {
return Err("Mean requires exactly one input array");
}
if let Some(output1d) = output.as_slice_mut() {
let input_view = &inputs[0];
let mut sum = 0.0;
let count = input_view.len();
if count == 0 {
return Err("Cannot compute mean of empty array");
}
for &val in input_view.iter() {
sum += val;
}
output1d[0] = sum / count as f64;
Ok(())
} else {
Err("Output array is not contiguous")
}
}
}
pub struct StdUFunc;
impl UFunc for StdUFunc {
fn name(&self) -> &str {
"std"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Reduction
}
fn apply(
&self,
inputs: &[ArrayView<f64, IxDyn>],
output: &mut ArrayViewMut<f64, IxDyn>,
) -> Result<(), &'static str> {
if inputs.len() != 1 {
return Err("Std requires exactly one input array");
}
if let Some(output1d) = output.as_slice_mut() {
let input_view = &inputs[0];
let mut sum = 0.0;
let mut sum_sq = 0.0;
let count = input_view.len();
if count <= 1 {
return Err("Cannot compute standard deviation with less than 2 elements");
}
for &val in input_view.iter() {
sum += val;
sum_sq += val * val;
}
let mean = sum / count as f64;
let variance = sum_sq / count as f64 - mean * mean;
output1d[0] = variance.sqrt();
Ok(())
} else {
Err("Output array is not contiguous")
}
}
}
pub struct VarUFunc;
impl UFunc for VarUFunc {
fn name(&self) -> &str {
"var"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Reduction
}
fn apply(
&self,
inputs: &[ArrayView<f64, IxDyn>],
output: &mut ArrayViewMut<f64, IxDyn>,
) -> Result<(), &'static str> {
if inputs.len() != 1 {
return Err("Var requires exactly one input array");
}
if let Some(output1d) = output.as_slice_mut() {
let input_view = &inputs[0];
let mut sum = 0.0;
let mut sum_sq = 0.0;
let count = input_view.len();
if count <= 1 {
return Err("Cannot compute variance with less than 2 elements");
}
for &val in input_view.iter() {
sum += val;
sum_sq += val * val;
}
let mean = sum / count as f64;
let variance = sum_sq / count as f64 - mean * mean;
output1d[0] = variance;
Ok(())
} else {
Err("Output array is not contiguous")
}
}
}
pub struct MinUFunc;
impl UFunc for MinUFunc {
fn name(&self) -> &str {
"min"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Reduction
}
fn apply(
&self,
inputs: &[ArrayView<f64, IxDyn>],
output: &mut ArrayViewMut<f64, IxDyn>,
) -> Result<(), &'static str> {
if inputs.len() != 1 {
return Err("Min requires exactly one input array");
}
if let Some(output1d) = output.as_slice_mut() {
let input_view = &inputs[0];
if input_view.is_empty() {
return Err("Cannot compute minimum of empty array");
}
let mut min_val = f64::INFINITY;
for &val in input_view.iter() {
if val < min_val {
min_val = val;
}
}
output1d[0] = min_val;
Ok(())
} else {
Err("Output array is not contiguous")
}
}
}
pub struct MaxUFunc;
impl UFunc for MaxUFunc {
fn name(&self) -> &str {
"max"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Reduction
}
fn apply(
&self,
inputs: &[ArrayView<f64, IxDyn>],
output: &mut ArrayViewMut<f64, IxDyn>,
) -> Result<(), &'static str> {
if inputs.len() != 1 {
return Err("Max requires exactly one input array");
}
if let Some(output1d) = output.as_slice_mut() {
let input_view = &inputs[0];
if input_view.is_empty() {
return Err("Cannot compute maximum of empty array");
}
let mut max_val = f64::NEG_INFINITY;
for &val in input_view.iter() {
if val > max_val {
max_val = val;
}
}
output1d[0] = max_val;
Ok(())
} else {
Err("Output array is not contiguous")
}
}
}
#[allow(dead_code)]
fn prepare_reduction_output<D>(
input: &crate::ndarray::ArrayView<f64, D>,
axis: Option<usize>,
) -> (Array<f64, Ix1>, Vec<usize>)
where
D: Dimension,
{
match axis {
Some(ax) => {
if ax >= input.ndim() {
panic!("Axis index out of bounds");
}
let mut outshape = Vec::with_capacity(input.ndim() - 1);
let mut outputsize = 1;
for (i, &dim) in input.shape().iter().enumerate() {
if i != ax {
outshape.push(dim);
outputsize *= dim;
}
}
(Array::<f64, Ix1>::zeros(outputsize), outshape)
}
None => {
(Array::<f64, Ix1>::zeros(1), vec![1])
}
}
}
#[allow(dead_code)]
pub fn sum<D>(array: &crate::ndarray::ArrayView<f64, D>, axis: Option<usize>) -> Array<f64, Ix1>
where
D: Dimension + crate::ndarray::RemoveAxis,
{
use ::ndarray::Axis;
match axis {
Some(ax) => {
let result = array.sum_axis(Axis(ax));
let len = result.len();
let (vec, _offset) = result.into_raw_vec_and_offset();
Array::from_shape_vec(len, vec).expect("Operation failed")
}
None => {
let total = array.iter().sum::<f64>();
Array::from_elem(1, total)
}
}
}
#[allow(dead_code)]
pub fn product<D>(array: &crate::ndarray::ArrayView<f64, D>, axis: Option<usize>) -> Array<f64, Ix1>
where
D: Dimension + crate::ndarray::RemoveAxis,
{
use ::ndarray::Axis;
match axis {
Some(ax) => {
let result = array.map_axis(Axis(ax), |lane| lane.iter().product());
let len = result.len();
let (vec, _offset) = result.into_raw_vec_and_offset();
Array::from_shape_vec(len, vec).expect("Operation failed")
}
None => {
let total = array.iter().product::<f64>();
Array::from_elem(1, total)
}
}
}
#[allow(dead_code)]
pub fn mean<D>(array: &crate::ndarray::ArrayView<f64, D>, axis: Option<usize>) -> Array<f64, Ix1>
where
D: Dimension + crate::ndarray::RemoveAxis,
{
init_reduction_ufuncs();
let sum_result = sum(array, axis);
match axis {
Some(ax) => {
let axis_len = array.len_of(crate::ndarray::Axis(ax)) as f64;
sum_result.map(|&x| x / axis_len)
}
None => {
let total_elements = array.len() as f64;
Array::from_vec(vec![sum_result[0] / total_elements])
}
}
}
#[allow(dead_code)]
pub fn std<D>(array: &crate::ndarray::ArrayView<f64, D>, axis: Option<usize>) -> Array<f64, Ix1>
where
D: Dimension + crate::ndarray::RemoveAxis,
{
let var_result = var(array, axis);
var_result.map(|&x| x.sqrt())
}
#[allow(dead_code)]
pub fn var<D>(array: &crate::ndarray::ArrayView<f64, D>, axis: Option<usize>) -> Array<f64, Ix1>
where
D: Dimension + crate::ndarray::RemoveAxis,
{
use ::ndarray::Axis;
init_reduction_ufuncs();
match axis {
Some(ax) => {
let n = array.len_of(Axis(ax)) as f64;
let result = array.map_axis(Axis(ax), |lane| {
let m = lane.mean_or(0.0);
lane.iter().map(|&x| (x - m).powi(2)).sum::<f64>() / n
});
let len = result.len();
let (vec, _offset) = result.into_raw_vec_and_offset();
Array::from_shape_vec(len, vec).expect("Operation failed")
}
None => {
let mean_val = array.mean_or(0.0);
let n = array.len() as f64;
let var_val = array.iter().map(|&x| (x - mean_val).powi(2)).sum::<f64>() / n;
Array::from_elem(1, var_val)
}
}
}
#[allow(dead_code)]
pub fn min<D>(array: &crate::ndarray::ArrayView<f64, D>, axis: Option<usize>) -> Array<f64, Ix1>
where
D: Dimension + crate::ndarray::RemoveAxis,
{
use ::ndarray::Axis;
init_reduction_ufuncs();
match axis {
Some(ax) => {
let result = array.map_axis(Axis(ax), |lane| {
*lane
.iter()
.min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
.unwrap_or(&f64::INFINITY)
});
let len = result.len();
let (vec, _offset) = result.into_raw_vec_and_offset();
Array::from_shape_vec(len, vec).expect("Operation failed")
}
None => {
let min_val = array
.iter()
.min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
.copied()
.unwrap_or(f64::INFINITY);
Array::from_elem(1, min_val)
}
}
}
#[allow(dead_code)]
pub fn max<D>(array: &crate::ndarray::ArrayView<f64, D>, axis: Option<usize>) -> Array<f64, Ix1>
where
D: Dimension + crate::ndarray::RemoveAxis,
{
use ::ndarray::Axis;
init_reduction_ufuncs();
match axis {
Some(ax) => {
let result = array.map_axis(Axis(ax), |lane| {
*lane
.iter()
.max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
.unwrap_or(&f64::NEG_INFINITY)
});
let len = result.len();
let (vec, _offset) = result.into_raw_vec_and_offset();
Array::from_shape_vec(len, vec).expect("Operation failed")
}
None => {
let max_val = array
.iter()
.max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
.copied()
.unwrap_or(f64::NEG_INFINITY);
Array::from_elem(1, max_val)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ::ndarray::array;
#[test]
fn test_sum() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let result = sum(&a.view(), None);
assert_eq!(result, array![21.0]);
let result = sum(&a.view(), Some(0));
assert_eq!(result, array![5.0, 7.0, 9.0]);
let result = sum(&a.view(), Some(1));
assert_eq!(result, array![6.0, 15.0]);
}
#[test]
fn test_product() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let result = product(&a.view(), None);
assert_eq!(result, array![720.0]);
let result = product(&a.view(), Some(0));
assert_eq!(result, array![4.0, 10.0, 18.0]);
let result = product(&a.view(), Some(1));
assert_eq!(result, array![6.0, 120.0]);
}
#[test]
fn test_mean() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let result = mean(&a.view(), None);
assert_eq!(result, array![3.5]);
let result = mean(&a.view(), Some(0));
assert_eq!(result, array![2.5, 3.5, 4.5]);
let result = mean(&a.view(), Some(1));
assert_eq!(result, array![2.0, 5.0]);
}
#[test]
fn test_std() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let result = std(&a.view(), None);
assert!((result[0] - (35.0_f64 / 12.0).sqrt()).abs() < 1e-6);
let result = std(&a.view(), Some(0));
assert!((result[0] - 1.5).abs() < 1e-10);
assert!((result[1] - 1.5).abs() < 1e-10);
assert!((result[2] - 1.5).abs() < 1e-10);
}
#[test]
fn test_var() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let result = var(&a.view(), None);
assert!((result[0] - 35.0 / 12.0).abs() < 1e-10);
let result = var(&a.view(), Some(0));
assert_eq!(result, array![2.25, 2.25, 2.25]);
}
#[test]
fn test_min() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let result = min(&a.view(), None);
assert_eq!(result, array![1.0]);
let result = min(&a.view(), Some(0));
assert_eq!(result, array![1.0, 2.0, 3.0]);
let result = min(&a.view(), Some(1));
assert_eq!(result, array![1.0, 4.0]);
}
#[test]
fn test_max() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let result = max(&a.view(), None);
assert_eq!(result, array![6.0]);
let result = max(&a.view(), Some(0));
assert_eq!(result, array![4.0, 5.0, 6.0]);
let result = max(&a.view(), Some(1));
assert_eq!(result, array![3.0, 6.0]);
}
}