use alloc::boxed::Box;
#[cfg(feature = "simd")]
use alloc::vec;
use alloc::vec::Vec;
use burn_backend::{DType, Element};
use burn_std::{Bytes, Shape, bf16, f16};
use bytemuck::Pod;
use crate::strided_index::StridedIter;
use crate::{FlexTensor, Layout};
#[cfg(feature = "simd")]
use crate::simd;
#[derive(Clone, Copy)]
pub enum CompareOp {
Greater,
GreaterEqual,
Lower,
LowerEqual,
Equal,
NotEqual,
}
pub fn compare<F32Cmp, F64Cmp>(
lhs: FlexTensor,
rhs: FlexTensor,
f32_cmp: F32Cmp,
f64_cmp: F64Cmp,
simd_hint: Option<CompareOp>,
) -> FlexTensor
where
F32Cmp: Fn(f32, f32) -> bool + Copy,
F64Cmp: Fn(f64, f64) -> bool + Copy,
{
debug_assert_eq!(lhs.dtype(), rhs.dtype(), "compare: dtype mismatch");
let (lhs, rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
let dtype = lhs.dtype();
match dtype {
DType::F32 => compare_f32(lhs, &rhs, f32_cmp, simd_hint),
DType::F64 => compare_typed(lhs, &rhs, f64_cmp),
DType::F16 => compare_typed(lhs, &rhs, |a: f16, b: f16| f32_cmp(a.to_f32(), b.to_f32())),
DType::BF16 => compare_typed(lhs, &rhs, |a: bf16, b: bf16| {
f32_cmp(a.to_f32(), b.to_f32())
}),
_ => panic!("compare: unsupported dtype {:?}", dtype),
}
}
#[cfg(feature = "simd")]
fn compare_f32<Cmp>(
lhs: FlexTensor,
rhs: &FlexTensor,
cmp: Cmp,
simd_hint: Option<CompareOp>,
) -> FlexTensor
where
Cmp: Fn(f32, f32) -> bool,
{
if let (Some((l_start, l_end)), Some((r_start, r_end))) = (
lhs.layout().contiguous_offsets(),
rhs.layout().contiguous_offsets(),
) && let Some(simd_op) = simd_hint.map(compare_op_to_simd)
{
let shape = lhs.layout().shape().clone();
let lhs_storage: &[f32] = lhs.storage();
let rhs_storage: &[f32] = rhs.storage();
let l_slice = &lhs_storage[l_start..l_end];
let r_slice = &rhs_storage[r_start..r_end];
let mut result = vec![0u8; l_slice.len()];
simd::cmp_f32(l_slice, r_slice, &mut result, simd_op);
return make_bool_tensor(result, shape);
}
if lhs.layout().num_dims() == 2
&& let Some(simd_op) = simd_hint.map(compare_op_to_simd)
&& let Some((result, shape)) = try_broadcast_cmp_f32(&lhs, rhs, simd_op)
{
return make_bool_tensor(result, shape);
}
compare_typed(lhs, rhs, cmp)
}
#[cfg(feature = "simd")]
fn try_broadcast_cmp_f32(
lhs: &FlexTensor,
rhs: &FlexTensor,
op: simd::CmpOp,
) -> Option<(Vec<u8>, Shape)> {
let lhs_strides = lhs.layout().strides();
let rhs_strides = rhs.layout().strides();
let shape = lhs.layout().shape().clone();
let [rows, cols] = shape[..] else {
return None;
};
if lhs_strides[1] == 0 && rhs_strides == [cols as isize, 1] {
let lhs_storage: &[f32] = lhs.storage();
let rhs_storage: &[f32] = rhs.storage();
let l_offset = lhs.layout().start_offset() as isize;
let l_stride = lhs_strides[0];
let r_offset = rhs.layout().start_offset();
let mut result = vec![0u8; rows * cols];
for row in 0..rows {
let a_val = lhs_storage[(l_offset + row as isize * l_stride) as usize];
let r_row_start = r_offset + row * cols;
let r_slice = &rhs_storage[r_row_start..r_row_start + cols];
let out_start = row * cols;
simd::cmp_scalar_f32(
r_slice,
a_val,
&mut result[out_start..out_start + cols],
swap_cmp_op(op),
);
}
return Some((result, shape));
}
if rhs_strides[0] == 0 && lhs_strides == [cols as isize, 1] {
let lhs_storage: &[f32] = lhs.storage();
let rhs_storage: &[f32] = rhs.storage();
let l_offset = lhs.layout().start_offset();
let r_offset = rhs.layout().start_offset() as isize;
let r_stride = rhs_strides[1];
let rhs_row: Vec<f32> = (0..cols)
.map(|j| rhs_storage[(r_offset + j as isize * r_stride) as usize])
.collect();
let mut result = vec![0u8; rows * cols];
for row in 0..rows {
let l_row_start = l_offset + row * cols;
let l_slice = &lhs_storage[l_row_start..l_row_start + cols];
let out_start = row * cols;
for (j, (&lv, &rv)) in l_slice.iter().zip(rhs_row.iter()).enumerate() {
result[out_start + j] = match op {
simd::CmpOp::Gt => (lv > rv) as u8,
simd::CmpOp::Ge => (lv >= rv) as u8,
simd::CmpOp::Lt => (lv < rv) as u8,
simd::CmpOp::Le => (lv <= rv) as u8,
simd::CmpOp::Eq => (lv == rv) as u8,
simd::CmpOp::Ne => (lv != rv) as u8,
};
}
}
return Some((result, shape));
}
if lhs_strides[1] == 0 && rhs_strides[0] == 0 {
let lhs_storage: &[f32] = lhs.storage();
let rhs_storage: &[f32] = rhs.storage();
let l_offset = lhs.layout().start_offset() as isize;
let l_stride = lhs_strides[0];
let r_offset = rhs.layout().start_offset() as isize;
let r_stride = rhs_strides[1];
let rhs_row: Vec<f32> = (0..cols)
.map(|j| rhs_storage[(r_offset + j as isize * r_stride) as usize])
.collect();
let mut result = vec![0u8; rows * cols];
for row in 0..rows {
let a_val = lhs_storage[(l_offset + row as isize * l_stride) as usize];
let out_start = row * cols;
simd::cmp_scalar_f32(
&rhs_row,
a_val,
&mut result[out_start..out_start + cols],
swap_cmp_op(op),
);
}
return Some((result, shape));
}
None
}
#[cfg(feature = "simd")]
fn swap_cmp_op(op: simd::CmpOp) -> simd::CmpOp {
match op {
simd::CmpOp::Gt => simd::CmpOp::Lt, simd::CmpOp::Ge => simd::CmpOp::Le,
simd::CmpOp::Lt => simd::CmpOp::Gt,
simd::CmpOp::Le => simd::CmpOp::Ge,
simd::CmpOp::Eq => simd::CmpOp::Eq, simd::CmpOp::Ne => simd::CmpOp::Ne,
}
}
#[cfg(not(feature = "simd"))]
fn compare_f32<Cmp>(
lhs: FlexTensor,
rhs: &FlexTensor,
cmp: Cmp,
_simd_hint: Option<CompareOp>,
) -> FlexTensor
where
Cmp: Fn(f32, f32) -> bool,
{
compare_typed(lhs, rhs, cmp)
}
#[cfg(feature = "simd")]
fn compare_op_to_simd(op: CompareOp) -> simd::CmpOp {
match op {
CompareOp::Greater => simd::CmpOp::Gt,
CompareOp::GreaterEqual => simd::CmpOp::Ge,
CompareOp::Lower => simd::CmpOp::Lt,
CompareOp::LowerEqual => simd::CmpOp::Le,
CompareOp::Equal => simd::CmpOp::Eq,
CompareOp::NotEqual => simd::CmpOp::Ne,
}
}
pub fn compare_elem<F32Cmp, F64Cmp>(
lhs: FlexTensor,
rhs: f64,
f32_cmp: F32Cmp,
f64_cmp: F64Cmp,
simd_hint: Option<CompareOp>,
) -> FlexTensor
where
F32Cmp: Fn(f32, f32) -> bool + Copy,
F64Cmp: Fn(f64, f64) -> bool + Copy,
{
let dtype = lhs.dtype();
match dtype {
DType::F32 => compare_elem_f32(lhs, rhs as f32, f32_cmp, simd_hint),
DType::F64 => compare_elem_typed(lhs, rhs, f64_cmp),
DType::F16 => {
let scalar = f16::from_f64(rhs);
compare_elem_typed(lhs, scalar, |a: f16, b: f16| {
f32_cmp(a.to_f32(), b.to_f32())
})
}
DType::BF16 => {
let scalar = bf16::from_f64(rhs);
compare_elem_typed(lhs, scalar, |a: bf16, b: bf16| {
f32_cmp(a.to_f32(), b.to_f32())
})
}
_ => panic!("compare_elem: unsupported dtype {:?}", dtype),
}
}
#[cfg(feature = "simd")]
fn compare_elem_f32<Cmp>(
lhs: FlexTensor,
rhs: f32,
cmp: Cmp,
simd_hint: Option<CompareOp>,
) -> FlexTensor
where
Cmp: Fn(f32, f32) -> bool,
{
if let Some((start, end)) = lhs.layout().contiguous_offsets()
&& let Some(simd_op) = simd_hint.map(compare_op_to_simd)
{
let shape = lhs.layout().shape().clone();
let lhs_storage: &[f32] = lhs.storage();
let l_slice = &lhs_storage[start..end];
let mut result = vec![0u8; l_slice.len()];
simd::cmp_scalar_f32(l_slice, rhs, &mut result, simd_op);
return make_bool_tensor(result, shape);
}
compare_elem_typed(lhs, rhs, cmp)
}
#[cfg(not(feature = "simd"))]
fn compare_elem_f32<Cmp>(
lhs: FlexTensor,
rhs: f32,
cmp: Cmp,
_simd_hint: Option<CompareOp>,
) -> FlexTensor
where
Cmp: Fn(f32, f32) -> bool,
{
compare_elem_typed(lhs, rhs, cmp)
}
fn compare_typed<E, Cmp>(lhs: FlexTensor, rhs: &FlexTensor, cmp: Cmp) -> FlexTensor
where
E: Element + Pod,
Cmp: Fn(E, E) -> bool,
{
let shape = lhs.layout().shape().clone();
let lhs_storage: &[E] = lhs.storage();
let rhs_storage: &[E] = rhs.storage();
let result: Vec<u8> = match (
lhs.layout().contiguous_offsets(),
rhs.layout().contiguous_offsets(),
) {
(Some((l_start, l_end)), Some((r_start, r_end))) => {
let l_slice = &lhs_storage[l_start..l_end];
let r_slice = &rhs_storage[r_start..r_end];
l_slice
.iter()
.zip(r_slice)
.map(|(&a, &b)| cmp(a, b) as u8)
.collect()
}
_ if lhs.layout().num_dims() == 2 => crate::ops::binary::apply_2d_strided(
lhs_storage,
rhs_storage,
lhs.layout(),
rhs.layout(),
|a, b| cmp(a, b) as u8,
),
_ => {
let lhs_iter = StridedIter::new(lhs.layout());
let rhs_iter = StridedIter::new(rhs.layout());
lhs_iter
.zip(rhs_iter)
.map(|(li, ri)| cmp(lhs_storage[li], rhs_storage[ri]) as u8)
.collect()
}
};
make_bool_tensor(result, shape)
}
fn compare_elem_typed<E, Cmp>(lhs: FlexTensor, rhs: E, cmp: Cmp) -> FlexTensor
where
E: Element + Pod + Copy,
Cmp: Fn(E, E) -> bool,
{
let shape = lhs.layout().shape().clone();
let lhs_storage: &[E] = lhs.storage();
let result: Vec<u8> = match lhs.layout().contiguous_offsets() {
Some((start, end)) => lhs_storage[start..end]
.iter()
.map(|&a| cmp(a, rhs) as u8)
.collect(),
None => StridedIter::new(lhs.layout())
.map(|idx| cmp(lhs_storage[idx], rhs) as u8)
.collect(),
};
make_bool_tensor(result, shape)
}
fn make_bool_tensor(data: Vec<u8>, shape: Shape) -> FlexTensor {
let bytes = Bytes::from_elems(data);
FlexTensor::new(
bytes,
Layout::contiguous(shape),
DType::Bool(burn_std::BoolStore::Native),
)
}
pub fn greater(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare(
lhs,
rhs,
|a, b| a > b,
|a, b| a > b,
Some(CompareOp::Greater),
)
}
pub fn greater_elem(lhs: FlexTensor, rhs: f64) -> FlexTensor {
compare_elem(
lhs,
rhs,
|a, b| a > b,
|a, b| a > b,
Some(CompareOp::Greater),
)
}
pub fn greater_equal(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare(
lhs,
rhs,
|a, b| a >= b,
|a, b| a >= b,
Some(CompareOp::GreaterEqual),
)
}
pub fn greater_equal_elem(lhs: FlexTensor, rhs: f64) -> FlexTensor {
compare_elem(
lhs,
rhs,
|a, b| a >= b,
|a, b| a >= b,
Some(CompareOp::GreaterEqual),
)
}
pub fn lower(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare(lhs, rhs, |a, b| a < b, |a, b| a < b, Some(CompareOp::Lower))
}
pub fn lower_elem(lhs: FlexTensor, rhs: f64) -> FlexTensor {
compare_elem(lhs, rhs, |a, b| a < b, |a, b| a < b, Some(CompareOp::Lower))
}
pub fn lower_equal(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare(
lhs,
rhs,
|a, b| a <= b,
|a, b| a <= b,
Some(CompareOp::LowerEqual),
)
}
pub fn lower_equal_elem(lhs: FlexTensor, rhs: f64) -> FlexTensor {
compare_elem(
lhs,
rhs,
|a, b| a <= b,
|a, b| a <= b,
Some(CompareOp::LowerEqual),
)
}
pub fn equal(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare(
lhs,
rhs,
|a, b| a == b,
|a, b| a == b,
Some(CompareOp::Equal),
)
}
pub fn equal_elem(lhs: FlexTensor, rhs: f64) -> FlexTensor {
compare_elem(
lhs,
rhs,
|a, b| a == b,
|a, b| a == b,
Some(CompareOp::Equal),
)
}
pub fn not_equal(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare(
lhs,
rhs,
|a, b| a != b,
|a, b| a != b,
Some(CompareOp::NotEqual),
)
}
pub fn not_equal_elem(lhs: FlexTensor, rhs: f64) -> FlexTensor {
compare_elem(
lhs,
rhs,
|a, b| a != b,
|a, b| a != b,
Some(CompareOp::NotEqual),
)
}
fn compare_int<Cmp>(lhs: FlexTensor, rhs: FlexTensor, cmp: Cmp) -> FlexTensor
where
Cmp: Fn(i64, i64) -> bool,
{
debug_assert_eq!(lhs.dtype(), DType::I64, "compare_int: expected I64 dtype");
debug_assert_eq!(rhs.dtype(), DType::I64, "compare_int: expected I64 dtype");
let (lhs, rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
let shape = lhs.layout().shape().clone();
let lhs_storage: &[i64] = lhs.storage();
let rhs_storage: &[i64] = rhs.storage();
let result: Vec<u8> = match (
lhs.layout().contiguous_offsets(),
rhs.layout().contiguous_offsets(),
) {
(Some((l_start, l_end)), Some((r_start, r_end))) => {
let l_slice = &lhs_storage[l_start..l_end];
let r_slice = &rhs_storage[r_start..r_end];
l_slice
.iter()
.zip(r_slice)
.map(|(&a, &b)| cmp(a, b) as u8)
.collect()
}
_ => {
let lhs_iter = StridedIter::new(lhs.layout());
let rhs_iter = StridedIter::new(rhs.layout());
lhs_iter
.zip(rhs_iter)
.map(|(li, ri)| cmp(lhs_storage[li], rhs_storage[ri]) as u8)
.collect()
}
};
make_bool_tensor(result, shape)
}
fn compare_int_elem<Cmp>(lhs: FlexTensor, rhs: i64, cmp: Cmp) -> FlexTensor
where
Cmp: Fn(i64, i64) -> bool,
{
debug_assert_eq!(
lhs.dtype(),
DType::I64,
"compare_int_elem: expected I64 dtype"
);
let shape = lhs.layout().shape().clone();
let lhs_storage: &[i64] = lhs.storage();
let result: Vec<u8> = match lhs.layout().contiguous_offsets() {
Some((start, end)) => lhs_storage[start..end]
.iter()
.map(|&a| cmp(a, rhs) as u8)
.collect(),
None => StridedIter::new(lhs.layout())
.map(|idx| cmp(lhs_storage[idx], rhs) as u8)
.collect(),
};
make_bool_tensor(result, shape)
}
pub fn int_greater(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare_int(lhs, rhs, |a, b| a > b)
}
pub fn int_greater_elem(lhs: FlexTensor, rhs: i64) -> FlexTensor {
compare_int_elem(lhs, rhs, |a, b| a > b)
}
pub fn int_greater_equal(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare_int(lhs, rhs, |a, b| a >= b)
}
pub fn int_greater_equal_elem(lhs: FlexTensor, rhs: i64) -> FlexTensor {
compare_int_elem(lhs, rhs, |a, b| a >= b)
}
pub fn int_lower(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare_int(lhs, rhs, |a, b| a < b)
}
pub fn int_lower_elem(lhs: FlexTensor, rhs: i64) -> FlexTensor {
compare_int_elem(lhs, rhs, |a, b| a < b)
}
pub fn int_lower_equal(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare_int(lhs, rhs, |a, b| a <= b)
}
pub fn int_lower_equal_elem(lhs: FlexTensor, rhs: i64) -> FlexTensor {
compare_int_elem(lhs, rhs, |a, b| a <= b)
}
pub fn int_equal(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare_int(lhs, rhs, |a, b| a == b)
}
pub fn int_equal_elem(lhs: FlexTensor, rhs: i64) -> FlexTensor {
compare_int_elem(lhs, rhs, |a, b| a == b)
}
pub fn int_not_equal(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
compare_int(lhs, rhs, |a, b| a != b)
}
pub fn int_not_equal_elem(lhs: FlexTensor, rhs: i64) -> FlexTensor {
compare_int_elem(lhs, rhs, |a, b| a != b)
}
pub fn bool_not_equal(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let (lhs, rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
let shape = lhs.layout().shape().clone();
let lhs_data: &[u8] = lhs.bytes();
let rhs_data: &[u8] = rhs.bytes();
let result: Vec<u8> = match (
lhs.layout().contiguous_offsets(),
rhs.layout().contiguous_offsets(),
) {
(Some((ls, le)), Some((rs, re))) => lhs_data[ls..le]
.iter()
.zip(&rhs_data[rs..re])
.map(|(&a, &b)| if a != b { 1 } else { 0 })
.collect(),
_ => {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
lhs.bytes()
.iter()
.zip(rhs.bytes())
.map(|(&a, &b)| if a != b { 1 } else { 0 })
.collect()
}
};
FlexTensor::new(
Bytes::from_elems(result),
Layout::contiguous(shape),
DType::Bool(burn_std::BoolStore::Native),
)
}
pub fn bool_not_equal_elem(lhs: FlexTensor, rhs: bool) -> FlexTensor {
let rhs_val: u8 = if rhs { 1 } else { 0 };
let shape = lhs.layout().shape().clone();
let lhs = lhs.to_contiguous();
let data: &[u8] = lhs.bytes();
let result: Vec<u8> = data
.iter()
.map(|&a| if a != rhs_val { 1 } else { 0 })
.collect();
FlexTensor::new(
Bytes::from_elems(result),
Layout::contiguous(shape),
DType::Bool(burn_std::BoolStore::Native),
)
}
pub fn any_float(tensor: FlexTensor) -> FlexTensor {
let has_any = match tensor.dtype() {
DType::F32 => iter_elements::<f32>(&tensor).any(|x| x != 0.0),
DType::F64 => iter_elements::<f64>(&tensor).any(|x| x != 0.0),
DType::F16 => iter_elements::<f16>(&tensor).any(|x: f16| x.to_f32() != 0.0),
DType::BF16 => iter_elements::<bf16>(&tensor).any(|x: bf16| x.to_f32() != 0.0),
_ => panic!("any_float: unsupported dtype {:?}", tensor.dtype()),
};
bool_scalar(has_any)
}
pub fn any_float_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
reduce_bool_dim(&tensor, dim, false, |a, b| a || b)
}
pub fn all_float(tensor: FlexTensor) -> FlexTensor {
let all = match tensor.dtype() {
DType::F32 => iter_elements::<f32>(&tensor).all(|x| x != 0.0),
DType::F64 => iter_elements::<f64>(&tensor).all(|x| x != 0.0),
DType::F16 => iter_elements::<f16>(&tensor).all(|x: f16| x.to_f32() != 0.0),
DType::BF16 => iter_elements::<bf16>(&tensor).all(|x: bf16| x.to_f32() != 0.0),
_ => panic!("all_float: unsupported dtype {:?}", tensor.dtype()),
};
bool_scalar(all)
}
pub fn all_float_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
reduce_bool_dim(&tensor, dim, true, |a, b| a && b)
}
pub fn any_int(tensor: FlexTensor) -> FlexTensor {
let has_any = match tensor.dtype() {
DType::I64 => iter_elements::<i64>(&tensor).any(|x| x != 0),
DType::I32 => iter_elements::<i32>(&tensor).any(|x| x != 0),
_ => panic!("any_int: unsupported dtype {:?}", tensor.dtype()),
};
bool_scalar(has_any)
}
pub fn any_int_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
reduce_bool_dim_int(&tensor, dim, false, |a, b| a || b)
}
pub fn all_int(tensor: FlexTensor) -> FlexTensor {
let all = match tensor.dtype() {
DType::I64 => iter_elements::<i64>(&tensor).all(|x| x != 0),
DType::I32 => iter_elements::<i32>(&tensor).all(|x| x != 0),
_ => panic!("all_int: unsupported dtype {:?}", tensor.dtype()),
};
bool_scalar(all)
}
pub fn all_int_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
reduce_bool_dim_int(&tensor, dim, true, |a, b| a && b)
}
pub fn any_bool(tensor: FlexTensor) -> FlexTensor {
let tensor = tensor.to_contiguous();
let data: &[u8] = tensor.bytes();
bool_scalar(data.iter().any(|&x| x != 0))
}
pub fn any_bool_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
reduce_bool_dim_raw(&tensor, dim, false, |a, b| a || b)
}
pub fn all_bool(tensor: FlexTensor) -> FlexTensor {
let tensor = tensor.to_contiguous();
let data: &[u8] = tensor.bytes();
bool_scalar(data.iter().all(|&x| x != 0))
}
pub fn all_bool_dim(tensor: FlexTensor, dim: usize) -> FlexTensor {
reduce_bool_dim_raw(&tensor, dim, true, |a, b| a && b)
}
fn bool_scalar(val: bool) -> FlexTensor {
let byte: u8 = if val { 1 } else { 0 };
FlexTensor::new(
Bytes::from_elems(alloc::vec![byte]),
Layout::contiguous(Shape::from(alloc::vec![1])),
DType::Bool(burn_std::BoolStore::Native),
)
}
fn iter_elements<'a, E: Element + Pod + 'a>(
tensor: &'a FlexTensor,
) -> Box<dyn Iterator<Item = E> + 'a> {
let data: &[E] = tensor.storage();
match tensor.layout().contiguous_offsets() {
Some((start, end)) => Box::new(data[start..end].iter().copied()),
None => Box::new(StridedIter::new(tensor.layout()).map(move |idx| data[idx])),
}
}
fn reduce_bool_dim_with(
tensor: &FlexTensor,
dim: usize,
init: bool,
combine: fn(bool, bool) -> bool,
is_nonzero: impl Fn(usize) -> bool,
) -> FlexTensor {
debug_assert!(tensor.is_contiguous() && tensor.layout().start_offset() == 0);
let shape = tensor.layout().shape();
let ndims = shape.num_dims();
assert!(dim < ndims);
let dim_size = shape[dim];
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape[dim] = 1;
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let out_size = outer_size.max(1) * inner_size.max(1);
let mut result: Vec<u8> = Vec::with_capacity(out_size);
for outer in 0..outer_size.max(1) {
for inner in 0..inner_size.max(1) {
let mut acc = init;
for d in 0..dim_size {
let idx = outer * dim_size * inner_size + d * inner_size + inner;
acc = combine(acc, is_nonzero(idx));
}
result.push(if acc { 1 } else { 0 });
}
}
FlexTensor::new(
Bytes::from_elems(result),
Layout::contiguous(Shape::from(out_shape)),
DType::Bool(burn_std::BoolStore::Native),
)
}
fn reduce_bool_dim(
tensor: &FlexTensor,
dim: usize,
init: bool,
combine: fn(bool, bool) -> bool,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
match tensor.dtype() {
DType::F32 => {
let data: &[f32] = tensor.storage();
reduce_bool_dim_with(&tensor, dim, init, combine, |idx| data[idx] != 0.0)
}
DType::F64 => {
let data: &[f64] = tensor.storage();
reduce_bool_dim_with(&tensor, dim, init, combine, |idx| data[idx] != 0.0)
}
DType::F16 => {
let data: &[f16] = tensor.storage();
reduce_bool_dim_with(&tensor, dim, init, combine, |idx| data[idx].to_f32() != 0.0)
}
DType::BF16 => {
let data: &[bf16] = tensor.storage();
reduce_bool_dim_with(&tensor, dim, init, combine, |idx| data[idx].to_f32() != 0.0)
}
_ => panic!("reduce_bool_dim: unsupported dtype {:?}", tensor.dtype()),
}
}
fn reduce_bool_dim_int(
tensor: &FlexTensor,
dim: usize,
init: bool,
combine: fn(bool, bool) -> bool,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let data: &[i64] = tensor.storage();
reduce_bool_dim_with(&tensor, dim, init, combine, |idx| data[idx] != 0)
}
fn reduce_bool_dim_raw(
tensor: &FlexTensor,
dim: usize,
init: bool,
combine: fn(bool, bool) -> bool,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let data: &[u8] = tensor.bytes();
reduce_bool_dim_with(&tensor, dim, init, combine, |idx| data[idx] != 0)
}
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::TensorData;
#[test]
fn test_greater() {
let lhs = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0], [3]));
let rhs = FlexTensor::from_data(TensorData::new(vec![2.0f32, 2.0, 1.0], [3]));
let result = greater(lhs, rhs);
let data: &[u8] = result.bytes();
assert_eq!(data, &[0, 0, 1]); }
#[test]
fn test_greater_elem() {
let lhs = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0], [3]));
let result = greater_elem(lhs, 2.0);
let data: &[u8] = result.bytes();
assert_eq!(data, &[0, 0, 1]); }
#[test]
fn test_equal() {
let lhs = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0], [3]));
let rhs = FlexTensor::from_data(TensorData::new(vec![1.0f32, 3.0, 3.0], [3]));
let result = equal(lhs, rhs);
let data: &[u8] = result.bytes();
assert_eq!(data, &[1, 0, 1]); }
fn tensor_2d(data: Vec<f32>, rows: usize, cols: usize) -> FlexTensor {
FlexTensor::from_data(TensorData::new(data, vec![rows, cols]))
}
#[test]
fn test_greater_transposed() {
let lhs = tensor_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let lhs = lhs.transpose(0, 1);
assert!(!lhs.is_contiguous());
let rhs = tensor_2d(vec![2.0, 2.0, 2.0, 2.0], 2, 2);
let result = greater(lhs, rhs);
let data: &[u8] = result.bytes();
assert_eq!(data, &[0, 1, 0, 1]);
}
#[test]
fn test_equal_flipped_1d() {
let lhs = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [4]));
let lhs = crate::ops::flip::flip(lhs, &[0]);
assert!(lhs.layout().strides()[0] < 0);
let rhs = FlexTensor::from_data(TensorData::new(vec![4.0f32, 2.0, 2.0, 1.0], [4]));
let result = equal(lhs, rhs);
let data: &[u8] = result.bytes();
assert_eq!(data, &[1, 0, 1, 1]);
}
#[test]
fn test_lower_flipped_2d() {
let lhs = tensor_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let lhs = crate::ops::flip::flip(lhs, &[0]);
assert!(lhs.layout().strides()[0] < 0);
let rhs = tensor_2d(vec![2.0, 5.0, 2.0, 1.0], 2, 2);
let result = lower(lhs, rhs);
let data: &[u8] = result.bytes();
assert_eq!(data, &[0, 1, 1, 0]);
}
#[test]
fn test_greater_elem_flipped() {
let lhs = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [4]));
let lhs = crate::ops::flip::flip(lhs, &[0]);
assert!(lhs.layout().strides()[0] < 0);
let result = greater_elem(lhs, 2.5);
let data: &[u8] = result.bytes();
assert_eq!(data, &[1, 1, 0, 0]);
}
#[test]
fn test_equal_both_transposed() {
let lhs = tensor_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).transpose(0, 1);
let rhs = tensor_2d(vec![1.0, 3.0, 2.0, 4.0], 2, 2).transpose(0, 1);
assert!(!lhs.is_contiguous());
assert!(!rhs.is_contiguous());
let result = equal(lhs, rhs);
let data: &[u8] = result.bytes();
assert_eq!(data, &[1, 0, 0, 1]);
}
#[test]
fn test_not_equal_narrowed() {
let lhs =
FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [6]));
let lhs = lhs.narrow(0, 1, 4);
let rhs = FlexTensor::from_data(TensorData::new(vec![2.0f32, 2.0, 4.0, 4.0], [4]));
let result = not_equal(lhs, rhs);
let data: &[u8] = result.bytes();
assert_eq!(data, &[0, 1, 0, 1]);
}
#[test]
fn test_int_greater_flipped() {
let lhs = FlexTensor::from_data(TensorData::new(vec![1i64, 2, 3, 4], [4]));
let lhs = crate::ops::flip::flip(lhs, &[0]);
assert!(lhs.layout().strides()[0] < 0);
let rhs = FlexTensor::from_data(TensorData::new(vec![3i64, 3, 3, 3], [4]));
let result = int_greater(lhs, rhs);
let data: &[u8] = result.bytes();
assert_eq!(data, &[1, 0, 0, 0]);
}
#[test]
fn test_lower_flipped_both_axes() {
let lhs = tensor_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let lhs = crate::ops::flip::flip(lhs, &[0, 1]);
assert!(lhs.layout().strides()[0] < 0);
assert!(lhs.layout().strides()[1] < 0);
let rhs = tensor_2d(vec![3.0, 3.0, 3.0, 3.0], 2, 2);
let result = lower(lhs, rhs);
let data: &[u8] = result.bytes();
assert_eq!(data, &[0, 0, 1, 1]);
}
#[test]
fn test_any_float_dim_transposed() {
let tensor = tensor_2d(vec![0.0, 1.0, 0.0, 0.0], 2, 2);
let transposed = tensor.transpose(0, 1);
assert!(!transposed.is_contiguous());
let result = any_float_dim(transposed, 1);
let data: &[u8] = result.bytes();
assert_eq!(data, &[0, 1]); }
#[test]
fn test_any_float_dim_narrowed() {
let tensor = FlexTensor::from_data(TensorData::new(
vec![0.0f32, 5.0, 0.0, 3.0, 0.0, 0.0],
[3, 2],
));
let narrowed = tensor.narrow(0, 0, 2); let result = any_float_dim(narrowed, 1);
let data: &[u8] = result.bytes();
assert_eq!(data, &[1, 1]);
}
}