pub fn value_mask<T: Tensor<Dtype = f32>>(
    t: T,
    mask: &T::NoTape,
    value: T::Dtype
) -> T
Expand description

Sets t to value anywhere mask equals value

Pytorch equivalent: t[mask == value] = value or torch.where(mask == value, value, t)

Example:

let t: Tensor1D<3> = Tensor1D::new([1.0, 2.0, 3.0]);
let m: Tensor1D<3> = Tensor1D::new([-1e10, 0.0, -1e10]);
let r = t.trace().value_mask(&m, -1e10);
assert_eq!(r.data(), &[-1e10, 2.0, -1e10]);