Skip to main content

bool_cast

Function bool_cast 

Source
pub fn bool_cast<R: CubeRuntime>(
    tensor: CubeTensor<R>,
    out_dtype: DType,
) -> CubeTensor<R>
Expand description

Cast a bool tensor to the given element type.

This alternative to cast is necessary because bool are represented as u32 or u8 where any non-zero value means true. Depending how it was created it may hold an uncanny bit combination. Naively casting it would not necessarily yield 0 or 1.