use burn_backend::DType;
use cubecl::prelude::InputScalar;
use super::{MaskFillStrategy, mask_where::MaskWhereStrategy};
use crate::{CubeRuntime, tensor::CubeTensor};
pub(crate) fn mask_fill_auto<R: CubeRuntime>(
tensor: CubeTensor<R>,
mask: CubeTensor<R>,
value: InputScalar,
dtype_bool: DType,
) -> CubeTensor<R> {
let strategy = if tensor.can_mut() && tensor.is_nonoverlapping() {
MaskFillStrategy::Inplace
} else {
MaskFillStrategy::Readonly
};
super::mask_fill(tensor, mask, value, strategy, dtype_bool)
}
pub(crate) fn mask_where_auto<R: CubeRuntime>(
tensor: CubeTensor<R>,
mask: CubeTensor<R>,
value: CubeTensor<R>,
dtype_bool: DType,
) -> CubeTensor<R> {
let strategy = if tensor.can_mut_broadcast(&value) {
MaskWhereStrategy::InplaceLhs
} else if value.can_mut_broadcast(&tensor) {
MaskWhereStrategy::InplaceRhs
} else {
MaskWhereStrategy::Readonly
};
super::mask_where(tensor, mask, value, strategy, dtype_bool)
}