use alloc::vec;
use alloc::vec::Vec;
use burn_backend::{DType, Element};
use burn_std::{Bytes, Shape, bf16, f16};
use crate::strided_index::StridedIter;
use crate::{FlexTensor, Layout};
use super::{INDEX_DTYPE, float_storage_as_f32};
#[inline(always)]
fn assert_dim_fits_isize(dim_size: usize, dim: usize) {
assert!(
dim_size <= isize::MAX as usize,
"dimension {dim} has size {dim_size} which exceeds isize::MAX"
);
}
#[cfg(feature = "simd")]
use crate::simd::kernels;
#[cfg(feature = "simd")]
use crate::simd::aligned;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
fn truncate_i64_to_pod<E: bytemuck::Pod>(value: i64) -> E {
let bytes = value.to_ne_bytes();
let size = core::mem::size_of::<E>();
debug_assert!(size <= core::mem::size_of::<i64>());
let offset = if cfg!(target_endian = "big") {
core::mem::size_of::<i64>() - size
} else {
0
};
bytemuck::pod_read_unaligned(&bytes[offset..offset + size])
}
pub fn sum(tensor: FlexTensor) -> FlexTensor {
match tensor.dtype() {
DType::F32 => sum_f32(&tensor),
DType::F64 => sum_impl::<f64>(&tensor),
DType::F16 => reduce_scalar_half(&tensor, |a, b| a + b, 0.0, f16::to_f32, f16::from_f32),
DType::BF16 => reduce_scalar_half(&tensor, |a, b| a + b, 0.0, bf16::to_f32, bf16::from_f32),
DType::I8 => sum_impl_widening::<i8>(&tensor),
DType::I16 => sum_impl_widening::<i16>(&tensor),
DType::I32 => sum_impl_widening::<i32>(&tensor),
DType::I64 => sum_impl::<i64>(&tensor),
DType::U8 => sum_impl_widening::<u8>(&tensor),
DType::U16 => sum_impl_widening::<u16>(&tensor),
DType::U32 => sum_impl_widening::<u32>(&tensor),
DType::U64 => sum_impl::<u64>(&tensor),
_ => panic!("sum: unsupported dtype {:?}", tensor.dtype()),
}
}
fn sum_f32(tensor: &FlexTensor) -> FlexTensor {
let result = match tensor.layout().contiguous_offsets() {
Some((start, end)) => {
let data: &[f32] = tensor.storage();
let slice = &data[start..end];
sum_f32_contiguous(slice)
}
None => {
let data: &[f32] = tensor.storage();
let elem_count = tensor.layout().num_elements();
if data.len() == elem_count {
sum_f32_contiguous(data)
} else {
StridedIter::new(tensor.layout()).map(|idx| data[idx]).sum()
}
}
};
let bytes = Bytes::from_elems(vec![result]);
FlexTensor::new(bytes, Layout::contiguous(Shape::from(vec![1])), DType::F32)
}
#[inline]
fn sum_f32_contiguous(data: &[f32]) -> f32 {
#[cfg(feature = "rayon")]
if data.len() >= 4 * 1024 * 1024 {
return sum_f32_parallel(data);
}
#[cfg(feature = "simd")]
{
kernels::sum_f32(data)
}
#[cfg(not(feature = "simd"))]
{
data.iter().copied().sum()
}
}
#[cfg(feature = "rayon")]
#[inline]
fn sum_f32_parallel(data: &[f32]) -> f32 {
const CHUNK_SIZE: usize = 64 * 1024;
data.par_chunks(CHUNK_SIZE)
.map(|chunk| {
#[cfg(feature = "simd")]
{
kernels::sum_f32(chunk)
}
#[cfg(not(feature = "simd"))]
{
chunk.iter().copied().sum::<f32>()
}
})
.sum()
}
fn sum_impl<E: Element + bytemuck::Pod + Default + core::iter::Sum>(
tensor: &FlexTensor,
) -> FlexTensor {
let result: E = match tensor.layout().contiguous_offsets() {
Some((start, end)) => {
let data: &[E] = tensor.storage();
data[start..end].iter().copied().sum()
}
None => {
let data: &[E] = tensor.storage();
StridedIter::new(tensor.layout()).map(|idx| data[idx]).sum()
}
};
let bytes = Bytes::from_elems(vec![result]);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(vec![1])),
tensor.dtype(),
)
}
macro_rules! widening_scalar_reduce {
($name:ident, $fold:expr, $init:expr) => {
fn $name<E>(tensor: &FlexTensor) -> FlexTensor
where
E: Element + bytemuck::Pod + Default,
i64: From<E>,
{
let total: i64 = match tensor.layout().contiguous_offsets() {
Some((start, end)) => {
let data: &[E] = tensor.storage();
data[start..end]
.iter()
.fold($init, |acc, x| ($fold)(acc, i64::from(*x)))
}
None => {
let data: &[E] = tensor.storage();
StridedIter::new(tensor.layout())
.fold($init, |acc, idx| ($fold)(acc, i64::from(data[idx])))
}
};
let result: E = truncate_i64_to_pod(total);
let bytes = Bytes::from_elems(vec![result]);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(vec![1])),
tensor.dtype(),
)
}
};
}
widening_scalar_reduce!(
sum_impl_widening,
|acc: i64, x: i64| acc.wrapping_add(x),
0i64
);
widening_scalar_reduce!(
prod_impl_widening,
|acc: i64, x: i64| acc.wrapping_mul(x),
1i64
);
fn reduce_scalar_half<E>(
tensor: &FlexTensor,
fold: fn(f32, f32) -> f32,
init: f32,
to_f32: fn(E) -> f32,
from_f32: fn(f32) -> E,
) -> FlexTensor
where
E: Element + bytemuck::Pod,
{
let result: f32 = match tensor.layout().contiguous_offsets() {
Some((start, end)) => {
let data: &[E] = tensor.storage();
data[start..end]
.iter()
.fold(init, |acc, x| fold(acc, to_f32(*x)))
}
None => {
let data: &[E] = tensor.storage();
StridedIter::new(tensor.layout()).fold(init, |acc, idx| fold(acc, to_f32(data[idx])))
}
};
let bytes = Bytes::from_elems(vec![from_f32(result)]);
FlexTensor::new(bytes, Layout::contiguous(Shape::from(vec![1])), E::dtype())
}
pub fn sum_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
match tensor.dtype() {
DType::F32 => reduce_dim_f32(&tensor, dim, ReduceOp::Sum),
DType::F64 => reduce_dim_impl::<f64, _>(&tensor, dim, 0.0, |acc, x| acc + x),
DType::F16 => reduce_dim_half(
&tensor,
dim,
0.0,
|acc, x| acc + x,
f16::to_f32,
f16::from_f32,
),
DType::BF16 => reduce_dim_half(
&tensor,
dim,
0.0,
|acc, x| acc + x,
bf16::to_f32,
bf16::from_f32,
),
DType::I8 => reduce_dim_widening::<i8, _>(&tensor, dim, 0, |acc, x| acc.wrapping_add(x)),
DType::I16 => reduce_dim_widening::<i16, _>(&tensor, dim, 0, |acc, x| acc.wrapping_add(x)),
DType::I32 => reduce_dim_widening::<i32, _>(&tensor, dim, 0, |acc, x| acc.wrapping_add(x)),
DType::I64 => reduce_dim_impl::<i64, _>(&tensor, dim, 0, |acc, x| acc + x),
DType::U8 => reduce_dim_widening::<u8, _>(&tensor, dim, 0, |acc, x| acc.wrapping_add(x)),
DType::U16 => reduce_dim_widening::<u16, _>(&tensor, dim, 0, |acc, x| acc.wrapping_add(x)),
DType::U32 => reduce_dim_widening::<u32, _>(&tensor, dim, 0, |acc, x| acc.wrapping_add(x)),
DType::U64 => reduce_dim_impl::<u64, _>(&tensor, dim, 0, |acc, x| acc + x),
_ => panic!("sum_dim: unsupported dtype {:?}", tensor.dtype()),
}
}
pub fn mean_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
let dim_size = tensor.layout().shape()[dim];
assert!(
dim_size > 0,
"mean_dim: cannot take mean of empty dimension"
);
let dtype = tensor.dtype();
match dtype {
DType::F16 => return mean_dim_half::<f16>(&tensor, dim),
DType::BF16 => return mean_dim_half::<bf16>(&tensor, dim),
_ => {}
}
let sum_result = sum_dim(tensor, dim);
match dtype {
DType::F32 => scalar_div::<f32>(sum_result, dim_size as f32),
DType::F64 => scalar_div::<f64>(sum_result, dim_size as f64),
DType::I8 => {
let divisor = dim_size as i32;
let mut tensor = sum_result;
let data: &mut [i8] = tensor.storage_mut();
for x in data.iter_mut() {
*x = ((*x as i32) / divisor) as i8;
}
tensor
}
DType::I16 => {
let divisor = dim_size as i32;
let mut tensor = sum_result;
let data: &mut [i16] = tensor.storage_mut();
for x in data.iter_mut() {
*x = ((*x as i32) / divisor) as i16;
}
tensor
}
DType::I32 => scalar_div::<i32>(sum_result, dim_size as i32),
DType::I64 => scalar_div::<i64>(sum_result, dim_size as i64),
DType::U8 => {
let divisor = dim_size as u32;
let mut tensor = sum_result;
let data: &mut [u8] = tensor.storage_mut();
for x in data.iter_mut() {
*x = ((*x as u32) / divisor) as u8;
}
tensor
}
DType::U16 => {
let divisor = dim_size as u32;
let mut tensor = sum_result;
let data: &mut [u16] = tensor.storage_mut();
for x in data.iter_mut() {
*x = ((*x as u32) / divisor) as u16;
}
tensor
}
DType::U32 => scalar_div::<u32>(sum_result, dim_size as u32),
DType::U64 => scalar_div::<u64>(sum_result, dim_size as u64),
_ => panic!("mean_dim: unsupported dtype {:?}", dtype),
}
}
pub fn prod(tensor: FlexTensor) -> FlexTensor {
match tensor.dtype() {
DType::F32 => prod_impl::<f32>(&tensor),
DType::F64 => prod_impl::<f64>(&tensor),
DType::F16 => reduce_scalar_half(&tensor, |a, b| a * b, 1.0, f16::to_f32, f16::from_f32),
DType::BF16 => reduce_scalar_half(&tensor, |a, b| a * b, 1.0, bf16::to_f32, bf16::from_f32),
DType::I8 => prod_impl_widening::<i8>(&tensor),
DType::I16 => prod_impl_widening::<i16>(&tensor),
DType::I32 => prod_impl_widening::<i32>(&tensor),
DType::I64 => prod_impl::<i64>(&tensor),
DType::U8 => prod_impl_widening::<u8>(&tensor),
DType::U16 => prod_impl_widening::<u16>(&tensor),
DType::U32 => prod_impl_widening::<u32>(&tensor),
DType::U64 => prod_impl::<u64>(&tensor),
_ => panic!("prod: unsupported dtype {:?}", tensor.dtype()),
}
}
fn prod_impl<E: Element + bytemuck::Pod + Default + core::iter::Product>(
tensor: &FlexTensor,
) -> FlexTensor {
let result: E = match tensor.layout().contiguous_offsets() {
Some((start, end)) => {
let data: &[E] = tensor.storage();
data[start..end].iter().copied().product()
}
None => {
let data: &[E] = tensor.storage();
StridedIter::new(tensor.layout())
.map(|idx| data[idx])
.product()
}
};
let bytes = Bytes::from_elems(vec![result]);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(vec![1])),
tensor.dtype(),
)
}
pub fn prod_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
match tensor.dtype() {
DType::F32 => reduce_dim_f32(&tensor, dim, ReduceOp::Prod),
DType::F64 => reduce_dim_impl::<f64, _>(&tensor, dim, 1.0, |acc, x| acc * x),
DType::F16 => reduce_dim_half(
&tensor,
dim,
1.0,
|acc, x| acc * x,
f16::to_f32,
f16::from_f32,
),
DType::BF16 => reduce_dim_half(
&tensor,
dim,
1.0,
|acc, x| acc * x,
bf16::to_f32,
bf16::from_f32,
),
DType::I8 => reduce_dim_widening::<i8, _>(&tensor, dim, 1, |acc, x| acc.wrapping_mul(x)),
DType::I16 => reduce_dim_widening::<i16, _>(&tensor, dim, 1, |acc, x| acc.wrapping_mul(x)),
DType::I32 => reduce_dim_widening::<i32, _>(&tensor, dim, 1, |acc, x| acc.wrapping_mul(x)),
DType::I64 => reduce_dim_impl::<i64, _>(&tensor, dim, 1, |acc, x| acc * x),
DType::U8 => reduce_dim_widening::<u8, _>(&tensor, dim, 1, |acc, x| acc.wrapping_mul(x)),
DType::U16 => reduce_dim_widening::<u16, _>(&tensor, dim, 1, |acc, x| acc.wrapping_mul(x)),
DType::U32 => reduce_dim_widening::<u32, _>(&tensor, dim, 1, |acc, x| acc.wrapping_mul(x)),
DType::U64 => reduce_dim_impl::<u64, _>(&tensor, dim, 1, |acc, x| acc * x),
_ => panic!("prod_dim: unsupported dtype {:?}", tensor.dtype()),
}
}
pub fn max(tensor: FlexTensor) -> FlexTensor {
match tensor.dtype() {
DType::F32 => max_f32_reduce(&tensor),
DType::F64 => max_impl::<f64>(&tensor),
DType::F16 => reduce_scalar_half(
&tensor,
f32::max,
f32::NEG_INFINITY,
f16::to_f32,
f16::from_f32,
),
DType::BF16 => reduce_scalar_half(
&tensor,
f32::max,
f32::NEG_INFINITY,
bf16::to_f32,
bf16::from_f32,
),
DType::I8 => max_impl::<i8>(&tensor),
DType::I16 => max_impl::<i16>(&tensor),
DType::I32 => max_impl::<i32>(&tensor),
DType::I64 => max_impl::<i64>(&tensor),
DType::U8 => max_impl::<u8>(&tensor),
DType::U16 => max_impl::<u16>(&tensor),
DType::U32 => max_impl::<u32>(&tensor),
DType::U64 => max_impl::<u64>(&tensor),
_ => panic!("max: unsupported dtype {:?}", tensor.dtype()),
}
}
pub fn min(tensor: FlexTensor) -> FlexTensor {
match tensor.dtype() {
DType::F32 => min_f32_reduce(&tensor),
DType::F64 => min_impl::<f64>(&tensor),
DType::F16 => {
reduce_scalar_half(&tensor, f32::min, f32::INFINITY, f16::to_f32, f16::from_f32)
}
DType::BF16 => reduce_scalar_half(
&tensor,
f32::min,
f32::INFINITY,
bf16::to_f32,
bf16::from_f32,
),
DType::I8 => min_impl::<i8>(&tensor),
DType::I16 => min_impl::<i16>(&tensor),
DType::I32 => min_impl::<i32>(&tensor),
DType::I64 => min_impl::<i64>(&tensor),
DType::U8 => min_impl::<u8>(&tensor),
DType::U16 => min_impl::<u16>(&tensor),
DType::U32 => min_impl::<u32>(&tensor),
DType::U64 => min_impl::<u64>(&tensor),
_ => panic!("min: unsupported dtype {:?}", tensor.dtype()),
}
}
fn max_f32_reduce(tensor: &FlexTensor) -> FlexTensor {
let result = match tensor.layout().contiguous_offsets() {
Some((start, end)) => {
let data: &[f32] = tensor.storage();
max_f32_contiguous(&data[start..end])
}
None => {
let data: &[f32] = tensor.storage();
let elem_count = tensor.layout().num_elements();
if data.len() == elem_count {
max_f32_contiguous(data)
} else {
StridedIter::new(tensor.layout())
.map(|idx| data[idx])
.reduce(|a, b| if a >= b { a } else { b })
.expect("max: tensor must not be empty")
}
}
};
let bytes = Bytes::from_elems(vec![result]);
FlexTensor::new(bytes, Layout::contiguous(Shape::from(vec![1])), DType::F32)
}
#[inline]
fn max_f32_contiguous(data: &[f32]) -> f32 {
#[cfg(feature = "simd")]
{
kernels::max_f32(data)
}
#[cfg(not(feature = "simd"))]
{
data.iter()
.copied()
.reduce(|a, b| if a >= b { a } else { b })
.expect("max: tensor must not be empty")
}
}
fn min_f32_reduce(tensor: &FlexTensor) -> FlexTensor {
let result = match tensor.layout().contiguous_offsets() {
Some((start, end)) => {
let data: &[f32] = tensor.storage();
min_f32_contiguous(&data[start..end])
}
None => {
let data: &[f32] = tensor.storage();
let elem_count = tensor.layout().num_elements();
if data.len() == elem_count {
min_f32_contiguous(data)
} else {
StridedIter::new(tensor.layout())
.map(|idx| data[idx])
.reduce(|a, b| if a <= b { a } else { b })
.expect("min: tensor must not be empty")
}
}
};
let bytes = Bytes::from_elems(vec![result]);
FlexTensor::new(bytes, Layout::contiguous(Shape::from(vec![1])), DType::F32)
}
#[inline]
fn min_f32_contiguous(data: &[f32]) -> f32 {
#[cfg(feature = "simd")]
{
kernels::min_f32(data)
}
#[cfg(not(feature = "simd"))]
{
data.iter()
.copied()
.reduce(|a, b| if a <= b { a } else { b })
.expect("min: tensor must not be empty")
}
}
fn max_impl<E: Element + bytemuck::Pod + PartialOrd>(tensor: &FlexTensor) -> FlexTensor {
let result: E = match tensor.layout().contiguous_offsets() {
Some((start, end)) => {
let data: &[E] = tensor.storage();
data[start..end]
.iter()
.copied()
.reduce(|a, b| if a >= b { a } else { b })
.expect("max: tensor must not be empty")
}
None => {
let data: &[E] = tensor.storage();
StridedIter::new(tensor.layout())
.map(|idx| data[idx])
.reduce(|a, b| if a >= b { a } else { b })
.expect("max: tensor must not be empty")
}
};
let bytes = Bytes::from_elems(vec![result]);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(vec![1])),
tensor.dtype(),
)
}
fn min_impl<E: Element + bytemuck::Pod + PartialOrd>(tensor: &FlexTensor) -> FlexTensor {
let result: E = match tensor.layout().contiguous_offsets() {
Some((start, end)) => {
let data: &[E] = tensor.storage();
data[start..end]
.iter()
.copied()
.reduce(|a, b| if a <= b { a } else { b })
.expect("min: tensor must not be empty")
}
None => {
let data: &[E] = tensor.storage();
StridedIter::new(tensor.layout())
.map(|idx| data[idx])
.reduce(|a, b| if a <= b { a } else { b })
.expect("min: tensor must not be empty")
}
};
let bytes = Bytes::from_elems(vec![result]);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(vec![1])),
tensor.dtype(),
)
}
pub fn argmax(tensor: FlexTensor, dim: usize) -> FlexTensor {
assert_dim_fits_isize(tensor.layout().shape()[dim], dim);
if tensor.dtype() == DType::F32 && dim == tensor.layout().shape().num_dims() - 1 {
#[cfg(feature = "simd")]
if tensor.layout().shape()[dim] >= EXTREMUM_SIMD_ROW_THRESHOLD {
return extremum_indices_f32_last_simd(&tensor, dim, kernels::max_f32);
}
return extremum_indices_f32_last_scalar(&tensor, dim, |a, b| a > b);
}
match tensor.dtype() {
DType::F32 => {
extremum_dim_with_indices::<f32, _>(&tensor, dim, |a, b| {
!b.is_nan() && (a.is_nan() || a > b)
})
.1
}
DType::F64 => {
extremum_dim_with_indices::<f64, _>(&tensor, dim, |a, b| {
!b.is_nan() && (a.is_nan() || a > b)
})
.1
}
DType::F16 => {
extremum_dim_with_indices_half::<f16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a > b),
f16::to_f32,
f16::from_f32,
)
.1
}
DType::BF16 => {
extremum_dim_with_indices_half::<bf16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a > b),
bf16::to_f32,
bf16::from_f32,
)
.1
}
DType::I8 => extremum_dim_with_indices::<i8, _>(&tensor, dim, |a, b| a > b).1,
DType::I16 => extremum_dim_with_indices::<i16, _>(&tensor, dim, |a, b| a > b).1,
DType::I32 => extremum_dim_with_indices::<i32, _>(&tensor, dim, |a, b| a > b).1,
DType::I64 => extremum_dim_with_indices::<i64, _>(&tensor, dim, |a, b| a > b).1,
_ => panic!("argmax: unsupported dtype {:?}", tensor.dtype()),
}
}
pub fn argmin(tensor: FlexTensor, dim: usize) -> FlexTensor {
assert_dim_fits_isize(tensor.layout().shape()[dim], dim);
if tensor.dtype() == DType::F32 && dim == tensor.layout().shape().num_dims() - 1 {
#[cfg(feature = "simd")]
if tensor.layout().shape()[dim] >= EXTREMUM_SIMD_ROW_THRESHOLD {
return extremum_indices_f32_last_simd(&tensor, dim, kernels::min_f32);
}
return extremum_indices_f32_last_scalar(&tensor, dim, |a, b| a < b);
}
match tensor.dtype() {
DType::F32 => {
extremum_dim_with_indices::<f32, _>(&tensor, dim, |a, b| {
!b.is_nan() && (a.is_nan() || a < b)
})
.1
}
DType::F64 => {
extremum_dim_with_indices::<f64, _>(&tensor, dim, |a, b| {
!b.is_nan() && (a.is_nan() || a < b)
})
.1
}
DType::F16 => {
extremum_dim_with_indices_half::<f16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a < b),
f16::to_f32,
f16::from_f32,
)
.1
}
DType::BF16 => {
extremum_dim_with_indices_half::<bf16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a < b),
bf16::to_f32,
bf16::from_f32,
)
.1
}
DType::I8 => extremum_dim_with_indices::<i8, _>(&tensor, dim, |a, b| a < b).1,
DType::I16 => extremum_dim_with_indices::<i16, _>(&tensor, dim, |a, b| a < b).1,
DType::I32 => extremum_dim_with_indices::<i32, _>(&tensor, dim, |a, b| a < b).1,
DType::I64 => extremum_dim_with_indices::<i64, _>(&tensor, dim, |a, b| a < b).1,
_ => panic!("argmin: unsupported dtype {:?}", tensor.dtype()),
}
}
#[derive(Clone, Copy)]
enum ReduceOp {
Sum,
Prod,
}
fn reduce_dim_f32(tensor: &FlexTensor, dim: usize, op: ReduceOp) -> FlexTensor {
let ndims = tensor.layout().shape().num_dims();
assert!(
dim < ndims,
"dim {} out of bounds for {} dimensions",
dim,
ndims
);
let outer_dims = dim;
let inner_dims = ndims - dim - 1;
let needs_copy = !tensor.is_contiguous() && (outer_dims > 1 || inner_dims > 1);
let tensor = if needs_copy {
tensor.to_contiguous()
} else {
tensor.clone()
};
let shape = tensor.layout().shape();
let strides = tensor.layout().strides();
let dim_size = shape[dim];
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let out_size: usize = out_shape.iter().product();
if out_size == 0 {
return FlexTensor::new(
Bytes::from_elems(Vec::<f32>::new()),
Layout::contiguous(Shape::from(out_shape)),
DType::F32,
);
}
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let data: &[f32] = tensor.storage();
let start_offset = tensor.layout().start_offset();
let dim_stride = strides[dim];
let (init, reduce_fn): (f32, fn(f32, f32) -> f32) = match op {
ReduceOp::Sum => (0.0, |a, b| a + b),
ReduceOp::Prod => (1.0, |a, b| a * b),
};
let has_negative_strides = strides.iter().any(|&s| s < 0);
let inner_contiguous = !has_negative_strides && (dim + 1 >= ndims || strides[ndims - 1] == 1);
let result: Vec<f32> = if inner_contiguous && dim == ndims - 1 && dim_stride == 1 {
reduce_last_dim_f32(data, start_offset, outer_size, dim_size, strides, dim, op)
} else if dim == 0 && inner_contiguous && matches!(op, ReduceOp::Sum) {
reduce_first_dim_f32(data, start_offset, dim_size, inner_size, dim_stride)
} else if dim > 0 && dim < ndims - 1 && inner_contiguous && matches!(op, ReduceOp::Sum) {
let outer_stride = strides[dim - 1];
reduce_middle_dim_f32(
data,
start_offset,
outer_size,
dim_size,
inner_size,
outer_stride,
dim_stride,
)
} else if dim_stride == 1 && matches!(op, ReduceOp::Sum) && outer_size == 1 {
#[cfg(feature = "simd")]
{
let mut result = vec![0.0f32; inner_size];
kernels::sum_rows_f32(
&data[start_offset..],
&mut result,
inner_size, dim_size, );
result
}
#[cfg(not(feature = "simd"))]
{
let inner_stride: isize = if dim + 1 < ndims { strides[dim + 1] } else { 1 };
let mut result = Vec::with_capacity(out_size);
for inner in 0..inner_size {
let base = (start_offset as isize + inner as isize * inner_stride) as usize;
let slice = &data[base..base + dim_size];
result.push(slice.iter().copied().sum());
}
result
}
} else if dim_stride == 1 && matches!(op, ReduceOp::Sum) {
let outer_stride: isize = if dim > 0 { strides[dim - 1] } else { 0 };
let inner_stride: isize = if dim + 1 < ndims { strides[dim + 1] } else { 1 };
let mut result = Vec::with_capacity(out_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
let base = (start_offset as isize
+ outer as isize * outer_stride
+ inner as isize * inner_stride) as usize;
let slice = &data[base..base + dim_size];
#[cfg(feature = "simd")]
let acc = kernels::sum_f32(slice);
#[cfg(not(feature = "simd"))]
let acc = slice.iter().copied().sum();
result.push(acc);
}
}
result
} else if tensor.is_contiguous() {
let mut result = Vec::with_capacity(out_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
let mut acc = init;
for d in 0..dim_size {
let idx = start_offset + outer * dim_size * inner_size + d * inner_size + inner;
acc = reduce_fn(acc, data[idx]);
}
result.push(acc);
}
}
result
} else {
let outer_stride: isize = if dim > 0 { strides[dim - 1] } else { 0 };
let inner_stride: isize = if dim + 1 < ndims { strides[dim + 1] } else { 1 };
let mut result = Vec::with_capacity(out_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
let base = start_offset as isize
+ outer as isize * outer_stride
+ inner as isize * inner_stride;
let mut acc = init;
for d in 0..dim_size {
let idx = (base + d as isize * dim_stride) as usize;
acc = reduce_fn(acc, data[idx]);
}
result.push(acc);
}
}
result
};
let bytes = Bytes::from_elems(result);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(out_shape)),
DType::F32,
)
}
#[inline]
fn reduce_middle_dim_f32(
data: &[f32],
start_offset: usize,
outer_size: usize, dim_size: usize, inner_size: usize, outer_stride: isize,
dim_stride: isize,
) -> Vec<f32> {
let out_size = outer_size * inner_size;
#[cfg(feature = "simd")]
{
let mut result = aligned::alloc_aligned_zeroed::<f32>(out_size);
kernels::scatter_add_batched(
&data[start_offset..],
&mut result,
outer_size,
dim_size,
inner_size,
outer_stride as usize,
dim_stride as usize,
);
aligned::to_vec(result)
}
#[cfg(not(feature = "simd"))]
{
let mut result = vec![0.0f32; out_size];
let start = start_offset as isize;
for batch in 0..outer_size {
let batch_start = (start + batch as isize * outer_stride) as usize;
let out_batch_start = batch * inner_size;
for row in 0..dim_size {
let row_start = (batch_start as isize + row as isize * dim_stride) as usize;
for c in 0..inner_size {
result[out_batch_start + c] += data[row_start + c];
}
}
}
result
}
}
#[inline]
fn reduce_first_dim_f32(
data: &[f32],
start_offset: usize,
dim_size: usize, inner_size: usize, dim_stride: isize, ) -> Vec<f32> {
#[cfg(feature = "simd")]
{
let mut result = aligned::alloc_aligned_zeroed::<f32>(inner_size);
kernels::scatter_add_f32(
&data[start_offset..],
&mut result,
dim_size,
inner_size,
dim_stride as usize,
);
aligned::to_vec(result)
}
#[cfg(not(feature = "simd"))]
{
let mut result = vec![0.0f32; inner_size];
let start = start_offset as isize;
for row in 0..dim_size {
let row_start = (start + row as isize * dim_stride) as usize;
for c in 0..inner_size {
result[c] += data[row_start + c];
}
}
result
}
}
#[inline]
fn reduce_last_dim_f32(
data: &[f32],
start_offset: usize,
outer_size: usize,
dim_size: usize,
strides: &[isize],
dim: usize,
op: ReduceOp,
) -> Vec<f32> {
let outer_stride: isize = if dim > 0 {
strides[dim - 1]
} else {
dim_size as isize
};
let rows = outer_size;
#[cfg(feature = "simd")]
if matches!(op, ReduceOp::Sum) && outer_stride == dim_size as isize {
let mut result = vec![0.0f32; rows];
kernels::sum_rows_f32(&data[start_offset..], &mut result, rows, dim_size);
return result;
}
let mut result = Vec::with_capacity(rows);
for outer in 0..rows {
let row_start = (start_offset as isize + outer as isize * outer_stride) as usize;
let row = &data[row_start..row_start + dim_size];
let val = match op {
ReduceOp::Sum => {
#[cfg(feature = "simd")]
{
kernels::sum_f32(row)
}
#[cfg(not(feature = "simd"))]
{
row.iter().copied().sum()
}
}
ReduceOp::Prod => row.iter().copied().product(),
};
result.push(val);
}
result
}
fn reduce_dim_impl<E, F>(tensor: &FlexTensor, dim: usize, init: E, reduce_fn: F) -> FlexTensor
where
E: Element + bytemuck::Pod + Copy,
F: Fn(E, E) -> E,
{
let ndims = tensor.layout().shape().num_dims();
assert!(
dim < ndims,
"dim {} out of bounds for {} dimensions",
dim,
ndims
);
let outer_dims = dim;
let inner_dims = ndims - dim - 1;
let needs_copy = !tensor.is_contiguous() && (outer_dims > 1 || inner_dims > 1);
let tensor = if needs_copy {
tensor.to_contiguous()
} else {
tensor.clone()
};
let shape = tensor.layout().shape();
let strides = tensor.layout().strides();
let dim_size = shape[dim];
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let out_size: usize = out_shape.iter().product();
if out_size == 0 {
return FlexTensor::new(
Bytes::from_elems(Vec::<E>::new()),
Layout::contiguous(Shape::from(out_shape)),
tensor.dtype(),
);
}
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let data: &[E] = tensor.storage();
let start_offset = tensor.layout().start_offset();
let mut result: Vec<E> = Vec::with_capacity(out_size);
if tensor.is_contiguous() {
for outer in 0..outer_size {
for inner in 0..inner_size {
let mut acc = init;
for d in 0..dim_size {
let idx = start_offset + outer * dim_size * inner_size + d * inner_size + inner;
acc = reduce_fn(acc, data[idx]);
}
result.push(acc);
}
}
} else {
let dim_stride = strides[dim];
let outer_stride: isize = if dim > 0 { strides[dim - 1] } else { 0 };
let inner_stride: isize = if dim + 1 < ndims { strides[dim + 1] } else { 1 };
for outer in 0..outer_size {
for inner in 0..inner_size {
let base = start_offset as isize
+ outer as isize * outer_stride
+ inner as isize * inner_stride;
let mut acc = init;
for d in 0..dim_size {
let idx = (base + d as isize * dim_stride) as usize;
acc = reduce_fn(acc, data[idx]);
}
result.push(acc);
}
}
}
let bytes = Bytes::from_elems(result);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(out_shape)),
tensor.dtype(),
)
}
fn reduce_dim_widening<E, F>(tensor: &FlexTensor, dim: usize, init: i64, reduce_fn: F) -> FlexTensor
where
E: Element + bytemuck::Pod,
i64: From<E>,
F: Fn(i64, i64) -> i64,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let ndims = shape.num_dims();
assert!(
dim < ndims,
"dim {} out of bounds for {} dimensions",
dim,
ndims
);
let dim_size = shape[dim];
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let out_size: usize = out_shape.iter().product();
if out_size == 0 {
return FlexTensor::new(
Bytes::from_elems(Vec::<E>::new()),
Layout::contiguous(Shape::from(out_shape)),
tensor.dtype(),
);
}
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let data: &[E] = tensor.storage();
let start_offset = tensor.layout().start_offset();
let mut result: Vec<E> = Vec::with_capacity(out_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
let mut acc = init;
for d in 0..dim_size {
let idx = start_offset + outer * dim_size * inner_size + d * inner_size + inner;
acc = reduce_fn(acc, i64::from(data[idx]));
}
let val: E = truncate_i64_to_pod(acc);
result.push(val);
}
}
let bytes = Bytes::from_elems(result);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(out_shape)),
tensor.dtype(),
)
}
fn reduce_dim_half<E, F>(
tensor: &FlexTensor,
dim: usize,
init: f32,
reduce_fn: F,
to_f32: fn(E) -> f32,
from_f32: fn(f32) -> E,
) -> FlexTensor
where
E: Element + bytemuck::Pod,
F: Fn(f32, f32) -> f32,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let ndims = shape.num_dims();
assert!(
dim < ndims,
"dim {} out of bounds for {} dimensions",
dim,
ndims
);
let dim_size = shape[dim];
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let out_size: usize = out_shape.iter().product();
if out_size == 0 {
return FlexTensor::new(
Bytes::from_elems(Vec::<E>::new()),
Layout::contiguous(Shape::from(out_shape)),
E::dtype(),
);
}
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let data: &[E] = tensor.storage();
let start_offset = tensor.layout().start_offset();
let mut result: Vec<E> = Vec::with_capacity(out_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
let mut acc = init;
for d in 0..dim_size {
let idx = start_offset + outer * dim_size * inner_size + d * inner_size + inner;
acc = reduce_fn(acc, to_f32(data[idx]));
}
result.push(from_f32(acc));
}
}
let bytes = Bytes::from_elems(result);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(out_shape)),
E::dtype(),
)
}
fn sum_dim_contiguous_f32(
data: &[f32],
outer_size: usize,
dim_size: usize,
inner_size: usize,
) -> Vec<f32> {
if outer_size == 0 || inner_size == 0 {
return Vec::new();
}
if inner_size == 1 {
let rows = outer_size;
#[cfg(feature = "simd")]
{
let mut result = vec![0.0f32; rows];
kernels::sum_rows_f32(data, &mut result, rows, dim_size);
return result;
}
#[cfg(not(feature = "simd"))]
{
return (0..rows)
.map(|i| data[i * dim_size..(i + 1) * dim_size].iter().sum())
.collect();
}
}
if outer_size == 1 {
#[cfg(feature = "simd")]
{
let mut result = aligned::alloc_aligned_zeroed::<f32>(inner_size);
kernels::scatter_add_f32(data, &mut result, dim_size, inner_size, inner_size);
return aligned::to_vec(result);
}
#[cfg(not(feature = "simd"))]
{
let mut result = vec![0.0f32; inner_size];
for row in 0..dim_size {
let row_start = row * inner_size;
for c in 0..inner_size {
result[c] += data[row_start + c];
}
}
return result;
}
}
let out_size = outer_size * inner_size;
#[cfg(feature = "simd")]
{
let mut result = aligned::alloc_aligned_zeroed::<f32>(out_size);
kernels::scatter_add_batched(
data,
&mut result,
outer_size,
dim_size,
inner_size,
dim_size * inner_size,
inner_size,
);
aligned::to_vec(result)
}
#[cfg(not(feature = "simd"))]
{
let mut result = vec![0.0f32; out_size];
for outer in 0..outer_size {
let out_base = outer * inner_size;
for d in 0..dim_size {
let in_base = outer * dim_size * inner_size + d * inner_size;
for c in 0..inner_size {
result[out_base + c] += data[in_base + c];
}
}
}
result
}
}
fn mean_dim_half<E>(tensor: &FlexTensor, dim: usize) -> FlexTensor
where
E: Element + bytemuck::Pod,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let ndims = shape.num_dims();
assert!(
dim < ndims,
"dim {} out of bounds for {} dimensions",
dim,
ndims
);
let dim_size = shape[dim];
assert!(
dim_size > 0,
"mean_dim: cannot take mean of empty dimension"
);
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let data = float_storage_as_f32(&tensor);
let divisor = dim_size as f32;
let sums = sum_dim_contiguous_f32(&data, outer_size, dim_size, inner_size);
let result: Vec<E> = sums
.into_iter()
.map(|s| E::from_elem(s / divisor))
.collect();
let bytes = Bytes::from_elems(result);
FlexTensor::new(
bytes,
Layout::contiguous(Shape::from(out_shape)),
E::dtype(),
)
}
fn mean_scalar_half<E>(tensor: &FlexTensor) -> FlexTensor
where
E: Element + bytemuck::Pod,
{
let tensor = tensor.to_contiguous();
let n = tensor.layout().num_elements();
let data = float_storage_as_f32(&tensor);
let acc = sum_f32_contiguous(&data);
let mean = acc / (n as f32);
let bytes = Bytes::from_elems(vec![E::from_elem(mean)]);
FlexTensor::new(bytes, Layout::contiguous(Shape::from(vec![1])), E::dtype())
}
pub fn mean(tensor: FlexTensor) -> FlexTensor {
let dtype = tensor.dtype();
match dtype {
DType::F16 => return mean_scalar_half::<f16>(&tensor),
DType::BF16 => return mean_scalar_half::<bf16>(&tensor),
_ => {}
}
let n = tensor.layout().num_elements();
let sum_result = sum(tensor);
match dtype {
DType::F32 => scalar_div::<f32>(sum_result, n as f32),
DType::F64 => scalar_div::<f64>(sum_result, n as f64),
_ => panic!("mean: unsupported dtype {:?}", dtype),
}
}
pub fn max_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
assert!(
tensor.layout().shape()[dim] > 0,
"max_dim: dimension {dim} has size 0"
);
if tensor.dtype() == DType::F32 && dim == tensor.layout().shape().num_dims() - 1 {
#[cfg(feature = "simd")]
if tensor.layout().shape()[dim] >= EXTREMUM_SIMD_ROW_THRESHOLD {
return extremum_dim_f32_last_simd(&tensor, dim, kernels::max_f32);
}
return extremum_f32_last_scalar(&tensor, dim, |a, b| a > b);
}
match tensor.dtype() {
DType::F32 => {
extremum_dim::<f32, _>(&tensor, dim, |a, b| !b.is_nan() && (a.is_nan() || a > b))
}
DType::F64 => {
extremum_dim::<f64, _>(&tensor, dim, |a, b| !b.is_nan() && (a.is_nan() || a > b))
}
DType::F16 => extremum_dim_half::<f16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a > b),
f16::to_f32,
f16::from_f32,
),
DType::BF16 => extremum_dim_half::<bf16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a > b),
bf16::to_f32,
bf16::from_f32,
),
DType::I64 => extremum_dim::<i64, _>(&tensor, dim, |a, b| a > b),
DType::I32 => extremum_dim::<i32, _>(&tensor, dim, |a, b| a > b),
DType::I16 => extremum_dim::<i16, _>(&tensor, dim, |a, b| a > b),
DType::I8 => extremum_dim::<i8, _>(&tensor, dim, |a, b| a > b),
DType::U64 => extremum_dim::<u64, _>(&tensor, dim, |a, b| a > b),
DType::U32 => extremum_dim::<u32, _>(&tensor, dim, |a, b| a > b),
DType::U16 => extremum_dim::<u16, _>(&tensor, dim, |a, b| a > b),
DType::U8 => extremum_dim::<u8, _>(&tensor, dim, |a, b| a > b),
_ => panic!("max_dim: unsupported dtype {:?}", tensor.dtype()),
}
}
pub fn min_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
assert!(
tensor.layout().shape()[dim] > 0,
"min_dim: dimension {dim} has size 0"
);
if tensor.dtype() == DType::F32 && dim == tensor.layout().shape().num_dims() - 1 {
#[cfg(feature = "simd")]
if tensor.layout().shape()[dim] >= EXTREMUM_SIMD_ROW_THRESHOLD {
return extremum_dim_f32_last_simd(&tensor, dim, kernels::min_f32);
}
return extremum_f32_last_scalar(&tensor, dim, |a, b| a < b);
}
match tensor.dtype() {
DType::F32 => {
extremum_dim::<f32, _>(&tensor, dim, |a, b| !b.is_nan() && (a.is_nan() || a < b))
}
DType::F64 => {
extremum_dim::<f64, _>(&tensor, dim, |a, b| !b.is_nan() && (a.is_nan() || a < b))
}
DType::F16 => extremum_dim_half::<f16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a < b),
f16::to_f32,
f16::from_f32,
),
DType::BF16 => extremum_dim_half::<bf16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a < b),
bf16::to_f32,
bf16::from_f32,
),
DType::I64 => extremum_dim::<i64, _>(&tensor, dim, |a, b| a < b),
DType::I32 => extremum_dim::<i32, _>(&tensor, dim, |a, b| a < b),
DType::I16 => extremum_dim::<i16, _>(&tensor, dim, |a, b| a < b),
DType::I8 => extremum_dim::<i8, _>(&tensor, dim, |a, b| a < b),
DType::U64 => extremum_dim::<u64, _>(&tensor, dim, |a, b| a < b),
DType::U32 => extremum_dim::<u32, _>(&tensor, dim, |a, b| a < b),
DType::U16 => extremum_dim::<u16, _>(&tensor, dim, |a, b| a < b),
DType::U8 => extremum_dim::<u8, _>(&tensor, dim, |a, b| a < b),
_ => panic!("min_dim: unsupported dtype {:?}", tensor.dtype()),
}
}
pub fn max_dim_with_indices(tensor: FlexTensor, dim: usize) -> (FlexTensor, FlexTensor) {
let dim_len = tensor.layout().shape()[dim];
assert!(
dim_len > 0,
"max_dim_with_indices: dimension {dim} has size 0"
);
assert_dim_fits_isize(dim_len, dim);
if tensor.dtype() == DType::F32 && dim == tensor.layout().shape().num_dims() - 1 {
#[cfg(feature = "simd")]
if tensor.layout().shape()[dim] >= EXTREMUM_SIMD_ROW_THRESHOLD {
return extremum_dim_with_indices_f32_last_simd(&tensor, dim, kernels::max_f32);
}
return extremum_with_indices_f32_last_scalar(&tensor, dim, |a, b| a > b);
}
match tensor.dtype() {
DType::F32 => extremum_dim_with_indices::<f32, _>(&tensor, dim, |a, b| {
!b.is_nan() && (a.is_nan() || a > b)
}),
DType::F64 => extremum_dim_with_indices::<f64, _>(&tensor, dim, |a, b| {
!b.is_nan() && (a.is_nan() || a > b)
}),
DType::F16 => extremum_dim_with_indices_half::<f16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a > b),
f16::to_f32,
f16::from_f32,
),
DType::BF16 => extremum_dim_with_indices_half::<bf16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a > b),
bf16::to_f32,
bf16::from_f32,
),
DType::I64 => extremum_dim_with_indices::<i64, _>(&tensor, dim, |a, b| a > b),
DType::I32 => extremum_dim_with_indices::<i32, _>(&tensor, dim, |a, b| a > b),
DType::I16 => extremum_dim_with_indices::<i16, _>(&tensor, dim, |a, b| a > b),
DType::I8 => extremum_dim_with_indices::<i8, _>(&tensor, dim, |a, b| a > b),
DType::U64 => extremum_dim_with_indices::<u64, _>(&tensor, dim, |a, b| a > b),
DType::U32 => extremum_dim_with_indices::<u32, _>(&tensor, dim, |a, b| a > b),
DType::U16 => extremum_dim_with_indices::<u16, _>(&tensor, dim, |a, b| a > b),
DType::U8 => extremum_dim_with_indices::<u8, _>(&tensor, dim, |a, b| a > b),
_ => panic!(
"max_dim_with_indices: unsupported dtype {:?}",
tensor.dtype()
),
}
}
pub fn min_dim_with_indices(tensor: FlexTensor, dim: usize) -> (FlexTensor, FlexTensor) {
let dim_len = tensor.layout().shape()[dim];
assert!(
dim_len > 0,
"min_dim_with_indices: dimension {dim} has size 0"
);
assert_dim_fits_isize(dim_len, dim);
if tensor.dtype() == DType::F32 && dim == tensor.layout().shape().num_dims() - 1 {
#[cfg(feature = "simd")]
if tensor.layout().shape()[dim] >= EXTREMUM_SIMD_ROW_THRESHOLD {
return extremum_dim_with_indices_f32_last_simd(&tensor, dim, kernels::min_f32);
}
return extremum_with_indices_f32_last_scalar(&tensor, dim, |a, b| a < b);
}
match tensor.dtype() {
DType::F32 => extremum_dim_with_indices::<f32, _>(&tensor, dim, |a, b| {
!b.is_nan() && (a.is_nan() || a < b)
}),
DType::F64 => extremum_dim_with_indices::<f64, _>(&tensor, dim, |a, b| {
!b.is_nan() && (a.is_nan() || a < b)
}),
DType::F16 => extremum_dim_with_indices_half::<f16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a < b),
f16::to_f32,
f16::from_f32,
),
DType::BF16 => extremum_dim_with_indices_half::<bf16, _>(
&tensor,
dim,
|a, b| !b.is_nan() && (a.is_nan() || a < b),
bf16::to_f32,
bf16::from_f32,
),
DType::I64 => extremum_dim_with_indices::<i64, _>(&tensor, dim, |a, b| a < b),
DType::I32 => extremum_dim_with_indices::<i32, _>(&tensor, dim, |a, b| a < b),
DType::I16 => extremum_dim_with_indices::<i16, _>(&tensor, dim, |a, b| a < b),
DType::I8 => extremum_dim_with_indices::<i8, _>(&tensor, dim, |a, b| a < b),
DType::U64 => extremum_dim_with_indices::<u64, _>(&tensor, dim, |a, b| a < b),
DType::U32 => extremum_dim_with_indices::<u32, _>(&tensor, dim, |a, b| a < b),
DType::U16 => extremum_dim_with_indices::<u16, _>(&tensor, dim, |a, b| a < b),
DType::U8 => extremum_dim_with_indices::<u8, _>(&tensor, dim, |a, b| a < b),
_ => panic!(
"min_dim_with_indices: unsupported dtype {:?}",
tensor.dtype()
),
}
}
#[cfg(feature = "rayon")]
const EXTREMUM_PARALLEL_THRESHOLD: usize = 32 * 1024;
#[cfg(feature = "simd")]
const EXTREMUM_SIMD_ROW_THRESHOLD: usize = 512;
fn extremum_f32_last_scalar<F>(tensor: &FlexTensor, dim: usize, is_better: F) -> FlexTensor
where
F: Fn(f32, f32) -> bool + Send + Sync,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let dim_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let data: &[f32] = tensor.storage();
let start = tensor.layout().start_offset();
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let reduce_row = |outer: usize| -> f32 {
let row_start = start + outer * dim_size;
let row = &data[row_start..row_start + dim_size];
let mut best = row[0];
for &v in row {
if v.is_nan() {
return f32::NAN;
}
if is_better(v, best) {
best = v;
}
}
best
};
#[cfg(feature = "rayon")]
let values: Vec<f32> = if outer_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..outer_size).into_par_iter().map(&reduce_row).collect()
} else {
(0..outer_size).map(reduce_row).collect()
};
#[cfg(not(feature = "rayon"))]
let values: Vec<f32> = (0..outer_size).map(reduce_row).collect();
FlexTensor::new(
Bytes::from_elems(values),
Layout::contiguous(Shape::from(out_shape)),
DType::F32,
)
}
fn extremum_indices_f32_last_scalar<F>(tensor: &FlexTensor, dim: usize, is_better: F) -> FlexTensor
where
F: Fn(f32, f32) -> bool + Send + Sync,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let dim_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let data: &[f32] = tensor.storage();
let start = tensor.layout().start_offset();
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let find_row = |outer: usize| -> isize {
let row_start = start + outer * dim_size;
let row = &data[row_start..row_start + dim_size];
let mut best = row[0];
let mut best_idx: isize = 0;
for (i, &v) in row.iter().enumerate() {
if v.is_nan() {
return i as isize;
}
if is_better(v, best) {
best = v;
best_idx = i as isize;
}
}
best_idx
};
#[cfg(feature = "rayon")]
let indices: Vec<isize> = if outer_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..outer_size).into_par_iter().map(find_row).collect()
} else {
(0..outer_size).map(find_row).collect()
};
#[cfg(not(feature = "rayon"))]
let indices: Vec<isize> = (0..outer_size).map(find_row).collect();
FlexTensor::new(
Bytes::from_elems(indices),
Layout::contiguous(Shape::from(out_shape)),
INDEX_DTYPE,
)
}
fn extremum_with_indices_f32_last_scalar<F>(
tensor: &FlexTensor,
dim: usize,
is_better: F,
) -> (FlexTensor, FlexTensor)
where
F: Fn(f32, f32) -> bool + Send + Sync,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let dim_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let data: &[f32] = tensor.storage();
let start = tensor.layout().start_offset();
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let find_row = |outer: usize| -> (f32, isize) {
let row_start = start + outer * dim_size;
let row = &data[row_start..row_start + dim_size];
let mut best = row[0];
let mut best_idx: isize = 0;
for (i, &v) in row.iter().enumerate() {
if v.is_nan() {
return (f32::NAN, i as isize);
}
if is_better(v, best) {
best = v;
best_idx = i as isize;
}
}
(best, best_idx)
};
#[cfg(feature = "rayon")]
let (values, indices): (Vec<f32>, Vec<isize>) =
if outer_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..outer_size).into_par_iter().map(find_row).unzip()
} else {
(0..outer_size).map(find_row).unzip()
};
#[cfg(not(feature = "rayon"))]
let (values, indices): (Vec<f32>, Vec<isize>) = (0..outer_size).map(find_row).unzip();
let val_tensor = FlexTensor::new(
Bytes::from_elems(values),
Layout::contiguous(Shape::from(out_shape.clone())),
DType::F32,
);
let idx_tensor = FlexTensor::new(
Bytes::from_elems(indices),
Layout::contiguous(Shape::from(out_shape)),
INDEX_DTYPE,
);
(val_tensor, idx_tensor)
}
#[cfg(feature = "simd")]
fn extremum_dim_f32_last_simd(
tensor: &FlexTensor,
dim: usize,
simd_reduce: fn(&[f32]) -> f32,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let dim_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let data: &[f32] = tensor.storage();
let start = tensor.layout().start_offset();
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let reduce_row = |outer: usize| -> f32 {
let row_start = start + outer * dim_size;
let row = &data[row_start..row_start + dim_size];
let ext = simd_reduce(row);
if ext.is_nan() {
return f32::NAN;
}
for &v in row {
if v.is_nan() {
return f32::NAN;
}
}
ext
};
#[cfg(feature = "rayon")]
let values: Vec<f32> = if outer_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..outer_size).into_par_iter().map(reduce_row).collect()
} else {
(0..outer_size).map(reduce_row).collect()
};
#[cfg(not(feature = "rayon"))]
let values: Vec<f32> = (0..outer_size).map(reduce_row).collect();
FlexTensor::new(
Bytes::from_elems(values),
Layout::contiguous(Shape::from(out_shape)),
DType::F32,
)
}
#[cfg(feature = "simd")]
fn extremum_dim_with_indices_f32_last_simd(
tensor: &FlexTensor,
dim: usize,
simd_reduce: fn(&[f32]) -> f32,
) -> (FlexTensor, FlexTensor) {
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let dim_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let data: &[f32] = tensor.storage();
let start = tensor.layout().start_offset();
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let find_row = |outer: usize| -> (f32, isize) {
let row_start = start + outer * dim_size;
let row = &data[row_start..row_start + dim_size];
let ext = simd_reduce(row);
for (i, &v) in row.iter().enumerate() {
if v.is_nan() {
return (f32::NAN, i as isize);
}
if v == ext {
return (ext, i as isize);
}
}
(ext, 0)
};
#[cfg(feature = "rayon")]
let (values, indices): (Vec<f32>, Vec<isize>) =
if outer_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..outer_size).into_par_iter().map(find_row).unzip()
} else {
(0..outer_size).map(find_row).unzip()
};
#[cfg(not(feature = "rayon"))]
let (values, indices): (Vec<f32>, Vec<isize>) = (0..outer_size).map(find_row).unzip();
let val_tensor = FlexTensor::new(
Bytes::from_elems(values),
Layout::contiguous(Shape::from(out_shape.clone())),
DType::F32,
);
let idx_tensor = FlexTensor::new(
Bytes::from_elems(indices),
Layout::contiguous(Shape::from(out_shape)),
INDEX_DTYPE,
);
(val_tensor, idx_tensor)
}
#[cfg(feature = "simd")]
fn extremum_indices_f32_last_simd(
tensor: &FlexTensor,
dim: usize,
simd_reduce: fn(&[f32]) -> f32,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let dim_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let data: &[f32] = tensor.storage();
let start = tensor.layout().start_offset();
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let find_row = |outer: usize| -> isize {
let row_start = start + outer * dim_size;
let row = &data[row_start..row_start + dim_size];
let ext = simd_reduce(row);
for (i, &v) in row.iter().enumerate() {
if v.is_nan() || v == ext {
return i as isize;
}
}
0
};
#[cfg(feature = "rayon")]
let indices: Vec<isize> = if outer_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..outer_size).into_par_iter().map(find_row).collect()
} else {
(0..outer_size).map(find_row).collect()
};
#[cfg(not(feature = "rayon"))]
let indices: Vec<isize> = (0..outer_size).map(find_row).collect();
FlexTensor::new(
Bytes::from_elems(indices),
Layout::contiguous(Shape::from(out_shape)),
INDEX_DTYPE,
)
}
fn extremum_dim<E, F>(tensor: &FlexTensor, dim: usize, is_better: F) -> FlexTensor
where
E: Element + bytemuck::Pod + Send + Sync,
F: Fn(E, E) -> bool + Send + Sync,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let ndims = shape.num_dims();
assert!(dim < ndims);
let dim_size = shape[dim];
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let outer_size: usize = shape[..dim].iter().product::<usize>();
let inner_size: usize = shape[dim + 1..].iter().product::<usize>();
let out_size = outer_size * inner_size;
let data: &[E] = tensor.storage();
let start_offset = tensor.layout().start_offset();
let find = |flat_idx: usize| -> E {
let outer = flat_idx / inner_size;
let inner = flat_idx % inner_size;
let base = start_offset + outer * dim_size * inner_size + inner;
let mut best = data[base];
for d in 1..dim_size {
let val = data[base + d * inner_size];
if is_better(val, best) {
best = val;
}
}
best
};
#[cfg(feature = "rayon")]
let values: Vec<E> = if out_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..out_size).into_par_iter().map(&find).collect()
} else {
(0..out_size).map(find).collect()
};
#[cfg(not(feature = "rayon"))]
let values: Vec<E> = (0..out_size).map(find).collect();
FlexTensor::new(
Bytes::from_elems(values),
Layout::contiguous(Shape::from(out_shape)),
E::dtype(),
)
}
fn extremum_dim_with_indices<E, F>(
tensor: &FlexTensor,
dim: usize,
is_better: F,
) -> (FlexTensor, FlexTensor)
where
E: Element + bytemuck::Pod + Send + Sync,
F: Fn(E, E) -> bool + Send + Sync,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let ndims = shape.num_dims();
assert!(dim < ndims);
let dim_size = shape[dim];
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let outer_size: usize = shape[..dim].iter().product::<usize>();
let inner_size: usize = shape[dim + 1..].iter().product::<usize>();
let out_size = outer_size * inner_size;
let data: &[E] = tensor.storage();
let start_offset = tensor.layout().start_offset();
let find = |flat_idx: usize| -> (E, isize) {
let outer = flat_idx / inner_size;
let inner = flat_idx % inner_size;
let base = start_offset + outer * dim_size * inner_size + inner;
let mut best = data[base];
let mut best_idx: isize = 0;
for d in 1..dim_size {
let val = data[base + d * inner_size];
if is_better(val, best) {
best = val;
best_idx = d as isize;
}
}
(best, best_idx)
};
#[cfg(feature = "rayon")]
let (values, indices): (Vec<E>, Vec<isize>) =
if out_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..out_size).into_par_iter().map(&find).unzip()
} else {
(0..out_size).map(find).unzip()
};
#[cfg(not(feature = "rayon"))]
let (values, indices): (Vec<E>, Vec<isize>) = (0..out_size).map(find).unzip();
let val_tensor = FlexTensor::new(
Bytes::from_elems(values),
Layout::contiguous(Shape::from(out_shape.clone())),
E::dtype(),
);
let idx_tensor = FlexTensor::new(
Bytes::from_elems(indices),
Layout::contiguous(Shape::from(out_shape)),
INDEX_DTYPE,
);
(val_tensor, idx_tensor)
}
fn extremum_dim_half<E, F>(
tensor: &FlexTensor,
dim: usize,
is_better: F,
to_f32: fn(E) -> f32,
from_f32: fn(f32) -> E,
) -> FlexTensor
where
E: Element + bytemuck::Pod + Send + Sync,
F: Fn(f32, f32) -> bool + Send + Sync,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let ndims = shape.num_dims();
assert!(dim < ndims);
let dim_size = shape[dim];
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let outer_size: usize = shape[..dim].iter().product::<usize>();
let inner_size: usize = shape[dim + 1..].iter().product::<usize>();
let out_size = outer_size * inner_size;
let data: &[E] = tensor.storage();
let start_offset = tensor.layout().start_offset();
let find = |flat_idx: usize| -> E {
let outer = flat_idx / inner_size;
let inner = flat_idx % inner_size;
let base = start_offset + outer * dim_size * inner_size + inner;
let mut best = to_f32(data[base]);
for d in 1..dim_size {
let val = to_f32(data[base + d * inner_size]);
if is_better(val, best) {
best = val;
}
}
from_f32(best)
};
#[cfg(feature = "rayon")]
let values: Vec<E> = if out_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..out_size).into_par_iter().map(&find).collect()
} else {
(0..out_size).map(find).collect()
};
#[cfg(not(feature = "rayon"))]
let values: Vec<E> = (0..out_size).map(find).collect();
FlexTensor::new(
Bytes::from_elems(values),
Layout::contiguous(Shape::from(out_shape)),
E::dtype(),
)
}
fn extremum_dim_with_indices_half<E, F>(
tensor: &FlexTensor,
dim: usize,
is_better: F,
to_f32: fn(E) -> f32,
from_f32: fn(f32) -> E,
) -> (FlexTensor, FlexTensor)
where
E: Element + bytemuck::Pod + Send + Sync,
F: Fn(f32, f32) -> bool + Send + Sync,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape();
let ndims = shape.num_dims();
assert!(dim < ndims);
let dim_size = shape[dim];
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let outer_size: usize = shape[..dim].iter().product::<usize>();
let inner_size: usize = shape[dim + 1..].iter().product::<usize>();
let out_size = outer_size * inner_size;
let data: &[E] = tensor.storage();
let start_offset = tensor.layout().start_offset();
let find = |flat_idx: usize| -> (E, isize) {
let outer = flat_idx / inner_size;
let inner = flat_idx % inner_size;
let base = start_offset + outer * dim_size * inner_size + inner;
let mut best = to_f32(data[base]);
let mut best_idx: isize = 0;
for d in 1..dim_size {
let val = to_f32(data[base + d * inner_size]);
if is_better(val, best) {
best = val;
best_idx = d as isize;
}
}
(from_f32(best), best_idx)
};
#[cfg(feature = "rayon")]
let (values, indices): (Vec<E>, Vec<isize>) =
if out_size * dim_size >= EXTREMUM_PARALLEL_THRESHOLD {
(0..out_size).into_par_iter().map(&find).unzip()
} else {
(0..out_size).map(find).unzip()
};
#[cfg(not(feature = "rayon"))]
let (values, indices): (Vec<E>, Vec<isize>) = (0..out_size).map(find).unzip();
let val_tensor = FlexTensor::new(
Bytes::from_elems(values),
Layout::contiguous(Shape::from(out_shape.clone())),
E::dtype(),
);
let idx_tensor = FlexTensor::new(
Bytes::from_elems(indices),
Layout::contiguous(Shape::from(out_shape)),
INDEX_DTYPE,
);
(val_tensor, idx_tensor)
}
fn scalar_div<E: Element + bytemuck::Pod + core::ops::Div<Output = E> + Copy>(
mut tensor: FlexTensor,
divisor: E,
) -> FlexTensor {
let data: &mut [E] = tensor.storage_mut();
for x in data.iter_mut() {
*x = *x / divisor;
}
tensor
}
#[cfg(test)]
mod tests {
use alloc::vec;
use burn_backend::TensorData;
use burn_backend::ops::{FloatTensorOps, IntTensorOps};
use burn_std::{bf16, f16};
use crate::{Flex, FlexTensor};
#[test]
fn test_mean_f16_overflow_intermediate_sum() {
let data: Vec<f16> = (0..1024).map(|i| f16::from_f32(i as f32)).collect();
let tensor = FlexTensor::from_data(TensorData::new(data, [1024]));
let result = Flex::float_mean(tensor);
let result_data = result.into_data();
let values: &[f16] = bytemuck::cast_slice(&result_data.bytes);
assert_eq!(values.len(), 1);
let mean = values[0].to_f32();
assert!(mean.is_finite(), "mean overflowed to {mean}");
assert!((mean - 511.5).abs() < 0.5, "expected ~511.5, got {mean}");
}
#[test]
fn test_mean_dim_f16_zero_outer_dim() {
let data: Vec<f16> = Vec::new();
let tensor = FlexTensor::from_data(TensorData::new(data, [0, 4]));
let result = Flex::float_mean_dim(tensor, 1);
assert_eq!(result.layout().shape().to_vec(), vec![0, 1]);
let result_data = result.into_data();
let values: &[f16] = bytemuck::cast_slice(&result_data.bytes);
assert!(values.is_empty());
}
#[test]
fn test_sum_dim_f32_zero_outer_dim() {
let data: Vec<f32> = Vec::new();
let tensor = FlexTensor::from_data(TensorData::new(data, [0, 4]));
let result = Flex::float_sum_dim(tensor, 1);
assert_eq!(result.layout().shape().to_vec(), vec![0, 1]);
assert!(result.into_data().bytes.is_empty());
}
#[test]
fn test_sum_dim_f32_zero_inner_dim() {
let data: Vec<f32> = Vec::new();
let tensor = FlexTensor::from_data(TensorData::new(data, [3, 0]));
let result = Flex::float_sum_dim(tensor, 0);
assert_eq!(result.layout().shape().to_vec(), vec![1, 0]);
assert!(result.into_data().bytes.is_empty());
}
#[test]
fn test_sum_dim_f64_zero_outer_dim() {
let data: Vec<f64> = Vec::new();
let tensor = FlexTensor::from_data(TensorData::new(data, [0, 4]));
let result = Flex::float_sum_dim(tensor, 1);
assert_eq!(result.layout().shape().to_vec(), vec![0, 1]);
assert!(result.into_data().bytes.is_empty());
}
#[test]
fn test_sum_dim_i8_zero_outer_dim() {
let data: Vec<i8> = Vec::new();
let tensor = FlexTensor::from_data(TensorData::new(data, [0, 4]));
let result = Flex::int_sum_dim(tensor, 1);
assert_eq!(result.layout().shape().to_vec(), vec![0, 1]);
assert!(result.into_data().bytes.is_empty());
}
#[test]
fn test_sum_dim_bf16_zero_outer_dim() {
let data: Vec<bf16> = Vec::new();
let tensor = FlexTensor::from_data(TensorData::new(data, [0, 4]));
let result = Flex::float_sum_dim(tensor, 1);
assert_eq!(result.layout().shape().to_vec(), vec![0, 1]);
assert!(result.into_data().bytes.is_empty());
}
#[test]
fn test_mean_dim_i8_large_dimension() {
let mut data: Vec<i8> = vec![0i8; 200];
data[0] = 100;
let tensor = FlexTensor::from_data(TensorData::new(data, [1, 200]));
let result = Flex::int_mean_dim(tensor, 1);
let result_data = result.into_data();
let values: Vec<i8> = bytemuck::cast_slice(&result_data.bytes).to_vec();
assert_eq!(values, vec![0]);
}
#[test]
fn test_mean_dim_i16_large_dimension() {
let mut data: Vec<i16> = vec![0i16; 40000];
data[0] = 32000;
let tensor = FlexTensor::from_data(TensorData::new(data, [1, 40000]));
let result = Flex::int_mean_dim(tensor, 1);
let result_data = result.into_data();
let values: Vec<i16> = bytemuck::cast_slice(&result_data.bytes).to_vec();
assert_eq!(values, vec![0]);
}
#[test]
fn test_sum_i32() {
let data: Vec<i32> = vec![1, 2, 3, 4, 5];
let tensor = FlexTensor::from_data(TensorData::new(data, [5]));
let result = Flex::int_sum(tensor);
assert_eq!(result.layout().shape().to_vec(), vec![1]);
let result_data = result.into_data();
let values: Vec<i32> = bytemuck::cast_slice(&result_data.bytes).to_vec();
assert_eq!(values, vec![15]);
}
#[test]
fn test_sum_dim_i32() {
let data: Vec<i32> = vec![1, 2, 3, 4, 5, 6];
let tensor = FlexTensor::from_data(TensorData::new(data, [2, 3]));
let result = Flex::int_sum_dim(tensor, 1);
assert_eq!(result.layout().shape().to_vec(), vec![2, 1]);
let result_data = result.into_data();
let values: Vec<i32> = bytemuck::cast_slice(&result_data.bytes).to_vec();
assert_eq!(values, vec![6, 15]);
}
#[test]
fn test_argmax_i32() {
let data: Vec<i32> = vec![1, 5, 3, 2, 4];
let tensor = FlexTensor::from_data(TensorData::new(data, [5]));
let result = Flex::int_argmax(tensor, 0);
assert_eq!(result.layout().shape().to_vec(), vec![1]);
let result_data = result.into_data();
#[cfg(target_pointer_width = "64")]
let values: Vec<i64> = bytemuck::cast_slice(&result_data.bytes).to_vec();
#[cfg(target_pointer_width = "32")]
let values: Vec<i64> = bytemuck::cast_slice::<u8, i32>(&result_data.bytes)
.iter()
.map(|&v| v as i64)
.collect();
assert_eq!(values, vec![1]);
}
#[test]
#[should_panic(expected = "dimension 0 has size 0")]
fn test_max_dim_zero_size_panics() {
let tensor = FlexTensor::from_data(TensorData::new(Vec::<f32>::new(), [0, 3]));
Flex::float_max_dim(tensor, 0);
}
#[test]
#[should_panic(expected = "dimension 1 has size 0")]
fn test_min_dim_zero_size_panics() {
let tensor = FlexTensor::from_data(TensorData::new(Vec::<f32>::new(), [3, 0]));
Flex::float_min_dim(tensor, 1);
}
#[test]
fn test_sum_u32() {
let tensor = FlexTensor::from_data(TensorData::new(vec![10u32, 20, 30], [3]));
let result = Flex::int_sum(tensor);
let data: Vec<u32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![60]);
}
#[test]
fn test_sum_u64() {
let tensor = FlexTensor::from_data(TensorData::new(vec![100u64, 200, 300], [3]));
let result = Flex::int_sum(tensor);
let data: Vec<u64> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![600]);
}
#[test]
fn test_sum_dim_u8() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1u8, 2, 3, 4], [2, 2]));
let result = Flex::int_sum_dim(tensor, 1);
let data: Vec<u8> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![3, 7]);
}
#[test]
fn test_prod_u16() {
let tensor = FlexTensor::from_data(TensorData::new(vec![2u16, 3, 5], [3]));
let result = Flex::int_prod(tensor);
let data: Vec<u16> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![30]);
}
#[test]
fn test_max_u32() {
let tensor = FlexTensor::from_data(TensorData::new(vec![5u32, 100, 42], [3]));
let result = Flex::int_max(tensor);
let data: Vec<u32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![100]);
}
#[test]
fn test_min_u8() {
let tensor = FlexTensor::from_data(TensorData::new(vec![5u8, 1, 42], [3]));
let result = Flex::int_min(tensor);
let data: Vec<u8> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![1]);
}
#[test]
fn test_max_dim_u64() {
let tensor = FlexTensor::from_data(TensorData::new(vec![10u64, 20, 30, 5], [2, 2]));
let result = Flex::int_max_dim(tensor, 1);
let data: Vec<u64> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![20, 30]);
}
#[test]
fn test_min_dim_u16() {
let tensor = FlexTensor::from_data(TensorData::new(vec![10u16, 2, 30, 5], [2, 2]));
let result = Flex::int_min_dim(tensor, 1);
let data: Vec<u16> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![2, 5]);
}
#[test]
fn test_mean_dim_u8() {
let tensor = FlexTensor::from_data(TensorData::new(vec![10u8, 20, 30, 40], [2, 2]));
let result = Flex::int_mean_dim(tensor, 1);
let data: Vec<u8> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![15, 35]);
}
#[test]
fn test_max_dim_with_indices_u32() {
let tensor = FlexTensor::from_data(TensorData::new(vec![5u32, 10, 3, 8], [2, 2]));
let (values, indices) = Flex::int_max_dim_with_indices(tensor, 1);
let vals: Vec<u32> = values.into_data().to_vec().unwrap();
let idxs: Vec<isize> = bytemuck::cast_slice(&indices.into_data().bytes).to_vec();
assert_eq!(vals, vec![10, 8]);
assert_eq!(idxs, vec![1, 1]);
}
#[test]
fn test_argmax_scalar_and_simd_paths_agree_on_leading_nan() {
let short =
FlexTensor::from_data(TensorData::new(vec![f32::NAN, f32::NAN, f32::NAN], [1, 3]));
let short_idxs: Vec<isize> =
bytemuck::cast_slice(&super::argmax(short, 1).into_data().bytes).to_vec();
let mut long_data = alloc::vec![1.0f32; 600];
long_data[0] = f32::NAN;
long_data[1] = f32::NAN;
long_data[300] = 5.0;
let long = FlexTensor::from_data(TensorData::new(long_data, [1, 600]));
let long_idxs: Vec<isize> =
bytemuck::cast_slice(&super::argmax(long, 1).into_data().bytes).to_vec();
assert_eq!(short_idxs, vec![0], "scalar path");
assert_eq!(long_idxs, vec![0], "SIMD path");
}
}