Function dfdx::tensor_ops::value_mask
source · [−]Expand description
Sets t
to value
anywhere mask
equals value
Pytorch equivalent: t[mask == value] = value
or torch.where(mask == value, value, t)
Example:
let t: Tensor1D<3> = Tensor1D::new([1.0, 2.0, 3.0]);
let m: Tensor1D<3> = Tensor1D::new([-1e10, 0.0, -1e10]);
let r = t.trace().value_mask(&m, -1e10);
assert_eq!(r.data(), &[-1e10, 2.0, -1e10]);