burn-flex 0.21.0

A fast, portable CPU backend for the Burn framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
//! Matrix multiplication via gemm crate.
//!
//! Optimizations:
//! - Strided gemm for f32/f64/f16 avoids copying non-contiguous tensors
//! - Enables parallelism for large matrices (with rayon feature)
//! - Batched matmul parallelized across batch dimension

use alloc::vec;
use alloc::vec::Vec;
use burn_backend::{DType, Element};
use burn_std::{Bytes, Shape, bf16, f16};

use crate::{FlexTensor, Layout};

/// Types that can be used with gemm-based matmul.
/// Only implement for types that `gemm::gemm` dispatches on via TypeId (f32, f64, f16).
trait GemmScalar: Element + bytemuck::Pod {
    fn zero() -> Self;
    fn one() -> Self;
}

impl GemmScalar for f32 {
    fn zero() -> Self {
        0.0
    }
    fn one() -> Self {
        1.0
    }
}

impl GemmScalar for f64 {
    fn zero() -> Self {
        0.0
    }
    fn one() -> Self {
        1.0
    }
}

impl GemmScalar for f16 {
    fn zero() -> Self {
        f16::from_f32(0.0)
    }
    fn one() -> Self {
        f16::from_f32(1.0)
    }
}

/// Checked multiplication for matrix sizes, panics on overflow.
#[inline]
fn checked_size(a: usize, b: usize) -> usize {
    a.checked_mul(b)
        .unwrap_or_else(|| panic!("matmul: matrix size overflow: {a} * {b}"))
}

/// Threshold for enabling parallelism (M*N*K operations).
/// 192^3 = ~7M ops - balance between 128x128 (no parallel) and 256x256 (parallel)
const PARALLEL_THRESHOLD: usize = 192 * 192 * 192;

/// Threshold for batch-level parallelism (total ops across all batches).
/// Use batch parallelism when individual matrices are small but total work is large.
#[cfg(feature = "rayon")]
const BATCH_PARALLEL_THRESHOLD: usize = 128 * 128 * 128; // ~2M ops total

/// Get parallelism setting based on matrix size.
fn get_parallelism(m: usize, n: usize, k: usize) -> gemm::Parallelism {
    let ops = m.saturating_mul(n).saturating_mul(k);
    if ops >= PARALLEL_THRESHOLD {
        #[cfg(feature = "rayon")]
        {
            gemm::Parallelism::Rayon(0) // 0 = use all available threads
        }
        #[cfg(not(feature = "rayon"))]
        {
            gemm::Parallelism::None
        }
    } else {
        gemm::Parallelism::None
    }
}

/// Dispatch matrix multiplication based on dtype.
pub fn matmul(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    assert_eq!(lhs.dtype(), rhs.dtype(), "matmul: dtype mismatch");

    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();
    let lhs_rank = lhs_shape.num_dims();
    let rhs_rank = rhs_shape.num_dims();

    assert!(lhs_rank >= 2, "matmul requires at least 2D tensors");
    assert!(rhs_rank >= 2, "matmul requires at least 2D tensors");

    // Check inner dimensions match: lhs[..., M, K] x rhs[..., K, N]
    let k_lhs = lhs_shape[lhs_rank - 1];
    let k_rhs = rhs_shape[rhs_rank - 2];
    assert_eq!(k_lhs, k_rhs, "matmul: inner dimensions must match");

    match lhs.dtype() {
        DType::F32 => matmul_gemm::<f32>(lhs, rhs),
        DType::F64 => matmul_gemm::<f64>(lhs, rhs),
        DType::F16 => matmul_gemm::<f16>(lhs, rhs),
        DType::BF16 => matmul_bf16(lhs, rhs),
        _ => panic!("matmul: unsupported dtype {:?}", lhs.dtype()),
    }
}

/// Extract 2D matrix strides from a tensor layout.
/// Returns (row_stride, col_stride) for the last two dimensions.
fn get_2d_strides(layout: &Layout) -> (isize, isize) {
    let strides = layout.strides();
    let ndim = strides.len();
    let row_stride = strides[ndim - 2];
    let col_stride = strides[ndim - 1];
    (row_stride, col_stride)
}

/// Compute broadcast batch dimensions for batched matmul.
/// Returns (broadcast_shape, lhs_strides, rhs_strides) where strides map
/// output batch index to input batch offset (in matrices).
fn broadcast_batch_dims(
    lhs_batch: &[usize],
    rhs_batch: &[usize],
) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
    // Pad shorter batch dims with 1s on the left
    let max_len = lhs_batch.len().max(rhs_batch.len());
    let lhs_padded: Vec<usize> = (0..max_len)
        .map(|i| {
            if i < max_len - lhs_batch.len() {
                1
            } else {
                lhs_batch[i - (max_len - lhs_batch.len())]
            }
        })
        .collect();
    let rhs_padded: Vec<usize> = (0..max_len)
        .map(|i| {
            if i < max_len - rhs_batch.len() {
                1
            } else {
                rhs_batch[i - (max_len - rhs_batch.len())]
            }
        })
        .collect();

    // Compute broadcast shape and strides
    let mut broadcast_shape = Vec::with_capacity(max_len);
    let mut lhs_strides = Vec::with_capacity(max_len);
    let mut rhs_strides = Vec::with_capacity(max_len);

    // Compute strides from right to left
    let mut lhs_stride = 1usize;
    let mut rhs_stride = 1usize;
    for i in (0..max_len).rev() {
        let ld = lhs_padded[i];
        let rd = rhs_padded[i];
        debug_assert!(
            ld == rd || ld == 1 || rd == 1,
            "matmul: batch dimensions not broadcastable: {:?} vs {:?}",
            lhs_batch,
            rhs_batch
        );
        broadcast_shape.push(ld.max(rd));
        // Stride is 0 if dimension is 1 (broadcast), otherwise actual stride
        lhs_strides.push(if ld == 1 { 0 } else { lhs_stride });
        rhs_strides.push(if rd == 1 { 0 } else { rhs_stride });
        lhs_stride *= ld;
        rhs_stride *= rd;
    }

    // Reverse to get correct order
    broadcast_shape.reverse();
    lhs_strides.reverse();
    rhs_strides.reverse();

    (broadcast_shape, lhs_strides, rhs_strides)
}

/// Convert a flat batch index to input batch offset using broadcast strides.
#[inline]
fn batch_index_to_offset(b: usize, broadcast_shape: &[usize], strides: &[usize]) -> usize {
    let mut offset = 0;
    let mut remaining = b;
    for i in (0..broadcast_shape.len()).rev() {
        let idx = remaining % broadcast_shape[i];
        offset += idx * strides[i];
        remaining /= broadcast_shape[i];
    }
    offset
}

/// Compute element-level batch strides for a tensor in a broadcast context.
/// Uses the actual layout strides so non-contiguous (transposed/sliced) tensors
/// work without a copy. Dimensions that are broadcast (size 1) get stride 0.
#[allow(clippy::needless_range_loop)]
fn broadcast_batch_elem_strides(
    batch_shape: &[usize],
    layout_strides: &[isize],
    broadcast_len: usize,
) -> Vec<isize> {
    let batch_ndim = batch_shape.len();
    debug_assert!(broadcast_len >= batch_ndim);
    let mut result = vec![0isize; broadcast_len];

    for i in 0..broadcast_len {
        let batch_idx = i as isize - (broadcast_len as isize - batch_ndim as isize);
        if batch_idx >= 0 {
            let bi = batch_idx as usize;
            if batch_shape[bi] > 1 {
                result[i] = layout_strides[bi];
            }
        }
    }

    result
}

/// Convert a flat batch index to an element offset using element-level strides.
#[inline]
fn batch_elem_offset(b: usize, broadcast_shape: &[usize], elem_strides: &[isize]) -> isize {
    let mut offset: isize = 0;
    let mut remaining = b;
    for i in (0..broadcast_shape.len()).rev() {
        let idx = remaining % broadcast_shape[i];
        offset += idx as isize * elem_strides[i];
        remaining /= broadcast_shape[i];
    }
    offset
}

// ============================================================================
// Generic gemm-based matmul (f32, f64, f16)
// ============================================================================

fn matmul_gemm<T: GemmScalar>(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    let lhs_rank = lhs.layout().shape().num_dims();
    let rhs_rank = rhs.layout().shape().num_dims();

    if lhs_rank == 2 && rhs_rank == 2 {
        matmul_2d_strided::<T>(lhs, rhs)
    } else {
        matmul_batched_gemm::<T>(lhs, rhs)
    }
}

/// 2D matmul with strided support: [M, K] x [K, N] -> [M, N]
fn matmul_2d_strided<T: GemmScalar>(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();

    let m = lhs_shape[0];
    let k = lhs_shape[1];
    let n = rhs_shape[1];

    let (lhs_row_stride, lhs_col_stride) = get_2d_strides(lhs.layout());
    let (rhs_row_stride, rhs_col_stride) = get_2d_strides(rhs.layout());

    let lhs_data: &[T] = lhs.storage();
    let rhs_data: &[T] = rhs.storage();
    let lhs_ptr = unsafe { lhs_data.as_ptr().add(lhs.layout().start_offset()) };
    let rhs_ptr = unsafe { rhs_data.as_ptr().add(rhs.layout().start_offset()) };

    let out_shape = Shape::from(vec![m, n]);
    let mut output = FlexTensor::empty(out_shape, T::dtype());
    let out_data: &mut [T] = output.storage_mut();

    let parallelism = get_parallelism(m, n, k);

    unsafe {
        gemm_call(
            m,
            n,
            k,
            out_data.as_mut_ptr(),
            1,
            n as isize,
            lhs_ptr,
            lhs_col_stride,
            lhs_row_stride,
            rhs_ptr,
            rhs_col_stride,
            rhs_row_stride,
            parallelism,
        );
    }

    output
}

/// Strided gemm call for one matrix. Wraps `gemm::gemm` with GemmScalar zero/one.
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn gemm_call<T: GemmScalar>(
    m: usize,
    n: usize,
    k: usize,
    out: *mut T,
    out_cs: isize,
    out_rs: isize,
    lhs: *const T,
    lhs_cs: isize,
    lhs_rs: isize,
    rhs: *const T,
    rhs_cs: isize,
    rhs_rs: isize,
    parallelism: gemm::Parallelism,
) {
    unsafe {
        gemm::gemm(
            m,
            n,
            k,
            out,
            out_cs,
            out_rs,
            false,
            lhs,
            lhs_cs,
            lhs_rs,
            rhs,
            rhs_cs,
            rhs_rs,
            T::zero(),
            T::one(),
            false,
            false,
            false,
            parallelism,
        );
    }
}

/// Batched matmul: [B..., M, K] x [B..., K, N] -> [B..., M, N]
/// Supports broadcasting on batch dimensions and strided (non-contiguous) inputs.
fn matmul_batched_gemm<T: GemmScalar>(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();
    let lhs_rank = lhs_shape.num_dims();
    let rhs_rank = rhs_shape.num_dims();

    let m = lhs_shape[lhs_rank - 2];
    let k = lhs_shape[lhs_rank - 1];
    let n = rhs_shape[rhs_rank - 1];

    let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
    let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();

    let (broadcast_shape, _, _) = broadcast_batch_dims(&lhs_batch, &rhs_batch);
    let batch_size: usize = broadcast_shape.iter().product();
    let broadcast_len = broadcast_shape.len();

    let lhs_batch_strides =
        broadcast_batch_elem_strides(&lhs_batch, lhs.layout().strides(), broadcast_len);
    let rhs_batch_strides =
        broadcast_batch_elem_strides(&rhs_batch, rhs.layout().strides(), broadcast_len);

    let (lhs_row_stride, lhs_col_stride) = get_2d_strides(lhs.layout());
    let (rhs_row_stride, rhs_col_stride) = get_2d_strides(rhs.layout());

    let out_matrix_size = checked_size(m, n);

    let mut out_dims = broadcast_shape.clone();
    out_dims.push(m);
    out_dims.push(n);
    let out_shape = Shape::from(out_dims);

    let mut output = FlexTensor::empty(out_shape, T::dtype());

    let lhs_data: &[T] = lhs.storage();
    let rhs_data: &[T] = rhs.storage();
    let lhs_start = lhs.layout().start_offset() as isize;
    let rhs_start = rhs.layout().start_offset() as isize;
    let out_data: &mut [T] = output.storage_mut();

    let per_matrix_ops = m.saturating_mul(n).saturating_mul(k);

    // Closure: run gemm for one batch slice at the given pointers
    let run_one = |out_ptr: *mut T, b: usize, parallelism: gemm::Parallelism| {
        let lhs_off = lhs_start + batch_elem_offset(b, &broadcast_shape, &lhs_batch_strides);
        let rhs_off = rhs_start + batch_elem_offset(b, &broadcast_shape, &rhs_batch_strides);
        unsafe {
            gemm_call::<T>(
                m,
                n,
                k,
                out_ptr,
                1,
                n as isize,
                lhs_data.as_ptr().offset(lhs_off),
                lhs_col_stride,
                lhs_row_stride,
                rhs_data.as_ptr().offset(rhs_off),
                rhs_col_stride,
                rhs_row_stride,
                parallelism,
            );
        }
    };

    // Strategy:
    // 1. Large matrices: let gemm parallelize internally
    // 2. Small matrices, large batch: parallelize batch loop
    // 3. Small total work: single-threaded
    #[cfg(feature = "rayon")]
    {
        let total_ops = batch_size.saturating_mul(per_matrix_ops);
        let prefer_batch_parallel = batch_size >= 4 && total_ops >= BATCH_PARALLEL_THRESHOLD;

        if per_matrix_ops >= PARALLEL_THRESHOLD && !prefer_batch_parallel {
            let parallelism = gemm::Parallelism::Rayon(0);
            for b in 0..batch_size {
                run_one(out_data[b * out_matrix_size..].as_mut_ptr(), b, parallelism);
            }
        } else if total_ops >= BATCH_PARALLEL_THRESHOLD && batch_size > 1 {
            use rayon::prelude::*;

            out_data
                .par_chunks_mut(out_matrix_size)
                .enumerate()
                .for_each(|(b, out_chunk)| {
                    run_one(out_chunk.as_mut_ptr(), b, gemm::Parallelism::None);
                });
        } else {
            for b in 0..batch_size {
                run_one(
                    out_data[b * out_matrix_size..].as_mut_ptr(),
                    b,
                    gemm::Parallelism::None,
                );
            }
        }
    }

    #[cfg(not(feature = "rayon"))]
    {
        let _ = per_matrix_ops;
        for b in 0..batch_size {
            run_one(
                out_data[b * out_matrix_size..].as_mut_ptr(),
                b,
                gemm::Parallelism::None,
            );
        }
    }

    output
}

// ============================================================================
// bf16 matmul (via f32 conversion)
// ============================================================================

fn matmul_bf16(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    let lhs = lhs.to_contiguous();
    let rhs = rhs.to_contiguous();

    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();

    // Convert bf16 -> f32
    let lhs_f32: Vec<f32> = lhs.storage::<bf16>().iter().map(|x| x.to_f32()).collect();
    let rhs_f32: Vec<f32> = rhs.storage::<bf16>().iter().map(|x| x.to_f32()).collect();

    // Create f32 tensors
    let lhs_f32_tensor = FlexTensor::new(
        Bytes::from_elems(lhs_f32),
        Layout::contiguous(lhs_shape.clone()),
        DType::F32,
    );
    let rhs_f32_tensor = FlexTensor::new(
        Bytes::from_elems(rhs_f32),
        Layout::contiguous(rhs_shape.clone()),
        DType::F32,
    );

    // Compute matmul in f32
    let result_f32 = matmul_gemm::<f32>(lhs_f32_tensor, rhs_f32_tensor);

    // Convert f32 -> bf16
    let result_bf16: Vec<bf16> = result_f32
        .storage::<f32>()
        .iter()
        .map(|x| bf16::from_f32(*x))
        .collect();

    FlexTensor::new(
        Bytes::from_elems(result_bf16),
        result_f32.layout().clone(),
        DType::BF16,
    )
}

// ============================================================================
// Integer matmul (naive, with optional SIMD for i32)
// ============================================================================

/// Integer matrix multiplication dispatch.
pub fn int_matmul(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    assert_eq!(lhs.dtype(), rhs.dtype(), "int_matmul: dtype mismatch");

    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();
    let lhs_rank = lhs_shape.num_dims();
    let rhs_rank = rhs_shape.num_dims();

    assert!(lhs_rank >= 2, "int_matmul requires at least 2D tensors");
    assert!(rhs_rank >= 2, "int_matmul requires at least 2D tensors");

    let k_lhs = lhs_shape[lhs_rank - 1];
    let k_rhs = rhs_shape[rhs_rank - 2];
    assert_eq!(k_lhs, k_rhs, "int_matmul: inner dimensions must match");

    match lhs.dtype() {
        DType::I32 => matmul_i32(lhs, rhs),
        DType::I64 => matmul_i64(lhs, rhs),
        _ => panic!("int_matmul: unsupported dtype {:?}", lhs.dtype()),
    }
}

/// i32 matmul using naive triple loop with SIMD dot product.
fn matmul_i32(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    let lhs = lhs.to_contiguous();
    let rhs = rhs.to_contiguous();

    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();
    let lhs_rank = lhs_shape.num_dims();
    let rhs_rank = rhs_shape.num_dims();

    if lhs_rank == 2 && rhs_rank == 2 {
        matmul_2d_i32(&lhs, &rhs)
    } else {
        matmul_batched_i32(lhs, rhs)
    }
}

/// 2D i32 matmul: [M, K] x [K, N] -> [M, N]
/// Transposes rhs to enable contiguous access for dot product.
fn matmul_2d_i32(lhs: &FlexTensor, rhs: &FlexTensor) -> FlexTensor {
    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();

    let m = lhs_shape[0];
    let k = lhs_shape[1];
    let n = rhs_shape[1];

    let lhs_data: &[i32] = lhs.storage();
    let rhs_data: &[i32] = rhs.storage();

    // Transpose rhs [K, N] -> [N, K] for contiguous column access
    let mut rhs_t = vec![0i32; k * n];
    for i in 0..k {
        for j in 0..n {
            rhs_t[j * k + i] = rhs_data[i * n + j];
        }
    }

    let mut output = vec![0i32; m * n];

    // Now both lhs rows and rhs columns (transposed rows) are contiguous
    for i in 0..m {
        let lhs_row = &lhs_data[i * k..(i + 1) * k];
        for j in 0..n {
            let rhs_col = &rhs_t[j * k..(j + 1) * k];
            output[i * n + j] = dot_i32(lhs_row, rhs_col);
        }
    }

    let out_shape = Shape::from(vec![m, n]);
    FlexTensor::new(
        Bytes::from_elems(output),
        Layout::contiguous(out_shape),
        DType::I32,
    )
}

/// Dot product for i32 slices. Uses macerator SIMD when the `simd` feature is enabled.
#[inline]
fn dot_i32(a: &[i32], b: &[i32]) -> i32 {
    debug_assert_eq!(a.len(), b.len());

    #[cfg(feature = "simd")]
    {
        dot_i32_simd(a, b)
    }

    #[cfg(not(feature = "simd"))]
    {
        dot_i32_scalar(a, b)
    }
}

#[cfg(not(feature = "simd"))]
#[inline]
fn dot_i32_scalar(a: &[i32], b: &[i32]) -> i32 {
    let mut sum = 0i32;
    for i in 0..a.len() {
        sum = sum.wrapping_add(a[i].wrapping_mul(b[i]));
    }
    sum
}

#[cfg(feature = "simd")]
#[macerator::with_simd]
fn dot_i32_simd<S: macerator::Simd>(a: &[i32], b: &[i32]) -> i32 {
    use macerator::{Scalar, VMulAdd, vload_unaligned};

    let lanes = i32::lanes::<S>();
    let len = a.len();
    let simd_len = len / lanes * lanes;
    let mut acc = 0i32.splat::<S>();

    let mut i = 0;
    while i < simd_len {
        let va = unsafe { vload_unaligned(a.as_ptr().add(i)) };
        let vb = unsafe { vload_unaligned(b.as_ptr().add(i)) };
        acc = i32::vmul_add(va, vb, acc);
        i += lanes;
    }

    let mut sum = acc.reduce_add();
    while i < len {
        sum = sum.wrapping_add(a[i].wrapping_mul(b[i]));
        i += 1;
    }
    sum
}

/// Batched i32 matmul: [B..., M, K] x [B..., K, N] -> [B..., M, N]
///
/// Uses naive triple-loop with SIMD dot product and batch-level parallelism.
fn matmul_batched_i32(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();
    let lhs_rank = lhs_shape.num_dims();
    let rhs_rank = rhs_shape.num_dims();

    let m = lhs_shape[lhs_rank - 2];
    let k = lhs_shape[lhs_rank - 1];
    let n = rhs_shape[rhs_rank - 1];

    let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
    let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();

    let (broadcast_shape, lhs_strides, rhs_strides) = broadcast_batch_dims(&lhs_batch, &rhs_batch);

    let batch_size: usize = broadcast_shape.iter().product();
    let rhs_batch_size: usize = rhs_batch.iter().product();
    let lhs_matrix_size = checked_size(m, k);
    let rhs_matrix_size = checked_size(k, n);
    let out_matrix_size = checked_size(m, n);

    let mut out_dims = broadcast_shape.clone();
    out_dims.push(m);
    out_dims.push(n);
    let out_shape = Shape::from(out_dims);

    let lhs_data: &[i32] = lhs.storage();
    let rhs_data: &[i32] = rhs.storage();

    // Transpose rhs per actual rhs batch: [B_rhs, K, N] -> [B_rhs, N, K]
    let mut rhs_transposed = vec![0i32; rhs_batch_size * n * k];
    for b in 0..rhs_batch_size {
        let src_offset = b * rhs_matrix_size;
        let dst_offset = b * n * k;
        for i in 0..k {
            for j in 0..n {
                rhs_transposed[dst_offset + j * k + i] = rhs_data[src_offset + i * n + j];
            }
        }
    }

    let mut output = vec![0i32; batch_size * out_matrix_size];

    let run_one = |b: usize, out_slice: &mut [i32]| {
        let lhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &lhs_strides);
        let rhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &rhs_strides);
        let lhs_offset = lhs_batch_idx * lhs_matrix_size;
        let rhs_t_offset = rhs_batch_idx * n * k;

        let lhs_slice = &lhs_data[lhs_offset..lhs_offset + lhs_matrix_size];
        let rhs_t_slice = &rhs_transposed[rhs_t_offset..rhs_t_offset + n * k];

        for i in 0..m {
            let lhs_row = &lhs_slice[i * k..(i + 1) * k];
            for j in 0..n {
                let rhs_col = &rhs_t_slice[j * k..(j + 1) * k];
                out_slice[i * n + j] = dot_i32(lhs_row, rhs_col);
            }
        }
    };

    #[cfg(feature = "rayon")]
    {
        use rayon::prelude::*;
        output
            .par_chunks_mut(out_matrix_size)
            .enumerate()
            .for_each(|(b, out_slice)| run_one(b, out_slice));
    }

    #[cfg(not(feature = "rayon"))]
    {
        for b in 0..batch_size {
            let offset = b * out_matrix_size;
            run_one(b, &mut output[offset..offset + out_matrix_size]);
        }
    }

    FlexTensor::new(
        Bytes::from_elems(output),
        Layout::contiguous(out_shape),
        DType::I32,
    )
}

/// i64 matmul using naive triple loop.
fn matmul_i64(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    let lhs = lhs.to_contiguous();
    let rhs = rhs.to_contiguous();

    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();
    let lhs_rank = lhs_shape.num_dims();
    let rhs_rank = rhs_shape.num_dims();

    if lhs_rank == 2 && rhs_rank == 2 {
        matmul_2d_i64(&lhs, &rhs)
    } else {
        matmul_batched_i64(lhs, rhs)
    }
}

/// 2D i64 matmul: [M, K] x [K, N] -> [M, N]
fn matmul_2d_i64(lhs: &FlexTensor, rhs: &FlexTensor) -> FlexTensor {
    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();

    let m = lhs_shape[0];
    let k = lhs_shape[1];
    let n = rhs_shape[1];

    let lhs_data: &[i64] = lhs.storage();
    let rhs_data: &[i64] = rhs.storage();

    let mut output = vec![0i64; m * n];

    for i in 0..m {
        for j in 0..n {
            let mut sum = 0i64;
            for l in 0..k {
                sum = sum.wrapping_add(lhs_data[i * k + l].wrapping_mul(rhs_data[l * n + j]));
            }
            output[i * n + j] = sum;
        }
    }

    let out_shape = Shape::from(vec![m, n]);
    FlexTensor::new(
        Bytes::from_elems(output),
        Layout::contiguous(out_shape),
        DType::I64,
    )
}

/// Batched i64 matmul with broadcast support
fn matmul_batched_i64(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
    let lhs_shape = lhs.layout().shape();
    let rhs_shape = rhs.layout().shape();
    let lhs_rank = lhs_shape.num_dims();
    let rhs_rank = rhs_shape.num_dims();

    let m = lhs_shape[lhs_rank - 2];
    let k = lhs_shape[lhs_rank - 1];
    let n = rhs_shape[rhs_rank - 1];

    let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
    let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();

    // Compute broadcast batch dimensions
    let (broadcast_shape, lhs_strides, rhs_strides) = broadcast_batch_dims(&lhs_batch, &rhs_batch);

    let batch_size: usize = broadcast_shape.iter().product();
    let lhs_matrix_size = checked_size(m, k);
    let rhs_matrix_size = checked_size(k, n);
    let out_matrix_size = checked_size(m, n);

    let mut out_dims = broadcast_shape.clone();
    out_dims.push(m);
    out_dims.push(n);
    let out_shape = Shape::from(out_dims);

    let lhs_data: &[i64] = lhs.storage();
    let rhs_data: &[i64] = rhs.storage();

    let mut output = vec![0i64; batch_size * out_matrix_size];

    for b in 0..batch_size {
        let lhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &lhs_strides);
        let rhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &rhs_strides);
        let lhs_offset = lhs_batch_idx * lhs_matrix_size;
        let rhs_offset = rhs_batch_idx * rhs_matrix_size;
        let out_offset = b * out_matrix_size;

        for i in 0..m {
            for j in 0..n {
                let mut sum = 0i64;
                for l in 0..k {
                    let lhs_idx = lhs_offset + i * k + l;
                    let rhs_idx = rhs_offset + l * n + j;
                    sum = sum.wrapping_add(lhs_data[lhs_idx].wrapping_mul(rhs_data[rhs_idx]));
                }
                output[out_offset + i * n + j] = sum;
            }
        }
    }

    FlexTensor::new(
        Bytes::from_elems(output),
        Layout::contiguous(out_shape),
        DType::I64,
    )
}

// ============================================================================
// Tests
// ============================================================================

// Tests kept here exercise flex-specific behavior of the matmul kernel:
// dtype-specific storage paths (F64, F16, BF16) that the generic
// FloatElem-parameterized backend-tests cannot reach. Plain contiguous
// F32/I32/I64 matmul, stride-through-matmul variants (transposed /
// swap_dims / broadcast-transposed), and other generic coverage have
// been migrated to burn-backend-tests so they run against every backend.
// When adding new tests, keep them here only if they probe a flex
// dtype-storage path; otherwise add them to
// crates/burn-backend-tests/tests/tensor/float/ops/matmul.rs.
#[cfg(test)]
mod tests {
    use alloc::vec;
    use burn_backend::TensorData;
    use burn_backend::ops::FloatTensorOps;
    use burn_std::{bf16, f16};

    use crate::{Flex, FlexTensor};

    #[test]
    fn test_matmul_f64() {
        let lhs = FlexTensor::from_data(TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]));
        let rhs = FlexTensor::from_data(TensorData::new(vec![5.0f64, 6.0, 7.0, 8.0], [2, 2]));

        let result = Flex::float_matmul(lhs, rhs);
        let values: Vec<f64> = result.into_data().to_vec().unwrap();

        assert_eq!(values, vec![19.0, 22.0, 43.0, 50.0]);
    }

    #[test]
    fn test_matmul_f16() {
        let lhs_vals: Vec<f16> = [1.0f32, 2.0, 3.0, 4.0]
            .iter()
            .copied()
            .map(f16::from_f32)
            .collect();
        let rhs_vals: Vec<f16> = [5.0f32, 6.0, 7.0, 8.0]
            .iter()
            .copied()
            .map(f16::from_f32)
            .collect();

        let lhs = FlexTensor::from_data(TensorData::new(lhs_vals, [2, 2]));
        let rhs = FlexTensor::from_data(TensorData::new(rhs_vals, [2, 2]));

        let result = Flex::float_matmul(lhs, rhs);
        let values: Vec<f16> = result.into_data().to_vec().unwrap();

        let expected = [19.0f32, 22.0, 43.0, 50.0];
        for (a, e) in values.iter().zip(expected.iter()) {
            assert!((a.to_f32() - e).abs() < 0.1, "f16 matmul mismatch");
        }
    }

    #[test]
    fn test_matmul_bf16() {
        let lhs_vals: Vec<bf16> = [1.0f32, 2.0, 3.0, 4.0]
            .iter()
            .copied()
            .map(bf16::from_f32)
            .collect();
        let rhs_vals: Vec<bf16> = [5.0f32, 6.0, 7.0, 8.0]
            .iter()
            .copied()
            .map(bf16::from_f32)
            .collect();

        let lhs = FlexTensor::from_data(TensorData::new(lhs_vals, [2, 2]));
        let rhs = FlexTensor::from_data(TensorData::new(rhs_vals, [2, 2]));

        let result = Flex::float_matmul(lhs, rhs);
        let values: Vec<bf16> = result.into_data().to_vec().unwrap();

        let expected = [19.0f32, 22.0, 43.0, 50.0];
        for (a, e) in values.iter().zip(expected.iter()) {
            assert!((a.to_f32() - e).abs() < 0.5, "bf16 matmul mismatch");
        }
    }

    #[test]
    fn test_matmul_batched_transposed_f64() {
        // Non-contiguous (swap_dims) batched matmul on the F64 dtype path.
        let q_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [2, 2, 2]);
        let k_data = TensorData::new(vec![1.0f64, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0], [2, 2, 2]);

        let q = FlexTensor::from_data(q_data.clone());
        let k = FlexTensor::from_data(k_data.clone());
        let k_t = k.transpose(1, 2);
        let result = Flex::float_matmul(q, k_t);

        let q2 = FlexTensor::from_data(q_data);
        let k2 = FlexTensor::from_data(k_data)
            .transpose(1, 2)
            .to_contiguous();
        let expected = Flex::float_matmul(q2, k2);

        let values: Vec<f64> = result.into_data().to_vec().unwrap();
        let expected: Vec<f64> = expected.into_data().to_vec().unwrap();
        assert_eq!(values, expected);
    }

    #[test]
    fn test_matmul_batched_transposed_f16() {
        // Non-contiguous (swap_dims) batched matmul on the F16 dtype path.
        let f = f16::from_f32;
        let q_data = TensorData::new(
            vec![
                f(1.0),
                f(2.0),
                f(3.0),
                f(4.0),
                f(5.0),
                f(6.0),
                f(7.0),
                f(8.0),
            ],
            [2, 2, 2],
        );
        let k_data = TensorData::new(
            vec![
                f(1.0),
                f(0.0),
                f(0.0),
                f(1.0),
                f(2.0),
                f(0.0),
                f(0.0),
                f(2.0),
            ],
            [2, 2, 2],
        );

        let q = FlexTensor::from_data(q_data.clone());
        let k = FlexTensor::from_data(k_data.clone());
        let k_t = k.transpose(1, 2);
        let result = Flex::float_matmul(q, k_t);

        let q2 = FlexTensor::from_data(q_data);
        let k2 = FlexTensor::from_data(k_data)
            .transpose(1, 2)
            .to_contiguous();
        let expected = Flex::float_matmul(q2, k2);

        let values: Vec<f16> = result.into_data().to_vec().unwrap();
        let expected: Vec<f16> = expected.into_data().to_vec().unwrap();
        assert_eq!(values, expected);
    }
}