use alloc::vec::Vec;
use burn_backend::Element;
use burn_std::{Bytes, bf16, f16};
use crate::{FlexTensor, Layout};
#[cfg(feature = "simd")]
#[inline]
fn uninit_vec<T: Copy>(len: usize) -> Vec<T> {
let mut v = Vec::with_capacity(len);
#[allow(clippy::uninit_vec)]
unsafe {
v.set_len(len);
}
v
}
pub fn mask_fill<T>(tensor: FlexTensor, mask: FlexTensor, value: T) -> FlexTensor
where
T: Element + bytemuck::Pod + Copy,
{
let dtype = tensor.dtype();
let (tensor, mask) = crate::ops::expand::broadcast_binary(tensor, mask);
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let shape = tensor.layout().shape().clone();
let tensor_data: &[T] = tensor.storage();
let mask_data: &[u8] = mask.bytes();
let result: Vec<T> = tensor_data
.iter()
.zip(mask_data.iter())
.map(|(&elem, &m)| if m != 0 { value } else { elem })
.collect();
FlexTensor::new(Bytes::from_elems(result), Layout::contiguous(shape), dtype)
}
pub fn mask_fill_f32(tensor: FlexTensor, mask: FlexTensor, value: f32) -> FlexTensor {
#[cfg(feature = "simd")]
{
let dtype = tensor.dtype();
let (tensor, mask) = crate::ops::expand::broadcast_binary(tensor, mask);
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let shape = tensor.layout().shape().clone();
let len = tensor.storage::<f32>().len();
let mut out = uninit_vec::<f32>(len);
crate::simd::mask_fill_f32(tensor.storage(), mask.bytes(), value, &mut out);
FlexTensor::new(Bytes::from_elems(out), Layout::contiguous(shape), dtype)
}
#[cfg(not(feature = "simd"))]
{
mask_fill(tensor, mask, value)
}
}
pub fn mask_fill_f64(tensor: FlexTensor, mask: FlexTensor, value: f64) -> FlexTensor {
#[cfg(feature = "simd")]
{
let dtype = tensor.dtype();
let (tensor, mask) = crate::ops::expand::broadcast_binary(tensor, mask);
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let shape = tensor.layout().shape().clone();
let len = tensor.storage::<f64>().len();
let mut out = uninit_vec::<f64>(len);
crate::simd::mask_fill_f64(tensor.storage(), mask.bytes(), value, &mut out);
FlexTensor::new(Bytes::from_elems(out), Layout::contiguous(shape), dtype)
}
#[cfg(not(feature = "simd"))]
{
mask_fill(tensor, mask, value)
}
}
pub fn mask_fill_f16(tensor: FlexTensor, mask: FlexTensor, value: f16) -> FlexTensor {
mask_fill(tensor, mask, value)
}
pub fn mask_fill_bf16(tensor: FlexTensor, mask: FlexTensor, value: bf16) -> FlexTensor {
mask_fill(tensor, mask, value)
}
pub fn mask_fill_i64(tensor: FlexTensor, mask: FlexTensor, value: i64) -> FlexTensor {
#[cfg(feature = "simd")]
{
let dtype = tensor.dtype();
let (tensor, mask) = crate::ops::expand::broadcast_binary(tensor, mask);
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let shape = tensor.layout().shape().clone();
let len = tensor.storage::<i64>().len();
let mut out = uninit_vec::<i64>(len);
crate::simd::mask_fill_i64(tensor.storage(), mask.bytes(), value, &mut out);
FlexTensor::new(Bytes::from_elems(out), Layout::contiguous(shape), dtype)
}
#[cfg(not(feature = "simd"))]
{
mask_fill(tensor, mask, value)
}
}
pub fn mask_fill_u64(tensor: FlexTensor, mask: FlexTensor, value: u64) -> FlexTensor {
mask_fill(tensor, mask, value)
}
pub fn mask_fill_bool(tensor: FlexTensor, mask: FlexTensor, value: bool) -> FlexTensor {
let out_dtype = burn_std::BoolDType::from(tensor.dtype());
#[cfg(feature = "simd")]
{
let (tensor, mask) = crate::ops::expand::broadcast_binary(tensor, mask);
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let shape = tensor.layout().shape().clone();
let len = tensor.bytes().len();
let mut out = uninit_vec::<u8>(len);
crate::simd::mask_fill_u8(tensor.bytes(), mask.bytes(), value as u8, &mut out);
crate::ops::comparison::make_bool_tensor(out, shape, out_dtype)
}
#[cfg(not(feature = "simd"))]
{
let (tensor, mask) = crate::ops::expand::broadcast_binary(tensor, mask);
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let shape = tensor.layout().shape().clone();
let tensor_data: &[u8] = tensor.bytes();
let mask_data: &[u8] = mask.bytes();
let value_u8 = value as u8;
let result: Vec<u8> = tensor_data
.iter()
.zip(mask_data.iter())
.map(|(&elem, &m)| if m != 0 { value_u8 } else { elem })
.collect();
crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
}
}
pub fn mask_where<T>(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor
where
T: Element + bytemuck::Pod + Copy,
{
let dtype = tensor.dtype();
let (tensor, mask, value) = broadcast_three(tensor, mask, value);
let shape = tensor.layout().shape().clone();
let tensor_data: &[T] = tensor.storage();
let mask_data: &[u8] = mask.bytes();
let value_data: &[T] = value.storage();
let result: Vec<T> = tensor_data
.iter()
.zip(mask_data.iter())
.zip(value_data.iter())
.map(|((&t, &m), &v)| if m != 0 { v } else { t })
.collect();
FlexTensor::new(Bytes::from_elems(result), Layout::contiguous(shape), dtype)
}
fn broadcast_three(
tensor: FlexTensor,
mask: FlexTensor,
value: FlexTensor,
) -> (FlexTensor, FlexTensor, FlexTensor) {
let target_shape =
crate::ops::expand::broadcast_shape(tensor.layout().shape(), mask.layout().shape());
let target_shape = crate::ops::expand::broadcast_shape(&target_shape, value.layout().shape());
let tensor = if tensor.layout().shape() == &target_shape {
tensor
} else {
crate::ops::expand::expand(tensor, target_shape.clone())
};
let mask = if mask.layout().shape() == &target_shape {
mask
} else {
crate::ops::expand::expand(mask, target_shape.clone())
};
let value = if value.layout().shape() == &target_shape {
value
} else {
crate::ops::expand::expand(value, target_shape)
};
(
tensor.to_contiguous(),
mask.to_contiguous(),
value.to_contiguous(),
)
}
pub fn mask_where_f32(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
#[cfg(feature = "simd")]
{
let dtype = tensor.dtype();
let (tensor, mask, value) = broadcast_three(tensor, mask, value);
let shape = tensor.layout().shape().clone();
let len = tensor.storage::<f32>().len();
let mut out = uninit_vec::<f32>(len);
crate::simd::mask_where_f32(tensor.storage(), mask.bytes(), value.storage(), &mut out);
FlexTensor::new(Bytes::from_elems(out), Layout::contiguous(shape), dtype)
}
#[cfg(not(feature = "simd"))]
{
mask_where::<f32>(tensor, mask, value)
}
}
pub fn mask_where_f64(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
#[cfg(feature = "simd")]
{
let dtype = tensor.dtype();
let (tensor, mask, value) = broadcast_three(tensor, mask, value);
let shape = tensor.layout().shape().clone();
let len = tensor.storage::<f64>().len();
let mut out = uninit_vec::<f64>(len);
crate::simd::mask_where_f64(tensor.storage(), mask.bytes(), value.storage(), &mut out);
FlexTensor::new(Bytes::from_elems(out), Layout::contiguous(shape), dtype)
}
#[cfg(not(feature = "simd"))]
{
mask_where::<f64>(tensor, mask, value)
}
}
pub fn mask_where_f16(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
mask_where::<f16>(tensor, mask, value)
}
pub fn mask_where_bf16(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
mask_where::<bf16>(tensor, mask, value)
}
pub fn mask_where_i64(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
#[cfg(feature = "simd")]
{
let dtype = tensor.dtype();
let (tensor, mask, value) = broadcast_three(tensor, mask, value);
let shape = tensor.layout().shape().clone();
let len = tensor.storage::<i64>().len();
let mut out = uninit_vec::<i64>(len);
crate::simd::mask_where_i64(tensor.storage(), mask.bytes(), value.storage(), &mut out);
FlexTensor::new(Bytes::from_elems(out), Layout::contiguous(shape), dtype)
}
#[cfg(not(feature = "simd"))]
{
mask_where::<i64>(tensor, mask, value)
}
}
pub fn mask_where_bool(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
let out_dtype = burn_std::BoolDType::from(tensor.dtype());
#[cfg(feature = "simd")]
{
let (tensor, mask, value) = broadcast_three(tensor, mask, value);
let shape = tensor.layout().shape().clone();
let len = tensor.bytes().len();
let mut out = uninit_vec::<u8>(len);
crate::simd::mask_where_u8(tensor.bytes(), mask.bytes(), value.bytes(), &mut out);
crate::ops::comparison::make_bool_tensor(out, shape, out_dtype)
}
#[cfg(not(feature = "simd"))]
{
let (tensor, mask, value) = broadcast_three(tensor, mask, value);
let shape = tensor.layout().shape().clone();
let tensor_data: &[u8] = tensor.bytes();
let mask_data: &[u8] = mask.bytes();
let value_data: &[u8] = value.bytes();
let result: Vec<u8> = tensor_data
.iter()
.zip(mask_data.iter())
.zip(value_data.iter())
.map(|((&t, &m), &v)| if m != 0 { v } else { t })
.collect();
crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
}
}