use alloc::vec::Vec;
use burn_backend::{DType, Element};
use burn_std::{Bytes, bf16, f16};
use crate::{FlexTensor, Layout};
pub fn mask_fill<T>(tensor: FlexTensor, mask: FlexTensor, value: T) -> FlexTensor
where
T: Element + bytemuck::Pod + Copy,
{
let dtype = tensor.dtype();
let (tensor, mask) = crate::ops::expand::broadcast_binary(tensor, mask);
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let shape = tensor.layout().shape().clone();
let tensor_data: &[T] = tensor.storage();
let mask_data: &[u8] = mask.bytes();
let result: Vec<T> = tensor_data
.iter()
.zip(mask_data.iter())
.map(|(&elem, &m)| if m != 0 { value } else { elem })
.collect();
FlexTensor::new(Bytes::from_elems(result), Layout::contiguous(shape), dtype)
}
pub fn mask_fill_f32(tensor: FlexTensor, mask: FlexTensor, value: f32) -> FlexTensor {
mask_fill(tensor, mask, value)
}
pub fn mask_fill_f64(tensor: FlexTensor, mask: FlexTensor, value: f64) -> FlexTensor {
mask_fill(tensor, mask, value)
}
pub fn mask_fill_f16(tensor: FlexTensor, mask: FlexTensor, value: f16) -> FlexTensor {
mask_fill(tensor, mask, value)
}
pub fn mask_fill_bf16(tensor: FlexTensor, mask: FlexTensor, value: bf16) -> FlexTensor {
mask_fill(tensor, mask, value)
}
pub fn mask_fill_i64(tensor: FlexTensor, mask: FlexTensor, value: i64) -> FlexTensor {
mask_fill(tensor, mask, value)
}
pub fn mask_fill_bool(tensor: FlexTensor, mask: FlexTensor, value: bool) -> FlexTensor {
let (tensor, mask) = crate::ops::expand::broadcast_binary(tensor, mask);
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let shape = tensor.layout().shape().clone();
let tensor_data: &[u8] = tensor.bytes();
let mask_data: &[u8] = mask.bytes();
let value_u8 = value as u8;
let result: Vec<u8> = tensor_data
.iter()
.zip(mask_data.iter())
.map(|(&elem, &m)| if m != 0 { value_u8 } else { elem })
.collect();
FlexTensor::new(
Bytes::from_elems(result),
Layout::contiguous(shape),
DType::Bool(burn_std::BoolStore::Native),
)
}
pub fn mask_where<T>(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor
where
T: Element + bytemuck::Pod + Copy,
{
let dtype = tensor.dtype();
let target_shape =
crate::ops::expand::broadcast_shape(tensor.layout().shape(), mask.layout().shape());
let target_shape = crate::ops::expand::broadcast_shape(&target_shape, value.layout().shape());
let tensor = if tensor.layout().shape() == &target_shape {
tensor
} else {
crate::ops::expand::expand(tensor, target_shape.clone())
};
let mask = if mask.layout().shape() == &target_shape {
mask
} else {
crate::ops::expand::expand(mask, target_shape.clone())
};
let value = if value.layout().shape() == &target_shape {
value
} else {
crate::ops::expand::expand(value, target_shape.clone())
};
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let value = value.to_contiguous();
let shape = tensor.layout().shape().clone();
let tensor_data: &[T] = tensor.storage();
let mask_data: &[u8] = mask.bytes();
let value_data: &[T] = value.storage();
let result: Vec<T> = tensor_data
.iter()
.zip(mask_data.iter())
.zip(value_data.iter())
.map(|((&t, &m), &v)| if m != 0 { v } else { t })
.collect();
FlexTensor::new(Bytes::from_elems(result), Layout::contiguous(shape), dtype)
}
pub fn mask_where_f32(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
mask_where::<f32>(tensor, mask, value)
}
pub fn mask_where_f64(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
mask_where::<f64>(tensor, mask, value)
}
pub fn mask_where_f16(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
mask_where::<f16>(tensor, mask, value)
}
pub fn mask_where_bf16(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
mask_where::<bf16>(tensor, mask, value)
}
pub fn mask_where_i64(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
mask_where::<i64>(tensor, mask, value)
}
pub fn mask_where_bool(tensor: FlexTensor, mask: FlexTensor, value: FlexTensor) -> FlexTensor {
let target_shape =
crate::ops::expand::broadcast_shape(tensor.layout().shape(), mask.layout().shape());
let target_shape = crate::ops::expand::broadcast_shape(&target_shape, value.layout().shape());
let tensor = if tensor.layout().shape() == &target_shape {
tensor
} else {
crate::ops::expand::expand(tensor, target_shape.clone())
};
let mask = if mask.layout().shape() == &target_shape {
mask
} else {
crate::ops::expand::expand(mask, target_shape.clone())
};
let value = if value.layout().shape() == &target_shape {
value
} else {
crate::ops::expand::expand(value, target_shape)
};
let tensor = tensor.to_contiguous();
let mask = mask.to_contiguous();
let value = value.to_contiguous();
let shape = tensor.layout().shape().clone();
let tensor_data: &[u8] = tensor.bytes();
let mask_data: &[u8] = mask.bytes();
let value_data: &[u8] = value.bytes();
let result: Vec<u8> = tensor_data
.iter()
.zip(mask_data.iter())
.zip(value_data.iter())
.map(|((&t, &m), &v)| if m != 0 { v } else { t })
.collect();
FlexTensor::new(
Bytes::from_elems(result),
Layout::contiguous(shape),
DType::Bool(burn_std::BoolStore::Native),
)
}
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::TensorData;
#[test]
fn test_mask_fill_f32() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [4]));
let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true, false], [4]));
let result = mask_fill_f32(tensor, mask, 0.0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![0.0, 2.0, 0.0, 4.0]);
}
#[test]
fn test_mask_fill_2d() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]));
let mask = FlexTensor::from_data(TensorData::new(vec![true, false, false, true], [2, 2]));
let result = mask_fill_f32(tensor, mask, -1.0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![-1.0, 2.0, 3.0, -1.0]);
}
#[test]
fn test_mask_where_f32() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [4]));
let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true, false], [4]));
let value = FlexTensor::from_data(TensorData::new(vec![10.0f32, 20.0, 30.0, 40.0], [4]));
let result = mask_where_f32(tensor, mask, value);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![10.0, 2.0, 30.0, 4.0]);
}
#[test]
fn test_mask_fill_broadcast() {
let tensor = FlexTensor::from_data(TensorData::new(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
[2, 3],
));
let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true], [3]));
let result = mask_fill_f32(tensor, mask, 0.0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![0.0, 2.0, 0.0, 0.0, 5.0, 0.0]);
}
#[test]
fn test_mask_fill_transposed_tensor() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]));
let tensor = tensor.transpose(0, 1);
assert!(!tensor.is_contiguous());
let mask = FlexTensor::from_data(TensorData::new(vec![true, false, false, true], [2, 2]));
let result = mask_fill_f32(tensor, mask, 0.0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![0.0, 3.0, 2.0, 0.0]);
}
#[test]
fn test_mask_fill_flipped_tensor() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [4]));
let tensor = crate::ops::flip::flip(tensor, &[0]);
assert!(tensor.layout().strides()[0] < 0);
let mask = FlexTensor::from_data(TensorData::new(vec![true, true, false, false], [4]));
let result = mask_fill_f32(tensor, mask, 0.0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![0.0, 0.0, 2.0, 1.0]);
}
#[test]
fn test_mask_fill_flipped_mask() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [4]));
let mask = FlexTensor::from_data(TensorData::new(vec![false, false, true, true], [4]));
let mask = crate::ops::flip::flip(mask, &[0]);
assert!(mask.layout().strides()[0] < 0);
let result = mask_fill_f32(tensor, mask, 0.0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![0.0, 0.0, 3.0, 4.0]);
}
#[test]
fn test_mask_where_flipped_2d() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]));
let tensor = crate::ops::flip::flip(tensor, &[0]);
assert!(tensor.layout().strides()[0] < 0);
let mask = FlexTensor::from_data(TensorData::new(vec![true, false, false, true], [2, 2]));
let value = FlexTensor::from_data(TensorData::new(vec![10.0f32, 20.0, 30.0, 40.0], [2, 2]));
let result = mask_where_f32(tensor, mask, value);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![10.0, 4.0, 1.0, 40.0]);
}
#[test]
fn test_mask_fill_both_flipped() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [4]));
let tensor = crate::ops::flip::flip(tensor, &[0]);
assert!(tensor.layout().strides()[0] < 0);
let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true, false], [4]));
let mask = crate::ops::flip::flip(mask, &[0]);
assert!(mask.layout().strides()[0] < 0);
let result = mask_fill_f32(tensor, mask, 0.0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![4.0, 0.0, 2.0, 0.0]);
}
#[test]
fn test_mask_fill_narrowed_tensor() {
let tensor =
FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [6]));
let tensor = tensor.narrow(0, 1, 4);
let mask = FlexTensor::from_data(TensorData::new(vec![true, false, false, true], [4]));
let result = mask_fill_f32(tensor, mask, 0.0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![0.0, 3.0, 4.0, 0.0]);
}
#[test]
fn test_mask_fill_i64_flipped() {
let tensor = FlexTensor::from_data(TensorData::new(vec![10i64, 20, 30, 40], [4]));
let tensor = crate::ops::flip::flip(tensor, &[0]);
assert!(tensor.layout().strides()[0] < 0);
let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true, false], [4]));
let result = mask_fill_i64(tensor, mask, -1);
let data: Vec<i64> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![-1, 30, -1, 10]);
}
}