pub trait ChooseFrom<Lhs, Rhs>: HasErr {
    type Output;

    // Required method
    fn try_choose(self, lhs: Lhs, rhs: Rhs) -> Result<Self::Output, Self::Err>;

    // Provided method
    fn choose(self, lhs: Lhs, rhs: Rhs) -> Self::Output { ... }
}
Expand description

Choose values from two tensors using a boolean mask. Equivalent to torch.where from pytorch.

let cond: Tensor<Rank1<3>, bool, _> = dev.tensor([true, false, true]);
let a: Tensor<Rank1<3>, f32, _> = dev.tensor([1.0, 2.0, 3.0]);
let b: Tensor<Rank1<3>, f32, _> = dev.tensor([-1.0, -2.0, -3.0]);
let c = cond.choose(a, b);
assert_eq!(c.array(), [1.0, -2.0, 3.0]);

Required Associated Types§

Required Methods§

source

fn try_choose(self, lhs: Lhs, rhs: Rhs) -> Result<Self::Output, Self::Err>

Fallible version of choose

Provided Methods§

source

fn choose(self, lhs: Lhs, rhs: Rhs) -> Self::Output

Construct a new tensor, where the output tensor contains the elements of lhs where self is true, and rhs where self is false.

Implementors§

source§

impl<S: Shape, E: Dtype, D: ChooseKernel<E>, LhsTape: Tape<E, D> + Merge<RhsTape>, RhsTape: Tape<E, D>> ChooseFrom<Tensor<S, E, D, LhsTape>, Tensor<S, E, D, RhsTape>> for Tensor<S, bool, D>

§

type Output = Tensor<S, E, D, LhsTape>