burn_jit/ops/
bool_ops.rs

1use crate::{element::BoolElement, kernel, FloatElement, IntElement, JitBackend, JitRuntime};
2use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor};
3use burn_tensor::{ops::BoolTensorOps, Shape, TensorData};
4use std::ops::Range;
5
6use super::{expand, permute};
7
8impl<R, F, I, BT> BoolTensorOps<Self> for JitBackend<R, F, I, BT>
9where
10    R: JitRuntime,
11    F: FloatElement,
12    I: IntElement,
13    BT: BoolElement,
14{
15    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
16        super::empty::<R, BT>(shape, device)
17    }
18
19    async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
20        super::bool_into_data::<R, BT>(tensor).await
21    }
22
23    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
24        let data: TensorData = TensorData::new(data.iter::<BT>().collect(), data.shape);
25        super::from_data::<R, BT>(data, device)
26    }
27
28    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
29        kernel::bool_cast::<R, BT, I>(tensor)
30    }
31
32    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
33        tensor.device.clone()
34    }
35
36    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
37        super::to_device(tensor, device)
38    }
39
40    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
41        super::reshape(tensor, shape)
42    }
43
44    fn bool_slice(tensor: BoolTensor<Self>, ranges: &[Range<usize>]) -> BoolTensor<Self> {
45        kernel::slice::<R, BT>(tensor, ranges)
46    }
47
48    fn bool_slice_assign(
49        tensor: BoolTensor<Self>,
50        ranges: &[Range<usize>],
51        value: BoolTensor<Self>,
52    ) -> BoolTensor<Self> {
53        kernel::slice_assign::<R, BT>(tensor, ranges, value)
54    }
55
56    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
57        kernel::equal::<R, BT, BT>(lhs, rhs)
58    }
59
60    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
61        kernel::equal_elem::<R, BT, BT>(tensor, BT::false_val())
62    }
63
64    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
65        kernel::bool_cast::<R, BT, F>(tensor)
66    }
67
68    fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
69        tensor.strides.swap(dim1, dim2);
70        tensor.shape.dims.swap(dim1, dim2);
71
72        tensor
73    }
74
75    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
76        kernel::repeat_dim::<R, BT>(tensor, dim, times)
77    }
78
79    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
80        permute(tensor, axes)
81    }
82
83    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
84        expand(tensor, shape)
85    }
86
87    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
88        kernel::flip::<R, BT, BT>(tensor, axes)
89    }
90}