use std::cmp::{max_by, min_by};
use crate::array;
use crate::error::ActiveStorageError;
use crate::models::{self, Order, ReductionAxes};
use crate::operation::{Element, NumOperation};
use crate::types::Missing;
use axum::body::Bytes;
use ndarray::{ArrayView, Axis};
use zerocopy::AsBytes;
fn missing_filter<'a, T: Element>(missing: &'a Missing<T>) -> Box<dyn Fn(&T) -> bool + 'a> {
match missing {
Missing::MissingValue(value) => Box::new(move |x: &T| *x != *value),
Missing::MissingValues(values) => Box::new(move |x: &T| !values.contains(x)),
Missing::ValidMin(min) => Box::new(move |x: &T| *x >= *min),
Missing::ValidMax(max) => Box::new(move |x: &T| *x <= *max),
Missing::ValidRange(min, max) => Box::new(move |x: &T| *x >= *min && *x <= *max),
}
}
fn count_non_missing<T: Element>(
array: &ArrayView<T, ndarray::Dim<ndarray::IxDynImpl>>,
missing: &Missing<T>,
) -> usize {
let filter = missing_filter(missing);
array.iter().copied().filter(filter).count()
}
fn count_array_multi_axis<T: Element>(
array: ndarray::ArrayView<T, ndarray::IxDyn>,
axes: &[usize],
missing: Option<Missing<T>>,
) -> (Vec<i64>, Vec<usize>) {
let result = match axes.first() {
None => {
array.map(|val| {
if let Some(missing) = &missing {
if !missing.is_missing(val) {
1
} else {
0
}
} else {
1
}
})
}
Some(first_axis) => {
let mut result = array
.fold_axis(Axis(*first_axis), 0, |running_count, val| {
if let Some(missing) = &missing {
if !missing.is_missing(val) {
running_count + 1
} else {
*running_count
}
} else {
running_count + 1
}
})
.into_dyn();
if let Some(remaining_axes) = axes.get(1..) {
for (n, axis) in remaining_axes.iter().enumerate() {
result = result
.fold_axis(Axis(axis - n - 1), 0, |total_count, count| {
total_count + count
})
.into_dyn();
}
}
result
}
};
let counts = result.iter().copied().collect();
(counts, result.shape().into())
}
pub struct Count {}
impl NumOperation for Count {
fn execute_t<T: Element>(
request_data: &models::RequestData,
mut data: Vec<u8>,
) -> Result<models::Response, ActiveStorageError> {
let array = array::build_array::<T>(request_data, &mut data)?;
let slice_info = array::build_slice_info::<T>(&request_data.selection, array.shape());
let sliced = array.slice(slice_info);
let typed_missing: Option<Missing<T>> = if let Some(missing) = &request_data.missing {
let m = Missing::try_from(missing)?;
Some(m)
} else {
None
};
let (body, shape, counts) = match &request_data.axis {
ReductionAxes::All => {
let count = if let Some(missing) = typed_missing {
count_non_missing(&sliced, &missing)
} else {
sliced.len()
};
let count = i64::try_from(count)?;
let body = count.to_ne_bytes();
let body = Bytes::copy_from_slice(&body);
(body, vec![], vec![count])
}
ReductionAxes::One(axis) => {
let (counts, shape) =
count_array_multi_axis(sliced.view(), &[*axis], typed_missing);
let body = counts.as_bytes();
let body = Bytes::copy_from_slice(body);
(body, shape, counts)
}
ReductionAxes::Multi(axes) => {
let (counts, shape) = count_array_multi_axis(sliced.view(), axes, typed_missing);
let body = counts.as_bytes();
let body = Bytes::copy_from_slice(body);
(body, shape, counts)
}
};
Ok(models::Response::new(
body,
models::DType::Int64,
shape,
counts,
))
}
}
pub struct Max {}
fn max_element_pairwise<T: Element>(x: &&T, y: &&T) -> std::cmp::Ordering {
x.partial_cmp(y)
.unwrap_or_else(|| panic!("unexpected undefined order error for min"))
}
fn reduction_over_zero_axes<T: Element>(
array: &ndarray::ArrayView<T, ndarray::IxDyn>,
missing: Option<Missing<T>>,
order: &Option<Order>,
) -> ndarray::ArrayBase<ndarray::OwnedRepr<(T, i64)>, ndarray::IxDyn> {
let func = |val| {
if let Some(missing) = &missing {
if !missing.is_missing(val) {
(*val, 1)
} else {
(*val, 0)
}
} else {
(*val, 1)
}
};
let result = match order {
Some(Order::F) => array.t().map(func),
_ => array.map(func),
};
result
}
fn max_array_multi_axis<T: Element>(
array: ndarray::ArrayView<T, ndarray::IxDyn>,
axes: &[usize],
missing: Option<Missing<T>>,
order: &Option<Order>,
) -> (Vec<T>, Vec<i64>, Vec<usize>) {
let (result, shape) = match axes.first() {
None => {
let result = reduction_over_zero_axes(&array, missing, order);
(result, array.shape().to_owned())
}
Some(first_axis) => {
let init = T::min_value();
let mut result = array
.fold_axis(Axis(*first_axis), (init, 0), |(running_max, count), val| {
if let Some(missing) = &missing {
if !missing.is_missing(val) {
let new_max = max_by(running_max, val, max_element_pairwise);
(*new_max, count + 1)
} else {
(*running_max, *count)
}
} else {
let new_max = max_by(running_max, val, max_element_pairwise);
(*new_max, count + 1)
}
})
.into_dyn();
if let Some(remaining_axes) = axes.get(1..) {
for (n, axis) in remaining_axes.iter().enumerate() {
result = result
.fold_axis(
Axis(axis - n - 1),
(init, 0),
|(global_max, total_count), (running_max, count)| {
let new_max = max_by(global_max, running_max, max_element_pairwise);
(*new_max, total_count + count)
},
)
.into_dyn();
}
}
let shape = result.shape().to_owned();
(result, shape)
}
};
let maxes = result.iter().map(|(max, _)| *max).collect::<Vec<T>>();
let counts = result.iter().map(|(_, count)| *count).collect::<Vec<i64>>();
(maxes, counts, shape)
}
impl NumOperation for Max {
fn execute_t<T: Element>(
request_data: &models::RequestData,
mut data: Vec<u8>,
) -> Result<models::Response, ActiveStorageError> {
let array = array::build_array::<T>(request_data, &mut data)?;
let slice_info = array::build_slice_info::<T>(&request_data.selection, array.shape());
let sliced = array.slice(slice_info);
let typed_missing: Option<Missing<T>> = if let Some(missing) = &request_data.missing {
let m = Missing::try_from(missing)?;
Some(m)
} else {
None
};
let (body, counts, shape) = match &request_data.axis {
ReductionAxes::One(axis) => {
let (maxes, counts, shape) = max_array_multi_axis(
sliced.view(),
&[*axis],
typed_missing,
&request_data.order,
);
let body = Bytes::copy_from_slice(maxes.as_bytes());
(body, counts, shape)
}
ReductionAxes::Multi(axes) => {
let (maxes, counts, shape) =
max_array_multi_axis(sliced, axes, typed_missing, &request_data.order);
let body = Bytes::copy_from_slice(maxes.as_bytes());
(body, counts, shape)
}
ReductionAxes::All => {
let init = T::min_value();
let (max, count) = sliced.fold((init, 0_i64), |(running_max, count), val| {
if let Some(missing) = &typed_missing {
if !missing.is_missing(val) {
(*max_by(&running_max, val, max_element_pairwise), count + 1)
} else {
(running_max, count)
}
} else {
(*max_by(&running_max, val, max_element_pairwise), count + 1)
}
});
let body = Bytes::copy_from_slice(max.as_bytes());
(body, vec![count], vec![])
}
};
Ok(models::Response::new(
body,
request_data.dtype,
shape,
counts,
))
}
}
pub struct Min {}
fn min_element_pairwise<T: Element>(x: &&T, y: &&T) -> std::cmp::Ordering {
x.partial_cmp(y)
.unwrap_or_else(|| panic!("unexpected undefined order error for min"))
}
fn min_array_multi_axis<T: Element>(
array: ndarray::ArrayView<T, ndarray::IxDyn>,
axes: &[usize],
missing: Option<Missing<T>>,
order: &Option<Order>,
) -> (Vec<T>, Vec<i64>, Vec<usize>) {
let (result, shape) = match axes.first() {
None => {
let result = reduction_over_zero_axes(&array, missing, order);
(result, array.shape().to_owned())
}
Some(first_axis) => {
let init = T::max_value();
let mut result = array
.fold_axis(Axis(*first_axis), (init, 0), |(running_min, count), val| {
if let Some(missing) = &missing {
if !missing.is_missing(val) {
let new_min = min_by(running_min, val, min_element_pairwise);
(*new_min, count + 1)
} else {
(*running_min, *count)
}
} else {
let new_min = min_by(running_min, val, min_element_pairwise);
(*new_min, count + 1)
}
})
.into_dyn();
if let Some(remaining_axes) = axes.get(1..) {
for (n, axis) in remaining_axes.iter().enumerate() {
result = result
.fold_axis(
Axis(axis - n - 1),
(init, 0),
|(global_min, total_count), (running_min, count)| {
let new_min = min_by(global_min, running_min, min_element_pairwise);
(*new_min, total_count + count)
},
)
.into_dyn();
}
}
let shape = result.shape().to_owned();
(result, shape)
}
};
let mins = result.iter().map(|(min, _)| *min).collect::<Vec<T>>();
let counts = result.iter().map(|(_, count)| *count).collect::<Vec<i64>>();
(mins, counts, shape)
}
impl NumOperation for Min {
fn execute_t<T: Element>(
request_data: &models::RequestData,
mut data: Vec<u8>,
) -> Result<models::Response, ActiveStorageError> {
let array = array::build_array::<T>(request_data, &mut data)?;
let slice_info = array::build_slice_info::<T>(&request_data.selection, array.shape());
let sliced = array.slice(slice_info);
let typed_missing: Option<Missing<T>> = if let Some(missing) = &request_data.missing {
let m = Missing::try_from(missing)?;
Some(m)
} else {
None
};
let (body, counts, shape) = match &request_data.axis {
ReductionAxes::One(axis) => {
let (mins, counts, shape) = min_array_multi_axis(
sliced.view(),
&[*axis],
typed_missing,
&request_data.order,
);
let body = Bytes::copy_from_slice(mins.as_bytes());
(body, counts, shape)
}
ReductionAxes::Multi(axes) => {
let (mins, counts, shape) =
min_array_multi_axis(sliced, axes, typed_missing, &request_data.order);
let body = Bytes::copy_from_slice(mins.as_bytes());
(body, counts, shape)
}
ReductionAxes::All => {
let init = T::max_value();
let (min, count) = sliced.fold((init, 0_i64), |(running_min, count), val| {
if let Some(missing) = &typed_missing {
if !missing.is_missing(val) {
(*min_by(&running_min, val, min_element_pairwise), count + 1)
} else {
(running_min, count)
}
} else {
(*min_by(&running_min, val, min_element_pairwise), count + 1)
}
});
let body = Bytes::copy_from_slice(min.as_bytes());
(body, vec![count], vec![])
}
};
Ok(models::Response::new(
body,
request_data.dtype,
shape,
counts,
))
}
}
pub struct Select {}
impl NumOperation for Select {
fn execute_t<T: Element>(
request_data: &models::RequestData,
mut data: Vec<u8>,
) -> Result<models::Response, ActiveStorageError> {
let array = array::build_array::<T>(request_data, &mut data)?;
let slice_info = array::build_slice_info::<T>(&request_data.selection, array.shape());
let sliced = array.slice(slice_info);
let count = if let Some(missing) = &request_data.missing {
let missing = Missing::<T>::try_from(missing)?;
count_non_missing(&sliced, &missing)
} else {
sliced.len()
};
let count = i64::try_from(count)?;
let shape = sliced.shape().to_vec();
let body = if !array.is_standard_layout() {
let sliced_ordered = sliced.t();
sliced_ordered.iter().copied().collect::<Vec<T>>()
} else {
sliced.iter().copied().collect::<Vec<T>>()
};
let body = body.as_bytes();
let body = Bytes::copy_from_slice(body);
Ok(models::Response::new(
body,
request_data.dtype,
shape,
vec![count],
))
}
}
pub struct Sum {}
fn sum_array_multi_axis<T: Element>(
array: ndarray::ArrayView<T, ndarray::IxDyn>,
axes: &[usize],
missing: Option<Missing<T>>,
order: &Option<Order>,
) -> (Vec<T>, Vec<i64>, Vec<usize>) {
let (result, shape) = match axes.first() {
None => {
let result = reduction_over_zero_axes(&array, missing, order);
(result, array.shape().to_owned())
}
Some(first_axis) => {
let mut result = array
.fold_axis(Axis(*first_axis), (T::zero(), 0), |(sum, count), val| {
if let Some(missing) = &missing {
if !missing.is_missing(val) {
(*sum + *val, count + 1)
} else {
(*sum, *count)
}
} else {
(*sum + *val, count + 1)
}
})
.into_dyn();
if let Some(remaining_axes) = axes.get(1..) {
for (n, axis) in remaining_axes.iter().enumerate() {
result = result
.fold_axis(
Axis(axis - n - 1),
(T::zero(), 0),
|(total_sum, total_count), (sum, count)| {
(*total_sum + *sum, total_count + count)
},
)
.into_dyn();
}
}
let shape = result.shape().to_owned();
(result, shape)
}
};
let sums = result.iter().map(|(sum, _)| *sum).collect::<Vec<T>>();
let counts = result.iter().map(|(_, count)| *count).collect::<Vec<i64>>();
(sums, counts, shape)
}
impl NumOperation for Sum {
fn execute_t<T: Element>(
request_data: &models::RequestData,
mut data: Vec<u8>,
) -> Result<models::Response, ActiveStorageError> {
let array = array::build_array::<T>(request_data, &mut data)?;
let slice_info = array::build_slice_info::<T>(&request_data.selection, array.shape());
let sliced = array.slice(slice_info);
let typed_missing: Option<Missing<T>> = if let Some(missing) = &request_data.missing {
let m = Missing::try_from(missing)?;
Some(m)
} else {
None
};
let (body, counts, shape) = match &request_data.axis {
ReductionAxes::One(axis) => {
let (sums, counts, shape) = sum_array_multi_axis(
sliced.view(),
&[*axis],
typed_missing,
&request_data.order,
);
let body = Bytes::copy_from_slice(sums.as_bytes());
(body, counts, shape)
}
ReductionAxes::Multi(axes) => {
let (sums, counts, shape) =
sum_array_multi_axis(sliced, axes, typed_missing, &request_data.order);
let body = Bytes::copy_from_slice(sums.as_bytes());
(body, counts, shape)
}
ReductionAxes::All => {
let (sum, count) = sliced.fold((T::zero(), 0_i64), |(sum, count), val| {
if let Some(missing) = &typed_missing {
if !missing.is_missing(val) {
(sum + *val, count + 1)
} else {
(sum, count)
}
} else {
(sum + *val, count + 1)
}
});
let body = Bytes::copy_from_slice(sum.as_bytes());
(body, vec![count], vec![])
}
};
Ok(models::Response::new(
body,
request_data.dtype,
shape,
counts,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::ReductionAxes;
use crate::operation::Operation;
use crate::test_utils;
use crate::types::DValue;
#[test]
fn count_i32_1d() {
let request_data = test_utils::get_test_request_data();
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
let response = Count::execute(&request_data, data).unwrap();
let expected = vec![2];
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len()); assert_eq!(models::DType::Int64, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(expected, response.count);
}
#[test]
fn count_u32_1d_missing_value() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Uint32;
request_data.missing = Some(Missing::MissingValue(0x04030201.into()));
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
let response = Count::execute(&request_data, data).unwrap();
let expected = vec![1];
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len()); assert_eq!(models::DType::Int64, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(expected, response.count);
}
#[test]
fn max_i64_1d() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Int64;
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
let response = Max::execute(&request_data, data).unwrap();
let expected: i64 = 0x0807060504030201;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len());
assert_eq!(models::DType::Int64, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![1], response.count);
}
#[test]
fn max_i64_1d_missing_values() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Int64;
request_data.missing = Some(Missing::MissingValues(vec![0x0807060504030201_i64.into()]));
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let response = Max::execute(&request_data, data).unwrap();
let expected: i64 = 0x100f0e0d0c0b0a09;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len());
assert_eq!(models::DType::Int64, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![1], response.count);
}
#[test]
fn max_f32_1d_infinity() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
let floats = [1.0, f32::INFINITY];
let data = floats.as_bytes();
let response = Max::execute(&request_data, data.into()).unwrap();
let expected = f32::INFINITY;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn max_f32_1d_infinity_first() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
let floats = [f32::INFINITY, 1.0];
let data = floats.as_bytes();
let response = Max::execute(&request_data, data.into()).unwrap();
let expected = f32::INFINITY;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn min_u64_1d() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Uint64;
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
let response = Min::execute(&request_data, data).unwrap();
let expected: u64 = 0x0807060504030201;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len());
assert_eq!(models::DType::Uint64, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![1], response.count);
}
#[test]
fn min_i64_1d_valid_min() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Int64;
request_data.missing = Some(Missing::ValidMin(0x0807060504030202_i64.into()));
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let response = Min::execute(&request_data, data).unwrap();
let expected: i64 = 0x100f0e0d0c0b0a09;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len());
assert_eq!(models::DType::Int64, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![1], response.count);
}
#[test]
fn min_f32_1d_infinity() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
let floats = [1.0, f32::INFINITY];
let data = floats.as_bytes();
let response = Min::execute(&request_data, data.into()).unwrap();
let expected = 1.0_f32;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn min_f32_1d_infinity_first() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
let floats = [f32::INFINITY, 1.0];
let data = floats.as_bytes();
let response = Min::execute(&request_data, data.into()).unwrap();
let expected = 1.0_f32;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
#[should_panic(expected = "unexpected undefined order error for min")]
fn min_f32_1d_nan() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
let floats = [1.0, f32::NAN];
let data = floats.as_bytes();
let response = Min::execute(&request_data, data.into()).unwrap();
let expected = 1.0_f32;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
#[should_panic(expected = "unexpected undefined order error for min")]
fn min_f32_1d_nan_first() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
let floats = [f32::NAN, 1.0];
let data = floats.as_bytes();
let response = Min::execute(&request_data, data.into()).unwrap();
let expected = 1.0_f32;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
#[should_panic(expected = "unexpected undefined order error for min")]
fn min_f32_1d_nan_missing_value() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
request_data.missing = Some(Missing::MissingValue(DValue::from_f64(42.0).unwrap()));
let floats = [1.0, f32::NAN];
let data = floats.as_bytes();
let response = Min::execute(&request_data, data.into()).unwrap();
let expected = 1.0_f32;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
#[should_panic(expected = "unexpected undefined order error for min")]
fn min_f32_1d_nan_first_missing_value() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
request_data.missing = Some(Missing::MissingValue(DValue::from_f64(42.0).unwrap()));
let floats = [f32::NAN, 1.0];
let data = floats.as_bytes();
let response = Min::execute(&request_data, data.into()).unwrap();
let expected = f32::NAN; assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn select_f32_1d() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
let response = Select::execute(&request_data, data).unwrap();
let expected: [u8; 8] = [1, 2, 3, 4, 5, 6, 7, 8];
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![2], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn select_f32_1d_1ax() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
request_data.axis = ReductionAxes::Multi(vec![0]);
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
let response = Select::execute(&request_data, data).unwrap();
let expected: [u8; 8] = [1, 2, 3, 4, 5, 6, 7, 8];
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![2], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn select_f64_2d() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float64;
request_data.shape = Some(vec![2, 1]);
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let response = Select::execute(&request_data, data).unwrap();
let expected: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(16, response.body.len());
assert_eq!(models::DType::Float64, response.dtype);
assert_eq!(vec![2, 1], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn select_f32_2d_with_selection() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
request_data.shape = Some(vec![2, 2]);
request_data.selection = Some(vec![
models::Slice::new(0, 2, 1),
models::Slice::new(1, 2, 1),
]);
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let response = Select::execute(&request_data, data).unwrap();
let expected: [u8; 8] = [5, 6, 7, 8, 13, 14, 15, 16];
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![2, 1], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn sum_u32_1d() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Uint32;
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
let response = Sum::execute(&request_data, data).unwrap();
let expected: u32 = 0x04030201 + 0x08070605;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Uint32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn sum_u32_1d_valid_max() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Uint32;
request_data.missing = Some(Missing::ValidMax((0x08070605 - 1).into()));
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
let response = Sum::execute(&request_data, data).unwrap();
let expected: u32 = 0x04030201;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Uint32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![1], response.count);
}
#[test]
fn sum_f32_1d_infinity() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float32;
let floats = [1.0, f32::INFINITY];
let data = floats.as_bytes();
let response = Sum::execute(&request_data, data.into()).unwrap();
let expected = f32::INFINITY;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(4, response.body.len());
assert_eq!(models::DType::Float32, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
#[test]
fn sum_f64_1d_nan() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float64;
let floats = [f64::NAN, 1.0];
let data = floats.as_bytes();
let response = Sum::execute(&request_data, data.into()).unwrap();
let expected = f64::NAN;
assert_eq!(expected.as_bytes(), response.body);
assert_eq!(8, response.body.len());
assert_eq!(models::DType::Float64, response.dtype);
assert_eq!(vec![0; 0], response.shape);
assert_eq!(vec![2], response.count);
}
fn vec_from_bytes<T: zerocopy::AsBytes + zerocopy::FromBytes + Clone>(data: &Bytes) -> Vec<T> {
let mut data = data.to_vec();
let data = data.as_mut_slice();
let layout = zerocopy::LayoutVerified::<_, [T]>::new_slice(&mut data[..]).unwrap();
layout.into_mut_slice().to_vec()
}
#[test]
fn sum_u32_1d_axis_0() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Uint32;
request_data.shape = Some(vec![2, 4]);
request_data.axis = ReductionAxes::One(0);
let data: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
let response = Sum::execute(&request_data, data.as_bytes().into()).unwrap();
let result = vec_from_bytes::<u32>(&response.body);
let arr = ndarray::Array::from_shape_vec((2, 4), data).unwrap();
let expected = arr.sum_axis(Axis(0)).to_vec();
assert_eq!(result, expected);
assert_eq!(models::DType::Uint32, response.dtype);
assert_eq!(16, response.body.len()); assert_eq!(vec![4], response.shape);
assert_eq!(vec![2, 2, 2, 2], response.count);
}
#[test]
fn sum_u32_1d_axis_1_missing() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Uint32;
request_data.shape = Some(vec![2, 4]);
request_data.axis = ReductionAxes::One(1);
request_data.missing = Some(Missing::MissingValue(0.into()));
let data: Vec<u32> = vec![0, 2, 3, 4, 5, 6, 7, 8];
let response = Sum::execute(&request_data, data.as_bytes().into()).unwrap();
let result = vec_from_bytes::<u32>(&response.body);
let arr = ndarray::Array::from_shape_vec((2, 4), data).unwrap();
let expected = arr.sum_axis(Axis(1)).to_vec();
assert_eq!(result, expected);
assert_eq!(models::DType::Uint32, response.dtype);
assert_eq!(8, response.body.len()); assert_eq!(vec![2], response.shape);
assert_eq!(vec![3, 4], response.count); }
#[test]
fn sum_f64_1d_axis_1_missing() {
let mut request_data = test_utils::get_test_request_data();
request_data.dtype = models::DType::Float64;
request_data.shape = Some(vec![2, 2, 2]);
request_data.axis = ReductionAxes::One(1);
request_data.missing = Some(Missing::MissingValue(0.into()));
let data: Vec<f64> = vec![0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let response = Sum::execute(&request_data, data.as_bytes().into()).unwrap();
let result = vec_from_bytes::<f64>(&response.body);
let result = ndarray::Array::from_shape_vec((2, 2), result).unwrap();
let arr = ndarray::Array::from_shape_vec((2, 2, 2), data).unwrap();
let expected = arr.sum_axis(Axis(1));
assert_eq!(result, expected);
assert_eq!(models::DType::Float64, response.dtype);
assert_eq!(32, response.body.len()); assert_eq!(vec![2, 2], response.shape);
assert_eq!(vec![1, 2, 2, 2], response.count); }
#[test]
fn partial_cmp_behaviour() {
assert_eq!(
f64::INFINITY.partial_cmp(&1.0),
Some(std::cmp::Ordering::Greater)
);
assert_eq!(f64::NAN.partial_cmp(&1.0), None);
assert_eq!(
f64::INFINITY.partial_cmp(&f64::NEG_INFINITY),
Some(std::cmp::Ordering::Greater)
);
}
#[test]
#[should_panic(expected = "assertion failed: axis.index() < self.ndim()")]
fn sum_multi_axis_2d_wrong_axis() {
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
.unwrap()
.into_dyn();
let axes = vec![2];
let _ = sum_array_multi_axis(array.view(), &axes, None, &None);
}
#[test]
fn sum_multi_axis_2d_2ax() {
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
.unwrap()
.into_dyn();
let axes = vec![0, 1];
let (sum, count, shape) = sum_array_multi_axis(array.view(), &axes, None, &None);
assert_eq!(sum, vec![6]);
assert_eq!(count, vec![4]);
assert_eq!(shape, Vec::<usize>::new());
}
#[test]
fn sum_multi_axis_2d_2ax_missing() {
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
.unwrap()
.into_dyn();
let axes = vec![0, 1];
let missing = Missing::MissingValue(1);
let (sum, count, shape) = sum_array_multi_axis(array.view(), &axes, Some(missing), &None);
assert_eq!(sum, vec![5]);
assert_eq!(count, vec![3]);
assert_eq!(shape, Vec::<usize>::new());
}
#[test]
fn sum_multi_axis_2d_no_ax_some_missing() {
let axes = vec![];
let missing = Some(Missing::ValidMax(2));
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) = sum_array_multi_axis(arr.view(), &axes, missing, &None);
assert_eq!(result, arr.iter().copied().collect::<Vec<i64>>());
assert_eq!(counts, vec![1, 1, 1, 0, 0, 0]);
assert_eq!(shape, arr.shape());
}
#[test]
fn sum_multi_axis_2d_no_ax_some_missing_f_order() {
let axes = vec![];
let missing = Some(Missing::ValidMax(2));
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) =
sum_array_multi_axis(arr.view(), &axes, missing, &Some(Order::F));
assert_eq!(result, arr.t().iter().copied().collect::<Vec<i64>>());
assert_eq!(counts, vec![1, 0, 1, 0, 1, 0]);
assert_eq!(shape, arr.shape());
}
#[test]
fn sum_multi_axis_4d_1ax() {
let array = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
.unwrap()
.into_dyn();
let axes = vec![2];
let (sum, count, shape) = sum_array_multi_axis(array.view(), &axes, None, &None);
assert_eq!(sum, vec![1, 5, 9, 13, 17, 21]);
assert_eq!(count, vec![2, 2, 2, 2, 2, 2]);
assert_eq!(shape, vec![2, 3, 1]);
}
#[test]
fn sum_multi_axis_4d_3ax() {
let array = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
.unwrap()
.into_dyn();
let axes = vec![0, 1, 3];
let (sum, count, shape) = sum_array_multi_axis(array.view(), &axes, None, &None);
assert_eq!(sum, vec![30, 36]);
assert_eq!(count, vec![6, 6]);
assert_eq!(shape, vec![2]);
}
#[test]
#[should_panic(expected = "assertion failed: axis.index() < self.ndim()")]
fn min_multi_axis_2d_wrong_axis() {
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
.unwrap()
.into_dyn();
let axes = vec![2];
let _ = min_array_multi_axis(array.view(), &axes, None, &None);
}
#[test]
fn min_multi_axis_2d_2ax() {
let axes = vec![0, 1];
let missing = None;
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) = min_array_multi_axis(arr.view(), &axes, missing, &None);
assert_eq!(result, vec![0]);
assert_eq!(counts, vec![6]);
assert_eq!(shape, Vec::<usize>::new());
}
#[test]
fn min_multi_axis_2d_no_ax_some_missing() {
let axes = vec![];
let missing = Some(Missing::ValidMax(2));
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) = min_array_multi_axis(arr.view(), &axes, missing, &None);
assert_eq!(result, arr.iter().copied().collect::<Vec<i64>>());
assert_eq!(counts, vec![1, 1, 1, 0, 0, 0]);
assert_eq!(shape, arr.shape());
}
#[test]
fn min_multi_axis_2d_no_ax_some_missing_f_order() {
let axes = vec![];
let missing = Some(Missing::ValidMax(2));
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) =
min_array_multi_axis(arr.view(), &axes, missing, &Some(Order::F));
assert_eq!(result, arr.t().iter().copied().collect::<Vec<i64>>());
assert_eq!(counts, vec![1, 0, 1, 0, 1, 0]);
assert_eq!(shape, arr.shape());
}
#[test]
fn min_multi_axis_2d_1ax_missing() {
let axes = vec![1];
let missing = Missing::MissingValue(0);
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) = min_array_multi_axis(arr.view(), &axes, Some(missing), &None);
assert_eq!(result, vec![1, 3]);
assert_eq!(counts, vec![2, 3]);
assert_eq!(shape, vec![2]);
}
#[test]
fn min_multi_axis_4d_3ax_missing() {
let arr = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
.unwrap()
.into_dyn();
let axes = vec![0, 1, 3];
let missing = Missing::MissingValue(1);
let (result, counts, shape) = min_array_multi_axis(arr.view(), &axes, Some(missing), &None);
assert_eq!(result, vec![0, 3]);
assert_eq!(counts, vec![6, 5]);
assert_eq!(shape, vec![2]);
}
#[test]
#[should_panic(expected = "assertion failed: axis.index() < self.ndim()")]
fn max_multi_axis_2d_wrong_axis() {
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
.unwrap()
.into_dyn();
let axes = vec![2];
let _ = max_array_multi_axis(array.view(), &axes, None, &None);
}
#[test]
fn max_multi_axis_2d_2ax() {
let axes = vec![0, 1];
let missing = None;
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) = max_array_multi_axis(arr.view(), &axes, missing, &None);
assert_eq!(result, vec![5]);
assert_eq!(counts, vec![6]);
assert_eq!(shape, Vec::<usize>::new());
}
#[test]
fn max_multi_axis_2d_no_ax_some_missing() {
let axes = vec![];
let missing = Some(Missing::ValidMax(2));
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) = max_array_multi_axis(arr.view(), &axes, missing, &None);
assert_eq!(result, arr.iter().copied().collect::<Vec<i64>>());
assert_eq!(counts, vec![1, 1, 1, 0, 0, 0]);
assert_eq!(shape, arr.shape());
}
#[test]
fn max_multi_axis_2d_no_ax_some_missing_f_order() {
let axes = vec![];
let missing = Some(Missing::ValidMax(2));
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) =
max_array_multi_axis(arr.view(), &axes, missing, &Some(Order::F));
assert_eq!(result, arr.t().iter().copied().collect::<Vec<i64>>());
assert_eq!(counts, vec![1, 0, 1, 0, 1, 0]);
assert_eq!(shape, arr.shape());
}
#[test]
fn max_multi_axis_2d_1ax_missing() {
let axes = vec![1];
let missing = Missing::MissingValue(0);
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (result, counts, shape) = max_array_multi_axis(arr.view(), &axes, Some(missing), &None);
assert_eq!(result, vec![2, 5]);
assert_eq!(counts, vec![2, 3]);
assert_eq!(shape, vec![2]);
}
#[test]
fn max_multi_axis_4d_3ax_missing() {
let arr = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
.unwrap()
.into_dyn();
let axes = vec![0, 1, 3];
let missing = Missing::MissingValue(10);
let (result, counts, shape) = max_array_multi_axis(arr.view(), &axes, Some(missing), &None);
assert_eq!(result, vec![8, 11]);
assert_eq!(counts, vec![5, 6]);
assert_eq!(shape, vec![2]);
}
#[test]
#[should_panic(expected = "assertion failed: axis.index() < self.ndim()")]
fn count_multi_axis_2d_wrong_axis() {
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
.unwrap()
.into_dyn();
let axes = vec![2];
let _ = count_array_multi_axis(array.view(), &axes, None);
}
#[test]
fn count_multi_axis_2d_2ax() {
let axes = vec![0, 1];
let missing = None;
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, missing);
assert_eq!(counts, vec![6]);
assert_eq!(shape, Vec::<usize>::new());
}
#[test]
fn count_multi_axis_2d_no_ax() {
let axes = vec![];
let missing = None;
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, missing);
assert_eq!(counts, vec![1, 1, 1, 1, 1, 1]);
assert_eq!(shape, arr.shape().to_vec());
}
#[test]
fn count_multi_axis_2d_1ax_missing() {
let axes = vec![1];
let missing = Missing::MissingValue(0);
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
.unwrap()
.into_dyn();
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, Some(missing));
assert_eq!(counts, vec![2, 3]);
assert_eq!(shape, vec![2]);
}
#[test]
fn count_multi_axis_4d_3ax_multi_missing() {
let arr = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
.unwrap()
.into_dyn();
let axes = vec![0, 1, 3];
let missing = Missing::MissingValues(vec![9, 10, 11]);
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, Some(missing));
assert_eq!(counts, vec![5, 4]);
assert_eq!(shape, vec![2]);
}
#[test]
fn count_multi_axis_4d_3ax_missing() {
let arr = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
.unwrap()
.into_dyn();
let axes = vec![0, 1, 3];
let missing = Missing::MissingValue(10);
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, Some(missing));
assert_eq!(counts, vec![5, 6]);
assert_eq!(shape, vec![2]);
}
}