burn_ndarray/ops/
bool_tensor.rs

1// Language
2use alloc::vec;
3use alloc::vec::Vec;
4use burn_tensor::ops::{BoolTensorOps, FloatTensor, IntTensorOps};
5use burn_tensor::{ElementConversion, TensorMetadata};
6use ndarray::IntoDimension;
7
8// Current crate
9use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
10use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
11use crate::{NdArrayDevice, SharedArray};
12
13// Workspace crates
14use burn_tensor::{Shape, TensorData, backend::Backend};
15
16use super::{NdArrayBoolOps, NdArrayOps};
17
18impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
19    for NdArray<E, I, Q>
20where
21    NdArrayTensor: From<SharedArray<E>>,
22    NdArrayTensor: From<SharedArray<I>>,
23{
24    fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
25        if !data.dtype.is_bool() {
26            unimplemented!("Unsupported dtype for `bool_from_data`")
27        }
28        NdArrayTensor::from_data(data)
29    }
30
31    async fn bool_into_data(tensor: NdArrayTensor) -> TensorData {
32        tensor.into_data()
33    }
34
35    fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
36        tensor
37    }
38
39    fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
40        NdArrayOps::reshape(tensor.bool(), shape).into()
41    }
42
43    fn bool_slice(tensor: NdArrayTensor, slices: &[burn_tensor::Slice]) -> NdArrayTensor {
44        NdArrayOps::slice(tensor.bool(), slices).into()
45    }
46
47    fn bool_into_int(tensor: NdArrayTensor) -> NdArrayTensor {
48        let shape = tensor.shape();
49        let values = tensor.bool().into_iter().collect();
50        NdArray::<E, I>::int_from_data(
51            TensorData::new(values, shape).convert::<I>(),
52            &NdArrayDevice::Cpu,
53        )
54    }
55
56    fn bool_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
57        NdArrayDevice::Cpu
58    }
59
60    fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
61        Self::bool_zeros(shape, _device)
62    }
63
64    fn bool_zeros(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
65        let values = vec![false; shape.num_elements()];
66        NdArrayTensor::from_data(TensorData::new(values, shape))
67    }
68
69    fn bool_ones(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
70        let values = vec![true; shape.num_elements()];
71        NdArrayTensor::from_data(TensorData::new(values, shape))
72    }
73
74    fn bool_slice_assign(
75        tensor: NdArrayTensor,
76        slices: &[burn_tensor::Slice],
77        value: NdArrayTensor,
78    ) -> NdArrayTensor {
79        NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()
80    }
81
82    fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
83        NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()
84    }
85
86    fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
87        NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()
88    }
89
90    fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {
91        tensor.bool().mapv(|a| !a).into_shared().into()
92    }
93
94    fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
95        NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()
96    }
97
98    fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
99        NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()
100    }
101
102    fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
103        let arr: SharedArray<E> = tensor.bool().mapv(|a| (a as i32).elem()).into_shared();
104        arr.into()
105    }
106
107    fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
108        NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()
109    }
110
111    fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
112        tensor.bool().permuted_axes(axes.into_dimension()).into()
113    }
114
115    fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
116        NdArrayOps::expand(tensor.bool(), shape).into()
117    }
118
119    fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
120        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
121            let tensor_bool = tensor.bool();
122            let indices_vec: Vec<usize> = indices
123                .into_iter()
124                .map(|i| i.elem::<i64>() as usize)
125                .collect();
126
127            let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
128            selected.into_shared().into()
129        })
130    }
131
132    fn bool_select_assign(
133        tensor: NdArrayTensor,
134        dim: usize,
135        indices: NdArrayTensor,
136        value: NdArrayTensor,
137    ) -> NdArrayTensor {
138        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
139            let mut output_array = tensor.bool().into_owned();
140            let value_bool = value.bool();
141
142            for (index_value, index) in indices.into_iter().enumerate() {
143                let index_usize = index.elem::<i64>() as usize;
144                let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
145                let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
146                // For boolean tensors, select_assign should use logical OR operation
147                view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
148            }
149            output_array.into_shared().into()
150        })
151    }
152
153    fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
154        NdArrayOps::flip(tensor.bool(), axes).into()
155    }
156
157    fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {
158        NdArrayOps::unfold(tensor.bool(), dim, size, step).into()
159    }
160}