1use alloc::vec;
4use alloc::vec::Vec;
5use burn_backend::{
6 DType, ExecutionError, TensorData,
7 ops::{BoolTensorOps, IntTensorOps},
8 tensor::{BoolTensor, Device, FloatTensor, IntTensor},
9};
10use burn_std::{Bytes, FloatDType, IntDType, Shape, Slice, bf16, f16};
11
12use crate::{Flex, FlexTensor, Layout};
13
14impl BoolTensorOps<Flex> for Flex {
15 fn bool_from_data(data: TensorData, _device: &Device<Flex>) -> BoolTensor<Flex> {
16 FlexTensor::from_data(data)
17 }
18
19 async fn bool_into_data(tensor: BoolTensor<Flex>) -> Result<TensorData, ExecutionError> {
20 Ok(tensor.into_data())
21 }
22
23 fn bool_device(_tensor: &BoolTensor<Flex>) -> Device<Flex> {
24 Default::default()
25 }
26
27 fn bool_to_device(tensor: BoolTensor<Flex>, _device: &Device<Flex>) -> BoolTensor<Flex> {
28 tensor
29 }
30
31 fn bool_cat(tensors: Vec<BoolTensor<Flex>>, dim: usize) -> BoolTensor<Flex> {
32 crate::ops::cat::cat(tensors, dim)
33 }
34
35 fn bool_reshape(tensor: BoolTensor<Flex>, shape: Shape) -> BoolTensor<Flex> {
36 tensor.reshape(shape)
37 }
38
39 fn bool_slice(tensor: BoolTensor<Flex>, slices: &[Slice]) -> BoolTensor<Flex> {
40 crate::ops::slice::slice(tensor, slices)
41 }
42
43 fn bool_empty(
44 shape: Shape,
45 _device: &Device<Flex>,
46 dtype: burn_std::BoolDType,
47 ) -> BoolTensor<Flex> {
48 FlexTensor::empty(shape, DType::from(dtype))
49 }
50
51 fn bool_slice_assign(
52 tensor: BoolTensor<Flex>,
53 slices: &[Slice],
54 value: BoolTensor<Flex>,
55 ) -> BoolTensor<Flex> {
56 crate::ops::slice::slice_assign(tensor, slices, value)
57 }
58
59 fn bool_into_int(tensor: BoolTensor<Flex>, out_dtype: burn_std::IntDType) -> IntTensor<Flex> {
60 let tensor = tensor.to_contiguous();
61 let shape = tensor.layout().shape().clone();
62 let out_dt = DType::from(out_dtype);
63 let bools = tensor.bytes();
64
65 macro_rules! convert {
66 ($int_ty:ty) => {{
67 let data: Vec<$int_ty> =
68 bools.iter().map(|&x| if x != 0 { 1 } else { 0 }).collect();
69 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
70 }};
71 }
72
73 match out_dtype {
74 IntDType::I64 => convert!(i64),
75 IntDType::I32 => convert!(i32),
76 IntDType::I16 => convert!(i16),
77 IntDType::I8 => convert!(i8),
78 IntDType::U64 => convert!(u64),
79 IntDType::U32 => convert!(u32),
80 IntDType::U16 => convert!(u16),
81 IntDType::U8 => convert!(u8),
82 }
83 }
84
85 fn bool_into_float(
86 tensor: BoolTensor<Flex>,
87 out_dtype: burn_std::FloatDType,
88 ) -> FloatTensor<Flex> {
89 let tensor = tensor.to_contiguous();
90 let shape = tensor.layout().shape().clone();
91 let out_dt = DType::from(out_dtype);
92 let bools = tensor.bytes();
93
94 match out_dtype {
95 FloatDType::F64 => {
96 let data: Vec<f64> = bools
97 .iter()
98 .map(|&x| if x != 0 { 1.0 } else { 0.0 })
99 .collect();
100 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
101 }
102 FloatDType::F32 | FloatDType::Flex32 => {
103 let data: Vec<f32> = bools
104 .iter()
105 .map(|&x| if x != 0 { 1.0 } else { 0.0 })
106 .collect();
107 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
108 }
109 FloatDType::F16 => {
110 let one = f16::from_f32(1.0);
111 let zero = f16::from_f32(0.0);
112 let data: Vec<f16> = bools
113 .iter()
114 .map(|&x| if x != 0 { one } else { zero })
115 .collect();
116 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
117 }
118 FloatDType::BF16 => {
119 let one = bf16::from_f32(1.0);
120 let zero = bf16::from_f32(0.0);
121 let data: Vec<bf16> = bools
122 .iter()
123 .map(|&x| if x != 0 { one } else { zero })
124 .collect();
125 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
126 }
127 }
128 }
129
130 fn bool_swap_dims(tensor: BoolTensor<Flex>, dim1: usize, dim2: usize) -> BoolTensor<Flex> {
131 tensor.transpose(dim1, dim2)
132 }
133
134 fn bool_permute(tensor: BoolTensor<Flex>, axes: &[usize]) -> BoolTensor<Flex> {
135 tensor.permute(axes)
136 }
137
138 fn bool_flip(tensor: BoolTensor<Flex>, axes: &[usize]) -> BoolTensor<Flex> {
139 crate::ops::flip::flip(tensor, axes)
140 }
141
142 fn bool_equal(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
143 use crate::strided_index::StridedIter;
144
145 let (lhs, rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
151
152 let out_dtype = burn_std::BoolDType::from(lhs.dtype());
153 let shape = lhs.layout().shape().clone();
154 let lhs_storage: &[u8] = lhs.bytes();
155 let rhs_storage: &[u8] = rhs.bytes();
156
157 let result: Vec<u8> = match (
158 lhs.layout().contiguous_offsets(),
159 rhs.layout().contiguous_offsets(),
160 ) {
161 (Some((l_start, l_end)), Some((r_start, r_end))) => {
162 let l_slice = &lhs_storage[l_start..l_end];
163 let r_slice = &rhs_storage[r_start..r_end];
164 l_slice
165 .iter()
166 .zip(r_slice)
167 .map(|(&a, &b)| (a == b) as u8)
168 .collect()
169 }
170 _ => {
171 let lhs_iter = StridedIter::new(lhs.layout());
172 let rhs_iter = StridedIter::new(rhs.layout());
173 lhs_iter
174 .zip(rhs_iter)
175 .map(|(li, ri)| (lhs_storage[li] == rhs_storage[ri]) as u8)
176 .collect()
177 }
178 };
179
180 crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
181 }
182
183 fn bool_not(mut tensor: BoolTensor<Flex>) -> BoolTensor<Flex> {
184 use crate::strided_index::StridedIter;
185
186 debug_assert!(
187 matches!(
188 tensor.dtype(),
189 DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
190 ),
191 "bool_not: only Bool(Native) and Bool(U8) are supported, got {:?}",
192 tensor.dtype()
193 );
194
195 if tensor.is_unique()
199 && tensor.layout().is_contiguous()
200 && tensor.layout().start_offset() == 0
201 {
202 let storage = tensor.storage_mut::<u8>();
203 crate::simd::bool_not_inplace_u8(storage);
204 return tensor;
205 }
206
207 let out_dtype = burn_std::BoolDType::from(tensor.dtype());
210 let shape = tensor.layout().shape().clone();
211 let storage: &[u8] = tensor.bytes();
212
213 let result: Vec<u8> = match tensor.layout().contiguous_offsets() {
214 Some((start, end)) => {
215 let slice = &storage[start..end];
216 let mut out = vec![0u8; slice.len()];
217 crate::simd::bool_not_u8(slice, &mut out);
218 out
219 }
220 None => StridedIter::new(tensor.layout())
221 .map(|idx| (storage[idx] == 0) as u8)
222 .collect(),
223 };
224
225 crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
226 }
227
228 fn bool_and(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
229 bool_binary_op_simd(lhs, rhs, BoolBinaryOp::And)
230 }
231
232 fn bool_or(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
233 bool_binary_op_simd(lhs, rhs, BoolBinaryOp::Or)
234 }
235
236 fn bool_xor(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
237 bool_binary_op_simd(lhs, rhs, BoolBinaryOp::Xor)
238 }
239
240 fn bool_expand(tensor: BoolTensor<Flex>, shape: Shape) -> BoolTensor<Flex> {
241 crate::ops::expand::expand(tensor, shape)
242 }
243
244 fn bool_zeros(
246 shape: Shape,
247 device: &Device<Flex>,
248 dtype: burn_std::BoolDType,
249 ) -> BoolTensor<Flex> {
250 Self::bool_empty(shape, device, dtype)
251 }
252
253 fn bool_ones(
254 shape: Shape,
255 _device: &Device<Flex>,
256 dtype: burn_std::BoolDType,
257 ) -> BoolTensor<Flex> {
258 let num_elements = shape.num_elements();
259 let data = vec![1u8; num_elements];
260 crate::ops::comparison::make_bool_tensor(data, shape, dtype)
261 }
262
263 fn bool_mask_where(
264 tensor: BoolTensor<Flex>,
265 mask: BoolTensor<Flex>,
266 value: BoolTensor<Flex>,
267 ) -> BoolTensor<Flex> {
268 crate::ops::mask::mask_where_bool(tensor, mask, value)
269 }
270
271 fn bool_mask_fill(
272 tensor: BoolTensor<Flex>,
273 mask: BoolTensor<Flex>,
274 value: burn_backend::Scalar,
275 ) -> BoolTensor<Flex> {
276 let value: bool = value.elem();
277 crate::ops::mask::mask_fill_bool(tensor, mask, value)
278 }
279
280 fn bool_gather(
281 dim: usize,
282 tensor: BoolTensor<Flex>,
283 indices: IntTensor<Flex>,
284 ) -> BoolTensor<Flex> {
285 crate::ops::gather_scatter::gather_bool(tensor, dim, indices)
286 }
287
288 fn bool_scatter_or(
289 dim: usize,
290 tensor: BoolTensor<Flex>,
291 indices: IntTensor<Flex>,
292 value: BoolTensor<Flex>,
293 ) -> BoolTensor<Flex> {
294 crate::ops::gather_scatter::scatter_or(tensor, dim, indices, value)
295 }
296
297 fn bool_equal_elem(lhs: BoolTensor<Flex>, rhs: burn_backend::Scalar) -> BoolTensor<Flex> {
298 use crate::strided_index::StridedIter;
299
300 let out_dtype = burn_std::BoolDType::from(lhs.dtype());
301 let shape = lhs.layout().shape().clone();
302 let storage: &[u8] = lhs.bytes();
303 let rhs_bool: bool = rhs.elem();
304 let rhs_val = rhs_bool as u8;
305
306 let result: Vec<u8> = match lhs.layout().contiguous_offsets() {
307 Some((start, end)) => storage[start..end]
308 .iter()
309 .map(|&v| (v == rhs_val) as u8)
310 .collect(),
311 None => StridedIter::new(lhs.layout())
312 .map(|idx| (storage[idx] == rhs_val) as u8)
313 .collect(),
314 };
315
316 crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
317 }
318
319 fn bool_unfold(
320 tensor: BoolTensor<Flex>,
321 dim: usize,
322 size: usize,
323 step: usize,
324 ) -> BoolTensor<Flex> {
325 crate::ops::unfold::unfold_bool(tensor, dim, size, step)
326 }
327
328 fn bool_not_equal(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
329 let out_dtype = burn_std::BoolDType::from(lhs.dtype());
330 crate::ops::comparison::bool_not_equal(lhs, rhs, out_dtype)
331 }
332
333 fn bool_not_equal_elem(lhs: BoolTensor<Flex>, rhs: burn_backend::Scalar) -> BoolTensor<Flex> {
334 let out_dtype = burn_std::BoolDType::from(lhs.dtype());
335 let rhs: bool = rhs.elem();
336 crate::ops::comparison::bool_not_equal_elem(lhs, rhs, out_dtype)
337 }
338
339 fn bool_any(tensor: BoolTensor<Flex>) -> BoolTensor<Flex> {
340 let out_dtype = burn_std::BoolDType::from(tensor.dtype());
341 crate::ops::comparison::any_bool(tensor, out_dtype)
342 }
343
344 fn bool_any_dim(tensor: BoolTensor<Flex>, dim: usize) -> BoolTensor<Flex> {
345 let out_dtype = burn_std::BoolDType::from(tensor.dtype());
346 crate::ops::comparison::any_bool_dim(tensor, dim, out_dtype)
347 }
348
349 fn bool_all(tensor: BoolTensor<Flex>) -> BoolTensor<Flex> {
350 let out_dtype = burn_std::BoolDType::from(tensor.dtype());
351 crate::ops::comparison::all_bool(tensor, out_dtype)
352 }
353
354 fn bool_all_dim(tensor: BoolTensor<Flex>, dim: usize) -> BoolTensor<Flex> {
355 let out_dtype = burn_std::BoolDType::from(tensor.dtype());
356 crate::ops::comparison::all_bool_dim(tensor, dim, out_dtype)
357 }
358
359 fn bool_select(
360 tensor: BoolTensor<Flex>,
361 dim: usize,
362 indices: IntTensor<Flex>,
363 ) -> BoolTensor<Flex> {
364 crate::ops::gather_scatter::select::<u8>(tensor, dim, indices)
365 }
366
367 fn bool_select_or(
368 tensor: BoolTensor<Flex>,
369 dim: usize,
370 indices: IntTensor<Flex>,
371 value: BoolTensor<Flex>,
372 ) -> BoolTensor<Flex> {
373 let mut result = crate::ops::gather_scatter::select_add::<u8>(tensor, dim, indices, value);
374 let storage: &mut [u8] = result.storage_mut();
376 for v in storage.iter_mut() {
377 if *v > 1 {
378 *v = 1;
379 }
380 }
381 result
382 }
383
384 fn bool_transpose(tensor: BoolTensor<Flex>) -> BoolTensor<Flex> {
385 let ndims = tensor.layout().num_dims();
386 if ndims < 2 {
387 return tensor;
388 }
389 tensor.transpose(ndims - 2, ndims - 1)
390 }
391
392 fn bool_repeat_dim(tensor: BoolTensor<Flex>, dim: usize, times: usize) -> BoolTensor<Flex> {
393 crate::ops::repeat_dim::repeat_dim(tensor, dim, times)
394 }
395
396 async fn bool_argwhere(tensor: BoolTensor<Flex>, out_dtype: IntDType) -> IntTensor<Flex> {
397 let tensor = tensor.to_contiguous();
398 let shape = tensor.layout().shape().clone();
399 let ndims = shape.num_dims();
400 let data: &[u8] = tensor.storage();
401 let n = shape.num_elements();
402
403 let count = data[..n].iter().filter(|&&v| v != 0).count();
404 let mut coords: Vec<isize> = Vec::with_capacity(count * ndims);
405 let strides = crate::layout::contiguous_strides_usize(&shape);
406
407 for (flat_idx, &val) in data[..n].iter().enumerate() {
408 if val != 0 {
409 let mut remaining = flat_idx;
410 for &s in &strides {
411 coords.push((remaining / s) as isize);
412 remaining %= s;
413 }
414 }
415 }
416
417 let out_shape = Shape::from(vec![count, ndims]);
418 let result = FlexTensor::new(
419 Bytes::from_elems(coords),
420 Layout::contiguous(out_shape),
421 crate::ops::INDEX_DTYPE,
422 );
423 if result.dtype() != DType::from(out_dtype) {
424 Flex::int_cast(result, out_dtype)
425 } else {
426 result
427 }
428 }
429}
430
431#[derive(Clone, Copy)]
433enum BoolBinaryOp {
434 And,
435 Or,
436 Xor,
437}
438
439fn bool_binary_op_simd(lhs: FlexTensor, rhs: FlexTensor, op: BoolBinaryOp) -> FlexTensor {
440 use crate::strided_index::StridedIter;
441
442 debug_assert_eq!(lhs.dtype(), rhs.dtype(), "bool_binary_op: dtype mismatch");
443
444 let (mut lhs, mut rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
448
449 let out_dtype = burn_std::BoolDType::from(lhs.dtype());
452 let shape = lhs.layout().shape().clone();
453 let l_offsets = lhs.layout().contiguous_offsets();
454 let r_offsets = rhs.layout().contiguous_offsets();
455
456 if lhs.is_unique()
458 && let (Some((0, l_end)), Some((r_start, r_end))) = (l_offsets, r_offsets)
459 {
460 let rhs_storage: &[u8] = rhs.bytes();
461 let r_slice = &rhs_storage[r_start..r_end];
462 let lhs_storage: &mut [u8] = lhs.storage_mut();
463 let l_slice = &mut lhs_storage[..l_end];
464
465 match op {
466 BoolBinaryOp::And => crate::simd::bool_and_inplace_u8(l_slice, r_slice),
467 BoolBinaryOp::Or => crate::simd::bool_or_inplace_u8(l_slice, r_slice),
468 BoolBinaryOp::Xor => crate::simd::bool_xor_inplace_u8(l_slice, r_slice),
469 }
470 return lhs;
471 }
472
473 if rhs.is_unique()
476 && let (Some((l_start, l_end)), Some((0, r_end))) = (l_offsets, r_offsets)
477 {
478 let lhs_storage: &[u8] = lhs.bytes();
479 let l_slice = &lhs_storage[l_start..l_end];
480 let rhs_storage: &mut [u8] = rhs.storage_mut();
481 let r_slice = &mut rhs_storage[..r_end];
482
483 match op {
484 BoolBinaryOp::And => crate::simd::bool_and_inplace_u8(r_slice, l_slice),
485 BoolBinaryOp::Or => crate::simd::bool_or_inplace_u8(r_slice, l_slice),
486 BoolBinaryOp::Xor => crate::simd::bool_xor_inplace_u8(r_slice, l_slice),
487 }
488 return rhs;
489 }
490
491 let lhs_storage: &[u8] = lhs.bytes();
493 let rhs_storage: &[u8] = rhs.bytes();
494
495 let result: Vec<u8> = match (l_offsets, r_offsets) {
496 (Some((l_start, l_end)), Some((r_start, r_end))) => {
497 let l_slice = &lhs_storage[l_start..l_end];
498 let r_slice = &rhs_storage[r_start..r_end];
499 let mut out = vec![0u8; l_slice.len()];
500 match op {
501 BoolBinaryOp::And => crate::simd::bool_and_u8(l_slice, r_slice, &mut out),
502 BoolBinaryOp::Or => crate::simd::bool_or_u8(l_slice, r_slice, &mut out),
503 BoolBinaryOp::Xor => crate::simd::bool_xor_u8(l_slice, r_slice, &mut out),
504 }
505 out
506 }
507 _ => {
508 let lhs_iter = StridedIter::new(lhs.layout());
509 let rhs_iter = StridedIter::new(rhs.layout());
510 match op {
511 BoolBinaryOp::And => lhs_iter
512 .zip(rhs_iter)
513 .map(|(li, ri)| lhs_storage[li] & rhs_storage[ri])
514 .collect(),
515 BoolBinaryOp::Or => lhs_iter
516 .zip(rhs_iter)
517 .map(|(li, ri)| lhs_storage[li] | rhs_storage[ri])
518 .collect(),
519 BoolBinaryOp::Xor => lhs_iter
520 .zip(rhs_iter)
521 .map(|(li, ri)| lhs_storage[li] ^ rhs_storage[ri])
522 .collect(),
523 }
524 }
525 };
526
527 crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
528}
529
530#[cfg(test)]
538mod tests {
539 use alloc::vec;
540 use burn_backend::TensorData;
541 use burn_backend::ops::BoolTensorOps;
542 use burn_std::{FloatDType, IntDType};
543
544 use crate::{Flex, FlexTensor};
545
546 #[test]
547 fn test_bool_into_int_u8() {
548 let t = FlexTensor::from_data(TensorData::from([true, false, true]));
549 let result = Flex::bool_into_int(t, IntDType::U8);
550 assert_eq!(result.dtype(), burn_backend::DType::U8);
551 let data: Vec<u8> = result.into_data().to_vec().unwrap();
552 assert_eq!(data, vec![1u8, 0, 1]);
553 }
554
555 #[test]
556 fn test_bool_into_float_f64() {
557 let t = FlexTensor::from_data(TensorData::from([true, false, true]));
558 let result = Flex::bool_into_float(t, FloatDType::F64);
559 assert_eq!(result.dtype(), burn_backend::DType::F64);
560 let data: Vec<f64> = result.into_data().to_vec().unwrap();
561 assert_eq!(data, vec![1.0f64, 0.0, 1.0]);
562 }
563}