1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
//! [`KernelRegistry`] — lazy compilation and caching of Metal compute pipelines.
//!
//! MSL shader source is embedded at compile time via `include_str!`. On first
//! access, the source is compiled into a Metal library, the named function is
//! extracted, and a `ComputePipelineState` is created and cached. Subsequent
//! calls return the cached pipeline.
use std::collections::HashMap;
use metal::{ComputePipelineDescriptor, ComputePipelineState, FunctionConstantValues, MTLDataType};
use crate::error::{MlxError, Result};
// MTLDataType numeric values (from metal-rs argument.rs, confirmed in Apple Metal spec):
// Int = 29
// Bool = 53
// These are used when calling set_constant_value_at_index so the Metal runtime
// knows how wide each constant value is.
/// Registry that lazily compiles and caches Metal compute pipelines from
/// embedded MSL source.
///
/// # Usage
///
/// ```ignore
/// let mut registry = KernelRegistry::new();
/// let pipeline = registry.get_pipeline("elementwise_add", device.metal_device())?;
/// encoder.encode(&pipeline, &buffers, grid, tg);
/// ```
///
/// # Thread Safety
///
/// `KernelRegistry` is **not** `Sync` by default (it uses `&mut self` for
/// `get_pipeline` to allow mutable cache insertion). If you need concurrent
/// access, wrap it in a `Mutex` or use one registry per thread.
pub struct KernelRegistry {
/// Cached pipelines keyed by kernel function name.
cache: HashMap<String, ComputePipelineState>,
/// MSL source text keyed by kernel function name.
///
/// Populated at construction time with all embedded shader sources.
sources: HashMap<String, &'static str>,
}
impl KernelRegistry {
/// Create a new registry with all embedded shader sources pre-registered.
///
/// No compilation happens here — shaders are compiled lazily on first use.
pub fn new() -> Self {
let mut sources = HashMap::new();
// Register embedded shader sources.
sources.insert(
"placeholder".into(),
include_str!("shaders/placeholder.metal"),
);
sources.insert(
"quantized_matmul".into(),
include_str!("shaders/quantized_matmul.metal"),
);
sources.insert(
"quantized_matmul_simd".into(),
include_str!("shaders/quantized_matmul.metal"),
);
sources.insert(
"quantized_matmul_simd_bf16".into(),
include_str!("shaders/quantized_matmul.metal"),
);
sources.insert(
"quantized_matmul_simd_bf16_expert".into(),
include_str!("shaders/quantized_matmul.metal"),
);
// GGML block-format quantized mat-vec kernels (ADR-006 Phase 3)
let ggml_src: &'static str =
include_str!("shaders/quantized_matmul_ggml.metal");
sources.insert("kernel_mul_mv_q4_0_f32".into(), ggml_src);
sources.insert("kernel_mul_mv_q8_0_f32".into(), ggml_src);
// ADR-028 iter-368: peer-style NSG=4 NR=2 variant (128 threads/TG).
sources.insert("kernel_mul_mv_q8_0_f32_nr2".into(), ggml_src);
sources.insert("kernel_mul_mv_q6_K_f32".into(), ggml_src);
// ADR-028 iter-309 — q6_K mat-vec with nr0=2 + cached yl[16]
// (peer-pattern port of llama.cpp's `kernel_mul_mv_q6_K_f32_impl`
// with N_R0_Q6_K=2; 4 rows/TG vs baseline's 2). Env-gated via
// `HF2Q_Q6K_MV_NR2=1` in the dispatcher.
sources.insert("kernel_mul_mv_q6_K_f32_nr2".into(), ggml_src);
// ADR-022 Phase 1 — Q5_1 / IQ4_NL dense mat-vec.
sources.insert("kernel_mul_mv_q5_1_f32".into(), ggml_src);
sources.insert("kernel_mul_mv_iq4_nl_f32".into(), ggml_src);
// ADR-013 P7 — Q4_K dense decode mat-vec (port of llama.cpp's
// kernel_mul_mv_q4_K_f32 at ggml-metal.metal:7715-7821).
sources.insert("kernel_mul_mv_q4_K_f32".into(), ggml_src);
// ADR-022 Phase 2 — Q5_K dense mv kernel.
sources.insert("kernel_mul_mv_q5_K_f32".into(), ggml_src);
// GGML block-format quantized matrix-matrix kernels
// (ADR-011 Phase 3 Wave P3a: port of llama.cpp's kernel_mul_mm_<q>_f32).
// Used at prefill m > 8 to reuse each weight tile across a 32-row
// block via threadgroup-staged simdgroup MMA, instead of re-reading
// every block per prompt-token as the mv kernel does.
let ggml_mm_src: &'static str =
include_str!("shaders/quantized_matmul_mm.metal");
sources.insert("kernel_mul_mm_q4_0_f32".into(), ggml_mm_src);
sources.insert("kernel_mul_mm_q8_0_f32".into(), ggml_mm_src);
sources.insert("kernel_mul_mm_q6_K_f32".into(), ggml_mm_src);
// ADR-022 Phase 1 — dense Q5_1 / IQ4_NL mm.
sources.insert("kernel_mul_mm_q5_1_f32".into(), ggml_mm_src);
sources.insert("kernel_mul_mm_iq4_nl_f32".into(), ggml_mm_src);
// ADR-022 Phase 2 — dense Q5_K mm.
sources.insert("kernel_mul_mm_q5_K_f32".into(), ggml_mm_src);
// ADR-022 Phase 3 — dense Q4_K mm.
sources.insert("kernel_mul_mm_q4_K_f32".into(), ggml_mm_src);
// GGML block-format quantized matrix-matrix kernels — tensor API
// variant (ADR-011 Phase 3 Wave P3b-tensor: port of llama.cpp's
// kernel_mul_mm_impl `#ifdef GGML_METAL_HAS_TENSOR` branch).
// Uses Apple's MetalPerformancePrimitives `tensor_ops::matmul2d`
// primitive which on M3+ dispatches to hardware tensor cores for
// 2-3x the effective FLOP throughput vs the simdgroup MMA path.
// Only compiled on devices where the tensor API is available; the
// kernel_registry's runtime-probe (see MlxDevice::has_tensor) gates
// compilation so non-tensor devices transparently fall back to the
// non-tensor `kernel_mul_mm_<q>_f32` kernels.
let ggml_mm_tensor_src: &'static str =
include_str!("shaders/quantized_matmul_mm_tensor.metal");
sources.insert("kernel_mul_mm_q4_0_tensor_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q4_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q6_K_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q8_0_tensor_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q6_K_tensor_f32".into(), ggml_mm_tensor_src);
// ADR-022 Phase 1 — Q5_1 / IQ4_NL tensor mm.
sources.insert("kernel_mul_mm_q5_1_tensor_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_iq4_nl_tensor_f32".into(), ggml_mm_tensor_src);
// ADR-022 Phase 2 — Q5_K tensor mm.
sources.insert("kernel_mul_mm_q5_K_tensor_f32".into(), ggml_mm_tensor_src);
// ADR-022 Phase 3 — Q4_K tensor mm + Q8_0 perm021.
sources.insert("kernel_mul_mm_q4_K_tensor_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q8_0_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
// ADR-029 iter-30 H29-speed — F16-weight V2 large-tile mm.
// Same source file as the V2 quantized variants; reads F16 weight
// directly from device memory (no per-call dequant). Used when
// MlxQWeight.f16_shadow is populated and m > MM_ROUTING_THRESHOLD.
sources.insert("hf2q_mul_mm_tensor_v2_f16".into(), ggml_mm_tensor_src);
// ADR-029 iter-36 H28-D — F16-weight perm021 mm for O-projection.
// Same source file; reads F16 weight from MlxQWeight.f16_shadow when
// populated, bypassing the per-call quantized dequant. B-stage
// (bfloat permuted [n_heads, seq_len, head_dim] input) is byte-
// identical to the quantized variant.
sources.insert("kernel_mul_mm_f16_tensor_bf16_perm021".into(), ggml_mm_tensor_src);
// ADR-029 iter-23 H28-A — V2 large-tile tensor mm (NRA=64 M, NRB=128 N).
// Same source file as V1 tensor mm; distinct kernel host names so the
// dispatcher can pick V1 vs V2 at runtime via HF2Q_LARGE_TILE_MM.
sources.insert("kernel_mul_mm_q4_0_tensor_v2_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q8_0_tensor_v2_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q6_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q5_1_tensor_v2_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_iq4_nl_tensor_v2_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q5_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
sources.insert("kernel_mul_mm_q4_K_tensor_v2_f32".into(), ggml_mm_tensor_src);
// ADR-029 iter-28 H29 — whole-tensor dequant from block_q → F16.
// Used at model load to materialize an F16 shadow of attn/dense MLP
// weights so the runtime dispatch can use kernel_mul_mm_f16_f32_*
// (peer's gemma4 pattern). Trades ~1 GB resident memory for 2-3×
// faster per-call dense matmul at prefill.
let dequant_to_f16_src: &'static str =
include_str!("shaders/dequant_to_f16.metal");
sources.insert("hf2q_dequant_q4_0_to_f16".into(), dequant_to_f16_src);
sources.insert("hf2q_dequant_q8_0_to_f16".into(), dequant_to_f16_src);
sources.insert("hf2q_dequant_q5_1_to_f16".into(), dequant_to_f16_src);
sources.insert("hf2q_dequant_iq4_nl_to_f16".into(), dequant_to_f16_src);
sources.insert("hf2q_dequant_q4_K_to_f16".into(), dequant_to_f16_src);
sources.insert("hf2q_dequant_q5_K_to_f16".into(), dequant_to_f16_src);
sources.insert("hf2q_dequant_q6_K_to_f16".into(), dequant_to_f16_src);
// ADR-022 Phase 1 P1.7 — Q5_1 / IQ4_NL mul_mv_ext r1 family.
// Eight instantiations (2 types × 4 r1ptg widths). Each PSO is
// additionally specialized at PSO-compile time with FC_mul_mv_nsg
// (function_constant 600) and FC_mul_mv_nxpsg (function_constant 601).
let mul_mv_ext_src: &'static str = include_str!("shaders/mul_mv_ext.metal");
sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_2".into(), mul_mv_ext_src);
sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_3".into(), mul_mv_ext_src);
sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_4".into(), mul_mv_ext_src);
sources.insert("kernel_mul_mv_ext_q5_1_f32_r1_5".into(), mul_mv_ext_src);
sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_2".into(), mul_mv_ext_src);
sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_3".into(), mul_mv_ext_src);
sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_4".into(), mul_mv_ext_src);
sources.insert("kernel_mul_mv_ext_iq4_nl_f32_r1_5".into(), mul_mv_ext_src);
// ADR-022 Phase 4 — Q4_0 / Q8_0 / Q4_K / Q5_K / Q6_K mv_ext.
// 5 types × 4 r1ptg widths = 20 instantiations.
for r1 in [2, 3, 4, 5].iter() {
for ty in ["q4_0", "q8_0", "q4_K", "q5_K", "q6_K"].iter() {
let name = format!("kernel_mul_mv_ext_{ty}_f32_r1_{r1}");
sources.insert(name, mul_mv_ext_src);
}
}
// Dense bf16×f32 → f32 tensor-API matmul (non-flash-attention
// prefill Q@K^T and scores@V, modeled on llama.cpp's
// kernel_mul_mm_bf16_f32 with the GGML_METAL_HAS_TENSOR branch
// active). Tile geometry and write-back identical to the
// quantized tensor kernel; only the A-stage copy (bfloat →
// bfloat, no dequantize) differs.
let dense_mm_bf16_tensor_src: &'static str =
include_str!("shaders/dense_mm_bf16_tensor.metal");
sources.insert("hf2q_dense_mm_bf16_f32_tensor".into(), dense_mm_bf16_tensor_src);
// ADR-029 iter-80 H60: V2 large-tile variant (NRA=64, NRB=128).
// Same source file (`dense_mm_bf16_tensor.metal`) — second host_name
// entry resolves to the V2 kernel appended at the bottom of that
// file. Picked at dispatch time when HF2Q_LARGE_TILE_MM=1.
sources.insert("hf2q_dense_mm_bf16_f32_tensor_v2".into(), dense_mm_bf16_tensor_src);
// Dense f32×f32 → f32 tensor-API matmul (F32-everywhere
// sibling of dense_mm_bf16_tensor). Used by hf2q's ADR-005
// iter-118 BF16-vs-F32 ViT attention A/B diagnostic to remove
// the BF16 K-stage cast as a confounding variable. Port of
// llama.cpp's kernel_mul_mm_f32_f32 specialization
// (ggml-metal.metal:10098) on the GGML_METAL_HAS_TENSOR
// branch. Same tile geometry (NR0=64 NR1=32 NK=32) but
// float-everywhere shmem staging.
let dense_mm_f32_f32_tensor_src: &'static str =
include_str!("shaders/dense_mm_f32_f32.metal");
sources.insert("hf2q_dense_mm_f32_f32_tensor".into(), dense_mm_f32_f32_tensor_src);
// Dense f16×f32 → f32 tensor-API matmul (F16-staging sibling
// of dense_mm_bf16_tensor). Used by hf2q's ADR-005 Phase 2c
// iter-128 gemma4v ViT precision-parity path: every mmproj
// weight is stored as F16 in GGUF, peer's `kernel_mul_mm_f16_f32`
// (`ggml-metal.metal:10099`) stages BOTH A and B as `half` in
// shmem and computes on `simdgroup_half8x8`. Matches peer
// per-element rounding budget exactly (10-bit mantissa vs
// BF16's 7-bit), closing the 1.16x/block cascade compound that
// iter-127 numerically bisected to BF16 staging. Same tile
// geometry as the BF16 sibling (NR0=64 NR1=32 NK=32, 8 KB
// shmem) — half and bfloat share 16-bit storage.
let dense_mm_f16_tensor_src: &'static str =
include_str!("shaders/dense_mm_f16_tensor.metal");
sources.insert("hf2q_dense_mm_f16_f32_tensor".into(), dense_mm_f16_tensor_src);
// Dense bf16×f32 → f32 GEMV (matrix-vector multiply) — optimized
// for M=1 single-token decode. Port of llama.cpp's
// kernel_mul_mv_bf16_f32_4 (bfloat4-vectorized GEMV kernel).
// Used in apply_linear_projection_f32 when seq_len=1 and the
// weight matrix is BF16, replacing the MM kernel (~2× faster for
// M=1 due to better memory bandwidth utilization per thread).
let dense_gemv_bf16_src: &'static str =
include_str!("shaders/dense_gemv_bf16.metal");
sources.insert("hf2q_dense_gemv_bf16_f32_4".into(), dense_gemv_bf16_src);
// Fused scale-mask-softmax for the non-flash-attention prefill
// path. One row-local threadgroup per (head, query) pair
// replaces three separate dispatches (scale, mask-add, softmax);
// reads a bf16 mask (-INF at masked positions, matching
// flash_attn_prefill_mask.metal) that is shared across heads.
let scale_mask_softmax_src: &'static str =
include_str!("shaders/scale_mask_softmax.metal");
sources.insert("scale_mask_softmax_f32".into(), scale_mask_softmax_src);
// ADR-029 iter-93 H71: float4-vectorized variant for peer parity
// with kernel_soft_max_f32_4. Same source file; v4 host_name resolves
// to the second kernel appended at the bottom of scale_mask_softmax.metal.
sources.insert("scale_mask_softmax_f32_v4".into(), scale_mask_softmax_src);
// Expert-routed (MoE) quantized matmul kernel (Story 2.1)
sources.insert(
"quantized_matmul_id".into(),
include_str!("shaders/quantized_matmul_id.metal"),
);
// Expert-routed (MoE) GGML block-format quantized matmul kernels
let ggml_id_src: &'static str =
include_str!("shaders/quantized_matmul_id_ggml.metal");
sources.insert("kernel_mul_mv_id_q4_0_f32".into(), ggml_id_src);
sources.insert("kernel_mul_mv_id_q8_0_f32".into(), ggml_id_src);
// ADR-013 P7 — Q4_K MoE expert-routed mat-vec (port of
// llama.cpp's kernel_mul_mv_id_q4_K_f32 at ggml-metal.metal:10349).
sources.insert("kernel_mul_mv_id_q4_K_f32".into(), ggml_id_src);
sources.insert("kernel_mul_mv_id_q5_K_f32".into(), ggml_id_src);
sources.insert("kernel_mul_mv_id_q6_K_f32".into(), ggml_id_src);
// ADR-028 iter-321 — q6_K _id with nr0=2 + cached yl[16]
// (peer-pattern port mirroring iter-309's non-_id variant).
// Env-gated via HF2Q_Q6K_ID_MV_NR2=1 in dispatch_id_mv.
sources.insert("kernel_mul_mv_id_q6_K_f32_nr2".into(), ggml_id_src);
// ADR-029 iter-6 — q8_0 _id with nr0=2 + nsg=4 cross-SG reduce
// (peer-pattern port; peer N_R0_Q8_0=2 + N_SG_Q8_0=4 in
// /opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h:27,40).
// Env-gated via HF2Q_Q8_0_ID_MV_NR2=1 in dispatch_id_mv.
sources.insert("kernel_mul_mv_id_q8_0_f32_nr2".into(), ggml_id_src);
// ADR-022 Phase 1 — Q5_1 / IQ4_NL MoE expert-routed mat-vec.
sources.insert("kernel_mul_mv_id_q5_1_f32".into(), ggml_id_src);
sources.insert("kernel_mul_mv_id_iq4_nl_f32".into(), ggml_id_src);
// Fused-SwiGLU mv_id variants (ADR-012 §Optimize / Task #15):
// computes y[r][n] = sum_k(dequant(W[expert][n][k]) * silu(gate[r][k]) * up[r][k])
// in one dispatch — replaces silu_mul + expert_down sequence.
sources.insert("kernel_mul_mv_id_q4_0_f32_swiglu".into(), ggml_id_src);
// Expert-routed (MoE) GGML block-format QUANTIZED MATRIX-MATRIX kernels
// (ADR-011 Phase 3 Wave P3a: port of llama.cpp's
// `kernel_mul_mm_id_map0_ne20_N` + `kernel_mul_mm_id_<q>_f32`).
// Two-stage dispatch: map0 regroups the token-to-expert table into
// per-expert routed-token lists, then mm_id stages a 64x32 expert
// weight tile into threadgroup shmem and reuses it across a 32-row
// block of that expert's routed tokens.
let ggml_id_mm_src: &'static str =
include_str!("shaders/quantized_matmul_id_mm.metal");
sources.insert("kernel_mul_mm_id_map0_ne20_1".into(), ggml_id_mm_src);
sources.insert("kernel_mul_mm_id_map0_ne20_8".into(), ggml_id_mm_src);
sources.insert("kernel_mul_mm_id_q4_0_f32".into(), ggml_id_mm_src);
sources.insert("kernel_mul_mm_id_q8_0_f32".into(), ggml_id_mm_src);
sources.insert("kernel_mul_mm_id_q6_K_f32".into(), ggml_id_mm_src);
// ADR-013 P16 — Q4_K mm_id (port of llama.cpp ggml-metal.metal:10169).
sources.insert("kernel_mul_mm_id_q4_K_f32".into(), ggml_id_mm_src);
// ADR-022 Phase 1 P1.6 — Q5_1 / IQ4_NL mm_id template instantiations.
sources.insert("kernel_mul_mm_id_q5_1_f32".into(), ggml_id_mm_src);
sources.insert("kernel_mul_mm_id_iq4_nl_f32".into(), ggml_id_mm_src);
// ADR-022 Phase 2 — Q5_K mm_id template instantiation.
sources.insert("kernel_mul_mm_id_q5_K_f32".into(), ggml_id_mm_src);
// MoE-routed quantized matrix-matrix kernels — tensor API variant
// (ADR-011 Phase 3 Wave P3b-tensor). Uses the MPP tensor_ops
// matmul2d primitive for hardware-tensor-core MMA on M3+. Only
// the mm_id kernel is ported — map0 is a short pre-pass (not
// matmul) and continues to use the simdgroup version.
let ggml_id_mm_tensor_src: &'static str =
include_str!("shaders/quantized_matmul_id_mm_tensor.metal");
sources.insert("kernel_mul_mm_id_q4_0_tensor_f32".into(), ggml_id_mm_tensor_src);
sources.insert("kernel_mul_mm_id_q8_0_tensor_f32".into(), ggml_id_mm_tensor_src);
sources.insert("kernel_mul_mm_id_q6_K_tensor_f32".into(), ggml_id_mm_tensor_src);
// ADR-013 P16 — Q4_K tensor-API mm_id.
sources.insert("kernel_mul_mm_id_q4_K_tensor_f32".into(), ggml_id_mm_tensor_src);
// ADR-022 Phase 1 P1.6 — Q5_1 / IQ4_NL tensor-API mm_id.
sources.insert("kernel_mul_mm_id_q5_1_tensor_f32".into(), ggml_id_mm_tensor_src);
sources.insert("kernel_mul_mm_id_iq4_nl_tensor_f32".into(), ggml_id_mm_tensor_src);
// ADR-022 Phase 2 — Q5_K tensor-API mm_id.
sources.insert("kernel_mul_mm_id_q5_K_tensor_f32".into(), ggml_id_mm_tensor_src);
// Embedding kernels (Story 1.5)
let embedding_src: &'static str = include_str!("shaders/embedding.metal");
sources.insert("embedding_gather_4bit".into(), embedding_src);
sources.insert("embedding_gather_6bit".into(), embedding_src);
// MoE gate kernel (Story 1.5)
let moe_gate_src: &'static str = include_str!("shaders/moe_gate.metal");
sources.insert("moe_gate".into(), moe_gate_src);
// MoE dispatch kernels (Story 1.5)
let moe_dispatch_src: &'static str = include_str!("shaders/moe_dispatch.metal");
sources.insert("fused_gelu_mul".into(), moe_dispatch_src);
sources.insert("moe_swiglu_fused".into(), moe_dispatch_src);
sources.insert("moe_swiglu_batch".into(), moe_dispatch_src);
sources.insert("moe_swiglu_seq".into(), moe_dispatch_src);
sources.insert("moe_accumulate".into(), moe_dispatch_src);
sources.insert("moe_weighted_sum".into(), moe_dispatch_src);
sources.insert("moe_weighted_sum_seq".into(), moe_dispatch_src);
sources.insert("zero_buffer".into(), moe_dispatch_src);
sources.insert("naive_matvec_f32".into(), moe_dispatch_src);
sources.insert("moe_gather_topk_weights".into(), moe_dispatch_src);
// bf16 variants (Phase 2 bf16 activation path)
sources.insert("fused_gelu_mul_bf16".into(), moe_dispatch_src);
sources.insert("moe_swiglu_seq_bf16".into(), moe_dispatch_src);
sources.insert("moe_weighted_sum_seq_bf16_input".into(), moe_dispatch_src);
// ADR-020 iter-11h-e3a: backward kernels for moe_weighted_sum_seq.
sources.insert(
"moe_weighted_sum_seq_backward_outputs_f32".into(),
moe_dispatch_src,
);
sources.insert(
"moe_weighted_sum_seq_backward_weights_f32".into(),
moe_dispatch_src,
);
// ADR-020 iter-11h-e3b: fused backward kernel for moe_swiglu_seq.
sources.insert(
"moe_swiglu_seq_backward_f32".into(),
moe_dispatch_src,
);
// Batched KV cache copy kernels
let kv_cache_src: &'static str = include_str!("shaders/kv_cache_copy.metal");
sources.insert("kv_cache_copy_batch_f32".into(), kv_cache_src);
sources.insert("kv_cache_copy_batch_f32_to_f16".into(), kv_cache_src);
sources.insert("kv_cache_copy_seq_f32".into(), kv_cache_src);
sources.insert("kv_cache_copy_seq_f32_to_f16".into(), kv_cache_src);
// Wave P4.11 — fused K+V copy variants
sources.insert("kv_cache_copy_seq_f32_kv_dual".into(), kv_cache_src);
sources.insert("kv_cache_copy_seq_f32_to_f16_kv_dual".into(), kv_cache_src);
// ADR-028 iter-145 — fused single-position K+V copy variants (decode shape)
sources.insert("kv_cache_copy_batch_f32_kv_dual".into(), kv_cache_src);
sources.insert("kv_cache_copy_batch_f32_to_f16_kv_dual".into(), kv_cache_src);
// bf16-source KV cache copy (Phase 2 bf16 activation path)
sources.insert("kv_cache_copy_seq_bf16".into(), kv_cache_src);
// Elementwise and transpose kernels (Story 1.5)
let elementwise_src: &'static str = include_str!("shaders/elementwise.metal");
sources.insert("elementwise_add_f32".into(), elementwise_src);
sources.insert("elementwise_add_f16".into(), elementwise_src);
sources.insert("elementwise_mul_f32".into(), elementwise_src);
sources.insert("elementwise_mul_f16".into(), elementwise_src);
sources.insert("elementwise_add_bf16".into(), elementwise_src);
sources.insert("elementwise_mul_bf16".into(), elementwise_src);
sources.insert("cast_f16_to_f32".into(), elementwise_src);
sources.insert("cast_f32_to_f16".into(), elementwise_src);
sources.insert("cast_bf16_to_f32".into(), elementwise_src);
sources.insert("cast_f32_to_bf16".into(), elementwise_src);
sources.insert("scalar_mul_bf16".into(), elementwise_src);
sources.insert("scalar_mul_f32".into(), elementwise_src);
sources.insert("embedding_gather_scale_f32".into(), elementwise_src);
sources.insert("embedding_gather_scale_batch_f32".into(), elementwise_src);
sources.insert("permute_021_bf16".into(), elementwise_src);
sources.insert("transpose_last2_bf16".into(), elementwise_src);
sources.insert("transpose_last2_f16".into(), elementwise_src);
sources.insert("permute_021_f32".into(), elementwise_src);
sources.insert("permute_021_bf16_to_f32".into(), elementwise_src);
sources.insert("transpose_2d_f32".into(), elementwise_src);
sources.insert("transpose_2d_f16".into(), elementwise_src);
// Attention kernels (Story 1.3)
let sdpa_src: &'static str = include_str!("shaders/sdpa.metal");
sources.insert("sdpa".into(), sdpa_src);
sources.insert("sdpa_bf16".into(), sdpa_src);
let sdpa_sliding_src: &'static str = include_str!("shaders/sdpa_sliding.metal");
sources.insert("sdpa_sliding".into(), sdpa_sliding_src);
sources.insert("sdpa_sliding_bf16".into(), sdpa_sliding_src);
// Flash-attention tiled prefill kernel (ADR-011 Phase 1).
// Ten entry points; all backed by the same shader source.
// Pipelines are compiled with function constants via
// `get_pipeline_with_bool_constants` — not `get_pipeline`.
let flash_attn_prefill_src: &'static str =
include_str!("shaders/flash_attn_prefill.metal");
// D=256 variants (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
sources.insert(
"steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskfloat32".into(),
flash_attn_prefill_src,
);
sources.insert(
"steel_attention_float32_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
flash_attn_prefill_src,
);
sources.insert(
"steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbfloat16".into(),
flash_attn_prefill_src,
);
sources.insert(
"steel_attention_bfloat16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
flash_attn_prefill_src,
);
sources.insert(
"steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskfloat16".into(),
flash_attn_prefill_src,
);
sources.insert(
"steel_attention_float16_bq32_bk16_bd256_wm4_wn1_maskbool_".into(),
flash_attn_prefill_src,
);
// D=512 variants (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup)
// NOTE: f32 at D=512 is NOT instantiated — threadgroup memory exceeds
// the 32 KB Metal limit (candle sdpa.rs:86-94).
sources.insert(
"steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbfloat16".into(),
flash_attn_prefill_src,
);
sources.insert(
"steel_attention_bfloat16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
flash_attn_prefill_src,
);
sources.insert(
"steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskfloat16".into(),
flash_attn_prefill_src,
);
sources.insert(
"steel_attention_float16_bq8_bk8_bd512_wm1_wn1_maskbool_".into(),
flash_attn_prefill_src,
);
// Flash attention vector kernels — SIMD-vectorized decode-path SDPA
// (ported from llama.cpp flash_attn_ext_vec)
let flash_attn_vec_src: &'static str =
include_str!("shaders/flash_attn_vec.metal");
sources.insert("flash_attn_vec_dk256".into(), flash_attn_vec_src);
sources.insert("flash_attn_vec_dk512".into(), flash_attn_vec_src);
sources.insert("flash_attn_vec_reduce_dk256".into(), flash_attn_vec_src);
sources.insert("flash_attn_vec_reduce_dk512".into(), flash_attn_vec_src);
// F16 KV variants (Phase 4a)
sources.insert("flash_attn_vec_f16kv_dk256".into(), flash_attn_vec_src);
sources.insert("flash_attn_vec_f16kv_dk512".into(), flash_attn_vec_src);
// RoPE, normalization, activation kernels (Story 1.4)
let rope_src: &'static str = include_str!("shaders/rope.metal");
sources.insert("rope_f32".into(), rope_src);
sources.insert("rope_f16".into(), rope_src);
sources.insert("rope_bf16".into(), rope_src);
sources.insert("rope_neox_bf16".into(), rope_src);
sources.insert("rope_neox_f32".into(), rope_src);
let rms_norm_src: &'static str = include_str!("shaders/rms_norm.metal");
sources.insert("rms_norm_f32".into(), rms_norm_src);
// ADR-028 iter-310 — float4 + simd_sum variants (peer-pattern,
// ported from llama.cpp kernel_rms_norm_fuse_impl<float4, 1>).
// Env-gated via HF2Q_RMS_NORM_V2=1 in the dispatchers.
sources.insert("rms_norm_f32_v2".into(), rms_norm_src);
sources.insert("rms_norm_no_scale_f32_v2".into(), rms_norm_src);
sources.insert("rms_norm_f16".into(), rms_norm_src);
sources.insert("rms_norm_bf16".into(), rms_norm_src);
sources.insert("rms_norm_no_scale_bf16".into(), rms_norm_src);
sources.insert("rms_norm_no_scale_f32".into(), rms_norm_src);
sources.insert("rms_norm_no_scale_f32_dual".into(), rms_norm_src);
sources.insert("rms_norm_f32_triple".into(), rms_norm_src);
sources.insert("fused_post_attn_triple_norm_f32".into(), rms_norm_src);
// ADR-028 iter-370: V2 (float4 + simd_sum) variant of triple_norm.
sources.insert("fused_post_attn_triple_norm_f32_v2".into(), rms_norm_src);
// ADR-028 iter-217: fused post-FF norm 2 + end-of-layer FINAL
// (combines 2 sequential fused_norm_add dispatches into 1 kernel).
sources.insert("fused_post_ff_norm2_endlayer_f32".into(), rms_norm_src);
// ADR-028 iter-362: V2 (float4 + simd_sum) variant of the above.
// Same math, 75% fewer barriers per dispatch (4 vs 16 at tg=256).
sources.insert("fused_post_ff_norm2_endlayer_f32_v2".into(), rms_norm_src);
// ADR-028 iter-367: V2 fusion of moe_weighted_sum INTO Path A end-of-layer.
// Eliminates 1 dispatch + moe_accum round-trip from gemma4 decode default.
sources.insert("fused_moe_wsum_post_ff_norm2_endlayer_f32_v2".into(), rms_norm_src);
sources.insert("rms_norm_no_scale_f32_dual_perm".into(), rms_norm_src);
// Fused RMS norm + elementwise multiply kernels (Phase 4e.2)
sources.insert("rms_norm_mul_f32".into(), rms_norm_src);
sources.insert("rms_norm_mul_f16".into(), rms_norm_src);
sources.insert("rms_norm_mul_bf16".into(), rms_norm_src);
// L2 norm kernels (ADR-013 Decision 3 — Gated DeltaNet Q/K norm)
let l2_norm_src: &'static str = include_str!("shaders/l2_norm.metal");
sources.insert("l2_norm_f32".into(), l2_norm_src);
sources.insert("l2_norm_f16".into(), l2_norm_src);
sources.insert("l2_norm_bf16".into(), l2_norm_src);
// ADR-015 iter59a — fused L2 norm + scalar multiply (DN q-path).
sources.insert("l2_norm_scale_f32".into(), l2_norm_src);
// Cumulative-sum kernels (ADR-013 Decision 4 — DeltaNet decay-mask base)
let cumsum_src: &'static str = include_str!("shaders/cumsum.metal");
sources.insert("cumsum_f32".into(), cumsum_src);
sources.insert("cumsum_bf16".into(), cumsum_src);
// SSM conv kernels (ADR-013 Decision 7 — DeltaNet 1D causal conv + SiLU)
let ssm_conv_src: &'static str = include_str!("shaders/ssm_conv.metal");
sources.insert("ssm_conv_forward_f32".into(), ssm_conv_src);
sources.insert("ssm_conv_forward_bf16".into(), ssm_conv_src);
sources.insert("ssm_conv_state_update_f32".into(), ssm_conv_src);
sources.insert("ssm_conv_state_update_bf16".into(), ssm_conv_src);
// Tri-solve kernels (ADR-013 Decision 5 — chunked DeltaNet debug path)
let tri_solve_src: &'static str = include_str!("shaders/tri_solve.metal");
sources.insert("tri_solve_lower_unit_f32".into(), tri_solve_src);
sources.insert("tri_solve_lower_unit_bf16".into(), tri_solve_src);
// Rope-multi kernels (ADR-013 Decision 10 — IMROPE for Qwen3.5)
let rope_multi_src: &'static str = include_str!("shaders/rope_multi.metal");
sources.insert("rope_multi_f32".into(), rope_multi_src);
sources.insert("rope_multi_bf16".into(), rope_multi_src);
// Gated DeltaNet fused kernel (ADR-013 Decision 6 — centerpiece)
let gdn_src: &'static str = include_str!("shaders/gated_delta_net.metal");
sources.insert("gated_delta_net_f32".into(), gdn_src);
// ADR-015 iter56 — decode-only `simd_sum` variant. Three NSG-templated
// host names share the same source; selection is by D_k via
// `dispatch_gated_delta_net_decode`. Drop-in for the fused kernel
// above when n_tokens=1.
let gdn_decode_src: &'static str =
include_str!("shaders/gated_delta_net_decode.metal");
sources.insert("gated_delta_net_decode_f32_1".into(), gdn_decode_src);
sources.insert("gated_delta_net_decode_f32_2".into(), gdn_decode_src);
sources.insert("gated_delta_net_decode_f32_4".into(), gdn_decode_src);
// Wave 5b — chunk-parallel inter-chunk state-recurrence kernel
// (the one new kernel in the chunk-parallel pipeline; spec source:
// arXiv 2412.06464 §4 + FLA chunk_delta_h.py:43-298).
let gdn_chunk_src: &'static str =
include_str!("shaders/gated_delta_net_chunk.metal");
sources.insert(
"gated_delta_net_chunk_inter_state_bf16".into(),
gdn_chunk_src,
);
// Wave 5b.1 iter 2 — chunk_scaled_dot_kkt kernel (input-side of
// the chunk pipeline; spec source: FLA chunk_scaled_dot_kkt.py:36-99).
let gdn_kkt_src: &'static str =
include_str!("shaders/gated_delta_net_kkt.metal");
sources.insert("gated_delta_net_kkt_bf16".into(), gdn_kkt_src);
// Wave 5b.1 iter 2 — recompute_w_u_fwd kernel (applies post-solve A
// to (β·v) and (β·k·exp(g)) to produce w and u; spec source: FLA
// wy_fast.py:29-117).
let gdn_recompute_wu_src: &'static str =
include_str!("shaders/gated_delta_net_recompute_wu.metal");
sources.insert(
"gated_delta_net_recompute_wu_bf16".into(),
gdn_recompute_wu_src,
);
// Wave 5b.1 iter 3 — chunk_fwd_o kernel (per-chunk output: closes
// the chunk pipeline; spec source: FLA chunk_o.py:42-138).
let gdn_chunk_o_src: &'static str =
include_str!("shaders/gated_delta_net_chunk_o.metal");
sources.insert("gated_delta_net_chunk_o_bf16".into(), gdn_chunk_o_src);
// Wave 5b.1 iter 4 — orchestrator helper kernels:
// chunk_local_cumsum_g_f32 — per-chunk prefix sum on g [B, T, H]
// chunk_tri_solve_invert_f32 — per-chunk-block (I + A_strict)^-1
// on FLA's [B, T, H, BT] layout.
let chunk_local_cumsum_g_src: &'static str =
include_str!("shaders/chunk_local_cumsum_g.metal");
sources.insert(
"chunk_local_cumsum_g_f32".into(),
chunk_local_cumsum_g_src,
);
let chunk_tri_solve_invert_src: &'static str =
include_str!("shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
sources.insert(
"chunk_tri_solve_invert_f32".into(),
chunk_tri_solve_invert_src,
);
// Sigmoid-gated elementwise multiply (ADR-013 Decision 9 — full-attn output gate)
let sigmoid_mul_src: &'static str = include_str!("shaders/sigmoid_mul.metal");
sources.insert("sigmoid_mul_f32".into(), sigmoid_mul_src);
sources.insert("sigmoid_mul_bf16".into(), sigmoid_mul_src);
let silu_mul_src: &'static str = include_str!("shaders/silu_mul.metal");
sources.insert("silu_mul_f32".into(), silu_mul_src);
let compute_g_beta_src: &'static str = include_str!("shaders/compute_g_beta.metal");
sources.insert("compute_g_beta_f32".into(), compute_g_beta_src);
let ssm_norm_gate_src: &'static str = include_str!("shaders/ssm_norm_gate.metal");
sources.insert("ssm_norm_gate_f32".into(), ssm_norm_gate_src);
let gelu_src: &'static str = include_str!("shaders/gelu.metal");
sources.insert("gelu_f32".into(), gelu_src);
sources.insert("gelu_f16".into(), gelu_src);
sources.insert("gelu_bf16".into(), gelu_src);
let softmax_src: &'static str = include_str!("shaders/softmax.metal");
sources.insert("softmax_f32".into(), softmax_src);
sources.insert("softmax_f16".into(), softmax_src);
sources.insert("softmax_bf16".into(), softmax_src);
let softmax_backward_src: &'static str =
include_str!("shaders/softmax_backward.metal");
sources.insert("softmax_backward_f32".into(), softmax_backward_src);
let log_elementwise_src: &'static str =
include_str!("shaders/log_elementwise.metal");
sources.insert("log_f32".into(), log_elementwise_src);
sources.insert("log_backward_f32".into(), log_elementwise_src);
let row_sum_src: &'static str = include_str!("shaders/row_sum.metal");
sources.insert("row_sum_f32".into(), row_sum_src);
sources.insert("row_sum_backward_f32".into(), row_sum_src);
// ADR-020 iter-10a: GGUF-legacy quantize-dequantize round-trip kernels
// (Q4_0 + Q8_0). Used by hf2q's dynamic_quant Track 1 to produce
// W_low / W_high for the gradient-Taylor sensitivity formula.
let qdq_legacy_src: &'static str = include_str!("shaders/qdq_legacy.metal");
sources.insert("qdq_q4_0_f32".into(), qdq_legacy_src);
sources.insert("qdq_q8_0_f32".into(), qdq_legacy_src);
// ADR-020 iter-10b: RMSNorm reverse-mode autograd kernels.
// r_inv helper is reused by both backward kernels; dx and dw cover
// the full backward identity for `y = x * rsqrt(mean(x²) + eps) * w`.
let rms_norm_backward_src: &'static str =
include_str!("shaders/rms_norm_backward.metal");
sources.insert(
"rms_norm_compute_rms_inv_f32".into(),
rms_norm_backward_src,
);
sources.insert("rms_norm_backward_dx_f32".into(), rms_norm_backward_src);
sources.insert("rms_norm_backward_dw_f32".into(), rms_norm_backward_src);
// ADR-020 iter-11a: 2-D row-major slice + concat-by-column kernels.
// Used by hf2q's multi-head SDPA on GpuTape (slice Q/K/V into
// per-head views, run per-head SDPA, concat per-head contexts
// back to full attention output).
let slice_concat_2d_src: &'static str =
include_str!("shaders/slice_concat_2d.metal");
sources.insert("slice_2d_cols_f32".into(), slice_concat_2d_src);
sources.insert("copy_2d_cols_into_f32".into(), slice_concat_2d_src);
// ADR-020 iter-11b: SiLU forward + backward kernels for GpuTape
// SwiGLU FFN composition.
let silu_backward_src: &'static str =
include_str!("shaders/silu_backward.metal");
sources.insert("silu_f32".into(), silu_backward_src);
sources.insert("silu_backward_f32".into(), silu_backward_src);
// ADR-020 iter-11d: FP32 embedding lookup + scatter-add backward.
let embedding_autograd_src: &'static str =
include_str!("shaders/embedding_autograd.metal");
sources.insert("embedding_lookup_f32".into(), embedding_autograd_src);
sources.insert(
"embedding_scatter_add_f32".into(),
embedding_autograd_src,
);
// ADR-020 iter-13a: Adam optimizer step kernel for Track 2
// DWQ-proper training loop.
let adam_update_src: &'static str =
include_str!("shaders/adam_update.metal");
sources.insert("adam_update_f32".into(), adam_update_src);
// ADR-020 iter-13b: differentiable affine qdq kernels for the
// DWQ-proper training loop. Init + forward + backward (scales,
// biases) — q_int is FROZEN, scales+biases learnable.
let qdq_affine_src: &'static str =
include_str!("shaders/qdq_affine.metal");
sources.insert("qdq_affine_init_f32".into(), qdq_affine_src);
sources.insert("qdq_affine_forward_f32".into(), qdq_affine_src);
sources.insert(
"qdq_affine_backward_scales_f32".into(),
qdq_affine_src,
);
sources.insert(
"qdq_affine_backward_biases_f32".into(),
qdq_affine_src,
);
// ADR-020 iter-15: fused affine quantized matmul for DWQ inference.
// Per-element kernel; one thread per (m, n) output element.
// Tiled + simdgroup-MMA variant lands in iter-15b.
let qmm_affine_src: &'static str =
include_str!("shaders/qmm_affine.metal");
sources.insert("qmm_affine_t_f32".into(), qmm_affine_src);
// ADR-020 iter-15b: tiled variant — 16x16 thread block with
// cooperative-load X/W tiles in threadgroup-shared memory for
// 2-5x speedup over the per-element kernel.
let qmm_affine_tiled_src: &'static str =
include_str!("shaders/qmm_affine_tiled.metal");
sources.insert(
"qmm_affine_t_f32_tiled".into(),
qmm_affine_tiled_src,
);
// ADR-020 iter-15c: simdgroup-MMA variant — uses Apple GPU
// hardware `simdgroup_matrix<float, 8, 8>` MMA for the inner
// reduction. Per-tile algorithmic 8× over scalar tiled, lands
// as ~3-4× wall after launch / load amortization.
let qmm_affine_simd_src: &'static str =
include_str!("shaders/qmm_affine_simd.metal");
sources.insert(
"qmm_affine_t_f32_simd".into(),
qmm_affine_simd_src,
);
// ADR-020 iter-15c-2: 4-simdgroup-per-TG variant — 32×32
// output tile, 4 simdgroups arranged as 2×2 grid each owning
// a 16×16 sub-tile = 4 simdgroup_matrix accumulators. Same
// math as 15c-1, fuller warp-pool exploitation.
let qmm_affine_simd4_src: &'static str =
include_str!("shaders/qmm_affine_simd4.metal");
sources.insert(
"qmm_affine_t_f32_simd4".into(),
qmm_affine_simd4_src,
);
// ADR-020 iter-15c-2b: gs=64 variant (mlx-lm dynamic_quant
// canonical default). Same 4-simdgroup geometry, BK=64
// instead of 32 (= 8 sub-K-tiles per K-step instead of 4).
let qmm_affine_simd4_gs64_src: &'static str =
include_str!("shaders/qmm_affine_simd4_gs64.metal");
sources.insert(
"qmm_affine_t_f32_simd4_gs64".into(),
qmm_affine_simd4_gs64_src,
);
// ADR-020 AC#5 Iter A: packed-U32 dense affine matmul (bits=4,
// gs=32) — production decode/prefill kernel for serving DWQ
// safetensors directly without a load-time unpack pass.
let qmm_affine_t_packed_simd4_b4_src: &'static str =
include_str!("shaders/qmm_affine_t_packed_simd4_b4.metal");
sources.insert(
"qmm_affine_t_packed_simd4_b4".into(),
qmm_affine_t_packed_simd4_b4_src,
);
// ADR-020 iter-11h-b: training-mode causal depthwise 1D
// convolution (forward + backward dx + backward dw). Used by
// GpuTape autograd for differentiable Qwen3.5MoE forward
// (GatedDeltaNet's conv1d step).
let conv1d_dwc_src: &'static str =
include_str!("shaders/conv1d_depthwise_causal.metal");
sources.insert(
"conv1d_depthwise_causal_forward_f32".into(),
conv1d_dwc_src,
);
sources.insert(
"conv1d_depthwise_causal_backward_dx_f32".into(),
conv1d_dwc_src,
);
sources.insert(
"conv1d_depthwise_causal_backward_dw_f32".into(),
conv1d_dwc_src,
);
// ADR-020 iter-11h-c1: elementwise exp forward + backward.
// Building block for GatedDeltaNet's alpha = exp(-g) state-decay.
let exp_src: &'static str =
include_str!("shaders/exp_elementwise.metal");
sources.insert("exp_f32".into(), exp_src);
sources.insert("exp_backward_f32".into(), exp_src);
// ADR-020 iter-11h-c2: vector outer product (forward + dlhs +
// drhs). Building block for gated_delta_update's
// outer(delta, k) state-update term.
let outer_src: &'static str =
include_str!("shaders/outer_product.metal");
sources.insert("outer_product_f32".into(), outer_src);
sources.insert("outer_product_backward_lhs_f32".into(), outer_src);
sources.insert("outer_product_backward_rhs_f32".into(), outer_src);
// ADR-020 iter-11h-e1: take_along_axis (gather) + scatter-backward.
// Building block for MoE router on GpuTape.
let taa_src: &'static str =
include_str!("shaders/take_along_axis.metal");
sources.insert("take_along_axis_f32".into(), taa_src);
sources.insert("take_along_axis_backward_f32".into(), taa_src);
// ADR-020 iter-11h-misc-1: elementwise divide forward + backward.
let div_src: &'static str =
include_str!("shaders/divide_elementwise.metal");
sources.insert("divide_f32".into(), div_src);
sources.insert("divide_backward_f32".into(), div_src);
// ADR-020 iter-11h-misc-3: elementwise sqrt forward + backward.
let sqrt_src: &'static str =
include_str!("shaders/sqrt_elementwise.metal");
sources.insert("sqrt_f32".into(), sqrt_src);
sources.insert("sqrt_backward_f32".into(), sqrt_src);
let softcap_src: &'static str = include_str!("shaders/softcap.metal");
sources.insert("softcap_f32".into(), softcap_src);
sources.insert("softcap_f16".into(), softcap_src);
sources.insert("softcap_bf16".into(), softcap_src);
// Fused norm-add kernels — Gemma4 post-attention / post-FFN ordering:
// normed = rms_norm(input, weight, eps); output = residual + normed
let fused_norm_add_src: &'static str =
include_str!("shaders/fused_norm_add_bf16.metal");
sources.insert("fused_norm_add_bf16".into(), fused_norm_add_src);
sources.insert("fused_norm_add_no_weight_bf16".into(), fused_norm_add_src);
// Fused head-norm + RoPE f32 kernel — replaces separate rms_norm + rope_neox_f32
let fused_hnr_f32_src: &'static str =
include_str!("shaders/fused_head_norm_rope_f32.metal");
sources.insert("fused_head_norm_rope_f32".into(), fused_hnr_f32_src);
// ADR-028 iter-337 — float4 + simd_sum Phase 1 variant. Phases
// 2-4 byte-identical to v1; race-fix barrier preserved. Env-gated
// via HF2Q_FUSED_HEAD_NORM_ROPE_V2 (default ON, opt-out via =0).
sources.insert("fused_head_norm_rope_f32_v2".into(), fused_hnr_f32_src);
// Fused head-norm + RoPE bf16 kernels (single-token + batch prefill)
// Both entry points live in the same .metal file.
let fused_hnr_bf16_src: &'static str =
include_str!("shaders/fused_head_norm_rope_bf16.metal");
sources.insert("fused_head_norm_rope_bf16".into(), fused_hnr_bf16_src);
sources.insert("fused_head_norm_rope_batch_bf16".into(), fused_hnr_bf16_src);
// Fused norm-add f32 kernels — post-attention / post-FFN / end-of-layer
let fused_norm_add_f32_src: &'static str =
include_str!("shaders/fused_norm_add_f32.metal");
sources.insert("fused_norm_add_f32".into(), fused_norm_add_f32_src);
// ADR-028 iter-331 — float4 + simd_sum variant (peer-pattern,
// ported from llama.cpp kernel_rms_norm_fuse_impl<float4, 3>).
// Env-gated via HF2Q_FUSED_NORM_ADD_V2=1 in the dispatcher
// (default ON since iter-331; opt-out via =0/false/off).
sources.insert("fused_norm_add_f32_v2".into(), fused_norm_add_f32_src);
sources.insert("fused_residual_norm_f32".into(), fused_norm_add_f32_src);
sources.insert("fused_residual_norm_scalar_f32".into(), fused_norm_add_f32_src);
sources.insert("fused_moe_routing_f32".into(), fused_norm_add_f32_src);
// ADR-028 iter-363: V2 (simd_max + simd_sum) variant of MoE routing.
sources.insert("fused_moe_routing_f32_v2".into(), fused_norm_add_f32_src);
sources.insert("fused_moe_routing_batch_f32".into(), fused_norm_add_f32_src);
sources.insert("fused_norm_add_scalar_f32".into(), fused_norm_add_f32_src);
sources.insert("fused_moe_wsum_norm_add_f32".into(), fused_norm_add_f32_src);
sources.insert("fused_moe_wsum_dnorm_add_f32".into(), fused_norm_add_f32_src);
// Argsort kernel (Story 2.3) — MoE top-K routing
let argsort_src: &'static str = include_str!("shaders/argsort.metal");
sources.insert("argsort_desc_f32".into(), argsort_src);
// Gather / index_select kernel (Story 2.4)
let gather_src: &'static str = include_str!("shaders/gather.metal");
sources.insert("gather_f32".into(), gather_src);
// F32 KV cache copy kernel (Session merge S1+S2)
let kv_cache_copy_src: &'static str =
include_str!("shaders/kv_cache_copy.metal");
sources.insert("kv_cache_copy".into(), kv_cache_copy_src);
sources.insert("kv_cache_copy_f32".into(), kv_cache_copy_src);
// Strided copy kernel (Story 2.5)
let copy_src: &'static str = include_str!("shaders/copy.metal");
sources.insert("strided_copy_f32".into(), copy_src);
sources.insert("offset_copy_f32".into(), copy_src);
// Fused-QKV split kernel (ADR-005 W-5b.18 — replaces hf2q CPU
// download → triple-loop split → 3× upload round-trip in
// gpu_delta_net::layer_qkv_deinterleave).
let qkv_split_src: &'static str = include_str!("shaders/qkv_split.metal");
sources.insert("qkv_split_f32".into(), qkv_split_src);
// Tiled-GQA broadcast kernel (ADR-005 W-5b.19 — replaces hf2q CPU
// tiled-replicate at gpu_delta_net::apply_gated_delta_net_chunk
// GQA pre-expansion, ~497 ms / 10.4 ms-per-layer at PP4106).
let repeat_tiled_src: &'static str =
include_str!("shaders/repeat_tiled.metal");
sources.insert("repeat_tiled_f32".into(), repeat_tiled_src);
// Dense F16 GEMM kernel (Story 2.6) — lm_head projection
let dense_gemm_src: &'static str = include_str!("shaders/dense_gemm.metal");
sources.insert("dense_gemm_f16".into(), dense_gemm_src);
sources.insert("dense_matvec_f16".into(), dense_gemm_src);
sources.insert("dense_matvec_f16w_f32io".into(), dense_gemm_src);
// BF16-weight mat-vec: BF16 weights × F32 input → F32 output (decode lm_head)
sources.insert("dense_matvec_bf16w_f32io".into(), dense_gemm_src);
// Pure F32 mat-vec: F32 weights × F32 input → F32 output (decode lm_head)
sources.insert("dense_matvec_f32".into(), dense_gemm_src);
// Standalone FWHT for TurboQuant pre/post-rotation (SIMD shuffle, zero barriers)
let fwht_src: &'static str = include_str!("shaders/fwht_standalone.metal");
sources.insert("fwht_standalone_f32_d256".into(), fwht_src);
sources.insert("fwht_standalone_f32_d512".into(), fwht_src);
// ADR-007 iter-14 D1 SRHT variants: sign pre-mult (for Q) + sign undo (for output)
sources.insert("fwht_sign_premult_f32_d256".into(), fwht_src);
sources.insert("fwht_sign_premult_f32_d512".into(), fwht_src);
sources.insert("fwht_sign_undo_f32_d256".into(), fwht_src);
sources.insert("fwht_sign_undo_f32_d512".into(), fwht_src);
// Fast Hadamard quantize (SIMD shuffle, zero barriers)
let hq_fast_src: &'static str = include_str!("shaders/hadamard_quantize_kv_fast.metal");
sources.insert("hadamard_quantize_kv_fast_d256".into(), hq_fast_src);
sources.insert("hadamard_quantize_kv_fast_d512".into(), hq_fast_src);
// ADR-028 iter-485 (Phase 7d / H4): fused K+V single-position 4-bit encoder.
sources.insert("hadamard_quantize_kv_fast_dual_d256".into(), hq_fast_src);
sources.insert("hadamard_quantize_kv_fast_dual_d512".into(), hq_fast_src);
// Track B (iter-21): higher-bit (5/6-bit) quantize kernels (byte-packed)
sources.insert("hadamard_quantize_kv_hb_d256".into(), hq_fast_src);
sources.insert("hadamard_quantize_kv_hb_d512".into(), hq_fast_src);
// ADR-028 iter-148: fused K+V single-position HB encoder
sources.insert("hadamard_quantize_kv_hb_dual_d256".into(), hq_fast_src);
sources.insert("hadamard_quantize_kv_hb_dual_d512".into(), hq_fast_src);
// ADR-028 Phase 10e.5 (iter-351): no-FWHT V quantize for hybrid path.
// Same byte-packed Lloyd-Max codebook output, but skips the Hadamard
// rotation so dequant in SDPA recovers raw V (no FWHT-undo needed).
sources.insert("kv_quantize_v_no_fwht_d256".into(), hq_fast_src);
sources.insert("kv_quantize_v_no_fwht_d512".into(), hq_fast_src);
// ADR-028 Phase 10c.5 (iter-354): fused F16-K-copy + V-no-FWHT-encode.
// Saves 30 KV-write dispatches/decode-token at gemma4 30L by combining
// the per-layer K-cast and V-encode into a single dispatch (Z-dim).
sources.insert("kv_copy_kf16_quantize_v_no_fwht_d256".into(), hq_fast_src);
sources.insert("kv_copy_kf16_quantize_v_no_fwht_d512".into(), hq_fast_src);
// iter-20 Leg F: TQ KV dequantize kernel (nibbles+norms → F32)
let tq_dq_src: &'static str = include_str!("shaders/tq_dequantize_kv.metal");
sources.insert("tq_dequantize_kv".into(), tq_dq_src);
// Track B (iter-21): higher-bit dequantize kernel (byte-packed indices)
sources.insert("tq_dequantize_hb_kv".into(), tq_dq_src);
// ADR-027 Phase B iter-30 (hf2q sub-sub-iter 23c-β.1): sequence-batch
// dequant variant. Same MSL source; new kernel entry point
// `tq_dequantize_hb_kv_seq` reads positions [start_pos..start_pos+n_tokens)
// in one dispatch (one threadgroup per (kv_head, position)). Unblocks
// hf2q's TQ-aware prefill SDPA path (current per-position kernel
// requires cur_len separate dispatches).
sources.insert("tq_dequantize_hb_kv_seq".into(), tq_dq_src);
// iter-24: native higher-bit (5/6/8-bit) TQ SDPA kernel (byte-packed K/V)
let tq_hb_src: &'static str = include_str!("shaders/flash_attn_vec_tq_hb.metal");
sources.insert("flash_attn_vec_tq_hb_dk256".into(), tq_hb_src);
sources.insert("flash_attn_vec_tq_hb_dk512".into(), tq_hb_src);
// ADR-028 §iter-485 (Phase 7d H3): fused TQ-HB reduce + FWHT-sign-undo.
// Combines flash_attn_vec_reduce + fwht_sign_undo_f32 into a single
// dispatch, saving 1 dispatch + 1 forced barrier per layer per decode
// token. Gated by env flag `HF2Q_TQ_HB_OUT_FUSED=1` in forward_mlx.rs.
let reduce_undo_src: &'static str = include_str!("shaders/flash_attn_vec_reduce_tq_hb_undo.metal");
sources.insert("flash_attn_vec_reduce_tq_hb_undo_dk256".into(), reduce_undo_src);
sources.insert("flash_attn_vec_reduce_tq_hb_undo_dk512".into(), reduce_undo_src);
// ADR-028 Phase 10d (iter-349): hybrid F16-K + TQ-HB-V SDPA kernel.
// Same V-side codebook as flash_attn_vec_tq_hb (5/6/8-bit Lloyd-Max);
// K-side reads F16 dense — peer-equivalent layout, no codebook lookup.
let hybrid_src: &'static str = include_str!("shaders/flash_attn_vec_hybrid.metal");
sources.insert("flash_attn_vec_hybrid_dk256".into(), hybrid_src);
sources.insert("flash_attn_vec_hybrid_dk512".into(), hybrid_src);
// ADR-029 CFA cfa-20260512-fa-peer-port (iter-122): verbatim llama.cpp peer port.
// F16-K + F16-V, DK=DV=256, NWG=1, NSG=1, NE=1. No function constants — baked.
let peer_port_src: &'static str = include_str!("shaders/flash_attn_vec_peer_port_f16.metal");
sources.insert("flash_attn_vec_peer_port_f16_dk256_dv256".into(), peer_port_src);
// ADR-029 iter-134: peer reduce kernel (verbatim port of ggml-metal.metal 7235-7275).
// Pairs with the NWG=32 vec kernel to match peer's actual runtime dispatch.
let peer_port_reduce_src: &'static str =
include_str!("shaders/flash_attn_vec_peer_port_f16_reduce.metal");
sources.insert(
"flash_attn_vec_peer_port_f16_reduce_dv256_nwg32".into(),
peer_port_reduce_src,
);
// ADR-029 iter-135: NWG=32 variant of the verbatim peer port. Same body as
// flash_attn_vec_peer_port_f16.metal with NWG=1→32. Pairs with iter-134 reduce kernel.
let peer_port_nwg32_src: &'static str =
include_str!("shaders/flash_attn_vec_peer_port_f16_nwg32.metal");
sources.insert(
"flash_attn_vec_peer_port_f16_nwg32_dk256_dv256".into(),
peer_port_nwg32_src,
);
// GPU sampling kernels — eliminate logits readback (Phase 6)
let argmax_src: &'static str = include_str!("shaders/argmax.metal");
sources.insert("argmax_f32".into(), argmax_src);
let softmax_sample_src: &'static str =
include_str!("shaders/softmax_sample.metal");
sources.insert("softmax_sample_f32".into(), softmax_sample_src);
// Top-K kernel for Q8 rerank: avoids full-logits readback.
let top_k_src: &'static str = include_str!("shaders/top_k.metal");
sources.insert("top_k_f32".into(), top_k_src);
// MoE GPU routing + weighted reduce (ADR-013 P13.3 perf).
// Replaces CPU softmax+topk round-trip and CPU weighted accumulate.
let moe_stk_src: &'static str =
include_str!("shaders/moe_softmax_topk.metal");
sources.insert("moe_softmax_topk_f32".into(), moe_stk_src);
let moe_wr_src: &'static str =
include_str!("shaders/moe_weighted_reduce.metal");
sources.insert("moe_weighted_reduce_f32".into(), moe_wr_src);
let sdpa_decode_src: &'static str =
include_str!("shaders/sdpa_decode.metal");
sources.insert("sdpa_decode".into(), sdpa_decode_src);
Self {
cache: HashMap::new(),
sources,
}
}
/// Register a shader source at runtime (useful for testing and dynamic
/// kernel generation).
pub fn register_source(&mut self, name: impl Into<String>, source: &'static str) {
let name = name.into();
// Invalidate any cached pipeline for this name since the source changed.
self.cache.remove(&name);
self.sources.insert(name, source);
}
/// Get a compiled compute pipeline for the named kernel function.
///
/// On first call for a given name, this compiles the MSL source into a
/// Metal library, extracts the named function, and creates a
/// `ComputePipelineState`. Subsequent calls return the cached pipeline.
///
/// # Errors
///
/// * `MlxError::KernelNotFound` — no source registered for this name.
/// * `MlxError::ShaderCompilationError` — MSL compilation or pipeline
/// creation failed.
pub fn get_pipeline(
&mut self,
name: &str,
device: &metal::DeviceRef,
) -> Result<&ComputePipelineState> {
if !self.cache.contains_key(name) {
// Slow path: compile the shader.
let source = self.sources.get(name).ok_or_else(|| {
MlxError::KernelNotFound(name.to_string())
})?;
let compile_opts = metal::CompileOptions::new();
let library = device
.new_library_with_source(source, &compile_opts)
.map_err(|msg| MlxError::ShaderCompilationError {
name: name.to_string(),
message: msg,
})?;
let function = library
.get_function(name, None)
.map_err(|msg| MlxError::ShaderCompilationError {
name: name.to_string(),
message: msg,
})?;
// Build the pipeline through a descriptor so we can attach a
// human-readable label. The label propagates into Instruments /
// xctrace Metal System Trace as the per-pipeline identifier
// (`metal-object-label` schema), giving us per-kernel attribution
// instead of the generic "Compute Command 0" placeholder.
//
// `MTLComputePipelineState.label` is read-only after creation per
// the Apple Metal spec; the only supported way to set it is via
// the descriptor before pipeline creation. ADR-015 iter9b.
let descriptor = ComputePipelineDescriptor::new();
descriptor.set_compute_function(Some(&function));
descriptor.set_label(name);
// ADR-028 iter-376: threadGroupSizeIsMultipleOfThreadExecutionWidth
// hint allows the Metal compiler to skip bounds checks and use more
// aggressive codegen. Opt-in via HF2Q_PIPELINE_TG_MULT_HINT=1.
// SAFETY: every dispatched threadgroup MUST be a multiple of 32 at
// runtime — Apple specifies undefined behavior otherwise. Our hot
// kernels use tg_size ∈ {32, 64, 256, 1024} (all multiples of 32).
if std::env::var("HF2Q_PIPELINE_TG_MULT_HINT").ok().as_deref() == Some("1") {
descriptor.set_thread_group_size_is_multiple_of_thread_execution_width(true);
}
let pipeline = device
.new_compute_pipeline_state(&descriptor)
.map_err(|msg| MlxError::ShaderCompilationError {
name: name.to_string(),
message: msg,
})?;
self.cache.insert(name.to_string(), pipeline);
}
// At this point the pipeline is guaranteed to be in the cache.
// We use `ok_or_else` instead of `expect` to satisfy the no-panic policy.
self.cache.get(name).ok_or_else(|| {
MlxError::KernelNotFound(name.to_string())
})
}
/// Get a compiled compute pipeline for the named kernel, specialized with
/// Metal function constants (both bool and i32 in one call).
///
/// `bool_constants` contains `(index, value)` pairs mapping to
/// `[[function_constant(index)]]` bool declarations in the MSL shader.
/// `int_constants` contains `(index, value)` pairs mapping to
/// `[[function_constant(index)]]` int (int32_t) declarations in the MSL
/// shader.
///
/// Pipelines are cached by a composite key:
/// `"<name>|<index>:b<0|1>|...|<index>:i<value>|..."`. The 'b' prefix
/// marks bool entries and the 'i' prefix marks i32 entries, making the
/// format unambiguous regardless of constant ordering. Distinct
/// `(name, constants)` combinations each compile to a separate pipeline;
/// the slow compilation path runs at most once per unique combination.
///
/// # Errors
///
/// * `MlxError::KernelNotFound` — no source registered for this name.
/// * `MlxError::ShaderCompilationError` — MSL compilation, function
/// specialisation, or pipeline creation failed.
pub fn get_pipeline_with_constants(
&mut self,
name: &str,
device: &metal::DeviceRef,
bool_constants: &[(usize, bool)],
int_constants: &[(usize, i32)],
) -> Result<&ComputePipelineState> {
// Build a composite cache key so distinct constant combinations each
// compile to their own pipeline. Bool entries use the 'b' type marker
// and i32 entries use 'i'; this prevents a collision between, e.g.,
// bool index 5 value 1 and int index 5 value 1.
let mut cache_key = name.to_string();
for &(index, value) in bool_constants {
cache_key.push('|');
cache_key.push_str(&index.to_string());
cache_key.push_str(if value { ":b1" } else { ":b0" });
}
for &(index, value) in int_constants {
cache_key.push('|');
cache_key.push_str(&index.to_string());
cache_key.push(':');
cache_key.push('i');
cache_key.push_str(&value.to_string());
}
if !self.cache.contains_key(&cache_key) {
// Slow path: compile the shader with function constant specialisation.
let source = self.sources.get(name).ok_or_else(|| {
MlxError::KernelNotFound(name.to_string())
})?;
let compile_opts = metal::CompileOptions::new();
let library = device
.new_library_with_source(source, &compile_opts)
.map_err(|msg| MlxError::ShaderCompilationError {
name: name.to_string(),
message: msg,
})?;
// Build the FunctionConstantValues object with all bool and i32
// constants. Metal's set_constant_value_at_index reads the value
// through a raw pointer; the pointed-to bytes must match the size
// declared in the MSL shader (1 byte for bool, 4 bytes for int).
let fcv = FunctionConstantValues::new();
for &(index, value) in bool_constants {
// MTLDataType::Bool = 53 (metal-rs argument.rs).
// The Metal runtime reads it as an Objective-C BOOL (uint8_t).
let v: u8 = if value { 1 } else { 0 };
fcv.set_constant_value_at_index(
(&v as *const u8).cast::<std::ffi::c_void>(),
MTLDataType::Bool,
index as u64,
);
}
for &(index, value) in int_constants {
// MTLDataType::Int = 29 (metal-rs argument.rs).
// The Metal runtime reads 4 bytes as a signed 32-bit integer,
// matching the Metal shader type `constant int`.
fcv.set_constant_value_at_index(
(&value as *const i32).cast::<std::ffi::c_void>(),
MTLDataType::Int,
index as u64,
);
}
let function = library
.get_function(name, Some(fcv))
.map_err(|msg| MlxError::ShaderCompilationError {
name: name.to_string(),
message: msg,
})?;
// Label this specialisation with the full composite cache key
// (e.g. `kernel_mul_mv_q4_0_f32|0:b1|3:i32`) so xctrace Metal
// System Trace shows each function-constant variant as a distinct
// pipeline. Without this, all specialisations share a generic
// "Compute Command 0" identifier and we cannot attribute µs/token
// to a specific (kernel, constants) combination. ADR-015 iter9b.
let descriptor = ComputePipelineDescriptor::new();
descriptor.set_compute_function(Some(&function));
descriptor.set_label(&cache_key);
// ADR-028 iter-376: same hint as primary pipeline path.
if std::env::var("HF2Q_PIPELINE_TG_MULT_HINT").ok().as_deref() == Some("1") {
descriptor.set_thread_group_size_is_multiple_of_thread_execution_width(true);
}
let pipeline = device
.new_compute_pipeline_state(&descriptor)
.map_err(|msg| MlxError::ShaderCompilationError {
name: name.to_string(),
message: msg,
})?;
self.cache.insert(cache_key.clone(), pipeline);
}
self.cache.get(&cache_key).ok_or_else(|| {
MlxError::KernelNotFound(name.to_string())
})
}
/// Get a compiled compute pipeline for the named kernel, specialized with
/// Metal bool function constants.
///
/// The `bool_constants` slice contains `(index, value)` pairs. Each pair
/// maps to a `[[function_constant(index)]]` declaration in the MSL shader.
///
/// This is a thin wrapper around [`get_pipeline_with_constants`] that
/// passes an empty `int_constants` slice. Existing callers continue to
/// work without modification; the cache-key format for pure-bool pipelines
/// is compatible (bool entries carry the 'b' type marker, which is the
/// only format ever written by this wrapper).
///
/// # Errors
///
/// * `MlxError::KernelNotFound` — no source registered for this name.
/// * `MlxError::ShaderCompilationError` — MSL compilation, function
/// specialisation, or pipeline creation failed.
pub fn get_pipeline_with_bool_constants(
&mut self,
name: &str,
device: &metal::DeviceRef,
bool_constants: &[(usize, bool)],
) -> Result<&ComputePipelineState> {
self.get_pipeline_with_constants(name, device, bool_constants, &[])
}
/// Check if a pipeline for the given name is already compiled and cached.
pub fn is_cached(&self, name: &str) -> bool {
self.cache.contains_key(name)
}
/// Number of compiled pipelines currently in the cache.
pub fn cached_count(&self) -> usize {
self.cache.len()
}
/// Number of registered shader sources.
pub fn source_count(&self) -> usize {
self.sources.len()
}
}
impl Default for KernelRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Minimal Metal shader that uses a single int function constant.
///
/// The kernel writes the constant value N into the first element of the
/// output buffer, allowing the test to verify that the Metal compiler
/// actually sees distinct specialisations for N=4 and N=8.
///
/// The shader is intentionally trivial — we only need it to *compile* with
/// an int function constant; correctness of the kernel logic is not under
/// test here.
const INT_FC_TEST_SHADER: &str = r#"
#include <metal_stdlib>
using namespace metal;
constant int test_N [[function_constant(100)]];
kernel void int_fc_test_kernel(
device int* out [[buffer(0)]],
uint tid [[thread_position_in_grid]])
{
if (tid == 0) {
out[0] = test_N;
}
}
"#;
/// Verify that `get_pipeline_with_constants` produces distinct cached
/// pipelines for different i32 function-constant values, and that
/// `get_pipeline_with_bool_constants` (the backward-compat wrapper) still
/// works correctly with the new 'b'-prefixed cache-key format.
///
/// This test requires a real Metal device and is therefore marked
/// `#[ignore]` on non-Apple platforms, but runs unconditionally on macOS.
#[test]
fn test_int_fc_distinct_pipelines_and_bool_compat() {
let device = metal::Device::system_default()
.expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
let mut registry = KernelRegistry::new();
// Register the inline test shader under a name that cannot collide with
// any production kernel.
registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
// Compile with N=4.
let p4_ptr = registry
.get_pipeline_with_constants(
"int_fc_test_kernel",
&device,
&[], // no bool constants
&[(100, 4_i32)], // int constant index 100 = 4
)
.expect("pipeline N=4 should compile") as *const _;
// Cache must now have exactly 1 entry for this kernel.
// (Other production kernels may already be in cache from new(); here
// we check that the N=4 key was inserted.)
let count_after_n4 = registry.cached_count();
// Compile with N=8 — must produce a SEPARATE pipeline.
let p8_ptr = registry
.get_pipeline_with_constants(
"int_fc_test_kernel",
&device,
&[],
&[(100, 8_i32)],
)
.expect("pipeline N=8 should compile") as *const _;
// Cache must have grown by exactly 1.
assert_eq!(
registry.cached_count(),
count_after_n4 + 1,
"N=8 must produce a new cache entry"
);
// The two pipelines must be distinct objects in the cache.
assert_ne!(
p4_ptr, p8_ptr,
"N=4 and N=8 specialisations must be separate ComputePipelineState objects"
);
// A second call with N=4 must return the SAME pipeline (cache hit, no
// new compilation).
let p4_again_ptr = registry
.get_pipeline_with_constants(
"int_fc_test_kernel",
&device,
&[],
&[(100, 4_i32)],
)
.expect("pipeline N=4 cache hit should succeed") as *const _;
assert_eq!(
registry.cached_count(),
count_after_n4 + 1,
"repeated N=4 call must be a cache hit, not a new entry"
);
assert_eq!(
p4_ptr, p4_again_ptr,
"repeated N=4 call must return the same pipeline pointer"
);
// Verify backward compatibility: get_pipeline_with_bool_constants must
// still route through get_pipeline_with_constants and produce a cached
// pipeline without panicking.
//
// We register a separate bool-constant shader that does NOT use a bool
// function constant (so the Metal compiler ignores missing FCs for
// this trivial case) — but the call path and cache-key format are what
// matter here. We reuse the int_fc_test_kernel source; the bool FC is
// simply unused by the shader (Metal allows unused FCs when the shader
// declares them with `function_constant` but the value is never read).
//
// To avoid a Metal compiler error for an undeclared function constant,
// we register a separate bare-kernel shader for the bool wrapper test.
const BARE_SHADER: &str = r#"
#include <metal_stdlib>
using namespace metal;
kernel void bare_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
if (tid == 0) { out[0] = 42; }
}
"#;
registry.register_source("bare_kernel", BARE_SHADER);
let count_before_bool = registry.cached_count();
let _bool_pipeline = registry
.get_pipeline_with_bool_constants("bare_kernel", &device, &[])
.expect("bool-constants wrapper with empty slice must succeed");
assert_eq!(
registry.cached_count(),
count_before_bool + 1,
"bool-constants wrapper must insert one new cache entry"
);
}
/// Verify that the `MTLComputePipelineState.label` produced by
/// `get_pipeline` and `get_pipeline_with_constants` actually propagates
/// from the descriptor to the resulting pipeline state.
///
/// This is the in-process smoke check for ADR-015 iter9b: we cannot
/// reach into xctrace from Rust, but we can read back the same `label`
/// property xctrace consumes via `ComputePipelineStateRef::label()`.
/// If labels are missing or wrong here, the MST trace will also show
/// generic identifiers — so this test gates the iter9 retry's
/// per-Q4_0-kernel attribution.
#[test]
fn test_pipeline_labels_propagate_for_mst() {
let device = metal::Device::system_default()
.expect("no Metal device — run on Apple Silicon or x86 Mac with Metal support");
let mut registry = KernelRegistry::new();
// Reuse the same trivial shaders as the int-FC test.
registry.register_source("int_fc_test_kernel", INT_FC_TEST_SHADER);
const BARE_SHADER_LABEL_TEST: &str = r#"
#include <metal_stdlib>
using namespace metal;
kernel void label_smoke_kernel(device int* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) {
if (tid == 0) { out[0] = 7; }
}
"#;
registry.register_source("label_smoke_kernel", BARE_SHADER_LABEL_TEST);
// Plain get_pipeline path — label must equal the kernel name.
// Capture as owned String so the cache borrow is released before
// the next get_pipeline_with_constants call below.
let plain_label = registry
.get_pipeline("label_smoke_kernel", &device)
.expect("plain pipeline must compile")
.label()
.to_string();
assert_eq!(
plain_label, "label_smoke_kernel",
"get_pipeline must label the pipeline with the kernel name (xctrace MST attribution)"
);
// Constants path — label must equal the composite cache key so each
// function-constant variant is individually attributable in MST.
// We capture the label as an owned String to release the borrow on
// the cache before fetching the next specialisation.
let label_v7 = registry
.get_pipeline_with_constants(
"int_fc_test_kernel",
&device,
&[],
&[(100, 7_i32)],
)
.expect("specialised pipeline must compile")
.label()
.to_string();
assert_eq!(
label_v7, "int_fc_test_kernel|100:i7",
"get_pipeline_with_constants must label with the cache_key so each \
specialisation is distinct in xctrace MST"
);
// A second specialisation must produce a different label.
let label_v13 = registry
.get_pipeline_with_constants(
"int_fc_test_kernel",
&device,
&[],
&[(100, 13_i32)],
)
.expect("second specialised pipeline must compile")
.label()
.to_string();
assert_eq!(label_v13, "int_fc_test_kernel|100:i13");
assert_ne!(
label_v7, label_v13,
"distinct constant values must yield distinct pipeline labels"
);
}
}