use ndarray::{Array2, Array3, ArrayView3};
use numpy::Element;
use rayon::prelude::*;
use reducers::axis;
use reducers::reducers_1d::{
lmedian_valid_in_place, median_valid_in_place, percentiles_valid_in_place, var_mean_valid,
Kind, Number,
};
use reducers::{Float, ScanPolicy};
const DEFAULT_STACK_PARALLEL_HW_THRESHOLD: usize = 10_000;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CombineKind {
Mean,
Median,
LMedian,
Sum,
Min,
Max,
Variance,
NanAverage,
}
impl CombineKind {
pub fn parse(s: &str) -> Option<Self> {
match s.to_ascii_lowercase().as_str() {
"mean" | "average" | "avg" => Some(Self::Mean),
"median" | "med" | "medi" => Some(Self::Median),
"lmedian" | "lmed" | "lmd" => Some(Self::LMedian),
"sum" => Some(Self::Sum),
"min" => Some(Self::Min),
"max" => Some(Self::Max),
"variance" | "var" => Some(Self::Variance),
_ => None,
}
}
fn reducers_kind(self) -> Option<Kind> {
match self {
Self::Mean => Some(Kind::Mean),
Self::Median => Some(Kind::Median),
Self::LMedian => Some(Kind::LMedian),
Self::Sum => Some(Kind::Sum),
Self::Min => Some(Kind::Min),
Self::Max => Some(Kind::Max),
Self::Variance => Some(Kind::Var),
Self::NanAverage => None,
}
}
}
fn axis0_shape<T>(arr: &ArrayView3<T>) -> (usize, usize, usize, usize) {
let (n, h, w) = (arr.shape()[0], arr.shape()[1], arr.shape()[2]);
(n, h, w, h * w)
}
fn contiguous_data<'a, T>(arr: &'a ArrayView3<T>, name: &str) -> &'a [T] {
arr.as_slice_memory_order()
.unwrap_or_else(|| panic!("{name} requires contiguous C-order arrays"))
}
fn array2_from_vec<T>(h: usize, w: usize, values: Vec<T>) -> Array2<T> {
Array2::from_shape_vec((h, w), values).expect("reducers axis output has H*W elements")
}
#[inline]
fn should_parallel_stack(hw: usize, n: usize) -> bool {
n > 0 && hw >= DEFAULT_STACK_PARALLEL_HW_THRESHOLD
}
#[inline]
fn compact_finite_strided<T: Float>(
data: &[T],
pixel: usize,
n: usize,
hw: usize,
buf: &mut [T],
) -> usize {
let mut count = 0;
let mut idx = pixel;
for _ in 0..n {
let x = data[idx];
if x.is_finite() {
buf[count] = x;
count += 1;
}
idx += hw;
}
count
}
#[inline]
fn median_strided<T: Float>(data: &[T], pixel: usize, n: usize, hw: usize, buf: &mut [T]) -> T {
let count = compact_finite_strided(data, pixel, n, hw, buf);
T::from_f64(median_valid_in_place(&mut buf[..count]))
}
#[inline]
fn lmedian_strided<T: Float>(data: &[T], pixel: usize, n: usize, hw: usize, buf: &mut [T]) -> T {
let count = compact_finite_strided(data, pixel, n, hw, buf);
T::from_f64(lmedian_valid_in_place(&mut buf[..count]))
}
#[inline]
fn var_mean_strided<T: Float>(
data: &[T],
pixel: usize,
n: usize,
hw: usize,
ddof: usize,
buf: &mut [T],
) -> (T, T) {
let count = compact_finite_strided(data, pixel, n, hw, buf);
let (var, mean) = var_mean_valid(&buf[..count], ddof);
(T::from_f64(var), T::from_f64(mean))
}
fn median_axis0_stack<T: Float>(arr: &ArrayView3<T>, lower: bool) -> Array2<T> {
let (n, h, w, outer) = axis0_shape(arr);
let data = contiguous_data(arr, "median kernel");
let mut values = vec![T::nan(); outer];
let compute_pixel = |tmp: &mut Vec<T>, (pixel, out_px): (usize, &mut T)| {
*out_px = if lower {
lmedian_strided(data, pixel, n, outer, tmp)
} else {
median_strided(data, pixel, n, outer, tmp)
};
};
if should_parallel_stack(outer, n) {
values
.par_iter_mut()
.enumerate()
.map_init(|| vec![T::zero(); n], compute_pixel)
.count();
} else {
let mut tmp = vec![T::zero(); n];
values
.iter_mut()
.enumerate()
.for_each(|item| compute_pixel(&mut tmp, item));
}
array2_from_vec(h, w, values)
}
pub fn combine_axis0<T: Float>(
arr: &ArrayView3<T>,
kind: CombineKind,
weights: Option<&[f64]>,
ddof: usize,
) -> Array2<T> {
let (n, h, w, outer) = axis0_shape(arr);
if matches!(kind, CombineKind::Median | CombineKind::LMedian) {
return median_axis0_stack(arr, matches!(kind, CombineKind::LMedian));
}
let data = contiguous_data(arr, "combine kernel");
if matches!(kind, CombineKind::NanAverage) {
let weights = weights.expect("nanaverage requires weights");
assert_eq!(weights.len(), n, "weights length must equal stack size N");
let weighted =
axis::weighted_axis0(data, weights, true, n, outer, ScanPolicy::SkipNonFinite);
let values = weighted.values.into_iter().map(T::from_f64).collect();
return array2_from_vec(h, w, values);
}
let reducers_kind = kind
.reducers_kind()
.expect("non-weighted combine kind has reducers mapping");
let values = axis::reduce_axis0(
data,
n,
outer,
reducers_kind,
ddof,
ScanPolicy::SkipNonFinite,
);
array2_from_vec(h, w, values)
}
pub fn variance_mean_axis0<T: Float>(arr: &ArrayView3<T>, ddof: usize) -> (Array2<T>, Array2<T>) {
let (n, h, w, outer) = axis0_shape(arr);
let data = contiguous_data(arr, "variance kernel");
let mut vars = vec![T::nan(); outer];
let mut means = vec![T::nan(); outer];
let compute_pixel =
|tmp: &mut Vec<T>, ((pixel, out_var), out_mean): ((usize, &mut T), &mut T)| {
let (var, mean) = var_mean_strided(data, pixel, n, outer, ddof, tmp);
*out_var = var;
*out_mean = mean;
};
if should_parallel_stack(outer, n) {
vars.par_iter_mut()
.enumerate()
.zip(means.par_iter_mut())
.map_init(|| vec![T::zero(); n], compute_pixel)
.count();
} else {
let mut tmp = vec![T::zero(); n];
vars.iter_mut()
.enumerate()
.zip(means.iter_mut())
.for_each(|item| compute_pixel(&mut tmp, item));
}
(array2_from_vec(h, w, vars), array2_from_vec(h, w, means))
}
pub fn percentiles_axis0<T: Float>(arr: &ArrayView3<T>, qs: &[f64]) -> Array3<T> {
let (n, h, w, outer) = axis0_shape(arr);
let nq = qs.len();
let data = contiguous_data(arr, "percentile kernel");
let mut by_pixel = vec![f64::NAN; outer * nq];
if nq > 0 {
let compute_pixel = |tmp: &mut Vec<T>, (pixel, out_q): (usize, &mut [f64])| {
let count = compact_finite_strided(data, pixel, n, outer, tmp);
percentiles_valid_in_place(&mut tmp[..count], qs, out_q);
};
if should_parallel_stack(outer, n) {
by_pixel
.par_chunks_mut(nq)
.enumerate()
.map_init(|| vec![T::zero(); n], compute_pixel)
.count();
} else {
let mut tmp = vec![T::zero(); n];
by_pixel
.chunks_mut(nq)
.enumerate()
.for_each(|item| compute_pixel(&mut tmp, item));
}
}
let values = by_pixel.into_iter().map(T::from_f64).collect();
Array3::from_shape_vec((h, w, nq), values).expect("percentile output has H*W*Q elements")
}
pub fn lmedian_axis0_ord<T>(arr: &ArrayView3<T>) -> Array2<T>
where
T: Copy + Element + Number,
{
let (n, h, w, outer) = axis0_shape(arr);
let data = contiguous_data(arr, "lmedian kernel");
let values = axis::reduce_axis0_number_exact(data, n, outer, Kind::LMedian);
array2_from_vec(h, w, values)
}