burn_ndarray/ops/
bool_tensor.rs1use alloc::vec;
3use alloc::vec::Vec;
4use burn_tensor::ops::{BoolTensorOps, FloatTensor, IntTensorOps};
5use burn_tensor::{ElementConversion, TensorMetadata};
6use core::ops::Range;
7use ndarray::{IntoDimension, Zip};
8
9use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
11use crate::{new_tensor_float, NdArrayDevice};
12use crate::{tensor::NdArrayTensor, NdArray};
13
14use burn_tensor::{backend::Backend, Shape, TensorData};
16
17use super::NdArrayOps;
18
19impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
20 for NdArray<E, I, Q>
21{
22 fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<bool> {
23 NdArrayTensor::from_data(data)
24 }
25
26 async fn bool_into_data(tensor: NdArrayTensor<bool>) -> TensorData {
27 let shape = tensor.shape();
28 let values = tensor.array.into_iter().collect();
29 TensorData::new(values, shape)
30 }
31
32 fn bool_to_device(tensor: NdArrayTensor<bool>, _device: &NdArrayDevice) -> NdArrayTensor<bool> {
33 tensor
34 }
35
36 fn bool_reshape(tensor: NdArrayTensor<bool>, shape: Shape) -> NdArrayTensor<bool> {
37 NdArrayOps::reshape(tensor, shape)
38 }
39
40 fn bool_slice(tensor: NdArrayTensor<bool>, ranges: &[Range<usize>]) -> NdArrayTensor<bool> {
41 NdArrayOps::slice(tensor, ranges)
42 }
43
44 fn bool_into_int(tensor: NdArrayTensor<bool>) -> NdArrayTensor<I> {
45 let shape = tensor.shape();
46 let values = tensor.array.into_iter().collect();
47 NdArray::<E, I>::int_from_data(
48 TensorData::new(values, shape).convert::<I>(),
49 &NdArrayDevice::Cpu,
50 )
51 }
52
53 fn bool_device(_tensor: &NdArrayTensor<bool>) -> <NdArray<E> as Backend>::Device {
54 NdArrayDevice::Cpu
55 }
56
57 fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<bool> {
58 let values = vec![false; shape.num_elements()];
59 NdArrayTensor::from_data(TensorData::new(values, shape))
60 }
61
62 fn bool_slice_assign(
63 tensor: NdArrayTensor<bool>,
64 ranges: &[Range<usize>],
65 value: NdArrayTensor<bool>,
66 ) -> NdArrayTensor<bool> {
67 NdArrayOps::slice_assign(tensor, ranges, value)
68 }
69
70 fn bool_cat(tensors: Vec<NdArrayTensor<bool>>, dim: usize) -> NdArrayTensor<bool> {
71 NdArrayOps::cat(tensors, dim)
72 }
73
74 fn bool_equal(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
75 let output = Zip::from(&lhs.array)
76 .and(&rhs.array)
77 .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
78 .into_shared();
79 NdArrayTensor::new(output)
80 }
81
82 fn bool_not(tensor: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
83 let array = tensor.array.mapv(|a| !a).into_shared();
84 NdArrayTensor { array }
85 }
86
87 fn bool_into_float(tensor: NdArrayTensor<bool>) -> FloatTensor<Self> {
88 new_tensor_float!(NdArrayTensor {
89 array: tensor.array.mapv(|a| (a as i32).elem()).into_shared(),
90 })
91 }
92
93 fn bool_swap_dims(
94 tensor: NdArrayTensor<bool>,
95 dim1: usize,
96 dim2: usize,
97 ) -> NdArrayTensor<bool> {
98 NdArrayOps::swap_dims(tensor, dim1, dim2)
99 }
100
101 fn bool_permute(tensor: NdArrayTensor<bool>, axes: &[usize]) -> NdArrayTensor<bool> {
102 let array = tensor.array.permuted_axes(axes.into_dimension());
103 NdArrayTensor { array }
104 }
105
106 fn bool_expand(tensor: NdArrayTensor<bool>, shape: Shape) -> NdArrayTensor<bool> {
107 NdArrayOps::expand(tensor, shape)
108 }
109
110 fn bool_flip(tensor: NdArrayTensor<bool>, axes: &[usize]) -> NdArrayTensor<bool> {
111 NdArrayOps::flip(tensor, axes)
112 }
113}