1use crate::{element::BoolElement, kernel, FloatElement, IntElement, JitBackend, JitRuntime};
2use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor};
3use burn_tensor::{ops::BoolTensorOps, Shape, TensorData};
4use std::ops::Range;
5
6use super::{expand, permute};
7
8impl<R, F, I, BT> BoolTensorOps<Self> for JitBackend<R, F, I, BT>
9where
10 R: JitRuntime,
11 F: FloatElement,
12 I: IntElement,
13 BT: BoolElement,
14{
15 fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
16 super::empty::<R, BT>(shape, device)
17 }
18
19 async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
20 super::bool_into_data::<R, BT>(tensor).await
21 }
22
23 fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
24 let data: TensorData = TensorData::new(data.iter::<BT>().collect(), data.shape);
25 super::from_data::<R, BT>(data, device)
26 }
27
28 fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
29 kernel::bool_cast::<R, BT, I>(tensor)
30 }
31
32 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
33 tensor.device.clone()
34 }
35
36 fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
37 super::to_device(tensor, device)
38 }
39
40 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
41 super::reshape(tensor, shape)
42 }
43
44 fn bool_slice(tensor: BoolTensor<Self>, ranges: &[Range<usize>]) -> BoolTensor<Self> {
45 kernel::slice::<R, BT>(tensor, ranges)
46 }
47
48 fn bool_slice_assign(
49 tensor: BoolTensor<Self>,
50 ranges: &[Range<usize>],
51 value: BoolTensor<Self>,
52 ) -> BoolTensor<Self> {
53 kernel::slice_assign::<R, BT>(tensor, ranges, value)
54 }
55
56 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
57 kernel::equal::<R, BT, BT>(lhs, rhs)
58 }
59
60 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
61 kernel::equal_elem::<R, BT, BT>(tensor, BT::false_val())
62 }
63
64 fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
65 kernel::bool_cast::<R, BT, F>(tensor)
66 }
67
68 fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
69 tensor.strides.swap(dim1, dim2);
70 tensor.shape.dims.swap(dim1, dim2);
71
72 tensor
73 }
74
75 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
76 kernel::repeat_dim::<R, BT>(tensor, dim, times)
77 }
78
79 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
80 permute(tensor, axes)
81 }
82
83 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
84 expand(tensor, shape)
85 }
86
87 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
88 kernel::flip::<R, BT, BT>(tensor, axes)
89 }
90}