candle-metal-kernels 0.10.2

Metal kernels for Candle
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
#include <metal_stdlib>
#include <metal_limits>
using namespace metal;

template<uint Y>
constexpr uint div_ceil(uint x) {
    return x / Y + (x % Y > 0);
}

template<uint X, uint Y>
constexpr uint div_ceil() {
    return X / Y + (X % Y > 0);
}

template<typename T>
constexpr uint work_per_thread() {
    return div_ceil<8, sizeof(T)>();
}

METAL_FUNC uint nonzero(uint n) {
    return n == 0 ? 1 : n;
}

template<uint N>
constexpr uint nonzero() {
    return N == 0 ? 1 : N;
}

template<typename T>
constexpr ushort granularity() {
    return nonzero<vec_elements<T>::value>();
}

METAL_FUNC uint next_p2(uint x) {
    return 1 << (32 - clz(x - 1));
}

METAL_FUNC uint prev_p2(uint x) {
    return 1 << (31 - clz(x));
}

constant uint MAX_SHARED_MEM = 32767;

template<typename T>
METAL_FUNC uint max_shared_mem(uint n) {
    return min(n, div_ceil<MAX_SHARED_MEM, sizeof(T)>());
}


template<ushort D, typename IndexT>
struct strided_indexer {
    constant const IndexT *dims;
    constant const IndexT *strides;
    strided_indexer<D - 1, IndexT> next {dims, strides};

    METAL_FUNC IndexT operator()(IndexT idx) const {
        IndexT dim = dims[D - 1];
        IndexT i = (idx % dim) * strides[D - 1];
        idx /= dim;
        return i + next(idx);
    }
};

template<typename IndexT>
struct strided_indexer<1, IndexT> {
    constant const IndexT *dims;
    constant const IndexT *strides;

    METAL_FUNC IndexT operator()(IndexT idx) const {
        return idx * strides[0];
    }
};

template<ushort D, typename IndexT>
METAL_FUNC IndexT get_strided_idx_fallback(
    IndexT idx,
    constant const IndexT &num_dims,
    constant const IndexT *dims,
    constant const IndexT *strides
) {
    strided_indexer<D, IndexT> next {dims, strides};

    IndexT strided_i = 0;
    for (IndexT d = D; d < num_dims; d++) {
        IndexT dim_idx = num_dims - 1 - d;
        IndexT dim = dims[dim_idx];
        strided_i += (idx % dim) * strides[dim_idx];
        idx /= dim;
    }
    return strided_i + next(idx);
}

template<typename IndexT>
METAL_FUNC IndexT get_strided_index_t(
    IndexT idx,
    constant const IndexT &num_dims,
    constant const IndexT *dims,
    constant const IndexT *strides
) {
    switch (num_dims) {
        case 1: return strided_indexer<1, IndexT>{dims, strides}(idx);
        case 2: return strided_indexer<2, IndexT>{dims, strides}(idx);
        case 3: return strided_indexer<3, IndexT>{dims, strides}(idx);
        case 4: return strided_indexer<4, IndexT>{dims, strides}(idx);
        //case 5: return strided_indexer<5, IndexT>{dims, strides}(idx);
        //case 6: return strided_indexer<6, IndexT>{dims, strides}(idx);
        default: return get_strided_idx_fallback<4, IndexT>(idx, num_dims, dims, strides);
    }
}

template<typename IndexT, bool STRIDED>
struct indexer_t {
    typedef IndexT I;
};

template<typename IndexT>
struct indexer_t<IndexT, false> {
    typedef IndexT I;

    const IndexT last_dim = 0;

    METAL_FUNC IndexT operator()(IndexT i) const {
        return i;
    }
};

template<typename IndexT>
struct indexer_t<IndexT, true> {
    typedef IndexT I;

    constant const IndexT &num_dims;
    constant const IndexT *dims;
    constant const IndexT *strides;
    const IndexT last_dim;

    METAL_FUNC IndexT operator()(IndexT i) const {
        return get_strided_index_t(i, num_dims, dims, strides);
    }
};

struct Divide {
    template<typename T>
    METAL_FUNC T operator()(T a, T b) { return a / b; }
    METAL_FUNC float  operator()(float  a, float  b) { return fast::divide(a, b); }
    METAL_FUNC half   operator()(half   a, half   b) { return divide(a, b); }
    #if defined(__HAVE_BFLOAT__)
    METAL_FUNC bfloat  operator()(bfloat  a, bfloat  b) { return static_cast<bfloat>(fast::divide(a, b)); }
    #endif
};

struct Exp {
    template<typename T>
    METAL_FUNC T operator()(T a) { return fast::exp(a); }
    METAL_FUNC float  operator()(float  a) { return fast::exp(a); }
    METAL_FUNC half   operator()(half   a) { return exp(a); }
    #if defined(__HAVE_BFLOAT__)
    METAL_FUNC bfloat  operator()(bfloat  a) { return static_cast<bfloat>(fast::exp(a)); }
    #endif
};


// Keeps track of the index of the value in the reduction operation (argmin, argmax, etc.)
// and the value itself. The index is also used to break ties in the reduction operation.
template <typename T>
struct indexed {
    uint i;
    T val;

    constexpr indexed<T>() threadgroup = default;
};

template <typename T>
struct is_indexed_type {
    static constant constexpr bool value = false;
};

template <typename T>
constexpr constant bool is_indexed_t = is_indexed_type<T>::value;

template <typename T>
struct is_indexed_type<indexed<T>> {
    static constant constexpr bool value = true;
};

template <typename T>
constexpr constant bool not_indexed_t = !is_indexed_t<T>;

template<typename T>
constexpr METAL_FUNC bool operator<(indexed<T> lhs, indexed<T> rhs) {
    return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.i < rhs.i);
}

template<typename T>
constexpr METAL_FUNC bool operator>(indexed<T> lhs, indexed<T> rhs) {
    return lhs.val > rhs.val || (lhs.val == rhs.val && lhs.i < rhs.i);
}

template<typename T>
struct _numeric_limits_impl<indexed<T>> {
    static constexpr METAL_FUNC indexed<T> lowest() {
        return indexed<T>{ 0, numeric_limits<T>::lowest() };
    }

    static constexpr METAL_FUNC indexed<T> max() {
        return indexed<T>{ 0, numeric_limits<T>::max() };
    }
};

#if __METAL_VERSION__ >= 220
METAL_FUNC int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
  return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
#endif


#if defined(__HAVE_BFLOAT__)
// Metal does not have simd_shuffle_down for bfloat16
METAL_FUNC bfloat simd_shuffle_down(bfloat value, ushort delta) {
    return as_type<bfloat>(simd_shuffle_down(as_type<ushort>(value), delta));
}
#endif

template <typename T>
METAL_FUNC indexed<T> simd_shuffle_down(indexed<T> iv, ushort delta) {
    return indexed<T> {
        simd_shuffle_down(iv.i, delta),
        simd_shuffle_down(iv.val, delta)
    };
}

template<typename T>
struct Sum {
    static constexpr METAL_FUNC T init() {
        return 0;
    }
    static METAL_FUNC T simd_op(T a) {
        return simd_sum(a);
    }

    template<typename V>
    METAL_FUNC V operator()(V a, V b) {
        return a + b;
    }
};

template<typename T>
struct Mul {
    static constexpr METAL_FUNC T init() {
        return 1;
    }
    static METAL_FUNC T simd_op(T a) {
        return simd_product(a);
    }

    template<typename V>
    METAL_FUNC V operator()(V a, V b) {
        return a * b;
    }
};

template<typename T>
struct Min {
    static constexpr METAL_FUNC T init() {
        return numeric_limits<T>::max();
    }
    static METAL_FUNC T simd_op(T a) {
        return simd_min(a);
    }

    template<typename V>
    METAL_FUNC V operator()(V a, V b) { return a < b ? a : b; }

    METAL_FUNC float operator()(float a, float b) { return fast::min(a, b); }
    METAL_FUNC half   operator()(half   a, half   b) { return min(a, b); }
    METAL_FUNC uint operator()(uint a, uint b) { return min(a, b); }
    METAL_FUNC uchar operator()(uchar a, uchar b) { return min(a, b); }

    #if __METAL_VERSION__ >= 220
    METAL_FUNC long operator()(long a, long b) { return min(a, b); }
    #endif

    #if defined(__HAVE_BFLOAT__)
    METAL_FUNC bfloat operator()(bfloat a, bfloat b) { return static_cast<bfloat>(fast::min(static_cast<float>(a), static_cast<float>(b))); }
    #endif
};

template<typename T>
struct Max {
    static constexpr METAL_FUNC T init() {
        return numeric_limits<T>::lowest();
    }
    static METAL_FUNC T simd_op(T a) {
        return simd_max(a);
    }

    template<typename V>
    METAL_FUNC V operator()(V a, V b) { return a > b ? a : b; }

    METAL_FUNC float operator()(float a, float b) { return fast::max(a, b); }
    METAL_FUNC half operator()(half a, half b) { return max(a, b); }
    METAL_FUNC uint operator()(uint a, uint b) { return max(a, b); }
    METAL_FUNC uchar operator()(uchar a, uchar b) { return max(a, b); }

    #if __METAL_VERSION__ >= 220
    METAL_FUNC long operator()(long a, long b) { return max(a, b); }
    #endif

    #if defined(__HAVE_BFLOAT__)
    METAL_FUNC bfloat operator()(bfloat a, bfloat b) { return static_cast<bfloat>(fast::max(static_cast<float>(a), static_cast<float>(b))); }
    #endif
};

template <typename T>
constexpr constant bool is_simd_t = __is_valid_simdgroup_type<T>::value;

template <typename T, typename _E = void>
struct is_valid_simd_type {
    static constant constexpr bool value = false;
};

template <typename T>
constexpr constant bool is_valid_simd_t = is_valid_simd_type<T>::value;

template <typename T>
struct is_valid_simd_type<T, typename metal::enable_if_t<is_simd_t<T>>> {
    static constant constexpr bool value = true;
};

template <typename T>
struct is_valid_simd_type<indexed<T>, typename metal::enable_if_t<is_valid_simd_t<T>>> {
    static constant constexpr bool value = true;
};

#if __METAL_VERSION__ >= 220
template <>
struct is_valid_simd_type<int64_t> {
    static constant constexpr bool value = true;
};
#endif

#if defined(__HAVE_BFLOAT__)
template <>
struct is_valid_simd_type<bfloat> {
    static constant constexpr bool value = true;
};
#endif

template <typename T, typename _E = void>
struct is_simd_op {
    static constant constexpr bool value = false;
};
template <typename T>
struct is_simd_op<Sum<T>, typename metal::enable_if_t<is_simd_t<T>>> {
    static constant constexpr bool value = true;
};
template <typename T>
struct is_simd_op<Mul<T>, typename metal::enable_if_t<is_simd_t<T>>> {
    static constant constexpr bool value = true;
};
template <typename T>
struct is_simd_op<Min<T>, typename metal::enable_if_t<is_simd_t<T>>> {
    static constant constexpr bool value = true;
};
template <typename T>
struct is_simd_op<Max<T>, typename metal::enable_if_t<is_simd_t<T>>> {
    static constant constexpr bool value = true;
};

// Helper struct for applying operators.
// The overloaded operator() function is used to apply an operator to two values.
template<typename OP, typename T>
struct operation;

// Specialization for scalar values.
template<typename OP, typename T>
struct operation {
    OP op;

    METAL_FUNC T operator()(T a, T b) {
        return op(a, b);
    }
};

// Specialization for indexed values.
template<typename OP, typename T>
struct operation<OP, indexed<T>> {
    OP op;

    METAL_FUNC indexed<T> operator()(indexed<T> a, indexed<T> b) {
        return op(a, b);
    }
    METAL_FUNC indexed<T> operator()(indexed<T> a, T b, uint idx) {
        return this->operator()(a, indexed<T>{ idx, b });
    }
};

// Load elements from global memory into shared memory.
// Handles both indexed and non-indexed types by using operate.
template<
    typename T,
    typename R,
    typename OP,
    ushort BLOCKSIZE,
    typename Indexer,
    typename IndexT,
    typename _E = void
>
struct loader;

template<
    typename T,
    typename R,
    typename OP,
    ushort BLOCKSIZE,
    typename Indexer,
    typename IndexT
>
struct loader<T, R, OP, BLOCKSIZE, Indexer, IndexT, typename metal::enable_if_t<not_indexed_t<R>>> {
    operation<OP, R> operate;

    METAL_FUNC R operator()(
        R value,
        Indexer indexer,
        constant IndexT &src_numel,
        constant IndexT &el_per_block,
        device const T *src,
        const IndexT offset,
        const uint tid
    ) {
        const IndexT idx = tid + offset;
        const IndexT stop_idx = min(el_per_block + offset, src_numel);

        #pragma clang loop unroll(full)
        for (IndexT i = idx; i < stop_idx; i += BLOCKSIZE) {
            value = operate(value, src[indexer(i)]);
        }
        return value;
    }
};

// Indexed
template<
    typename T,
    typename R,
    typename OP,
    ushort BLOCKSIZE,
    typename Indexer,
    typename IndexT
>
struct loader<T, R, OP, BLOCKSIZE, Indexer, IndexT, typename metal::enable_if_t<is_indexed_t<R>>> {
    operation<OP, R> operate;

    METAL_FUNC R operator()(
        R value,
        Indexer indexer,
        constant IndexT &src_numel,
        constant IndexT &el_per_block,
        device const T *src,
        const IndexT offset,
        const uint tid
    ) {
        const IndexT idx = tid + offset;
        const IndexT stop_idx = min(el_per_block + offset, src_numel);

        #pragma clang loop unroll(full)
        for (IndexT i = idx; i < stop_idx; i += BLOCKSIZE) {
            value = operate(value, src[indexer(i)], i % indexer.last_dim);
        }
        return value;
    }
};

template<
    typename OP,
    ushort BLOCKSIZE,
    typename T,
    typename _E = void
>
struct simdgroup_reducer;

// Specialization for built-in simd operations.
template<typename OP, ushort BLOCKSIZE, typename T>
struct simdgroup_reducer<OP, BLOCKSIZE, T, typename metal::enable_if_t<is_simd_op<OP>::value && is_valid_simd_t<T>>> {
    METAL_FUNC T operator()(T value) {
        return OP::simd_op(value);
    }
};

// Specialization for custom (non-built-in) simd operations.
template<typename OP, ushort BLOCKSIZE, typename T>
struct simdgroup_reducer<OP, BLOCKSIZE, T, typename metal::enable_if_t<!is_simd_op<OP>::value && is_valid_simd_t<T>>> {
    operation<OP, T> op;

    METAL_FUNC T operator()(T value) {
        if (BLOCKSIZE >= 32) value = op(value, simd_shuffle_down(value, 16));
        if (BLOCKSIZE >= 16) value = op(value, simd_shuffle_down(value,  8));
        if (BLOCKSIZE >=  8) value = op(value, simd_shuffle_down(value,  4));
        if (BLOCKSIZE >=  4) value = op(value, simd_shuffle_down(value,  2));
        if (BLOCKSIZE >=  2) value = op(value, simd_shuffle_down(value,  1));
        return value;
    }
};

template<typename T, typename OP, ushort BLOCKSIZE>
struct block_reducer {
    simdgroup_reducer<OP, BLOCKSIZE, T> simd_reduce;
    operation<OP, T> operate;
    threadgroup T *shared;

    block_reducer(threadgroup T shared[BLOCKSIZE]) {
        this->shared = shared;
    }

    METAL_FUNC T operator()(T value, const uint tid) {
        if (BLOCKSIZE >= 64) {
            // Only store in threadgroup shared memory if needed.
            shared[tid] = value;
            // Threadgroup barrier is needed to ensure that all threads have written to shared memory
            threadgroup_barrier(mem_flags::mem_none);
        }

        #pragma clang loop unroll(full)
        for (ushort s = BLOCKSIZE / 2; s >= 64; s >>= 1) {
            if (tid < s) shared[tid] = operate(shared[tid], shared[tid + s]);
            threadgroup_barrier(mem_flags::mem_none);
        }
        if (tid < 32) {
            // Last shared memory reduce can be done without tid < s check.
            if (BLOCKSIZE >= 64) {
                value = operate(shared[tid], shared[tid + 32]);
                simdgroup_barrier(mem_flags::mem_none);
            }
            // Remaining 32 threads can be reduced with simdgroup_reduce.
            value = simd_reduce(value);
        }
        return value;
    }
};

template<typename T, typename _E = void>
struct storer;

template<typename T>
struct storer<T, typename metal::enable_if_t<not_indexed_t<T>>> {
    device T *dst;
    const uint tid;
    const uint dst_id;

    METAL_FUNC void operator()(T value) {
        if (tid == 0) {
            dst[dst_id] = value;
        }
    }
};

template<typename T>
struct storer<T, typename metal::enable_if_t<is_indexed_t<T>>> {
    device uint *dst;
    const uint tid;
    const uint dst_id;

    METAL_FUNC void operator()(T value) {
        if (tid == 0) {
            dst[dst_id] = value.i;
        }
    }
};

// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
template<
    typename T,
    typename R,
    typename OP,
    ushort BLOCKSIZE,
    typename Indexer,
    typename IndexT = typename Indexer::IndexT
>
METAL_FUNC void reduce(
    Indexer indexer,
    constant IndexT &src_numel,
    constant IndexT &el_per_block,
    device const T *src,
    device R *dst,
    threadgroup R shared[BLOCKSIZE],
    uint tid [[ thread_index_in_threadgroup ]],
    uint dst_id [[ threadgroup_position_in_grid ]]
) {
    loader<T, R, OP, BLOCKSIZE, Indexer, IndexT> load;
    block_reducer<R, OP, BLOCKSIZE> reduce(shared);
    storer<R> store { dst, tid, dst_id };

    // Calculate offset for the threadgroup of current thread
    const IndexT offset = dst_id * el_per_block;

    // Load with reduction from global memory into shared memory
    auto value = load(OP::init(), indexer, src_numel, el_per_block, src, offset, tid);

    // Complete reduction
    R result = reduce(value, tid);

    store(result);
}

#define reduce_switch(CASE_MACRO, OP, T, R, INDEXER)    \
    switch (max_shared_mem<T>(block_dim)) {             \
        CASE_MACRO(OP, T, R, 1024, INDEXER)             \
        CASE_MACRO(OP, T, R,  512, INDEXER)             \
        CASE_MACRO(OP, T, R,  256, INDEXER)             \
        CASE_MACRO(OP, T, R,  128, INDEXER)             \
        CASE_MACRO(OP, T, R,   64, INDEXER)             \
        CASE_MACRO(OP, T, R,   32, INDEXER)             \
        CASE_MACRO(OP, T, R,   16, INDEXER)             \
        CASE_MACRO(OP, T, R,    8, INDEXER)             \
        CASE_MACRO(OP, T, R,    4, INDEXER)             \
        CASE_MACRO(OP, T, R,    2, INDEXER)             \
        CASE_MACRO(OP, T, R,    1, INDEXER)             \
    }

#define reduce_case(OP, T, R, N, INDEXER)                               \
case N: {                                                               \
    threadgroup T shared[N];                                            \
    reduce<T, R, OP<R>, N>(                                             \
        INDEXER, src_numel, el_per_block, src, dst, shared, tid, dst_id \
    );                                                                  \
    break;                                                              \
}

#define impl_reduce_inner(OP, NAME, T)              \
kernel void NAME(                                   \
    constant uint &src_numel,                       \
    constant uint &num_dims,                        \
    constant uint *dims,                            \
    constant uint &el_per_block,                    \
    device const T *src,                            \
    device T *dst,                                  \
    uint tid [[ thread_index_in_threadgroup ]],     \
    uint dst_id [[ threadgroup_position_in_grid ]], \
    uint block_dim [[ threads_per_threadgroup ]]    \
) {                                                 \
    indexer_t<uint, false> indexer;                 \
    reduce_switch(reduce_case, OP, T, T, indexer)   \
}

#define impl_reduce_strided(OP, NAME, T)            \
kernel void NAME##_strided(                         \
    constant uint &src_numel,                       \
    constant uint &num_dims,                        \
    constant uint *dims,                            \
    constant uint *strides,                         \
    constant uint &el_per_block,                    \
    device const T *src,                            \
    device T *dst,                                  \
    uint tid [[ thread_index_in_threadgroup ]],     \
    uint dst_id [[ threadgroup_position_in_grid ]], \
    uint block_dim [[ threads_per_threadgroup ]]    \
) {                                                 \
    indexer_t<uint, true> indexer {                 \
        num_dims, dims, strides, dims[num_dims - 1] \
    };                                              \
    reduce_switch(reduce_case, OP, T, T, indexer)   \
}

#define impl_reduce(OP, NAME, T)                    \
impl_reduce_inner(OP, NAME, T)                      \
impl_reduce_strided(OP, NAME, T)

template<
    typename T,
    typename ReductionOp,
    ushort BLOCKSIZE,
    typename Indexer,
    typename IndexT = typename Indexer::IndexT
>
METAL_FUNC void reduce(
    Indexer indexer,
    constant IndexT &src_numel,
    constant IndexT &el_per_block,
    device const T *src,
    device uint *dst,
    threadgroup indexed<T> shared[BLOCKSIZE],
    uint tid [[ thread_index_in_threadgroup ]],
    uint dst_id [[ threadgroup_position_in_grid ]]
) {
    using I = indexed<T>;
    loader<T, I, ReductionOp, BLOCKSIZE, Indexer, IndexT> load;
    block_reducer<I, ReductionOp, BLOCKSIZE> reduce(shared);
    storer<I> store { dst, tid, dst_id };

    // Calculate offset for the threadgroup of current thread
    const uint offset = dst_id * el_per_block;

    // Load with reduction from global memory into shared memory
    auto value = load(
        ReductionOp::init(),
        indexer,
        src_numel,
        el_per_block,
        src,
        offset,
        tid
    );

    // Complete reduction
    I result = reduce(value, tid);

    // Return index of reduce result
    store(result);
}

#define arg_reduce_case(OP, T, R, N, INDEXER)           \
case N: {                                               \
    using I = indexed<R>;                               \
    threadgroup I shared[N];                            \
    reduce<T, OP<I>, N>(                                \
        indexer,                                        \
        src_numel,                                      \
        el_per_block,                                   \
        src,                                            \
        dst,                                            \
        shared,                                         \
        tid,                                            \
        dst_id);                                        \
    break;                                              \
}

#define impl_arg_reduce_inner(OP, NAME, T)              \
kernel void NAME(                                       \
    constant uint &src_numel,                           \
    constant uint &num_dims,                            \
    constant uint *dims,                                \
    constant uint &el_per_block,                        \
    device const T *src,                                \
    device uint *dst,                                   \
    uint tid [[ thread_index_in_threadgroup ]],         \
    uint dst_id [[ threadgroup_position_in_grid ]],     \
    uint block_dim [[ threads_per_threadgroup ]]        \
) {                                                     \
    indexer_t<uint, false> indexer {                    \
        dims[num_dims - 1]                              \
    };                                                  \
    reduce_switch(arg_reduce_case, OP, T, T, indexer)   \
}                                                       \

#define impl_arg_reduce_strided(OP, NAME, T)            \
kernel void NAME##_strided(                             \
    constant uint &src_numel,                           \
    constant uint &num_dims,                            \
    constant uint *dims,                                \
    constant uint *strides,                             \
    constant uint &el_per_block,                        \
    device const T *src,                                \
    device uint *dst,                                   \
    uint tid [[ thread_index_in_threadgroup ]],         \
    uint dst_id [[ threadgroup_position_in_grid ]],     \
    uint block_dim [[ threads_per_threadgroup ]]        \
) {                                                     \
    indexer_t<uint, true> indexer {                     \
        num_dims, dims, strides, dims[num_dims - 1]     \
    };                                                  \
    reduce_switch(arg_reduce_case, OP, T, T, indexer)   \
}

#define impl_arg_reduce(OP, NAME, T)                    \
impl_arg_reduce_inner(OP, NAME, T)                      \
impl_arg_reduce_strided(OP, NAME, T)

// Contains the intermediate results for the online softmax calculation.
// m: max
// d: sum of the exponentials
template <typename T>
struct MD {
    T m;
    float d;

    constexpr MD<T>() = default;
    constexpr MD<T>() threadgroup = default;
};

// Enable operations for softmax MD
template<typename OP, typename T>
struct operation<OP, MD<T>> {
    OP op;

    METAL_FUNC MD<T> operator()(MD<T> a, MD<T> b) {
        return op(a, b);
    }

    METAL_FUNC MD<T> operator()(MD<T> a, T b) {
        return this->operator()(a, MD<T>{ b, static_cast<T>(1.0) });
    }
};

template <typename T>
METAL_FUNC MD<T> simd_shuffle_down(MD<T> md, ushort delta) {
    return MD<T> {
        simd_shuffle_down(md.m, delta),
        simd_shuffle_down(md.d, delta)
    };
}

// Enable simd_shuffle_down for softmax MD
template <typename T>
struct is_valid_simd_type<MD<T>, typename metal::enable_if_t<is_valid_simd_t<T>>> {
    static constant constexpr bool value = true;
};

template<typename T>
struct MDReduceOp {
    Exp fast_exp;

    static constexpr METAL_FUNC MD<T> init() {
        return MD<T>{ numeric_limits<T>::lowest(), 0 };
    }

    METAL_FUNC MD<T> operator()(MD<T> a, MD<T> b) {
        bool a_bigger = a.m > b.m;
        MD<T> bigger_m = a_bigger ? a : b;
        MD<T> smaller_m = a_bigger ? b : a;
        MD<T> res;
        res.d = bigger_m.d + smaller_m.d * fast_exp(smaller_m.m - bigger_m.m);
        res.m = bigger_m.m;
        return res;
    }
};

template<typename T, ushort BLOCKSIZE>
struct finalize_softmax {
    Divide fast_divide;
    Exp fast_exp;

    METAL_FUNC void operator()(
        device const T *src,
        device T *dst,
        threadgroup MD<T> &md_total,
        const uint thread_id,
        const uint stop_idx
    ) {
        const float d_total_inverse = fast_divide(1.0, md_total.d);
        for (uint idx = thread_id; idx < stop_idx; idx += BLOCKSIZE) {
            dst[idx] = static_cast<T>(fast_exp(src[idx] - md_total.m) * d_total_inverse);
        }
    }
};


// Welford's algorithm approach for an online softmax implementation.
// Same as the Online normalizer calculation for softmax: https://arxiv.org/pdf/1805.02867.pdf
template<typename T, ushort BLOCKSIZE>
METAL_FUNC void softmax(
    constant uint &src_numel,
    constant uint &el_per_block,
    device const T *src,
    device T *dst,
    threadgroup MD<T> shared[BLOCKSIZE],
    threadgroup MD<T> &md_total,

    uint tid [[ thread_index_in_threadgroup ]],
    uint dst_id [[ threadgroup_position_in_grid ]]
) {
    using MDReduceOp = MDReduceOp<T>;
    using Indexer = indexer_t<uint, false>;
    Indexer indexer;
    loader<T, MD<T>, MDReduceOp, BLOCKSIZE, Indexer, uint> load;
    block_reducer<MD<T>, MDReduceOp, BLOCKSIZE> reduce(shared);
    finalize_softmax<T, BLOCKSIZE> softmax_finalize;

    // Calculate offset for the threadgroup of current thread;
    const uint offset = dst_id * el_per_block;

    // Calculate partial result for current thread
    MD<T> md_partial = MD<T> { numeric_limits<T>::lowest(), 0 };
    md_partial = load(
        md_partial,
        indexer,
        src_numel,
        el_per_block,
        src,
        offset,
        tid
    );

    // Reduce in shared memory
    MD<T> md = reduce(md_partial, tid);

    if (tid == 0) md_total = md;
    threadgroup_barrier(mem_flags::mem_none);

    // Finalize softmax
    const uint thread_id = tid + offset;
    const uint stop_idx = min(el_per_block + offset, src_numel);
    softmax_finalize(src, dst, md_total, thread_id, stop_idx);
}

#define softmax_case(T, N)                              \
case N: {                                               \
    threadgroup MD<T> shared[N];                        \
    threadgroup MD<T> md_total;                         \
    softmax<T, N>(                                      \
        src_numel,                                      \
        el_per_block,                                   \
        src,                                            \
        dst,                                            \
        shared,                                         \
        md_total,                                       \
        tid,                                            \
        dst_id);                                        \
    break;                                              \
}

#define impl_softmax(NAME, T)                           \
kernel void NAME(                                       \
    constant uint &src_numel,                           \
    constant uint &el_per_block,                        \
    device const T *src,                                \
    device T *dst,                                      \
    uint tid [[ thread_index_in_threadgroup ]],         \
    uint dst_id [[ threadgroup_position_in_grid ]],     \
    uint block_dim [[ threads_per_threadgroup ]]        \
) {                                                     \
    switch (max_shared_mem<T>(block_dim)) {             \
        softmax_case(T, 1024);                          \
        softmax_case(T,  512);                          \
        softmax_case(T,  256);                          \
        softmax_case(T,  128);                          \
        softmax_case(T,   64);                          \
        softmax_case(T,   32);                          \
        softmax_case(T,   16);                          \
        softmax_case(T,    8);                          \
        softmax_case(T,    4);                          \
        softmax_case(T,    2);                          \
        softmax_case(T,    1);                          \
    }                                                   \
}


template<typename T>
METAL_FUNC void rmsnorm(
    constant size_t &src_numel,
    constant size_t &el_to_sum_per_block,
    device const T *src,
    device T *dst,
    device const T *alpha,
    constant float &eps,
    uint id,
    uint tid,
    uint dst_id,
    uint block_dim,
    threadgroup float * shared_memory
) {
    size_t start_idx = dst_id * el_to_sum_per_block;
    size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
    size_t idx = start_idx + tid;

    float tmp = 0;
    while (idx < stop_idx) {
        tmp = tmp + float(src[idx]) * float(src[idx]);
        idx += block_dim;
    }
    shared_memory[tid] = tmp;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint s = block_dim / 2; s > 0; s >>= 1) {
        if (tid < s) {
            shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    /* wait for shared_memory[0] to be filled */
    threadgroup_barrier(mem_flags::mem_threadgroup);

    float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps);
    float inv_norm = 1.0f / norm;
    idx = start_idx + tid;
    while (idx < stop_idx) {
        float val = float(src[idx]) * inv_norm;
        if (alpha != nullptr) {
            val *= float(alpha[idx - start_idx]);
        }
        dst[idx] = T(val);
        idx += block_dim;
    }
}

template<typename T>
struct RMS {
    uint count;
    T mean;

    constexpr RMS<T>() = default;
    constexpr RMS<T>() threadgroup = default;
};

template<typename T>
struct RMSLoadOp {
    static constexpr METAL_FUNC RMS<T> init() {
        return { 0, 0 };
    }

    METAL_FUNC RMS<T> operator()(RMS<T> a, RMS<T> b) {
        a.mean += (b.mean * b.mean);
        a.count += 1;
        return a;
    }
};

template<typename T>
struct RMSReduceOp {
    static constexpr METAL_FUNC RMS<T> init() {
        return { 0, 0 };
    }

    METAL_FUNC RMS<T> operator()(RMS<T> a, RMS<T> b) {
        uint new_count = a.count + b.count;
        uint nb_over_n = b.count / new_count;
        T delta = b.mean - a.mean;
        //a.mean += delta * nb_over_n;
        a.mean += b.mean + delta * delta * a.count * nb_over_n;
        // *m2 += b_m2 + delta * delta * (*count) * nb_over_n;
        a.count = new_count;
        return a;
    }
};

template<typename OP, typename T>
struct operation<OP, RMS<T>> {
    OP op;

    METAL_FUNC RMS<T> operator()(RMS<T> a, RMS<T> b) {
        return op(a, b);
    }

    template<typename U>
    METAL_FUNC RMS<T> operator()(RMS<T> a, U b) {
        return this->operator()(a, RMS<T>{ 0, static_cast<T>(b) });
    }
};

template <typename T>
METAL_FUNC RMS<T> simd_shuffle_down(RMS<T> rms, ushort delta) {
    return RMS<T> {
        simd_shuffle_down(rms.count, delta),
        simd_shuffle_down(rms.mean, delta)
    };
}

template <typename T>
struct is_valid_simd_type<RMS<T>, typename metal::enable_if_t<is_valid_simd_t<T>>> {
    static constant constexpr bool value = true;
};

// Kernels
template<
    typename T,
    ushort BLOCKSIZE
>
METAL_FUNC void rms_norm(
    constant uint &src_numel,
    constant uint &el_per_block,
    device const T *src,
    device T *dst,
    device const T *alpha,
    constant float &eps,
    threadgroup RMS<float> shared[BLOCKSIZE],
    threadgroup float &total,

    uint tid [[ thread_index_in_threadgroup ]],
    uint dst_id [[ threadgroup_position_in_grid ]]
) {
    using Indexer = indexer_t<uint, false>;
    Indexer indexer;
    Divide fast_divide;
    loader<T, RMS<float>, RMSLoadOp<float>, BLOCKSIZE,  Indexer, uint> load;
    block_reducer<RMS<float>, RMSReduceOp<float>, BLOCKSIZE> reduce(shared);

    // Calculate offset for the threadgroup of current thread
    const uint offset = dst_id * el_per_block;
    const uint stop_idx = min(el_per_block + offset, src_numel);
    const uint idx = tid + offset;

    // Load with reduction from global memory into shared memory
    RMS<float> value = load(
        RMSLoadOp<float>::init(),
        indexer,
        src_numel,
        el_per_block,
        src,
        offset,
        tid
    );
    RMS<float> result = RMS<float> { value.count, static_cast<float>(value.mean) };

    // Complete reduction
    result = reduce(result, tid);
    if (tid == 0) {
        total = rsqrt(fast_divide(result.mean, float(el_per_block)) + eps);
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (alpha == nullptr) {
        #pragma clang loop unroll(full)
        for (uint i = idx; i < stop_idx; i += BLOCKSIZE) {
            dst[i] = src[i] * static_cast<T>(total);
        }
    } else {
        #pragma clang loop unroll(full)
        for (uint i = idx; i < stop_idx; i += BLOCKSIZE) {
            T val = src[i] * static_cast<T>(total);
            val *= alpha[i - offset];
            dst[i] = val;
        }
    }
}


#define rms_norm_case(T, N)                             \
case N: {                                               \
    threadgroup RMS<float> shared[N];                   \
    threadgroup float total;                            \
    rms_norm<T, N>(                                     \
        src_numel,                                      \
        el_per_block,                                   \
        src,                                            \
        dst,                                            \
        alpha,                                          \
        eps,                                            \
        shared,                                         \
        total,                                          \
        tid,                                            \
        dst_id);                                        \
    break;                                              \
}

#define impl_rms_norm(NAME, T)                          \
kernel void NAME(                                       \
    constant uint &src_numel,                           \
    constant uint &el_per_block,                        \
    device const T *src,                                \
    device T *dst,                                      \
    device const T *alpha,                              \
    constant float &eps,                                \
    uint tid [[ thread_index_in_threadgroup ]],         \
    uint dst_id [[ threadgroup_position_in_grid ]],     \
    uint block_dim [[ threads_per_threadgroup ]]        \
) {                                                     \
    switch (max_shared_mem<float>(block_dim)) {         \
        rms_norm_case(T, 1024);                         \
        rms_norm_case(T,  512);                         \
        rms_norm_case(T,  256);                         \
        rms_norm_case(T,  128);                         \
        rms_norm_case(T,   64);                         \
        rms_norm_case(T,   32);                         \
        rms_norm_case(T,   16);                         \
        rms_norm_case(T,    8);                         \
        rms_norm_case(T,    4);                         \
        rms_norm_case(T,    2);                         \
        rms_norm_case(T,    1);                         \
    }                                                   \
}

template<typename T>
struct LayerNormValue {
    uint count;
    T mean;
    T m2;

    constexpr LayerNormValue<T>() = default;
    constexpr LayerNormValue<T>() threadgroup = default;
};

template<typename T>
struct LNLoadOp {
    static constexpr METAL_FUNC LayerNormValue<T> init() {
        return { 0, 0, 0 };
    }

    METAL_FUNC LayerNormValue<T> operator()(LayerNormValue<T> a, LayerNormValue<T> b) {
        a.count += 1;
        T delta1 = b.mean - a.mean;
        a.mean += delta1 / a.count;
        T delta2 = b.mean - a.mean;
        a.m2 += delta1 * delta2;
        return a;
    }
};

template<typename T>
struct LNReduceOp {
    static constexpr METAL_FUNC LayerNormValue<T> init() {
        return { 0, 0, 0 };
    }

    METAL_FUNC LayerNormValue<T> operator()(LayerNormValue<T> a, LayerNormValue<T> b) {
        if (b.count == 0) {
            return a;
        }
        uint new_count = a.count + b.count;
        T nb_over_n = b.count / T(new_count);
        T delta = b.mean - a.mean;
        a.mean += delta * nb_over_n;
        a.m2 += b.m2 + delta * delta * a.count * nb_over_n;
        a.count = new_count;
        return a;
    }
};

template<typename OP, typename T>
struct operation<OP, LayerNormValue<T>> {
    OP op;

    METAL_FUNC LayerNormValue<T> operator()(LayerNormValue<T> a, LayerNormValue<T> b) {
        return op(a, b);
    }

    template<typename U>
    METAL_FUNC LayerNormValue<T> operator()(LayerNormValue<T> a, U b) {
        return this->operator()(a, LayerNormValue<T>{ 0, static_cast<T>(b), static_cast<T>(b) });
    }
};

template <typename T>
METAL_FUNC LayerNormValue<T> simd_shuffle_down(LayerNormValue<T> lnv, ushort delta) {
    return LayerNormValue<T> {
        simd_shuffle_down(lnv.count, delta),
        simd_shuffle_down(lnv.mean, delta),
        simd_shuffle_down(lnv.m2, delta)
    };
}

template <typename T>
struct is_valid_simd_type<LayerNormValue<T>, typename metal::enable_if_t<is_valid_simd_t<T>>> {
    static constant constexpr bool value = true;
};

// Kernels
template<
    typename T,
    ushort BLOCKSIZE
>
METAL_FUNC void layer_norm(
    constant uint &src_numel,
    constant uint &el_per_block,
    device const T *src,
    device T *dst,
    device const T *alpha,
    device const T *beta,
    constant float &eps,
    threadgroup LayerNormValue<float> shared[BLOCKSIZE],
    threadgroup float &mu,
    threadgroup float &sigma,

    uint tid [[ thread_index_in_threadgroup ]],
    uint dst_id [[ threadgroup_position_in_grid ]],
    uint lane_id [[thread_index_in_simdgroup]]
) {
    using Indexer = indexer_t<uint, false>;
    Indexer indexer;
    Divide fast_divide;
    loader<T, LayerNormValue<float>, LNLoadOp<float>, BLOCKSIZE,  Indexer, uint> load;
    block_reducer<LayerNormValue<float>, LNReduceOp<float>, BLOCKSIZE> reduce(shared);

    // Calculate offset for the threadgroup of current thread
    const uint offset = dst_id * el_per_block;
    const uint stop_idx = min(el_per_block + offset, src_numel);
    const uint idx = tid + offset;

    // Load with reduction from global memory into shared memory
    LayerNormValue<float> value = load(
        LNReduceOp<float>::init(),
        indexer,
        src_numel,
        el_per_block,
        src,
        offset,
        tid
    );
    LayerNormValue<float> result = LayerNormValue<float> { value.count, static_cast<float>(value.mean), static_cast<float>(value.m2) };

    // Complete reduction
    result = reduce(result, tid);
    if (tid == 0) {
        mu = result.mean;
        sigma = rsqrt(fast_divide(result.m2, float(result.count)) + eps);
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (alpha == nullptr || beta == nullptr) {
        if (alpha == nullptr) {
            #pragma clang loop unroll(full)
            for (uint i = idx; i < stop_idx; i += BLOCKSIZE) {
                T val = src[i];
                T normalized = (val - static_cast<T>(mu)) * static_cast<T>(sigma);
                dst[i] = normalized + beta[i - offset];
            }
        } else {
            #pragma clang loop unroll(full)
            for (uint i = idx; i < stop_idx; i += BLOCKSIZE) {
                T val = src[i];
                T normalized = (val - static_cast<T>(mu)) * static_cast<T>(sigma);
                dst[i] = normalized * alpha[i - offset];
            }
        }
    } else {
        #pragma clang loop unroll(full)
        for (uint i = idx; i < stop_idx; i += BLOCKSIZE) {
            T val = src[i];
            T normalized = (val - static_cast<T>(mu)) * static_cast<T>(sigma);
            dst[i] = static_cast<T>(fma(normalized, alpha[i - offset], beta[i - offset]));
        }
    }
}

#define layer_norm_case(T, N)                           \
case N: {                                               \
    threadgroup LayerNormValue<float> shared[N];        \
    threadgroup float mu;                               \
    threadgroup float sigma;                            \
    layer_norm<T, N>(                                   \
        src_numel,                                      \
        el_per_block,                                   \
        src,                                            \
        dst,                                            \
        alpha,                                          \
        beta,                                           \
        eps,                                            \
        shared,                                         \
        mu,                                             \
        sigma,                                          \
        tid,                                            \
        dst_id,                                         \
        lane_id);                                       \
    break;                                              \
}

#define impl_layer_norm(NAME, T)                        \
kernel void NAME(                                       \
    constant uint &src_numel,                           \
    constant uint &el_per_block,                        \
    device const T *src,                                \
    device T *dst,                                      \
    device const T *alpha,                              \
    device const T *beta,                               \
    constant float &eps,                                \
    uint tid [[ thread_index_in_threadgroup ]],         \
    uint dst_id [[ threadgroup_position_in_grid ]],     \
    uint lane_id [[thread_index_in_simdgroup]],         \
    uint block_dim [[ threads_per_threadgroup ]]        \
) {                                                     \
    switch (max_shared_mem<float>(block_dim)) {         \
        layer_norm_case(T, 1024);                       \
        layer_norm_case(T,  512);                       \
        layer_norm_case(T,  256);                       \
        layer_norm_case(T,  128);                       \
        layer_norm_case(T,   64);                       \
        layer_norm_case(T,   32);                       \
        layer_norm_case(T,   16);                       \
        layer_norm_case(T,    8);                       \
        layer_norm_case(T,    4);                       \
        layer_norm_case(T,    2);                       \
        layer_norm_case(T,    1);                       \
    }                                                   \
}

template<typename T>
METAL_FUNC void ropei(
    constant size_t &bh,
    constant size_t &td,
    constant size_t &stride_b,
    device const T *src,
    device const T *cos,
    device const T *sin,
    device T *dst,
    uint tid
) {
    if (2 * tid >= bh * td) {
        return;
    }
    size_t rope_idx = tid % (td / 2);
    if (stride_b > 0) {
      size_t b_idx = (2 * tid) / stride_b;
      rope_idx += b_idx * (td / 2);
    }
    T c = cos[rope_idx];
    T s = sin[rope_idx];
    dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
    dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c;
}

template<typename T>
METAL_FUNC void rope(
    constant size_t &bh,
    constant size_t &td,
    constant size_t &d,
    constant size_t &stride_b,
    device const T *src,
    device const T *cos,
    device const T *sin,
    device T *dst,
    uint idx
) {
    if (2 * idx >= bh * td) {
        return;
    }
    size_t i_bh = idx / (td / 2);
    size_t i_td = idx - (td / 2) * i_bh;
    size_t i_t = i_td / (d / 2);
    size_t i_d = i_td - (d / 2) * i_t;
    size_t i1 = i_bh * td + i_t * d + i_d;
    size_t i2 = i1 + d / 2;
    size_t i_cs = i_t * (d / 2) + i_d;
    if (stride_b > 0) {
      size_t b_idx = (2 * idx) / stride_b;
      i_cs += b_idx * (td / 2);
    }
    T c = cos[i_cs];
    T s = sin[i_cs];
    dst[i1] = src[i1] * c - src[i2] * s;
    dst[i2] = src[i1] * s + src[i2] * c;
}

template<typename T>
METAL_FUNC void rope_thd(
    constant size_t &b,
    constant size_t &t,
    constant size_t &h,
    constant size_t &d,
    constant size_t &stride_b,
    device const T *src,
    device const T *cos,
    device const T *sin,
    device T *dst,
    uint idx
) {
    if (2 * idx >= b * t * h * d) {
        return;
    }
    const size_t i_bth = idx / (d / 2);
    const size_t i_d = idx - (d / 2) * i_bth;
    const size_t i_t = (i_bth / h) % t;
    const size_t i1 = i_bth * d + i_d;
    const size_t i2 = i1 + d / 2;
    size_t i_cs = i_t * (d / 2) + i_d;
    if (stride_b > 0) {
      const size_t b_idx = (2 * idx) / stride_b;
      i_cs += b_idx * ((t * d) / 2);
    }
    T c = cos[i_cs];
    T s = sin[i_cs];
    dst[i1] = src[i1] * c - src[i2] * s;
    dst[i2] = src[i1] * s + src[i2] * c;
}

#define ROPE(FN_NAME, FN_NAME_I, FN_NAME_THD, TYPENAME) \
kernel void FN_NAME_I( \
    constant size_t &bh, \
    constant size_t &td, \
    constant size_t &stride_b, \
    device const TYPENAME *src,  \
    device const TYPENAME *cos,  \
    device const TYPENAME *sin,  \
    device TYPENAME *dst, \
    uint tid [[ thread_position_in_grid ]] \
) { \
    ropei<TYPENAME>(bh, td, stride_b, src, cos, sin, dst, tid); \
}\
kernel void FN_NAME( \
    constant size_t &bh, \
    constant size_t &td, \
    constant size_t &d, \
    constant size_t &stride_b, \
    device const TYPENAME *src,  \
    device const TYPENAME *cos,  \
    device const TYPENAME *sin,  \
    device TYPENAME *dst, \
    uint idx [[ thread_position_in_grid ]] \
) { \
    rope<TYPENAME>(bh, td, d, stride_b, src, cos, sin, dst, idx); \
}\
kernel void FN_NAME_THD( \
    constant size_t &b, \
    constant size_t &t, \
    constant size_t &h, \
    constant size_t &d, \
    constant size_t &stride_b, \
    device const TYPENAME *src,  \
    device const TYPENAME *cos,  \
    device const TYPENAME *sin,  \
    device TYPENAME *dst, \
    uint idx [[ thread_position_in_grid ]] \
) { \
    rope_thd<TYPENAME>(b, t, h, d, stride_b, src, cos, sin, dst, idx); \
}\

impl_rms_norm(rmsnorm_f32, float)
impl_rms_norm(rmsnorm_f16, half)
impl_layer_norm(layernorm_f32, float)
impl_layer_norm(layernorm_f16, half)
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)

impl_reduce(Sum, fast_sum_f32, float)
impl_reduce(Sum, fast_sum_u32, uint)
impl_reduce(Sum, fast_sum_f16, half)
impl_reduce(Sum, fast_sum_u8, uint8_t)

impl_reduce(Mul, fast_mul_f32, float)
impl_reduce(Mul, fast_mul_u32, uint)
impl_reduce(Mul, fast_mul_f16, half)
impl_reduce(Mul, fast_mul_u8, uint8_t)

impl_reduce(Max, fast_max_f32, float)
impl_reduce(Max, fast_max_u32, uint)
impl_reduce(Max, fast_max_f16, half)
impl_reduce(Max, fast_max_u8, uint8_t)

impl_reduce(Min, fast_min_f32, float)
impl_reduce(Min, fast_min_u32, uint)
impl_reduce(Min, fast_min_f16, half)
impl_reduce(Min, fast_min_u8, uint8_t)

impl_arg_reduce(Min, fast_argmin_f32, float)
impl_arg_reduce(Min, fast_argmin_f16, half)
impl_arg_reduce(Min, fast_argmin_u32, uint)
impl_arg_reduce(Min, fast_argmin_u8, uint8_t)

impl_arg_reduce(Max, fast_argmax_f32, float)
impl_arg_reduce(Max, fast_argmax_f16, half)
impl_arg_reduce(Max, fast_argmax_u32, uint)
impl_arg_reduce(Max, fast_argmax_u8, uint8_t)

impl_softmax(softmax_f32, float)
impl_softmax(softmax_f16, half)

#if __METAL_VERSION__ >= 220
impl_reduce(Sum, fast_sum_i64, int64_t)
impl_reduce(Mul, fast_mul_i64, int64_t)
impl_reduce(Min, fast_min_i64, int64_t)
impl_reduce(Max, fast_max_i64, int64_t)

impl_arg_reduce(Min, fast_argmin_i64, int64_t)
impl_arg_reduce(Max, fast_argmax_i64, int64_t)
#endif

#if defined(__HAVE_BFLOAT__)
impl_reduce(Sum, fast_sum_bf16, bfloat)
impl_reduce(Mul, fast_mul_bf16, bfloat)
impl_reduce(Max, fast_max_bf16, bfloat)
impl_reduce(Min, fast_min_bf16, bfloat)

impl_arg_reduce(Min, fast_argmin_bf16, bfloat)
impl_arg_reduce(Max, fast_argmax_bf16, bfloat)

impl_softmax(softmax_bf16, bfloat)

impl_rms_norm(rmsnorm_bf16, bfloat)
impl_layer_norm(layernorm_bf16, bfloat)
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif