Skip to main content

mask_where

Function mask_where 

Source
pub fn mask_where<R: CubeRuntime>(
    input: CubeTensor<R>,
    mask: CubeTensor<R>,
    value: CubeTensor<R>,
    strategy: MaskWhereStrategy,
    dtype_bool: DType,
) -> CubeTensor<R>
Expand description

Execute the mask where kernel with the given strategy.