Trait dfdx::tensor_ops::ChooseFrom
source · pub trait ChooseFrom<Lhs, Rhs>: HasErr {
type Output;
// Required method
fn try_choose(self, lhs: Lhs, rhs: Rhs) -> Result<Self::Output, Self::Err>;
// Provided method
fn choose(self, lhs: Lhs, rhs: Rhs) -> Self::Output { ... }
}
Expand description
Choose values from two tensors using a boolean mask. Equivalent to torch.where
from pytorch.
let cond: Tensor<Rank1<3>, bool, _> = dev.tensor([true, false, true]);
let a: Tensor<Rank1<3>, f32, _> = dev.tensor([1.0, 2.0, 3.0]);
let b: Tensor<Rank1<3>, f32, _> = dev.tensor([-1.0, -2.0, -3.0]);
let c = cond.choose(a, b);
assert_eq!(c.array(), [1.0, -2.0, 3.0]);
Required Associated Types§
Required Methods§
sourcefn try_choose(self, lhs: Lhs, rhs: Rhs) -> Result<Self::Output, Self::Err>
fn try_choose(self, lhs: Lhs, rhs: Rhs) -> Result<Self::Output, Self::Err>
Fallible version of choose