burn_tensor/tensor/api/
bool.rs

1use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
2use alloc::vec::Vec;
3
4use crate::try_read_sync;
5
6/// The part of the tensor to keep when creating a triangular mask.
7enum TriPart {
8    /// Upper triangular part.
9    Upper,
10
11    /// Lower triangular part.
12    Lower,
13
14    /// Diagonal part.
15    Diagonal,
16}
17
18impl<B, const D: usize> Tensor<B, D, Bool>
19where
20    B: Backend,
21{
22    /// Create a boolean tensor from data on the given device.
23    pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
24        Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device))
25    }
26
27    /// Convert the bool tensor into an int tensor.
28    pub fn int(self) -> Tensor<B, D, Int> {
29        Tensor::new(B::bool_into_int(self.primitive))
30    }
31
32    /// Convert the bool tensor into an float tensor.
33    pub fn float(self) -> Tensor<B, D> {
34        Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
35    }
36
37    /// Inverses boolean values.
38    pub fn bool_not(self) -> Self {
39        Tensor::new(B::bool_not(self.primitive))
40    }
41
42    /// Performs logical and (`&&`) on two boolean tensors
43    pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
44        Tensor::new(B::bool_and(self.primitive, rhs.primitive))
45    }
46
47    /// Performs logical or (`||`) on two boolean tensors
48    pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
49        Tensor::new(B::bool_or(self.primitive, rhs.primitive))
50    }
51
52    /// Compute the indices of the elements that are non-zero.
53    ///
54    /// # Returns
55    ///
56    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
57    /// the non-zero elements in that dimension.
58    pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
59        try_read_sync(self.nonzero_async())
60            .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
61    }
62
63    /// Compute the indices of the elements that are non-zero.
64    ///
65    /// # Returns
66    ///
67    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
68    /// the non-zero elements in that dimension.
69    pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
70        B::bool_nonzero(self.primitive)
71            .await
72            .into_iter()
73            .map(Tensor::new)
74            .collect()
75    }
76
77    /// Compute the indices of the elements that are true, grouped by element.
78    ///
79    /// # Returns
80    ///
81    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
82    /// result contains the indices of a non-zero element.
83    pub fn argwhere(self) -> Tensor<B, 2, Int> {
84        try_read_sync(self.argwhere_async())
85            .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
86    }
87
88    /// Compute the indices of the elements that are true, grouped by element.
89    ///
90    /// # Returns
91    ///
92    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
93    /// result contains the indices of a non-zero element.
94    pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
95        Tensor::new(B::bool_argwhere(self.primitive).await)
96    }
97
98    /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
99    /// fill the specified area with a value.
100    fn tri_mask<S: Into<Shape>>(
101        shape: S,
102        tri_part: TriPart,
103        offset: i64,
104        device: &B::Device,
105    ) -> Self {
106        let shape: Shape = shape.into();
107        let height = shape.dims[D - 2];
108        let width = shape.dims[D - 1];
109
110        // Generate row and column index tensors.
111        let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
112        let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
113
114        // Prepare shapes for broadcasting.
115        let mut row_shape = [1; D];
116        row_shape[D - 2] = height;
117        let mut col_shape = [1; D];
118        col_shape[D - 1] = width;
119
120        // Reshape for broadcasting.
121        let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
122        let col_broadcast = col_indices.reshape(Shape::new(col_shape));
123
124        // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
125        let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
126
127        // Select the appropriate comparison function based on `tri_part`.
128        let compare = match tri_part {
129            TriPart::Upper => Tensor::greater_elem,
130            TriPart::Lower => Tensor::lower_elem,
131            TriPart::Diagonal => Tensor::not_equal_elem,
132        };
133
134        // Generate and return the mask by applying the comparison to the matrix.
135        compare(matrix, 0).unsqueeze()
136    }
137
138    /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
139    /// area with a value.
140    ///
141    /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
142    ///
143    /// # Arguments
144    ///
145    /// * `shape`: The shape of the matrix.
146    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
147    ///   towards the upper triangle.
148    /// * `device`: The device on which the tensor will be allocated.
149    ///
150    /// # Returns
151    ///
152    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
153    /// upper triangle taking into account the specified `offset`. All other elements are `true`.
154    ///
155    /// # Example
156    /// ```rust
157    /// use burn_tensor::backend::Backend;
158    /// use burn_tensor::{Tensor, Bool};
159    ///
160    /// fn example<B: Backend>() {
161    ///   let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
162    ///   println!("{mask}");
163    ///   // [[false, false, false],
164    ///   //  [true, false, false],
165    ///   //  [true, true, false]]
166    /// }
167    /// ```
168    pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
169        Self::tri_mask(shape, TriPart::Upper, offset, device)
170    }
171
172    /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
173    /// area with a value.
174    ///
175    /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
176    ///
177    /// # Arguments
178    ///
179    /// * `shape`: The shape of the matrix.
180    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
181    ///   towards the lower triangle.
182    /// * `device`: The device on which the tensor will be allocated.
183    ///
184    /// # Returns
185    ///
186    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
187    /// lower triangle taking into account the specified `offset`. All other elements are `true`.
188    ///
189    /// # Example
190    /// ```rust
191    /// use burn_tensor::backend::Backend;
192    /// use burn_tensor::{Tensor, Bool};
193    ///
194    /// fn example<B: Backend>() {
195    ///   let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
196    ///   println!("{mask}");
197    ///   // [[false, true, true],
198    ///   //  [false, false, true],
199    ///   //  [false, false, false]]
200    /// }
201    /// ```
202    pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
203        Self::tri_mask(shape, TriPart::Lower, offset, device)
204    }
205
206    /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
207    /// area with a value.
208    ///
209    /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
210    ///
211    /// # Arguments
212    ///
213    /// * `shape`: The shape of the matrix.
214    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
215    ///   towards the upper triangle.
216    /// * `device`: The device on which the tensor will be allocated.
217    ///
218    /// # Returns
219    ///
220    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
221    /// diagonal. All other elements are `true`.
222    ///
223    /// # Example
224    /// ```rust
225    /// use burn_tensor::backend::Backend;
226    /// use burn_tensor::{Tensor, Bool};
227    ///
228    /// fn example<B: Backend>() {
229    ///   let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
230    ///   println!("{mask}");
231    ///   // [[false, true, true],
232    ///   //  [true, false, true],
233    ///   //  [true, true, false]]
234    /// }
235    /// ```
236    pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
237        Self::tri_mask(shape, TriPart::Diagonal, offset, device)
238    }
239}