#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ReduceOp {
Sum,
Mean,
Max,
Min,
Prod,
All,
Any,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum AccumulationPrecision {
#[default]
Native,
BF16,
FP32,
FP64,
}
pub fn reduce_output_shape(input_shape: &[usize], dims: &[usize], keepdim: bool) -> Vec<usize> {
if keepdim {
input_shape
.iter()
.enumerate()
.map(|(i, &s)| if dims.contains(&i) { 1 } else { s })
.collect()
} else {
input_shape
.iter()
.enumerate()
.filter(|(i, _)| !dims.contains(i))
.map(|(_, &s)| s)
.collect()
}
}
#[inline]
pub fn compute_reduce_strides(shape: &[usize], dim: usize) -> (usize, usize, usize) {
let outer_size: usize = shape[..dim].iter().product::<usize>().max(1);
let reduce_size = shape[dim];
let inner_size: usize = shape[dim + 1..].iter().product::<usize>().max(1);
(outer_size, reduce_size, inner_size)
}
#[inline]
pub fn reduce_dim_output_shape(shape: &[usize], dim: usize, keepdim: bool) -> Vec<usize> {
reduce_output_shape(shape, &[dim], keepdim)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reduce_output_shape() {
assert_eq!(reduce_output_shape(&[2, 3, 4], &[1], false), vec![2, 4]);
assert_eq!(reduce_output_shape(&[2, 3, 4], &[1], true), vec![2, 1, 4]);
assert_eq!(reduce_output_shape(&[2, 3, 4], &[0, 2], false), vec![3]);
assert_eq!(
reduce_output_shape(&[2, 3, 4], &[0, 2], true),
vec![1, 3, 1]
);
assert_eq!(
reduce_output_shape(&[2, 3, 4], &[0, 1, 2], false),
Vec::<usize>::new()
);
assert_eq!(
reduce_output_shape(&[2, 3, 4], &[0, 1, 2], true),
vec![1, 1, 1]
);
}
#[test]
fn test_compute_reduce_strides() {
let (outer, reduce, inner) = compute_reduce_strides(&[2, 3, 4], 1);
assert_eq!((outer, reduce, inner), (2, 3, 4));
let (outer, reduce, inner) = compute_reduce_strides(&[2, 3, 4], 0);
assert_eq!((outer, reduce, inner), (1, 2, 12));
let (outer, reduce, inner) = compute_reduce_strides(&[2, 3, 4], 2);
assert_eq!((outer, reduce, inner), (6, 4, 1));
}
}