Skip to main content

burn_flex/ops/
bool.rs

1//! Bool tensor operations for the Flex backend.
2
3use alloc::vec;
4use alloc::vec::Vec;
5use burn_backend::{
6    DType, ExecutionError, TensorData,
7    ops::{BoolTensorOps, IntTensorOps},
8    tensor::{BoolTensor, Device, FloatTensor, IntTensor},
9};
10use burn_std::{Bytes, FloatDType, IntDType, Shape, Slice, bf16, f16};
11
12use crate::{Flex, FlexTensor, Layout};
13
14impl BoolTensorOps<Flex> for Flex {
15    fn bool_from_data(data: TensorData, _device: &Device<Flex>) -> BoolTensor<Flex> {
16        FlexTensor::from_data(data)
17    }
18
19    async fn bool_into_data(tensor: BoolTensor<Flex>) -> Result<TensorData, ExecutionError> {
20        Ok(tensor.into_data())
21    }
22
23    fn bool_device(_tensor: &BoolTensor<Flex>) -> Device<Flex> {
24        Default::default()
25    }
26
27    fn bool_to_device(tensor: BoolTensor<Flex>, _device: &Device<Flex>) -> BoolTensor<Flex> {
28        tensor
29    }
30
31    fn bool_cat(tensors: Vec<BoolTensor<Flex>>, dim: usize) -> BoolTensor<Flex> {
32        crate::ops::cat::cat(tensors, dim)
33    }
34
35    fn bool_reshape(tensor: BoolTensor<Flex>, shape: Shape) -> BoolTensor<Flex> {
36        tensor.reshape(shape)
37    }
38
39    fn bool_slice(tensor: BoolTensor<Flex>, slices: &[Slice]) -> BoolTensor<Flex> {
40        crate::ops::slice::slice(tensor, slices)
41    }
42
43    fn bool_empty(
44        shape: Shape,
45        _device: &Device<Flex>,
46        dtype: burn_std::BoolDType,
47    ) -> BoolTensor<Flex> {
48        FlexTensor::empty(shape, DType::from(dtype))
49    }
50
51    fn bool_slice_assign(
52        tensor: BoolTensor<Flex>,
53        slices: &[Slice],
54        value: BoolTensor<Flex>,
55    ) -> BoolTensor<Flex> {
56        crate::ops::slice::slice_assign(tensor, slices, value)
57    }
58
59    fn bool_into_int(tensor: BoolTensor<Flex>, out_dtype: burn_std::IntDType) -> IntTensor<Flex> {
60        let tensor = tensor.to_contiguous();
61        let shape = tensor.layout().shape().clone();
62        let out_dt = DType::from(out_dtype);
63        let bools = tensor.bytes();
64
65        macro_rules! convert {
66            ($int_ty:ty) => {{
67                let data: Vec<$int_ty> =
68                    bools.iter().map(|&x| if x != 0 { 1 } else { 0 }).collect();
69                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
70            }};
71        }
72
73        match out_dtype {
74            IntDType::I64 => convert!(i64),
75            IntDType::I32 => convert!(i32),
76            IntDType::I16 => convert!(i16),
77            IntDType::I8 => convert!(i8),
78            IntDType::U64 => convert!(u64),
79            IntDType::U32 => convert!(u32),
80            IntDType::U16 => convert!(u16),
81            IntDType::U8 => convert!(u8),
82        }
83    }
84
85    fn bool_into_float(
86        tensor: BoolTensor<Flex>,
87        out_dtype: burn_std::FloatDType,
88    ) -> FloatTensor<Flex> {
89        let tensor = tensor.to_contiguous();
90        let shape = tensor.layout().shape().clone();
91        let out_dt = DType::from(out_dtype);
92        let bools = tensor.bytes();
93
94        match out_dtype {
95            FloatDType::F64 => {
96                let data: Vec<f64> = bools
97                    .iter()
98                    .map(|&x| if x != 0 { 1.0 } else { 0.0 })
99                    .collect();
100                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
101            }
102            FloatDType::F32 | FloatDType::Flex32 => {
103                let data: Vec<f32> = bools
104                    .iter()
105                    .map(|&x| if x != 0 { 1.0 } else { 0.0 })
106                    .collect();
107                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
108            }
109            FloatDType::F16 => {
110                let one = f16::from_f32(1.0);
111                let zero = f16::from_f32(0.0);
112                let data: Vec<f16> = bools
113                    .iter()
114                    .map(|&x| if x != 0 { one } else { zero })
115                    .collect();
116                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
117            }
118            FloatDType::BF16 => {
119                let one = bf16::from_f32(1.0);
120                let zero = bf16::from_f32(0.0);
121                let data: Vec<bf16> = bools
122                    .iter()
123                    .map(|&x| if x != 0 { one } else { zero })
124                    .collect();
125                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
126            }
127        }
128    }
129
130    fn bool_swap_dims(tensor: BoolTensor<Flex>, dim1: usize, dim2: usize) -> BoolTensor<Flex> {
131        tensor.transpose(dim1, dim2)
132    }
133
134    fn bool_permute(tensor: BoolTensor<Flex>, axes: &[usize]) -> BoolTensor<Flex> {
135        tensor.permute(axes)
136    }
137
138    fn bool_flip(tensor: BoolTensor<Flex>, axes: &[usize]) -> BoolTensor<Flex> {
139        crate::ops::flip::flip(tensor, axes)
140    }
141
142    fn bool_equal(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
143        use crate::strided_index::StridedIter;
144
145        // Broadcast to a common shape before comparing. The contiguous fast
146        // path below uses `zip`, which silently truncates to the shorter
147        // operand; and the output shape is taken from lhs, so mismatched
148        // operands would otherwise produce a result vec shorter than the
149        // output layout claims.
150        let (lhs, rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
151
152        let out_dtype = burn_std::BoolDType::from(lhs.dtype());
153        let shape = lhs.layout().shape().clone();
154        let lhs_storage: &[u8] = lhs.bytes();
155        let rhs_storage: &[u8] = rhs.bytes();
156
157        let result: Vec<u8> = match (
158            lhs.layout().contiguous_offsets(),
159            rhs.layout().contiguous_offsets(),
160        ) {
161            (Some((l_start, l_end)), Some((r_start, r_end))) => {
162                let l_slice = &lhs_storage[l_start..l_end];
163                let r_slice = &rhs_storage[r_start..r_end];
164                l_slice
165                    .iter()
166                    .zip(r_slice)
167                    .map(|(&a, &b)| (a == b) as u8)
168                    .collect()
169            }
170            _ => {
171                let lhs_iter = StridedIter::new(lhs.layout());
172                let rhs_iter = StridedIter::new(rhs.layout());
173                lhs_iter
174                    .zip(rhs_iter)
175                    .map(|(li, ri)| (lhs_storage[li] == rhs_storage[ri]) as u8)
176                    .collect()
177            }
178        };
179
180        crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
181    }
182
183    fn bool_not(mut tensor: BoolTensor<Flex>) -> BoolTensor<Flex> {
184        use crate::strided_index::StridedIter;
185
186        debug_assert!(
187            matches!(
188                tensor.dtype(),
189                DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
190            ),
191            "bool_not: only Bool(Native) and Bool(U8) are supported, got {:?}",
192            tensor.dtype()
193        );
194
195        // Fast path: in-place for unique, contiguous tensors at offset 0. This
196        // preserves the input tensor's dtype tag implicitly (the in-place SIMD
197        // ops flip bytes without touching the dtype tag).
198        if tensor.is_unique()
199            && tensor.layout().is_contiguous()
200            && tensor.layout().start_offset() == 0
201        {
202            let storage = tensor.storage_mut::<u8>();
203            crate::simd::bool_not_inplace_u8(storage);
204            return tensor;
205        }
206
207        // Allocating path for shared, non-contiguous, or offset tensors:
208        // preserve the input's bool dtype for the new tensor.
209        let out_dtype = burn_std::BoolDType::from(tensor.dtype());
210        let shape = tensor.layout().shape().clone();
211        let storage: &[u8] = tensor.bytes();
212
213        let result: Vec<u8> = match tensor.layout().contiguous_offsets() {
214            Some((start, end)) => {
215                let slice = &storage[start..end];
216                let mut out = vec![0u8; slice.len()];
217                crate::simd::bool_not_u8(slice, &mut out);
218                out
219            }
220            None => StridedIter::new(tensor.layout())
221                .map(|idx| (storage[idx] == 0) as u8)
222                .collect(),
223        };
224
225        crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
226    }
227
228    fn bool_and(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
229        bool_binary_op_simd(lhs, rhs, BoolBinaryOp::And)
230    }
231
232    fn bool_or(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
233        bool_binary_op_simd(lhs, rhs, BoolBinaryOp::Or)
234    }
235
236    fn bool_xor(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
237        bool_binary_op_simd(lhs, rhs, BoolBinaryOp::Xor)
238    }
239
240    fn bool_expand(tensor: BoolTensor<Flex>, shape: Shape) -> BoolTensor<Flex> {
241        crate::ops::expand::expand(tensor, shape)
242    }
243
244    // Missing methods
245    fn bool_zeros(
246        shape: Shape,
247        device: &Device<Flex>,
248        dtype: burn_std::BoolDType,
249    ) -> BoolTensor<Flex> {
250        Self::bool_empty(shape, device, dtype)
251    }
252
253    fn bool_ones(
254        shape: Shape,
255        _device: &Device<Flex>,
256        dtype: burn_std::BoolDType,
257    ) -> BoolTensor<Flex> {
258        let num_elements = shape.num_elements();
259        let data = vec![1u8; num_elements];
260        crate::ops::comparison::make_bool_tensor(data, shape, dtype)
261    }
262
263    fn bool_mask_where(
264        tensor: BoolTensor<Flex>,
265        mask: BoolTensor<Flex>,
266        value: BoolTensor<Flex>,
267    ) -> BoolTensor<Flex> {
268        crate::ops::mask::mask_where_bool(tensor, mask, value)
269    }
270
271    fn bool_mask_fill(
272        tensor: BoolTensor<Flex>,
273        mask: BoolTensor<Flex>,
274        value: burn_backend::Scalar,
275    ) -> BoolTensor<Flex> {
276        let value: bool = value.elem();
277        crate::ops::mask::mask_fill_bool(tensor, mask, value)
278    }
279
280    fn bool_gather(
281        dim: usize,
282        tensor: BoolTensor<Flex>,
283        indices: IntTensor<Flex>,
284    ) -> BoolTensor<Flex> {
285        crate::ops::gather_scatter::gather_bool(tensor, dim, indices)
286    }
287
288    fn bool_scatter_or(
289        dim: usize,
290        tensor: BoolTensor<Flex>,
291        indices: IntTensor<Flex>,
292        value: BoolTensor<Flex>,
293    ) -> BoolTensor<Flex> {
294        crate::ops::gather_scatter::scatter_or(tensor, dim, indices, value)
295    }
296
297    fn bool_equal_elem(lhs: BoolTensor<Flex>, rhs: burn_backend::Scalar) -> BoolTensor<Flex> {
298        use crate::strided_index::StridedIter;
299
300        let out_dtype = burn_std::BoolDType::from(lhs.dtype());
301        let shape = lhs.layout().shape().clone();
302        let storage: &[u8] = lhs.bytes();
303        let rhs_bool: bool = rhs.elem();
304        let rhs_val = rhs_bool as u8;
305
306        let result: Vec<u8> = match lhs.layout().contiguous_offsets() {
307            Some((start, end)) => storage[start..end]
308                .iter()
309                .map(|&v| (v == rhs_val) as u8)
310                .collect(),
311            None => StridedIter::new(lhs.layout())
312                .map(|idx| (storage[idx] == rhs_val) as u8)
313                .collect(),
314        };
315
316        crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
317    }
318
319    fn bool_unfold(
320        tensor: BoolTensor<Flex>,
321        dim: usize,
322        size: usize,
323        step: usize,
324    ) -> BoolTensor<Flex> {
325        crate::ops::unfold::unfold_bool(tensor, dim, size, step)
326    }
327
328    fn bool_not_equal(lhs: BoolTensor<Flex>, rhs: BoolTensor<Flex>) -> BoolTensor<Flex> {
329        let out_dtype = burn_std::BoolDType::from(lhs.dtype());
330        crate::ops::comparison::bool_not_equal(lhs, rhs, out_dtype)
331    }
332
333    fn bool_not_equal_elem(lhs: BoolTensor<Flex>, rhs: burn_backend::Scalar) -> BoolTensor<Flex> {
334        let out_dtype = burn_std::BoolDType::from(lhs.dtype());
335        let rhs: bool = rhs.elem();
336        crate::ops::comparison::bool_not_equal_elem(lhs, rhs, out_dtype)
337    }
338
339    fn bool_any(tensor: BoolTensor<Flex>) -> BoolTensor<Flex> {
340        let out_dtype = burn_std::BoolDType::from(tensor.dtype());
341        crate::ops::comparison::any_bool(tensor, out_dtype)
342    }
343
344    fn bool_any_dim(tensor: BoolTensor<Flex>, dim: usize) -> BoolTensor<Flex> {
345        let out_dtype = burn_std::BoolDType::from(tensor.dtype());
346        crate::ops::comparison::any_bool_dim(tensor, dim, out_dtype)
347    }
348
349    fn bool_all(tensor: BoolTensor<Flex>) -> BoolTensor<Flex> {
350        let out_dtype = burn_std::BoolDType::from(tensor.dtype());
351        crate::ops::comparison::all_bool(tensor, out_dtype)
352    }
353
354    fn bool_all_dim(tensor: BoolTensor<Flex>, dim: usize) -> BoolTensor<Flex> {
355        let out_dtype = burn_std::BoolDType::from(tensor.dtype());
356        crate::ops::comparison::all_bool_dim(tensor, dim, out_dtype)
357    }
358
359    fn bool_select(
360        tensor: BoolTensor<Flex>,
361        dim: usize,
362        indices: IntTensor<Flex>,
363    ) -> BoolTensor<Flex> {
364        crate::ops::gather_scatter::select::<u8>(tensor, dim, indices)
365    }
366
367    fn bool_select_or(
368        tensor: BoolTensor<Flex>,
369        dim: usize,
370        indices: IntTensor<Flex>,
371        value: BoolTensor<Flex>,
372    ) -> BoolTensor<Flex> {
373        let mut result = crate::ops::gather_scatter::select_add::<u8>(tensor, dim, indices, value);
374        // Clamp to 0/1: select_add sums u8 values, but bool OR saturates at 1
375        let storage: &mut [u8] = result.storage_mut();
376        for v in storage.iter_mut() {
377            if *v > 1 {
378                *v = 1;
379            }
380        }
381        result
382    }
383
384    fn bool_transpose(tensor: BoolTensor<Flex>) -> BoolTensor<Flex> {
385        let ndims = tensor.layout().num_dims();
386        if ndims < 2 {
387            return tensor;
388        }
389        tensor.transpose(ndims - 2, ndims - 1)
390    }
391
392    fn bool_repeat_dim(tensor: BoolTensor<Flex>, dim: usize, times: usize) -> BoolTensor<Flex> {
393        crate::ops::repeat_dim::repeat_dim(tensor, dim, times)
394    }
395
396    async fn bool_argwhere(tensor: BoolTensor<Flex>, out_dtype: IntDType) -> IntTensor<Flex> {
397        let tensor = tensor.to_contiguous();
398        let shape = tensor.layout().shape().clone();
399        let ndims = shape.num_dims();
400        let data: &[u8] = tensor.storage();
401        let n = shape.num_elements();
402
403        let count = data[..n].iter().filter(|&&v| v != 0).count();
404        let mut coords: Vec<isize> = Vec::with_capacity(count * ndims);
405        let strides = crate::layout::contiguous_strides_usize(&shape);
406
407        for (flat_idx, &val) in data[..n].iter().enumerate() {
408            if val != 0 {
409                let mut remaining = flat_idx;
410                for &s in &strides {
411                    coords.push((remaining / s) as isize);
412                    remaining %= s;
413                }
414            }
415        }
416
417        let out_shape = Shape::from(vec![count, ndims]);
418        let result = FlexTensor::new(
419            Bytes::from_elems(coords),
420            Layout::contiguous(out_shape),
421            crate::ops::INDEX_DTYPE,
422        );
423        if result.dtype() != DType::from(out_dtype) {
424            Flex::int_cast(result, out_dtype)
425        } else {
426            result
427        }
428    }
429}
430
431/// Boolean binary operation type.
432#[derive(Clone, Copy)]
433enum BoolBinaryOp {
434    And,
435    Or,
436    Xor,
437}
438
439fn bool_binary_op_simd(lhs: FlexTensor, rhs: FlexTensor, op: BoolBinaryOp) -> FlexTensor {
440    use crate::strided_index::StridedIter;
441
442    debug_assert_eq!(lhs.dtype(), rhs.dtype(), "bool_binary_op: dtype mismatch");
443
444    // Broadcast to a common shape before dispatching. The scalar/SIMD helpers
445    // below assume equal-length operands; without this, mismatched shapes
446    // either silently keep the lhs shape or OOB-panic inside the helpers.
447    let (mut lhs, mut rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
448
449    // Preserve the input bool dtype (taken from lhs; rhs is assumed to match
450    // in dtype, checked above).
451    let out_dtype = burn_std::BoolDType::from(lhs.dtype());
452    let shape = lhs.layout().shape().clone();
453    let l_offsets = lhs.layout().contiguous_offsets();
454    let r_offsets = rhs.layout().contiguous_offsets();
455
456    // Fast path 1: lhs is unique and contiguous at offset 0 -> in-place on lhs
457    if lhs.is_unique()
458        && let (Some((0, l_end)), Some((r_start, r_end))) = (l_offsets, r_offsets)
459    {
460        let rhs_storage: &[u8] = rhs.bytes();
461        let r_slice = &rhs_storage[r_start..r_end];
462        let lhs_storage: &mut [u8] = lhs.storage_mut();
463        let l_slice = &mut lhs_storage[..l_end];
464
465        match op {
466            BoolBinaryOp::And => crate::simd::bool_and_inplace_u8(l_slice, r_slice),
467            BoolBinaryOp::Or => crate::simd::bool_or_inplace_u8(l_slice, r_slice),
468            BoolBinaryOp::Xor => crate::simd::bool_xor_inplace_u8(l_slice, r_slice),
469        }
470        return lhs;
471    }
472
473    // Fast path 2: rhs is unique and contiguous at offset 0 -> in-place on rhs
474    // (And/Or/Xor are commutative, so we can swap operands)
475    if rhs.is_unique()
476        && let (Some((l_start, l_end)), Some((0, r_end))) = (l_offsets, r_offsets)
477    {
478        let lhs_storage: &[u8] = lhs.bytes();
479        let l_slice = &lhs_storage[l_start..l_end];
480        let rhs_storage: &mut [u8] = rhs.storage_mut();
481        let r_slice = &mut rhs_storage[..r_end];
482
483        match op {
484            BoolBinaryOp::And => crate::simd::bool_and_inplace_u8(r_slice, l_slice),
485            BoolBinaryOp::Or => crate::simd::bool_or_inplace_u8(r_slice, l_slice),
486            BoolBinaryOp::Xor => crate::simd::bool_xor_inplace_u8(r_slice, l_slice),
487        }
488        return rhs;
489    }
490
491    // Allocating path: neither tensor is suitable for in-place
492    let lhs_storage: &[u8] = lhs.bytes();
493    let rhs_storage: &[u8] = rhs.bytes();
494
495    let result: Vec<u8> = match (l_offsets, r_offsets) {
496        (Some((l_start, l_end)), Some((r_start, r_end))) => {
497            let l_slice = &lhs_storage[l_start..l_end];
498            let r_slice = &rhs_storage[r_start..r_end];
499            let mut out = vec![0u8; l_slice.len()];
500            match op {
501                BoolBinaryOp::And => crate::simd::bool_and_u8(l_slice, r_slice, &mut out),
502                BoolBinaryOp::Or => crate::simd::bool_or_u8(l_slice, r_slice, &mut out),
503                BoolBinaryOp::Xor => crate::simd::bool_xor_u8(l_slice, r_slice, &mut out),
504            }
505            out
506        }
507        _ => {
508            let lhs_iter = StridedIter::new(lhs.layout());
509            let rhs_iter = StridedIter::new(rhs.layout());
510            match op {
511                BoolBinaryOp::And => lhs_iter
512                    .zip(rhs_iter)
513                    .map(|(li, ri)| lhs_storage[li] & rhs_storage[ri])
514                    .collect(),
515                BoolBinaryOp::Or => lhs_iter
516                    .zip(rhs_iter)
517                    .map(|(li, ri)| lhs_storage[li] | rhs_storage[ri])
518                    .collect(),
519                BoolBinaryOp::Xor => lhs_iter
520                    .zip(rhs_iter)
521                    .map(|(li, ri)| lhs_storage[li] ^ rhs_storage[ri])
522                    .collect(),
523            }
524        }
525    };
526
527    crate::ops::comparison::make_bool_tensor(result, shape, out_dtype)
528}
529
530// Tests kept here exercise flex-specific dtype storage selection via
531// explicit IntDType/FloatDType. Plain bool ops, bool-to-int/float
532// casts, and negative-stride (flipped) bool coverage have been migrated
533// to crates/burn-backend-tests/tests/tensor/bool/ops/{logical,cast}.rs
534// so they run against every backend. When adding new tests, keep them
535// here only if they probe flex dtype dispatch; otherwise add them
536// there.
537#[cfg(test)]
538mod tests {
539    use alloc::vec;
540    use burn_backend::TensorData;
541    use burn_backend::ops::BoolTensorOps;
542    use burn_std::{FloatDType, IntDType};
543
544    use crate::{Flex, FlexTensor};
545
546    #[test]
547    fn test_bool_into_int_u8() {
548        let t = FlexTensor::from_data(TensorData::from([true, false, true]));
549        let result = Flex::bool_into_int(t, IntDType::U8);
550        assert_eq!(result.dtype(), burn_backend::DType::U8);
551        let data: Vec<u8> = result.into_data().to_vec().unwrap();
552        assert_eq!(data, vec![1u8, 0, 1]);
553    }
554
555    #[test]
556    fn test_bool_into_float_f64() {
557        let t = FlexTensor::from_data(TensorData::from([true, false, true]));
558        let result = Flex::bool_into_float(t, FloatDType::F64);
559        assert_eq!(result.dtype(), burn_backend::DType::F64);
560        let data: Vec<f64> = result.into_data().to_vec().unwrap();
561        assert_eq!(data, vec![1.0f64, 0.0, 1.0]);
562    }
563}