burn-flex 0.21.0

A fast, portable CPU backend for the Burn framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
//! Gather and scatter operations for indexed tensor access.

use alloc::borrow::Cow;
use alloc::vec;
use alloc::vec::Vec;
use burn_backend::{DType, Element};
use burn_std::{Bytes, Shape};
use bytemuck::Pod;

#[cfg(feature = "rayon")]
use rayon::prelude::*;

use crate::{FlexTensor, Layout};

/// Read indices from a tensor as `isize`, the native offset type used by the
/// gather/scatter/select kernels in this module.
///
/// This is the internal index layer for burn-flex: every indexed op
/// ([`gather`], [`scatter_add`], [`select`], [`select_add`], and the
/// [`scatter_min`]/[`scatter_max`] variants) routes its index tensor through
/// this helper before touching the element buffer. Normalising to `isize`
/// lets the kernels use a single inner-loop signature regardless of how the
/// caller's index tensor was dtyped.
///
/// # Accepted widths
///
/// Any of the integer DTypes `I8`, `I16`, `I32`, `I64`, `U8`, `U16`, `U32`,
/// `U64` is accepted. This is intentional: burn-flex's default `IntElem` is
/// I32 rather than the I64 convention used by other backends, and users can
/// also pin index tensors to any width they want via
/// `Tensor::from_data(.., (&device, DType::Ix))`. Whichever width lands here
/// is converted to `isize` on the fly.
///
/// # Zero-copy vs. owned
///
/// The return type is `Cow<'_, [isize]>` because only one width is zero-copy:
/// the one matching the host pointer width. On 64-bit targets, I64 indices
/// can be borrowed directly via `bytemuck::cast_slice` (both are 8 bytes).
/// Every other width requires an owned `Vec<isize>` with an element-wise
/// cast. U64 indices additionally go through a `try_from` to surface values
/// that would wrap when cast to `isize`.
///
/// # History
///
/// Earlier versions of `int_gather`, `int_scatter_add`, `int_select`, and
/// `int_select_add` carried a `debug_assert_eq!(indices.dtype(), DType::I64,
/// ..)` that contradicted this helper's contract. The asserts were dropped
/// in tracel-ai/burn#4776 once it was confirmed that `read_indices` had
/// always handled every supported width correctly at runtime. If you're
/// tempted to re-add a dtype check here, don't - the float siblings
/// ([`gather_f32`], [`select_f32`], ...) already share this helper without a
/// check, and asymmetry between the int and float paths was what surfaced
/// the bug.
fn read_indices(tensor: &FlexTensor) -> Cow<'_, [isize]> {
    match tensor.dtype() {
        #[cfg(target_pointer_width = "64")]
        DType::I64 => {
            const { assert!(size_of::<i64>() == size_of::<isize>()) };
            let data = tensor.storage::<i64>();
            Cow::Borrowed(bytemuck::cast_slice(data))
        }
        #[cfg(target_pointer_width = "32")]
        DType::I64 => Cow::Owned(
            tensor
                .storage::<i64>()
                .iter()
                .map(|&v| {
                    isize::try_from(v).unwrap_or_else(|_| {
                        panic!("read_indices: i64 index {v} out of isize range")
                    })
                })
                .collect(),
        ),
        #[cfg(target_pointer_width = "64")]
        DType::I32 => Cow::Owned(
            tensor
                .storage::<i32>()
                .iter()
                .map(|&v| v as isize)
                .collect(),
        ),
        #[cfg(target_pointer_width = "32")]
        DType::I32 => {
            const { assert!(size_of::<i32>() == size_of::<isize>()) };
            let data = tensor.storage::<i32>();
            Cow::Borrowed(bytemuck::cast_slice(data))
        }
        DType::I16 => Cow::Owned(
            tensor
                .storage::<i16>()
                .iter()
                .map(|&v| v as isize)
                .collect(),
        ),
        DType::I8 => Cow::Owned(tensor.storage::<i8>().iter().map(|&v| v as isize).collect()),
        DType::U64 => Cow::Owned(
            tensor
                .storage::<u64>()
                .iter()
                .map(|&v| {
                    isize::try_from(v).unwrap_or_else(|_| {
                        panic!("read_indices: u64 index {v} out of isize range")
                    })
                })
                .collect(),
        ),
        #[cfg(target_pointer_width = "64")]
        DType::U32 => Cow::Owned(
            tensor
                .storage::<u32>()
                .iter()
                .map(|&v| v as isize)
                .collect(),
        ),
        #[cfg(target_pointer_width = "32")]
        DType::U32 => Cow::Owned(
            tensor
                .storage::<u32>()
                .iter()
                .map(|&v| {
                    isize::try_from(v).unwrap_or_else(|_| {
                        panic!("read_indices: u32 index {v} out of isize range")
                    })
                })
                .collect(),
        ),
        DType::U16 => Cow::Owned(
            tensor
                .storage::<u16>()
                .iter()
                .map(|&v| v as isize)
                .collect(),
        ),
        DType::U8 => Cow::Owned(tensor.storage::<u8>().iter().map(|&v| v as isize).collect()),
        other => panic!("read_indices: unsupported index dtype {:?}", other),
    }
}

#[cold]
#[inline(never)]
fn index_oob(raw: isize, dim_size: usize) -> ! {
    panic!("index {raw} out of bounds for dimension of size {dim_size}");
}

/// Validate an index is non-negative and within bounds, panicking with a clear message otherwise.
#[inline(always)]
fn checked_index(raw: isize, dim_size: usize) -> usize {
    if raw < 0 || raw as usize >= dim_size {
        index_oob(raw, dim_size);
    }
    raw as usize
}

/// Gather values from tensor along a dimension using indices.
///
/// For a 2D tensor with dim=1:
/// output[i, j] = tensor[i, indices[i, j]]
///
/// The output has the same shape as indices.
pub fn gather<E: Element + Pod + Default + Copy + Send + Sync>(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
) -> FlexTensor {
    let tensor = tensor.to_contiguous();
    let indices = indices.to_contiguous();

    let tensor_shape = tensor.layout().shape();
    let indices_shape = indices.layout().shape();
    let ndims = tensor_shape.num_dims();

    assert!(
        dim < ndims,
        "dim {} out of bounds for {} dimensions",
        dim,
        ndims
    );

    // Validate shapes: all dims except `dim` must match between tensor and indices
    for i in 0..ndims {
        if i != dim {
            assert_eq!(
                tensor_shape[i], indices_shape[i],
                "gather: shape mismatch at dim {}: tensor {} vs indices {}",
                i, tensor_shape[i], indices_shape[i]
            );
        }
    }

    let tensor_data: &[E] = tensor.storage();
    let indices_data = read_indices(&indices);

    // Calculate strides for tensor (row-major)
    let tensor_strides: Vec<usize> = compute_strides(tensor_shape);
    let indices_strides: Vec<usize> = compute_strides(indices_shape);

    let output_size = indices_shape.num_elements();

    // Use specialized 2D implementation for common case
    if ndims == 2 {
        let result = gather_2d::<E>(
            tensor_data,
            &indices_data,
            tensor_shape[0],
            tensor_shape[1],
            indices_shape[0],
            indices_shape[1],
            dim,
        );
        let bytes = Bytes::from_elems(result);
        return FlexTensor::new(bytes, Layout::contiguous(indices_shape.clone()), E::dtype());
    }

    // General N-D case with pre-allocated coordinates
    let dim_stride = tensor_strides[dim];

    let gather_dim_size = tensor_shape[dim];

    #[cfg(feature = "rayon")]
    let result: Vec<E> = (0..output_size)
        .into_par_iter()
        .map(|out_idx| {
            let index_val = checked_index(indices_data[out_idx], gather_dim_size);
            let src_idx = compute_gather_index(
                out_idx,
                index_val,
                dim,
                dim_stride,
                &indices_strides,
                &tensor_strides,
                ndims,
            );
            tensor_data[src_idx]
        })
        .collect();

    #[cfg(not(feature = "rayon"))]
    let result: Vec<E> = (0..output_size)
        .map(|out_idx| {
            let index_val = checked_index(indices_data[out_idx], gather_dim_size);
            let src_idx = compute_gather_index(
                out_idx,
                index_val,
                dim,
                dim_stride,
                &indices_strides,
                &tensor_strides,
                ndims,
            );
            tensor_data[src_idx]
        })
        .collect();

    let bytes = Bytes::from_elems(result);
    FlexTensor::new(bytes, Layout::contiguous(indices_shape.clone()), E::dtype())
}

/// Optimized 2D gather implementation.
#[inline]
fn gather_2d<E: Element + Pod + Default + Copy + Send + Sync>(
    tensor_data: &[E],
    indices_data: &[isize],
    tensor_rows: usize,
    tensor_cols: usize,
    indices_rows: usize,
    indices_cols: usize,
    dim: usize,
) -> Vec<E> {
    let output_size = indices_rows * indices_cols;
    let dim_size = if dim == 0 { tensor_rows } else { tensor_cols };

    let mut result = vec![E::default(); output_size];

    #[cfg(feature = "rayon")]
    const PARALLEL_THRESHOLD: usize = 256 * 1024;

    #[cfg(feature = "rayon")]
    if output_size >= PARALLEL_THRESHOLD {
        if dim == 0 {
            result
                .par_chunks_mut(indices_cols)
                .enumerate()
                .for_each(|(i, row)| {
                    for j in 0..indices_cols {
                        let src_row = checked_index(indices_data[i * indices_cols + j], dim_size);
                        row[j] = tensor_data[src_row * tensor_cols + j];
                    }
                });
        } else {
            result
                .par_chunks_mut(indices_cols)
                .enumerate()
                .for_each(|(i, row)| {
                    for j in 0..indices_cols {
                        let src_col = checked_index(indices_data[i * indices_cols + j], dim_size);
                        row[j] = tensor_data[i * tensor_cols + src_col];
                    }
                });
        }
    } else if dim == 0 {
        for i in 0..indices_rows {
            for j in 0..indices_cols {
                let src_row = checked_index(indices_data[i * indices_cols + j], dim_size);
                result[i * indices_cols + j] = tensor_data[src_row * tensor_cols + j];
            }
        }
    } else {
        for i in 0..indices_rows {
            for j in 0..indices_cols {
                let src_col = checked_index(indices_data[i * indices_cols + j], dim_size);
                result[i * indices_cols + j] = tensor_data[i * tensor_cols + src_col];
            }
        }
    }

    #[cfg(not(feature = "rayon"))]
    {
        if dim == 0 {
            for i in 0..indices_rows {
                for j in 0..indices_cols {
                    let src_row = checked_index(indices_data[i * indices_cols + j], dim_size);
                    result[i * indices_cols + j] = tensor_data[src_row * tensor_cols + j];
                }
            }
        } else {
            for i in 0..indices_rows {
                for j in 0..indices_cols {
                    let src_col = checked_index(indices_data[i * indices_cols + j], dim_size);
                    result[i * indices_cols + j] = tensor_data[i * tensor_cols + src_col];
                }
            }
        }
    }

    result
}

/// Compute source index for gather operation (N-D case).
#[inline]
fn compute_gather_index(
    out_idx: usize,
    index_val: usize,
    dim: usize,
    dim_stride: usize,
    indices_strides: &[usize],
    tensor_strides: &[usize],
    ndims: usize,
) -> usize {
    let mut src_idx = index_val * dim_stride;
    let mut remaining = out_idx;

    for d in 0..ndims {
        if d != dim {
            let coord = remaining / indices_strides[d];
            remaining %= indices_strides[d];
            src_idx += coord * tensor_strides[d];
        } else {
            remaining %= indices_strides[d];
        }
    }
    src_idx
}

/// Scatter add: adds values to tensor at positions specified by indices.
pub fn scatter_add<E: Element + Pod + Default + Copy + core::ops::AddAssign + Send + Sync>(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
    value: FlexTensor,
) -> FlexTensor {
    let tensor = tensor.to_contiguous();
    let indices = indices.to_contiguous();
    let value = value.to_contiguous();

    let tensor_shape = tensor.layout().shape().clone();
    let indices_shape = indices.layout().shape();
    let value_shape = value.layout().shape();
    let ndims = tensor_shape.num_dims();

    assert!(
        dim < ndims,
        "dim {} out of bounds for {} dimensions",
        dim,
        ndims
    );
    assert_eq!(
        indices_shape,
        value_shape,
        "scatter_add: indices shape {:?} must match value shape {:?}",
        indices_shape.to_vec(),
        value_shape.to_vec()
    );

    for i in 0..ndims {
        if i != dim {
            assert_eq!(
                tensor_shape[i], indices_shape[i],
                "scatter_add: shape mismatch at dim {}: tensor {} vs indices {}",
                i, tensor_shape[i], indices_shape[i]
            );
        }
    }

    let tensor_data: &[E] = tensor.storage();
    let indices_data = read_indices(&indices);
    let value_data: &[E] = value.storage();

    let mut result: Vec<E> = tensor_data.to_vec();

    let tensor_strides: Vec<usize> = compute_strides(&tensor_shape);
    let indices_strides: Vec<usize> = compute_strides(indices_shape);

    let num_elements = indices_shape.num_elements();

    // Use specialized 2D implementation
    if ndims == 2 {
        scatter_add_2d(
            &mut result,
            &indices_data,
            value_data,
            tensor_shape[0],
            tensor_shape[1],
            indices_shape[0],
            indices_shape[1],
            dim,
        );
    } else {
        // General N-D case (sequential due to potential index conflicts)
        let dim_stride = tensor_strides[dim];
        let scatter_dim_size = tensor_shape[dim];
        for idx in 0..num_elements {
            let index_val = checked_index(indices_data[idx], scatter_dim_size);
            let dst_idx = compute_gather_index(
                idx,
                index_val,
                dim,
                dim_stride,
                &indices_strides,
                &tensor_strides,
                ndims,
            );
            result[dst_idx] += value_data[idx];
        }
    }

    let bytes = Bytes::from_elems(result);
    FlexTensor::new(bytes, Layout::contiguous(tensor_shape), E::dtype())
}

/// Optimized 2D scatter_add implementation.
#[inline]
#[allow(clippy::too_many_arguments)]
fn scatter_add_2d<E: Copy + core::ops::AddAssign>(
    result: &mut [E],
    indices_data: &[isize],
    value_data: &[E],
    tensor_rows: usize,
    tensor_cols: usize,
    indices_rows: usize,
    indices_cols: usize,
    dim: usize,
) {
    let dim_size = if dim == 0 { tensor_rows } else { tensor_cols };
    if dim == 0 {
        for i in 0..indices_rows {
            for j in 0..indices_cols {
                let idx = i * indices_cols + j;
                let dst_row = checked_index(indices_data[idx], dim_size);
                result[dst_row * tensor_cols + j] += value_data[idx];
            }
        }
    } else {
        for i in 0..indices_rows {
            for j in 0..indices_cols {
                let idx = i * indices_cols + j;
                let dst_col = checked_index(indices_data[idx], dim_size);
                result[i * tensor_cols + dst_col] += value_data[idx];
            }
        }
    }
}

/// Select slices from tensor along a dimension using 1D indices.
///
/// Unlike gather, indices is 1D and selects entire slices.
/// For a 2D tensor with dim=0 and indices=[2, 0]:
/// output[0, :] = tensor[2, :]
/// output[1, :] = tensor[0, :]
pub fn select<E: Element + Pod + Default + Copy + Send + Sync>(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
) -> FlexTensor {
    let tensor = tensor.to_contiguous();
    let indices = indices.to_contiguous();

    let tensor_shape = tensor.layout().shape();
    let ndims = tensor_shape.num_dims();

    assert!(
        dim < ndims,
        "dim {} out of bounds for {} dimensions",
        dim,
        ndims
    );
    assert_eq!(
        indices.layout().num_dims(),
        1,
        "select: indices must be 1D, got {} dims",
        indices.layout().num_dims()
    );

    let tensor_data: &[E] = tensor.storage();
    let indices_data = read_indices(&indices);
    let num_indices = indices_data.len();

    // Build output shape: replace dim with num_indices
    let mut output_dims = tensor_shape.to_vec();
    output_dims[dim] = num_indices;
    let output_shape = Shape::from(output_dims);

    // Use optimized 2D implementation with bulk copies
    if ndims == 2 {
        let result = select_2d::<E>(
            tensor_data,
            &indices_data,
            tensor_shape[0],
            tensor_shape[1],
            num_indices,
            dim,
        );
        let bytes = Bytes::from_elems(result);
        return FlexTensor::new(bytes, Layout::contiguous(output_shape), E::dtype());
    }

    // General N-D case
    let tensor_strides: Vec<usize> = compute_strides(tensor_shape);
    let output_strides: Vec<usize> = compute_strides(&output_shape);
    let output_size = output_shape.num_elements();

    // Calculate slice size (elements after dim)
    let slice_size: usize = tensor_strides[dim];

    let select_dim_size = tensor_shape[dim];

    // If dim is the last dimension or we can use bulk copies
    if dim == ndims - 1 || slice_size == 1 {
        // Element-wise with parallelism
        #[cfg(feature = "rayon")]
        let result: Vec<E> = (0..output_size)
            .into_par_iter()
            .map(|out_idx| {
                let mut remaining = out_idx;
                let mut src_idx = 0;
                for d in 0..ndims {
                    let coord = remaining / output_strides[d];
                    remaining %= output_strides[d];
                    if d == dim {
                        let index_val = checked_index(indices_data[coord], select_dim_size);
                        src_idx += index_val * tensor_strides[d];
                    } else {
                        src_idx += coord * tensor_strides[d];
                    }
                }
                tensor_data[src_idx]
            })
            .collect();

        #[cfg(not(feature = "rayon"))]
        #[allow(clippy::needless_range_loop)]
        let result: Vec<E> = {
            let mut result = vec![E::default(); output_size];
            for out_idx in 0..output_size {
                let mut remaining = out_idx;
                let mut src_idx = 0;
                for d in 0..ndims {
                    let coord = remaining / output_strides[d];
                    remaining %= output_strides[d];
                    if d == dim {
                        let index_val = checked_index(indices_data[coord], select_dim_size);
                        src_idx += index_val * tensor_strides[d];
                    } else {
                        src_idx += coord * tensor_strides[d];
                    }
                }
                result[out_idx] = tensor_data[src_idx];
            }
            result
        };

        let bytes = Bytes::from_elems(result);
        return FlexTensor::new(bytes, Layout::contiguous(output_shape), E::dtype());
    }

    // Use bulk copies for contiguous slices
    let mut result = vec![E::default(); output_size];

    // For each position in dimensions before `dim`
    let outer_count = if dim == 0 {
        1
    } else {
        tensor_shape[..dim].iter().product()
    };

    for outer in 0..outer_count {
        let outer_offset_tensor = outer * tensor_strides[if dim == 0 { 0 } else { dim - 1 }];
        let outer_offset_output = outer * output_strides[if dim == 0 { 0 } else { dim - 1 }];

        for (i, &idx) in indices_data.iter().enumerate() {
            let index_val = checked_index(idx, select_dim_size);
            let src_start = outer_offset_tensor + index_val * tensor_strides[dim];
            let dst_start = outer_offset_output + i * output_strides[dim];
            result[dst_start..dst_start + slice_size]
                .copy_from_slice(&tensor_data[src_start..src_start + slice_size]);
        }
    }

    let bytes = Bytes::from_elems(result);
    FlexTensor::new(bytes, Layout::contiguous(output_shape), E::dtype())
}

/// Optimized 2D select with bulk row copies when dim=0.
#[inline]
fn select_2d<E: Element + Pod + Default + Copy + Send + Sync>(
    tensor_data: &[E],
    indices_data: &[isize],
    tensor_rows: usize,
    tensor_cols: usize,
    num_indices: usize,
    dim: usize,
) -> Vec<E> {
    let dim_size = if dim == 0 { tensor_rows } else { tensor_cols };
    let (output_rows, output_cols) = if dim == 0 {
        (num_indices, tensor_cols)
    } else {
        (tensor_rows, num_indices)
    };
    let output_size = output_rows * output_cols;

    // Minimum bytes of output before we consider rayon. Below this, a
    // single-threaded loop is faster because there is not enough work to
    // amortize the work-stealing dispatch overhead.
    #[cfg(feature = "rayon")]
    const PARALLEL_THRESHOLD_BYTES: usize = 4 * 1024 * 1024;

    // Minimum elements per rayon task. Without batching, par_chunks_mut
    // creates one task per row (e.g. 512 single-row tasks of 4 KB each)
    // whose dispatch overhead dominates the actual copy.
    #[cfg(feature = "rayon")]
    const MIN_ELEMS_PER_TASK: usize = 64 * 1024;

    if dim == 0 {
        // SAFETY: the output has exactly num_indices * tensor_cols elements.
        // Both the parallel and serial paths below write every element exactly
        // once via non-overlapping row copies, so no element is left uninitialized.
        let mut result = Vec::with_capacity(output_size);
        #[allow(clippy::uninit_vec)]
        unsafe {
            result.set_len(output_size)
        };

        #[cfg(feature = "rayon")]
        if output_size * size_of::<E>() >= PARALLEL_THRESHOLD_BYTES {
            // Batch multiple rows per rayon task so each task copies at
            // least MIN_ELEMS_PER_TASK elements.
            let rows_per_chunk = (MIN_ELEMS_PER_TASK / tensor_cols).max(1);
            let elems_per_chunk = rows_per_chunk * tensor_cols;
            result.par_chunks_mut(elems_per_chunk).enumerate().for_each(
                |(chunk_idx, dst_chunk)| {
                    let start_row = chunk_idx * rows_per_chunk;
                    let chunk_rows = dst_chunk.len() / tensor_cols;
                    for i in 0..chunk_rows {
                        let src_row_idx = checked_index(indices_data[start_row + i], dim_size);
                        let src_start = src_row_idx * tensor_cols;
                        let dst_start = i * tensor_cols;
                        dst_chunk[dst_start..dst_start + tensor_cols]
                            .copy_from_slice(&tensor_data[src_start..src_start + tensor_cols]);
                    }
                },
            );
        } else {
            for (i, &idx) in indices_data.iter().enumerate() {
                let src_row_idx = checked_index(idx, dim_size);
                let src_start = src_row_idx * tensor_cols;
                let dst_start = i * tensor_cols;
                result[dst_start..dst_start + tensor_cols]
                    .copy_from_slice(&tensor_data[src_start..src_start + tensor_cols]);
            }
        }

        #[cfg(not(feature = "rayon"))]
        {
            for (i, &idx) in indices_data.iter().enumerate() {
                let src_row_idx = checked_index(idx, dim_size);
                let src_start = src_row_idx * tensor_cols;
                let dst_start = i * tensor_cols;
                result[dst_start..dst_start + tensor_cols]
                    .copy_from_slice(&tensor_data[src_start..src_start + tensor_cols]);
            }
        }

        result
    } else {
        // dim == 1: gather individual elements per row (not contiguous).
        // Zero-init is fine here since the inner loop is per-element anyway.
        let mut result = vec![E::default(); output_size];

        #[cfg(feature = "rayon")]
        if output_size * size_of::<E>() >= PARALLEL_THRESHOLD_BYTES {
            let rows_per_chunk = (MIN_ELEMS_PER_TASK / output_cols).max(1);
            let elems_per_chunk = rows_per_chunk * output_cols;
            result.par_chunks_mut(elems_per_chunk).enumerate().for_each(
                |(chunk_idx, dst_chunk)| {
                    let start_row = chunk_idx * rows_per_chunk;
                    let chunk_rows = dst_chunk.len() / output_cols;
                    for r in 0..chunk_rows {
                        let row = start_row + r;
                        let dst_base = r * output_cols;
                        for (j, &idx) in indices_data.iter().enumerate() {
                            let src_col = checked_index(idx, dim_size);
                            dst_chunk[dst_base + j] = tensor_data[row * tensor_cols + src_col];
                        }
                    }
                },
            );
        } else {
            for row in 0..output_rows {
                for (j, &idx) in indices_data.iter().enumerate() {
                    let src_col = checked_index(idx, dim_size);
                    result[row * output_cols + j] = tensor_data[row * tensor_cols + src_col];
                }
            }
        }

        #[cfg(not(feature = "rayon"))]
        {
            for row in 0..output_rows {
                for (j, &idx) in indices_data.iter().enumerate() {
                    let src_col = checked_index(idx, dim_size);
                    result[row * output_cols + j] = tensor_data[row * tensor_cols + src_col];
                }
            }
        }

        result
    }
}

/// Select add: adds values back to tensor at positions specified by 1D indices.
pub fn select_add<E: Element + Pod + Default + Copy + core::ops::AddAssign + Send + Sync>(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
    value: FlexTensor,
) -> FlexTensor {
    let tensor = tensor.to_contiguous();
    let indices = indices.to_contiguous();
    let value = value.to_contiguous();

    let tensor_shape = tensor.layout().shape().clone();
    let value_shape = value.layout().shape();
    let ndims = tensor_shape.num_dims();

    assert!(
        dim < ndims,
        "dim {} out of bounds for {} dimensions",
        dim,
        ndims
    );
    assert_eq!(
        indices.layout().num_dims(),
        1,
        "select_add: indices must be 1D"
    );

    let tensor_data: &[E] = tensor.storage();
    let indices_data = read_indices(&indices);
    let value_data: &[E] = value.storage();
    let num_indices = indices_data.len();

    // Validate value shape
    for d in 0..ndims {
        if d == dim {
            assert_eq!(
                value_shape[d], num_indices,
                "select_add: value dim {} should be {} (num indices), got {}",
                d, num_indices, value_shape[d]
            );
        } else {
            assert_eq!(
                value_shape[d], tensor_shape[d],
                "select_add: value dim {} should match tensor dim {}, got {}",
                d, tensor_shape[d], value_shape[d]
            );
        }
    }

    let mut result: Vec<E> = tensor_data.to_vec();

    // Use optimized 2D implementation
    if ndims == 2 {
        select_add_2d(
            &mut result,
            &indices_data,
            value_data,
            tensor_shape[0],
            tensor_shape[1],
            num_indices,
            dim,
        );
        let bytes = Bytes::from_elems(result);
        return FlexTensor::new(bytes, Layout::contiguous(tensor_shape), E::dtype());
    }

    // General N-D case
    let tensor_strides: Vec<usize> = compute_strides(&tensor_shape);
    let value_strides: Vec<usize> = compute_strides(value_shape);
    let select_add_dim_size = tensor_shape[dim];

    for (val_idx, &val) in value_data.iter().enumerate() {
        let mut remaining = val_idx;
        let mut dst_idx = 0;
        for d in 0..ndims {
            let coord = remaining / value_strides[d];
            remaining %= value_strides[d];
            if d == dim {
                let index_val = checked_index(indices_data[coord], select_add_dim_size);
                dst_idx += index_val * tensor_strides[d];
            } else {
                dst_idx += coord * tensor_strides[d];
            }
        }
        result[dst_idx] += val;
    }

    let bytes = Bytes::from_elems(result);
    FlexTensor::new(bytes, Layout::contiguous(tensor_shape), E::dtype())
}

/// Optimized 2D select_add.
#[inline]
fn select_add_2d<E: Copy + core::ops::AddAssign>(
    result: &mut [E],
    indices_data: &[isize],
    value_data: &[E],
    tensor_rows: usize,
    tensor_cols: usize,
    num_indices: usize,
    dim: usize,
) {
    let dim_size = if dim == 0 { tensor_rows } else { tensor_cols };
    if dim == 0 {
        for (i, &idx) in indices_data.iter().enumerate() {
            let dst_row = checked_index(idx, dim_size);
            let dst_start = dst_row * tensor_cols;
            let src_start = i * tensor_cols;
            for j in 0..tensor_cols {
                result[dst_start + j] += value_data[src_start + j];
            }
        }
    } else {
        for row in 0..tensor_rows {
            for (j, &idx) in indices_data.iter().enumerate() {
                let dst_col = checked_index(idx, dim_size);
                result[row * tensor_cols + dst_col] += value_data[row * num_indices + j];
            }
        }
    }
}

/// Compute row-major strides for a shape.
#[inline]
fn compute_strides(dims: &[usize]) -> Vec<usize> {
    let ndims = dims.len();
    let mut strides = vec![1usize; ndims];
    for i in (0..ndims.saturating_sub(1)).rev() {
        strides[i] = strides[i + 1] * dims[i + 1];
    }
    strides
}

/// Multi-dimensional scatter: update `data` at locations specified by N-dimensional index tuples.
pub fn scatter_nd<
    E: Element + Pod + Default + Copy + core::ops::AddAssign + core::ops::Mul<Output = E> + PartialOrd,
>(
    data: FlexTensor,
    indices: FlexTensor,
    values: FlexTensor,
    reduction: burn_backend::tensor::IndexingUpdateOp,
) -> FlexTensor {
    use burn_backend::tensor::IndexingUpdateOp;

    let data = data.to_contiguous();
    let indices = indices.to_contiguous();
    let values = values.to_contiguous();

    let data_shape: Vec<usize> = data.layout().shape().to_vec();
    let idx_shape: Vec<usize> = indices.layout().shape().to_vec();
    let m = idx_shape.len();
    let k = idx_shape[m - 1];

    let num_indices: usize = idx_shape[..m - 1].iter().product();
    let slice_size: usize = data_shape[k..].iter().product();

    let data_data: &[E] = data.storage();
    let idx_data = read_indices(&indices);
    let val_data: &[E] = values.storage();

    let mut result: Vec<E> = data_data.to_vec();

    let strides = compute_strides(&data_shape);

    for n in 0..num_indices {
        let mut base_offset = 0usize;
        for j in 0..k {
            let idx_val = idx_data[n * k + j] as usize;
            base_offset += idx_val * strides[j];
        }

        let val_offset = n * slice_size;
        match reduction {
            IndexingUpdateOp::Assign => {
                result[base_offset..(base_offset + slice_size)]
                    .copy_from_slice(&val_data[val_offset..(val_offset + slice_size)]);
            }
            IndexingUpdateOp::Add => {
                for s in 0..slice_size {
                    result[base_offset + s] += val_data[val_offset + s];
                }
            }
            IndexingUpdateOp::Mul => {
                for s in 0..slice_size {
                    result[base_offset + s] = result[base_offset + s] * val_data[val_offset + s];
                }
            }
            IndexingUpdateOp::Min => {
                for s in 0..slice_size {
                    let b = val_data[val_offset + s];
                    if b < result[base_offset + s] {
                        result[base_offset + s] = b;
                    }
                }
            }
            IndexingUpdateOp::Max => {
                for s in 0..slice_size {
                    let b = val_data[val_offset + s];
                    if b > result[base_offset + s] {
                        result[base_offset + s] = b;
                    }
                }
            }
        }
    }

    let shape = Shape::from(data_shape);
    let bytes = Bytes::from_elems(result);
    FlexTensor::new(bytes, Layout::contiguous(shape), E::dtype())
}

/// Multi-dimensional gather: collect slices from `data` at locations specified by N-dimensional
/// index tuples.
pub fn gather_nd<E: Element + Pod + Default + Copy>(
    data: FlexTensor,
    indices: FlexTensor,
) -> FlexTensor {
    let data = data.to_contiguous();
    let indices = indices.to_contiguous();

    let data_shape: Vec<usize> = data.layout().shape().to_vec();
    let idx_shape: Vec<usize> = indices.layout().shape().to_vec();
    let m = idx_shape.len();
    let k = idx_shape[m - 1];

    let num_indices: usize = idx_shape[..m - 1].iter().product();
    let slice_size: usize = data_shape[k..].iter().product();

    let data_data: &[E] = data.storage();
    let idx_data = read_indices(&indices);

    let mut out_shape_vec: Vec<usize> = idx_shape[..m - 1].to_vec();
    out_shape_vec.extend_from_slice(&data_shape[k..]);

    let strides = compute_strides(&data_shape);

    let total = num_indices * slice_size;
    let mut result = vec![E::default(); total];

    for n in 0..num_indices {
        let mut base_offset = 0usize;
        for j in 0..k {
            let idx_val = idx_data[n * k + j] as usize;
            base_offset += idx_val * strides[j];
        }
        let out_offset = n * slice_size;
        result[out_offset..(out_offset + slice_size)]
            .copy_from_slice(&data_data[base_offset..(base_offset + slice_size)]);
    }

    let shape = Shape::from(out_shape_vec);
    let bytes = Bytes::from_elems(result);
    FlexTensor::new(bytes, Layout::contiguous(shape), E::dtype())
}

// Type-specific wrappers

pub fn gather_f32(tensor: FlexTensor, dim: usize, indices: FlexTensor) -> FlexTensor {
    gather::<f32>(tensor, dim, indices)
}

pub fn gather_f64(tensor: FlexTensor, dim: usize, indices: FlexTensor) -> FlexTensor {
    gather::<f64>(tensor, dim, indices)
}

pub fn gather_i64(tensor: FlexTensor, dim: usize, indices: FlexTensor) -> FlexTensor {
    gather::<i64>(tensor, dim, indices)
}

pub fn scatter_add_f32(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
    value: FlexTensor,
) -> FlexTensor {
    scatter_add::<f32>(tensor, dim, indices, value)
}

pub fn scatter_add_f64(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
    value: FlexTensor,
) -> FlexTensor {
    scatter_add::<f64>(tensor, dim, indices, value)
}

pub fn scatter_add_i64(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
    value: FlexTensor,
) -> FlexTensor {
    scatter_add::<i64>(tensor, dim, indices, value)
}

pub fn select_f32(tensor: FlexTensor, dim: usize, indices: FlexTensor) -> FlexTensor {
    select::<f32>(tensor, dim, indices)
}

pub fn select_f64(tensor: FlexTensor, dim: usize, indices: FlexTensor) -> FlexTensor {
    select::<f64>(tensor, dim, indices)
}

pub fn select_i64(tensor: FlexTensor, dim: usize, indices: FlexTensor) -> FlexTensor {
    select::<i64>(tensor, dim, indices)
}

pub fn select_add_f32(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
    value: FlexTensor,
) -> FlexTensor {
    select_add::<f32>(tensor, dim, indices, value)
}

pub fn select_add_f64(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
    value: FlexTensor,
) -> FlexTensor {
    select_add::<f64>(tensor, dim, indices, value)
}

pub fn select_add_i64(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
    value: FlexTensor,
) -> FlexTensor {
    select_add::<i64>(tensor, dim, indices, value)
}

// Bool-specific operations

pub fn gather_bool(tensor: FlexTensor, dim: usize, indices: FlexTensor) -> FlexTensor {
    gather::<u8>(tensor, dim, indices)
}

/// Scatter OR for bool tensors: ORs values into tensor at indexed positions.
pub fn scatter_or(
    tensor: FlexTensor,
    dim: usize,
    indices: FlexTensor,
    value: FlexTensor,
) -> FlexTensor {
    // Preserve the input tensor's bool dtype for the output.
    let out_dtype = burn_std::BoolDType::from(tensor.dtype());
    let tensor = tensor.to_contiguous();
    let indices = indices.to_contiguous();
    let value = value.to_contiguous();

    let tensor_shape = tensor.layout().shape().clone();
    let indices_shape = indices.layout().shape();
    let value_shape = value.layout().shape();
    let ndims = tensor_shape.num_dims();

    assert!(
        dim < ndims,
        "dim {} out of bounds for {} dimensions",
        dim,
        ndims
    );
    assert_eq!(
        indices_shape,
        value_shape,
        "scatter_or: indices shape {:?} must match value shape {:?}",
        indices_shape.to_vec(),
        value_shape.to_vec()
    );

    for i in 0..ndims {
        if i != dim {
            assert_eq!(
                tensor_shape[i], indices_shape[i],
                "scatter_or: shape mismatch at dim {}: tensor {} vs indices {}",
                i, tensor_shape[i], indices_shape[i]
            );
        }
    }

    let tensor_data: &[u8] = tensor.storage();
    let indices_data = read_indices(&indices);
    let value_data: &[u8] = value.storage();

    let mut result: Vec<u8> = tensor_data.to_vec();

    let tensor_strides = compute_strides(&tensor_shape);
    let indices_strides = compute_strides(indices_shape);

    let num_elements = indices_shape.num_elements();

    let scatter_or_dim_size = tensor_shape[dim];

    // Use 2D specialized path
    if ndims == 2 {
        let tensor_cols = tensor_shape[1];
        let indices_rows = indices_shape[0];
        let indices_cols = indices_shape[1];

        if dim == 0 {
            for i in 0..indices_rows {
                for j in 0..indices_cols {
                    let idx = i * indices_cols + j;
                    let dst_row = checked_index(indices_data[idx], scatter_or_dim_size);
                    result[dst_row * tensor_cols + j] |= value_data[idx];
                }
            }
        } else {
            for i in 0..indices_rows {
                for j in 0..indices_cols {
                    let idx = i * indices_cols + j;
                    let dst_col = checked_index(indices_data[idx], scatter_or_dim_size);
                    result[i * tensor_cols + dst_col] |= value_data[idx];
                }
            }
        }
    } else {
        let dim_stride = tensor_strides[dim];
        for idx in 0..num_elements {
            let index_val = checked_index(indices_data[idx], scatter_or_dim_size);
            let dst_idx = compute_gather_index(
                idx,
                index_val,
                dim,
                dim_stride,
                &indices_strides,
                &tensor_strides,
                ndims,
            );
            result[dst_idx] |= value_data[idx];
        }
    }

    crate::ops::comparison::make_bool_tensor(result, tensor_shape, out_dtype)
}

// Tests kept here probe flex-specific behavior: non-I64 index dtype
// acceptance through the internal `read_indices` path and the uninit-
// buffer + rayon-chunked `select` kernel. Plain gather/scatter/select
// tests (including an empty-indices edge case) live in
// crates/burn-backend-tests/tests/tensor/float/ops/{gather_scatter,select}.rs
// so every backend is exercised. When adding new tests, keep them here
// only if they probe flex internals; otherwise add them there.
#[cfg(test)]
mod tests {
    use super::*;
    use burn_backend::TensorData;

    #[test]
    fn test_gather_with_i32_indices() {
        let tensor = FlexTensor::from_data(TensorData::new(vec![10.0f32, 20.0, 30.0, 40.0], [4]));
        let indices = FlexTensor::from_data(TensorData::new(vec![3i32, 0, 2], [3]));

        let result = gather::<f32>(tensor, 0, indices);
        let data: Vec<f32> = result.into_data().to_vec().unwrap();
        assert_eq!(data, vec![40.0, 10.0, 30.0]);
    }

    #[test]
    fn test_select_with_i32_indices() {
        let tensor = FlexTensor::from_data(TensorData::new(
            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
            [3, 2],
        ));
        let indices = FlexTensor::from_data(TensorData::new(vec![2i32, 0], [2]));

        let result = select::<f32>(tensor, 0, indices);
        let data: Vec<f32> = result.into_data().to_vec().unwrap();
        assert_eq!(data, vec![5.0, 6.0, 1.0, 2.0]);
    }

    /// Exercises the uninit buffer path (dim=0) and, when rayon is enabled,
    /// the chunked parallel path (output > 4 MB).
    #[test]
    fn test_select_2d_dim0_large() {
        let rows = 2048;
        let cols = 1024;
        let data: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
        let tensor = FlexTensor::from_data(TensorData::new(data.clone(), [rows, cols]));

        // Select every other row in reverse order.
        let idx: Vec<i64> = (0..rows as i64).rev().step_by(2).collect();
        let num_idx = idx.len();
        let indices = FlexTensor::from_data(TensorData::new(idx.clone(), [num_idx]));

        let result = select::<f32>(tensor, 0, indices);
        assert_eq!(result.layout().shape().to_vec(), vec![num_idx, cols]);
        let out: Vec<f32> = result.into_data().to_vec().unwrap();

        for (i, &row_idx) in idx.iter().enumerate() {
            let expected_start = row_idx as usize * cols;
            let actual = &out[i * cols..(i + 1) * cols];
            let expected = &data[expected_start..expected_start + cols];
            assert_eq!(
                actual, expected,
                "mismatch at output row {i} (src row {row_idx})"
            );
        }
    }
}