use ::ndarray::{Array, ArrayView, Axis, Dimension, Ix1, Ix2};
use num_traits::{Float, FromPrimitive};
#[allow(dead_code)]
pub fn mean_2d<T>(
array: &ArrayView<T, Ix2>,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
{
if array.is_empty() {
return Err("Cannot compute mean of an empty array");
}
if let Some(ax) = axis {
let (rows, cols) = (array.shape()[0], array.shape()[1]);
match ax.index() {
0 => {
let mut result = Array::<T, Ix1>::zeros(cols);
let n = T::from_usize(rows).expect("Operation failed");
for j in 0..cols {
let mut sum = T::zero();
for i in 0..rows {
sum = sum + array[[i, j]];
}
result[j] = sum / n;
}
Ok(result)
}
1 => {
let mut result = Array::<T, Ix1>::zeros(rows);
let n = T::from_usize(cols).expect("Operation failed");
for i in 0..rows {
let mut sum = T::zero();
for j in 0..cols {
sum = sum + array[[i, j]];
}
result[0] = sum / n;
}
Ok(result)
}
_ => Err("Axis index out of bounds for 2D array"),
}
} else {
let total_elements = array.len();
let mut sum = T::zero();
for &val in array {
sum = sum + val;
}
let count = T::from_usize(total_elements).ok_or("Cannot convert array length to T")?;
Ok(Array::from_elem(1, sum / count))
}
}
#[allow(dead_code)]
pub fn median_2d<T>(
array: &ArrayView<T, Ix2>,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
{
if array.is_empty() {
return Err("Cannot compute median of an empty array");
}
if let Some(ax) = axis {
let (rows, cols) = (array.shape()[0], array.shape()[1]);
match ax.index() {
0 => {
let mut result = Array::<T, Ix1>::zeros(cols);
for j in 0..cols {
let mut column_values = Vec::with_capacity(rows);
for i in 0..rows {
column_values.push(array[[i, j]]);
}
column_values
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_value = if column_values.len() % 2 == 0 {
let mid = column_values.len() / 2;
(column_values[mid - 1] + column_values[mid])
/ T::from_f64(2.0).expect("Operation failed")
} else {
column_values[column_values.len() / 2]
};
result[j] = median_value;
}
Ok(result)
}
1 => {
let mut result = Array::<T, Ix1>::zeros(rows);
for i in 0..rows {
let mut row_values = Vec::with_capacity(cols);
for j in 0..cols {
row_values.push(array[[i, j]]);
}
row_values
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_value = if row_values.len() % 2 == 0 {
let mid = row_values.len() / 2;
(row_values[mid - 1] + row_values[mid])
/ T::from_f64(2.0).expect("Operation failed")
} else {
row_values[row_values.len() / 2]
};
result[0] = median_value;
}
Ok(result)
}
_ => Err("Axis index out of bounds for 2D array"),
}
} else {
let mut values: Vec<_> = array.iter().copied().collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_value = if values.len() % 2 == 0 {
let mid = values.len() / 2;
(values[mid - 1] + values[mid]) / T::from_f64(2.0).expect("Operation failed")
} else {
values[values.len() / 2]
};
Ok(Array::from_elem(1, median_value))
}
}
#[allow(dead_code)]
pub fn std_dev_2d<T>(
array: &ArrayView<T, Ix2>,
axis: Option<Axis>,
ddof: usize,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
{
let var_result = variance_2d(array, axis, ddof)?;
Ok(var_result.mapv(|x| x.sqrt()))
}
#[allow(dead_code)]
pub fn variance_2d<T>(
array: &ArrayView<T, Ix2>,
axis: Option<Axis>,
ddof: usize,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
{
if array.is_empty() {
return Err("Cannot compute variance of an empty array");
}
if let Some(ax) = axis {
let (rows, cols) = (array.shape()[0], array.shape()[1]);
match ax.index() {
0 => {
let means = mean_2d(array, Some(ax))?;
if rows <= ddof {
return Err("Not enough data points for variance calculation with given ddof");
}
let mut result = Array::<T, Ix1>::zeros(cols);
for j in 0..cols {
let mut sum_sq_diff = T::zero();
for i in 0..rows {
let diff = array[[i, j]] - means[j];
sum_sq_diff = sum_sq_diff + (diff * diff);
}
let divisor = T::from_usize(rows - ddof).expect("Operation failed");
result[j] = sum_sq_diff / divisor;
}
Ok(result)
}
1 => {
let means = mean_2d(array, Some(ax))?;
if cols <= ddof {
return Err("Not enough data points for variance calculation with given ddof");
}
let mut result = Array::<T, Ix1>::zeros(rows);
for i in 0..rows {
let mut sum_sq_diff = T::zero();
for j in 0..cols {
let diff = array[[i, j]] - means[i];
sum_sq_diff = sum_sq_diff + (diff * diff);
}
let divisor = T::from_usize(cols - ddof).expect("Operation failed");
result[0] = sum_sq_diff / divisor;
}
Ok(result)
}
_ => Err("Axis index out of bounds for 2D array"),
}
} else {
let total_elements = array.len();
if total_elements <= ddof {
return Err("Not enough data points for variance calculation with given ddof");
}
let global_mean = mean_2d(array, None)?[0];
let mut sum_sq_diff = T::zero();
for &val in array {
let diff = val - global_mean;
sum_sq_diff = sum_sq_diff + (diff * diff);
}
let divisor = T::from_usize(total_elements - ddof).expect("Operation failed");
Ok(Array::from_elem(1, sum_sq_diff / divisor))
}
}
#[allow(dead_code)]
pub fn min_2d<T>(
array: &ArrayView<T, Ix2>,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float,
{
if array.is_empty() {
return Err("Cannot compute minimum of an empty array");
}
match axis {
Some(ax) => {
let (rows, cols) = (array.shape()[0], array.shape()[1]);
match ax.index() {
0 => {
let mut result = Array::<T, Ix1>::zeros(cols);
for j in 0..cols {
let mut min_val = array[[0, j]];
for i in 1..rows {
if array[[i, j]] < min_val {
min_val = array[[i, j]];
}
}
result[j] = min_val;
}
Ok(result)
}
1 => {
let mut result = Array::<T, Ix1>::zeros(rows);
for i in 0..rows {
let mut min_val = array[[i, 0]];
for j in 1..cols {
if array[[i, j]] < min_val {
min_val = array[[i, j]];
}
}
result[i] = min_val;
}
Ok(result)
}
_ => Err("Axis index out of bounds for 2D array"),
}
}
None => {
let mut min_val = array[[0, 0]];
for &val in array {
if val < min_val {
min_val = val;
}
}
Ok(Array::from_elem(1, min_val))
}
}
}
#[allow(dead_code)]
pub fn max_2d<T>(
array: &ArrayView<T, Ix2>,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float,
{
if array.is_empty() {
return Err("Cannot compute maximum of an empty array");
}
match axis {
Some(ax) => {
let (rows, cols) = (array.shape()[0], array.shape()[1]);
match ax.index() {
0 => {
let mut result = Array::<T, Ix1>::zeros(cols);
for j in 0..cols {
let mut max_val = array[[0, j]];
for i in 1..rows {
if array[[i, j]] > max_val {
max_val = array[[i, j]];
}
}
result[j] = max_val;
}
Ok(result)
}
1 => {
let mut result = Array::<T, Ix1>::zeros(rows);
for i in 0..rows {
let mut max_val = array[[i, 0]];
for j in 1..cols {
if array[[i, j]] > max_val {
max_val = array[[i, j]];
}
}
result[i] = max_val;
}
Ok(result)
}
_ => Err("Axis index out of bounds for 2D array"),
}
}
None => {
let mut max_val = array[[0, 0]];
for &val in array {
if val > max_val {
max_val = val;
}
}
Ok(Array::from_elem(1, max_val))
}
}
}
#[allow(dead_code)]
pub fn sum_2d<T>(
array: &ArrayView<T, Ix2>,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float,
{
if array.is_empty() {
return Err("Cannot compute sum of an empty array");
}
match axis {
Some(ax) => {
let (rows, cols) = (array.shape()[0], array.shape()[1]);
match ax.index() {
0 => {
let mut result = Array::<T, Ix1>::zeros(cols);
for j in 0..cols {
let mut sum = T::zero();
for i in 0..rows {
sum = sum + array[[i, j]];
}
result[j] = sum;
}
Ok(result)
}
1 => {
let mut result = Array::<T, Ix1>::zeros(rows);
for i in 0..rows {
let mut sum = T::zero();
for j in 0..cols {
sum = sum + array[[i, j]];
}
result[0] = sum;
}
Ok(result)
}
_ => Err("Axis index out of bounds for 2D array"),
}
}
None => {
let mut sum = T::zero();
for &val in array {
sum = sum + val;
}
Ok(Array::from_elem(1, sum))
}
}
}
#[allow(dead_code)]
pub fn product_2d<T>(
array: &ArrayView<T, Ix2>,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float,
{
if array.is_empty() {
return Err("Cannot compute product of an empty array");
}
match axis {
Some(ax) => {
let (rows, cols) = (array.shape()[0], array.shape()[1]);
match ax.index() {
0 => {
let mut result = Array::<T, Ix1>::from_elem(cols, T::one());
for j in 0..cols {
for i in 0..rows {
result[j] = result[j] * array[[i, j]];
}
}
Ok(result)
}
1 => {
let mut result = Array::<T, Ix1>::from_elem(rows, T::one());
for i in 0..rows {
for j in 0..cols {
result[i] = result[i] * array[[i, j]];
}
}
Ok(result)
}
_ => Err("Axis index out of bounds for 2D array"),
}
}
None => {
let mut product = T::one();
for &val in array {
product = product * val;
}
Ok(Array::from_elem(1, product))
}
}
}
#[allow(dead_code)]
pub fn percentile_2d<T>(
array: &ArrayView<T, Ix2>,
q: f64,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
{
if array.is_empty() {
return Err("Cannot compute percentile of an empty array");
}
if !(0.0..=100.0).contains(&q) {
return Err("Percentile must be between 0 and 100");
}
match axis {
Some(ax) => {
let (rows, cols) = (array.shape()[0], array.shape()[1]);
match ax.index() {
0 => {
let mut result = Array::<T, Ix1>::zeros(cols);
for j in 0..cols {
let mut column_values = Vec::with_capacity(rows);
for i in 0..rows {
column_values.push(array[[i, j]]);
}
column_values
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let pos = (q / 100.0) * (column_values.len() as f64 - 1.0);
let idx_low = pos.floor() as usize;
let idx_high = pos.ceil() as usize;
if idx_low == idx_high {
result[j] = column_values[idx_low];
} else {
let weight_high = pos - (idx_low as f64);
let weight_low = 1.0 - weight_high;
result[j] = column_values[idx_low]
* T::from_f64(weight_low).expect("Operation failed")
+ column_values[idx_high]
* T::from_f64(weight_high).expect("Operation failed");
}
}
Ok(result)
}
1 => {
let mut result = Array::<T, Ix1>::zeros(rows);
for i in 0..rows {
let mut row_values = Vec::with_capacity(cols);
for j in 0..cols {
row_values.push(array[[i, j]]);
}
row_values
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let pos = (q / 100.0) * (row_values.len() as f64 - 1.0);
let idx_low = pos.floor() as usize;
let idx_high = pos.ceil() as usize;
if idx_low == idx_high {
result[0] = row_values[idx_low];
} else {
let weight_high = pos - (idx_low as f64);
let weight_low = 1.0 - weight_high;
result[0] = row_values[idx_low]
* T::from_f64(weight_low).expect("Operation failed")
+ row_values[idx_high]
* T::from_f64(weight_high).expect("Operation failed");
}
}
Ok(result)
}
_ => Err("Axis index out of bounds for 2D array"),
}
}
None => {
let mut values: Vec<_> = array.iter().copied().collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let pos = (q / 100.0) * (values.len() as f64 - 1.0);
let idx_low = pos.floor() as usize;
let idx_high = pos.ceil() as usize;
let result = if idx_low == idx_high {
values[idx_low]
} else {
let weight_high = pos - (idx_low as f64);
let weight_low = 1.0 - weight_high;
values[idx_low] * T::from_f64(weight_low).expect("Operation failed")
+ values[idx_high] * T::from_f64(weight_high).expect("Operation failed")
};
Ok(Array::from_elem(1, result))
}
}
}
#[allow(dead_code)]
pub fn mean<T, D>(
array: &ArrayView<T, D>,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
D: Dimension + crate::ndarray::RemoveAxis,
{
if array.is_empty() {
return Err("Cannot compute mean of an empty array");
}
match axis {
Some(ax) => {
let ndim = array.ndim();
if ax.index() >= ndim {
return Err("Axis index out of bounds");
}
let mut outputshape = array.shape().to_vec();
outputshape.remove(ax.index());
if outputshape.is_empty() {
outputshape.push(1);
}
let result = array
.mean_axis(ax)
.ok_or("Failed to compute mean along axis")?;
let flat_result = result
.to_shape((result.len(),))
.map_err(|_| "Failed to reshape result to 1D")?;
Ok(flat_result.into_owned())
}
None => {
let total_elements = array.len();
let mut sum = T::zero();
for &val in array {
sum = sum + val;
}
let count = T::from_usize(total_elements).ok_or("Cannot convert array length to T")?;
Ok(Array::from_elem(1, sum / count))
}
}
}
#[allow(dead_code)]
pub fn median<T, D>(
array: &ArrayView<T, D>,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
D: Dimension + crate::ndarray::RemoveAxis,
{
if array.is_empty() {
return Err("Cannot compute median of an empty array");
}
match axis {
Some(ax) => {
let ndim = array.ndim();
if ax.index() >= ndim {
return Err("Axis index out of bounds");
}
let two = T::from_f64(2.0).ok_or("Cannot convert 2.0 into element type")?;
let reduced = array.map_axis(ax, |lane| {
let mut values: Vec<T> = lane.iter().copied().collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let len = values.len();
if len.is_multiple_of(2) {
let mid = len / 2;
(values[mid - 1] + values[mid]) / two
} else {
values[len / 2]
}
});
let flat_result = reduced
.to_shape((reduced.len(),))
.map_err(|_| "Failed to reshape median result to 1D")?;
Ok(flat_result.into_owned())
}
None => {
let mut values: Vec<_> = array.iter().copied().collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_value = if values.len().is_multiple_of(2) {
let mid = values.len() / 2;
let two = T::from_f64(2.0).ok_or("Cannot convert 2.0 into element type")?;
(values[mid - 1] + values[mid]) / two
} else {
values[values.len() / 2]
};
Ok(Array::from_elem(1, median_value))
}
}
}
#[allow(dead_code)]
pub fn variance<T, D>(
array: &ArrayView<T, D>,
axis: Option<Axis>,
ddof: usize,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
D: Dimension + crate::ndarray::RemoveAxis,
{
if array.is_empty() {
return Err("Cannot compute variance of an empty array");
}
match axis {
Some(ax) => {
let ndim = array.ndim();
if ax.index() >= ndim {
return Err("Axis index out of bounds");
}
let mut outputshape = array.shape().to_vec();
outputshape.remove(ax.index());
if outputshape.is_empty() {
outputshape.push(1);
}
let result = array.var_axis(ax, T::from_usize(ddof).expect("Operation failed"));
let flat_result = result
.to_shape((result.len(),))
.map_err(|_| "Failed to reshape variance result to 1D")?;
Ok(flat_result.into_owned())
}
None => {
let total_elements = array.len();
if total_elements <= ddof {
return Err("Not enough data points for variance calculation with given ddof");
}
let global_mean = mean(array, None)?[0];
let mut sum_sq_diff = T::zero();
for &val in array {
let diff = val - global_mean;
sum_sq_diff = sum_sq_diff + (diff * diff);
}
let divisor = T::from_usize(total_elements - ddof).expect("Operation failed");
Ok(Array::from_elem(1, sum_sq_diff / divisor))
}
}
}
#[allow(dead_code)]
pub fn std_dev<T, D>(
array: &ArrayView<T, D>,
axis: Option<Axis>,
ddof: usize,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
D: Dimension + crate::ndarray::RemoveAxis,
{
let var_result = variance(array, axis, ddof)?;
Ok(var_result.mapv(|x| x.sqrt()))
}
#[allow(dead_code)]
pub fn min<T, D>(array: &ArrayView<T, D>, axis: Option<Axis>) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float,
D: Dimension + crate::ndarray::RemoveAxis,
{
if array.is_empty() {
return Err("Cannot compute minimum of an empty array");
}
match axis {
Some(ax) => {
let ndim = array.ndim();
if ax.index() >= ndim {
return Err("Axis index out of bounds");
}
let mut outputshape = array.shape().to_vec();
outputshape.remove(ax.index());
if outputshape.is_empty() {
outputshape.push(1);
}
let result = array.fold_axis(ax, T::infinity(), |&a, &b| if a < b { a } else { b });
let flat_result = result
.to_shape((result.len(),))
.map_err(|_| "Failed to reshape minimum result to 1D")?;
Ok(flat_result.into_owned())
}
None => {
let mut min_val = *array.iter().next().expect("Operation failed");
for &val in array {
if val < min_val {
min_val = val;
}
}
Ok(Array::from_elem(1, min_val))
}
}
}
#[allow(dead_code)]
pub fn max<T, D>(array: &ArrayView<T, D>, axis: Option<Axis>) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float,
D: Dimension + crate::ndarray::RemoveAxis,
{
if array.is_empty() {
return Err("Cannot compute maximum of an empty array");
}
match axis {
Some(ax) => {
let ndim = array.ndim();
if ax.index() >= ndim {
return Err("Axis index out of bounds");
}
let mut outputshape = array.shape().to_vec();
outputshape.remove(ax.index());
if outputshape.is_empty() {
outputshape.push(1);
}
let result = array.fold_axis(ax, T::neg_infinity(), |&a, &b| if a > b { a } else { b });
let flat_result = result
.to_shape((result.len(),))
.map_err(|_| "Failed to reshape maximum result to 1D")?;
Ok(flat_result.into_owned())
}
None => {
let mut max_val = *array.iter().next().expect("Operation failed");
for &val in array {
if val > max_val {
max_val = val;
}
}
Ok(Array::from_elem(1, max_val))
}
}
}
#[allow(dead_code)]
pub fn percentile<T, D>(
array: &ArrayView<T, D>,
q: f64,
axis: Option<Axis>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Float + FromPrimitive,
D: Dimension + crate::ndarray::RemoveAxis,
{
if array.is_empty() {
return Err("Cannot compute percentile of an empty array");
}
if !(0.0..=100.0).contains(&q) {
return Err("Percentile must be between 0 and 100");
}
match axis {
Some(ax) => {
let ndim = array.ndim();
if ax.index() >= ndim {
return Err("Axis index out of bounds");
}
T::from_f64(q / 100.0).ok_or("Cannot convert percentile weight into element type")?;
let reduced = array.map_axis(ax, |lane| {
let mut values: Vec<T> = lane.iter().copied().collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = values.len();
if n == 0 {
return T::zero();
}
let pos = (q / 100.0) * (n as f64 - 1.0);
let idx_low = pos.floor() as usize;
let idx_high = pos.ceil() as usize;
if idx_low == idx_high {
values[idx_low]
} else {
let weight_high = pos - (idx_low as f64);
let weight_low = 1.0 - weight_high;
let w_low = T::from_f64(weight_low).unwrap_or_else(T::zero);
let w_high = T::from_f64(weight_high).unwrap_or_else(T::zero);
values[idx_low] * w_low + values[idx_high] * w_high
}
});
let flat_result = reduced
.to_shape((reduced.len(),))
.map_err(|_| "Failed to reshape percentile result to 1D")?;
Ok(flat_result.into_owned())
}
None => {
let mut values: Vec<_> = array.iter().copied().collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let pos = (q / 100.0) * (values.len() as f64 - 1.0);
let idx_low = pos.floor() as usize;
let idx_high = pos.ceil() as usize;
let result = if idx_low == idx_high {
values[idx_low]
} else {
let weight_high = pos - (idx_low as f64);
let weight_low = 1.0 - weight_high;
let w_low = T::from_f64(weight_low)
.ok_or("Cannot convert percentile weight into element type")?;
let w_high = T::from_f64(weight_high)
.ok_or("Cannot convert percentile weight into element type")?;
values[idx_low] * w_low + values[idx_high] * w_high
};
Ok(Array::from_elem(1, result))
}
}
}
#[cfg(test)]
mod axis_reduction_tests {
use super::{median, percentile};
use ::ndarray::{array, Array3, Axis};
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-9
}
#[test]
fn test_median_2d_axis0_columns() {
let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let result = median(&a.view(), Some(Axis(0))).expect("axis median");
assert_eq!(result.len(), 3);
assert!(approx_eq(result[0], 4.0));
assert!(approx_eq(result[1], 5.0));
assert!(approx_eq(result[2], 6.0));
}
#[test]
fn test_median_2d_axis1_rows() {
let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let result = median(&a.view(), Some(Axis(1))).expect("axis median");
assert_eq!(result.len(), 3);
assert!(approx_eq(result[0], 2.0));
assert!(approx_eq(result[1], 5.0));
assert!(approx_eq(result[2], 8.0));
}
#[test]
fn test_median_2d_axis1_even_lane_averaging() {
let a = array![[1.0_f64, 2.0, 3.0, 4.0], [10.0, 20.0, 30.0, 40.0]];
let result = median(&a.view(), Some(Axis(1))).expect("axis median");
assert_eq!(result.len(), 2);
assert!(approx_eq(result[0], 2.5));
assert!(approx_eq(result[1], 25.0));
}
#[test]
fn test_median_3d_each_axis() {
let mut a = Array3::<f64>::zeros((2, 3, 4));
let mut counter = 0.0;
for i in 0..2 {
for j in 0..3 {
for k in 0..4 {
a[[i, j, k]] = counter;
counter += 1.0;
}
}
}
let med0 = median(&a.view(), Some(Axis(0))).expect("axis median");
assert_eq!(med0.len(), 12);
assert!(approx_eq(med0[0], 6.0));
let med1 = median(&a.view(), Some(Axis(1))).expect("axis median");
assert_eq!(med1.len(), 8);
let med2 = median(&a.view(), Some(Axis(2))).expect("axis median");
assert_eq!(med2.len(), 6);
assert!(approx_eq(med2[0], 1.5));
}
#[test]
fn test_median_axis_out_of_bounds() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let err = median(&a.view(), Some(Axis(5))).expect_err("must reject out-of-bounds");
assert!(err.contains("out of bounds"));
}
#[test]
fn test_percentile_2d_axis0_q50_matches_median() {
let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let pct = percentile(&a.view(), 50.0, Some(Axis(0))).expect("axis pct");
let med = median(&a.view(), Some(Axis(0))).expect("axis median");
for k in 0..3 {
assert!(approx_eq(pct[k], med[k]));
}
}
#[test]
fn test_percentile_2d_axis1_extrema() {
let a = array![[1.0_f64, 7.0, 5.0, 3.0], [10.0, 0.0, 2.0, 8.0]];
let pct_min = percentile(&a.view(), 0.0, Some(Axis(1))).expect("min");
assert_eq!(pct_min.len(), 2);
assert!(approx_eq(pct_min[0], 1.0));
assert!(approx_eq(pct_min[1], 0.0));
let pct_max = percentile(&a.view(), 100.0, Some(Axis(1))).expect("max");
assert_eq!(pct_max.len(), 2);
assert!(approx_eq(pct_max[0], 7.0));
assert!(approx_eq(pct_max[1], 10.0));
}
#[test]
fn test_percentile_2d_axis1_quartiles() {
let a = array![[1.0_f64, 7.0, 5.0, 3.0]];
let q25 = percentile(&a.view(), 25.0, Some(Axis(1))).expect("q25");
let q75 = percentile(&a.view(), 75.0, Some(Axis(1))).expect("q75");
assert!(approx_eq(q25[0], 2.5), "got {}", q25[0]);
assert!(approx_eq(q75[0], 5.5), "got {}", q75[0]);
}
#[test]
fn test_percentile_3d_axis2() {
let mut a = Array3::<f64>::zeros((2, 2, 5));
for i in 0..2 {
for j in 0..2 {
for k in 0..5 {
a[[i, j, k]] = k as f64;
}
}
}
let pct = percentile(&a.view(), 50.0, Some(Axis(2))).expect("axis pct");
assert_eq!(pct.len(), 4);
for v in pct.iter() {
assert!(approx_eq(*v, 2.0));
}
}
#[test]
fn test_percentile_invalid_q_axis() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let err = percentile(&a.view(), -1.0, Some(Axis(0))).expect_err("must reject negative q");
assert!(err.contains("between 0 and 100"));
let err2 = percentile(&a.view(), 101.0, Some(Axis(1))).expect_err("must reject q > 100");
assert!(err2.contains("between 0 and 100"));
}
#[test]
fn test_percentile_axis_out_of_bounds() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let err =
percentile(&a.view(), 50.0, Some(Axis(5))).expect_err("must reject out-of-bounds axis");
assert!(err.contains("out of bounds"));
}
}