1#![allow(unsafe_op_in_unsafe_fn)]
24use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33#[derive(Clone)]
35pub enum Thunk {
36 Nop,
38 Sgemm {
40 a: usize,
41 b: usize,
42 c: usize,
43 m: u32,
44 k: u32,
45 n: u32,
46 },
47 DenseSolveF64 {
53 a: usize,
54 b: usize,
55 x: usize,
56 n: u32,
57 nrhs: u32,
58 },
59 DenseSolveF32 {
62 a: usize,
63 b: usize,
64 x: usize,
65 n: u32,
66 nrhs: u32,
67 },
68 BatchedDenseSolveF64 {
73 a: usize,
74 b: usize,
75 x: usize,
76 batch: u32,
77 n: u32,
78 nrhs: u32,
79 },
80 BatchedDenseSolveF32 {
82 a: usize,
83 b: usize,
84 x: usize,
85 batch: u32,
86 n: u32,
87 nrhs: u32,
88 },
89 BatchedDgemmF64 {
95 a: usize,
96 b: usize,
97 c: usize,
98 batch: u32,
99 m: u32,
100 k: u32,
101 n: u32,
102 },
103 BatchedSgemm {
110 a: usize,
111 b: usize,
112 c: usize,
113 batch: u32,
114 m: u32,
115 k: u32,
116 n: u32,
117 },
118 Dgemm {
120 a: usize,
121 b: usize,
122 c: usize,
123 m: u32,
124 k: u32,
125 n: u32,
126 },
127 TransposeF64 {
131 src: usize,
132 dst: usize,
133 in_total: u32,
134 out_dims: Vec<u32>,
135 in_strides: Vec<u32>,
136 },
137 ActivationF64 {
141 src: usize,
142 dst: usize,
143 len: u32,
144 kind: Activation,
145 },
146 ComplexNormSqF32 {
150 src: usize,
151 dst: usize,
152 len: u32,
154 },
155 ComplexNormSqBackwardF32 {
159 z: usize,
160 g: usize,
161 dz: usize,
162 len: u32,
163 },
164 ConjugateC64 { src: usize, dst: usize, len: u32 },
167 ActivationC64 {
174 src: usize,
175 dst: usize,
176 len: u32,
177 kind: Activation,
178 },
179 ReduceSumF64 {
183 src: usize,
184 dst: usize,
185 outer: u32,
186 reduced: u32,
187 inner: u32,
188 },
189 CopyF64 { src: usize, dst: usize, len: u32 },
192 BinaryFullF64 {
196 lhs: usize,
197 rhs: usize,
198 dst: usize,
199 len: u32,
200 lhs_len: u32,
201 rhs_len: u32,
202 op: BinaryOp,
203 out_dims_bcast: Vec<u32>,
206 bcast_lhs_strides: Vec<u32>,
207 bcast_rhs_strides: Vec<u32>,
208 },
209 ConcatF64 {
213 dst: usize,
214 outer: u32,
215 inner: u32,
216 total_axis: u32,
217 inputs: Vec<(usize, u32)>,
218 },
219 BinaryFullC64 {
227 lhs: usize,
228 rhs: usize,
229 dst: usize,
230 len: u32,
233 lhs_len: u32,
234 rhs_len: u32,
235 op: BinaryOp,
236 out_dims_bcast: Vec<u32>,
237 bcast_lhs_strides: Vec<u32>,
238 bcast_rhs_strides: Vec<u32>,
239 },
240 Scan {
249 body: Arc<ThunkSchedule>,
250 body_init: Arc<Vec<u8>>, body_input_off: usize, body_output_off: usize, outer_init_off: usize, outer_final_off: usize, length: u32,
256 carry_bytes: u32, save_trajectory: bool,
262 xs_inputs: Arc<Vec<(usize, usize, u32)>>,
267 bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
273 num_checkpoints: u32,
279 },
280
281 ScanBackward {
289 body_vjp: Arc<ThunkSchedule>,
290 body_init: Arc<Vec<u8>>,
291 body_carry_in_off: usize, body_x_offs: Arc<Vec<usize>>, body_d_output_off: usize, body_dcarry_out_off: usize, outer_init_off: usize, outer_traj_off: usize, outer_upstream_off: usize, outer_xs_offs: Arc<Vec<(usize, u32)>>,
301 outer_dinit_off: usize, length: u32,
303 carry_bytes: u32,
304 carry_elem_size: u32,
310 save_trajectory: bool, num_checkpoints: u32,
317 forward_body: Option<Arc<ThunkSchedule>>,
321 forward_body_init: Option<Arc<Vec<u8>>>,
323 forward_body_carry_in_off: usize,
326 forward_body_output_off: usize,
327 forward_body_x_offs: Arc<Vec<usize>>,
330 },
331
332 ScanBackwardXs {
339 body_vjp: Arc<ThunkSchedule>,
340 body_init: Arc<Vec<u8>>,
341 body_carry_in_off: usize,
342 body_x_offs: Arc<Vec<usize>>,
343 body_d_output_off: usize,
344 body_dcarry_out_off: usize,
345 body_dxs_out_off: usize, outer_init_off: usize,
347 outer_traj_off: usize,
348 outer_upstream_off: usize,
349 outer_xs_offs: Arc<Vec<(usize, u32)>>,
350 outer_dxs_off: usize, length: u32,
352 carry_bytes: u32,
353 carry_elem_size: u32,
355 per_step_bytes: u32, save_trajectory: bool,
357 num_checkpoints: u32,
365 forward_body: Option<Arc<ThunkSchedule>>,
366 forward_body_init: Option<Arc<Vec<u8>>>,
367 forward_body_carry_in_off: usize,
368 forward_body_output_off: usize,
369 forward_body_x_offs: Arc<Vec<usize>>,
370 },
371 CustomFn {
376 body: Arc<ThunkSchedule>,
377 body_init: Arc<Vec<u8>>,
378 inputs: Arc<Vec<(usize, usize, u32)>>,
380 body_output_off: usize,
381 outer_output_off: usize,
382 out_bytes: u32,
383 },
384 FusedMmBiasAct {
386 a: usize,
387 w: usize,
388 bias: usize,
389 c: usize,
390 m: u32,
391 k: u32,
392 n: u32,
393 act: Option<Activation>,
394 },
395 FusedResidualLN {
397 x: usize,
398 res: usize,
399 bias: usize,
400 g: usize,
401 b: usize,
402 out: usize,
403 rows: u32,
404 h: u32,
405 eps: f32,
406 has_bias: bool,
407 },
408 FusedResidualRmsNorm {
410 x: usize,
411 res: usize,
412 bias: usize,
413 g: usize,
414 b: usize,
415 out: usize,
416 rows: u32,
417 h: u32,
418 eps: f32,
419 has_bias: bool,
420 },
421 BiasAdd {
423 src: usize,
424 bias: usize,
425 dst: usize,
426 m: u32,
427 n: u32,
428 },
429 BinaryFull {
444 lhs: usize,
445 rhs: usize,
446 dst: usize,
447 len: u32,
448 lhs_len: u32,
449 rhs_len: u32,
450 op: BinaryOp,
451 out_dims_bcast: Vec<u32>,
453 bcast_lhs_strides: Vec<u32>,
455 bcast_rhs_strides: Vec<u32>,
457 },
458 ActivationInPlace {
460 data: usize,
461 len: u32,
462 act: Activation,
463 },
464 Gather {
466 table: usize,
467 table_len: u32,
468 idx: usize,
469 dst: usize,
470 num_idx: u32,
471 trailing: u32,
472 },
473 Narrow {
475 src: usize,
476 dst: usize,
477 outer: u32,
478 src_stride: u32,
479 dst_stride: u32,
480 inner: u32,
481 elem_bytes: u8,
482 },
483 Copy { src: usize, dst: usize, len: u32 },
485 LayerNorm {
487 src: usize,
488 g: usize,
489 b: usize,
490 dst: usize,
491 rows: u32,
492 h: u32,
493 eps: f32,
494 },
495 GroupNorm {
497 src: usize,
498 g: usize,
499 b: usize,
500 dst: usize,
501 n: u32,
502 c: u32,
503 h: u32,
504 w: u32,
505 num_groups: u32,
506 eps: f32,
507 },
508 LayerNorm2d {
510 src: usize,
511 g: usize,
512 b: usize,
513 dst: usize,
514 n: u32,
515 c: u32,
516 h: u32,
517 w: u32,
518 eps: f32,
519 },
520 ConvTranspose2d {
522 src: usize,
523 weight: usize,
524 dst: usize,
525 n: u32,
526 c_in: u32,
527 h: u32,
528 w_in: u32,
529 c_out: u32,
530 h_out: u32,
531 w_out: u32,
532 kh: u32,
533 kw: u32,
534 sh: u32,
535 sw: u32,
536 ph: u32,
537 pw: u32,
538 dh: u32,
539 dw: u32,
540 groups: u32,
541 },
542 ResizeNearest2x {
544 src: usize,
545 dst: usize,
546 n: u32,
547 c: u32,
548 h: u32,
549 w: u32,
550 },
551 AxialRope2d {
553 src: usize,
554 dst: usize,
555 batch: u32,
556 seq: u32,
557 hidden: u32,
558 end_x: u32,
559 end_y: u32,
560 head_dim: u32,
561 num_heads: u32,
562 theta: f32,
563 repeat_factor: u32,
564 },
565 RmsNorm {
568 src: usize,
569 g: usize,
570 b: usize,
571 dst: usize,
572 rows: u32,
573 h: u32,
574 eps: f32,
575 },
576 Softmax { data: usize, rows: u32, cols: u32 },
578 Cumsum {
581 src: usize,
582 dst: usize,
583 rows: u32,
584 cols: u32,
585 exclusive: bool,
586 },
587 SelectiveScan {
591 x: usize,
592 delta: usize,
593 a: usize,
594 b: usize,
595 c: usize,
596 dst: usize,
597 batch: u32,
598 seq: u32,
599 hidden: u32,
600 state_size: u32,
601 },
602
603 GatedDeltaNet {
607 q: usize,
608 k: usize,
609 v: usize,
610 g: usize,
611 beta: usize,
612 state: usize,
615 dst: usize,
616 batch: u32,
617 seq: u32,
618 heads: u32,
619 state_size: u32,
620 },
621
622 Conv2D1x1 {
632 src: usize,
633 weight: usize,
634 dst: usize,
635 n: u32,
636 c_in: u32,
637 c_out: u32,
638 hw: u32,
639 },
640
641 DequantMatMul {
645 x: usize,
646 w_q: usize, scale: usize, zp: usize, dst: usize,
650 m: u32,
651 k: u32,
652 n: u32,
653 block_size: u32,
654 is_asymmetric: bool,
655 },
656
657 DequantMatMulGguf {
667 x: usize, w_q: usize, dst: usize, m: u32,
671 k: u32,
672 n: u32,
673 scheme: rlx_ir::quant::QuantScheme,
674 },
675
676 DequantMatMulInt4 {
678 x: usize,
679 w_q: usize,
680 scale: usize,
681 zp: usize,
682 dst: usize,
683 m: u32,
684 k: u32,
685 n: u32,
686 block_size: u32,
687 is_asymmetric: bool,
688 },
689
690 DequantMatMulFp8 {
692 x: usize,
693 w_q: usize,
694 scale: usize,
695 dst: usize,
696 m: u32,
697 k: u32,
698 n: u32,
699 e5m2: bool,
700 },
701
702 DequantMatMulNvfp4 {
704 x: usize,
705 w_q: usize,
706 scale: usize,
707 global_scale: usize,
708 dst: usize,
709 m: u32,
710 k: u32,
711 n: u32,
712 },
713
714 LoraMatMul {
718 x: usize,
719 w: usize,
720 a: usize,
721 b: usize,
722 dst: usize,
723 m: u32,
724 k: u32,
725 n: u32,
726 r: u32,
727 scale: f32,
728 },
729 Sample {
733 logits: usize,
734 dst: usize,
735 batch: u32,
736 vocab: u32,
737 top_k: u32, top_p: f32, temperature: f32, seed: u64,
741 },
742 Attention {
753 q: usize,
754 k: usize,
755 v: usize,
756 mask: usize,
757 out: usize,
758 batch: u32,
759 seq: u32,
761 kv_seq: u32,
763 heads: u32,
764 head_dim: u32,
765 mask_kind: rlx_ir::op::MaskKind,
766 q_row_stride: u32,
767 k_row_stride: u32,
768 v_row_stride: u32,
769 bhsd: bool,
777 },
778 AttentionBackward {
780 q: usize,
781 k: usize,
782 v: usize,
783 dy: usize,
784 mask: usize,
785 out: usize,
786 batch: u32,
787 seq: u32,
788 kv_seq: u32,
789 heads: u32,
790 head_dim: u32,
791 mask_kind: rlx_ir::op::MaskKind,
792 wrt: rlx_ir::op::AttentionBwdWrt,
793 bhsd: bool,
794 },
795 Rope {
801 src: usize,
802 cos: usize,
803 sin: usize,
804 dst: usize,
805 batch: u32,
806 seq: u32,
807 hidden: u32,
808 head_dim: u32,
809 n_rot: u32,
810 cos_len: u32,
811 src_row_stride: u32,
812 },
813 FusedAttnBlock {
816 hidden: usize,
817 qkv_w: usize,
818 out_w: usize,
819 mask: usize,
820 out: usize,
821 qkv_b: usize,
822 out_b: usize, cos: usize,
824 sin: usize,
825 cos_len: u32, batch: u32,
827 seq: u32,
828 hs: u32,
829 nh: u32,
830 dh: u32,
831 has_bias: bool,
832 has_rope: bool,
833 },
834 FusedBertLayer {
837 hidden: usize,
839 qkv_w: usize,
840 qkv_b: usize,
841 out_w: usize,
842 out_b: usize,
843 mask: usize,
844 ln1_g: usize,
846 ln1_b: usize,
847 eps1: f32,
848 fc1_w: usize,
850 fc1_b: usize,
851 fc2_w: usize,
852 fc2_b: usize,
853 ln2_g: usize,
855 ln2_b: usize,
856 eps2: f32,
857 out: usize,
859 batch: u32,
861 seq: u32,
862 hs: u32,
863 nh: u32,
864 dh: u32,
865 int_dim: u32,
866 },
867 FusedNomicLayer {
869 hidden: usize,
870 qkv_w: usize,
871 out_w: usize,
872 mask: usize,
873 cos: usize,
874 sin: usize,
875 cos_len: u32,
876 ln1_g: usize,
877 ln1_b: usize,
878 eps1: f32,
879 fc11_w: usize,
880 fc12_w: usize,
881 fc2_w: usize,
882 ln2_g: usize,
883 ln2_b: usize,
884 eps2: f32,
885 out: usize,
886 batch: u32,
887 seq: u32,
888 hs: u32,
889 nh: u32,
890 dh: u32,
891 int_dim: u32,
892 },
893 FusedSwiGLU {
897 src: usize,
898 dst: usize,
899 n_half: u32,
900 total: u32,
901 gate_first: bool,
902 },
903 Concat {
908 dst: usize,
909 outer: u32,
910 inner: u32,
911 total_axis: u32,
912 inputs: Vec<(usize, u32)>,
913 },
914 Compare {
916 lhs: usize,
917 rhs: usize,
918 dst: usize,
919 len: u32,
920 op: CmpOp,
921 },
922 Reduce {
930 src: usize,
931 dst: usize,
932 outer: u32,
933 reduced: u32,
934 inner: u32,
935 op: ReduceOp,
936 },
937 TopK {
941 src: usize,
942 dst: usize,
943 outer: u32,
944 axis_dim: u32,
945 k: u32,
946 },
947 GroupedMatMul {
951 input: usize,
952 weight: usize,
953 expert_idx: usize,
954 dst: usize,
955 m: u32,
956 k_dim: u32,
957 n: u32,
958 num_experts: u32,
959 },
960 DequantGroupedMatMulGguf {
962 input: usize,
963 w_q: usize,
964 expert_idx: usize,
965 dst: usize,
966 m: u32,
967 k_dim: u32,
968 n: u32,
969 num_experts: u32,
970 scheme: rlx_ir::quant::QuantScheme,
971 },
972 DequantMoEWeightsGguf {
974 w_q: usize,
975 dst: usize,
976 k_dim: u32,
977 n: u32,
978 num_experts: u32,
979 scheme: rlx_ir::quant::QuantScheme,
980 },
981 ScatterAdd {
984 updates: usize,
985 indices: usize,
986 dst: usize,
987 num_updates: u32,
988 out_dim: u32,
989 trailing: u32,
990 },
991 Where {
993 cond: usize,
994 on_true: usize,
995 on_false: usize,
996 dst: usize,
997 len: u32,
998 },
999 Transpose {
1005 src: usize,
1006 dst: usize,
1007 in_total: u32,
1008 out_dims: Vec<u32>,
1009 in_strides: Vec<u32>,
1010 },
1011 GatherAxis {
1016 table: usize,
1017 idx: usize,
1018 dst: usize,
1019 outer: u32,
1020 axis_dim: u32,
1021 num_idx: u32,
1022 trailing: u32,
1023 },
1024 Pool2D {
1028 src: usize,
1029 dst: usize,
1030 n: u32,
1031 c: u32,
1032 h: u32,
1033 w: u32,
1034 h_out: u32,
1035 w_out: u32,
1036 kh: u32,
1037 kw: u32,
1038 sh: u32,
1039 sw: u32,
1040 ph: u32,
1041 pw: u32,
1042 kind: ReduceOp,
1043 },
1044 Conv2D {
1049 src: usize,
1050 weight: usize,
1051 dst: usize,
1052 n: u32,
1053 c_in: u32,
1054 h: u32,
1055 w: u32,
1056 c_out: u32,
1057 h_out: u32,
1058 w_out: u32,
1059 kh: u32,
1060 kw: u32,
1061 sh: u32,
1062 sw: u32,
1063 ph: u32,
1064 pw: u32,
1065 dh: u32,
1066 dw: u32,
1067 groups: u32,
1068 },
1069
1070 QMatMul {
1078 x: usize,
1079 w: usize,
1080 bias: usize,
1081 out: usize,
1082 m: u32,
1083 k: u32,
1084 n: u32,
1085 x_zp: i32,
1086 w_zp: i32,
1087 out_zp: i32,
1088 mult: f32,
1089 },
1090
1091 QConv2d {
1095 x: usize,
1096 w: usize,
1097 bias: usize,
1098 out: usize,
1099 n: u32,
1100 c_in: u32,
1101 h: u32,
1102 w_in: u32,
1103 c_out: u32,
1104 h_out: u32,
1105 w_out: u32,
1106 kh: u32,
1107 kw: u32,
1108 sh: u32,
1109 sw: u32,
1110 ph: u32,
1111 pw: u32,
1112 dh: u32,
1113 dw: u32,
1114 groups: u32,
1115 x_zp: i32,
1116 w_zp: i32,
1117 out_zp: i32,
1118 mult: f32,
1119 },
1120
1121 Quantize {
1128 x: usize,
1129 q: usize,
1130 len: u32,
1131 chan_axis: u32,
1132 chan_dim: u32,
1133 inner: u32,
1134 scales: Vec<f32>,
1135 zero_points: Vec<i32>,
1136 },
1137
1138 Dequantize {
1140 q: usize,
1141 x: usize,
1142 len: u32,
1143 chan_axis: u32,
1144 chan_dim: u32,
1145 inner: u32,
1146 scales: Vec<f32>,
1147 zero_points: Vec<i32>,
1148 },
1149
1150 FakeQuantize {
1161 x: usize,
1162 out: usize,
1163 len: u32,
1164 chan_axis: u32,
1165 chan_dim: u32,
1166 inner: u32,
1167 bits: u8,
1168 ste: rlx_ir::op::SteKind,
1172 scale_mode: rlx_ir::op::ScaleMode,
1177 state_off: Option<usize>,
1181 },
1182
1183 FakeQuantizeBackward {
1188 x: usize,
1189 dy: usize,
1190 dx: usize,
1191 len: u32,
1192 chan_axis: u32,
1193 chan_dim: u32,
1194 inner: u32,
1195 bits: u8,
1196 ste: rlx_ir::op::SteKind,
1197 },
1198
1199 FakeQuantizeLSQ {
1202 x: usize,
1203 scale_off: usize,
1204 out: usize,
1205 len: u32,
1206 chan_axis: u32,
1207 chan_dim: u32,
1208 inner: u32,
1209 bits: u8,
1210 },
1211
1212 FakeQuantizeLSQBackwardX {
1215 x: usize,
1216 scale_off: usize,
1217 dy: usize,
1218 dx: usize,
1219 len: u32,
1220 chan_axis: u32,
1221 chan_dim: u32,
1222 inner: u32,
1223 bits: u8,
1224 },
1225
1226 FakeQuantizeLSQBackwardScale {
1231 x: usize,
1232 scale_off: usize,
1233 dy: usize,
1234 dscale: usize,
1235 len: u32,
1236 chan_axis: u32,
1237 chan_dim: u32,
1238 inner: u32,
1239 bits: u8,
1240 },
1241
1242 ReluBackward {
1244 x: usize,
1245 dy: usize,
1246 dx: usize,
1247 len: u32,
1248 },
1249 ReluBackwardF64 {
1255 x: usize,
1256 dy: usize,
1257 dx: usize,
1258 len: u32,
1259 },
1260
1261 ActivationBackward {
1266 x: usize,
1267 dy: usize,
1268 dx: usize,
1269 len: u32,
1270 kind: Activation,
1271 },
1272 ActivationBackwardF64 {
1278 x: usize,
1279 dy: usize,
1280 dx: usize,
1281 len: u32,
1282 kind: Activation,
1283 },
1284
1285 LayerNormBackwardInput {
1288 x: usize,
1289 gamma: usize,
1290 dy: usize,
1291 dx: usize,
1292 rows: u32,
1293 h: u32,
1294 eps: f32,
1295 },
1296
1297 LayerNormBackwardGamma {
1299 x: usize,
1300 dy: usize,
1301 dgamma: usize,
1302 rows: u32,
1303 h: u32,
1304 eps: f32,
1305 },
1306
1307 RmsNormBackwardInput {
1308 x: usize,
1309 gamma: usize,
1310 beta: usize,
1311 dy: usize,
1312 dx: usize,
1313 rows: u32,
1314 h: u32,
1315 eps: f32,
1316 },
1317 RmsNormBackwardGamma {
1318 x: usize,
1319 gamma: usize,
1320 beta: usize,
1321 dy: usize,
1322 dgamma: usize,
1323 rows: u32,
1324 h: u32,
1325 eps: f32,
1326 },
1327 RmsNormBackwardBeta {
1328 x: usize,
1329 gamma: usize,
1330 beta: usize,
1331 dy: usize,
1332 dbeta: usize,
1333 rows: u32,
1334 h: u32,
1335 eps: f32,
1336 },
1337 RopeBackward {
1338 dy: usize,
1339 cos: usize,
1340 sin: usize,
1341 dx: usize,
1342 batch: u32,
1343 seq: u32,
1344 hidden: u32,
1345 head_dim: u32,
1346 n_rot: u32,
1347 cos_len: u32,
1348 },
1349 CumsumBackward {
1350 dy: usize,
1351 dx: usize,
1352 rows: u32,
1353 cols: u32,
1354 exclusive: bool,
1355 },
1356 GatherBackward {
1357 dy: usize,
1358 indices: usize,
1359 dst: usize,
1360 outer: u32,
1361 axis_dim: u32,
1362 num_idx: u32,
1363 trailing: u32,
1364 },
1365
1366 GroupNormBackwardInput {
1367 x: usize,
1368 gamma: usize,
1369 beta: usize,
1370 dy: usize,
1371 dx: usize,
1372 n: u32,
1373 c: u32,
1374 h: u32,
1375 w: u32,
1376 num_groups: u32,
1377 eps: f32,
1378 },
1379 GroupNormBackwardGamma {
1380 x: usize,
1381 dy: usize,
1382 dgamma: usize,
1383 n: u32,
1384 c: u32,
1385 h: u32,
1386 w: u32,
1387 num_groups: u32,
1388 eps: f32,
1389 },
1390 GroupNormBackwardBeta {
1391 dy: usize,
1392 dbeta: usize,
1393 n: u32,
1394 c: u32,
1395 h: u32,
1396 w: u32,
1397 },
1398
1399 MaxPool2dBackward {
1405 x: usize,
1406 dy: usize,
1407 dx: usize,
1408 n: u32,
1409 c: u32,
1410 h: u32,
1411 w: u32,
1412 h_out: u32,
1413 w_out: u32,
1414 kh: u32,
1415 kw: u32,
1416 sh: u32,
1417 sw: u32,
1418 ph: u32,
1419 pw: u32,
1420 },
1421
1422 Conv2dBackwardInput {
1426 dy: usize,
1427 w: usize,
1428 dx: usize,
1429 n: u32,
1430 c_in: u32,
1431 h: u32,
1432 w_in: u32,
1433 c_out: u32,
1434 h_out: u32,
1435 w_out: u32,
1436 kh: u32,
1437 kw: u32,
1438 sh: u32,
1439 sw: u32,
1440 ph: u32,
1441 pw: u32,
1442 dh: u32,
1443 dw: u32,
1444 groups: u32,
1445 },
1446
1447 Conv2dBackwardWeight {
1451 x: usize,
1452 dy: usize,
1453 dw: usize,
1454 n: u32,
1455 c_in: u32,
1456 h: u32,
1457 w: u32,
1458 c_out: u32,
1459 h_out: u32,
1460 w_out: u32,
1461 kh: u32,
1462 kw: u32,
1463 sh: u32,
1464 sw: u32,
1465 ph: u32,
1466 pw: u32,
1467 dh: u32,
1468 dw_dil: u32,
1469 groups: u32,
1470 },
1471
1472 SoftmaxCrossEntropy {
1476 logits: usize,
1477 labels: usize,
1478 dst: usize,
1479 n: u32,
1480 c: u32,
1481 },
1482
1483 SoftmaxCrossEntropyBackward {
1486 logits: usize,
1487 labels: usize,
1488 d_loss: usize,
1489 dlogits: usize,
1490 n: u32,
1491 c: u32,
1492 },
1493
1494 CustomOp {
1500 kernel: Arc<dyn CpuKernel>,
1501 inputs: Vec<(usize, u32, Shape)>, output: (usize, u32, Shape), attrs: Vec<u8>,
1504 },
1505
1506 GaussianSplatRender {
1516 positions_off: usize,
1517 positions_len: usize,
1518 scales_off: usize,
1519 scales_len: usize,
1520 rotations_off: usize,
1521 rotations_len: usize,
1522 opacities_off: usize,
1523 opacities_len: usize,
1524 colors_off: usize,
1525 colors_len: usize,
1526 sh_coeffs_off: usize,
1527 sh_coeffs_len: usize,
1528 meta_off: usize,
1529 dst_off: usize,
1530 dst_len: usize,
1531 width: u32,
1532 height: u32,
1533 tile_size: u32,
1534 radius_scale: f32,
1535 alpha_cutoff: f32,
1536 max_splat_steps: u32,
1537 transmittance_threshold: f32,
1538 max_list_entries: u32,
1539 },
1540 GaussianSplatRenderBackward {
1541 positions_off: usize,
1542 positions_len: usize,
1543 scales_off: usize,
1544 scales_len: usize,
1545 rotations_off: usize,
1546 rotations_len: usize,
1547 opacities_off: usize,
1548 opacities_len: usize,
1549 colors_off: usize,
1550 colors_len: usize,
1551 sh_coeffs_off: usize,
1552 sh_coeffs_len: usize,
1553 meta_off: usize,
1554 d_loss_off: usize,
1555 d_loss_len: usize,
1556 packed_off: usize,
1557 packed_len: usize,
1558 width: u32,
1559 height: u32,
1560 tile_size: u32,
1561 radius_scale: f32,
1562 alpha_cutoff: f32,
1563 max_splat_steps: u32,
1564 transmittance_threshold: f32,
1565 max_list_entries: u32,
1566 loss_grad_clip: f32,
1567 sh_band: u32,
1568 max_anisotropy: f32,
1569 },
1570 GaussianSplatPrepare {
1572 positions_off: usize,
1573 positions_len: usize,
1574 scales_off: usize,
1575 scales_len: usize,
1576 rotations_off: usize,
1577 rotations_len: usize,
1578 opacities_off: usize,
1579 opacities_len: usize,
1580 colors_off: usize,
1581 colors_len: usize,
1582 sh_coeffs_off: usize,
1583 sh_coeffs_len: usize,
1584 meta_off: usize,
1585 meta_len: usize,
1586 prep_off: usize,
1587 prep_len: usize,
1588 width: u32,
1589 height: u32,
1590 tile_size: u32,
1591 radius_scale: f32,
1592 alpha_cutoff: f32,
1593 max_splat_steps: u32,
1594 transmittance_threshold: f32,
1595 max_list_entries: u32,
1596 },
1597 GaussianSplatRasterize {
1599 prep_off: usize,
1600 prep_len: usize,
1601 meta_off: usize,
1602 meta_len: usize,
1603 dst_off: usize,
1604 dst_len: usize,
1605 count: usize,
1606 width: u32,
1607 height: u32,
1608 tile_size: u32,
1609 alpha_cutoff: f32,
1610 max_splat_steps: u32,
1611 transmittance_threshold: f32,
1612 max_list_entries: u32,
1613 },
1614 Fft1d {
1615 src: usize,
1616 dst: usize,
1617 outer: u32,
1618 n_complex: u32,
1619 inverse: bool,
1620 dtype: rlx_ir::DType,
1621 },
1622}
1623
1624#[derive(Clone)]
1627pub struct ThunkSchedule {
1628 pub thunks: Vec<Thunk>,
1629 pub moe_resident: Option<std::sync::Arc<[bool]>>,
1631 pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1633 pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1635 pub mask_threshold: f32,
1637 pub mask_neg_inf: f32,
1638 pub score_skip: f32,
1639 pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1645}
1646
1647impl ThunkSchedule {
1648 pub fn strip_nops(&mut self) {
1649 self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1650 self.compiled_fns.clear();
1653 }
1654}
1655
1656fn node_offset(arena: &Arena, id: NodeId) -> usize {
1658 if arena.has_buffer(id) {
1659 arena.byte_offset(id)
1660 } else {
1661 usize::MAX
1662 }
1663}
1664
1665fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1671 match t {
1672 Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1673 Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1674 Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1675 Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1676 Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1677 Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1678 Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1679 Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1680 Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1681 Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1682 Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1683 Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1684 Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1685 Thunk::ConjugateC64 { src, .. } => vec![*src],
1686 Thunk::Scan {
1687 outer_init_off,
1688 xs_inputs,
1689 ..
1690 } => {
1691 let mut v = vec![*outer_init_off];
1692 for (_, outer_xs_off, _) in xs_inputs.iter() {
1693 v.push(*outer_xs_off);
1694 }
1695 v
1696 }
1697 Thunk::ScanBackward {
1698 outer_init_off,
1699 outer_traj_off,
1700 outer_upstream_off,
1701 outer_xs_offs,
1702 ..
1703 } => {
1704 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1705 for (off, _) in outer_xs_offs.iter() {
1706 v.push(*off);
1707 }
1708 v
1709 }
1710 Thunk::ScanBackwardXs {
1711 outer_init_off,
1712 outer_traj_off,
1713 outer_upstream_off,
1714 outer_xs_offs,
1715 ..
1716 } => {
1717 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1718 for (off, _) in outer_xs_offs.iter() {
1719 v.push(*off);
1720 }
1721 v
1722 }
1723 Thunk::CustomFn { inputs, .. } => {
1724 inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1725 }
1726 Thunk::ActivationInPlace { data, .. } => vec![*data],
1727 Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1728 vec![*src, *g, *b]
1729 }
1730 Thunk::ResizeNearest2x { src, .. } => vec![*src],
1731 Thunk::AxialRope2d { src, .. } => vec![*src],
1732 Thunk::FusedResidualLN {
1733 x, res, bias, g, b, ..
1734 } => vec![*x, *res, *bias, *g, *b],
1735 Thunk::FusedResidualRmsNorm {
1736 x, res, bias, g, b, ..
1737 } => vec![*x, *res, *bias, *g, *b],
1738 Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1739 Thunk::Softmax { data, .. } => vec![*data],
1740 Thunk::Cumsum { src, .. } => vec![*src],
1741 Thunk::Sample { logits, .. } => vec![*logits],
1742 Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1743 Thunk::DequantMatMul {
1744 x, w_q, scale, zp, ..
1745 } => vec![*x, *w_q, *scale, *zp],
1746 Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1747 Thunk::DequantMatMulInt4 {
1748 x, w_q, scale, zp, ..
1749 } => vec![*x, *w_q, *scale, *zp],
1750 Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1751 Thunk::DequantMatMulNvfp4 {
1752 x,
1753 w_q,
1754 scale,
1755 global_scale,
1756 ..
1757 } => vec![*x, *w_q, *scale, *global_scale],
1758 Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1759 Thunk::SelectiveScan {
1760 x, delta, a, b, c, ..
1761 } => vec![*x, *delta, *a, *b, *c],
1762 Thunk::GatedDeltaNet {
1763 q,
1764 k,
1765 v,
1766 g,
1767 beta,
1768 state,
1769 ..
1770 } => {
1771 let mut v = vec![*q, *k, *v, *g, *beta];
1772 if *state != 0 {
1773 v.push(*state);
1774 }
1775 v
1776 }
1777 Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1778 Thunk::AttentionBackward {
1779 q, k, v, dy, mask, ..
1780 } => {
1781 let mut v = vec![*q, *k, *v, *dy];
1782 if *mask != 0 {
1783 v.push(*mask);
1784 }
1785 v
1786 }
1787 Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1788 Thunk::FusedAttnBlock {
1789 hidden,
1790 qkv_w,
1791 out_w,
1792 mask,
1793 qkv_b,
1794 out_b,
1795 cos,
1796 sin,
1797 ..
1798 } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1799 Thunk::FusedSwiGLU { src, .. } => vec![*src],
1800 Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1801 Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1802 Thunk::Narrow { src, .. } => vec![*src],
1803 Thunk::Copy { src, .. } => vec![*src],
1804 Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1805 _ => vec![],
1809 }
1810}
1811
1812#[allow(clippy::too_many_arguments)]
1826fn dequant_matmul_int8(
1827 x: &[f32], w_bytes: &[i8], scales: &[f32], zps: &[f32], out: &mut [f32], m: usize,
1833 k: usize,
1834 n: usize,
1835 block_size: usize,
1836 asym: bool,
1837) {
1838 let blocks_per_col = k.div_ceil(block_size);
1839 for i in 0..m {
1840 for j in 0..n {
1841 let mut acc = 0f32;
1842 for p in 0..k {
1843 let block = p / block_size;
1844 let s = scales[block * n + j];
1845 let z = if asym { zps[block * n + j] } else { 0.0 };
1846 let q = w_bytes[p * n + j] as f32;
1847 let dequantized = (q - z) * s;
1848 acc += x[i * k + p] * dequantized;
1849 }
1850 out[i * n + j] = acc;
1851 }
1852 }
1853 let _ = blocks_per_col;
1854}
1855
1856#[allow(clippy::too_many_arguments)]
1857fn dequant_matmul_int4(
1858 x: &[f32],
1859 w_bytes: &[u8],
1860 scales: &[f32],
1861 zps: &[f32],
1862 out: &mut [f32],
1863 m: usize,
1864 k: usize,
1865 n: usize,
1866 block_size: usize,
1867 asym: bool,
1868) {
1869 for i in 0..m {
1870 for j in 0..n {
1871 let mut acc = 0f32;
1872 for p in 0..k {
1873 let block = p / block_size;
1874 let s = scales[block * n + j];
1875 let z = if asym { zps[block * n + j] } else { 0.0 };
1876 let byte_idx = (p * n + j) / 2;
1877 let nibble = if (p * n + j) & 1 == 0 {
1878 w_bytes[byte_idx] & 0x0F
1879 } else {
1880 w_bytes[byte_idx] >> 4
1881 };
1882 let dequantized = (nibble as f32 - z) * s;
1883 acc += x[i * k + p] * dequantized;
1884 }
1885 out[i * n + j] = acc;
1886 }
1887 }
1888}
1889
1890fn fp8_e4m3_to_f32(b: u8) -> f32 {
1891 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1892 let exp = (b >> 3) & 0x0F;
1893 let mant = b & 0x07;
1894 if exp == 0 {
1895 if mant == 0 {
1896 return 0.0;
1897 }
1898 return sign * (mant as f32) * 2f32.powi(-9);
1899 }
1900 if exp == 0x0F {
1901 return if mant == 0 {
1902 sign * f32::INFINITY
1903 } else {
1904 f32::NAN
1905 };
1906 }
1907 sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
1908}
1909
1910fn fp8_e5m2_to_f32(b: u8) -> f32 {
1911 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1912 let exp = (b >> 2) & 0x1F;
1913 let mant = b & 0x03;
1914 if exp == 0 {
1915 if mant == 0 {
1916 return 0.0;
1917 }
1918 return sign * (mant as f32) * 2f32.powi(-16);
1919 }
1920 if exp == 0x1F {
1921 return if mant == 0 {
1922 sign * f32::INFINITY
1923 } else {
1924 f32::NAN
1925 };
1926 }
1927 sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
1928}
1929
1930#[allow(clippy::too_many_arguments)]
1931fn dequant_matmul_fp8(
1932 x: &[f32],
1933 w_bytes: &[u8],
1934 scales: &[f32],
1935 out: &mut [f32],
1936 m: usize,
1937 k: usize,
1938 n: usize,
1939 e5m2: bool,
1940) {
1941 let dequant = if e5m2 {
1942 fp8_e5m2_to_f32
1943 } else {
1944 fp8_e4m3_to_f32
1945 };
1946 for i in 0..m {
1947 for j in 0..n {
1948 let mut acc = 0f32;
1949 for p in 0..k {
1950 let w = dequant(w_bytes[p * n + j]);
1951 let s = scales.get(j).copied().unwrap_or(1.0);
1952 acc += x[i * k + p] * w * s;
1953 }
1954 out[i * n + j] = acc;
1955 }
1956 }
1957}
1958
1959#[allow(clippy::too_many_arguments)]
1960pub fn dequant_matmul_nvfp4(
1961 x: &[f32],
1962 w_bytes: &[u8],
1963 scale_bytes: &[u8],
1964 global_scale: f32,
1965 out: &mut [f32],
1966 m: usize,
1967 k: usize,
1968 n: usize,
1969) {
1970 use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
1971 let gs = NVFP4_GROUP_SIZE;
1972 for i in 0..m {
1973 for j in 0..n {
1974 let mut acc = 0f32;
1975 for p in 0..k {
1976 let byte_idx = (p * n + j) / 2;
1977 let nibble = if (p * n + j) & 1 == 0 {
1978 w_bytes[byte_idx] & 0x0F
1979 } else {
1980 w_bytes[byte_idx] >> 4
1981 };
1982 let block = p / gs;
1983 let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
1984 let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
1985 acc += x[i * k + p] * w;
1986 }
1987 out[i * n + j] = acc;
1988 }
1989 }
1990}
1991
1992fn sample_row(
2001 logits: &[f32],
2002 top_k: usize,
2003 top_p: f32,
2004 temperature: f32,
2005 rng: &mut rlx_ir::Philox4x32,
2006) -> usize {
2007 let v = logits.len();
2008 if v == 0 {
2009 return 0;
2010 }
2011 let temp = temperature.max(1e-6);
2012 let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2014
2015 if top_k > 0 && top_k < v {
2017 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2019 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2022 let cutoff = indexed[top_k - 1].1;
2023 for x in scaled.iter_mut() {
2024 if *x < cutoff {
2025 *x = f32::NEG_INFINITY;
2026 }
2027 }
2028 }
2029
2030 let mut max_l = f32::NEG_INFINITY;
2032 for &x in &scaled {
2033 if x > max_l {
2034 max_l = x;
2035 }
2036 }
2037 let mut sum = 0.0f32;
2038 for x in scaled.iter_mut() {
2039 *x = (*x - max_l).exp();
2040 sum += *x;
2041 }
2042 let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2043 for x in scaled.iter_mut() {
2044 *x *= inv;
2045 }
2046
2047 if top_p < 1.0 {
2050 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2051 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2052 let mut cum = 0.0f32;
2053 let mut keep = vec![false; v];
2054 for (idx, p) in indexed.iter() {
2055 keep[*idx] = true;
2056 cum += *p;
2057 if cum >= top_p {
2058 break;
2059 }
2060 }
2061 let mut new_sum = 0.0f32;
2062 for (i, x) in scaled.iter_mut().enumerate() {
2063 if !keep[i] {
2064 *x = 0.0;
2065 }
2066 new_sum += *x;
2067 }
2068 let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2069 for x in scaled.iter_mut() {
2070 *x *= inv;
2071 }
2072 }
2073
2074 let r = rng.next_f32();
2076 let mut acc = 0.0f32;
2077 for (i, &p) in scaled.iter().enumerate() {
2078 acc += p;
2079 if r <= acc {
2080 return i;
2081 }
2082 }
2083 v - 1 }
2085
2086#[inline]
2090fn apply_synthetic_mask(
2091 scores: &mut [f32],
2092 q_seq: usize,
2093 k_seq: usize,
2094 kind: rlx_ir::op::MaskKind,
2095) {
2096 let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2097 let q_offset = k_seq.saturating_sub(q_seq);
2098 match kind {
2099 rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2100 rlx_ir::op::MaskKind::Causal => {
2101 for qi in 0..q_seq {
2102 let abs_q = q_offset + qi;
2103 for ki in (abs_q + 1)..k_seq {
2104 scores[qi * k_seq + ki] = neg;
2105 }
2106 }
2107 }
2108 rlx_ir::op::MaskKind::SlidingWindow(w) => {
2109 for qi in 0..q_seq {
2110 let abs_q = q_offset + qi;
2111 let lo = abs_q.saturating_sub(w);
2112 for ki in 0..k_seq {
2113 if ki < lo || ki > abs_q {
2114 scores[qi * k_seq + ki] = neg;
2115 }
2116 }
2117 }
2118 }
2119 }
2120}
2121
2122pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2124 let mut thunks = Vec::with_capacity(graph.len());
2125
2126 for node in graph.nodes() {
2127 if rlx_opt::is_pure_view(graph, node) {
2131 thunks.push(Thunk::Nop);
2132 continue;
2133 }
2134 let t = match &node.op {
2135 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2136
2137 Op::FusedMatMulBiasAct { activation } => {
2138 let shape = &node.shape;
2139 let n = shape.dim(shape.rank() - 1).unwrap_static();
2140 let total = shape.num_elements().unwrap();
2141 let m = total / n;
2142 let a_len = get_len(graph, node.inputs[0]);
2143 let k = a_len / m;
2144 Thunk::FusedMmBiasAct {
2145 a: node_offset(arena, node.inputs[0]),
2146 w: node_offset(arena, node.inputs[1]),
2147 bias: node_offset(arena, node.inputs[2]),
2148 c: node_offset(arena, node.id),
2149 m: m as u32,
2150 k: k as u32,
2151 n: n as u32,
2152 act: *activation,
2153 }
2154 }
2155
2156 Op::FusedResidualLN { has_bias, eps } => {
2157 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2158 let total = node.shape.num_elements().unwrap();
2159 let rows = total / h;
2160 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2161 Thunk::FusedResidualLN {
2162 x: node_offset(arena, node.inputs[0]),
2163 res: node_offset(arena, node.inputs[1]),
2164 bias: if *has_bias {
2165 node_offset(arena, node.inputs[2])
2166 } else {
2167 0
2168 },
2169 g: node_offset(arena, node.inputs[g_idx]),
2170 b: node_offset(arena, node.inputs[b_idx]),
2171 out: node_offset(arena, node.id),
2172 rows: rows as u32,
2173 h: h as u32,
2174 eps: *eps,
2175 has_bias: *has_bias,
2176 }
2177 }
2178
2179 Op::FusedResidualRmsNorm { has_bias, eps } => {
2180 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2181 let total = node.shape.num_elements().unwrap();
2182 let rows = total / h;
2183 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2184 Thunk::FusedResidualRmsNorm {
2185 x: node_offset(arena, node.inputs[0]),
2186 res: node_offset(arena, node.inputs[1]),
2187 bias: if *has_bias {
2188 node_offset(arena, node.inputs[2])
2189 } else {
2190 0
2191 },
2192 g: node_offset(arena, node.inputs[g_idx]),
2193 b: node_offset(arena, node.inputs[b_idx]),
2194 out: node_offset(arena, node.id),
2195 rows: rows as u32,
2196 h: h as u32,
2197 eps: *eps,
2198 has_bias: *has_bias,
2199 }
2200 }
2201
2202 Op::MatMul => {
2203 let shape = &node.shape;
2204 let a_shape = &graph.node(node.inputs[0]).shape;
2205 let b_shape = &graph.node(node.inputs[1]).shape;
2206 let n = shape.dim(shape.rank() - 1).unwrap_static();
2207
2208 let batched_3d = a_shape.rank() >= 3
2215 && b_shape.rank() == a_shape.rank()
2216 && shape.rank() == a_shape.rank()
2217 && {
2218 let mut ok = true;
2220 for d in 0..a_shape.rank() - 2 {
2221 if a_shape.dim(d) != b_shape.dim(d) || a_shape.dim(d) != shape.dim(d) {
2222 ok = false;
2223 break;
2224 }
2225 }
2226 ok
2227 };
2228 if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2229 let r = shape.rank();
2233 let mut batch_prod = 1usize;
2234 for d in 0..r - 2 {
2235 batch_prod *= shape.dim(d).unwrap_static();
2236 }
2237 let m_dim = shape.dim(r - 2).unwrap_static();
2238 let k_dim = a_shape.dim(r - 1).unwrap_static();
2239 debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2240 Thunk::BatchedDgemmF64 {
2241 a: node_offset(arena, node.inputs[0]),
2242 b: node_offset(arena, node.inputs[1]),
2243 c: node_offset(arena, node.id),
2244 batch: batch_prod as u32,
2245 m: m_dim as u32,
2246 k: k_dim as u32,
2247 n: n as u32,
2248 }
2249 } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2250 let r = shape.rank();
2253 let mut batch_prod = 1usize;
2254 for d in 0..r - 2 {
2255 batch_prod *= shape.dim(d).unwrap_static();
2256 }
2257 let m_dim = shape.dim(r - 2).unwrap_static();
2258 let k_dim = a_shape.dim(r - 1).unwrap_static();
2259 debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2260 Thunk::BatchedSgemm {
2261 a: node_offset(arena, node.inputs[0]),
2262 b: node_offset(arena, node.inputs[1]),
2263 c: node_offset(arena, node.id),
2264 batch: batch_prod as u32,
2265 m: m_dim as u32,
2266 k: k_dim as u32,
2267 n: n as u32,
2268 }
2269 } else {
2270 let total = shape.num_elements().unwrap();
2271 let m = total / n;
2272 let a_len = get_len(graph, node.inputs[0]);
2273 let k = a_len / m;
2274 match shape.dtype() {
2275 rlx_ir::DType::F64 => Thunk::Dgemm {
2276 a: node_offset(arena, node.inputs[0]),
2277 b: node_offset(arena, node.inputs[1]),
2278 c: node_offset(arena, node.id),
2279 m: m as u32,
2280 k: k as u32,
2281 n: n as u32,
2282 },
2283 _ => Thunk::Sgemm {
2284 a: node_offset(arena, node.inputs[0]),
2285 b: node_offset(arena, node.inputs[1]),
2286 c: node_offset(arena, node.id),
2287 m: m as u32,
2288 k: k as u32,
2289 n: n as u32,
2290 },
2291 }
2292 }
2293 }
2294
2295 Op::Binary(op) => {
2296 let lhs_len = get_len(graph, node.inputs[0]);
2297 let rhs_len = get_len(graph, node.inputs[1]);
2298 let out_len = node.shape.num_elements().unwrap();
2299 if node.shape.dtype() == rlx_ir::DType::C64 {
2300 match op {
2304 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2305 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2306 "Op::Binary({op:?}) on DType::C64: complex \
2307 max/min/pow have no single natural definition \
2308 — caller should drop to 2N-real-block (see \
2309 spike-ac) and pick a convention there"
2310 ),
2311 }
2312 }
2313 let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2317 if lhs_len == out_len && rhs_len == out_len {
2318 (Vec::new(), Vec::new(), Vec::new())
2319 } else {
2320 let lhs_dims = get_static_dims(graph, node.inputs[0]);
2321 let rhs_dims = get_static_dims(graph, node.inputs[1]);
2322 let out_dims_v = get_static_dims(graph, node.id);
2323 if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2324 (Vec::new(), Vec::new(), Vec::new())
2329 } else {
2330 let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2331 let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2332 let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2333 (od, ls, rs)
2334 }
2335 };
2336 if node.shape.dtype() == rlx_ir::DType::C64 {
2337 Thunk::BinaryFullC64 {
2338 lhs: node_offset(arena, node.inputs[0]),
2339 rhs: node_offset(arena, node.inputs[1]),
2340 dst: node_offset(arena, node.id),
2341 len: out_len as u32,
2342 lhs_len: lhs_len as u32,
2343 rhs_len: rhs_len as u32,
2344 op: *op,
2345 out_dims_bcast,
2346 bcast_lhs_strides,
2347 bcast_rhs_strides,
2348 }
2349 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2350 Thunk::BinaryFullF64 {
2353 lhs: node_offset(arena, node.inputs[0]),
2354 rhs: node_offset(arena, node.inputs[1]),
2355 dst: node_offset(arena, node.id),
2356 len: out_len as u32,
2357 lhs_len: lhs_len as u32,
2358 rhs_len: rhs_len as u32,
2359 op: *op,
2360 out_dims_bcast,
2361 bcast_lhs_strides,
2362 bcast_rhs_strides,
2363 }
2364 } else if matches!(op, BinaryOp::Add)
2365 && rhs_len < out_len
2366 && out_len % rhs_len == 0
2367 && is_trailing_bias_broadcast(
2368 graph.node(node.inputs[1]).shape.dims(),
2369 graph.node(node.id).shape.dims(),
2370 )
2371 {
2372 Thunk::BiasAdd {
2382 src: node_offset(arena, node.inputs[0]),
2383 bias: node_offset(arena, node.inputs[1]),
2384 dst: node_offset(arena, node.id),
2385 m: (out_len / rhs_len) as u32,
2386 n: rhs_len as u32,
2387 }
2388 } else {
2389 let lhs_len = get_len(graph, node.inputs[0]);
2390 Thunk::BinaryFull {
2391 lhs: node_offset(arena, node.inputs[0]),
2392 rhs: node_offset(arena, node.inputs[1]),
2393 dst: node_offset(arena, node.id),
2394 len: out_len as u32,
2395 lhs_len: lhs_len as u32,
2396 rhs_len: rhs_len as u32,
2397 op: *op,
2398 out_dims_bcast,
2399 bcast_lhs_strides,
2400 bcast_rhs_strides,
2401 }
2402 }
2403 }
2404
2405 Op::Activation(act) => {
2406 let len = node.shape.num_elements().unwrap();
2407 let in_off = node_offset(arena, node.inputs[0]);
2408 let out_off = node_offset(arena, node.id);
2409 if node.shape.dtype() == rlx_ir::DType::C64 {
2410 match act {
2415 Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2416 other => panic!(
2417 "Op::Activation({other:?}) on DType::C64: no \
2418 natural complex extension — supported on C64: \
2419 Neg, Exp, Log, Sqrt"
2420 ),
2421 }
2422 Thunk::ActivationC64 {
2423 src: in_off,
2424 dst: out_off,
2425 len: len as u32,
2426 kind: *act,
2427 }
2428 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2429 Thunk::ActivationF64 {
2430 src: in_off,
2431 dst: out_off,
2432 len: len as u32,
2433 kind: *act,
2434 }
2435 } else if in_off == out_off {
2436 Thunk::ActivationInPlace {
2440 data: out_off,
2441 len: len as u32,
2442 act: *act,
2443 }
2444 } else {
2445 thunks.push(Thunk::Copy {
2449 src: in_off,
2450 dst: out_off,
2451 len: len as u32,
2452 });
2453 Thunk::ActivationInPlace {
2454 data: out_off,
2455 len: len as u32,
2456 act: *act,
2457 }
2458 }
2459 }
2460
2461 Op::Gather { axis } if *axis == 0 => {
2462 let table_shape = &graph.node(node.inputs[0]).shape;
2463 let table_total = table_shape.num_elements().unwrap();
2464 let trailing: usize = (1..table_shape.rank())
2465 .map(|i| table_shape.dim(i).unwrap_static())
2466 .product();
2467 let idx_len = get_len(graph, node.inputs[1]);
2468 Thunk::Gather {
2469 table: node_offset(arena, node.inputs[0]),
2470 table_len: table_total as u32,
2471 idx: node_offset(arena, node.inputs[1]),
2472 dst: node_offset(arena, node.id),
2473 num_idx: idx_len as u32,
2474 trailing: trailing as u32,
2475 }
2476 }
2477
2478 Op::Gather { axis } => {
2479 let table_shape = &graph.node(node.inputs[0]).shape;
2481 let rank = table_shape.rank();
2482 let outer: usize = (0..*axis)
2483 .map(|i| table_shape.dim(i).unwrap_static())
2484 .product::<usize>()
2485 .max(1);
2486 let trailing: usize = (*axis + 1..rank)
2487 .map(|i| table_shape.dim(i).unwrap_static())
2488 .product::<usize>()
2489 .max(1);
2490 let axis_dim = table_shape.dim(*axis).unwrap_static();
2491 let idx_len = get_len(graph, node.inputs[1]);
2492 Thunk::GatherAxis {
2493 table: node_offset(arena, node.inputs[0]),
2494 idx: node_offset(arena, node.inputs[1]),
2495 dst: node_offset(arena, node.id),
2496 outer: outer as u32,
2497 axis_dim: axis_dim as u32,
2498 num_idx: idx_len as u32,
2499 trailing: trailing as u32,
2500 }
2501 }
2502
2503 Op::Narrow { axis, start, len } => {
2504 let in_shape = &graph.node(node.inputs[0]).shape;
2505 let elem_bytes = in_shape.dtype().size_bytes() as u8;
2506 let rank = in_shape.rank();
2507 let outer: usize = (0..*axis)
2508 .map(|i| in_shape.dim(i).unwrap_static())
2509 .product::<usize>()
2510 .max(1);
2511 let inner: usize = (*axis + 1..rank)
2512 .map(|i| in_shape.dim(i).unwrap_static())
2513 .product::<usize>()
2514 .max(1);
2515 let in_axis = in_shape.dim(*axis).unwrap_static();
2516 let src_byte_offset =
2517 node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2518 Thunk::Narrow {
2519 src: src_byte_offset,
2520 dst: node_offset(arena, node.id),
2521 outer: outer as u32,
2522 src_stride: (in_axis * inner) as u32, dst_stride: (*len * inner) as u32, inner: (*len * inner) as u32, elem_bytes,
2526 }
2527 }
2528
2529 Op::Reshape { .. } | Op::Cast { .. } => {
2530 let len = node.shape.num_elements().unwrap();
2532 let src = node_offset(arena, node.inputs[0]);
2533 let dst = node_offset(arena, node.id);
2534 match node.shape.dtype() {
2535 rlx_ir::DType::F64 => Thunk::CopyF64 {
2536 src,
2537 dst,
2538 len: len as u32,
2539 },
2540 _ => Thunk::Copy {
2541 src,
2542 dst,
2543 len: len as u32,
2544 },
2545 }
2546 }
2547
2548 Op::Quantize {
2549 axis,
2550 scales,
2551 zero_points,
2552 } => {
2553 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2554 Thunk::Quantize {
2555 x: node_offset(arena, node.inputs[0]),
2556 q: node_offset(arena, node.id),
2557 len: node.shape.num_elements().unwrap() as u32,
2558 chan_axis: chan_axis as u32,
2559 chan_dim: chan_dim as u32,
2560 inner: inner as u32,
2561 scales: scales.clone(),
2562 zero_points: zero_points.clone(),
2563 }
2564 }
2565
2566 Op::FakeQuantize {
2567 bits,
2568 axis,
2569 ste,
2570 scale_mode,
2571 } => {
2572 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2573 let state_off = match scale_mode {
2574 rlx_ir::op::ScaleMode::PerBatch => None,
2575 rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2576 debug_assert_eq!(
2578 node.inputs.len(),
2579 2,
2580 "EMA/Fixed FakeQuantize needs a state input"
2581 );
2582 Some(node_offset(arena, node.inputs[1]))
2583 }
2584 };
2585 Thunk::FakeQuantize {
2586 x: node_offset(arena, node.inputs[0]),
2587 out: node_offset(arena, node.id),
2588 len: node.shape.num_elements().unwrap() as u32,
2589 chan_axis: chan_axis as u32,
2590 chan_dim: chan_dim as u32,
2591 inner: inner as u32,
2592 bits: *bits,
2593 ste: *ste,
2594 scale_mode: *scale_mode,
2595 state_off,
2596 }
2597 }
2598
2599 Op::FakeQuantizeLSQ { bits, axis } => {
2600 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2601 Thunk::FakeQuantizeLSQ {
2602 x: node_offset(arena, node.inputs[0]),
2603 scale_off: node_offset(arena, node.inputs[1]),
2604 out: node_offset(arena, node.id),
2605 len: node.shape.num_elements().unwrap() as u32,
2606 chan_axis: chan_axis as u32,
2607 chan_dim: chan_dim as u32,
2608 inner: inner as u32,
2609 bits: *bits,
2610 }
2611 }
2612
2613 Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2614 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2615 Thunk::FakeQuantizeLSQBackwardX {
2616 x: node_offset(arena, node.inputs[0]),
2617 scale_off: node_offset(arena, node.inputs[1]),
2618 dy: node_offset(arena, node.inputs[2]),
2619 dx: node_offset(arena, node.id),
2620 len: node.shape.num_elements().unwrap() as u32,
2621 chan_axis: chan_axis as u32,
2622 chan_dim: chan_dim as u32,
2623 inner: inner as u32,
2624 bits: *bits,
2625 }
2626 }
2627
2628 Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2629 let in_shape = &graph.node(node.inputs[0]).shape;
2632 let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2633 Thunk::FakeQuantizeLSQBackwardScale {
2634 x: node_offset(arena, node.inputs[0]),
2635 scale_off: node_offset(arena, node.inputs[1]),
2636 dy: node_offset(arena, node.inputs[2]),
2637 dscale: node_offset(arena, node.id),
2638 len: in_shape.num_elements().unwrap() as u32,
2639 chan_axis: chan_axis as u32,
2640 chan_dim: chan_dim as u32,
2641 inner: inner as u32,
2642 bits: *bits,
2643 }
2644 }
2645
2646 Op::FakeQuantizeBackward { bits, axis, ste } => {
2647 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2648 Thunk::FakeQuantizeBackward {
2649 x: node_offset(arena, node.inputs[0]),
2650 dy: node_offset(arena, node.inputs[1]),
2651 dx: node_offset(arena, node.id),
2652 len: node.shape.num_elements().unwrap() as u32,
2653 chan_axis: chan_axis as u32,
2654 chan_dim: chan_dim as u32,
2655 inner: inner as u32,
2656 bits: *bits,
2657 ste: *ste,
2658 }
2659 }
2660
2661 Op::Dequantize {
2662 axis,
2663 scales,
2664 zero_points,
2665 } => {
2666 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2667 Thunk::Dequantize {
2668 q: node_offset(arena, node.inputs[0]),
2669 x: node_offset(arena, node.id),
2670 len: node.shape.num_elements().unwrap() as u32,
2671 chan_axis: chan_axis as u32,
2672 chan_dim: chan_dim as u32,
2673 inner: inner as u32,
2674 scales: scales.clone(),
2675 zero_points: zero_points.clone(),
2676 }
2677 }
2678
2679 Op::Expand { .. } => {
2680 let in_shape = &graph.node(node.inputs[0]).shape;
2685 let out_shape = &node.shape;
2686 let in_rank = in_shape.rank();
2687 let out_rank = out_shape.rank();
2688 let pad = out_rank.saturating_sub(in_rank);
2690 let in_dims: Vec<usize> = (0..out_rank)
2691 .map(|i| {
2692 if i < pad {
2693 1
2694 } else {
2695 in_shape.dim(i - pad).unwrap_static()
2696 }
2697 })
2698 .collect();
2699 let mut in_strides_full = vec![1usize; out_rank];
2701 for d in (0..out_rank.saturating_sub(1)).rev() {
2702 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2703 }
2704 let out_dims: Vec<u32> = (0..out_rank)
2705 .map(|i| out_shape.dim(i).unwrap_static() as u32)
2706 .collect();
2707 let in_strides: Vec<u32> = (0..out_rank)
2709 .map(|i| {
2710 if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2711 0
2712 } else {
2713 in_strides_full[i] as u32
2714 }
2715 })
2716 .collect();
2717 let in_total = in_dims.iter().product::<usize>() as u32;
2718 let src = node_offset(arena, node.inputs[0]);
2719 let dst = node_offset(arena, node.id);
2720 match node.shape.dtype() {
2721 rlx_ir::DType::F64 => Thunk::TransposeF64 {
2722 src,
2723 dst,
2724 in_total,
2725 out_dims,
2726 in_strides,
2727 },
2728 _ => Thunk::Transpose {
2729 src,
2730 dst,
2731 in_total,
2732 out_dims,
2733 in_strides,
2734 },
2735 }
2736 }
2737
2738 Op::RmsNorm { eps, .. } => {
2739 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2740 let total = node.shape.num_elements().unwrap();
2741 Thunk::RmsNorm {
2742 src: node_offset(arena, node.inputs[0]),
2743 g: node_offset(arena, node.inputs[1]),
2744 b: node_offset(arena, node.inputs[2]),
2745 dst: node_offset(arena, node.id),
2746 rows: (total / h) as u32,
2747 h: h as u32,
2748 eps: *eps,
2749 }
2750 }
2751
2752 Op::LayerNorm { eps, .. } => {
2753 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2754 let total = node.shape.num_elements().unwrap();
2755 Thunk::LayerNorm {
2756 src: node_offset(arena, node.inputs[0]),
2757 g: node_offset(arena, node.inputs[1]),
2758 b: node_offset(arena, node.inputs[2]),
2759 dst: node_offset(arena, node.id),
2760 rows: (total / h) as u32,
2761 h: h as u32,
2762 eps: *eps,
2763 }
2764 }
2765
2766 Op::GroupNorm { num_groups, eps } => {
2767 let in_shape = &graph.node(node.inputs[0]).shape;
2768 Thunk::GroupNorm {
2769 src: node_offset(arena, node.inputs[0]),
2770 g: node_offset(arena, node.inputs[1]),
2771 b: node_offset(arena, node.inputs[2]),
2772 dst: node_offset(arena, node.id),
2773 n: in_shape.dim(0).unwrap_static() as u32,
2774 c: in_shape.dim(1).unwrap_static() as u32,
2775 h: in_shape.dim(2).unwrap_static() as u32,
2776 w: in_shape.dim(3).unwrap_static() as u32,
2777 num_groups: *num_groups as u32,
2778 eps: *eps,
2779 }
2780 }
2781
2782 Op::LayerNorm2d { eps } => {
2783 let in_shape = &graph.node(node.inputs[0]).shape;
2784 Thunk::LayerNorm2d {
2785 src: node_offset(arena, node.inputs[0]),
2786 g: node_offset(arena, node.inputs[1]),
2787 b: node_offset(arena, node.inputs[2]),
2788 dst: node_offset(arena, node.id),
2789 n: in_shape.dim(0).unwrap_static() as u32,
2790 c: in_shape.dim(1).unwrap_static() as u32,
2791 h: in_shape.dim(2).unwrap_static() as u32,
2792 w: in_shape.dim(3).unwrap_static() as u32,
2793 eps: *eps,
2794 }
2795 }
2796
2797 Op::ConvTranspose2d {
2798 kernel_size,
2799 stride,
2800 padding,
2801 dilation,
2802 output_padding: _,
2803 groups,
2804 } => {
2805 let in_shape = &graph.node(node.inputs[0]).shape;
2806 let out_shape = &node.shape;
2807 Thunk::ConvTranspose2d {
2808 src: node_offset(arena, node.inputs[0]),
2809 weight: node_offset(arena, node.inputs[1]),
2810 dst: node_offset(arena, node.id),
2811 n: in_shape.dim(0).unwrap_static() as u32,
2812 c_in: in_shape.dim(1).unwrap_static() as u32,
2813 h: in_shape.dim(2).unwrap_static() as u32,
2814 w_in: in_shape.dim(3).unwrap_static() as u32,
2815 c_out: out_shape.dim(1).unwrap_static() as u32,
2816 h_out: out_shape.dim(2).unwrap_static() as u32,
2817 w_out: out_shape.dim(3).unwrap_static() as u32,
2818 kh: kernel_size[0] as u32,
2819 kw: kernel_size[1] as u32,
2820 sh: stride.first().copied().unwrap_or(1) as u32,
2821 sw: stride.get(1).copied().unwrap_or(1) as u32,
2822 ph: padding.first().copied().unwrap_or(0) as u32,
2823 pw: padding.get(1).copied().unwrap_or(0) as u32,
2824 dh: dilation.first().copied().unwrap_or(1) as u32,
2825 dw: dilation.get(1).copied().unwrap_or(1) as u32,
2826 groups: *groups as u32,
2827 }
2828 }
2829
2830 Op::ResizeNearest2x => {
2831 let in_shape = &graph.node(node.inputs[0]).shape;
2832 Thunk::ResizeNearest2x {
2833 src: node_offset(arena, node.inputs[0]),
2834 dst: node_offset(arena, node.id),
2835 n: in_shape.dim(0).unwrap_static() as u32,
2836 c: in_shape.dim(1).unwrap_static() as u32,
2837 h: in_shape.dim(2).unwrap_static() as u32,
2838 w: in_shape.dim(3).unwrap_static() as u32,
2839 }
2840 }
2841
2842 Op::AxialRope2d {
2843 end_x,
2844 end_y,
2845 head_dim,
2846 num_heads,
2847 theta,
2848 repeat_factor,
2849 } => {
2850 let in_shape = &graph.node(node.inputs[0]).shape;
2851 let batch = in_shape.dim(0).unwrap_static() as u32;
2852 let seq = in_shape.dim(1).unwrap_static() as u32;
2853 let hidden = in_shape.dim(2).unwrap_static() as u32;
2854 Thunk::AxialRope2d {
2855 src: node_offset(arena, node.inputs[0]),
2856 dst: node_offset(arena, node.id),
2857 batch,
2858 seq,
2859 hidden,
2860 end_x: *end_x as u32,
2861 end_y: *end_y as u32,
2862 head_dim: *head_dim as u32,
2863 num_heads: *num_heads as u32,
2864 theta: *theta,
2865 repeat_factor: *repeat_factor as u32,
2866 }
2867 }
2868
2869 Op::Softmax { axis } => {
2870 let rank = node.shape.rank();
2871 let ax = if *axis < 0 {
2872 (rank as i32 + axis) as usize
2873 } else {
2874 *axis as usize
2875 };
2876 let cols = node.shape.dim(ax).unwrap_static();
2877 let total = node.shape.num_elements().unwrap();
2878 let in_off = node_offset(arena, node.inputs[0]);
2879 let out_off = node_offset(arena, node.id);
2880 if in_off != out_off {
2886 thunks.push(Thunk::Copy {
2887 src: in_off,
2888 dst: out_off,
2889 len: total as u32,
2890 });
2891 }
2892 Thunk::Softmax {
2893 data: out_off,
2894 rows: (total / cols) as u32,
2895 cols: cols as u32,
2896 }
2897 }
2898
2899 Op::SelectiveScan { state_size } => {
2900 let in_shape = &graph.node(node.inputs[0]).shape;
2901 let (batch, seq, hidden) = (
2902 in_shape.dim(0).unwrap_static(),
2903 in_shape.dim(1).unwrap_static(),
2904 in_shape.dim(2).unwrap_static(),
2905 );
2906 Thunk::SelectiveScan {
2907 x: node_offset(arena, node.inputs[0]),
2908 delta: node_offset(arena, node.inputs[1]),
2909 a: node_offset(arena, node.inputs[2]),
2910 b: node_offset(arena, node.inputs[3]),
2911 c: node_offset(arena, node.inputs[4]),
2912 dst: node_offset(arena, node.id),
2913 batch: batch as u32,
2914 seq: seq as u32,
2915 hidden: hidden as u32,
2916 state_size: *state_size as u32,
2917 }
2918 }
2919
2920 Op::GatedDeltaNet {
2921 state_size,
2922 carry_state,
2923 } => {
2924 let q_shape = &graph.node(node.inputs[0]).shape;
2925 let (batch, seq, heads) = (
2926 q_shape.dim(0).unwrap_static(),
2927 q_shape.dim(1).unwrap_static(),
2928 q_shape.dim(2).unwrap_static(),
2929 );
2930 let state_off = if *carry_state {
2931 node_offset(arena, node.inputs[5])
2932 } else {
2933 0
2934 };
2935 Thunk::GatedDeltaNet {
2936 q: node_offset(arena, node.inputs[0]),
2937 k: node_offset(arena, node.inputs[1]),
2938 v: node_offset(arena, node.inputs[2]),
2939 g: node_offset(arena, node.inputs[3]),
2940 beta: node_offset(arena, node.inputs[4]),
2941 state: state_off,
2942 dst: node_offset(arena, node.id),
2943 batch: batch as u32,
2944 seq: seq as u32,
2945 heads: heads as u32,
2946 state_size: *state_size as u32,
2947 }
2948 }
2949
2950 Op::QMatMul {
2951 x_zp,
2952 w_zp,
2953 out_zp,
2954 mult,
2955 } => {
2956 let x_shape = &graph.node(node.inputs[0]).shape;
2957 let w_shape = &graph.node(node.inputs[1]).shape;
2958 let m = x_shape.dim(0).unwrap_static();
2959 let k = x_shape.dim(1).unwrap_static();
2960 let n = w_shape.dim(1).unwrap_static();
2961 Thunk::QMatMul {
2962 x: node_offset(arena, node.inputs[0]),
2963 w: node_offset(arena, node.inputs[1]),
2964 bias: node_offset(arena, node.inputs[2]),
2965 out: node_offset(arena, node.id),
2966 m: m as u32,
2967 k: k as u32,
2968 n: n as u32,
2969 x_zp: *x_zp,
2970 w_zp: *w_zp,
2971 out_zp: *out_zp,
2972 mult: *mult,
2973 }
2974 }
2975
2976 Op::QConv2d {
2977 kernel_size,
2978 stride,
2979 padding,
2980 dilation,
2981 groups,
2982 x_zp,
2983 w_zp,
2984 out_zp,
2985 mult,
2986 } => {
2987 let in_shape = &graph.node(node.inputs[0]).shape;
2988 let w_shape = &graph.node(node.inputs[1]).shape;
2989 let out_shape = &node.shape;
2990 if kernel_size.len() == 2
2991 && in_shape.rank() == 4
2992 && w_shape.rank() == 4
2993 && out_shape.rank() == 4
2994 {
2995 Thunk::QConv2d {
2996 x: node_offset(arena, node.inputs[0]),
2997 w: node_offset(arena, node.inputs[1]),
2998 bias: node_offset(arena, node.inputs[2]),
2999 out: node_offset(arena, node.id),
3000 n: in_shape.dim(0).unwrap_static() as u32,
3001 c_in: in_shape.dim(1).unwrap_static() as u32,
3002 h: in_shape.dim(2).unwrap_static() as u32,
3003 w_in: in_shape.dim(3).unwrap_static() as u32,
3004 c_out: out_shape.dim(1).unwrap_static() as u32,
3005 h_out: out_shape.dim(2).unwrap_static() as u32,
3006 w_out: out_shape.dim(3).unwrap_static() as u32,
3007 kh: kernel_size[0] as u32,
3008 kw: kernel_size[1] as u32,
3009 sh: stride.first().copied().unwrap_or(1) as u32,
3010 sw: stride.get(1).copied().unwrap_or(1) as u32,
3011 ph: padding.first().copied().unwrap_or(0) as u32,
3012 pw: padding.get(1).copied().unwrap_or(0) as u32,
3013 dh: dilation.first().copied().unwrap_or(1) as u32,
3014 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3015 groups: *groups as u32,
3016 x_zp: *x_zp,
3017 w_zp: *w_zp,
3018 out_zp: *out_zp,
3019 mult: *mult,
3020 }
3021 } else {
3022 Thunk::Nop
3023 }
3024 }
3025
3026 Op::DequantMatMul { scheme } => {
3027 use rlx_ir::quant::QuantScheme;
3028 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3029 let total = node.shape.num_elements().unwrap();
3030 let m = total / n.max(1);
3031 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3032 let k = x_total / m.max(1);
3033 if scheme.is_gguf() {
3034 Thunk::DequantMatMulGguf {
3035 x: node_offset(arena, node.inputs[0]),
3036 w_q: node_offset(arena, node.inputs[1]),
3037 dst: node_offset(arena, node.id),
3038 m: m as u32,
3039 k: k as u32,
3040 n: n as u32,
3041 scheme: *scheme,
3042 }
3043 } else {
3044 match scheme {
3045 QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3046 x: node_offset(arena, node.inputs[0]),
3047 w_q: node_offset(arena, node.inputs[1]),
3048 scale: node_offset(arena, node.inputs[2]),
3049 global_scale: node_offset(arena, node.inputs[3]),
3050 dst: node_offset(arena, node.id),
3051 m: m as u32,
3052 k: k as u32,
3053 n: n as u32,
3054 },
3055 QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3056 x: node_offset(arena, node.inputs[0]),
3057 w_q: node_offset(arena, node.inputs[1]),
3058 scale: node_offset(arena, node.inputs[2]),
3059 zp: node_offset(arena, node.inputs[3]),
3060 dst: node_offset(arena, node.id),
3061 m: m as u32,
3062 k: k as u32,
3063 n: n as u32,
3064 block_size: *block_size,
3065 is_asymmetric: false,
3066 },
3067 QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3068 x: node_offset(arena, node.inputs[0]),
3069 w_q: node_offset(arena, node.inputs[1]),
3070 scale: node_offset(arena, node.inputs[2]),
3071 dst: node_offset(arena, node.id),
3072 m: m as u32,
3073 k: k as u32,
3074 n: n as u32,
3075 e5m2: false,
3076 },
3077 QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3078 x: node_offset(arena, node.inputs[0]),
3079 w_q: node_offset(arena, node.inputs[1]),
3080 scale: node_offset(arena, node.inputs[2]),
3081 dst: node_offset(arena, node.id),
3082 m: m as u32,
3083 k: k as u32,
3084 n: n as u32,
3085 e5m2: true,
3086 },
3087 QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3088 x: node_offset(arena, node.inputs[0]),
3089 w_q: node_offset(arena, node.inputs[1]),
3090 scale: node_offset(arena, node.inputs[2]),
3091 zp: node_offset(arena, node.inputs[3]),
3092 dst: node_offset(arena, node.id),
3093 m: m as u32,
3094 k: k as u32,
3095 n: n as u32,
3096 block_size: *block_size,
3097 is_asymmetric: false,
3098 },
3099 QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3100 x: node_offset(arena, node.inputs[0]),
3101 w_q: node_offset(arena, node.inputs[1]),
3102 scale: node_offset(arena, node.inputs[2]),
3103 zp: node_offset(arena, node.inputs[3]),
3104 dst: node_offset(arena, node.id),
3105 m: m as u32,
3106 k: k as u32,
3107 n: n as u32,
3108 block_size: *block_size,
3109 is_asymmetric: true,
3110 },
3111 other => panic!(
3112 "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3113 ),
3114 }
3115 }
3116 }
3117
3118 Op::LoraMatMul { scale } => {
3119 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3121 let total = node.shape.num_elements().unwrap();
3122 let m = total / n.max(1);
3123 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3124 let k = x_total / m.max(1);
3125 let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3126 let r = a_total / k.max(1);
3127 Thunk::LoraMatMul {
3128 x: node_offset(arena, node.inputs[0]),
3129 w: node_offset(arena, node.inputs[1]),
3130 a: node_offset(arena, node.inputs[2]),
3131 b: node_offset(arena, node.inputs[3]),
3132 dst: node_offset(arena, node.id),
3133 m: m as u32,
3134 k: k as u32,
3135 n: n as u32,
3136 r: r as u32,
3137 scale: *scale,
3138 }
3139 }
3140
3141 Op::Sample {
3142 top_k,
3143 top_p,
3144 temperature,
3145 seed,
3146 } => {
3147 let in_shape = &graph.node(node.inputs[0]).shape;
3148 let (batch, vocab) = if in_shape.rank() >= 2 {
3150 (
3151 in_shape.dim(0).unwrap_static(),
3152 in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3153 )
3154 } else {
3155 (1, in_shape.num_elements().unwrap_or(0))
3156 };
3157 Thunk::Sample {
3158 logits: node_offset(arena, node.inputs[0]),
3159 dst: node_offset(arena, node.id),
3160 batch: batch as u32,
3161 vocab: vocab as u32,
3162 top_k: *top_k as u32,
3163 top_p: *top_p,
3164 temperature: *temperature,
3165 seed: *seed,
3166 }
3167 }
3168
3169 Op::Cumsum { axis, exclusive } => {
3170 let rank = node.shape.rank();
3175 let ax = if *axis < 0 {
3176 (rank as i32 + axis) as usize
3177 } else {
3178 *axis as usize
3179 };
3180 assert_eq!(
3181 ax,
3182 rank - 1,
3183 "Cumsum only supports the last axis on CPU today"
3184 );
3185 let cols = node.shape.dim(ax).unwrap_static();
3186 let total = node.shape.num_elements().unwrap();
3187 Thunk::Cumsum {
3188 src: node_offset(arena, node.inputs[0]),
3189 dst: node_offset(arena, node.id),
3190 rows: (total / cols) as u32,
3191 cols: cols as u32,
3192 exclusive: *exclusive,
3193 }
3194 }
3195
3196 Op::Attention {
3197 num_heads,
3198 head_dim,
3199 mask_kind,
3200 score_scale: _,
3201 attn_logit_softcap: _,
3202 } => {
3203 let q_shape = &graph.node(node.inputs[0]).shape;
3209 let k_shape = &graph.node(node.inputs[1]).shape;
3210 let rank = q_shape.rank();
3211 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3212 let d1 = q_shape.dim(1).unwrap_static();
3213 let d2 = q_shape.dim(2).unwrap_static();
3214 if d1 == *num_heads {
3215 (
3217 q_shape.dim(0).unwrap_static(),
3218 d2,
3219 k_shape.dim(2).unwrap_static(),
3220 true,
3221 )
3222 } else {
3223 (
3225 q_shape.dim(0).unwrap_static(),
3226 d1,
3227 k_shape.dim(1).unwrap_static(),
3228 false,
3229 )
3230 }
3231 } else if rank >= 3 {
3232 (
3233 q_shape.dim(0).unwrap_static(),
3234 q_shape.dim(1).unwrap_static(),
3235 k_shape.dim(1).unwrap_static(),
3236 false,
3237 )
3238 } else {
3239 (
3240 1,
3241 q_shape.dim(0).unwrap_static(),
3242 k_shape.dim(0).unwrap_static(),
3243 false,
3244 )
3245 };
3246 let mask_off = if matches!(
3247 mask_kind,
3248 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3249 ) {
3250 node_offset(arena, node.inputs[3])
3251 } else {
3252 0
3253 };
3254 let hs = (*num_heads * *head_dim) as u32;
3255 Thunk::Attention {
3256 q: node_offset(arena, node.inputs[0]),
3257 k: node_offset(arena, node.inputs[1]),
3258 v: node_offset(arena, node.inputs[2]),
3259 mask: mask_off,
3260 out: node_offset(arena, node.id),
3261 batch: batch as u32,
3262 seq: seq as u32,
3263 kv_seq: kv_seq as u32,
3264 heads: *num_heads as u32,
3265 head_dim: *head_dim as u32,
3266 mask_kind: *mask_kind,
3267 q_row_stride: hs,
3271 k_row_stride: hs,
3272 v_row_stride: hs,
3273 bhsd,
3274 }
3275 }
3276
3277 Op::AttentionBackward {
3278 num_heads,
3279 head_dim,
3280 mask_kind,
3281 wrt,
3282 } => {
3283 let q_shape = &graph.node(node.inputs[0]).shape;
3284 let k_shape = &graph.node(node.inputs[1]).shape;
3285 let rank = q_shape.rank();
3286 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3287 let d1 = q_shape.dim(1).unwrap_static();
3288 let d2 = q_shape.dim(2).unwrap_static();
3289 if d1 == *num_heads {
3290 (
3291 q_shape.dim(0).unwrap_static(),
3292 d2,
3293 k_shape.dim(2).unwrap_static(),
3294 true,
3295 )
3296 } else {
3297 (
3298 q_shape.dim(0).unwrap_static(),
3299 d1,
3300 k_shape.dim(1).unwrap_static(),
3301 false,
3302 )
3303 }
3304 } else if rank >= 3 {
3305 (
3306 q_shape.dim(0).unwrap_static(),
3307 q_shape.dim(1).unwrap_static(),
3308 k_shape.dim(1).unwrap_static(),
3309 false,
3310 )
3311 } else {
3312 (
3313 1,
3314 q_shape.dim(0).unwrap_static(),
3315 k_shape.dim(0).unwrap_static(),
3316 false,
3317 )
3318 };
3319 let mask_off = if matches!(
3320 mask_kind,
3321 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3322 ) {
3323 node_offset(arena, node.inputs[4])
3324 } else {
3325 0
3326 };
3327 Thunk::AttentionBackward {
3328 q: node_offset(arena, node.inputs[0]),
3329 k: node_offset(arena, node.inputs[1]),
3330 v: node_offset(arena, node.inputs[2]),
3331 dy: node_offset(arena, node.inputs[3]),
3332 mask: mask_off,
3333 out: node_offset(arena, node.id),
3334 batch: batch as u32,
3335 seq: seq as u32,
3336 kv_seq: kv_seq as u32,
3337 heads: *num_heads as u32,
3338 head_dim: *head_dim as u32,
3339 mask_kind: *mask_kind,
3340 wrt: *wrt,
3341 bhsd,
3342 }
3343 }
3344
3345 Op::FusedAttentionBlock {
3346 num_heads,
3347 head_dim,
3348 has_bias,
3349 has_rope,
3350 } => {
3351 let x_shape = &graph.node(node.inputs[0]).shape;
3352 let (batch, seq) = if x_shape.rank() >= 3 {
3353 (
3354 x_shape.dim(0).unwrap_static(),
3355 x_shape.dim(1).unwrap_static(),
3356 )
3357 } else {
3358 let total = x_shape.num_elements().unwrap();
3359 let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3360 (total / (s * num_heads * head_dim), s)
3361 };
3362 let hs = (*num_heads * *head_dim) as u32;
3363 let mut idx = 4;
3365 let (qkv_b_off, out_b_off) = if *has_bias {
3366 let qb = node_offset(arena, node.inputs[idx]);
3367 let ob = node_offset(arena, node.inputs[idx + 1]);
3368 idx += 2;
3369 (qb, ob)
3370 } else {
3371 (0, 0)
3372 };
3373 let (cos_off, sin_off, cl) = if *has_rope {
3374 let c = node_offset(arena, node.inputs[idx]);
3375 let s = node_offset(arena, node.inputs[idx + 1]);
3376 let clen = get_len(graph, node.inputs[idx]);
3377 (c, s, clen as u32)
3378 } else {
3379 (0, 0, 0)
3380 };
3381
3382 Thunk::FusedAttnBlock {
3383 hidden: node_offset(arena, node.inputs[0]),
3384 qkv_w: node_offset(arena, node.inputs[1]),
3385 out_w: node_offset(arena, node.inputs[2]),
3386 mask: node_offset(arena, node.inputs[3]),
3387 out: node_offset(arena, node.id),
3388 qkv_b: qkv_b_off,
3389 out_b: out_b_off,
3390 cos: cos_off,
3391 sin: sin_off,
3392 cos_len: cl,
3393 batch: batch as u32,
3394 seq: seq as u32,
3395 hs,
3396 nh: *num_heads as u32,
3397 dh: *head_dim as u32,
3398 has_bias: *has_bias,
3399 has_rope: *has_rope,
3400 }
3401 }
3402
3403 Op::Rope { head_dim, n_rot } => {
3404 let x_shape = &graph.node(node.inputs[0]).shape;
3405 let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3406 (
3407 x_shape.dim(0).unwrap_static(),
3408 x_shape.dim(1).unwrap_static(),
3409 x_shape.dim(2).unwrap_static(),
3410 )
3411 } else {
3412 let total = x_shape.num_elements().unwrap();
3413 (
3414 1,
3415 x_shape.dim(0).unwrap_static(),
3416 total / x_shape.dim(0).unwrap_static(),
3417 )
3418 };
3419 let cos_len = get_len(graph, node.inputs[1]);
3420 Thunk::Rope {
3421 src: node_offset(arena, node.inputs[0]),
3422 cos: node_offset(arena, node.inputs[1]),
3423 sin: node_offset(arena, node.inputs[2]),
3424 dst: node_offset(arena, node.id),
3425 batch: batch as u32,
3426 seq: seq as u32,
3427 hidden: hidden as u32,
3428 head_dim: *head_dim as u32,
3429 n_rot: *n_rot as u32,
3430 cos_len: cos_len as u32,
3431 src_row_stride: hidden as u32,
3435 }
3436 }
3437
3438 Op::FusedSwiGLU {
3439 cast_to: _,
3440 gate_first,
3441 } => {
3442 let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3443 let total = node.shape.num_elements().unwrap();
3444 Thunk::FusedSwiGLU {
3445 src: node_offset(arena, node.inputs[0]),
3446 dst: node_offset(arena, node.id),
3447 n_half: n_half as u32,
3448 total: total as u32,
3449 gate_first: *gate_first,
3450 }
3451 }
3452
3453 Op::Conv {
3454 kernel_size,
3455 stride,
3456 padding,
3457 dilation,
3458 groups,
3459 } => {
3460 let in_shape = &graph.node(node.inputs[0]).shape;
3461 let w_shape = &graph.node(node.inputs[1]).shape;
3462 let out_shape = &node.shape;
3463 let is_1x1_simple = kernel_size.len() == 2
3467 && kernel_size[0] == 1
3468 && kernel_size[1] == 1
3469 && stride.iter().all(|&s| s == 1)
3470 && padding.iter().all(|&p| p == 0)
3471 && dilation.iter().all(|&d| d == 1)
3472 && *groups == 1;
3473 if is_1x1_simple && in_shape.rank() == 4 && out_shape.rank() == 4 {
3474 let n = in_shape.dim(0).unwrap_static();
3475 let c_in = in_shape.dim(1).unwrap_static();
3476 let c_out = out_shape.dim(1).unwrap_static();
3477 let h = in_shape.dim(2).unwrap_static();
3478 let w = in_shape.dim(3).unwrap_static();
3479 Thunk::Conv2D1x1 {
3480 src: node_offset(arena, node.inputs[0]),
3481 weight: node_offset(arena, node.inputs[1]),
3482 dst: node_offset(arena, node.id),
3483 n: n as u32,
3484 c_in: c_in as u32,
3485 c_out: c_out as u32,
3486 hw: (h * w) as u32,
3487 }
3488 } else if kernel_size.len() == 2
3489 && in_shape.rank() == 4
3490 && w_shape.rank() == 4
3491 && out_shape.rank() == 4
3492 {
3493 Thunk::Conv2D {
3494 src: node_offset(arena, node.inputs[0]),
3495 weight: node_offset(arena, node.inputs[1]),
3496 dst: node_offset(arena, node.id),
3497 n: in_shape.dim(0).unwrap_static() as u32,
3498 c_in: in_shape.dim(1).unwrap_static() as u32,
3499 h: in_shape.dim(2).unwrap_static() as u32,
3500 w: in_shape.dim(3).unwrap_static() as u32,
3501 c_out: out_shape.dim(1).unwrap_static() as u32,
3502 h_out: out_shape.dim(2).unwrap_static() as u32,
3503 w_out: out_shape.dim(3).unwrap_static() as u32,
3504 kh: kernel_size[0] as u32,
3505 kw: kernel_size[1] as u32,
3506 sh: stride.first().copied().unwrap_or(1) as u32,
3507 sw: stride.get(1).copied().unwrap_or(1) as u32,
3508 ph: padding.first().copied().unwrap_or(0) as u32,
3509 pw: padding.get(1).copied().unwrap_or(0) as u32,
3510 dh: dilation.first().copied().unwrap_or(1) as u32,
3511 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3512 groups: *groups as u32,
3513 }
3514 } else {
3515 Thunk::Nop
3516 }
3517 }
3518
3519 Op::Pool {
3520 kind,
3521 kernel_size,
3522 stride,
3523 padding,
3524 } => {
3525 let in_shape = &graph.node(node.inputs[0]).shape;
3527 let out_shape = &node.shape;
3528 if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3529 Thunk::Pool2D {
3530 src: node_offset(arena, node.inputs[0]),
3531 dst: node_offset(arena, node.id),
3532 n: in_shape.dim(0).unwrap_static() as u32,
3533 c: in_shape.dim(1).unwrap_static() as u32,
3534 h: in_shape.dim(2).unwrap_static() as u32,
3535 w: in_shape.dim(3).unwrap_static() as u32,
3536 h_out: out_shape.dim(2).unwrap_static() as u32,
3537 w_out: out_shape.dim(3).unwrap_static() as u32,
3538 kh: kernel_size[0] as u32,
3539 kw: kernel_size[1] as u32,
3540 sh: stride.first().copied().unwrap_or(1) as u32,
3541 sw: stride.get(1).copied().unwrap_or(1) as u32,
3542 ph: padding.first().copied().unwrap_or(0) as u32,
3543 pw: padding.get(1).copied().unwrap_or(0) as u32,
3544 kind: *kind,
3545 }
3546 } else {
3547 Thunk::Nop
3548 }
3549 }
3550
3551 Op::Transpose { perm } => {
3552 let in_shape = &graph.node(node.inputs[0]).shape;
3555 let in_rank = in_shape.rank();
3556 let in_dims: Vec<usize> = (0..in_rank)
3557 .map(|i| in_shape.dim(i).unwrap_static())
3558 .collect();
3559 let mut in_strides_full = vec![1usize; in_rank];
3561 for d in (0..in_rank.saturating_sub(1)).rev() {
3562 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3563 }
3564 let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3565 let in_strides: Vec<u32> =
3566 perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3567 let in_total = in_dims.iter().product::<usize>() as u32;
3568 let src = node_offset(arena, node.inputs[0]);
3569 let dst = node_offset(arena, node.id);
3570 match node.shape.dtype() {
3571 rlx_ir::DType::F64 => Thunk::TransposeF64 {
3572 src,
3573 dst,
3574 in_total,
3575 out_dims,
3576 in_strides,
3577 },
3578 _ => Thunk::Transpose {
3579 src,
3580 dst,
3581 in_total,
3582 out_dims,
3583 in_strides,
3584 },
3585 }
3586 }
3587
3588 Op::ScatterAdd => {
3589 let upd_shape = &graph.node(node.inputs[0]).shape;
3592 let out_shape = &node.shape;
3593 let num_updates = upd_shape.dim(0).unwrap_static();
3594 let out_dim = out_shape.dim(0).unwrap_static();
3595 let trailing: usize = (1..out_shape.rank())
3596 .map(|i| out_shape.dim(i).unwrap_static())
3597 .product::<usize>()
3598 .max(1);
3599 Thunk::ScatterAdd {
3600 updates: node_offset(arena, node.inputs[0]),
3601 indices: node_offset(arena, node.inputs[1]),
3602 dst: node_offset(arena, node.id),
3603 num_updates: num_updates as u32,
3604 out_dim: out_dim as u32,
3605 trailing: trailing as u32,
3606 }
3607 }
3608
3609 Op::GroupedMatMul => {
3610 let in_shape = &graph.node(node.inputs[0]).shape;
3612 let w_shape = &graph.node(node.inputs[1]).shape;
3613 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3614 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3615 let num_experts = w_shape.dim(0).unwrap_static();
3616 let n = w_shape.dim(2).unwrap_static();
3617 Thunk::GroupedMatMul {
3618 input: node_offset(arena, node.inputs[0]),
3619 weight: node_offset(arena, node.inputs[1]),
3620 expert_idx: node_offset(arena, node.inputs[2]),
3621 dst: node_offset(arena, node.id),
3622 m: m as u32,
3623 k_dim: k_dim as u32,
3624 n: n as u32,
3625 num_experts: num_experts as u32,
3626 }
3627 }
3628
3629 Op::DequantGroupedMatMul { scheme } => {
3630 let in_shape = &graph.node(node.inputs[0]).shape;
3631 let w_shape = &graph.node(node.inputs[1]).shape;
3632 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3633 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3634 let out_shape = &node.shape;
3635 let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3636 let block_elems = scheme.gguf_block_size() as usize;
3637 let block_bytes = scheme.gguf_block_bytes() as usize;
3638 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3639 let total_bytes = w_shape.num_elements().unwrap();
3640 let num_experts = total_bytes / slab_bytes.max(1);
3641 Thunk::DequantGroupedMatMulGguf {
3642 input: node_offset(arena, node.inputs[0]),
3643 w_q: node_offset(arena, node.inputs[1]),
3644 expert_idx: node_offset(arena, node.inputs[2]),
3645 dst: node_offset(arena, node.id),
3646 m: m as u32,
3647 k_dim: k_dim as u32,
3648 n: n as u32,
3649 num_experts: num_experts as u32,
3650 scheme: *scheme,
3651 }
3652 }
3653
3654 Op::DequantMoEWeights { scheme } => {
3655 let w_shape = &graph.node(node.inputs[0]).shape;
3656 let out_shape = &node.shape;
3657 let num_experts = out_shape.dim(0).unwrap_static();
3658 let k_dim = out_shape.dim(1).unwrap_static();
3659 let n = out_shape.dim(2).unwrap_static();
3660 let block_elems = scheme.gguf_block_size() as usize;
3661 let block_bytes = scheme.gguf_block_bytes() as usize;
3662 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3663 let total_bytes = w_shape.num_elements().unwrap();
3664 assert_eq!(
3665 total_bytes,
3666 num_experts * slab_bytes,
3667 "DequantMoEWeights packed bytes mismatch"
3668 );
3669 Thunk::DequantMoEWeightsGguf {
3670 w_q: node_offset(arena, node.inputs[0]),
3671 dst: node_offset(arena, node.id),
3672 k_dim: k_dim as u32,
3673 n: n as u32,
3674 num_experts: num_experts as u32,
3675 scheme: *scheme,
3676 }
3677 }
3678
3679 Op::TopK { k } => {
3680 let in_shape = &graph.node(node.inputs[0]).shape;
3681 let rank = in_shape.rank();
3682 let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3683 let outer = in_shape.num_elements().unwrap() / axis_dim;
3684 Thunk::TopK {
3685 src: node_offset(arena, node.inputs[0]),
3686 dst: node_offset(arena, node.id),
3687 outer: outer as u32,
3688 axis_dim: axis_dim as u32,
3689 k: *k as u32,
3690 }
3691 }
3692
3693 Op::Reduce {
3694 op,
3695 axes,
3696 keep_dim: _,
3697 } => {
3698 let in_shape = &graph.node(node.inputs[0]).shape;
3704 let rank = in_shape.rank();
3705 let mut sorted = axes.clone();
3706 sorted.sort();
3707 sorted.dedup();
3708 let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
3709 && !sorted.is_empty()
3710 && *sorted.last().unwrap() < rank;
3711 if !contiguous {
3712 Thunk::Nop
3713 } else {
3714 let first = sorted[0];
3715 let last = *sorted.last().unwrap();
3716 let outer: usize = (0..first)
3717 .map(|i| in_shape.dim(i).unwrap_static())
3718 .product::<usize>()
3719 .max(1);
3720 let reduced: usize = (first..=last)
3721 .map(|i| in_shape.dim(i).unwrap_static())
3722 .product();
3723 let inner: usize = (last + 1..rank)
3724 .map(|i| in_shape.dim(i).unwrap_static())
3725 .product::<usize>()
3726 .max(1);
3727 let src = node_offset(arena, node.inputs[0]);
3728 let dst = node_offset(arena, node.id);
3729 if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
3730 Thunk::ReduceSumF64 {
3731 src,
3732 dst,
3733 outer: outer as u32,
3734 reduced: reduced as u32,
3735 inner: inner as u32,
3736 }
3737 } else {
3738 Thunk::Reduce {
3739 src,
3740 dst,
3741 outer: outer as u32,
3742 reduced: reduced as u32,
3743 inner: inner as u32,
3744 op: *op,
3745 }
3746 }
3747 }
3748 }
3749
3750 Op::Compare(cmp) => {
3751 let len = node.shape.num_elements().unwrap();
3752 Thunk::Compare {
3753 lhs: node_offset(arena, node.inputs[0]),
3754 rhs: node_offset(arena, node.inputs[1]),
3755 dst: node_offset(arena, node.id),
3756 len: len as u32,
3757 op: *cmp,
3758 }
3759 }
3760
3761 Op::Where => {
3762 let len = node.shape.num_elements().unwrap();
3763 Thunk::Where {
3764 cond: node_offset(arena, node.inputs[0]),
3765 on_true: node_offset(arena, node.inputs[1]),
3766 on_false: node_offset(arena, node.inputs[2]),
3767 dst: node_offset(arena, node.id),
3768 len: len as u32,
3769 }
3770 }
3771
3772 Op::ReluBackward => {
3773 let len: usize = (0..node.shape.rank())
3774 .map(|i| node.shape.dim(i).unwrap_static())
3775 .product();
3776 let x = node_offset(arena, node.inputs[0]);
3777 let dy = node_offset(arena, node.inputs[1]);
3778 let dx = node_offset(arena, node.id);
3779 match node.shape.dtype() {
3780 rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
3781 x,
3782 dy,
3783 dx,
3784 len: len as u32,
3785 },
3786 _ => Thunk::ReluBackward {
3787 x,
3788 dy,
3789 dx,
3790 len: len as u32,
3791 },
3792 }
3793 }
3794
3795 Op::ComplexNormSq => {
3796 let len: usize = (0..node.shape.rank())
3797 .map(|i| node.shape.dim(i).unwrap_static())
3798 .product();
3799 let src = node_offset(arena, node.inputs[0]);
3800 let dst = node_offset(arena, node.id);
3801 Thunk::ComplexNormSqF32 {
3802 src,
3803 dst,
3804 len: len as u32,
3805 }
3806 }
3807
3808 Op::ComplexNormSqBackward => {
3809 let len: usize = (0..node.shape.rank())
3810 .map(|i| node.shape.dim(i).unwrap_static())
3811 .product();
3812 let z = node_offset(arena, node.inputs[0]);
3813 let g = node_offset(arena, node.inputs[1]);
3814 let dz = node_offset(arena, node.id);
3815 Thunk::ComplexNormSqBackwardF32 {
3816 z,
3817 g,
3818 dz,
3819 len: len as u32,
3820 }
3821 }
3822
3823 Op::Conjugate => {
3824 let len: usize = (0..node.shape.rank())
3825 .map(|i| node.shape.dim(i).unwrap_static())
3826 .product();
3827 Thunk::ConjugateC64 {
3828 src: node_offset(arena, node.inputs[0]),
3829 dst: node_offset(arena, node.id),
3830 len: len as u32,
3831 }
3832 }
3833
3834 Op::ActivationBackward { kind } => {
3835 let len: usize = (0..node.shape.rank())
3836 .map(|i| node.shape.dim(i).unwrap_static())
3837 .product();
3838 let x = node_offset(arena, node.inputs[0]);
3839 let dy = node_offset(arena, node.inputs[1]);
3840 let dx = node_offset(arena, node.id);
3841 match node.shape.dtype() {
3842 rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
3843 x,
3844 dy,
3845 dx,
3846 len: len as u32,
3847 kind: *kind,
3848 },
3849 _ => Thunk::ActivationBackward {
3850 x,
3851 dy,
3852 dx,
3853 len: len as u32,
3854 kind: *kind,
3855 },
3856 }
3857 }
3858
3859 Op::LayerNormBackwardInput { eps, .. } => {
3860 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3862 let total = node.shape.num_elements().unwrap();
3863 Thunk::LayerNormBackwardInput {
3864 x: node_offset(arena, node.inputs[0]),
3865 gamma: node_offset(arena, node.inputs[1]),
3866 dy: node_offset(arena, node.inputs[2]),
3867 dx: node_offset(arena, node.id),
3868 rows: (total / h) as u32,
3869 h: h as u32,
3870 eps: *eps,
3871 }
3872 }
3873
3874 Op::LayerNormBackwardGamma { eps, .. } => {
3875 let x_shape = &graph.node(node.inputs[0]).shape;
3876 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3877 let x_total = x_shape.num_elements().unwrap();
3878 Thunk::LayerNormBackwardGamma {
3879 x: node_offset(arena, node.inputs[0]),
3880 dy: node_offset(arena, node.inputs[1]),
3881 dgamma: node_offset(arena, node.id),
3882 rows: (x_total / h) as u32,
3883 h: h as u32,
3884 eps: *eps,
3885 }
3886 }
3887
3888 Op::RmsNormBackwardInput { eps, .. }
3889 | Op::RmsNormBackwardGamma { eps, .. }
3890 | Op::RmsNormBackwardBeta { eps, .. } => {
3891 let x_shape = &graph.node(node.inputs[0]).shape;
3892 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3893 let rows = (x_shape.num_elements().unwrap() / h) as u32;
3894 let off = |i: usize| node_offset(arena, node.inputs[i]);
3895 let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
3896 match &node.op {
3897 Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
3898 x: common.0,
3899 gamma: common.1,
3900 beta: common.2,
3901 dy: common.3,
3902 dx: node_offset(arena, node.id),
3903 rows: common.4,
3904 h: common.5,
3905 eps: common.6,
3906 },
3907 Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
3908 x: common.0,
3909 gamma: common.1,
3910 beta: common.2,
3911 dy: common.3,
3912 dgamma: node_offset(arena, node.id),
3913 rows: common.4,
3914 h: common.5,
3915 eps: common.6,
3916 },
3917 Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
3918 x: common.0,
3919 gamma: common.1,
3920 beta: common.2,
3921 dy: common.3,
3922 dbeta: node_offset(arena, node.id),
3923 rows: common.4,
3924 h: common.5,
3925 eps: common.6,
3926 },
3927 _ => unreachable!(),
3928 }
3929 }
3930
3931 Op::RopeBackward { head_dim, n_rot } => {
3932 let dy_shape = &graph.node(node.inputs[0]).shape;
3933 let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
3934 (
3935 dy_shape.dim(0).unwrap_static(),
3936 dy_shape.dim(1).unwrap_static(),
3937 dy_shape.dim(2).unwrap_static(),
3938 )
3939 } else {
3940 (
3941 1,
3942 dy_shape.dim(0).unwrap_static(),
3943 dy_shape.dim(1).unwrap_static(),
3944 )
3945 };
3946 let cos_shape = &graph.node(node.inputs[1]).shape;
3947 let cos_len = cos_shape.num_elements().unwrap();
3948 Thunk::RopeBackward {
3949 dy: node_offset(arena, node.inputs[0]),
3950 cos: node_offset(arena, node.inputs[1]),
3951 sin: node_offset(arena, node.inputs[2]),
3952 dx: node_offset(arena, node.id),
3953 batch: batch as u32,
3954 seq: seq as u32,
3955 hidden: hidden as u32,
3956 head_dim: *head_dim as u32,
3957 n_rot: *n_rot as u32,
3958 cos_len: cos_len as u32,
3959 }
3960 }
3961
3962 Op::CumsumBackward { exclusive, .. } => {
3963 let dy_shape = &graph.node(node.inputs[0]).shape;
3964 let rank = dy_shape.rank();
3965 let cols = dy_shape.dim(rank - 1).unwrap_static();
3966 let rows = dy_shape.num_elements().unwrap() / cols;
3967 Thunk::CumsumBackward {
3968 dy: node_offset(arena, node.inputs[0]),
3969 dx: node_offset(arena, node.id),
3970 rows: rows as u32,
3971 cols: cols as u32,
3972 exclusive: *exclusive,
3973 }
3974 }
3975
3976 Op::GatherBackward { .. } => {
3977 let dy_shape = &graph.node(node.inputs[0]).shape;
3978 let idx_shape = &graph.node(node.inputs[1]).shape;
3979 let out_shape = &node.shape;
3980 let rank = out_shape.rank();
3981 let axis = match &node.op {
3982 Op::GatherBackward { axis } => *axis,
3983 _ => 0,
3984 };
3985 let axis_u = if axis < 0 {
3986 (rank as i32 + axis) as usize
3987 } else {
3988 axis as usize
3989 };
3990 let outer: usize = (0..axis_u)
3991 .map(|i| dy_shape.dim(i).unwrap_static())
3992 .product::<usize>()
3993 .max(1);
3994 let num_idx = idx_shape.dim(axis_u).unwrap_static();
3995 let trailing: usize = (axis_u + 1..dy_shape.rank())
3996 .map(|i| dy_shape.dim(i).unwrap_static())
3997 .product::<usize>()
3998 .max(1);
3999 let axis_dim = out_shape.dim(axis_u).unwrap_static();
4000 Thunk::GatherBackward {
4001 dy: node_offset(arena, node.inputs[0]),
4002 indices: node_offset(arena, node.inputs[1]),
4003 dst: node_offset(arena, node.id),
4004 outer: outer as u32,
4005 axis_dim: axis_dim as u32,
4006 num_idx: num_idx as u32,
4007 trailing: trailing as u32,
4008 }
4009 }
4010
4011 Op::GroupNormBackwardInput { num_groups, eps }
4012 | Op::GroupNormBackwardGamma { num_groups, eps }
4013 | Op::GroupNormBackwardBeta { num_groups, eps } => {
4014 let x_shape = &graph.node(node.inputs[0]).shape;
4015 let n = x_shape.dim(0).unwrap_static() as u32;
4016 let c = x_shape.dim(1).unwrap_static() as u32;
4017 let h = x_shape.dim(2).unwrap_static() as u32;
4018 let w = x_shape.dim(3).unwrap_static() as u32;
4019 match &node.op {
4020 Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4021 x: node_offset(arena, node.inputs[0]),
4022 gamma: node_offset(arena, node.inputs[1]),
4023 beta: node_offset(arena, node.inputs[2]),
4024 dy: node_offset(arena, node.inputs[3]),
4025 dx: node_offset(arena, node.id),
4026 n,
4027 c,
4028 h,
4029 w,
4030 num_groups: *num_groups as u32,
4031 eps: *eps,
4032 },
4033 Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4034 x: node_offset(arena, node.inputs[0]),
4035 dy: node_offset(arena, node.inputs[1]),
4036 dgamma: node_offset(arena, node.id),
4037 n,
4038 c,
4039 h,
4040 w,
4041 num_groups: *num_groups as u32,
4042 eps: *eps,
4043 },
4044 Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4045 dy: node_offset(arena, node.inputs[1]),
4046 dbeta: node_offset(arena, node.id),
4047 n,
4048 c,
4049 h,
4050 w,
4051 },
4052 _ => unreachable!(),
4053 }
4054 }
4055
4056 Op::MaxPool2dBackward {
4057 kernel_size,
4058 stride,
4059 padding,
4060 } => {
4061 let x_shape = &graph.node(node.inputs[0]).shape;
4062 let dy_shape = &graph.node(node.inputs[1]).shape;
4063 if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4064 Thunk::MaxPool2dBackward {
4065 x: node_offset(arena, node.inputs[0]),
4066 dy: node_offset(arena, node.inputs[1]),
4067 dx: node_offset(arena, node.id),
4068 n: x_shape.dim(0).unwrap_static() as u32,
4069 c: x_shape.dim(1).unwrap_static() as u32,
4070 h: x_shape.dim(2).unwrap_static() as u32,
4071 w: x_shape.dim(3).unwrap_static() as u32,
4072 h_out: dy_shape.dim(2).unwrap_static() as u32,
4073 w_out: dy_shape.dim(3).unwrap_static() as u32,
4074 kh: kernel_size[0] as u32,
4075 kw: kernel_size[1] as u32,
4076 sh: stride.first().copied().unwrap_or(1) as u32,
4077 sw: stride.get(1).copied().unwrap_or(1) as u32,
4078 ph: padding.first().copied().unwrap_or(0) as u32,
4079 pw: padding.get(1).copied().unwrap_or(0) as u32,
4080 }
4081 } else {
4082 Thunk::Nop
4083 }
4084 }
4085
4086 Op::Conv2dBackwardInput {
4087 kernel_size,
4088 stride,
4089 padding,
4090 dilation,
4091 groups,
4092 } => {
4093 let dy_shape = &graph.node(node.inputs[0]).shape;
4094 let w_shape = &graph.node(node.inputs[1]).shape;
4095 let out_shape = &node.shape;
4096 if kernel_size.len() == 2
4097 && dy_shape.rank() == 4
4098 && w_shape.rank() == 4
4099 && out_shape.rank() == 4
4100 {
4101 Thunk::Conv2dBackwardInput {
4102 dy: node_offset(arena, node.inputs[0]),
4103 w: node_offset(arena, node.inputs[1]),
4104 dx: node_offset(arena, node.id),
4105 n: out_shape.dim(0).unwrap_static() as u32,
4106 c_in: out_shape.dim(1).unwrap_static() as u32,
4107 h: out_shape.dim(2).unwrap_static() as u32,
4108 w_in: out_shape.dim(3).unwrap_static() as u32,
4109 c_out: dy_shape.dim(1).unwrap_static() as u32,
4110 h_out: dy_shape.dim(2).unwrap_static() as u32,
4111 w_out: dy_shape.dim(3).unwrap_static() as u32,
4112 kh: kernel_size[0] as u32,
4113 kw: kernel_size[1] as u32,
4114 sh: stride.first().copied().unwrap_or(1) as u32,
4115 sw: stride.get(1).copied().unwrap_or(1) as u32,
4116 ph: padding.first().copied().unwrap_or(0) as u32,
4117 pw: padding.get(1).copied().unwrap_or(0) as u32,
4118 dh: dilation.first().copied().unwrap_or(1) as u32,
4119 dw: dilation.get(1).copied().unwrap_or(1) as u32,
4120 groups: *groups as u32,
4121 }
4122 } else {
4123 Thunk::Nop
4124 }
4125 }
4126
4127 Op::Conv2dBackwardWeight {
4128 kernel_size,
4129 stride,
4130 padding,
4131 dilation,
4132 groups,
4133 } => {
4134 let x_shape = &graph.node(node.inputs[0]).shape;
4135 let dy_shape = &graph.node(node.inputs[1]).shape;
4136 let dw_shape = &node.shape;
4137 if kernel_size.len() == 2
4138 && x_shape.rank() == 4
4139 && dy_shape.rank() == 4
4140 && dw_shape.rank() == 4
4141 {
4142 Thunk::Conv2dBackwardWeight {
4143 x: node_offset(arena, node.inputs[0]),
4144 dy: node_offset(arena, node.inputs[1]),
4145 dw: node_offset(arena, node.id),
4146 n: x_shape.dim(0).unwrap_static() as u32,
4147 c_in: x_shape.dim(1).unwrap_static() as u32,
4148 h: x_shape.dim(2).unwrap_static() as u32,
4149 w: x_shape.dim(3).unwrap_static() as u32,
4150 c_out: dy_shape.dim(1).unwrap_static() as u32,
4151 h_out: dy_shape.dim(2).unwrap_static() as u32,
4152 w_out: dy_shape.dim(3).unwrap_static() as u32,
4153 kh: kernel_size[0] as u32,
4154 kw: kernel_size[1] as u32,
4155 sh: stride.first().copied().unwrap_or(1) as u32,
4156 sw: stride.get(1).copied().unwrap_or(1) as u32,
4157 ph: padding.first().copied().unwrap_or(0) as u32,
4158 pw: padding.get(1).copied().unwrap_or(0) as u32,
4159 dh: dilation.first().copied().unwrap_or(1) as u32,
4160 dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4161 groups: *groups as u32,
4162 }
4163 } else {
4164 Thunk::Nop
4165 }
4166 }
4167
4168 Op::SoftmaxCrossEntropyWithLogits => {
4169 let logits_shape = &graph.node(node.inputs[0]).shape;
4170 if logits_shape.rank() == 2 {
4171 Thunk::SoftmaxCrossEntropy {
4172 logits: node_offset(arena, node.inputs[0]),
4173 labels: node_offset(arena, node.inputs[1]),
4174 dst: node_offset(arena, node.id),
4175 n: logits_shape.dim(0).unwrap_static() as u32,
4176 c: logits_shape.dim(1).unwrap_static() as u32,
4177 }
4178 } else {
4179 Thunk::Nop
4180 }
4181 }
4182
4183 Op::SoftmaxCrossEntropyBackward => {
4184 let logits_shape = &graph.node(node.inputs[0]).shape;
4185 if logits_shape.rank() == 2 {
4186 Thunk::SoftmaxCrossEntropyBackward {
4187 logits: node_offset(arena, node.inputs[0]),
4188 labels: node_offset(arena, node.inputs[1]),
4189 d_loss: node_offset(arena, node.inputs[2]),
4190 dlogits: node_offset(arena, node.id),
4191 n: logits_shape.dim(0).unwrap_static() as u32,
4192 c: logits_shape.dim(1).unwrap_static() as u32,
4193 }
4194 } else {
4195 Thunk::Nop
4196 }
4197 }
4198
4199 Op::DenseSolve => {
4200 let a_shape = &graph.node(node.inputs[0]).shape;
4202 let n = a_shape.dim(0).unwrap_static();
4203 debug_assert_eq!(
4204 n,
4205 a_shape.dim(1).unwrap_static(),
4206 "DenseSolve: A must be square"
4207 );
4208 let b_elems = node.shape.num_elements().unwrap();
4209 let nrhs = b_elems / n;
4210 match node.shape.dtype() {
4211 rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4212 a: node_offset(arena, node.inputs[0]),
4213 b: node_offset(arena, node.inputs[1]),
4214 x: node_offset(arena, node.id),
4215 n: n as u32,
4216 nrhs: nrhs as u32,
4217 },
4218 rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4219 a: node_offset(arena, node.inputs[0]),
4220 b: node_offset(arena, node.inputs[1]),
4221 x: node_offset(arena, node.id),
4222 n: n as u32,
4223 nrhs: nrhs as u32,
4224 },
4225 other => panic!(
4226 "DenseSolve: F32 + F64 lowered; got {other:?}. \
4227 Add another variant when needed."
4228 ),
4229 }
4230 }
4231
4232 Op::BatchedDenseSolve => {
4233 let a_shape = &graph.node(node.inputs[0]).shape;
4235 assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4236 let batch = a_shape.dim(0).unwrap_static();
4237 let n = a_shape.dim(1).unwrap_static();
4238 debug_assert_eq!(
4239 n,
4240 a_shape.dim(2).unwrap_static(),
4241 "BatchedDenseSolve: A's last two dims must match"
4242 );
4243 let total = node.shape.num_elements().unwrap();
4244 let nrhs = total / (batch * n);
4245 match node.shape.dtype() {
4246 rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4247 a: node_offset(arena, node.inputs[0]),
4248 b: node_offset(arena, node.inputs[1]),
4249 x: node_offset(arena, node.id),
4250 batch: batch as u32,
4251 n: n as u32,
4252 nrhs: nrhs as u32,
4253 },
4254 rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4255 a: node_offset(arena, node.inputs[0]),
4256 b: node_offset(arena, node.inputs[1]),
4257 x: node_offset(arena, node.id),
4258 batch: batch as u32,
4259 n: n as u32,
4260 nrhs: nrhs as u32,
4261 },
4262 other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4263 }
4264 }
4265
4266 Op::Scan {
4267 body,
4268 length,
4269 save_trajectory,
4270 num_bcast,
4271 num_xs,
4272 num_checkpoints,
4273 } => {
4274 assert!(
4275 *num_checkpoints == 0 || *num_checkpoints <= *length,
4276 "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4277 *num_checkpoints,
4278 *length
4279 );
4280 if *num_checkpoints != 0 && *num_checkpoints != *length {
4281 assert!(
4282 *save_trajectory,
4283 "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4284 );
4285 }
4286 let body_plan = rlx_opt::memory::plan_memory(body);
4297 let _body_arena_size = body_plan.arena_size;
4298 let body_offsets: HashMap<NodeId, usize> = body_plan
4301 .assignments
4302 .iter()
4303 .map(|(id, slot)| (*id, slot.offset))
4304 .collect();
4305
4306 let mut body_inputs: Vec<NodeId> = body
4309 .nodes()
4310 .iter()
4311 .filter(|n| matches!(n.op, Op::Input { .. }))
4312 .map(|n| n.id)
4313 .collect();
4314 body_inputs.sort();
4315 let n_body_inputs = body_inputs.len();
4316 let expected = 1 + *num_bcast as usize + *num_xs as usize;
4317 if n_body_inputs != expected {
4318 let names: Vec<String> = body
4319 .nodes()
4320 .iter()
4321 .filter_map(|n| match &n.op {
4322 Op::Input { name } => Some(format!("{}={}", n.id, name)),
4323 _ => None,
4324 })
4325 .collect();
4326 panic!(
4327 "Op::Scan body has {} Op::Input nodes; expected {} \
4328 (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4329 n_body_inputs,
4330 expected,
4331 *num_bcast,
4332 *num_xs,
4333 names.join(", ")
4334 );
4335 }
4336
4337 let body_input_id = body_inputs[0];
4338 let body_input_off = body_offsets[&body_input_id];
4339 let body_output_id = body
4340 .outputs
4341 .first()
4342 .copied()
4343 .expect("Op::Scan body must declare one output");
4344 let body_output_off = body_offsets[&body_output_id];
4345
4346 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4347 for n in body.nodes() {
4350 if let Op::Constant { data } = &n.op
4351 && body_arena.has_buffer(n.id)
4352 && !data.is_empty()
4353 {
4354 match n.shape.dtype() {
4355 rlx_ir::DType::F64 => {
4356 let off = body_arena.byte_offset(n.id);
4357 let buf = body_arena.raw_buf_mut();
4358 let nbytes = (buf.len() - off).min(data.len());
4359 buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4360 }
4361 _ => {
4362 let buf = body_arena.slice_mut(n.id);
4363 let n_floats = data.len() / 4;
4364 let n_lim = buf.len().min(n_floats);
4365 for i in 0..n_lim {
4366 let bytes = [
4367 data[i * 4],
4368 data[i * 4 + 1],
4369 data[i * 4 + 2],
4370 data[i * 4 + 3],
4371 ];
4372 buf[i] = f32::from_le_bytes(bytes);
4373 }
4374 }
4375 }
4376 }
4377 }
4378 let body_init = body_arena.raw_buf().to_vec();
4379 let body_schedule = compile_thunks(body, &body_arena);
4380
4381 let carry_bytes = if *save_trajectory {
4386 let total = node
4387 .shape
4388 .size_bytes()
4389 .expect("Op::Scan trajectory output must have static shape");
4390 total / *length as usize
4391 } else {
4392 node.shape
4393 .size_bytes()
4394 .expect("Op::Scan carry must have static shape")
4395 };
4396
4397 let mut bcast_inputs: Vec<(usize, usize, u32)> =
4402 Vec::with_capacity(*num_bcast as usize);
4403 for i in 0..*num_bcast as usize {
4404 let body_b_id = body_inputs[1 + i];
4405 let body_b_off = body_offsets[&body_b_id];
4406 let outer_b_id = node.inputs[1 + i];
4407 let outer_b_off = node_offset(arena, outer_b_id);
4408 let outer_b_shape = &graph.node(outer_b_id).shape;
4409 let total = outer_b_shape
4410 .size_bytes()
4411 .expect("Op::Scan bcast must have static shape");
4412 bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4413 }
4414
4415 let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4419 let xs_base = 1 + *num_bcast as usize;
4420 for i in 0..*num_xs as usize {
4421 let body_x_id = body_inputs[xs_base + i];
4422 let body_x_off = body_offsets[&body_x_id];
4423 let outer_xs_id = node.inputs[xs_base + i];
4424 let outer_xs_off = node_offset(arena, outer_xs_id);
4425 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4426 let total = outer_xs_shape
4427 .size_bytes()
4428 .expect("Op::Scan xs must have static shape");
4429 let per_step = total / *length as usize;
4430 xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4431 }
4432
4433 Thunk::Scan {
4434 body: Arc::new(body_schedule),
4435 body_init: Arc::new(body_init),
4436 body_input_off,
4437 body_output_off,
4438 outer_init_off: node_offset(arena, node.inputs[0]),
4439 outer_final_off: node_offset(arena, node.id),
4440 length: *length,
4441 carry_bytes: carry_bytes as u32,
4442 save_trajectory: *save_trajectory,
4443 xs_inputs: Arc::new(xs_inputs),
4444 bcast_inputs: Arc::new(bcast_inputs),
4445 num_checkpoints: *num_checkpoints,
4446 }
4447 }
4448
4449 Op::ScanBackward {
4450 body_vjp,
4451 length,
4452 save_trajectory,
4453 num_xs,
4454 num_checkpoints,
4455 forward_body,
4456 } => {
4457 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4458 if is_recursive {
4459 assert!(
4460 forward_body.is_some(),
4461 "Op::ScanBackward with num_checkpoints<length requires forward_body"
4462 );
4463 }
4464 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4472 let body_offsets: HashMap<NodeId, usize> = body_plan
4473 .assignments
4474 .iter()
4475 .map(|(id, slot)| (*id, slot.offset))
4476 .collect();
4477 let mut body_d_output_off: Option<usize> = None;
4478 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4479 for n in body_vjp.nodes() {
4480 if let Op::Input { name } = &n.op {
4481 let off = body_offsets[&n.id];
4482 if name == "d_output" {
4483 body_d_output_off = Some(off);
4484 } else {
4485 body_other_inputs.push((n.id, off));
4486 }
4487 }
4488 }
4489 body_other_inputs.sort_by_key(|(id, _)| *id);
4490 let body_d_output_off =
4491 body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4492 let expected_others = 1 + *num_xs as usize;
4493 assert_eq!(
4494 body_other_inputs.len(),
4495 expected_others,
4496 "ScanBackward body_vjp has {} non-d_output Inputs; \
4497 expected {} (1 carry + {} xs)",
4498 body_other_inputs.len(),
4499 expected_others,
4500 num_xs
4501 );
4502 let body_carry_in_off = body_other_inputs[0].1;
4503 let body_x_offs: Vec<usize> = body_other_inputs
4504 .iter()
4505 .skip(1)
4506 .map(|(_, off)| *off)
4507 .collect();
4508 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4509
4510 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4511 for n in body_vjp.nodes() {
4513 if let Op::Constant { data } = &n.op
4514 && body_arena.has_buffer(n.id)
4515 && !data.is_empty()
4516 {
4517 match n.shape.dtype() {
4518 rlx_ir::DType::F64 => {
4519 let off = body_arena.byte_offset(n.id);
4520 let buf = body_arena.raw_buf_mut();
4521 let nb = (buf.len() - off).min(data.len());
4522 buf[off..off + nb].copy_from_slice(&data[..nb]);
4523 }
4524 _ => {
4525 let buf = body_arena.slice_mut(n.id);
4526 let nf = data.len() / 4;
4527 let nl = buf.len().min(nf);
4528 for i in 0..nl {
4529 let bytes = [
4530 data[i * 4],
4531 data[i * 4 + 1],
4532 data[i * 4 + 2],
4533 data[i * 4 + 3],
4534 ];
4535 buf[i] = f32::from_le_bytes(bytes);
4536 }
4537 }
4538 }
4539 }
4540 }
4541 let body_init = body_arena.raw_buf().to_vec();
4542 let body_schedule = compile_thunks(body_vjp, &body_arena);
4543
4544 let carry_bytes = body_vjp
4546 .node(body_vjp.outputs[0])
4547 .shape
4548 .size_bytes()
4549 .expect("ScanBackward dcarry must be statically shaped");
4550 let carry_elem_size = body_vjp
4551 .node(body_vjp.outputs[0])
4552 .shape
4553 .dtype()
4554 .size_bytes() as u32;
4555
4556 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4559 for i in 0..*num_xs as usize {
4560 let outer_xs_id = node.inputs[3 + i];
4561 let outer_xs_off = node_offset(arena, outer_xs_id);
4562 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4563 let total = outer_xs_shape
4564 .size_bytes()
4565 .expect("ScanBackward xs must have static shape");
4566 let per_step = total / *length as usize;
4567 outer_xs_offs.push((outer_xs_off, per_step as u32));
4568 }
4569
4570 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4575 if is_recursive {
4576 let fb = forward_body.as_ref().unwrap();
4577 let fb_plan = rlx_opt::memory::plan_memory(fb);
4578 let fb_offsets: HashMap<NodeId, usize> = fb_plan
4579 .assignments
4580 .iter()
4581 .map(|(id, slot)| (*id, slot.offset))
4582 .collect();
4583 let mut fb_inputs: Vec<NodeId> = fb
4584 .nodes()
4585 .iter()
4586 .filter(|n| matches!(n.op, Op::Input { .. }))
4587 .map(|n| n.id)
4588 .collect();
4589 fb_inputs.sort();
4590 let fb_carry = fb_offsets[&fb_inputs[0]];
4591 let fb_xs: Vec<usize> = (1..fb_inputs.len())
4592 .map(|i| fb_offsets[&fb_inputs[i]])
4593 .collect();
4594 let fb_out = fb_offsets[&fb.outputs[0]];
4595 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4596 for n in fb.nodes() {
4597 if let Op::Constant { data } = &n.op
4598 && fb_arena.has_buffer(n.id)
4599 && !data.is_empty()
4600 {
4601 let off = fb_arena.byte_offset(n.id);
4608 let buf = fb_arena.raw_buf_mut();
4609 let nb = (buf.len() - off).min(data.len());
4610 buf[off..off + nb].copy_from_slice(&data[..nb]);
4611 }
4612 }
4613 let fb_init_bytes = fb_arena.raw_buf().to_vec();
4614 let fb_sched = compile_thunks(fb, &fb_arena);
4615 (
4616 Some(Arc::new(fb_sched)),
4617 Some(Arc::new(fb_init_bytes)),
4618 fb_carry,
4619 fb_out,
4620 fb_xs,
4621 )
4622 } else {
4623 (None, None, 0, 0, Vec::new())
4624 };
4625
4626 Thunk::ScanBackward {
4627 body_vjp: Arc::new(body_schedule),
4628 body_init: Arc::new(body_init),
4629 body_carry_in_off,
4630 body_x_offs: Arc::new(body_x_offs),
4631 body_d_output_off,
4632 body_dcarry_out_off,
4633 outer_init_off: node_offset(arena, node.inputs[0]),
4634 outer_traj_off: node_offset(arena, node.inputs[1]),
4635 outer_upstream_off: node_offset(arena, node.inputs[2]),
4636 outer_xs_offs: Arc::new(outer_xs_offs),
4637 outer_dinit_off: node_offset(arena, node.id),
4638 length: *length,
4639 carry_bytes: carry_bytes as u32,
4640 carry_elem_size,
4641 save_trajectory: *save_trajectory,
4642 num_checkpoints: *num_checkpoints,
4643 forward_body: fb_schedule,
4644 forward_body_init: fb_init,
4645 forward_body_carry_in_off: fb_carry_in_off,
4646 forward_body_output_off: fb_output_off,
4647 forward_body_x_offs: Arc::new(fb_x_offs),
4648 }
4649 }
4650
4651 Op::ScanBackwardXs {
4652 body_vjp,
4653 length,
4654 save_trajectory,
4655 num_xs,
4656 xs_idx,
4657 num_checkpoints,
4658 forward_body,
4659 } => {
4660 assert!(
4661 *num_checkpoints == 0 || *num_checkpoints <= *length,
4662 "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
4663 *num_checkpoints,
4664 *length
4665 );
4666 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4667 if is_recursive {
4668 assert!(
4669 forward_body.is_some(),
4670 "Op::ScanBackwardXs with num_checkpoints<length \
4671 requires forward_body"
4672 );
4673 }
4674 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4682 let body_offsets: HashMap<NodeId, usize> = body_plan
4683 .assignments
4684 .iter()
4685 .map(|(id, slot)| (*id, slot.offset))
4686 .collect();
4687 let mut body_d_output_off: Option<usize> = None;
4688 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4689 for n in body_vjp.nodes() {
4690 if let Op::Input { name } = &n.op {
4691 let off = body_offsets[&n.id];
4692 if name == "d_output" {
4693 body_d_output_off = Some(off);
4694 } else {
4695 body_other_inputs.push((n.id, off));
4696 }
4697 }
4698 }
4699 body_other_inputs.sort_by_key(|(id, _)| *id);
4700 let body_d_output_off =
4701 body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
4702 let expected_others = 1 + *num_xs as usize;
4703 assert_eq!(
4704 body_other_inputs.len(),
4705 expected_others,
4706 "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
4707 body_other_inputs.len(),
4708 expected_others
4709 );
4710 let body_carry_in_off = body_other_inputs[0].1;
4711 let body_x_offs: Vec<usize> = body_other_inputs
4712 .iter()
4713 .skip(1)
4714 .map(|(_, off)| *off)
4715 .collect();
4716 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4717 let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
4718 let body_dxs_out_off = body_offsets[&dxs_out_node];
4719
4720 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4721 for n in body_vjp.nodes() {
4722 if let Op::Constant { data } = &n.op
4723 && body_arena.has_buffer(n.id)
4724 && !data.is_empty()
4725 {
4726 match n.shape.dtype() {
4727 rlx_ir::DType::F64 => {
4728 let off = body_arena.byte_offset(n.id);
4729 let buf = body_arena.raw_buf_mut();
4730 let nb = (buf.len() - off).min(data.len());
4731 buf[off..off + nb].copy_from_slice(&data[..nb]);
4732 }
4733 _ => {
4734 let buf = body_arena.slice_mut(n.id);
4735 let nf = data.len() / 4;
4736 let nl = buf.len().min(nf);
4737 for i in 0..nl {
4738 let bytes = [
4739 data[i * 4],
4740 data[i * 4 + 1],
4741 data[i * 4 + 2],
4742 data[i * 4 + 3],
4743 ];
4744 buf[i] = f32::from_le_bytes(bytes);
4745 }
4746 }
4747 }
4748 }
4749 }
4750 let body_init = body_arena.raw_buf().to_vec();
4751 let body_schedule = compile_thunks(body_vjp, &body_arena);
4752
4753 let carry_bytes = body_vjp
4754 .node(body_vjp.outputs[0])
4755 .shape
4756 .size_bytes()
4757 .expect("ScanBackwardXs dcarry must be statically shaped");
4758 let carry_elem_size = body_vjp
4759 .node(body_vjp.outputs[0])
4760 .shape
4761 .dtype()
4762 .size_bytes() as u32;
4763 let per_step_bytes = body_vjp
4764 .node(dxs_out_node)
4765 .shape
4766 .size_bytes()
4767 .expect("ScanBackwardXs dxs body output must be statically shaped");
4768
4769 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4770 for i in 0..*num_xs as usize {
4771 let outer_xs_id = node.inputs[3 + i];
4772 let outer_xs_off = node_offset(arena, outer_xs_id);
4773 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4774 let total = outer_xs_shape
4775 .size_bytes()
4776 .expect("ScanBackwardXs xs must have static shape");
4777 let per_step = total / *length as usize;
4778 outer_xs_offs.push((outer_xs_off, per_step as u32));
4779 }
4780
4781 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4784 if is_recursive {
4785 let fb = forward_body.as_ref().unwrap();
4786 let fb_plan = rlx_opt::memory::plan_memory(fb);
4787 let fb_offsets: HashMap<NodeId, usize> = fb_plan
4788 .assignments
4789 .iter()
4790 .map(|(id, slot)| (*id, slot.offset))
4791 .collect();
4792 let mut fb_inputs: Vec<NodeId> = fb
4793 .nodes()
4794 .iter()
4795 .filter(|n| matches!(n.op, Op::Input { .. }))
4796 .map(|n| n.id)
4797 .collect();
4798 fb_inputs.sort();
4799 let fb_carry = fb_offsets[&fb_inputs[0]];
4800 let fb_xs: Vec<usize> = (1..fb_inputs.len())
4801 .map(|i| fb_offsets[&fb_inputs[i]])
4802 .collect();
4803 let fb_out = fb_offsets[&fb.outputs[0]];
4804 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4805 for n in fb.nodes() {
4806 if let Op::Constant { data } = &n.op
4807 && fb_arena.has_buffer(n.id)
4808 && !data.is_empty()
4809 {
4810 let off = fb_arena.byte_offset(n.id);
4817 let buf = fb_arena.raw_buf_mut();
4818 let nb = (buf.len() - off).min(data.len());
4819 buf[off..off + nb].copy_from_slice(&data[..nb]);
4820 }
4821 }
4822 let fb_init_bytes = fb_arena.raw_buf().to_vec();
4823 let fb_sched = compile_thunks(fb, &fb_arena);
4824 (
4825 Some(Arc::new(fb_sched)),
4826 Some(Arc::new(fb_init_bytes)),
4827 fb_carry,
4828 fb_out,
4829 fb_xs,
4830 )
4831 } else {
4832 (None, None, 0, 0, Vec::new())
4833 };
4834
4835 Thunk::ScanBackwardXs {
4836 body_vjp: Arc::new(body_schedule),
4837 body_init: Arc::new(body_init),
4838 body_carry_in_off,
4839 body_x_offs: Arc::new(body_x_offs),
4840 body_d_output_off,
4841 body_dcarry_out_off,
4842 body_dxs_out_off,
4843 outer_init_off: node_offset(arena, node.inputs[0]),
4844 outer_traj_off: node_offset(arena, node.inputs[1]),
4845 outer_upstream_off: node_offset(arena, node.inputs[2]),
4846 outer_xs_offs: Arc::new(outer_xs_offs),
4847 outer_dxs_off: node_offset(arena, node.id),
4848 length: *length,
4849 carry_bytes: carry_bytes as u32,
4850 carry_elem_size,
4851 per_step_bytes: per_step_bytes as u32,
4852 save_trajectory: *save_trajectory,
4853 num_checkpoints: *num_checkpoints,
4854 forward_body: fb_schedule,
4855 forward_body_init: fb_init,
4856 forward_body_carry_in_off: fb_carry_in_off,
4857 forward_body_output_off: fb_output_off,
4858 forward_body_x_offs: Arc::new(fb_x_offs),
4859 }
4860 }
4861
4862 Op::Concat { axis } => {
4863 let out_shape = &node.shape;
4867 let rank = out_shape.rank();
4868 let outer: usize = (0..*axis)
4869 .map(|i| out_shape.dim(i).unwrap_static())
4870 .product::<usize>()
4871 .max(1);
4872 let inner: usize = (*axis + 1..rank)
4873 .map(|i| out_shape.dim(i).unwrap_static())
4874 .product::<usize>()
4875 .max(1);
4876 let total_axis = out_shape.dim(*axis).unwrap_static();
4877 let inputs: Vec<(usize, u32)> = node
4878 .inputs
4879 .iter()
4880 .map(|&in_id| {
4881 let in_shape = &graph.node(in_id).shape;
4882 let in_axis = in_shape.dim(*axis).unwrap_static();
4883 (node_offset(arena, in_id), in_axis as u32)
4884 })
4885 .collect();
4886 let dst = node_offset(arena, node.id);
4887 match out_shape.dtype() {
4888 rlx_ir::DType::F64 => Thunk::ConcatF64 {
4889 dst,
4890 outer: outer as u32,
4891 inner: inner as u32,
4892 total_axis: total_axis as u32,
4893 inputs,
4894 },
4895 _ => Thunk::Concat {
4896 dst,
4897 outer: outer as u32,
4898 inner: inner as u32,
4899 total_axis: total_axis as u32,
4900 inputs,
4901 },
4902 }
4903 }
4904
4905 Op::GaussianSplatRender {
4906 width,
4907 height,
4908 tile_size,
4909 radius_scale,
4910 alpha_cutoff,
4911 max_splat_steps,
4912 transmittance_threshold,
4913 max_list_entries,
4914 } => {
4915 let elem_len =
4916 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4917 Thunk::GaussianSplatRender {
4918 positions_off: node_offset(arena, node.inputs[0]),
4919 positions_len: elem_len(node.inputs[0]),
4920 scales_off: node_offset(arena, node.inputs[1]),
4921 scales_len: elem_len(node.inputs[1]),
4922 rotations_off: node_offset(arena, node.inputs[2]),
4923 rotations_len: elem_len(node.inputs[2]),
4924 opacities_off: node_offset(arena, node.inputs[3]),
4925 opacities_len: elem_len(node.inputs[3]),
4926 colors_off: node_offset(arena, node.inputs[4]),
4927 colors_len: elem_len(node.inputs[4]),
4928 sh_coeffs_off: node_offset(arena, node.inputs[5]),
4929 sh_coeffs_len: elem_len(node.inputs[5]),
4930 meta_off: node_offset(arena, node.inputs[6]),
4931 dst_off: node_offset(arena, node.id),
4932 dst_len: node.shape.num_elements().unwrap_or(0),
4933 width: *width,
4934 height: *height,
4935 tile_size: *tile_size,
4936 radius_scale: *radius_scale,
4937 alpha_cutoff: *alpha_cutoff,
4938 max_splat_steps: *max_splat_steps,
4939 transmittance_threshold: *transmittance_threshold,
4940 max_list_entries: *max_list_entries,
4941 }
4942 }
4943
4944 Op::GaussianSplatRenderBackward {
4945 width,
4946 height,
4947 tile_size,
4948 radius_scale,
4949 alpha_cutoff,
4950 max_splat_steps,
4951 transmittance_threshold,
4952 max_list_entries,
4953 loss_grad_clip,
4954 sh_band,
4955 max_anisotropy,
4956 } => {
4957 let elem_len =
4958 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4959 Thunk::GaussianSplatRenderBackward {
4960 positions_off: node_offset(arena, node.inputs[0]),
4961 positions_len: elem_len(node.inputs[0]),
4962 scales_off: node_offset(arena, node.inputs[1]),
4963 scales_len: elem_len(node.inputs[1]),
4964 rotations_off: node_offset(arena, node.inputs[2]),
4965 rotations_len: elem_len(node.inputs[2]),
4966 opacities_off: node_offset(arena, node.inputs[3]),
4967 opacities_len: elem_len(node.inputs[3]),
4968 colors_off: node_offset(arena, node.inputs[4]),
4969 colors_len: elem_len(node.inputs[4]),
4970 sh_coeffs_off: node_offset(arena, node.inputs[5]),
4971 sh_coeffs_len: elem_len(node.inputs[5]),
4972 meta_off: node_offset(arena, node.inputs[6]),
4973 d_loss_off: node_offset(arena, node.inputs[7]),
4974 d_loss_len: elem_len(node.inputs[7]),
4975 packed_off: node_offset(arena, node.id),
4976 packed_len: node.shape.num_elements().unwrap_or(0),
4977 width: *width,
4978 height: *height,
4979 tile_size: *tile_size,
4980 radius_scale: *radius_scale,
4981 alpha_cutoff: *alpha_cutoff,
4982 max_splat_steps: *max_splat_steps,
4983 transmittance_threshold: *transmittance_threshold,
4984 max_list_entries: *max_list_entries,
4985 loss_grad_clip: *loss_grad_clip,
4986 sh_band: *sh_band,
4987 max_anisotropy: *max_anisotropy,
4988 }
4989 }
4990
4991 Op::GaussianSplatPrepare {
4992 width,
4993 height,
4994 tile_size,
4995 radius_scale,
4996 alpha_cutoff,
4997 max_splat_steps,
4998 transmittance_threshold,
4999 max_list_entries,
5000 } => {
5001 let elem_len =
5002 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5003 Thunk::GaussianSplatPrepare {
5004 positions_off: node_offset(arena, node.inputs[0]),
5005 positions_len: elem_len(node.inputs[0]),
5006 scales_off: node_offset(arena, node.inputs[1]),
5007 scales_len: elem_len(node.inputs[1]),
5008 rotations_off: node_offset(arena, node.inputs[2]),
5009 rotations_len: elem_len(node.inputs[2]),
5010 opacities_off: node_offset(arena, node.inputs[3]),
5011 opacities_len: elem_len(node.inputs[3]),
5012 colors_off: node_offset(arena, node.inputs[4]),
5013 colors_len: elem_len(node.inputs[4]),
5014 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5015 sh_coeffs_len: elem_len(node.inputs[5]),
5016 meta_off: node_offset(arena, node.inputs[6]),
5017 meta_len: elem_len(node.inputs[6]),
5018 prep_off: node_offset(arena, node.id),
5019 prep_len: node.shape.num_elements().unwrap_or(0),
5020 width: *width,
5021 height: *height,
5022 tile_size: *tile_size,
5023 radius_scale: *radius_scale,
5024 alpha_cutoff: *alpha_cutoff,
5025 max_splat_steps: *max_splat_steps,
5026 transmittance_threshold: *transmittance_threshold,
5027 max_list_entries: *max_list_entries,
5028 }
5029 }
5030
5031 Op::GaussianSplatRasterize {
5032 width,
5033 height,
5034 tile_size,
5035 alpha_cutoff,
5036 max_splat_steps,
5037 transmittance_threshold,
5038 max_list_entries,
5039 } => {
5040 let elem_len =
5041 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5042 let prep_id = node.inputs[0];
5043 let count = match &graph.node(prep_id).op {
5044 rlx_ir::Op::GaussianSplatPrepare { .. } => {
5045 elem_len(graph.node(prep_id).inputs[0]) / 3
5046 }
5047 _ => 1,
5048 };
5049 Thunk::GaussianSplatRasterize {
5050 prep_off: node_offset(arena, prep_id),
5051 prep_len: elem_len(prep_id),
5052 meta_off: node_offset(arena, node.inputs[1]),
5053 meta_len: elem_len(node.inputs[1]),
5054 dst_off: node_offset(arena, node.id),
5055 dst_len: node.shape.num_elements().unwrap_or(0),
5056 count,
5057 width: *width,
5058 height: *height,
5059 tile_size: *tile_size,
5060 alpha_cutoff: *alpha_cutoff,
5061 max_splat_steps: *max_splat_steps,
5062 transmittance_threshold: *transmittance_threshold,
5063 max_list_entries: *max_list_entries,
5064 }
5065 }
5066
5067 Op::Custom { name, attrs, .. } => {
5068 let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5069 panic!(
5070 "compile_thunks: no CPU kernel registered for \
5071 Op::Custom('{name}'). Register one via \
5072 rlx_cpu::op_registry::register_cpu_kernel \
5073 before compiling on the CPU backend."
5074 )
5075 });
5076 let inputs_v: Vec<(usize, u32, Shape)> = node
5077 .inputs
5078 .iter()
5079 .map(|&in_id| {
5080 let s = graph.node(in_id).shape.clone();
5081 let len = s.num_elements().unwrap_or(0) as u32;
5082 (node_offset(arena, in_id), len, s)
5083 })
5084 .collect();
5085 let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5086 Thunk::CustomOp {
5087 kernel,
5088 inputs: inputs_v,
5089 output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5090 attrs: attrs.clone(),
5091 }
5092 }
5093
5094 Op::Fft { inverse } => {
5095 let shape = &node.shape;
5103 let last = shape.dim(shape.rank() - 1).unwrap_static();
5104 let n_complex = (last / 2) as u32;
5105 let total = shape.num_elements().unwrap_or(0);
5106 let outer = (total / last) as u32;
5107 let dtype = shape.dtype();
5108 assert!(
5109 matches!(dtype, rlx_ir::DType::F32 | rlx_ir::DType::F64),
5110 "Op::Fft on CPU requires F32 or F64, got {dtype:?}"
5111 );
5112 Thunk::Fft1d {
5113 src: node_offset(arena, node.inputs[0]),
5114 dst: node_offset(arena, node.id),
5115 outer,
5116 n_complex,
5117 inverse: *inverse,
5118 dtype,
5119 }
5120 }
5121
5122 Op::CustomFn {
5123 fwd_body,
5124 num_inputs,
5125 ..
5126 } => {
5127 let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5133 let body_offsets: HashMap<NodeId, usize> = body_plan
5134 .assignments
5135 .iter()
5136 .map(|(id, slot)| (*id, slot.offset))
5137 .collect();
5138
5139 let mut body_input_ids: Vec<NodeId> = fwd_body
5140 .nodes()
5141 .iter()
5142 .filter(|n| matches!(n.op, Op::Input { .. }))
5143 .map(|n| n.id)
5144 .collect();
5145 body_input_ids.sort();
5146 assert_eq!(
5147 body_input_ids.len(),
5148 *num_inputs as usize,
5149 "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5150 body_input_ids.len(),
5151 *num_inputs,
5152 );
5153
5154 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5155 for n in fwd_body.nodes() {
5156 if let Op::Constant { data } = &n.op
5157 && body_arena.has_buffer(n.id)
5158 && !data.is_empty()
5159 {
5160 match n.shape.dtype() {
5161 rlx_ir::DType::F64 => {
5162 let off = body_arena.byte_offset(n.id);
5163 let buf = body_arena.raw_buf_mut();
5164 let nb = (buf.len() - off).min(data.len());
5165 buf[off..off + nb].copy_from_slice(&data[..nb]);
5166 }
5167 _ => {
5168 let buf = body_arena.slice_mut(n.id);
5169 let nf = data.len() / 4;
5170 let nl = buf.len().min(nf);
5171 for i in 0..nl {
5172 let bytes = [
5173 data[i * 4],
5174 data[i * 4 + 1],
5175 data[i * 4 + 2],
5176 data[i * 4 + 3],
5177 ];
5178 buf[i] = f32::from_le_bytes(bytes);
5179 }
5180 }
5181 }
5182 }
5183 }
5184 let body_init = body_arena.raw_buf().to_vec();
5185 let body_schedule = compile_thunks(fwd_body, &body_arena);
5186
5187 let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5189 .map(|i| {
5190 let body_in = body_input_ids[i];
5191 let body_off = body_offsets[&body_in];
5192 let outer_in = node.inputs[i];
5193 let outer_off = node_offset(arena, outer_in);
5194 let bytes = graph
5195 .node(outer_in)
5196 .shape
5197 .size_bytes()
5198 .expect("Op::CustomFn primal input must have static shape");
5199 (body_off, outer_off, bytes as u32)
5200 })
5201 .collect();
5202
5203 let body_output_id = fwd_body
5204 .outputs
5205 .first()
5206 .copied()
5207 .expect("Op::CustomFn fwd_body must declare exactly one output");
5208 let body_output_off = body_offsets[&body_output_id];
5209 let out_bytes = node
5210 .shape
5211 .size_bytes()
5212 .expect("Op::CustomFn output must have static shape");
5213
5214 Thunk::CustomFn {
5215 body: Arc::new(body_schedule),
5216 body_init: Arc::new(body_init),
5217 inputs: Arc::new(inputs_v),
5218 body_output_off,
5219 outer_output_off: node_offset(arena, node.id),
5220 out_bytes: out_bytes as u32,
5221 }
5222 }
5223
5224 _ => Thunk::Nop,
5225 };
5226 thunks.push(t);
5227 }
5228
5229 let cfg = crate::config::RuntimeConfig::global();
5230 let mask_thr = cfg.mask_binary_threshold;
5231 let mask_neg = cfg.attn_mask_neg_inf;
5232 let score_skip = cfg.score_skip_threshold;
5233
5234 let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5236 .iter()
5237 .filter(|t| !matches!(t, Thunk::Nop))
5238 .map(|thunk| {
5239 match thunk.clone() {
5240 Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5241
5242 Thunk::Sgemm { a, b, c, m, k, n } => {
5243 let (m, k, n) = (m as usize, k as usize, n as usize);
5244 Arc::new(move |base: *mut u8| unsafe {
5245 crate::blas::sgemm(
5246 sl(a, base, m * k),
5247 sl(b, base, k * n),
5248 sl_mut(c, base, m * n),
5249 m,
5250 k,
5251 n,
5252 );
5253 })
5254 }
5255
5256 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5257 let (n_, nrhs_) = (n as usize, nrhs as usize);
5258 Arc::new(move |base: *mut u8| unsafe {
5259 let a_src = sl_f64(a, base, n_ * n_);
5260 let b_src = sl_f64(b, base, n_ * nrhs_);
5261 let mut a_scratch: Vec<f64> = a_src.to_vec();
5262 let mut x_buf: Vec<f64> = b_src.to_vec();
5263 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5264 if info != 0 {
5265 panic!("DenseSolveF64: singular (info={info})");
5266 }
5267 sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5268 })
5269 }
5270
5271 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5272 let (n_, nrhs_) = (n as usize, nrhs as usize);
5273 Arc::new(move |base: *mut u8| unsafe {
5274 let a_src = sl(a, base, n_ * n_);
5275 let b_src = sl(b, base, n_ * nrhs_);
5276 let mut a_scratch: Vec<f32> = a_src.to_vec();
5277 let mut x_buf: Vec<f32> = b_src.to_vec();
5278 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5279 if info != 0 {
5280 panic!("DenseSolveF32: singular (info={info})");
5281 }
5282 sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5283 })
5284 }
5285
5286 Thunk::FusedMmBiasAct {
5287 a,
5288 w,
5289 bias,
5290 c,
5291 m,
5292 k,
5293 n,
5294 act,
5295 } => {
5296 let (m, k, n) = (m as usize, k as usize, n as usize);
5297 Arc::new(move |base: *mut u8| unsafe {
5298 let out = sl_mut(c, base, m * n);
5299 crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5300 match act {
5308 Some(Activation::Gelu) => {
5309 crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5310 }
5311 Some(other) => {
5312 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5313 apply_activation_inplace(out, other);
5314 }
5315 None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5316 }
5317 })
5318 }
5319
5320 Thunk::FusedResidualLN {
5321 x,
5322 res,
5323 bias,
5324 g,
5325 b,
5326 out,
5327 rows,
5328 h,
5329 eps,
5330 has_bias,
5331 } => {
5332 let (rows, h) = (rows as usize, h as usize);
5333 Arc::new(move |base: *mut u8| unsafe {
5334 let zero = vec![0f32; h]; let bi = if has_bias { sl(bias, base, h) } else { &zero };
5336 let xp = sl(x, base, rows * h).as_ptr() as usize;
5337 let rp = sl(res, base, rows * h).as_ptr() as usize;
5338 let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5339 let bp = bi.as_ptr() as usize;
5340 let gp = sl(g, base, h).as_ptr() as usize;
5341 let bbp = sl(b, base, h).as_ptr() as usize;
5342 crate::pool::par_for(rows, 4, &|off, cnt| {
5343 let xs = std::slice::from_raw_parts(
5344 (xp as *const f32).add(off * h),
5345 cnt * h,
5346 );
5347 let rs = std::slice::from_raw_parts(
5348 (rp as *const f32).add(off * h),
5349 cnt * h,
5350 );
5351 let os = std::slice::from_raw_parts_mut(
5352 (op as *mut f32).add(off * h),
5353 cnt * h,
5354 );
5355 let bi = std::slice::from_raw_parts(bp as *const f32, h);
5356 let g = std::slice::from_raw_parts(gp as *const f32, h);
5357 let b = std::slice::from_raw_parts(bbp as *const f32, h);
5358 crate::kernels::residual_bias_layer_norm(
5359 xs, rs, bi, g, b, os, cnt, h, eps,
5360 );
5361 });
5362 })
5363 }
5364
5365 Thunk::BiasAdd {
5366 src,
5367 bias,
5368 dst,
5369 m,
5370 n,
5371 } => {
5372 let (m, n) = (m as usize, n as usize);
5373 Arc::new(move |base: *mut u8| unsafe {
5374 let out = sl_mut(dst, base, m * n);
5375 out.copy_from_slice(sl(src, base, m * n));
5376 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5377 })
5378 }
5379
5380 Thunk::Gather {
5381 table,
5382 table_len,
5383 idx,
5384 dst,
5385 num_idx,
5386 trailing,
5387 } => {
5388 let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5389 Arc::new(move |base: *mut u8| unsafe {
5390 let tab = sl(table, base, tl);
5391 let ids = sl(idx, base, ni);
5392 let out = sl_mut(dst, base, ni * tr);
5393 for i in 0..ni {
5394 let row = ids[i] as usize;
5395 out[i * tr..(i + 1) * tr]
5396 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5397 }
5398 })
5399 }
5400
5401 Thunk::Narrow {
5402 src,
5403 dst,
5404 outer,
5405 src_stride,
5406 dst_stride,
5407 inner,
5408 elem_bytes,
5409 } => {
5410 narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5411 }
5412
5413 Thunk::Copy { src, dst, len } => {
5414 let len = len as usize;
5415 Arc::new(move |base: *mut u8| unsafe {
5416 sl_mut(dst, base, len).copy_from_slice(sl(src, base, len));
5417 })
5418 }
5419
5420 Thunk::Softmax { data, rows, cols } => {
5421 let (rows, cols) = (rows as usize, cols as usize);
5422 Arc::new(move |base: *mut u8| unsafe {
5423 crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5424 })
5425 }
5426
5427 Thunk::Cumsum {
5428 src,
5429 dst,
5430 rows,
5431 cols,
5432 exclusive,
5433 } => {
5434 let (rows, cols) = (rows as usize, cols as usize);
5435 Arc::new(move |base: *mut u8| unsafe {
5436 let s = sl(src, base, rows * cols);
5437 let d = sl_mut(dst, base, rows * cols);
5438 if exclusive {
5439 for r in 0..rows {
5440 let mut acc = 0.0f32;
5441 for c in 0..cols {
5442 d[r * cols + c] = acc;
5443 acc += s[r * cols + c];
5444 }
5445 }
5446 } else {
5447 for r in 0..rows {
5448 let mut acc = 0.0f32;
5449 for c in 0..cols {
5450 acc += s[r * cols + c];
5451 d[r * cols + c] = acc;
5452 }
5453 }
5454 }
5455 })
5456 }
5457
5458 Thunk::Sample {
5459 logits,
5460 dst,
5461 batch,
5462 vocab,
5463 top_k,
5464 top_p,
5465 temperature,
5466 seed,
5467 } => {
5468 let (b, v) = (batch as usize, vocab as usize);
5469 let k = (top_k as usize).min(v);
5470 Arc::new(move |base: *mut u8| unsafe {
5471 let lg = sl(logits, base, b * v);
5472 let out = sl_mut(dst, base, b);
5473 let mut rng =
5474 rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5475 for bi in 0..b {
5476 let row = &lg[bi * v..(bi + 1) * v];
5477 out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5478 }
5479 })
5480 }
5481
5482 Thunk::DequantMatMul {
5483 x,
5484 w_q,
5485 scale,
5486 zp,
5487 dst,
5488 m,
5489 k,
5490 n,
5491 block_size,
5492 is_asymmetric,
5493 } => {
5494 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5495 let n_blocks_per_col = k.div_ceil(bs);
5496 Arc::new(move |base: *mut u8| unsafe {
5497 let xs = sl(x, base, m * k);
5498 let raw = base.add(w_q);
5500 let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5501 let scales = sl(scale, base, n_blocks_per_col * n);
5502 let zps = if is_asymmetric {
5503 sl(zp, base, n_blocks_per_col * n)
5504 } else {
5505 &[][..]
5506 };
5507 let out = sl_mut(dst, base, m * n);
5508 dequant_matmul_int8(
5509 xs,
5510 w_bytes,
5511 scales,
5512 zps,
5513 out,
5514 m,
5515 k,
5516 n,
5517 bs,
5518 is_asymmetric,
5519 );
5520 })
5521 }
5522
5523 Thunk::DequantMatMulGguf {
5524 x,
5525 w_q,
5526 dst,
5527 m,
5528 k,
5529 n,
5530 scheme,
5531 } => {
5532 let (m, k, n) = (m as usize, k as usize, n as usize);
5533 let block_bytes = scheme.gguf_block_bytes() as usize;
5534 let block_elems = scheme.gguf_block_size() as usize;
5535 let total_bytes = (k * n) / block_elems * block_bytes;
5536 Arc::new(move |base: *mut u8| unsafe {
5537 let xs = sl(x, base, m * k);
5538 let w_bytes =
5539 std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
5540 let out = sl_mut(dst, base, m * n);
5541 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
5542 })
5543 }
5544
5545 Thunk::DequantMatMulInt4 {
5546 x,
5547 w_q,
5548 scale,
5549 zp,
5550 dst,
5551 m,
5552 k,
5553 n,
5554 block_size,
5555 is_asymmetric,
5556 } => {
5557 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5558 let n_blocks = k.div_ceil(bs);
5559 Arc::new(move |base: *mut u8| unsafe {
5560 let xs = sl(x, base, m * k);
5561 let w_bytes = std::slice::from_raw_parts(
5562 base.add(w_q) as *const u8,
5563 (k * n).div_ceil(2),
5564 );
5565 let scales = sl(scale, base, n_blocks * n);
5566 let zps = if is_asymmetric {
5567 sl(zp, base, n_blocks * n)
5568 } else {
5569 &[][..]
5570 };
5571 let out = sl_mut(dst, base, m * n);
5572 dequant_matmul_int4(
5573 xs,
5574 w_bytes,
5575 scales,
5576 zps,
5577 out,
5578 m,
5579 k,
5580 n,
5581 bs,
5582 is_asymmetric,
5583 );
5584 })
5585 }
5586
5587 Thunk::DequantMatMulFp8 {
5588 x,
5589 w_q,
5590 scale,
5591 dst,
5592 m,
5593 k,
5594 n,
5595 e5m2,
5596 } => {
5597 let (m, k, n) = (m as usize, k as usize, n as usize);
5598 Arc::new(move |base: *mut u8| unsafe {
5599 let xs = sl(x, base, m * k);
5600 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
5601 let scales = sl(scale, base, n);
5602 let out = sl_mut(dst, base, m * n);
5603 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
5604 })
5605 }
5606
5607 Thunk::DequantMatMulNvfp4 {
5608 x,
5609 w_q,
5610 scale,
5611 global_scale,
5612 dst,
5613 m,
5614 k,
5615 n,
5616 } => {
5617 let (m, k, n) = (m as usize, k as usize, n as usize);
5618 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
5619 Arc::new(move |base: *mut u8| unsafe {
5620 let xs = sl(x, base, m * k);
5621 let w_bytes = std::slice::from_raw_parts(
5622 base.add(w_q) as *const u8,
5623 (k * n).div_ceil(2),
5624 );
5625 let scale_bytes =
5626 std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
5627 let gs = sl(global_scale, base, 1)[0];
5628 let out = sl_mut(dst, base, m * n);
5629 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
5630 })
5631 }
5632
5633 Thunk::LoraMatMul {
5634 x,
5635 w,
5636 a,
5637 b,
5638 dst,
5639 m,
5640 k,
5641 n,
5642 r,
5643 scale,
5644 } => {
5645 let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
5646 Arc::new(move |base: *mut u8| unsafe {
5647 let xs = sl(x, base, m * k);
5648 let ws = sl(w, base, k * n);
5649 let a_s = sl(a, base, k * r);
5650 let bs = sl(b, base, r * n);
5651 let out = sl_mut(dst, base, m * n);
5652 crate::blas::sgemm(xs, ws, out, m, k, n);
5654 let mut tmp = vec![0f32; m * r];
5656 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
5657 if scale != 1.0 {
5661 for v in tmp.iter_mut() {
5662 *v *= scale;
5663 }
5664 }
5665 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
5666 })
5667 }
5668
5669 Thunk::LayerNorm {
5670 src,
5671 g,
5672 b,
5673 dst,
5674 rows,
5675 h,
5676 eps,
5677 } => {
5678 let (rows, h) = (rows as usize, h as usize);
5679 Arc::new(move |base: *mut u8| unsafe {
5680 let inp = sl(src, base, rows * h);
5681 let gamma = sl(g, base, h);
5682 let beta = sl(b, base, h);
5683 let out = sl_mut(dst, base, rows * h);
5684 for row in 0..rows {
5685 crate::kernels::layer_norm_row(
5686 &inp[row * h..(row + 1) * h],
5687 gamma,
5688 beta,
5689 &mut out[row * h..(row + 1) * h],
5690 h,
5691 eps,
5692 );
5693 }
5694 })
5695 }
5696
5697 Thunk::Attention {
5698 q,
5699 k,
5700 v,
5701 mask,
5702 out,
5703 batch,
5704 seq,
5705 kv_seq: _,
5706 heads,
5707 head_dim,
5708 mask_kind,
5709 q_row_stride,
5710 k_row_stride,
5711 v_row_stride,
5712 bhsd,
5713 } => {
5714 let (b, s, nh, dh) = (
5715 batch as usize,
5716 seq as usize,
5717 heads as usize,
5718 head_dim as usize,
5719 );
5720 let hs = nh * dh;
5721 let qrs = q_row_stride as usize;
5722 let krs = k_row_stride as usize;
5723 let vrs = v_row_stride as usize;
5724 let scale = (dh as f32).powf(-0.5);
5725 Arc::new(move |base: *mut u8| unsafe {
5726 let (q_len, k_len, v_len, o_len) = if bhsd {
5731 let n = b * nh * s * dh;
5732 (n, n, n, n)
5733 } else {
5734 (b * s * qrs, b * s * krs, b * s * vrs, b * s * hs)
5735 };
5736 let q_d = sl(q, base, q_len);
5737 let k_d = sl(k, base, k_len);
5738 let v_d = sl(v, base, v_len);
5739 let m_d: &[f32] = match mask_kind {
5740 rlx_ir::op::MaskKind::Custom => sl(mask, base, b * s),
5741 rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * s * s),
5742 _ => &[],
5743 };
5744 let o_d = sl_mut(out, base, o_len);
5745 let sdh = s * dh;
5746 let mut qh = vec![0f32; sdh];
5747 let mut kh = vec![0f32; sdh];
5748 let mut vh = vec![0f32; sdh];
5749 let mut sc = vec![0f32; s * s];
5750 let mut oh = vec![0f32; sdh];
5751 for bi in 0..b {
5752 for hi in 0..nh {
5753 for si in 0..s {
5754 let (q_off, k_off, v_off) = if bhsd {
5766 (
5767 bi * nh * s * dh + hi * s * dh + si * dh,
5768 bi * nh * s * dh + hi * s * dh + si * dh,
5769 bi * nh * s * dh + hi * s * dh + si * dh,
5770 )
5771 } else {
5772 (
5773 bi * s * qrs + si * qrs + hi * dh,
5774 bi * s * krs + si * krs + hi * dh,
5775 bi * s * vrs + si * vrs + hi * dh,
5776 )
5777 };
5778 qh[si * dh..(si + 1) * dh]
5779 .copy_from_slice(&q_d[q_off..q_off + dh]);
5780 kh[si * dh..(si + 1) * dh]
5781 .copy_from_slice(&k_d[k_off..k_off + dh]);
5782 vh[si * dh..(si + 1) * dh]
5783 .copy_from_slice(&v_d[v_off..v_off + dh]);
5784 }
5785 for qi in 0..s {
5786 for ki in 0..s {
5787 let mut dot = 0f32;
5788 for d in 0..dh {
5789 dot += qh[qi * dh + d] * kh[ki * dh + d];
5790 }
5791 sc[qi * s + ki] = dot * scale;
5792 }
5793 }
5794 match mask_kind {
5797 rlx_ir::op::MaskKind::None => {}
5798 rlx_ir::op::MaskKind::Causal => {
5799 for qi in 0..s {
5800 for ki in (qi + 1)..s {
5801 sc[qi * s + ki] = mask_neg;
5802 }
5803 }
5804 }
5805 rlx_ir::op::MaskKind::SlidingWindow(w) => {
5806 for qi in 0..s {
5807 let lo = qi.saturating_sub(w);
5808 for ki in 0..s {
5809 if ki < lo || ki > qi {
5810 sc[qi * s + ki] = mask_neg;
5811 }
5812 }
5813 }
5814 }
5815 rlx_ir::op::MaskKind::Custom => {
5816 for qi in 0..s {
5817 for ki in 0..s {
5818 if m_d[bi * s + ki] < mask_thr {
5819 sc[qi * s + ki] = mask_neg;
5820 }
5821 }
5822 }
5823 }
5824 rlx_ir::op::MaskKind::Bias => {
5825 let per_bh = s * s;
5826 let off = (bi * nh + hi) * per_bh;
5827 for i in 0..per_bh {
5828 sc[i] += m_d[off + i];
5829 }
5830 }
5831 }
5832 crate::naive::softmax(&mut sc, s, s);
5833 oh.fill(0.0);
5834 for qi in 0..s {
5835 for ki in 0..s {
5836 let w = sc[qi * s + ki];
5837 if w > score_skip {
5838 for d in 0..dh {
5839 oh[qi * dh + d] += w * vh[ki * dh + d];
5840 }
5841 }
5842 }
5843 }
5844 for si in 0..s {
5845 let off = if bhsd {
5846 bi * nh * s * dh + hi * s * dh + si * dh
5847 } else {
5848 bi * s * hs + si * hs + hi * dh
5849 };
5850 o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
5851 }
5852 }
5853 }
5854 })
5855 }
5856
5857 Thunk::FusedSwiGLU {
5858 src,
5859 dst,
5860 n_half,
5861 total,
5862 gate_first,
5863 } => {
5864 let n = n_half as usize;
5865 let t = total as usize;
5866 let outer = t / n;
5867 let in_total = outer * 2 * n;
5868 Arc::new(move |base: *mut u8| unsafe {
5869 let inp = sl(src, base, in_total);
5870 let out = sl_mut(dst, base, t);
5871 for o in 0..outer {
5872 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
5873 let out_row = &mut out[o * n..(o + 1) * n];
5874 for i in 0..n {
5875 let (up, gate) = if gate_first {
5876 (in_row[n + i], in_row[i])
5877 } else {
5878 (in_row[i], in_row[n + i])
5879 };
5880 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
5881 }
5882 }
5883 })
5884 }
5885
5886 Thunk::Concat {
5887 dst,
5888 outer,
5889 inner,
5890 total_axis,
5891 inputs,
5892 } => {
5893 let outer = outer as usize;
5894 let inner = inner as usize;
5895 let total_axis = total_axis as usize;
5896 let out_total = outer * total_axis * inner;
5897 let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
5900 let mut cum: usize = 0;
5901 for (src_off, in_axis) in &inputs {
5902 let in_axis = *in_axis as usize;
5903 layout.push((*src_off, cum * inner, in_axis * inner));
5904 cum += in_axis;
5905 }
5906 Arc::new(move |base: *mut u8| unsafe {
5907 let out = sl_mut(dst, base, out_total);
5908 let row_stride = total_axis * inner;
5909 for (src_off, dst_col_off, copy_per_row) in &layout {
5910 let in_total = outer * *copy_per_row;
5911 let inp = sl(*src_off, base, in_total);
5912 for o in 0..outer {
5913 let dst_row_start = o * row_stride + *dst_col_off;
5914 let src_row_start = o * *copy_per_row;
5915 out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
5916 &inp[src_row_start..src_row_start + *copy_per_row],
5917 );
5918 }
5919 }
5920 })
5921 }
5922
5923 Thunk::CustomOp {
5924 kernel,
5925 inputs,
5926 output,
5927 attrs,
5928 } => {
5929 let kernel = kernel.clone();
5935 let attrs = attrs.clone();
5936 let inputs = inputs.clone();
5937 let (out_off, out_len, out_shape) = output.clone();
5938 Arc::new(move |base: *mut u8| unsafe {
5939 dispatch_custom_op(
5940 &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
5941 );
5942 })
5943 }
5944
5945 Thunk::GaussianSplatRender {
5946 positions_off,
5947 positions_len,
5948 scales_off,
5949 scales_len,
5950 rotations_off,
5951 rotations_len,
5952 opacities_off,
5953 opacities_len,
5954 colors_off,
5955 colors_len,
5956 sh_coeffs_off,
5957 sh_coeffs_len,
5958 meta_off,
5959 dst_off,
5960 dst_len,
5961 width,
5962 height,
5963 tile_size,
5964 radius_scale,
5965 alpha_cutoff,
5966 max_splat_steps,
5967 transmittance_threshold,
5968 max_list_entries,
5969 } => Arc::new(move |base: *mut u8| unsafe {
5970 crate::splat::execute_gaussian_splat_render(
5971 positions_off,
5972 positions_len,
5973 scales_off,
5974 scales_len,
5975 rotations_off,
5976 rotations_len,
5977 opacities_off,
5978 opacities_len,
5979 colors_off,
5980 colors_len,
5981 sh_coeffs_off,
5982 sh_coeffs_len,
5983 meta_off,
5984 dst_off,
5985 dst_len,
5986 width,
5987 height,
5988 tile_size,
5989 radius_scale,
5990 alpha_cutoff,
5991 max_splat_steps,
5992 transmittance_threshold,
5993 max_list_entries,
5994 base,
5995 );
5996 }),
5997
5998 Thunk::GaussianSplatRenderBackward {
5999 positions_off,
6000 positions_len,
6001 scales_off,
6002 scales_len,
6003 rotations_off,
6004 rotations_len,
6005 opacities_off,
6006 opacities_len,
6007 colors_off,
6008 colors_len,
6009 sh_coeffs_off,
6010 sh_coeffs_len,
6011 meta_off,
6012 d_loss_off,
6013 d_loss_len,
6014 packed_off,
6015 packed_len,
6016 width,
6017 height,
6018 tile_size,
6019 radius_scale,
6020 alpha_cutoff,
6021 max_splat_steps,
6022 transmittance_threshold,
6023 max_list_entries,
6024 loss_grad_clip,
6025 sh_band,
6026 max_anisotropy,
6027 } => Arc::new(move |base: *mut u8| unsafe {
6028 crate::splat::execute_gaussian_splat_render_backward(
6029 positions_off,
6030 positions_len,
6031 scales_off,
6032 scales_len,
6033 rotations_off,
6034 rotations_len,
6035 opacities_off,
6036 opacities_len,
6037 colors_off,
6038 colors_len,
6039 sh_coeffs_off,
6040 sh_coeffs_len,
6041 meta_off,
6042 d_loss_off,
6043 d_loss_len,
6044 packed_off,
6045 packed_len,
6046 width,
6047 height,
6048 tile_size,
6049 radius_scale,
6050 alpha_cutoff,
6051 max_splat_steps,
6052 transmittance_threshold,
6053 max_list_entries,
6054 loss_grad_clip,
6055 sh_band,
6056 max_anisotropy,
6057 base,
6058 );
6059 }),
6060
6061 Thunk::GaussianSplatPrepare {
6062 positions_off,
6063 positions_len,
6064 scales_off,
6065 scales_len,
6066 rotations_off,
6067 rotations_len,
6068 opacities_off,
6069 opacities_len,
6070 colors_off,
6071 colors_len,
6072 sh_coeffs_off,
6073 sh_coeffs_len,
6074 meta_off,
6075 meta_len,
6076 prep_off,
6077 prep_len,
6078 width,
6079 height,
6080 tile_size,
6081 radius_scale,
6082 alpha_cutoff,
6083 max_splat_steps,
6084 transmittance_threshold,
6085 max_list_entries,
6086 } => Arc::new(move |base: *mut u8| unsafe {
6087 crate::splat::execute_gaussian_splat_prepare(
6088 positions_off,
6089 positions_len,
6090 scales_off,
6091 scales_len,
6092 rotations_off,
6093 rotations_len,
6094 opacities_off,
6095 opacities_len,
6096 colors_off,
6097 colors_len,
6098 sh_coeffs_off,
6099 sh_coeffs_len,
6100 meta_off,
6101 meta_len,
6102 prep_off,
6103 prep_len,
6104 width,
6105 height,
6106 tile_size,
6107 radius_scale,
6108 alpha_cutoff,
6109 max_splat_steps,
6110 transmittance_threshold,
6111 max_list_entries,
6112 base,
6113 );
6114 }),
6115
6116 Thunk::GaussianSplatRasterize {
6117 prep_off,
6118 prep_len,
6119 meta_off,
6120 meta_len,
6121 dst_off,
6122 dst_len,
6123 count,
6124 width,
6125 height,
6126 tile_size,
6127 alpha_cutoff,
6128 max_splat_steps,
6129 transmittance_threshold,
6130 max_list_entries,
6131 } => Arc::new(move |base: *mut u8| unsafe {
6132 crate::splat::execute_gaussian_splat_rasterize(
6133 prep_off,
6134 prep_len,
6135 meta_off,
6136 meta_len,
6137 dst_off,
6138 dst_len,
6139 count,
6140 width,
6141 height,
6142 tile_size,
6143 alpha_cutoff,
6144 max_splat_steps,
6145 transmittance_threshold,
6146 max_list_entries,
6147 base,
6148 );
6149 }),
6150
6151 Thunk::Fft1d {
6152 src,
6153 dst,
6154 outer,
6155 n_complex,
6156 inverse,
6157 dtype,
6158 } => {
6159 let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6160 rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6161 execute_fft1d_f64(
6162 src,
6163 dst,
6164 outer as usize,
6165 n_complex as usize,
6166 inverse,
6167 base,
6168 );
6169 }),
6170 rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6171 execute_fft1d_f32(
6172 src,
6173 dst,
6174 outer as usize,
6175 n_complex as usize,
6176 inverse,
6177 base,
6178 );
6179 }),
6180 other => panic!("Op::Fft on CPU requires F32/F64, got {other:?}"),
6181 };
6182 f
6183 }
6184
6185 _ => Arc::new(|_: *mut u8| {}),
6186 }
6187 })
6188 .collect();
6189
6190 let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6194 .and_then(|v| v.parse().ok())
6195 .unwrap_or(64);
6196 let should_fuse = thunks.iter().any(|t| match t {
6197 Thunk::Attention { batch, seq, .. } => {
6198 (*batch as usize) * (*seq as usize) <= fuse_threshold
6199 }
6200 _ => false,
6201 });
6202
6203 if should_fuse {
6204 let active: Vec<usize> = thunks
6206 .iter()
6207 .enumerate()
6208 .filter(|(_, t)| !matches!(t, Thunk::Nop))
6209 .map(|(i, _)| i)
6210 .collect();
6211
6212 let mut kill = vec![false; thunks.len()]; let mut insertions: Vec<(usize, Thunk)> = Vec::new(); let mut ai = 0;
6216 while ai < active.len() {
6217 let a = |off: usize| -> Option<(usize, &Thunk)> {
6219 active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6220 };
6221
6222 let matched = (|| {
6224 let (_i0, t0) = a(0)?;
6225 let (_, t1) = a(1)?;
6226 let (_, t2) = a(2)?;
6227 let (_, t3) = a(3)?;
6228
6229 let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6231 Thunk::FusedMmBiasAct {
6232 a,
6233 w,
6234 bias,
6235 n: _,
6236 act: None,
6237 ..
6238 } => (*a, *w, *bias, true),
6239 Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6240 _ => return None,
6241 };
6242
6243 if !matches!(t1, Thunk::Narrow { .. }) {
6245 return None;
6246 }
6247 if !matches!(t2, Thunk::Narrow { .. }) {
6248 return None;
6249 }
6250 if !matches!(t3, Thunk::Narrow { .. }) {
6251 return None;
6252 }
6253
6254 let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6256 _,
6257 Thunk::Rope {
6258 cos, sin, cos_len, ..
6259 },
6260 )) = a(4)
6261 {
6262 if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6263 if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6264 (true, 6, *cos, *sin, *cos_len)
6265 } else {
6266 return None;
6267 }
6268 } else {
6269 return None;
6270 }
6271 } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6272 (false, 4, 0, 0, 0)
6273 } else {
6274 return None;
6275 };
6276
6277 let (_attn_real_idx, attn_t) = a(attn_ai)?;
6278 let (batch, seq, heads, head_dim, mask) = match attn_t {
6279 Thunk::Attention {
6280 batch,
6281 seq,
6282 heads,
6283 head_dim,
6284 mask,
6285 ..
6286 } => (*batch, *seq, *heads, *head_dim, *mask),
6287 _ => return None,
6288 };
6289
6290 let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6292 let (out_w, out_b, out_dst) = match out_t {
6293 Thunk::FusedMmBiasAct {
6294 w,
6295 bias,
6296 c,
6297 act: None,
6298 ..
6299 } => (*w, *bias, *c),
6300 Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6301 _ => return None,
6302 };
6303
6304 let hs = heads * head_dim;
6305 let total_active = attn_ai + 2; Some((
6308 total_active,
6309 Thunk::FusedAttnBlock {
6310 hidden,
6311 qkv_w,
6312 out_w,
6313 mask,
6314 out: out_dst,
6315 qkv_b: if has_b { qkv_b } else { 0 },
6316 out_b: if has_b { out_b } else { 0 },
6317 cos: cos_off,
6318 sin: sin_off,
6319 cos_len: cl,
6320 batch,
6321 seq,
6322 hs,
6323 nh: heads,
6324 dh: head_dim,
6325 has_bias: has_b,
6326 has_rope,
6327 },
6328 ))
6329 })();
6330
6331 if let Some((count, fused_thunk)) = matched {
6332 for off in 0..count {
6334 if let Some(&idx) = active.get(ai + off) {
6335 kill[idx] = true;
6336 }
6337 }
6338 insertions.push((active[ai], fused_thunk));
6340 ai += count;
6341 } else {
6342 ai += 1;
6343 }
6344 }
6345
6346 if !insertions.is_empty() {
6348 let mut new_thunks = Vec::with_capacity(thunks.len());
6349 let mut insert_idx = 0;
6350 for (i, t) in thunks.into_iter().enumerate() {
6351 if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6352 new_thunks.push(insertions[insert_idx].1.clone());
6353 insert_idx += 1;
6354 }
6355 if !kill[i] {
6356 new_thunks.push(t);
6357 }
6358 }
6359 if cfg.verbose >= 1 {
6360 eprintln!(
6361 "[rlx] fused_attention: {} attention blocks fused",
6362 insertions.len()
6363 );
6364 }
6365 thunks = new_thunks;
6366 }
6367 }
6368
6369 if should_fuse {
6374 let active: Vec<usize> = thunks
6375 .iter()
6376 .enumerate()
6377 .filter(|(_, t)| !matches!(t, Thunk::Nop))
6378 .map(|(i, _)| i)
6379 .collect();
6380
6381 let mut kill = vec![false; thunks.len()];
6382 let mut insertions: Vec<(usize, Thunk)> = Vec::new();
6383
6384 let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
6385
6386 let mut ai = 0;
6387 while ai < active.len() {
6388 let bert_match = (|| -> Option<usize> {
6390 let fab = a(ai)?;
6391 let rln1 = a(ai + 1)?;
6392 let ffn1 = a(ai + 2)?;
6393 let ffn2 = a(ai + 3)?;
6394 let rln2 = a(ai + 4)?;
6395
6396 let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
6397 Thunk::FusedAttnBlock {
6398 hidden,
6399 qkv_w,
6400 qkv_b,
6401 out_w,
6402 out_b,
6403 mask,
6404 batch,
6405 seq,
6406 hs,
6407 nh,
6408 dh,
6409 has_bias: true,
6410 has_rope: false,
6411 ..
6412 } => (
6413 *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
6414 ),
6415 _ => return None,
6416 };
6417 let (ln1_g, ln1_b, eps1) = match rln1 {
6418 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6419 _ => return None,
6420 };
6421 let (fc1_w, fc1_b, int_dim) = match ffn1 {
6422 Thunk::FusedMmBiasAct {
6423 w,
6424 bias,
6425 n,
6426 act: Some(Activation::Gelu),
6427 ..
6428 } => (*w, *bias, *n),
6429 _ => return None,
6430 };
6431 let (fc2_w, fc2_b) = match ffn2 {
6432 Thunk::FusedMmBiasAct {
6433 w, bias, act: None, ..
6434 } => (*w, *bias),
6435 _ => return None,
6436 };
6437 let (ln2_g, ln2_b, eps2, out) = match rln2 {
6438 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6439 _ => return None,
6440 };
6441
6442 for off in 0..5 {
6443 kill[active[ai + off]] = true;
6444 }
6445 insertions.push((
6446 active[ai],
6447 Thunk::FusedBertLayer {
6448 hidden,
6449 qkv_w,
6450 qkv_b,
6451 out_w,
6452 out_b,
6453 mask,
6454 ln1_g,
6455 ln1_b,
6456 eps1,
6457 fc1_w,
6458 fc1_b,
6459 fc2_w,
6460 fc2_b,
6461 ln2_g,
6462 ln2_b,
6463 eps2,
6464 out,
6465 batch,
6466 seq,
6467 hs,
6468 nh,
6469 dh,
6470 int_dim,
6471 },
6472 ));
6473 Some(5)
6474 })();
6475 if let Some(n) = bert_match {
6476 ai += n;
6477 continue;
6478 }
6479
6480 #[allow(unreachable_code)]
6484 let nomic_match = (|| -> Option<usize> {
6485 return None; let fab = a(ai)?;
6487 let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
6488 match fab {
6489 Thunk::FusedAttnBlock {
6490 hidden,
6491 qkv_w,
6492 out_w,
6493 mask,
6494 cos,
6495 sin,
6496 cos_len,
6497 batch,
6498 seq,
6499 hs,
6500 nh,
6501 dh,
6502 has_bias: false,
6503 has_rope: true,
6504 ..
6505 } => (
6506 *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
6507 *hs, *nh, *dh,
6508 ),
6509 _ => return None,
6510 };
6511 let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
6513 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6514 _ => return None,
6515 };
6516 let fused_fc_w = match a(ai + 2)? {
6518 Thunk::Sgemm { b: w, .. } => *w,
6519 _ => return None,
6520 };
6521 if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
6523 return None;
6524 }
6525 if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
6526 return None;
6527 }
6528 if !matches!(
6530 a(ai + 5)?,
6531 Thunk::ActivationInPlace {
6532 act: Activation::Silu,
6533 ..
6534 }
6535 ) {
6536 return None;
6537 }
6538 if !matches!(
6540 a(ai + 6)?,
6541 Thunk::BinaryFull {
6542 op: BinaryOp::Mul,
6543 ..
6544 }
6545 ) {
6546 return None;
6547 }
6548 let fc2_w = match a(ai + 7)? {
6550 Thunk::Sgemm { b: w, .. } => *w,
6551 _ => return None,
6552 };
6553 let int_dim = match a(ai + 3)? {
6555 Thunk::Narrow { inner, .. } => *inner,
6556 _ => return None,
6557 };
6558 let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
6560 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6561 _ => return None,
6562 };
6563
6564 for off in 0..9 {
6565 kill[active[ai + off]] = true;
6566 }
6567 insertions.push((
6568 active[ai],
6569 Thunk::FusedNomicLayer {
6570 hidden,
6571 qkv_w,
6572 out_w,
6573 mask,
6574 cos,
6575 sin,
6576 cos_len,
6577 ln1_g,
6578 ln1_b,
6579 eps1,
6580 fc11_w: fused_fc_w,
6581 fc12_w: 0,
6582 fc2_w,
6583 ln2_g,
6584 ln2_b,
6585 eps2,
6586 out,
6587 batch,
6588 seq,
6589 hs,
6590 nh,
6591 dh,
6592 int_dim,
6593 },
6594 ));
6595 Some(9)
6596 })();
6597 if let Some(n) = nomic_match {
6598 ai += n;
6599 continue;
6600 }
6601
6602 ai += 1;
6603 }
6604
6605 if !insertions.is_empty() {
6606 let mut new_thunks = Vec::with_capacity(thunks.len());
6607 let mut ins_idx = 0;
6608 for (i, t) in thunks.into_iter().enumerate() {
6609 if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
6610 new_thunks.push(insertions[ins_idx].1.clone());
6611 ins_idx += 1;
6612 }
6613 if !kill[i] {
6614 new_thunks.push(t);
6615 }
6616 }
6617 if cfg.verbose >= 1 {
6618 eprintln!(
6619 "[rlx] fused_layer: {} full transformer layers fused",
6620 insertions.len()
6621 );
6622 }
6623 thunks = new_thunks;
6624 }
6625 }
6626
6627 {
6639 let mut read_offsets: HashMap<usize, usize> = HashMap::new();
6642 for t in &thunks {
6643 for off in thunk_read_offsets(t) {
6644 *read_offsets.entry(off).or_insert(0) += 1;
6645 }
6646 }
6647
6648 let mut fused_count = 0usize;
6649 for i in 0..thunks.len().saturating_sub(1) {
6650 let narrow = match &thunks[i] {
6653 Thunk::Narrow { .. } => i,
6654 _ => continue,
6655 };
6656 let mut j = narrow + 1;
6658 while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
6659 j += 1;
6660 }
6661 if j >= thunks.len() {
6662 continue;
6663 }
6664 let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
6666 Thunk::Narrow {
6667 src,
6668 dst,
6669 src_stride,
6670 ..
6671 } => (*src, *dst, *src_stride),
6672 _ => continue,
6673 };
6674 let rope_reads_narrow = matches!(&thunks[j],
6675 Thunk::Rope { src, .. } if *src == n_dst);
6676 if !rope_reads_narrow {
6677 continue;
6678 }
6679 if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
6683 continue;
6684 }
6685
6686 if let Thunk::Rope {
6689 src,
6690 src_row_stride,
6691 ..
6692 } = &mut thunks[j]
6693 {
6694 *src = n_src;
6695 *src_row_stride = n_src_stride;
6696 }
6697 thunks[narrow] = Thunk::Nop;
6698 fused_count += 1;
6699 }
6700
6701 if fused_count > 0 && cfg.verbose >= 1 {
6702 eprintln!(
6703 "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
6704 fused_count
6705 );
6706 }
6707 }
6708
6709 {
6721 let mut read_counts: HashMap<usize, usize> = HashMap::new();
6722 for t in &thunks {
6723 for off in thunk_read_offsets(t) {
6724 *read_counts.entry(off).or_insert(0) += 1;
6725 }
6726 }
6727 let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
6729 for (i, t) in thunks.iter().enumerate() {
6730 if let Thunk::Narrow { dst, .. } = t {
6731 dst_to_idx.insert(*dst, i);
6732 }
6733 }
6734
6735 let mut fused_count = 0usize;
6736 for i in 0..thunks.len() {
6737 let (q_off, k_off, v_off) = match &thunks[i] {
6738 Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
6739 _ => continue,
6740 };
6741 let q_n = match dst_to_idx.get(&q_off).copied() {
6743 Some(x) => x,
6744 None => continue,
6745 };
6746 let k_n = match dst_to_idx.get(&k_off).copied() {
6747 Some(x) => x,
6748 None => continue,
6749 };
6750 let v_n = match dst_to_idx.get(&v_off).copied() {
6751 Some(x) => x,
6752 None => continue,
6753 };
6754 if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
6756 continue;
6757 }
6758 if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
6759 continue;
6760 }
6761 if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
6762 continue;
6763 }
6764
6765 let (q_src, q_stride) = match &thunks[q_n] {
6766 Thunk::Narrow {
6767 src, src_stride, ..
6768 } => (*src, *src_stride),
6769 _ => continue,
6770 };
6771 let (k_src, k_stride) = match &thunks[k_n] {
6772 Thunk::Narrow {
6773 src, src_stride, ..
6774 } => (*src, *src_stride),
6775 _ => continue,
6776 };
6777 let (v_src, v_stride) = match &thunks[v_n] {
6778 Thunk::Narrow {
6779 src, src_stride, ..
6780 } => (*src, *src_stride),
6781 _ => continue,
6782 };
6783
6784 if let Thunk::Attention {
6785 q,
6786 k,
6787 v,
6788 q_row_stride,
6789 k_row_stride,
6790 v_row_stride,
6791 ..
6792 } = &mut thunks[i]
6793 {
6794 *q = q_src;
6795 *k = k_src;
6796 *v = v_src;
6797 *q_row_stride = q_stride;
6798 *k_row_stride = k_stride;
6799 *v_row_stride = v_stride;
6800 }
6801 thunks[q_n] = Thunk::Nop;
6802 thunks[k_n] = Thunk::Nop;
6803 thunks[v_n] = Thunk::Nop;
6804 fused_count += 1;
6805 }
6806
6807 if fused_count > 0 && cfg.verbose >= 1 {
6808 eprintln!(
6809 "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
6810 fused_count
6811 );
6812 }
6813 }
6814
6815 ThunkSchedule {
6816 thunks,
6817 moe_resident: None,
6818 moe_resident_layers: None,
6819 moe_topk_capture: None,
6820 mask_threshold: cfg.mask_binary_threshold,
6821 mask_neg_inf: cfg.attn_mask_neg_inf,
6822 score_skip: cfg.score_skip_threshold,
6823 compiled_fns,
6824 }
6825}
6826
6827fn get_len(graph: &Graph, id: NodeId) -> usize {
6828 graph.node(id).shape.num_elements().unwrap_or(0)
6829}
6830
6831fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
6833 let dims = graph.node(id).shape.dims();
6834 let mut out = Vec::with_capacity(dims.len());
6835 for d in dims {
6836 if let Some(s) = match d {
6837 rlx_ir::Dim::Static(s) => Some(*s),
6838 _ => None,
6839 } {
6840 out.push(s);
6841 } else {
6842 return Vec::new();
6843 }
6844 }
6845 out
6846}
6847
6848fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
6866 if rhs_dims.len() > out_dims.len() {
6867 return false;
6868 }
6869 let off = out_dims.len() - rhs_dims.len();
6870 for i in 0..rhs_dims.len() {
6871 let r = match rhs_dims[i] {
6872 rlx_ir::Dim::Static(n) => n,
6873 _ => return false,
6874 };
6875 let o = match out_dims[off + i] {
6876 rlx_ir::Dim::Static(n) => n,
6877 _ => return false,
6878 };
6879 if r != o {
6880 return false;
6881 }
6882 }
6883 true
6884}
6885
6886fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
6887 let r_out = out_dims.len();
6888 let r_in = in_dims.len();
6889 assert!(
6890 r_in <= r_out,
6891 "broadcast: input rank {r_in} > output rank {r_out}"
6892 );
6893 let pad = r_out - r_in;
6894 let mut strides = vec![0u32; r_out];
6895 let mut acc: usize = 1;
6896 for d in (0..r_out).rev() {
6897 let in_size = if d < pad { 1 } else { in_dims[d - pad] };
6898 if in_size == 1 {
6899 strides[d] = 0;
6900 } else {
6901 assert_eq!(
6902 in_size, out_dims[d],
6903 "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
6904 out_dims[d]
6905 );
6906 strides[d] = acc as u32;
6907 acc *= in_size;
6908 }
6909 }
6910 strides
6911}
6912
6913pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6917 let base = arena_buf.as_mut_ptr();
6918 for f in &schedule.compiled_fns {
6919 f(base);
6920 }
6921}
6922
6923pub fn execute_thunks_active(
6928 schedule: &ThunkSchedule,
6929 _arena_buf: &mut [u8],
6930 _actual: usize,
6931 _upper: usize,
6932) -> bool {
6933 let _ = schedule;
6934 false
6935}
6936
6937struct MoeResidencyGuard;
6939impl Drop for MoeResidencyGuard {
6940 fn drop(&mut self) {
6941 if let Some(stats) = crate::moe_residency::take_stats() {
6942 crate::moe_residency::stash_last_forward_stats(stats);
6943 } else {
6944 crate::moe_residency::clear_mask();
6945 }
6946 }
6947}
6948
6949pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6950 crate::moe_residency::reset_gmm_counters();
6951 if let Some(layers) = schedule.moe_resident_layers.clone() {
6952 crate::moe_residency::set_per_layer_masks(Some(layers));
6953 } else {
6954 crate::moe_residency::set_mask(schedule.moe_resident.clone());
6955 }
6956 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
6957 cap.clear();
6958 }
6959 let _moe_guard = MoeResidencyGuard;
6960 let base = arena_buf.as_mut_ptr();
6961 let mask_thr = schedule.mask_threshold;
6962 let mask_neg = schedule.mask_neg_inf;
6963 let score_thr = schedule.score_skip;
6964 let thunks = &schedule.thunks;
6965 let len = thunks.len();
6966
6967 let max_h = thunks
6969 .iter()
6970 .filter_map(|t| match t {
6971 Thunk::FusedResidualLN { h, .. }
6972 | Thunk::FusedResidualRmsNorm { h, .. }
6973 | Thunk::LayerNorm { h, .. } => Some(*h as usize),
6974 _ => None,
6975 })
6976 .max()
6977 .unwrap_or(0);
6978 let zero_bias = vec![0f32; max_h];
6979
6980 let max_sdpa = thunks
6983 .iter()
6984 .filter_map(|t| match t {
6985 Thunk::Attention {
6986 batch,
6987 seq,
6988 kv_seq,
6989 heads,
6990 head_dim,
6991 ..
6992 } => Some((
6993 *batch as usize,
6994 (*seq as usize).max(*kv_seq as usize),
6995 *heads as usize,
6996 *head_dim as usize,
6997 )),
6998 _ => None,
6999 })
7000 .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
7001 (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7002 });
7003 let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7004 let max_units = max_batch * max_heads;
7005 let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7006
7007 let fl = thunks
7009 .iter()
7010 .filter_map(|t| match t {
7011 Thunk::FusedBertLayer {
7012 batch,
7013 seq,
7014 hs,
7015 int_dim,
7016 ..
7017 } => {
7018 let m = (*batch as usize) * (*seq as usize);
7019 let h = *hs as usize;
7020 let id = *int_dim as usize;
7021 Some((m, h, id, m * (*seq as usize)))
7022 }
7023 Thunk::FusedNomicLayer {
7024 batch,
7025 seq,
7026 hs,
7027 int_dim,
7028 ..
7029 } => {
7030 let m = (*batch as usize) * (*seq as usize);
7031 let h = *hs as usize;
7032 let id = *int_dim as usize;
7033 Some((m, h, id, m * (*seq as usize)))
7034 }
7035 _ => None,
7036 })
7037 .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7038 (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7039 });
7040 let (fl_m, fl_h, fl_int, fl_ss) = fl;
7041 let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7042 let mut fl_attn = vec![0f32; fl_m * fl_h];
7043 let mut fl_res = vec![0f32; fl_m * fl_h];
7044 let mut fl_normed = vec![0f32; fl_m * fl_h];
7045 let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; let mut fl_sc = vec![0f32; fl_ss.max(1)];
7047
7048 for i in 0..len {
7049 let thunk = unsafe { thunks.get_unchecked(i) };
7050 match thunk {
7051 Thunk::Nop => {}
7052
7053 Thunk::GaussianSplatRender {
7054 positions_off,
7055 positions_len,
7056 scales_off,
7057 scales_len,
7058 rotations_off,
7059 rotations_len,
7060 opacities_off,
7061 opacities_len,
7062 colors_off,
7063 colors_len,
7064 sh_coeffs_off,
7065 sh_coeffs_len,
7066 meta_off,
7067 dst_off,
7068 dst_len,
7069 width,
7070 height,
7071 tile_size,
7072 radius_scale,
7073 alpha_cutoff,
7074 max_splat_steps,
7075 transmittance_threshold,
7076 max_list_entries,
7077 } => unsafe {
7078 crate::splat::execute_gaussian_splat_render(
7079 *positions_off,
7080 *positions_len,
7081 *scales_off,
7082 *scales_len,
7083 *rotations_off,
7084 *rotations_len,
7085 *opacities_off,
7086 *opacities_len,
7087 *colors_off,
7088 *colors_len,
7089 *sh_coeffs_off,
7090 *sh_coeffs_len,
7091 *meta_off,
7092 *dst_off,
7093 *dst_len,
7094 *width,
7095 *height,
7096 *tile_size,
7097 *radius_scale,
7098 *alpha_cutoff,
7099 *max_splat_steps,
7100 *transmittance_threshold,
7101 *max_list_entries,
7102 base,
7103 );
7104 },
7105
7106 Thunk::GaussianSplatRenderBackward {
7107 positions_off,
7108 positions_len,
7109 scales_off,
7110 scales_len,
7111 rotations_off,
7112 rotations_len,
7113 opacities_off,
7114 opacities_len,
7115 colors_off,
7116 colors_len,
7117 sh_coeffs_off,
7118 sh_coeffs_len,
7119 meta_off,
7120 d_loss_off,
7121 d_loss_len,
7122 packed_off,
7123 packed_len,
7124 width,
7125 height,
7126 tile_size,
7127 radius_scale,
7128 alpha_cutoff,
7129 max_splat_steps,
7130 transmittance_threshold,
7131 max_list_entries,
7132 loss_grad_clip,
7133 sh_band,
7134 max_anisotropy,
7135 } => unsafe {
7136 crate::splat::execute_gaussian_splat_render_backward(
7137 *positions_off,
7138 *positions_len,
7139 *scales_off,
7140 *scales_len,
7141 *rotations_off,
7142 *rotations_len,
7143 *opacities_off,
7144 *opacities_len,
7145 *colors_off,
7146 *colors_len,
7147 *sh_coeffs_off,
7148 *sh_coeffs_len,
7149 *meta_off,
7150 *d_loss_off,
7151 *d_loss_len,
7152 *packed_off,
7153 *packed_len,
7154 *width,
7155 *height,
7156 *tile_size,
7157 *radius_scale,
7158 *alpha_cutoff,
7159 *max_splat_steps,
7160 *transmittance_threshold,
7161 *max_list_entries,
7162 *loss_grad_clip,
7163 *sh_band,
7164 *max_anisotropy,
7165 base,
7166 );
7167 },
7168
7169 Thunk::GaussianSplatPrepare {
7170 positions_off,
7171 positions_len,
7172 scales_off,
7173 scales_len,
7174 rotations_off,
7175 rotations_len,
7176 opacities_off,
7177 opacities_len,
7178 colors_off,
7179 colors_len,
7180 sh_coeffs_off,
7181 sh_coeffs_len,
7182 meta_off,
7183 meta_len,
7184 prep_off,
7185 prep_len,
7186 width,
7187 height,
7188 tile_size,
7189 radius_scale,
7190 alpha_cutoff,
7191 max_splat_steps,
7192 transmittance_threshold,
7193 max_list_entries,
7194 } => unsafe {
7195 crate::splat::execute_gaussian_splat_prepare(
7196 *positions_off,
7197 *positions_len,
7198 *scales_off,
7199 *scales_len,
7200 *rotations_off,
7201 *rotations_len,
7202 *opacities_off,
7203 *opacities_len,
7204 *colors_off,
7205 *colors_len,
7206 *sh_coeffs_off,
7207 *sh_coeffs_len,
7208 *meta_off,
7209 *meta_len,
7210 *prep_off,
7211 *prep_len,
7212 *width,
7213 *height,
7214 *tile_size,
7215 *radius_scale,
7216 *alpha_cutoff,
7217 *max_splat_steps,
7218 *transmittance_threshold,
7219 *max_list_entries,
7220 base,
7221 );
7222 },
7223
7224 Thunk::GaussianSplatRasterize {
7225 prep_off,
7226 prep_len,
7227 meta_off,
7228 meta_len,
7229 dst_off,
7230 dst_len,
7231 count,
7232 width,
7233 height,
7234 tile_size,
7235 alpha_cutoff,
7236 max_splat_steps,
7237 transmittance_threshold,
7238 max_list_entries,
7239 } => unsafe {
7240 crate::splat::execute_gaussian_splat_rasterize(
7241 *prep_off,
7242 *prep_len,
7243 *meta_off,
7244 *meta_len,
7245 *dst_off,
7246 *dst_len,
7247 *count,
7248 *width,
7249 *height,
7250 *tile_size,
7251 *alpha_cutoff,
7252 *max_splat_steps,
7253 *transmittance_threshold,
7254 *max_list_entries,
7255 base,
7256 );
7257 },
7258
7259 Thunk::Fft1d {
7260 src,
7261 dst,
7262 outer,
7263 n_complex,
7264 inverse,
7265 dtype,
7266 } => unsafe {
7267 match dtype {
7268 rlx_ir::DType::F64 => execute_fft1d_f64(
7269 *src,
7270 *dst,
7271 *outer as usize,
7272 *n_complex as usize,
7273 *inverse,
7274 base,
7275 ),
7276 rlx_ir::DType::F32 => execute_fft1d_f32(
7277 *src,
7278 *dst,
7279 *outer as usize,
7280 *n_complex as usize,
7281 *inverse,
7282 base,
7283 ),
7284 other => panic!("Op::Fft on CPU requires F32/F64, got {other:?}"),
7285 }
7286 },
7287
7288 Thunk::CustomFn {
7292 body,
7293 body_init,
7294 inputs,
7295 body_output_off,
7296 outer_output_off,
7297 out_bytes,
7298 } => {
7299 let mut body_buf: Vec<u8> = (**body_init).clone();
7300 unsafe {
7301 for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
7302 let src = (base as *const u8).add(*outer_in_off);
7303 let dst = body_buf.as_mut_ptr().add(*body_in_off);
7304 std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
7305 }
7306 }
7307 execute_thunks(body, &mut body_buf);
7308 unsafe {
7309 let src = body_buf.as_ptr().add(*body_output_off);
7310 let dst = base.add(*outer_output_off);
7311 std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
7312 }
7313 }
7314
7315 Thunk::Sgemm { a, b, c, m, k, n } => {
7316 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7317 unsafe {
7318 crate::blas::sgemm_auto(
7319 sl(*a, base, m * k),
7320 sl(*b, base, k * n),
7321 sl_mut(*c, base, m * n),
7322 m,
7323 k,
7324 n,
7325 );
7326 }
7327 }
7328
7329 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
7330 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7331 unsafe {
7337 let a_src = sl_f64(*a, base, n_ * n_);
7338 let b_src = sl_f64(*b, base, n_ * nrhs_);
7339 let mut a_scratch: Vec<f64> = a_src.to_vec();
7340 let mut x_buf: Vec<f64> = b_src.to_vec();
7341 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7342 if info != 0 {
7343 panic!(
7344 "DenseSolveF64: dgesv reported singular matrix \
7345 (info={info}, n={n_}, nrhs={nrhs_})"
7346 );
7347 }
7348 let dst = sl_mut_f64(*x, base, n_ * nrhs_);
7349 dst.copy_from_slice(&x_buf);
7350 }
7351 }
7352
7353 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
7354 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7355 unsafe {
7356 let a_src = sl(*a, base, n_ * n_);
7357 let b_src = sl(*b, base, n_ * nrhs_);
7358 let mut a_scratch: Vec<f32> = a_src.to_vec();
7359 let mut x_buf: Vec<f32> = b_src.to_vec();
7360 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7361 if info != 0 {
7362 panic!(
7363 "DenseSolveF32: sgesv reported singular matrix \
7364 (info={info}, n={n_}, nrhs={nrhs_})"
7365 );
7366 }
7367 let dst = sl_mut(*x, base, n_ * nrhs_);
7368 dst.copy_from_slice(&x_buf);
7369 }
7370 }
7371
7372 Thunk::BatchedDenseSolveF64 {
7373 a,
7374 b,
7375 x,
7376 batch,
7377 n,
7378 nrhs,
7379 } => {
7380 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7387 let a_stride = n_ * n_;
7388 let b_stride = n_ * nrhs_;
7389 unsafe {
7390 let a_full = sl_f64(*a, base, b_ * a_stride);
7391 let b_full = sl_f64(*b, base, b_ * b_stride);
7392 let x_full = sl_mut_f64(*x, base, b_ * b_stride);
7393 for bi in 0..b_ {
7394 let mut a_scratch: Vec<f64> =
7395 a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7396 let mut x_buf: Vec<f64> =
7397 b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7398 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7399 if info != 0 {
7400 panic!(
7401 "BatchedDenseSolveF64: slice {bi} \
7402 singular (info={info}, n={n_}, nrhs={nrhs_})"
7403 );
7404 }
7405 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7406 }
7407 }
7408 }
7409
7410 Thunk::BatchedDenseSolveF32 {
7411 a,
7412 b,
7413 x,
7414 batch,
7415 n,
7416 nrhs,
7417 } => {
7418 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7419 let a_stride = n_ * n_;
7420 let b_stride = n_ * nrhs_;
7421 unsafe {
7422 let a_full = sl(*a, base, b_ * a_stride);
7423 let b_full = sl(*b, base, b_ * b_stride);
7424 let x_full = sl_mut(*x, base, b_ * b_stride);
7425 for bi in 0..b_ {
7426 let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7427 let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7428 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7429 if info != 0 {
7430 panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
7431 }
7432 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7433 }
7434 }
7435 }
7436
7437 Thunk::BatchedDgemmF64 {
7438 a,
7439 b,
7440 c,
7441 batch,
7442 m,
7443 k,
7444 n,
7445 } => {
7446 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7447 let a_stride = m_ * k_;
7448 let b_stride = k_ * n_;
7449 let c_stride = m_ * n_;
7450 unsafe {
7451 let a_full = sl_f64(*a, base, b_ * a_stride);
7452 let b_full = sl_f64(*b, base, b_ * b_stride);
7453 let c_full = sl_mut_f64(*c, base, b_ * c_stride);
7454 for bi in 0..b_ {
7455 let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7456 let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7457 let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7458 crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
7459 }
7460 }
7461 }
7462
7463 Thunk::BatchedSgemm {
7464 a,
7465 b,
7466 c,
7467 batch,
7468 m,
7469 k,
7470 n,
7471 } => {
7472 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7473 let a_stride = m_ * k_;
7474 let b_stride = k_ * n_;
7475 let c_stride = m_ * n_;
7476 unsafe {
7477 let a_full = sl(*a, base, b_ * a_stride);
7478 let b_full = sl(*b, base, b_ * b_stride);
7479 let c_full = sl_mut(*c, base, b_ * c_stride);
7480 for bi in 0..b_ {
7481 let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7482 let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7483 let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7484 crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
7485 }
7486 }
7487 }
7488
7489 Thunk::Dgemm { a, b, c, m, k, n } => {
7490 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7491 unsafe {
7492 crate::blas::dgemm(
7493 sl_f64(*a, base, m * k),
7494 sl_f64(*b, base, k * n),
7495 sl_mut_f64(*c, base, m * n),
7496 m,
7497 k,
7498 n,
7499 );
7500 }
7501 }
7502
7503 Thunk::TransposeF64 {
7504 src,
7505 dst,
7506 in_total,
7507 out_dims,
7508 in_strides,
7509 } => unsafe {
7510 let inp = sl_f64(*src, base, *in_total as usize);
7511 let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
7512 let out = sl_mut_f64(*dst, base, out_total);
7513 transpose_walk_f64(inp, out, out_dims, in_strides);
7514 },
7515
7516 Thunk::ActivationF64 {
7517 src,
7518 dst,
7519 len,
7520 kind,
7521 } => {
7522 let len = *len as usize;
7523 unsafe {
7524 let inp = sl_f64(*src, base, len);
7525 let out = sl_mut_f64(*dst, base, len);
7526 apply_activation_f64(inp, out, *kind);
7527 }
7528 }
7529
7530 Thunk::ReduceSumF64 {
7531 src,
7532 dst,
7533 outer,
7534 reduced,
7535 inner,
7536 } => {
7537 let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
7538 unsafe {
7539 let inp = sl_f64(*src, base, o * r * n);
7540 let out = sl_mut_f64(*dst, base, o * n);
7541 reduce_sum_f64(inp, out, o, r, n);
7542 }
7543 }
7544
7545 Thunk::CopyF64 { src, dst, len } => {
7546 let len = *len as usize;
7547 if *src == *dst { } else {
7549 unsafe {
7550 let s = sl_f64(*src, base, len);
7551 let d = sl_mut_f64(*dst, base, len);
7552 d.copy_from_slice(s);
7553 }
7554 }
7555 }
7556
7557 Thunk::BinaryFullF64 {
7558 lhs,
7559 rhs,
7560 dst,
7561 len,
7562 lhs_len,
7563 rhs_len,
7564 op,
7565 out_dims_bcast,
7566 bcast_lhs_strides,
7567 bcast_rhs_strides,
7568 } => {
7569 let len = *len as usize;
7570 let lhs_len = *lhs_len as usize;
7571 let rhs_len = *rhs_len as usize;
7572 unsafe {
7573 let l = sl_f64(*lhs, base, lhs_len);
7574 let r = sl_f64(*rhs, base, rhs_len);
7575 let d = sl_mut_f64(*dst, base, len);
7576 if lhs_len == len && rhs_len == len {
7577 for i in 0..len {
7578 d[i] = binary_op_f64(*op, l[i], r[i]);
7579 }
7580 } else if !out_dims_bcast.is_empty() {
7581 let rank = out_dims_bcast.len();
7585 let mut coords = vec![0u32; rank];
7586 for i in 0..len {
7587 let mut rem = i;
7588 for ax in (0..rank).rev() {
7589 let sz = out_dims_bcast[ax] as usize;
7590 coords[ax] = (rem % sz) as u32;
7591 rem /= sz;
7592 }
7593 let mut li: usize = 0;
7594 let mut ri: usize = 0;
7595 for ax in 0..rank {
7596 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7597 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7598 }
7599 d[i] = binary_op_f64(*op, l[li], r[ri]);
7600 }
7601 } else {
7602 for i in 0..len {
7607 d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
7608 }
7609 }
7610 }
7611 }
7612
7613 Thunk::BinaryFullC64 {
7614 lhs,
7615 rhs,
7616 dst,
7617 len,
7618 lhs_len,
7619 rhs_len,
7620 op,
7621 out_dims_bcast,
7622 bcast_lhs_strides,
7623 bcast_rhs_strides,
7624 } => {
7625 let n_out = *len as usize;
7631 let n_l = *lhs_len as usize;
7632 let n_r = *rhs_len as usize;
7633 unsafe {
7634 let l = sl(*lhs, base, 2 * n_l);
7635 let r = sl(*rhs, base, 2 * n_r);
7636 let d = sl_mut(*dst, base, 2 * n_out);
7637 let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
7638 match op {
7639 BinaryOp::Add => (a_re + b_re, a_im + b_im),
7640 BinaryOp::Sub => (a_re - b_re, a_im - b_im),
7641 BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
7642 BinaryOp::Div => {
7643 let denom = b_re * b_re + b_im * b_im;
7644 (
7645 (a_re * b_re + a_im * b_im) / denom,
7646 (a_im * b_re - a_re * b_im) / denom,
7647 )
7648 }
7649 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
7650 unreachable!("C64 max/min/pow rejected at lowering")
7651 }
7652 }
7653 };
7654 if n_l == n_out && n_r == n_out {
7655 for i in 0..n_out {
7656 let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
7657 d[2 * i] = re;
7658 d[2 * i + 1] = im;
7659 }
7660 } else if !out_dims_bcast.is_empty() {
7661 let rank = out_dims_bcast.len();
7665 let mut coords = vec![0u32; rank];
7666 for i in 0..n_out {
7667 let mut rem = i;
7668 for ax in (0..rank).rev() {
7669 let sz = out_dims_bcast[ax] as usize;
7670 coords[ax] = (rem % sz) as u32;
7671 rem /= sz;
7672 }
7673 let mut li: usize = 0;
7674 let mut ri: usize = 0;
7675 for ax in 0..rank {
7676 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7677 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7678 }
7679 let (re, im) =
7680 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7681 d[2 * i] = re;
7682 d[2 * i + 1] = im;
7683 }
7684 } else {
7685 for i in 0..n_out {
7687 let li = if n_l == 1 { 0 } else { i % n_l };
7688 let ri = if n_r == 1 { 0 } else { i % n_r };
7689 let (re, im) =
7690 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7691 d[2 * i] = re;
7692 d[2 * i + 1] = im;
7693 }
7694 }
7695 }
7696 }
7697
7698 Thunk::ComplexNormSqF32 { src, dst, len } => {
7699 let n = *len as usize;
7700 unsafe {
7701 let s = sl(*src, base, 2 * n);
7702 let d = sl_mut(*dst, base, n);
7703 for i in 0..n {
7704 let re = s[2 * i];
7705 let im = s[2 * i + 1];
7706 d[i] = re * re + im * im;
7707 }
7708 }
7709 }
7710
7711 Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
7712 let n = *len as usize;
7715 unsafe {
7716 let zb = sl(*z, base, 2 * n);
7717 let gb = sl(*g, base, n);
7718 let db = sl_mut(*dz, base, 2 * n);
7719 for i in 0..n {
7720 let re = zb[2 * i];
7721 let im = zb[2 * i + 1];
7722 let gv = gb[i];
7723 db[2 * i] = gv * re;
7724 db[2 * i + 1] = gv * im;
7725 }
7726 }
7727 }
7728
7729 Thunk::ConjugateC64 { src, dst, len } => {
7730 let n = *len as usize;
7731 unsafe {
7732 let s = sl(*src, base, 2 * n);
7733 let d = sl_mut(*dst, base, 2 * n);
7734 for i in 0..n {
7735 d[2 * i] = s[2 * i];
7736 d[2 * i + 1] = -s[2 * i + 1];
7737 }
7738 }
7739 }
7740
7741 Thunk::ActivationC64 {
7742 src,
7743 dst,
7744 len,
7745 kind,
7746 } => {
7747 let n = *len as usize;
7748 unsafe {
7749 let s = sl(*src, base, 2 * n);
7750 let d = sl_mut(*dst, base, 2 * n);
7751 for i in 0..n {
7752 let a = s[2 * i];
7753 let b = s[2 * i + 1];
7754 let (re, im) = match kind {
7755 Activation::Neg => (-a, -b),
7756 Activation::Exp => {
7757 let ea = a.exp();
7759 (ea * b.cos(), ea * b.sin())
7760 }
7761 Activation::Log => {
7762 let r = (a * a + b * b).sqrt();
7764 (r.ln(), b.atan2(a))
7765 }
7766 Activation::Sqrt => {
7767 let r = (a * a + b * b).sqrt();
7770 let re = ((r + a) * 0.5).max(0.0).sqrt();
7771 let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
7772 let im = if b >= 0.0 { im_mag } else { -im_mag };
7773 (re, im)
7774 }
7775 _ => unreachable!("non-C64 activation kind survived lowering"),
7776 };
7777 d[2 * i] = re;
7778 d[2 * i + 1] = im;
7779 }
7780 }
7781 }
7782
7783 Thunk::Scan {
7784 body,
7785 body_init,
7786 body_input_off,
7787 body_output_off,
7788 outer_init_off,
7789 outer_final_off,
7790 length,
7791 carry_bytes,
7792 save_trajectory,
7793 xs_inputs,
7794 bcast_inputs,
7795 num_checkpoints,
7796 } => {
7797 let cb = *carry_bytes as usize;
7798 let n_steps = *length as usize;
7799 let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
7803 n_steps } else {
7805 *num_checkpoints as usize
7806 };
7807 let checkpoint_t_for_k = |k: usize| -> usize {
7808 if k_total == n_steps {
7809 k
7810 } else {
7811 ((k + 1) * n_steps)
7812 .div_ceil(k_total)
7813 .saturating_sub(1)
7814 .min(n_steps - 1)
7815 }
7816 };
7817 let mut next_k = 0usize;
7818
7819 let mut body_buf: Vec<u8> = (**body_init).clone();
7820 unsafe {
7821 std::ptr::copy_nonoverlapping(
7822 base.add(*outer_init_off),
7823 body_buf.as_mut_ptr().add(*body_input_off),
7824 cb,
7825 );
7826 for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
7830 std::ptr::copy_nonoverlapping(
7831 base.add(*outer_b_off),
7832 body_buf.as_mut_ptr().add(*body_b_off),
7833 *total_bytes as usize,
7834 );
7835 }
7836 }
7837 for t in 0..n_steps {
7838 for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
7839 let psb = *per_step_bytes as usize;
7840 unsafe {
7841 std::ptr::copy_nonoverlapping(
7842 base.add(*outer_xs_off + t * psb),
7843 body_buf.as_mut_ptr().add(*body_x_off),
7844 psb,
7845 );
7846 }
7847 }
7848
7849 execute_thunks(body, &mut body_buf);
7850
7851 if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
7852 unsafe {
7853 std::ptr::copy_nonoverlapping(
7854 body_buf.as_ptr().add(*body_output_off),
7855 base.add(*outer_final_off + next_k * cb),
7856 cb,
7857 );
7858 }
7859 next_k += 1;
7860 }
7861
7862 if *body_output_off != *body_input_off {
7863 body_buf
7864 .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
7865 }
7866 }
7867
7868 if !*save_trajectory {
7869 unsafe {
7871 std::ptr::copy_nonoverlapping(
7872 body_buf.as_ptr().add(*body_output_off),
7873 base.add(*outer_final_off),
7874 cb,
7875 );
7876 }
7877 }
7878 }
7879
7880 Thunk::ScanBackward {
7881 body_vjp,
7882 body_init,
7883 body_carry_in_off,
7884 body_x_offs,
7885 body_d_output_off,
7886 body_dcarry_out_off,
7887 outer_init_off,
7888 outer_traj_off,
7889 outer_upstream_off,
7890 outer_xs_offs,
7891 outer_dinit_off,
7892 length,
7893 carry_bytes,
7894 save_trajectory,
7895 num_checkpoints,
7896 forward_body,
7897 forward_body_init,
7898 forward_body_carry_in_off,
7899 forward_body_output_off,
7900 forward_body_x_offs,
7901 carry_elem_size,
7902 } => {
7903 let cb = *carry_bytes as usize;
7916 let n_steps = *length as usize;
7917 let k_total = *num_checkpoints as usize;
7918 let is_recursive = k_total != 0 && k_total != n_steps;
7919 let checkpoint_t_for_k = |k: usize| -> usize {
7920 ((k + 1) * n_steps)
7921 .div_ceil(k_total)
7922 .saturating_sub(1)
7923 .min(n_steps - 1)
7924 };
7925
7926 let mut fwd_buf: Vec<u8> = if is_recursive {
7927 (**forward_body_init.as_ref().unwrap()).clone()
7928 } else {
7929 Vec::new()
7930 };
7931
7932 let mut dcarry: Vec<u8> = vec![0u8; cb];
7933 if !*save_trajectory {
7934 unsafe {
7935 std::ptr::copy_nonoverlapping(
7936 base.add(*outer_upstream_off),
7937 dcarry.as_mut_ptr(),
7938 cb,
7939 );
7940 }
7941 }
7942
7943 let mut body_buf: Vec<u8> = (**body_init).clone();
7944
7945 let process_iter =
7950 |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
7951 if *save_trajectory {
7952 unsafe {
7953 let up_off = *outer_upstream_off + t * cb;
7954 match *carry_elem_size {
7955 4 => {
7956 let up_ptr = base.add(up_off) as *const f32;
7957 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
7958 let n_elems = cb / 4;
7959 for i in 0..n_elems {
7960 *dc_ptr.add(i) += *up_ptr.add(i);
7961 }
7962 }
7963 8 => {
7964 let up_ptr = base.add(up_off) as *const f64;
7965 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
7966 let n_elems = cb / 8;
7967 for i in 0..n_elems {
7968 *dc_ptr.add(i) += *up_ptr.add(i);
7969 }
7970 }
7971 other => panic!(
7972 "ScanBackward: unsupported carry elem size {other} \
7973 (only f32/f64 carries are supported today)"
7974 ),
7975 }
7976 }
7977 }
7978 body_buf[*body_carry_in_off..*body_carry_in_off + cb]
7979 .copy_from_slice(carry_in);
7980 unsafe {
7981 for (i, body_x_off) in body_x_offs.iter().enumerate() {
7982 let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
7983 let psb = per_step_bytes as usize;
7984 std::ptr::copy_nonoverlapping(
7985 base.add(outer_xs_off + t * psb),
7986 body_buf.as_mut_ptr().add(*body_x_off),
7987 psb,
7988 );
7989 }
7990 std::ptr::copy_nonoverlapping(
7991 dcarry.as_ptr(),
7992 body_buf.as_mut_ptr().add(*body_d_output_off),
7993 cb,
7994 );
7995 }
7996 execute_thunks(body_vjp, body_buf);
7997 unsafe {
7998 std::ptr::copy_nonoverlapping(
7999 body_buf.as_ptr().add(*body_dcarry_out_off),
8000 dcarry.as_mut_ptr(),
8001 cb,
8002 );
8003 }
8004 };
8005
8006 if is_recursive {
8007 let leaf_threshold = 4usize;
8015 let fb_sched = forward_body.as_ref().unwrap();
8016 let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8017 let mut segment_end = n_steps - 1;
8018 for seg_k in (0..k_total).rev() {
8019 let segment_start = if seg_k == 0 {
8020 0
8021 } else {
8022 checkpoint_t_for_k(seg_k - 1) + 1
8023 };
8024 let mut anchor: Vec<u8> = vec![0u8; cb];
8025 unsafe {
8026 let src = if seg_k == 0 {
8027 base.add(*outer_init_off)
8028 } else {
8029 base.add(*outer_traj_off + (seg_k - 1) * cb)
8030 };
8031 std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8032 }
8033 let mut leaf_action = |t: usize, carry_in: &[u8]| {
8036 process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8037 };
8038 unsafe {
8039 griewank_process_segment(
8040 segment_start,
8041 segment_end,
8042 &anchor,
8043 cb,
8044 fb_sched,
8045 fb_init,
8046 *forward_body_carry_in_off,
8047 *forward_body_output_off,
8048 forward_body_x_offs,
8049 base,
8050 outer_xs_offs,
8051 &mut fwd_buf,
8052 leaf_threshold,
8053 &mut leaf_action,
8054 );
8055 }
8056 if seg_k == 0 {
8057 break;
8058 }
8059 segment_end = segment_start - 1;
8060 }
8061 } else {
8062 let mut carry_buf: Vec<u8> = vec![0u8; cb];
8065 for t in (0..n_steps).rev() {
8066 unsafe {
8067 let src = if t == 0 {
8068 base.add(*outer_init_off)
8069 } else {
8070 base.add(*outer_traj_off + (t - 1) * cb)
8071 };
8072 std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8073 }
8074 process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8075 }
8076 }
8077
8078 unsafe {
8079 std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8080 }
8081 }
8082
8083 Thunk::ScanBackwardXs {
8084 body_vjp,
8085 body_init,
8086 body_carry_in_off,
8087 body_x_offs,
8088 body_d_output_off,
8089 body_dcarry_out_off,
8090 body_dxs_out_off,
8091 outer_init_off,
8092 outer_traj_off,
8093 outer_upstream_off,
8094 outer_xs_offs,
8095 outer_dxs_off,
8096 length,
8097 carry_bytes,
8098 carry_elem_size,
8099 per_step_bytes,
8100 save_trajectory,
8101 num_checkpoints,
8102 forward_body,
8103 forward_body_init,
8104 forward_body_carry_in_off,
8105 forward_body_output_off,
8106 forward_body_x_offs,
8107 } => {
8108 let cb = *carry_bytes as usize;
8109 let psb = *per_step_bytes as usize;
8110 let n_steps = *length as usize;
8111 let k_total = *num_checkpoints as usize;
8112 let is_recursive = k_total != 0 && k_total != n_steps;
8113 let checkpoint_t_for_k = |k: usize| -> usize {
8114 ((k + 1) * n_steps)
8115 .div_ceil(k_total)
8116 .saturating_sub(1)
8117 .min(n_steps - 1)
8118 };
8119
8120 let mut fwd_buf: Vec<u8> = if is_recursive {
8124 (**forward_body_init.as_ref().unwrap()).clone()
8125 } else {
8126 Vec::new()
8127 };
8128 let mut seg_cache: Vec<u8> = Vec::new();
8129 let mut seg_start_t: usize = usize::MAX;
8130 let mut seg_count: usize = 0;
8131 let recompute_carry_t =
8132 |t: usize,
8133 dst: &mut [u8],
8134 fwd_buf: &mut Vec<u8>,
8135 seg_cache: &mut Vec<u8>,
8136 seg_start_t: &mut usize,
8137 seg_count: &mut usize| {
8138 if !is_recursive {
8139 unsafe {
8140 let src = if t == 0 {
8141 base.add(*outer_init_off)
8142 } else {
8143 base.add(*outer_traj_off + (t - 1) * cb)
8144 };
8145 std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
8146 }
8147 return;
8148 }
8149 if *seg_start_t != usize::MAX
8150 && t >= *seg_start_t
8151 && t < *seg_start_t + *seg_count
8152 {
8153 let off = (t - *seg_start_t) * cb;
8154 dst.copy_from_slice(&seg_cache[off..off + cb]);
8155 return;
8156 }
8157 let seg_k = (0..k_total)
8158 .find(|&k| t <= checkpoint_t_for_k(k))
8159 .unwrap_or(k_total - 1);
8160 let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
8161 (0, unsafe { base.add(*outer_init_off) as *const u8 })
8162 } else {
8163 let prev_ck = checkpoint_t_for_k(seg_k - 1);
8164 (prev_ck + 1, unsafe {
8165 base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
8166 })
8167 };
8168 let seg_end_t = checkpoint_t_for_k(seg_k);
8169 let seg_size = seg_end_t - anchor_t + 1;
8170
8171 fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
8172 unsafe {
8173 std::ptr::copy_nonoverlapping(
8174 anchor_ptr,
8175 fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
8176 cb,
8177 );
8178 }
8179 seg_cache.resize(seg_size * cb, 0u8);
8180 seg_cache[0..cb].copy_from_slice(
8181 &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8182 );
8183 let fb_sched = forward_body.as_ref().unwrap();
8184 for i in 1..seg_size {
8185 let cur_iter = anchor_t + i - 1;
8186 for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
8187 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
8188 let xb = x_psb as usize;
8189 unsafe {
8190 std::ptr::copy_nonoverlapping(
8191 base.add(outer_xs_off + cur_iter * xb),
8192 fwd_buf.as_mut_ptr().add(*fb_x_off),
8193 xb,
8194 );
8195 }
8196 }
8197 execute_thunks(fb_sched, fwd_buf);
8198 if *forward_body_output_off != *forward_body_carry_in_off {
8199 fwd_buf.copy_within(
8200 *forward_body_output_off..*forward_body_output_off + cb,
8201 *forward_body_carry_in_off,
8202 );
8203 }
8204 let cache_off = i * cb;
8205 seg_cache[cache_off..cache_off + cb].copy_from_slice(
8206 &fwd_buf
8207 [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8208 );
8209 }
8210 *seg_start_t = anchor_t;
8211 *seg_count = seg_size;
8212
8213 let off = (t - anchor_t) * cb;
8214 dst.copy_from_slice(&seg_cache[off..off + cb]);
8215 };
8216
8217 let mut dcarry: Vec<u8> = vec![0u8; cb];
8218 if !*save_trajectory {
8219 unsafe {
8220 std::ptr::copy_nonoverlapping(
8221 base.add(*outer_upstream_off),
8222 dcarry.as_mut_ptr(),
8223 cb,
8224 );
8225 }
8226 }
8227
8228 let mut body_buf: Vec<u8> = (**body_init).clone();
8229
8230 for t in (0..n_steps).rev() {
8231 if *save_trajectory {
8232 unsafe {
8233 let up_off = *outer_upstream_off + t * cb;
8234 match *carry_elem_size {
8235 4 => {
8236 let up_ptr = base.add(up_off) as *const f32;
8237 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8238 let n_elems = cb / 4;
8239 for i in 0..n_elems {
8240 *dc_ptr.add(i) += *up_ptr.add(i);
8241 }
8242 }
8243 8 => {
8244 let up_ptr = base.add(up_off) as *const f64;
8245 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8246 let n_elems = cb / 8;
8247 for i in 0..n_elems {
8248 *dc_ptr.add(i) += *up_ptr.add(i);
8249 }
8250 }
8251 other => panic!(
8252 "ScanBackwardXs: unsupported carry elem size {other} \
8253 (only f32/f64 carries are supported today)"
8254 ),
8255 }
8256 }
8257 }
8258
8259 let carry_dst_start = *body_carry_in_off;
8263 {
8264 let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
8265 recompute_carry_t(
8266 t,
8267 carry_slice,
8268 &mut fwd_buf,
8269 &mut seg_cache,
8270 &mut seg_start_t,
8271 &mut seg_count,
8272 );
8273 }
8274 unsafe {
8275 for (i, body_x_off) in body_x_offs.iter().enumerate() {
8276 let (outer_xs_off, x_psb) = outer_xs_offs[i];
8277 let xb = x_psb as usize;
8278 std::ptr::copy_nonoverlapping(
8279 base.add(outer_xs_off + t * xb),
8280 body_buf.as_mut_ptr().add(*body_x_off),
8281 xb,
8282 );
8283 }
8284 std::ptr::copy_nonoverlapping(
8285 dcarry.as_ptr(),
8286 body_buf.as_mut_ptr().add(*body_d_output_off),
8287 cb,
8288 );
8289 }
8290
8291 execute_thunks(body_vjp, &mut body_buf);
8292
8293 unsafe {
8296 std::ptr::copy_nonoverlapping(
8297 body_buf.as_ptr().add(*body_dxs_out_off),
8298 base.add(*outer_dxs_off + t * psb),
8299 psb,
8300 );
8301 }
8302
8303 unsafe {
8305 std::ptr::copy_nonoverlapping(
8306 body_buf.as_ptr().add(*body_dcarry_out_off),
8307 dcarry.as_mut_ptr(),
8308 cb,
8309 );
8310 }
8311 }
8312 }
8313
8314 Thunk::FusedMmBiasAct {
8315 a,
8316 w,
8317 bias,
8318 c,
8319 m,
8320 k,
8321 n,
8322 act,
8323 } => {
8324 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8325 unsafe {
8326 let out = sl_mut(*c, base, m * n);
8327 crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
8328 match act {
8329 Some(Activation::Gelu) => {
8330 crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
8331 }
8332 Some(other) => {
8333 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8334 apply_activation_inplace(out, *other);
8335 }
8336 None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
8337 }
8338 }
8339 }
8340
8341 Thunk::FusedResidualLN {
8342 x,
8343 res,
8344 bias,
8345 g,
8346 b,
8347 out,
8348 rows,
8349 h,
8350 eps,
8351 has_bias,
8352 } => {
8353 let (rows, h) = (*rows as usize, *h as usize);
8354 unsafe {
8355 let zero = &zero_bias[..h];
8356 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8357 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8358 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8359 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8360 let bi_ptr = bi.as_ptr() as usize;
8361 let g_ptr = sl(*g, base, h).as_ptr() as usize;
8362 let b_ptr = sl(*b, base, h).as_ptr() as usize;
8363 let e = *eps;
8364 crate::pool::par_for(rows, 4, &|off, cnt| {
8365 let xs =
8366 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8367 let rs =
8368 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8369 let os = std::slice::from_raw_parts_mut(
8370 (o_ptr as *mut f32).add(off * h),
8371 cnt * h,
8372 );
8373 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8374 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8375 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8376 crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
8377 });
8378 }
8379 }
8380
8381 Thunk::FusedResidualRmsNorm {
8382 x,
8383 res,
8384 bias,
8385 g,
8386 b,
8387 out,
8388 rows,
8389 h,
8390 eps,
8391 has_bias,
8392 } => {
8393 let (rows, h) = (*rows as usize, *h as usize);
8394 unsafe {
8395 let zero = &zero_bias[..h];
8396 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8397 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8398 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8399 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8400 let bi_ptr = bi.as_ptr() as usize;
8401 let g_ptr = sl(*g, base, h).as_ptr() as usize;
8402 let b_ptr = sl(*b, base, h).as_ptr() as usize;
8403 let e = *eps;
8404 crate::pool::par_for(rows, 4, &|off, cnt| {
8405 let xs =
8406 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8407 let rs =
8408 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8409 let os = std::slice::from_raw_parts_mut(
8410 (o_ptr as *mut f32).add(off * h),
8411 cnt * h,
8412 );
8413 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8414 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8415 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8416 crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
8417 });
8418 }
8419 }
8420
8421 Thunk::BiasAdd {
8422 src,
8423 bias,
8424 dst,
8425 m,
8426 n,
8427 } => {
8428 let (m, n) = (*m as usize, *n as usize);
8429 unsafe {
8430 let out = sl_mut(*dst, base, m * n);
8431 out.copy_from_slice(sl(*src, base, m * n));
8432 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8433 }
8434 }
8435
8436 Thunk::BinaryFull {
8437 lhs,
8438 rhs,
8439 dst,
8440 len,
8441 lhs_len,
8442 rhs_len,
8443 op,
8444 out_dims_bcast,
8445 bcast_lhs_strides,
8446 bcast_rhs_strides,
8447 } => {
8448 let len = *len as usize;
8449 let ll = (*lhs_len as usize).max(1);
8450 let rl = (*rhs_len as usize).max(1);
8451 unsafe {
8452 let l = sl(*lhs, base, ll);
8453 let r = sl(*rhs, base, rl);
8454 let o = sl_mut(*dst, base, len);
8455 if ll == len && rl == len {
8457 #[cfg(target_arch = "aarch64")]
8458 if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
8459 use std::arch::aarch64::*;
8460 let chunks = len / 4;
8461 for c in 0..chunks {
8462 let off = c * 4;
8463 let vl = vld1q_f32(l.as_ptr().add(off));
8464 let vr = vld1q_f32(r.as_ptr().add(off));
8465 let res = match op {
8466 BinaryOp::Add => vaddq_f32(vl, vr),
8467 BinaryOp::Mul => vmulq_f32(vl, vr),
8468 _ => unreachable!(),
8469 };
8470 vst1q_f32(o.as_mut_ptr().add(off), res);
8471 }
8472 for i in (chunks * 4)..len {
8473 o[i] = match op {
8474 BinaryOp::Add => l[i] + r[i],
8475 BinaryOp::Mul => l[i] * r[i],
8476 _ => unreachable!(),
8477 };
8478 }
8479 continue;
8485 }
8486 }
8487 if !out_dims_bcast.is_empty() {
8488 let rank = out_dims_bcast.len();
8491 let mut coords = vec![0u32; rank];
8492 for i in 0..len {
8493 let mut rem = i;
8494 for ax in (0..rank).rev() {
8495 let sz = out_dims_bcast[ax] as usize;
8496 coords[ax] = (rem % sz) as u32;
8497 rem /= sz;
8498 }
8499 let mut li: usize = 0;
8500 let mut ri: usize = 0;
8501 for ax in 0..rank {
8502 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8503 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8504 }
8505 o[i] = match op {
8506 BinaryOp::Add => l[li] + r[ri],
8507 BinaryOp::Sub => l[li] - r[ri],
8508 BinaryOp::Mul => l[li] * r[ri],
8509 BinaryOp::Div => l[li] / r[ri],
8510 BinaryOp::Max => l[li].max(r[ri]),
8511 BinaryOp::Min => l[li].min(r[ri]),
8512 BinaryOp::Pow => l[li].powf(r[ri]),
8513 };
8514 }
8515 } else {
8516 for i in 0..len {
8518 let li = if ll == 1 { 0 } else { i % ll };
8519 let ri = if rl == 1 { 0 } else { i % rl };
8520 o[i] = match op {
8521 BinaryOp::Add => l[li] + r[ri],
8522 BinaryOp::Sub => l[li] - r[ri],
8523 BinaryOp::Mul => l[li] * r[ri],
8524 BinaryOp::Div => l[li] / r[ri],
8525 BinaryOp::Max => l[li].max(r[ri]),
8526 BinaryOp::Min => l[li].min(r[ri]),
8527 BinaryOp::Pow => l[li].powf(r[ri]),
8528 };
8529 }
8530 }
8531 }
8532 }
8533
8534 Thunk::Gather {
8535 table,
8536 table_len,
8537 idx,
8538 dst,
8539 num_idx,
8540 trailing,
8541 } => {
8542 let (ni, tr) = (*num_idx as usize, *trailing as usize);
8543 unsafe {
8544 let tab = sl(*table, base, *table_len as usize);
8545 let ids = sl(*idx, base, ni);
8546 let out = sl_mut(*dst, base, ni * tr);
8547 for i in 0..ni {
8548 let row = ids[i] as usize;
8549 out[i * tr..(i + 1) * tr].copy_from_slice(&tab[row * tr..(row + 1) * tr]);
8550 }
8551 }
8552 }
8553
8554 Thunk::Narrow {
8555 src,
8556 dst,
8557 outer,
8558 src_stride,
8559 dst_stride,
8560 inner,
8561 elem_bytes,
8562 } => {
8563 let f = narrow_thunk_closure(
8564 *src,
8565 *dst,
8566 *outer,
8567 *src_stride,
8568 *dst_stride,
8569 *inner,
8570 *elem_bytes,
8571 );
8572 f(base);
8573 }
8574
8575 Thunk::Copy { src, dst, len } => {
8576 let len = *len as usize;
8577 unsafe {
8578 let s = sl(*src, base, len);
8579 let d = sl_mut(*dst, base, len);
8580 d.copy_from_slice(s);
8581 }
8582 }
8583
8584 Thunk::LayerNorm {
8585 src,
8586 g,
8587 b,
8588 dst,
8589 rows,
8590 h,
8591 eps,
8592 } => {
8593 let (rows, h) = (*rows as usize, *h as usize);
8594 unsafe {
8595 let input = sl(*src, base, rows * h);
8596 let gamma = sl(*g, base, h);
8597 let beta = sl(*b, base, h);
8598 let output = sl_mut(*dst, base, rows * h);
8599 if rows >= 4 && rows * h >= 30_000 {
8601 let i_ptr = input.as_ptr() as usize;
8602 let o_ptr = output.as_mut_ptr() as usize;
8603 let g_ptr = gamma.as_ptr() as usize;
8604 let b_ptr = beta.as_ptr() as usize;
8605 let e = *eps;
8606 crate::pool::par_for(rows, 4, &|off, cnt| {
8607 let inp = std::slice::from_raw_parts(
8608 (i_ptr as *const f32).add(off * h),
8609 cnt * h,
8610 );
8611 let out = std::slice::from_raw_parts_mut(
8612 (o_ptr as *mut f32).add(off * h),
8613 cnt * h,
8614 );
8615 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8616 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8617 for row in 0..cnt {
8618 crate::kernels::layer_norm_row(
8619 &inp[row * h..(row + 1) * h],
8620 g,
8621 b,
8622 &mut out[row * h..(row + 1) * h],
8623 h,
8624 e,
8625 );
8626 }
8627 });
8628 } else {
8629 for row in 0..rows {
8630 crate::kernels::layer_norm_row(
8631 &input[row * h..(row + 1) * h],
8632 gamma,
8633 beta,
8634 &mut output[row * h..(row + 1) * h],
8635 h,
8636 *eps,
8637 );
8638 }
8639 }
8640 }
8641 }
8642
8643 Thunk::GroupNorm {
8644 src,
8645 g,
8646 b,
8647 dst,
8648 n,
8649 c,
8650 h,
8651 w,
8652 num_groups,
8653 eps,
8654 } => {
8655 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8656 let plane = c * h * w;
8657 unsafe {
8658 for ni in 0..n {
8659 let input = sl(*src, base.add(ni * plane), plane);
8660 let gamma = sl(*g, base, c);
8661 let beta = sl(*b, base, c);
8662 let output = sl_mut(*dst, base.add(ni * plane), plane);
8663 crate::kernels::group_norm_nchw(
8664 input,
8665 gamma,
8666 beta,
8667 output,
8668 1,
8669 c,
8670 h,
8671 w,
8672 *num_groups as usize,
8673 *eps,
8674 );
8675 }
8676 }
8677 }
8678
8679 Thunk::LayerNorm2d {
8680 src,
8681 g,
8682 b,
8683 dst,
8684 n,
8685 c,
8686 h,
8687 w,
8688 eps,
8689 } => {
8690 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8691 let plane = c * h * w;
8692 unsafe {
8693 let input = sl(*src, base, n * plane);
8694 let gamma = sl(*g, base, c);
8695 let beta = sl(*b, base, c);
8696 let output = sl_mut(*dst, base, n * plane);
8697 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
8698 }
8699 }
8700
8701 Thunk::ConvTranspose2d {
8702 src,
8703 weight,
8704 dst,
8705 n,
8706 c_in,
8707 h,
8708 w_in,
8709 c_out,
8710 h_out,
8711 w_out,
8712 kh,
8713 kw,
8714 sh,
8715 sw,
8716 ph,
8717 pw,
8718 dh,
8719 dw,
8720 groups,
8721 } => {
8722 let n = *n as usize;
8723 let c_in = *c_in as usize;
8724 let h = *h as usize;
8725 let w_in = *w_in as usize;
8726 let c_out = *c_out as usize;
8727 let h_out = *h_out as usize;
8728 let w_out = *w_out as usize;
8729 unsafe {
8730 let inp = sl(*src, base, n * c_in * h * w_in);
8731 let wt = sl(
8732 *weight,
8733 base,
8734 c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
8735 );
8736 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
8737 crate::kernels::conv_transpose2d_nchw(
8738 inp,
8739 wt,
8740 out,
8741 n,
8742 c_in,
8743 h,
8744 w_in,
8745 c_out,
8746 h_out,
8747 w_out,
8748 *kh as usize,
8749 *kw as usize,
8750 *sh as usize,
8751 *sw as usize,
8752 *ph as usize,
8753 *pw as usize,
8754 *dh as usize,
8755 *dw as usize,
8756 *groups as usize,
8757 );
8758 }
8759 }
8760
8761 Thunk::ResizeNearest2x {
8762 src,
8763 dst,
8764 n,
8765 c,
8766 h,
8767 w,
8768 } => {
8769 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8770 let in_plane = c * h * w;
8771 let out_plane = c * h * 2 * w * 2;
8772 unsafe {
8773 for ni in 0..n {
8774 let input = sl(*src, base.add(ni * in_plane), in_plane);
8775 let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
8776 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
8777 }
8778 }
8779 }
8780
8781 Thunk::AxialRope2d {
8782 src,
8783 dst,
8784 batch,
8785 seq,
8786 hidden,
8787 end_x,
8788 end_y,
8789 head_dim,
8790 num_heads,
8791 theta,
8792 repeat_factor,
8793 } => {
8794 let b = *batch as usize;
8795 let s = *seq as usize;
8796 let hdim = *head_dim as usize;
8797 let nh = *num_heads as usize;
8798 let plane = s * (*hidden as usize);
8799 unsafe {
8800 for bi in 0..b {
8801 let input = sl(*src, base.add(bi * plane), plane);
8802 let output = sl_mut(*dst, base.add(bi * plane), plane);
8803 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
8804 input,
8805 nh,
8806 s,
8807 hdim,
8808 *end_x as usize,
8809 *end_y as usize,
8810 *theta,
8811 *repeat_factor as usize,
8812 );
8813 output.copy_from_slice(&rotated);
8814 }
8815 }
8816 }
8817
8818 Thunk::RmsNorm {
8819 src,
8820 g,
8821 b,
8822 dst,
8823 rows,
8824 h,
8825 eps,
8826 } => {
8827 let (rows, h) = (*rows as usize, *h as usize);
8828 unsafe {
8829 let input = sl(*src, base, rows * h);
8830 let gamma = sl(*g, base, h);
8831 let beta = sl(*b, base, h);
8832 let output = sl_mut(*dst, base, rows * h);
8833 let inv_h = 1.0 / h as f32;
8834 for row in 0..rows {
8835 let in_row = &input[row * h..(row + 1) * h];
8836 let out_row = &mut output[row * h..(row + 1) * h];
8837 let mut sumsq = 0f32;
8839 for &v in in_row {
8840 sumsq += v * v;
8841 }
8842 let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
8843 for i in 0..h {
8844 out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
8845 }
8846 }
8847 }
8848 }
8849
8850 Thunk::Softmax { data, rows, cols } => {
8851 let (rows, cols) = (*rows as usize, *cols as usize);
8852 unsafe {
8853 crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
8854 }
8855 }
8856
8857 Thunk::Cumsum {
8858 src,
8859 dst,
8860 rows,
8861 cols,
8862 exclusive,
8863 } => {
8864 let (rows, cols) = (*rows as usize, *cols as usize);
8865 unsafe {
8866 let s = sl(*src, base, rows * cols);
8867 let d = sl_mut(*dst, base, rows * cols);
8868 if *exclusive {
8869 for r in 0..rows {
8870 let mut acc = 0.0f32;
8871 for c in 0..cols {
8872 d[r * cols + c] = acc;
8873 acc += s[r * cols + c];
8874 }
8875 }
8876 } else {
8877 for r in 0..rows {
8878 let mut acc = 0.0f32;
8879 for c in 0..cols {
8880 acc += s[r * cols + c];
8881 d[r * cols + c] = acc;
8882 }
8883 }
8884 }
8885 }
8886 }
8887
8888 Thunk::Sample {
8889 logits,
8890 dst,
8891 batch,
8892 vocab,
8893 top_k,
8894 top_p,
8895 temperature,
8896 seed,
8897 } => {
8898 let (b, v) = (*batch as usize, *vocab as usize);
8899 let k = (*top_k as usize).min(v);
8900 unsafe {
8901 let lg = sl(*logits, base, b * v);
8902 let out = sl_mut(*dst, base, b);
8903 let mut rng =
8904 rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
8905 for bi in 0..b {
8906 let row = &lg[bi * v..(bi + 1) * v];
8907 out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
8908 }
8909 }
8910 }
8911
8912 Thunk::GatedDeltaNet {
8913 q,
8914 k,
8915 v,
8916 g,
8917 beta,
8918 state,
8919 dst,
8920 batch,
8921 seq,
8922 heads,
8923 state_size,
8924 } => unsafe {
8925 execute_gated_delta_net_f32(
8926 *q,
8927 *k,
8928 *v,
8929 *g,
8930 *beta,
8931 *state,
8932 *dst,
8933 *batch as usize,
8934 *seq as usize,
8935 *heads as usize,
8936 *state_size as usize,
8937 base,
8938 );
8939 },
8940
8941 Thunk::SelectiveScan {
8942 x,
8943 delta,
8944 a,
8945 b: bp,
8946 c: cp,
8947 dst,
8948 batch,
8949 seq,
8950 hidden,
8951 state_size,
8952 } => {
8953 let (b, s, h, n) = (
8954 *batch as usize,
8955 *seq as usize,
8956 *hidden as usize,
8957 *state_size as usize,
8958 );
8959 unsafe {
8960 let xs = sl(*x, base, b * s * h);
8961 let dt = sl(*delta, base, b * s * h);
8962 let am = sl(*a, base, h * n);
8963 let bm = sl(*bp, base, b * s * n);
8964 let cm = sl(*cp, base, b * s * n);
8965 let out = sl_mut(*dst, base, b * s * h);
8966
8967 let mut state = vec![0f32; h * n];
8971 for bi in 0..b {
8972 for v in state.iter_mut() {
8974 *v = 0.0;
8975 }
8976 for si in 0..s {
8977 let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8978 let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8979 let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
8980 let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
8981 let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8982
8983 for ci in 0..h {
8984 let d = dt_row[ci];
8985 let xv = x_row[ci];
8986 let mut acc = 0f32;
8987 for ni in 0..n {
8988 let da = (d * am[ci * n + ni]).exp();
8990 state[ci * n + ni] =
8991 da * state[ci * n + ni] + d * b_row[ni] * xv;
8992 acc += c_row[ni] * state[ci * n + ni];
8993 }
8994 out_row[ci] = acc;
8995 }
8996 }
8997 }
8998 }
8999 }
9000
9001 Thunk::DequantMatMul {
9002 x,
9003 w_q,
9004 scale,
9005 zp,
9006 dst,
9007 m,
9008 k,
9009 n,
9010 block_size,
9011 is_asymmetric,
9012 } => {
9013 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9014 let n_blocks = k.div_ceil(bs);
9015 unsafe {
9016 let xs = sl(*x, base, m * k);
9017 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
9018 let scales = sl(*scale, base, n_blocks * n);
9019 let zps = if *is_asymmetric {
9020 sl(*zp, base, n_blocks * n)
9021 } else {
9022 &[][..]
9023 };
9024 let out = sl_mut(*dst, base, m * n);
9025 dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9026 }
9027 }
9028
9029 Thunk::DequantMatMulGguf {
9030 x,
9031 w_q,
9032 dst,
9033 m,
9034 k,
9035 n,
9036 scheme,
9037 } => {
9038 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9039 let block_bytes = scheme.gguf_block_bytes() as usize;
9040 let block_elems = scheme.gguf_block_size() as usize;
9041 debug_assert!(
9042 block_bytes > 0 && block_elems > 0,
9043 "non-GGUF scheme in GGUF arm"
9044 );
9045 debug_assert!(
9046 (k * n).is_multiple_of(block_elems),
9047 "k*n={} not aligned to GGUF block size {}",
9048 k * n,
9049 block_elems
9050 );
9051 let total_bytes = (k * n) / block_elems * block_bytes;
9052 unsafe {
9053 let xs = sl(*x, base, m * k);
9054 let w_bytes_ptr = base.add(*w_q) as *const u8;
9055 let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
9056 let out = sl_mut(*dst, base, m * n);
9057 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
9058 }
9059 }
9060
9061 Thunk::DequantMatMulInt4 {
9062 x,
9063 w_q,
9064 scale,
9065 zp,
9066 dst,
9067 m,
9068 k,
9069 n,
9070 block_size,
9071 is_asymmetric,
9072 } => {
9073 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9074 let n_blocks = k.div_ceil(bs);
9075 unsafe {
9076 let xs = sl(*x, base, m * k);
9077 let w_bytes = std::slice::from_raw_parts(
9078 base.add(*w_q) as *const u8,
9079 (k * n).div_ceil(2),
9080 );
9081 let scales = sl(*scale, base, n_blocks * n);
9082 let zps = if *is_asymmetric {
9083 sl(*zp, base, n_blocks * n)
9084 } else {
9085 &[][..]
9086 };
9087 let out = sl_mut(*dst, base, m * n);
9088 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9089 }
9090 }
9091
9092 Thunk::DequantMatMulFp8 {
9093 x,
9094 w_q,
9095 scale,
9096 dst,
9097 m,
9098 k,
9099 n,
9100 e5m2,
9101 } => {
9102 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9103 unsafe {
9104 let xs = sl(*x, base, m * k);
9105 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
9106 let scales = sl(*scale, base, n);
9107 let out = sl_mut(*dst, base, m * n);
9108 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
9109 }
9110 }
9111
9112 Thunk::DequantMatMulNvfp4 {
9113 x,
9114 w_q,
9115 scale,
9116 global_scale,
9117 dst,
9118 m,
9119 k,
9120 n,
9121 } => {
9122 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9123 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
9124 unsafe {
9125 let xs = sl(*x, base, m * k);
9126 let w_bytes = std::slice::from_raw_parts(
9127 base.add(*w_q) as *const u8,
9128 (k * n).div_ceil(2),
9129 );
9130 let scale_bytes =
9131 std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
9132 let gs = sl(*global_scale, base, 1)[0];
9133 let out = sl_mut(*dst, base, m * n);
9134 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
9135 }
9136 }
9137
9138 Thunk::LoraMatMul {
9139 x,
9140 w,
9141 a,
9142 b,
9143 dst,
9144 m,
9145 k,
9146 n,
9147 r,
9148 scale,
9149 } => {
9150 let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
9151 unsafe {
9152 let xs = sl(*x, base, m * k);
9153 let ws = sl(*w, base, k * n);
9154 let a_s = sl(*a, base, k * r);
9155 let bs = sl(*b, base, r * n);
9156 let out = sl_mut(*dst, base, m * n);
9157 crate::blas::sgemm(xs, ws, out, m, k, n);
9158 let mut tmp = vec![0f32; m * r];
9159 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
9160 if *scale != 1.0 {
9161 for v in tmp.iter_mut() {
9162 *v *= *scale;
9163 }
9164 }
9165 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
9166 }
9167 }
9168
9169 Thunk::Attention {
9170 q,
9171 k,
9172 v,
9173 mask,
9174 out,
9175 batch,
9176 seq,
9177 kv_seq,
9178 heads,
9179 head_dim,
9180 mask_kind,
9181 q_row_stride,
9182 k_row_stride,
9183 v_row_stride,
9184 bhsd,
9185 } => {
9186 let (b, q_s, k_s, nh, dh) = (
9187 *batch as usize,
9188 *seq as usize,
9189 *kv_seq as usize,
9190 *heads as usize,
9191 *head_dim as usize,
9192 );
9193 let hs = nh * dh;
9194 let (qrs, krs, vrs) = if *bhsd {
9197 (dh, dh, dh)
9198 } else {
9199 (
9200 *q_row_stride as usize,
9201 *k_row_stride as usize,
9202 *v_row_stride as usize,
9203 )
9204 };
9205 let bhsd = *bhsd;
9206 let _ = (q_row_stride, k_row_stride, v_row_stride);
9207 let scale = (dh as f32).powf(-0.5);
9208 let ss = q_s * k_s;
9209 let cfg = crate::config::RuntimeConfig::global();
9210 unsafe {
9211 let q_len = if bhsd {
9218 b * nh * q_s * dh
9219 } else {
9220 b * q_s * qrs
9221 };
9222 let k_len = if bhsd {
9223 b * nh * k_s * dh
9224 } else {
9225 b * k_s * krs
9226 };
9227 let v_len = if bhsd {
9228 b * nh * k_s * dh
9229 } else {
9230 b * k_s * vrs
9231 };
9232 let q_data = sl(*q, base, q_len);
9233 let k_data = sl(*k, base, k_len);
9234 let v_data = sl(*v, base, v_len);
9235 let mask_data: &[f32] = match mask_kind {
9236 rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
9237 rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
9238 _ => &[],
9239 };
9240 let out_len = if bhsd {
9241 b * nh * q_s * dh
9242 } else {
9243 b * q_s * hs
9244 };
9245 let out_data = sl_mut(*out, base, out_len);
9246
9247 if bhsd {
9258 let scores = &mut sdpa_scores[..ss];
9259 for bi in 0..b {
9260 for hi in 0..nh {
9261 let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
9262 let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
9263 for qi in 0..q_s {
9265 let q_base = q_head_base + qi * dh;
9266 for ki in 0..k_s {
9267 let k_base = k_head_base + ki * dh;
9268 let mut dot = 0f32;
9269 for d in 0..dh {
9270 dot += q_data[q_base + d] * k_data[k_base + d];
9271 }
9272 scores[qi * k_s + ki] = dot * scale;
9273 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9274 && !mask_data.is_empty()
9275 && mask_data[bi * k_s + ki] < mask_thr
9276 {
9277 scores[qi * k_s + ki] = mask_neg;
9278 }
9279 }
9280 }
9281 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9282 let off = (bi * nh + hi) * q_s * k_s;
9283 for i in 0..q_s * k_s {
9284 scores[i] += mask_data[off + i];
9285 }
9286 }
9287 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9288 crate::kernels::neon_softmax(scores, q_s, k_s);
9289 for qi in 0..q_s {
9291 let o_base = q_head_base + qi * dh;
9292 for d in 0..dh {
9293 out_data[o_base + d] = 0.0;
9294 }
9295 for ki in 0..k_s {
9296 let sc = scores[qi * k_s + ki];
9297 if sc > score_thr {
9298 let v_base = k_head_base + ki * dh;
9299 for d in 0..dh {
9300 out_data[o_base + d] += sc * v_data[v_base + d];
9301 }
9302 }
9303 }
9304 }
9305 }
9306 }
9307 continue;
9308 }
9309
9310 if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
9317 let scores = &mut sdpa_scores[..ss];
9319 #[cfg(target_arch = "aarch64")]
9320 let neon_chunks = dh / 4;
9321
9322 for bi in 0..b {
9323 for hi in 0..nh {
9324 for qi in 0..q_s {
9326 let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
9327 for ki in 0..k_s {
9328 let k_off = bi * k_s * krs + ki * krs + hi * dh;
9329 #[cfg(target_arch = "aarch64")]
9330 let mut dot;
9331 #[cfg(not(target_arch = "aarch64"))]
9332 let mut dot = 0f32;
9333 #[cfg(target_arch = "aarch64")]
9334 {
9335 use std::arch::aarch64::*;
9336 let mut acc = vdupq_n_f32(0.0);
9337 for c in 0..neon_chunks {
9338 let vq =
9339 vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
9340 let vk =
9341 vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
9342 acc = vfmaq_f32(acc, vq, vk);
9343 }
9344 dot = vaddvq_f32(acc);
9345 for d in (neon_chunks * 4)..dh {
9346 dot += q_data[q_off + d] * k_data[k_off + d];
9347 }
9348 }
9349 #[cfg(not(target_arch = "aarch64"))]
9350 for d in 0..dh {
9351 dot += q_data[q_off + d] * k_data[k_off + d];
9352 }
9353 scores[qi * k_s + ki] = dot * scale;
9354 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9361 && !mask_data.is_empty()
9362 && mask_data[bi * k_s + ki] < mask_thr
9363 {
9364 scores[qi * k_s + ki] = mask_neg;
9365 }
9366 }
9367 }
9368
9369 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9370 let off = (bi * nh + hi) * q_s * k_s;
9371 for i in 0..q_s * k_s {
9372 scores[i] += mask_data[off + i];
9373 }
9374 }
9375 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9376 crate::kernels::neon_softmax(scores, q_s, k_s);
9377
9378 for qi in 0..q_s {
9380 let o_off = bi * q_s * hs + qi * hs + hi * dh;
9381 for d in 0..dh {
9383 out_data[o_off + d] = 0.0;
9384 }
9385 for ki in 0..k_s {
9386 let sc = scores[qi * k_s + ki];
9387 if sc > score_thr {
9388 let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
9389 #[cfg(target_arch = "aarch64")]
9390 {
9391 use std::arch::aarch64::*;
9392 let vsc = vdupq_n_f32(sc);
9393 for c in 0..neon_chunks {
9394 let off = c * 4;
9395 let vo = vld1q_f32(
9396 out_data.as_ptr().add(o_off + off),
9397 );
9398 let vv =
9399 vld1q_f32(v_data.as_ptr().add(v_off + off));
9400 vst1q_f32(
9401 out_data.as_mut_ptr().add(o_off + off),
9402 vfmaq_f32(vo, vsc, vv),
9403 );
9404 }
9405 }
9406 #[cfg(not(target_arch = "aarch64"))]
9407 for d in 0..dh {
9408 out_data[o_off + d] += sc * v_data[v_off + d];
9409 }
9410 }
9411 }
9412 }
9413 }
9414 }
9415 } else {
9416 let total_work = b * nh;
9418 let q_addr = q_data.as_ptr() as usize;
9419 let k_addr = k_data.as_ptr() as usize;
9420 let v_addr = v_data.as_ptr() as usize;
9421 let m_addr = mask_data.as_ptr() as usize;
9422 let o_addr = out_data.as_mut_ptr() as usize;
9423 let sc_addr = sdpa_scores.as_mut_ptr() as usize;
9424
9425 crate::pool::par_for(total_work, 1, &|off, cnt| {
9426 for idx in off..off + cnt {
9427 let bi = idx / nh;
9428 let hi = idx % nh;
9429
9430 let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
9431 let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
9432 let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
9433 let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
9434 let sc = std::slice::from_raw_parts_mut(
9435 (sc_addr as *mut f32).add(idx * ss),
9436 ss,
9437 );
9438
9439 crate::blas::sgemm_general(
9442 q_start,
9443 k_start,
9444 sc.as_mut_ptr(),
9445 q_s,
9446 k_s,
9447 dh,
9448 scale,
9449 0.0,
9450 qrs,
9451 krs,
9452 k_s,
9453 false,
9454 true,
9455 );
9456
9457 match mask_kind {
9458 rlx_ir::op::MaskKind::Custom => {
9459 let mask_bi = std::slice::from_raw_parts(
9460 (m_addr as *const f32).add(bi * k_s),
9461 k_s,
9462 );
9463 for ki in 0..k_s {
9464 if mask_bi[ki] < mask_thr {
9465 for qi in 0..q_s {
9466 sc[qi * k_s + ki] = mask_neg;
9467 }
9468 }
9469 }
9470 }
9471 rlx_ir::op::MaskKind::Bias => {
9472 let bias = std::slice::from_raw_parts(
9474 (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
9475 q_s * k_s,
9476 );
9477 for i in 0..q_s * k_s {
9478 sc[i] += bias[i];
9479 }
9480 }
9481 _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
9482 }
9483
9484 crate::kernels::neon_softmax(sc, q_s, k_s);
9485
9486 crate::blas::sgemm_general(
9490 sc.as_ptr(),
9491 v_start,
9492 o_start,
9493 q_s,
9494 dh,
9495 k_s,
9496 1.0,
9497 0.0,
9498 k_s,
9499 vrs,
9500 hs,
9501 false,
9502 false,
9503 );
9504 }
9505 });
9506 }
9507 }
9508 }
9509
9510 Thunk::AttentionBackward {
9511 q,
9512 k,
9513 v,
9514 dy,
9515 mask,
9516 out,
9517 batch,
9518 seq,
9519 kv_seq,
9520 heads,
9521 head_dim,
9522 mask_kind,
9523 wrt,
9524 bhsd,
9525 } => {
9526 let (b, q_s, k_s, nh, dh) = (
9527 *batch as usize,
9528 *seq as usize,
9529 *kv_seq as usize,
9530 *heads as usize,
9531 *head_dim as usize,
9532 );
9533 unsafe {
9534 let q_len = if *bhsd {
9535 b * nh * q_s * dh
9536 } else {
9537 b * q_s * nh * dh
9538 };
9539 let k_len = if *bhsd {
9540 b * nh * k_s * dh
9541 } else {
9542 b * k_s * nh * dh
9543 };
9544 let out_len = match wrt {
9545 rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
9546 k_len
9547 }
9548 rlx_ir::op::AttentionBwdWrt::Query => q_len,
9549 };
9550 let q_data = sl(*q, base, q_len);
9551 let k_data = sl(*k, base, k_len);
9552 let v_data = sl(*v, base, k_len);
9553 let dy_data = sl(*dy, base, q_len);
9554 let out_data = sl_mut(*out, base, out_len);
9555 let mask_data: &[f32] = if *mask != 0 {
9556 let ml = match mask_kind {
9557 rlx_ir::op::MaskKind::Custom => b * k_s,
9558 rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
9559 _ => 0,
9560 };
9561 sl(*mask, base, ml)
9562 } else {
9563 &[]
9564 };
9565 crate::attention_bwd::attention_backward(
9566 *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
9567 *mask_kind, mask_data, *bhsd,
9568 );
9569 }
9570 }
9571
9572 Thunk::ActivationInPlace { data, len, act } => {
9573 let len = *len as usize;
9574 unsafe {
9575 let d = sl_mut(*data, base, len);
9576 match act {
9577 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
9578 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
9579 Activation::Silu => crate::kernels::par_silu_inplace(d),
9580 Activation::Relu => {
9581 for v in d.iter_mut() {
9582 *v = v.max(0.0);
9583 }
9584 }
9585 Activation::Sigmoid => {
9586 for v in d.iter_mut() {
9587 *v = 1.0 / (1.0 + (-*v).exp());
9588 }
9589 }
9590 Activation::Tanh => {
9591 for v in d.iter_mut() {
9592 *v = v.tanh();
9593 }
9594 }
9595 Activation::Exp => {
9596 for v in d.iter_mut() {
9597 *v = v.exp();
9598 }
9599 }
9600 Activation::Log => {
9601 for v in d.iter_mut() {
9602 *v = v.ln();
9603 }
9604 }
9605 Activation::Sqrt => {
9606 for v in d.iter_mut() {
9607 *v = v.sqrt();
9608 }
9609 }
9610 Activation::Rsqrt => {
9611 for v in d.iter_mut() {
9612 *v = 1.0 / v.sqrt();
9613 }
9614 }
9615 Activation::Neg => {
9616 for v in d.iter_mut() {
9617 *v = -*v;
9618 }
9619 }
9620 Activation::Abs => {
9621 for v in d.iter_mut() {
9622 *v = v.abs();
9623 }
9624 }
9625 Activation::Round => {
9626 for v in d.iter_mut() {
9627 *v = v.round();
9628 }
9629 }
9630 Activation::Sin => {
9631 for v in d.iter_mut() {
9632 *v = v.sin();
9633 }
9634 }
9635 Activation::Cos => {
9636 for v in d.iter_mut() {
9637 *v = v.cos();
9638 }
9639 }
9640 Activation::Tan => {
9641 for v in d.iter_mut() {
9642 *v = v.tan();
9643 }
9644 }
9645 Activation::Atan => {
9646 for v in d.iter_mut() {
9647 *v = v.atan();
9648 }
9649 }
9650 }
9651 }
9652 }
9653
9654 Thunk::FusedAttnBlock {
9655 hidden,
9656 qkv_w,
9657 out_w,
9658 mask,
9659 out,
9660 qkv_b,
9661 out_b,
9662 cos,
9663 sin,
9664 cos_len,
9665 batch,
9666 seq,
9667 hs,
9668 nh,
9669 dh,
9670 has_bias,
9671 has_rope,
9672 } => {
9673 let (b, s) = (*batch as usize, *seq as usize);
9674 let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
9675 let m = b * s;
9676 let scale = (d_h as f32).powf(-0.5);
9677 let half = d_h / 2;
9678 unsafe {
9679 let inp = sl(*hidden, base, m * h);
9680 let wq = sl(*qkv_w, base, h * 3 * h);
9681 let wo = sl(*out_w, base, h * h);
9682 let mk = sl(*mask, base, b * s);
9683 let dst = sl_mut(*out, base, m * h);
9684
9685 let mut qkv = vec![0f32; m * 3 * h];
9687 let mut attn_out = vec![0f32; m * h];
9688 let mut scores_buf = vec![0f32; s * s]; crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
9692 if *has_bias {
9693 let bias = sl(*qkv_b, base, 3 * h);
9694 crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
9695 }
9696
9697 #[cfg(target_arch = "aarch64")]
9700 let neon_chunks = d_h / 4;
9701 #[cfg(target_arch = "aarch64")]
9702 let _rope_chunks = half / 4;
9703
9704 for bi in 0..b {
9705 for hi in 0..n_h {
9706 for qi in 0..s {
9708 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9709 for ki in 0..s {
9710 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9711 let mut dot = 0f32;
9712
9713 if *has_rope {
9714 let q_cos = qi * half;
9716 let k_cos = ki * half;
9717 let cos_tab = sl(*cos, base, *cos_len as usize);
9718 let sin_tab = sl(*sin, base, *cos_len as usize);
9719 for i in 0..half {
9722 let q1 = qkv[q_base + i];
9723 let q2 = qkv[q_base + half + i];
9724 let k1 = qkv[k_base + i];
9725 let k2 = qkv[k_base + half + i];
9726 let c_q = cos_tab[q_cos + i];
9727 let s_q = sin_tab[q_cos + i];
9728 let c_k = cos_tab[k_cos + i];
9729 let s_k = sin_tab[k_cos + i];
9730 let qr1 = q1 * c_q - q2 * s_q;
9731 let kr1 = k1 * c_k - k2 * s_k;
9732 let qr2 = q2 * c_q + q1 * s_q;
9733 let kr2 = k2 * c_k + k1 * s_k;
9734 dot += qr1 * kr1 + qr2 * kr2;
9735 }
9736 } else {
9737 #[cfg(target_arch = "aarch64")]
9739 {
9740 use std::arch::aarch64::*;
9741 let mut acc = vdupq_n_f32(0.0);
9742 for c in 0..neon_chunks {
9743 let vq =
9744 vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
9745 let vk =
9746 vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
9747 acc = vfmaq_f32(acc, vq, vk);
9748 }
9749 dot = vaddvq_f32(acc);
9750 for d in (neon_chunks * 4)..d_h {
9751 dot += qkv[q_base + d] * qkv[k_base + d];
9752 }
9753 }
9754 #[cfg(not(target_arch = "aarch64"))]
9755 for d in 0..d_h {
9756 dot += qkv[q_base + d] * qkv[k_base + d];
9757 }
9758 }
9759
9760 scores_buf[qi * s + ki] = dot * scale;
9761 if mk[bi * s + ki] < mask_thr {
9762 scores_buf[qi * s + ki] = mask_neg;
9763 }
9764 }
9765 }
9766
9767 crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
9769
9770 for qi in 0..s {
9772 let o_base = bi * s * h + qi * h + hi * d_h;
9773 for d in 0..d_h {
9774 attn_out[o_base + d] = 0.0;
9775 }
9776 for ki in 0..s {
9777 let sc = scores_buf[qi * s + ki];
9778 if sc > score_thr {
9779 let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9780 #[cfg(target_arch = "aarch64")]
9781 {
9782 use std::arch::aarch64::*;
9783 let vsc = vdupq_n_f32(sc);
9784 for c in 0..neon_chunks {
9785 let off = c * 4;
9786 let vo =
9787 vld1q_f32(attn_out.as_ptr().add(o_base + off));
9788 let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
9789 vst1q_f32(
9790 attn_out.as_mut_ptr().add(o_base + off),
9791 vfmaq_f32(vo, vsc, vv),
9792 );
9793 }
9794 }
9795 #[cfg(not(target_arch = "aarch64"))]
9796 for d in 0..d_h {
9797 attn_out[o_base + d] += sc * qkv[v_base + d];
9798 }
9799 }
9800 }
9801 }
9802 }
9803 }
9804
9805 crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
9807 if *has_bias {
9808 let bias = sl(*out_b, base, h);
9809 crate::blas::bias_add(dst, bias, m, h);
9810 }
9811 }
9812 }
9813
9814 Thunk::Rope {
9815 src,
9816 cos,
9817 sin,
9818 dst,
9819 batch,
9820 seq,
9821 hidden,
9822 head_dim,
9823 n_rot,
9824 cos_len,
9825 src_row_stride,
9826 } => {
9827 let (b, s, hs, dh, nr) = (
9828 *batch as usize,
9829 *seq as usize,
9830 *hidden as usize,
9831 *head_dim as usize,
9832 *n_rot as usize,
9833 );
9834 let tab_half = dh / 2;
9835 let rot_half = nr / 2;
9836 let nh = hs / dh;
9837 let cl = *cos_len as usize;
9838 let src_rs = *src_row_stride as usize;
9839 unsafe {
9840 let x = sl(*src, base, b * s * src_rs);
9841 let cos_tab = sl(*cos, base, cl);
9842 let sin_tab = sl(*sin, base, cl);
9843 let out = sl_mut(*dst, base, b * s * hs);
9844
9845 let total = b * s;
9846 let x_ptr = x.as_ptr() as usize;
9847 let o_ptr = out.as_mut_ptr() as usize;
9848 let c_ptr = cos_tab.as_ptr() as usize;
9849 let s_ptr = sin_tab.as_ptr() as usize;
9850
9851 crate::pool::par_for(total, 4, &|off, cnt| {
9852 for idx in off..off + cnt {
9853 let bi = idx / s;
9854 let si = idx % s;
9855 let tab_off = si * tab_half;
9856
9857 for hi in 0..nh {
9858 let src_base = bi * s * src_rs + si * src_rs + hi * dh;
9859 let dst_base = bi * s * hs + si * hs + hi * dh;
9860 let xp = (x_ptr as *const f32).add(src_base);
9861 let op = (o_ptr as *mut f32).add(dst_base);
9862 let cp = (c_ptr as *const f32).add(tab_off);
9863 let sp = (s_ptr as *const f32).add(tab_off);
9864
9865 for i in 0..rot_half {
9866 let x1 = *xp.add(i);
9867 let x2 = *xp.add(rot_half + i);
9868 let cv = *cp.add(i);
9869 let sv = *sp.add(i);
9870 *op.add(i) = x1 * cv - x2 * sv;
9871 *op.add(rot_half + i) = x2 * cv + x1 * sv;
9872 }
9873 for j in nr..dh {
9874 *op.add(j) = *xp.add(j);
9875 }
9876 }
9877 }
9878 });
9879 }
9880 }
9881 Thunk::FusedBertLayer {
9882 hidden,
9883 qkv_w,
9884 qkv_b,
9885 out_w,
9886 out_b,
9887 mask,
9888 ln1_g,
9889 ln1_b,
9890 eps1,
9891 fc1_w,
9892 fc1_b,
9893 fc2_w,
9894 fc2_b,
9895 ln2_g,
9896 ln2_b,
9897 eps2,
9898 out,
9899 batch,
9900 seq,
9901 hs,
9902 nh,
9903 dh,
9904 int_dim,
9905 } => {
9906 let (b, s, h, n_h, d_h) = (
9907 *batch as usize,
9908 *seq as usize,
9909 *hs as usize,
9910 *nh as usize,
9911 *dh as usize,
9912 );
9913 let m = b * s;
9914 let id = *int_dim as usize;
9915 let scale = (d_h as f32).powf(-0.5);
9916 let _half = d_h / 2;
9917 #[cfg(target_arch = "aarch64")]
9918 let neon_chunks = d_h / 4;
9919 unsafe {
9920 let inp = sl(*hidden, base, m * h);
9921 let dst = sl_mut(*out, base, m * h);
9922 let mk = sl(*mask, base, b * s);
9923
9924 let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
9926 let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
9927 let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
9928 let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
9929 let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
9930 let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
9931
9932 crate::blas::par_sgemm_bias(
9934 inp,
9935 sl(*qkv_w, base, h * 3 * h),
9936 sl(*qkv_b, base, 3 * h),
9937 qkv,
9938 m,
9939 h,
9940 3 * h,
9941 );
9942
9943 for bi in 0..b {
9945 for hi in 0..n_h {
9946 for qi in 0..s {
9947 for ki in 0..s {
9948 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9949 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9950 #[cfg(target_arch = "aarch64")]
9951 let dot;
9952 #[cfg(not(target_arch = "aarch64"))]
9953 let mut dot = 0f32;
9954 #[cfg(target_arch = "aarch64")]
9955 {
9956 use std::arch::aarch64::*;
9957 let mut acc = vdupq_n_f32(0.0);
9958 for c in 0..neon_chunks {
9959 acc = vfmaq_f32(
9960 acc,
9961 vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
9962 vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
9963 );
9964 }
9965 dot = vaddvq_f32(acc);
9966 }
9967 #[cfg(not(target_arch = "aarch64"))]
9968 for d in 0..d_h {
9969 dot += qkv[q_base + d] * qkv[k_base + d];
9970 }
9971 sc[qi * s + ki] = dot * scale;
9972 if mk[bi * s + ki] < mask_thr {
9973 sc[qi * s + ki] = mask_neg;
9974 }
9975 }
9976 }
9977 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
9978 for qi in 0..s {
9979 let o = bi * s * h + qi * h + hi * d_h;
9980 for d in 0..d_h {
9981 attn[o + d] = 0.0;
9982 }
9983 for ki in 0..s {
9984 let w = sc[qi * s + ki];
9985 if w > score_thr {
9986 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9987 #[cfg(target_arch = "aarch64")]
9988 {
9989 use std::arch::aarch64::*;
9990 let vw = vdupq_n_f32(w);
9991 for c in 0..neon_chunks {
9992 let off = c * 4;
9993 vst1q_f32(
9994 attn.as_mut_ptr().add(o + off),
9995 vfmaq_f32(
9996 vld1q_f32(attn.as_ptr().add(o + off)),
9997 vw,
9998 vld1q_f32(qkv.as_ptr().add(v + off)),
9999 ),
10000 );
10001 }
10002 }
10003 #[cfg(not(target_arch = "aarch64"))]
10004 for d in 0..d_h {
10005 attn[o + d] += w * qkv[v + d];
10006 }
10007 }
10008 }
10009 }
10010 }
10011 }
10012
10013 crate::blas::sgemm_bias(
10015 attn,
10016 sl(*out_w, base, h * h),
10017 sl(*out_b, base, h),
10018 res,
10019 m,
10020 h,
10021 h,
10022 );
10023 #[cfg(target_arch = "aarch64")]
10024 {
10025 use std::arch::aarch64::*;
10026 let chunks_h = (m * h) / 4;
10027 for c in 0..chunks_h {
10028 let off = c * 4;
10029 vst1q_f32(
10030 res.as_mut_ptr().add(off),
10031 vaddq_f32(
10032 vld1q_f32(res.as_ptr().add(off)),
10033 vld1q_f32(inp.as_ptr().add(off)),
10034 ),
10035 );
10036 }
10037 for i in (chunks_h * 4)..(m * h) {
10038 res[i] += inp[i];
10039 }
10040 }
10041 #[cfg(not(target_arch = "aarch64"))]
10042 for i in 0..m * h {
10043 res[i] += inp[i];
10044 }
10045
10046 let g1 = sl(*ln1_g, base, h);
10048 let b1 = sl(*ln1_b, base, h);
10049 for r in 0..m {
10050 crate::kernels::layer_norm_row(
10051 &res[r * h..(r + 1) * h],
10052 g1,
10053 b1,
10054 &mut normed[r * h..(r + 1) * h],
10055 h,
10056 *eps1,
10057 );
10058 }
10059
10060 crate::blas::par_sgemm_bias(
10062 normed,
10063 sl(*fc1_w, base, h * id),
10064 sl(*fc1_b, base, id),
10065 ffn,
10066 m,
10067 h,
10068 id,
10069 );
10070 crate::kernels::par_gelu_inplace(ffn);
10071
10072 crate::blas::par_sgemm_bias(
10074 ffn,
10075 sl(*fc2_w, base, id * h),
10076 sl(*fc2_b, base, h),
10077 res,
10078 m,
10079 id,
10080 h,
10081 );
10082 #[cfg(target_arch = "aarch64")]
10083 {
10084 use std::arch::aarch64::*;
10085 let chunks_h = (m * h) / 4;
10086 for c in 0..chunks_h {
10087 let off = c * 4;
10088 vst1q_f32(
10089 res.as_mut_ptr().add(off),
10090 vaddq_f32(
10091 vld1q_f32(res.as_ptr().add(off)),
10092 vld1q_f32(normed.as_ptr().add(off)),
10093 ),
10094 );
10095 }
10096 for i in (chunks_h * 4)..(m * h) {
10097 res[i] += normed[i];
10098 }
10099 }
10100 #[cfg(not(target_arch = "aarch64"))]
10101 for i in 0..m * h {
10102 res[i] += normed[i];
10103 }
10104
10105 let g2 = sl(*ln2_g, base, h);
10107 let b2 = sl(*ln2_b, base, h);
10108 for r in 0..m {
10109 crate::kernels::layer_norm_row(
10110 &res[r * h..(r + 1) * h],
10111 g2,
10112 b2,
10113 &mut dst[r * h..(r + 1) * h],
10114 h,
10115 *eps2,
10116 );
10117 }
10118 }
10119 }
10120
10121 Thunk::FusedNomicLayer {
10122 hidden,
10123 qkv_w,
10124 out_w,
10125 mask,
10126 cos,
10127 sin,
10128 cos_len,
10129 ln1_g,
10130 ln1_b,
10131 eps1,
10132 fc11_w,
10133 fc12_w: _,
10134 fc2_w,
10135 ln2_g,
10136 ln2_b,
10137 eps2,
10138 out,
10139 batch,
10140 seq,
10141 hs,
10142 nh,
10143 dh,
10144 int_dim,
10145 } => {
10146 let (b, s, h, n_h, d_h) = (
10147 *batch as usize,
10148 *seq as usize,
10149 *hs as usize,
10150 *nh as usize,
10151 *dh as usize,
10152 );
10153 let m = b * s;
10154 let id = *int_dim as usize;
10155 let scale = (d_h as f32).powf(-0.5);
10156 let half_dh = d_h / 2;
10157 #[cfg(target_arch = "aarch64")]
10158 let neon_chunks = d_h / 4;
10159 unsafe {
10160 let inp = sl(*hidden, base, m * h);
10161 let dst = sl_mut(*out, base, m * h);
10162 let mk = sl(*mask, base, b * s);
10163 let cos_tab = sl(*cos, base, *cos_len as usize);
10164 let sin_tab = sl(*sin, base, *cos_len as usize);
10165 let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
10167
10168 let mut qkv = vec![0f32; m * 3 * h];
10169 let mut attn = vec![0f32; m * h];
10170 let mut res = vec![0f32; m * h];
10171 let mut normed = vec![0f32; m * h];
10172 let mut ffn_concat = vec![0f32; m * 2 * id]; let mut sc = vec![0f32; s * s];
10174
10175 crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
10177
10178 for bi in 0..b {
10180 for hi in 0..n_h {
10181 for qi in 0..s {
10182 for ki in 0..s {
10183 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10184 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10185 let mut dot = 0f32;
10186 for i in 0..half_dh {
10187 let q1 = qkv[q_base + i];
10188 let q2 = qkv[q_base + half_dh + i];
10189 let k1 = qkv[k_base + i];
10190 let k2 = qkv[k_base + half_dh + i];
10191 let cq = cos_tab[qi * half_dh + i];
10192 let sq = sin_tab[qi * half_dh + i];
10193 let ck = cos_tab[ki * half_dh + i];
10194 let sk = sin_tab[ki * half_dh + i];
10195 dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
10196 + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
10197 }
10198 sc[qi * s + ki] = dot * scale;
10199 if mk[bi * s + ki] < mask_thr {
10200 sc[qi * s + ki] = mask_neg;
10201 }
10202 }
10203 }
10204 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
10205 for qi in 0..s {
10206 let o = bi * s * h + qi * h + hi * d_h;
10207 for d in 0..d_h {
10208 attn[o + d] = 0.0;
10209 }
10210 for ki in 0..s {
10211 let w = sc[qi * s + ki];
10212 if w > score_thr {
10213 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10214 #[cfg(target_arch = "aarch64")]
10215 {
10216 use std::arch::aarch64::*;
10217 let vw = vdupq_n_f32(w);
10218 for c in 0..neon_chunks {
10219 let off = c * 4;
10220 vst1q_f32(
10221 attn.as_mut_ptr().add(o + off),
10222 vfmaq_f32(
10223 vld1q_f32(attn.as_ptr().add(o + off)),
10224 vw,
10225 vld1q_f32(qkv.as_ptr().add(v + off)),
10226 ),
10227 );
10228 }
10229 }
10230 #[cfg(not(target_arch = "aarch64"))]
10231 for d in 0..d_h {
10232 attn[o + d] += w * qkv[v + d];
10233 }
10234 }
10235 }
10236 }
10237 }
10238 }
10239
10240 crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
10242 for i in 0..m * h {
10243 res[i] += inp[i];
10244 }
10245
10246 let g1 = sl(*ln1_g, base, h);
10248 let b1 = sl(*ln1_b, base, h);
10249 for r in 0..m {
10250 crate::kernels::layer_norm_row(
10251 &res[r * h..(r + 1) * h],
10252 g1,
10253 b1,
10254 &mut normed[r * h..(r + 1) * h],
10255 h,
10256 *eps1,
10257 );
10258 }
10259
10260 crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
10262 for row in 0..m {
10265 let bo = row * 2 * id;
10266 for j in 0..id {
10268 let x = ffn_concat[bo + id + j];
10269 ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
10270 }
10271 for j in 0..id {
10273 ffn_concat[bo + j] *= ffn_concat[bo + id + j];
10274 }
10275 }
10276
10277 crate::blas::sgemm_general(
10282 ffn_concat.as_ptr(),
10283 sl(*fc2_w, base, id * h).as_ptr(),
10284 res.as_mut_ptr(),
10285 m,
10286 h,
10287 id,
10288 1.0,
10289 0.0,
10290 2 * id,
10291 h,
10292 h,
10293 false,
10294 false,
10295 );
10296 for i in 0..m * h {
10297 res[i] += normed[i];
10298 }
10299
10300 let g2 = sl(*ln2_g, base, h);
10302 let b2 = sl(*ln2_b, base, h);
10303 for r in 0..m {
10304 crate::kernels::layer_norm_row(
10305 &res[r * h..(r + 1) * h],
10306 g2,
10307 b2,
10308 &mut dst[r * h..(r + 1) * h],
10309 h,
10310 *eps2,
10311 );
10312 }
10313 }
10314 }
10315
10316 Thunk::FusedSwiGLU {
10317 src,
10318 dst,
10319 n_half,
10320 total,
10321 gate_first,
10322 } => {
10323 let n = *n_half as usize;
10324 let t = *total as usize;
10325 let outer = t / n;
10326 let in_total = outer * 2 * n;
10327 let gate_first = *gate_first;
10328 unsafe {
10329 let inp = sl(*src, base, in_total);
10330 let out = sl_mut(*dst, base, t);
10331 for o in 0..outer {
10332 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
10333 let out_row = &mut out[o * n..(o + 1) * n];
10334 for i in 0..n {
10335 let (up, gate) = if gate_first {
10336 (in_row[n + i], in_row[i])
10337 } else {
10338 (in_row[i], in_row[n + i])
10339 };
10340 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
10341 }
10342 }
10343 }
10344 }
10345
10346 Thunk::Concat {
10347 dst,
10348 outer,
10349 inner,
10350 total_axis,
10351 inputs,
10352 } => {
10353 let outer = *outer as usize;
10354 let inner = *inner as usize;
10355 let total_axis = *total_axis as usize;
10356 let row_stride = total_axis * inner;
10357 let out_total = outer * row_stride;
10358 unsafe {
10359 let out = sl_mut(*dst, base, out_total);
10360 let mut cum: usize = 0;
10361 for (src_off, in_axis) in inputs {
10362 let in_axis = *in_axis as usize;
10363 let copy_per_row = in_axis * inner;
10364 let dst_col_off = cum * inner;
10365 let in_total = outer * copy_per_row;
10366 let inp = sl(*src_off, base, in_total);
10367 for o in 0..outer {
10368 let dst_row_start = o * row_stride + dst_col_off;
10369 let src_row_start = o * copy_per_row;
10370 out[dst_row_start..dst_row_start + copy_per_row]
10371 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10372 }
10373 cum += in_axis;
10374 }
10375 }
10376 }
10377
10378 Thunk::ConcatF64 {
10379 dst,
10380 outer,
10381 inner,
10382 total_axis,
10383 inputs,
10384 } => {
10385 let outer = *outer as usize;
10386 let inner = *inner as usize;
10387 let total_axis = *total_axis as usize;
10388 let row_stride = total_axis * inner;
10389 let out_total = outer * row_stride;
10390 unsafe {
10391 let out = sl_mut_f64(*dst, base, out_total);
10392 let mut cum: usize = 0;
10393 for (src_off, in_axis) in inputs {
10394 let in_axis = *in_axis as usize;
10395 let copy_per_row = in_axis * inner;
10396 let dst_col_off = cum * inner;
10397 let in_total = outer * copy_per_row;
10398 let inp = sl_f64(*src_off, base, in_total);
10399 for o in 0..outer {
10400 let dst_row_start = o * row_stride + dst_col_off;
10401 let src_row_start = o * copy_per_row;
10402 out[dst_row_start..dst_row_start + copy_per_row]
10403 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10404 }
10405 cum += in_axis;
10406 }
10407 }
10408 }
10409
10410 Thunk::Compare {
10411 lhs,
10412 rhs,
10413 dst,
10414 len,
10415 op,
10416 } => {
10417 let len = *len as usize;
10418 unsafe {
10419 let l = sl(*lhs, base, len);
10420 let r = sl(*rhs, base, len);
10421 let o = sl_mut(*dst, base, len);
10422 for i in 0..len {
10423 o[i] = match op {
10424 CmpOp::Eq => (l[i] == r[i]) as u32 as f32,
10425 CmpOp::Ne => (l[i] != r[i]) as u32 as f32,
10426 CmpOp::Lt => (l[i] < r[i]) as u32 as f32,
10427 CmpOp::Le => (l[i] <= r[i]) as u32 as f32,
10428 CmpOp::Gt => (l[i] > r[i]) as u32 as f32,
10429 CmpOp::Ge => (l[i] >= r[i]) as u32 as f32,
10430 };
10431 }
10432 }
10433 }
10434
10435 Thunk::Where {
10436 cond,
10437 on_true,
10438 on_false,
10439 dst,
10440 len,
10441 } => {
10442 let len = *len as usize;
10443 unsafe {
10444 let c = sl(*cond, base, len);
10445 let t = sl(*on_true, base, len);
10446 let e = sl(*on_false, base, len);
10447 let o = sl_mut(*dst, base, len);
10448 for i in 0..len {
10449 o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
10451 }
10452 }
10453 }
10454
10455 Thunk::ScatterAdd {
10456 updates,
10457 indices,
10458 dst,
10459 num_updates,
10460 out_dim,
10461 trailing,
10462 } => {
10463 let num_updates = *num_updates as usize;
10464 let out_dim = *out_dim as usize;
10465 let trailing = *trailing as usize;
10466 unsafe {
10467 let upd = sl(*updates, base, num_updates * trailing);
10468 let ids = sl(*indices, base, num_updates);
10469 let out = sl_mut(*dst, base, out_dim * trailing);
10470 for v in out.iter_mut() {
10472 *v = 0.0;
10473 }
10474 for i in 0..num_updates {
10475 let row = ids[i] as usize;
10476 debug_assert!(row < out_dim, "ScatterAdd index out of range");
10477 let src_off = i * trailing;
10478 let dst_off = row * trailing;
10479 for j in 0..trailing {
10480 out[dst_off + j] += upd[src_off + j];
10481 }
10482 }
10483 }
10484 }
10485
10486 Thunk::GroupedMatMul {
10487 input,
10488 weight,
10489 expert_idx,
10490 dst,
10491 m,
10492 k_dim,
10493 n,
10494 num_experts,
10495 } => {
10496 let m = *m as usize;
10497 let k_dim = *k_dim as usize;
10498 let n = *n as usize;
10499 let num_experts = *num_experts as usize;
10500 unsafe {
10501 let inp = sl(*input, base, m * k_dim);
10502 let wt = sl(*weight, base, num_experts * k_dim * n);
10503 let ids = sl(*expert_idx, base, m);
10504 let out = sl_mut(*dst, base, m * n);
10505
10506 let mut counts = vec![0usize; num_experts];
10509 for i in 0..m {
10510 let e = ids[i] as usize;
10511 debug_assert!(
10512 e < num_experts,
10513 "expert_idx out of range: {e} >= {num_experts}"
10514 );
10515 counts[e] += 1;
10516 }
10517 let mut offsets = vec![0usize; num_experts + 1];
10519 for e in 0..num_experts {
10520 offsets[e + 1] = offsets[e] + counts[e];
10521 }
10522 let mut packed_in = vec![0f32; m * k_dim];
10526 let mut original_pos = vec![0usize; m];
10527 let mut write_idx = vec![0usize; num_experts];
10528 for i in 0..m {
10529 let e = ids[i] as usize;
10530 let dst_row = offsets[e] + write_idx[e];
10531 packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
10532 .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
10533 original_pos[dst_row] = i;
10534 write_idx[e] += 1;
10535 }
10536
10537 let mut packed_out = vec![0f32; m * n];
10541 let expert_stride = k_dim * n;
10542 let gmm_ord = crate::moe_residency::next_gmm_ord();
10543 let moe_layer = gmm_ord / 3;
10544 for e in 0..num_experts {
10545 let count = counts[e];
10546 if count == 0 {
10547 continue;
10548 }
10549 crate::moe_residency::record_expert_tokens(moe_layer, e, count);
10550 let in_start = offsets[e];
10551 let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
10552 let w_slab: &[f32] =
10553 if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
10554 if let Some(ptr) =
10555 crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
10556 {
10557 std::slice::from_raw_parts(ptr, expert_stride)
10558 } else {
10559 &wt[e * expert_stride..(e + 1) * expert_stride]
10560 }
10561 } else {
10562 &wt[e * expert_stride..(e + 1) * expert_stride]
10563 };
10564 let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
10565 crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
10566 }
10567
10568 for packed_idx in 0..m {
10570 let i = original_pos[packed_idx];
10571 out[i * n..(i + 1) * n]
10572 .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
10573 }
10574 }
10575 }
10576
10577 Thunk::DequantGroupedMatMulGguf {
10578 input,
10579 w_q,
10580 expert_idx,
10581 dst,
10582 m,
10583 k_dim,
10584 n,
10585 num_experts,
10586 scheme,
10587 } => {
10588 let m = *m as usize;
10589 let k_dim = *k_dim as usize;
10590 let n = *n as usize;
10591 let num_experts = *num_experts as usize;
10592 let block_elems = scheme.gguf_block_size() as usize;
10593 let block_bytes = scheme.gguf_block_bytes() as usize;
10594 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10595 unsafe {
10596 let inp = sl(*input, base, m * k_dim);
10597 let wt = std::slice::from_raw_parts(
10598 base.add(*w_q) as *const u8,
10599 num_experts * slab_bytes,
10600 );
10601 let ids = sl(*expert_idx, base, m);
10602 let out = sl_mut(*dst, base, m * n);
10603 crate::gguf_matmul::gguf_grouped_matmul_bt(
10604 inp,
10605 wt,
10606 ids,
10607 out,
10608 m,
10609 k_dim,
10610 n,
10611 num_experts,
10612 *scheme,
10613 );
10614 }
10615 }
10616
10617 Thunk::DequantMoEWeightsGguf {
10618 w_q,
10619 dst,
10620 k_dim,
10621 n,
10622 num_experts,
10623 scheme,
10624 } => {
10625 let k_dim = *k_dim as usize;
10626 let n = *n as usize;
10627 let num_experts = *num_experts as usize;
10628 let block_elems = scheme.gguf_block_size() as usize;
10629 let block_bytes = scheme.gguf_block_bytes() as usize;
10630 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10631 unsafe {
10632 let wt = std::slice::from_raw_parts(
10633 base.add(*w_q) as *const u8,
10634 num_experts * slab_bytes,
10635 );
10636 let out = sl_mut(*dst, base, num_experts * k_dim * n);
10637 crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
10638 wt,
10639 out,
10640 num_experts,
10641 k_dim,
10642 n,
10643 *scheme,
10644 );
10645 }
10646 }
10647
10648 Thunk::TopK {
10649 src,
10650 dst,
10651 outer,
10652 axis_dim,
10653 k,
10654 } => {
10655 let outer = *outer as usize;
10656 let axis_dim = *axis_dim as usize;
10657 let k = *k as usize;
10658 unsafe {
10659 let inp = sl(*src, base, outer * axis_dim);
10660 let out = sl_mut(*dst, base, outer * k);
10661 let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
10665 for o in 0..outer {
10666 row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
10667 for ki in 0..k {
10668 let mut best_i = 0usize;
10670 let mut best_v = row_buf[0];
10671 for i in 1..axis_dim {
10672 let v = row_buf[i];
10673 if v > best_v {
10674 best_v = v;
10675 best_i = i;
10676 }
10677 }
10678 out[o * k + ki] = best_i as f32;
10679 row_buf[best_i] = f32::NEG_INFINITY;
10682 }
10683 }
10684 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
10685 cap.push_topk_f32(&out[..outer * k], axis_dim);
10686 }
10687 }
10688 }
10689
10690 Thunk::Reduce {
10691 src,
10692 dst,
10693 outer,
10694 reduced,
10695 inner,
10696 op,
10697 } => {
10698 let outer = *outer as usize;
10699 let reduced = *reduced as usize;
10700 let inner = *inner as usize;
10701 let in_total = outer * reduced * inner;
10702 let out_total = outer * inner;
10703 unsafe {
10704 let inp = sl(*src, base, in_total);
10705 let out = sl_mut(*dst, base, out_total);
10706 for o in 0..outer {
10707 for i in 0..inner {
10708 let mut acc = match op {
10709 ReduceOp::Max => f32::NEG_INFINITY,
10710 ReduceOp::Min => f32::INFINITY,
10711 ReduceOp::Prod => 1.0f32,
10712 _ => 0.0f32, };
10714 for r in 0..reduced {
10716 let v = inp[o * reduced * inner + r * inner + i];
10717 acc = match op {
10718 ReduceOp::Sum | ReduceOp::Mean => acc + v,
10719 ReduceOp::Max => acc.max(v),
10720 ReduceOp::Min => acc.min(v),
10721 ReduceOp::Prod => acc * v,
10722 };
10723 }
10724 if matches!(op, ReduceOp::Mean) {
10725 acc /= reduced as f32;
10726 }
10727 out[o * inner + i] = acc;
10728 }
10729 }
10730 }
10731 }
10732
10733 Thunk::Conv2D1x1 {
10734 src,
10735 weight,
10736 dst,
10737 n,
10738 c_in,
10739 c_out,
10740 hw,
10741 } => {
10742 let n = *n as usize;
10743 let c_in = *c_in as usize;
10744 let c_out = *c_out as usize;
10745 let hw = *hw as usize;
10746 unsafe {
10747 let inp = sl(*src, base, n * c_in * hw);
10748 let wt = sl(*weight, base, c_out * c_in);
10749 let out = sl_mut(*dst, base, n * c_out * hw);
10750 for ni in 0..n {
10755 let in_off = ni * c_in * hw;
10756 let out_off = ni * c_out * hw;
10757 crate::blas::sgemm(
10758 wt,
10759 &inp[in_off..in_off + c_in * hw],
10760 &mut out[out_off..out_off + c_out * hw],
10761 c_out,
10762 c_in,
10763 hw,
10764 );
10765 }
10766 }
10767 }
10768
10769 Thunk::Conv2D {
10770 src,
10771 weight,
10772 dst,
10773 n,
10774 c_in,
10775 h,
10776 w,
10777 c_out,
10778 h_out,
10779 w_out,
10780 kh,
10781 kw,
10782 sh,
10783 sw,
10784 ph,
10785 pw,
10786 dh,
10787 dw,
10788 groups,
10789 } => {
10790 let n = *n as usize;
10791 let c_in = *c_in as usize;
10792 let h = *h as usize;
10793 let w = *w as usize;
10794 let c_out = *c_out as usize;
10795 let h_out = *h_out as usize;
10796 let w_out = *w_out as usize;
10797 let kh = *kh as usize;
10798 let kw = *kw as usize;
10799 let sh = *sh as usize;
10800 let sw = *sw as usize;
10801 let ph = *ph as usize;
10802 let pw = *pw as usize;
10803 let dh = *dh as usize;
10804 let dw = *dw as usize;
10805 let groups = *groups as usize;
10806 let c_in_per_g = c_in / groups;
10807 let c_out_per_g = c_out / groups;
10808 unsafe {
10809 let inp = sl(*src, base, n * c_in * h * w);
10810 let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
10811 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
10812 for ni in 0..n {
10813 for co in 0..c_out {
10814 let g = co / c_out_per_g;
10815 let ci_start = g * c_in_per_g;
10816 for ho in 0..h_out {
10817 for wo in 0..w_out {
10818 let mut acc = 0f32;
10819 for ci_off in 0..c_in_per_g {
10820 let ci = ci_start + ci_off;
10821 let in_chan = ((ni * c_in) + ci) * h * w;
10822 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
10823 for ki in 0..kh {
10824 for kj in 0..kw {
10825 let hi = ho * sh + ki * dh;
10826 let wi = wo * sw + kj * dw;
10827 if hi < ph || wi < pw {
10828 continue;
10829 }
10830 let hi = hi - ph;
10831 let wi = wi - pw;
10832 if hi >= h || wi >= w {
10833 continue;
10834 }
10835 acc += inp[in_chan + hi * w + wi]
10836 * wt[wt_chan + ki * kw + kj];
10837 }
10838 }
10839 }
10840 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
10841 acc;
10842 }
10843 }
10844 }
10845 }
10846 }
10847 }
10848
10849 Thunk::Pool2D {
10850 src,
10851 dst,
10852 n,
10853 c,
10854 h,
10855 w,
10856 h_out,
10857 w_out,
10858 kh,
10859 kw,
10860 sh,
10861 sw,
10862 ph,
10863 pw,
10864 kind,
10865 } => {
10866 let n = *n as usize;
10867 let c = *c as usize;
10868 let h = *h as usize;
10869 let w = *w as usize;
10870 let h_out = *h_out as usize;
10871 let w_out = *w_out as usize;
10872 let kh = *kh as usize;
10873 let kw = *kw as usize;
10874 let sh = *sh as usize;
10875 let sw = *sw as usize;
10876 let ph = *ph as usize;
10877 let pw = *pw as usize;
10878 let kernel_area = (kh * kw) as f32;
10879 unsafe {
10880 let inp = sl(*src, base, n * c * h * w);
10881 let out = sl_mut(*dst, base, n * c * h_out * w_out);
10882 for ni in 0..n {
10883 for ci in 0..c {
10884 let in_chan = ni * c * h * w + ci * h * w;
10885 let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
10886 for ho in 0..h_out {
10887 for wo in 0..w_out {
10888 let mut acc = match kind {
10889 ReduceOp::Max => f32::NEG_INFINITY,
10890 _ => 0f32, };
10892 for ki in 0..kh {
10893 for kj in 0..kw {
10894 let hi = ho * sh + ki;
10895 let wi = wo * sw + kj;
10896 if hi < ph || wi < pw {
10898 continue;
10899 }
10900 let hi = hi - ph;
10901 let wi = wi - pw;
10902 if hi >= h || wi >= w {
10903 continue;
10904 }
10905 let v = inp[in_chan + hi * w + wi];
10906 match kind {
10907 ReduceOp::Max => acc = acc.max(v),
10908 _ => acc += v,
10909 }
10910 }
10911 }
10912 if matches!(kind, ReduceOp::Mean) {
10913 acc /= kernel_area;
10914 }
10915 out[out_chan + ho * w_out + wo] = acc;
10916 }
10917 }
10918 }
10919 }
10920 }
10921 }
10922
10923 Thunk::ReluBackward { x, dy, dx, len } => {
10924 let len = *len as usize;
10925 unsafe {
10926 let xs = sl(*x, base, len);
10927 let dys = sl(*dy, base, len);
10928 let out = sl_mut(*dx, base, len);
10929 for i in 0..len {
10930 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10931 }
10932 }
10933 }
10934
10935 Thunk::ReluBackwardF64 { x, dy, dx, len } => {
10936 let len = *len as usize;
10937 unsafe {
10938 let xs = sl_f64(*x, base, len);
10939 let dys = sl_f64(*dy, base, len);
10940 let out = sl_mut_f64(*dx, base, len);
10941 for i in 0..len {
10942 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10943 }
10944 }
10945 }
10946
10947 Thunk::QMatMul {
10948 x,
10949 w,
10950 bias,
10951 out,
10952 m,
10953 k,
10954 n,
10955 x_zp,
10956 w_zp,
10957 out_zp,
10958 mult,
10959 } => {
10960 let m = *m as usize;
10961 let k = *k as usize;
10962 let n = *n as usize;
10963 unsafe {
10964 let x_ptr = base.add(*x) as *const i8;
10965 let w_ptr = base.add(*w) as *const i8;
10966 let bias_ptr = base.add(*bias) as *const i32;
10967 let out_ptr = base.add(*out) as *mut i8;
10968 for mi in 0..m {
10969 for ni in 0..n {
10970 let mut acc: i32 = *bias_ptr.add(ni);
10971 for ki in 0..k {
10972 let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
10973 let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
10974 acc += xv * wv;
10975 }
10976 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
10979 let r = r.clamp(-128, 127) as i8;
10980 *out_ptr.add(mi * n + ni) = r;
10981 }
10982 }
10983 }
10984 }
10985
10986 Thunk::QConv2d {
10987 x,
10988 w,
10989 bias,
10990 out,
10991 n,
10992 c_in,
10993 h,
10994 w_in,
10995 c_out,
10996 h_out,
10997 w_out,
10998 kh,
10999 kw,
11000 sh,
11001 sw,
11002 ph,
11003 pw,
11004 dh,
11005 dw,
11006 groups,
11007 x_zp,
11008 w_zp,
11009 out_zp,
11010 mult,
11011 } => {
11012 let n = *n as usize;
11013 let c_in = *c_in as usize;
11014 let h = *h as usize;
11015 let w_in = *w_in as usize;
11016 let c_out = *c_out as usize;
11017 let h_out = *h_out as usize;
11018 let w_out = *w_out as usize;
11019 let kh = *kh as usize;
11020 let kw = *kw as usize;
11021 let sh = *sh as usize;
11022 let sw = *sw as usize;
11023 let ph = *ph as usize;
11024 let pw = *pw as usize;
11025 let dh = *dh as usize;
11026 let dw = *dw as usize;
11027 let groups = *groups as usize;
11028 let c_in_per_g = c_in / groups;
11029 let c_out_per_g = c_out / groups;
11030 unsafe {
11031 let x_ptr = base.add(*x) as *const i8;
11032 let w_ptr = base.add(*w) as *const i8;
11033 let bias_ptr = base.add(*bias) as *const i32;
11034 let out_ptr = base.add(*out) as *mut i8;
11035 for ni in 0..n {
11036 for co in 0..c_out {
11037 let g = co / c_out_per_g;
11038 let ci_start = g * c_in_per_g;
11039 for ho in 0..h_out {
11040 for wo in 0..w_out {
11041 let mut acc: i32 = *bias_ptr.add(co);
11042 for ci_off in 0..c_in_per_g {
11043 let ci = ci_start + ci_off;
11044 let in_chan = ((ni * c_in) + ci) * h * w_in;
11045 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
11046 for ki in 0..kh {
11047 for kj in 0..kw {
11048 let hi = ho * sh + ki * dh;
11049 let wi = wo * sw + kj * dw;
11050 if hi < ph || wi < pw {
11051 continue;
11052 }
11053 let hi = hi - ph;
11054 let wi = wi - pw;
11055 if hi >= h || wi >= w_in {
11056 continue;
11057 }
11058 let xv = *x_ptr.add(in_chan + hi * w_in + wi)
11059 as i32
11060 - *x_zp;
11061 let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
11062 - *w_zp;
11063 acc += xv * wv;
11064 }
11065 }
11066 }
11067 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
11068 let r = r.clamp(-128, 127) as i8;
11069 let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
11070 *out_ptr.add(dst) = r;
11071 }
11072 }
11073 }
11074 }
11075 }
11076 }
11077
11078 Thunk::Quantize {
11079 x,
11080 q,
11081 len,
11082 chan_axis: _,
11083 chan_dim,
11084 inner,
11085 scales,
11086 zero_points,
11087 } => {
11088 let len = *len as usize;
11089 let chan_dim = *chan_dim as usize;
11090 let inner = *inner as usize;
11091 unsafe {
11092 let xs = sl(*x, base, len);
11093 let q_ptr = base.add(*q) as *mut i8;
11094 for i in 0..len {
11095 let c = if chan_dim == 1 {
11096 0
11097 } else {
11098 (i / inner) % chan_dim
11099 };
11100 let inv_scale = 1.0 / scales[c];
11101 let zp = zero_points[c];
11102 let v = (xs[i] * inv_scale).round() as i32 + zp;
11103 *q_ptr.add(i) = v.clamp(-128, 127) as i8;
11104 }
11105 }
11106 }
11107
11108 Thunk::Dequantize {
11109 q,
11110 x,
11111 len,
11112 chan_axis: _,
11113 chan_dim,
11114 inner,
11115 scales,
11116 zero_points,
11117 } => {
11118 let len = *len as usize;
11119 let chan_dim = *chan_dim as usize;
11120 let inner = *inner as usize;
11121 unsafe {
11122 let q_ptr = base.add(*q) as *const i8;
11123 let out = sl_mut(*x, base, len);
11124 for i in 0..len {
11125 let c = if chan_dim == 1 {
11126 0
11127 } else {
11128 (i / inner) % chan_dim
11129 };
11130 let scale = scales[c];
11131 let zp = zero_points[c];
11132 let qv = *q_ptr.add(i) as i32;
11133 out[i] = (qv - zp) as f32 * scale;
11134 }
11135 }
11136 }
11137
11138 Thunk::FakeQuantize {
11139 x,
11140 out,
11141 len,
11142 chan_axis: _,
11143 chan_dim,
11144 inner,
11145 bits,
11146 ste: _,
11147 scale_mode,
11148 state_off,
11149 } => {
11150 use rlx_ir::op::ScaleMode;
11151 let len = *len as usize;
11152 let chan_dim = *chan_dim as usize;
11153 let inner = *inner as usize;
11154 let q_max: f32 = match *bits {
11155 8 => 127.0,
11156 4 => 7.0,
11157 2 => 1.0,
11158 n => panic!("FakeQuantize: unsupported bits {n}"),
11159 };
11160 unsafe {
11161 let xs = sl(*x, base, len);
11162 let outs = sl_mut(*out, base, len);
11163
11164 let mut scale = vec![0f32; chan_dim];
11165 match scale_mode {
11166 ScaleMode::PerBatch => {
11167 let mut max_abs = vec![0f32; chan_dim];
11168 for i in 0..len {
11169 let c = if chan_dim == 1 {
11170 0
11171 } else {
11172 (i / inner) % chan_dim
11173 };
11174 let a = xs[i].abs();
11175 if a > max_abs[c] {
11176 max_abs[c] = a;
11177 }
11178 }
11179 for c in 0..chan_dim {
11180 scale[c] = (max_abs[c] / q_max).max(1e-12);
11181 }
11182 }
11183 ScaleMode::EMA { decay } => {
11184 let mut max_abs = vec![0f32; chan_dim];
11187 for i in 0..len {
11188 let c = if chan_dim == 1 {
11189 0
11190 } else {
11191 (i / inner) % chan_dim
11192 };
11193 let a = xs[i].abs();
11194 if a > max_abs[c] {
11195 max_abs[c] = a;
11196 }
11197 }
11198 let state =
11199 sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
11200 for c in 0..chan_dim {
11201 let cur = (max_abs[c] / q_max).max(1e-12);
11202 let blended = if state[c] <= 0.0 {
11204 cur
11205 } else {
11206 *decay * state[c] + (1.0 - *decay) * cur
11207 };
11208 state[c] = blended;
11209 scale[c] = blended;
11210 }
11211 }
11212 ScaleMode::Fixed => {
11213 let state =
11214 sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
11215 for c in 0..chan_dim {
11216 scale[c] = state[c].max(1e-12);
11217 }
11218 }
11219 }
11220
11221 for i in 0..len {
11222 let c = if chan_dim == 1 {
11223 0
11224 } else {
11225 (i / inner) % chan_dim
11226 };
11227 let s = scale[c];
11228 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11229 outs[i] = qv * s;
11230 }
11231 }
11232 }
11233
11234 Thunk::ActivationBackward {
11235 x,
11236 dy,
11237 dx,
11238 len,
11239 kind,
11240 } => {
11241 let len = *len as usize;
11242 unsafe {
11243 let xs = sl(*x, base, len);
11244 let dys = sl(*dy, base, len);
11245 let out = sl_mut(*dx, base, len);
11246 activation_backward_kernel(*kind, xs, dys, out);
11247 }
11248 }
11249
11250 Thunk::ActivationBackwardF64 {
11251 x,
11252 dy,
11253 dx,
11254 len,
11255 kind,
11256 } => {
11257 let len = *len as usize;
11258 unsafe {
11259 let xs = sl_f64(*x, base, len);
11260 let dys = sl_f64(*dy, base, len);
11261 let out = sl_mut_f64(*dx, base, len);
11262 activation_backward_kernel_f64(*kind, xs, dys, out);
11263 }
11264 }
11265
11266 Thunk::FakeQuantizeLSQ {
11267 x,
11268 scale_off,
11269 out,
11270 len,
11271 chan_axis: _,
11272 chan_dim,
11273 inner,
11274 bits,
11275 } => {
11276 let len = *len as usize;
11277 let chan_dim = *chan_dim as usize;
11278 let inner = *inner as usize;
11279 let q_max: f32 = match *bits {
11280 8 => 127.0,
11281 4 => 7.0,
11282 2 => 1.0,
11283 n => panic!("FakeQuantizeLSQ: bad bits {n}"),
11284 };
11285 unsafe {
11286 let xs = sl(*x, base, len);
11287 let scale = sl(*scale_off, base, chan_dim);
11288 let outs = sl_mut(*out, base, len);
11289 for i in 0..len {
11290 let c = if chan_dim == 1 {
11291 0
11292 } else {
11293 (i / inner) % chan_dim
11294 };
11295 let s = scale[c].max(1e-12);
11296 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11297 outs[i] = qv * s;
11298 }
11299 }
11300 }
11301
11302 Thunk::FakeQuantizeLSQBackwardX {
11303 x,
11304 scale_off,
11305 dy,
11306 dx,
11307 len,
11308 chan_axis: _,
11309 chan_dim,
11310 inner,
11311 bits,
11312 } => {
11313 let len = *len as usize;
11314 let chan_dim = *chan_dim as usize;
11315 let inner = *inner as usize;
11316 let q_max: f32 = match *bits {
11317 8 => 127.0,
11318 4 => 7.0,
11319 2 => 1.0,
11320 n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
11321 };
11322 unsafe {
11323 let xs = sl(*x, base, len);
11324 let scale = sl(*scale_off, base, chan_dim);
11325 let dys = sl(*dy, base, len);
11326 let outs = sl_mut(*dx, base, len);
11327 for i in 0..len {
11329 let c = if chan_dim == 1 {
11330 0
11331 } else {
11332 (i / inner) % chan_dim
11333 };
11334 let z = xs[i] / scale[c].max(1e-12);
11335 outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
11336 }
11337 }
11338 }
11339
11340 Thunk::FakeQuantizeLSQBackwardScale {
11341 x,
11342 scale_off,
11343 dy,
11344 dscale,
11345 len,
11346 chan_axis: _,
11347 chan_dim,
11348 inner,
11349 bits,
11350 } => {
11351 let len = *len as usize;
11352 let chan_dim = *chan_dim as usize;
11353 let inner = *inner as usize;
11354 let q_max: f32 = match *bits {
11355 8 => 127.0,
11356 4 => 7.0,
11357 2 => 1.0,
11358 n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
11359 };
11360 unsafe {
11361 let xs = sl(*x, base, len);
11362 let scale = sl(*scale_off, base, chan_dim);
11363 let dys = sl(*dy, base, len);
11364 let outs = sl_mut(*dscale, base, chan_dim);
11365 for v in outs.iter_mut() {
11366 *v = 0.0;
11367 }
11368 for i in 0..len {
11371 let c = if chan_dim == 1 {
11372 0
11373 } else {
11374 (i / inner) % chan_dim
11375 };
11376 let s = scale[c].max(1e-12);
11377 let z = xs[i] / s;
11378 let psi = if z.abs() <= q_max {
11379 -z + z.round()
11380 } else if z > 0.0 {
11381 q_max
11382 } else {
11383 -q_max
11384 };
11385 outs[c] += psi * dys[i];
11386 }
11387 }
11388 }
11389
11390 Thunk::FakeQuantizeBackward {
11391 x,
11392 dy,
11393 dx,
11394 len,
11395 chan_axis: _,
11396 chan_dim,
11397 inner,
11398 bits,
11399 ste,
11400 } => {
11401 use rlx_ir::op::SteKind;
11402 let len = *len as usize;
11403 let chan_dim = *chan_dim as usize;
11404 let inner = *inner as usize;
11405 let q_max: f32 = match *bits {
11406 8 => 127.0,
11407 4 => 7.0,
11408 2 => 1.0,
11409 n => panic!("FakeQuantizeBackward: bad bits {n}"),
11410 };
11411 unsafe {
11412 let xs = sl(*x, base, len);
11413 let dys = sl(*dy, base, len);
11414 let outs = sl_mut(*dx, base, len);
11415
11416 let mut max_abs = vec![0f32; chan_dim];
11418 for i in 0..len {
11419 let c = if chan_dim == 1 {
11420 0
11421 } else {
11422 (i / inner) % chan_dim
11423 };
11424 let a = xs[i].abs();
11425 if a > max_abs[c] {
11426 max_abs[c] = a;
11427 }
11428 }
11429 let mut scale = vec![0f32; chan_dim];
11430 for c in 0..chan_dim {
11431 scale[c] = (max_abs[c] / q_max).max(1e-12);
11432 }
11433
11434 match *ste {
11435 SteKind::Identity => {
11436 outs.copy_from_slice(dys);
11438 }
11439 SteKind::ClippedIdentity => {
11440 for i in 0..len {
11443 let c = if chan_dim == 1 {
11444 0
11445 } else {
11446 (i / inner) % chan_dim
11447 };
11448 let bound = q_max * scale[c];
11449 outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
11450 }
11451 }
11452 SteKind::Tanh => {
11453 for i in 0..len {
11455 let c = if chan_dim == 1 {
11456 0
11457 } else {
11458 (i / inner) % chan_dim
11459 };
11460 let t = (xs[i] / scale[c]).tanh();
11461 outs[i] = dys[i] * (1.0 - t * t);
11462 }
11463 }
11464 SteKind::HardTanh => {
11465 for i in 0..len {
11467 let c = if chan_dim == 1 {
11468 0
11469 } else {
11470 (i / inner) % chan_dim
11471 };
11472 let bound = q_max * scale[c];
11473 let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
11474 outs[i] = dys[i] * attenuation;
11475 }
11476 }
11477 }
11478 }
11479 }
11480
11481 Thunk::LayerNormBackwardInput {
11482 x,
11483 gamma,
11484 dy,
11485 dx,
11486 rows,
11487 h,
11488 eps,
11489 } => {
11490 let rows = *rows as usize;
11491 let h = *h as usize;
11492 let eps = *eps;
11493 unsafe {
11494 let xs = sl(*x, base, rows * h);
11495 let g = sl(*gamma, base, h);
11496 let dys = sl(*dy, base, rows * h);
11497 let out = sl_mut(*dx, base, rows * h);
11498 let n_inv = 1.0 / h as f32;
11499 for r in 0..rows {
11500 let xr = &xs[r * h..(r + 1) * h];
11501 let dyr = &dys[r * h..(r + 1) * h];
11502 let mut sum = 0f32;
11505 for &v in xr {
11506 sum += v;
11507 }
11508 let mean = sum * n_inv;
11509 let mut var = 0f32;
11510 for &v in xr {
11511 let d = v - mean;
11512 var += d * d;
11513 }
11514 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11515
11516 let mut s_sy = 0f32;
11519 let mut s_sxh = 0f32;
11520 for d in 0..h {
11521 let xh = (xr[d] - mean) * inv_std;
11522 let sy = dyr[d] * g[d];
11523 s_sy += sy;
11524 s_sxh += sy * xh;
11525 }
11526 let m_sy = s_sy * n_inv;
11527 let m_sxh = s_sxh * n_inv;
11528
11529 for d in 0..h {
11530 let xh = (xr[d] - mean) * inv_std;
11531 let sy = dyr[d] * g[d];
11532 out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
11533 }
11534 }
11535 }
11536 }
11537
11538 Thunk::LayerNormBackwardGamma {
11539 x,
11540 dy,
11541 dgamma,
11542 rows,
11543 h,
11544 eps,
11545 } => {
11546 let rows = *rows as usize;
11547 let h = *h as usize;
11548 let eps = *eps;
11549 unsafe {
11550 let xs = sl(*x, base, rows * h);
11551 let dys = sl(*dy, base, rows * h);
11552 let out = sl_mut(*dgamma, base, h);
11553 for v in out.iter_mut() {
11554 *v = 0.0;
11555 }
11556 let n_inv = 1.0 / h as f32;
11557 for r in 0..rows {
11558 let xr = &xs[r * h..(r + 1) * h];
11559 let dyr = &dys[r * h..(r + 1) * h];
11560 let mut sum = 0f32;
11561 for &v in xr {
11562 sum += v;
11563 }
11564 let mean = sum * n_inv;
11565 let mut var = 0f32;
11566 for &v in xr {
11567 let d = v - mean;
11568 var += d * d;
11569 }
11570 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11571 for d in 0..h {
11572 let xh = (xr[d] - mean) * inv_std;
11573 out[d] += dyr[d] * xh;
11574 }
11575 }
11576 }
11577 }
11578
11579 Thunk::RmsNormBackwardInput {
11580 x,
11581 gamma,
11582 beta,
11583 dy,
11584 dx,
11585 rows,
11586 h,
11587 eps,
11588 } => {
11589 let (rows, h) = (*rows as usize, *h as usize);
11590 unsafe {
11591 let xs = sl(*x, base, rows * h);
11592 let g = sl(*gamma, base, h);
11593 let b = sl(*beta, base, h);
11594 let dys = sl(*dy, base, rows * h);
11595 let out = sl_mut(*dx, base, rows * h);
11596 let mut dg = vec![0f32; h];
11597 let mut db = vec![0f32; h];
11598 for r in 0..rows {
11599 crate::training_bwd::rms_norm_backward_row(
11600 &xs[r * h..(r + 1) * h],
11601 g,
11602 b,
11603 &dys[r * h..(r + 1) * h],
11604 &mut out[r * h..(r + 1) * h],
11605 &mut dg,
11606 &mut db,
11607 *eps,
11608 );
11609 }
11610 }
11611 }
11612
11613 Thunk::RmsNormBackwardGamma {
11614 x,
11615 gamma,
11616 beta,
11617 dy,
11618 dgamma,
11619 rows,
11620 h,
11621 eps,
11622 } => {
11623 let (rows, h) = (*rows as usize, *h as usize);
11624 unsafe {
11625 let xs = sl(*x, base, rows * h);
11626 let g = sl(*gamma, base, h);
11627 let b = sl(*beta, base, h);
11628 let dys = sl(*dy, base, rows * h);
11629 let out = sl_mut(*dgamma, base, h);
11630 for v in out.iter_mut() {
11631 *v = 0.0;
11632 }
11633 let mut dx = vec![0f32; h];
11634 let mut db = vec![0f32; h];
11635 for r in 0..rows {
11636 crate::training_bwd::rms_norm_backward_row(
11637 &xs[r * h..(r + 1) * h],
11638 g,
11639 b,
11640 &dys[r * h..(r + 1) * h],
11641 &mut dx,
11642 &mut *out,
11643 &mut db,
11644 *eps,
11645 );
11646 }
11647 }
11648 }
11649
11650 Thunk::RmsNormBackwardBeta {
11651 x,
11652 gamma,
11653 beta,
11654 dy,
11655 dbeta,
11656 rows,
11657 h,
11658 eps,
11659 } => {
11660 let (rows, h) = (*rows as usize, *h as usize);
11661 unsafe {
11662 let xs = sl(*x, base, rows * h);
11663 let g = sl(*gamma, base, h);
11664 let b = sl(*beta, base, h);
11665 let dys = sl(*dy, base, rows * h);
11666 let out = sl_mut(*dbeta, base, h);
11667 for v in out.iter_mut() {
11668 *v = 0.0;
11669 }
11670 let mut dx = vec![0f32; h];
11671 let mut dg = vec![0f32; h];
11672 for r in 0..rows {
11673 crate::training_bwd::rms_norm_backward_row(
11674 &xs[r * h..(r + 1) * h],
11675 g,
11676 b,
11677 &dys[r * h..(r + 1) * h],
11678 &mut dx,
11679 &mut dg,
11680 &mut *out,
11681 *eps,
11682 );
11683 }
11684 }
11685 }
11686
11687 Thunk::RopeBackward {
11688 dy,
11689 cos,
11690 sin,
11691 dx,
11692 batch,
11693 seq,
11694 hidden,
11695 head_dim,
11696 n_rot,
11697 cos_len,
11698 } => {
11699 let (b, s, hs, dh, nr, cl) = (
11700 *batch as usize,
11701 *seq as usize,
11702 *hidden as usize,
11703 *head_dim as usize,
11704 *n_rot as usize,
11705 *cos_len as usize,
11706 );
11707 let nh = hs / dh;
11708 let tab_half = dh / 2;
11709 unsafe {
11710 let dys = sl(*dy, base, b * s * hs);
11711 let cos_tab = sl(*cos, base, cl);
11712 let sin_tab = sl(*sin, base, cl);
11713 let out = sl_mut(*dx, base, b * s * hs);
11714 for bi in 0..b {
11715 for si in 0..s {
11716 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
11717 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
11718 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
11719 for hi in 0..nh {
11720 let base_idx = bi * s * hs + si * hs + hi * dh;
11721 crate::training_bwd::rope_backward_row(
11722 &dys[base_idx..base_idx + dh],
11723 cp,
11724 sp,
11725 &mut out[base_idx..base_idx + dh],
11726 dh,
11727 nr,
11728 );
11729 }
11730 }
11731 }
11732 }
11733 }
11734
11735 Thunk::CumsumBackward {
11736 dy,
11737 dx,
11738 rows,
11739 cols,
11740 exclusive,
11741 } => {
11742 let (rows, cols) = (*rows as usize, *cols as usize);
11743 unsafe {
11744 let dys = sl(*dy, base, rows * cols);
11745 let out = sl_mut(*dx, base, rows * cols);
11746 for r in 0..rows {
11747 crate::training_bwd::cumsum_backward_row(
11748 &dys[r * cols..(r + 1) * cols],
11749 &mut out[r * cols..(r + 1) * cols],
11750 *exclusive,
11751 );
11752 }
11753 }
11754 }
11755
11756 Thunk::GroupNormBackwardInput {
11757 x,
11758 gamma,
11759 beta: _beta,
11760 dy,
11761 dx,
11762 n,
11763 c,
11764 h,
11765 w,
11766 num_groups,
11767 eps,
11768 } => {
11769 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11770 let plane = c * h * w;
11771 unsafe {
11772 let xs = sl(*x, base, n * plane);
11773 let g = sl(*gamma, base, c);
11774 let dys = sl(*dy, base, n * plane);
11775 let out = sl_mut(*dx, base, n * plane);
11776 crate::training_bwd::group_norm_backward_input_nchw(
11777 xs,
11778 g,
11779 dys,
11780 out,
11781 n,
11782 c,
11783 h,
11784 w,
11785 *num_groups as usize,
11786 *eps,
11787 );
11788 }
11789 }
11790
11791 Thunk::GroupNormBackwardGamma {
11792 x,
11793 dy,
11794 dgamma,
11795 n,
11796 c,
11797 h,
11798 w,
11799 num_groups,
11800 eps,
11801 } => {
11802 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11803 let plane = c * h * w;
11804 unsafe {
11805 let xs = sl(*x, base, n * plane);
11806 let dys = sl(*dy, base, n * plane);
11807 let out = sl_mut(*dgamma, base, c);
11808 crate::training_bwd::group_norm_backward_gamma_nchw(
11809 xs,
11810 dys,
11811 out,
11812 n,
11813 c,
11814 h,
11815 w,
11816 *num_groups as usize,
11817 *eps,
11818 );
11819 }
11820 }
11821
11822 Thunk::GroupNormBackwardBeta {
11823 dy,
11824 dbeta,
11825 n,
11826 c,
11827 h,
11828 w,
11829 } => {
11830 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11831 let plane = c * h * w;
11832 unsafe {
11833 let dys = sl(*dy, base, n * plane);
11834 let out = sl_mut(*dbeta, base, c);
11835 crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
11836 }
11837 }
11838
11839 Thunk::GatherBackward {
11840 dy,
11841 indices,
11842 dst,
11843 outer,
11844 axis_dim,
11845 num_idx,
11846 trailing,
11847 } => {
11848 let (outer, axis_dim, num_idx, trailing) = (
11849 *outer as usize,
11850 *axis_dim as usize,
11851 *num_idx as usize,
11852 *trailing as usize,
11853 );
11854 unsafe {
11855 let dys = sl(*dy, base, outer * num_idx * trailing);
11856 let ids = sl(*indices, base, num_idx);
11857 let out = sl_mut(*dst, base, outer * axis_dim * trailing);
11858 for v in out.iter_mut() {
11859 *v = 0.0;
11860 }
11861 crate::training_bwd::gather_axis_backward(
11862 dys, ids, out, outer, axis_dim, num_idx, trailing,
11863 );
11864 }
11865 }
11866
11867 Thunk::MaxPool2dBackward {
11868 x,
11869 dy,
11870 dx,
11871 n,
11872 c,
11873 h,
11874 w,
11875 h_out,
11876 w_out,
11877 kh,
11878 kw,
11879 sh,
11880 sw,
11881 ph,
11882 pw,
11883 } => {
11884 let n = *n as usize;
11885 let c = *c as usize;
11886 let h = *h as usize;
11887 let w = *w as usize;
11888 let h_out = *h_out as usize;
11889 let w_out = *w_out as usize;
11890 let kh = *kh as usize;
11891 let kw = *kw as usize;
11892 let sh = *sh as usize;
11893 let sw = *sw as usize;
11894 let ph = *ph as usize;
11895 let pw = *pw as usize;
11896 unsafe {
11897 let xs = sl(*x, base, n * c * h * w);
11898 let dys = sl(*dy, base, n * c * h_out * w_out);
11899 let dxs = sl_mut(*dx, base, n * c * h * w);
11900 for v in dxs.iter_mut() {
11903 *v = 0.0;
11904 }
11905 for ni in 0..n {
11906 for ci in 0..c {
11907 let in_chan = (ni * c + ci) * h * w;
11908 let out_chan = (ni * c + ci) * h_out * w_out;
11909 for ho in 0..h_out {
11910 for wo in 0..w_out {
11911 let mut best_v = f32::NEG_INFINITY;
11913 let mut best_idx: Option<usize> = None;
11914 for ki in 0..kh {
11915 for kj in 0..kw {
11916 let hi = ho * sh + ki;
11917 let wi = wo * sw + kj;
11918 if hi < ph || wi < pw {
11919 continue;
11920 }
11921 let hi = hi - ph;
11922 let wi = wi - pw;
11923 if hi >= h || wi >= w {
11924 continue;
11925 }
11926 let idx = in_chan + hi * w + wi;
11927 let v = xs[idx];
11928 if v > best_v {
11932 best_v = v;
11933 best_idx = Some(idx);
11934 }
11935 }
11936 }
11937 if let Some(idx) = best_idx {
11938 dxs[idx] += dys[out_chan + ho * w_out + wo];
11939 }
11940 }
11941 }
11942 }
11943 }
11944 }
11945 }
11946
11947 Thunk::Conv2dBackwardInput {
11948 dy,
11949 w,
11950 dx,
11951 n,
11952 c_in,
11953 h,
11954 w_in,
11955 c_out,
11956 h_out,
11957 w_out,
11958 kh,
11959 kw,
11960 sh,
11961 sw,
11962 ph,
11963 pw,
11964 dh,
11965 dw,
11966 groups,
11967 } => {
11968 let n = *n as usize;
11980 let c_in = *c_in as usize;
11981 let h = *h as usize;
11982 let w_in = *w_in as usize;
11983 let c_out = *c_out as usize;
11984 let h_out = *h_out as usize;
11985 let w_out = *w_out as usize;
11986 let kh = *kh as usize;
11987 let kw = *kw as usize;
11988 let sh = *sh as usize;
11989 let sw = *sw as usize;
11990 let ph = *ph as usize;
11991 let pw = *pw as usize;
11992 let dh = *dh as usize;
11993 let dw = *dw as usize;
11994 let groups = *groups as usize;
11995 let c_in_per_g = c_in / groups;
11996 let c_out_per_g = c_out / groups;
11997
11998 let m_dim = c_in_per_g * kh * kw;
11999 let n_dim = h_out * w_out;
12000 let k_dim = c_out_per_g;
12001
12002 let dy_stride_n = c_out * h_out * w_out;
12003 let dy_stride_g = c_out_per_g * h_out * w_out;
12004 let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12005 let dx_stride_n = c_in * h * w_in;
12006 let dx_stride_g = c_in_per_g * h * w_in;
12007
12008 unsafe {
12009 let dys = sl(*dy, base, n * c_out * h_out * w_out);
12010 let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
12011 let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
12012 for v in dxs.iter_mut() {
12013 *v = 0.0;
12014 }
12015
12016 let mut dcol = vec![0f32; m_dim * n_dim];
12018
12019 for ni in 0..n {
12020 for g in 0..groups {
12021 let w_g_off = g * w_stride_g;
12022 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12023 let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
12024
12025 crate::blas::sgemm_general(
12030 ws.as_ptr().add(w_g_off),
12031 dys.as_ptr().add(dy_n_g_off),
12032 dcol.as_mut_ptr(),
12033 m_dim,
12034 n_dim,
12035 k_dim,
12036 1.0,
12037 0.0,
12038 m_dim,
12039 n_dim,
12040 n_dim,
12041 true,
12042 false,
12043 );
12044
12045 col2im(
12047 &dcol,
12048 &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
12049 c_in_per_g,
12050 h,
12051 w_in,
12052 h_out,
12053 w_out,
12054 kh,
12055 kw,
12056 sh,
12057 sw,
12058 ph,
12059 pw,
12060 dh,
12061 dw,
12062 );
12063 }
12064 }
12065 }
12066 }
12067
12068 Thunk::Conv2dBackwardWeight {
12069 x,
12070 dy,
12071 dw,
12072 n,
12073 c_in,
12074 h,
12075 w,
12076 c_out,
12077 h_out,
12078 w_out,
12079 kh,
12080 kw,
12081 sh,
12082 sw,
12083 ph,
12084 pw,
12085 dh,
12086 dw_dil,
12087 groups,
12088 } => {
12089 let n = *n as usize;
12090 let c_in = *c_in as usize;
12091 let h = *h as usize;
12092 let w = *w as usize;
12093 let c_out = *c_out as usize;
12104 let h_out = *h_out as usize;
12105 let w_out = *w_out as usize;
12106 let kh = *kh as usize;
12107 let kw = *kw as usize;
12108 let sh = *sh as usize;
12109 let sw = *sw as usize;
12110 let ph = *ph as usize;
12111 let pw = *pw as usize;
12112 let dh = *dh as usize;
12113 let dw_dil = *dw_dil as usize;
12114 let groups = *groups as usize;
12115 let c_in_per_g = c_in / groups;
12116 let c_out_per_g = c_out / groups;
12117
12118 let m_dim = c_out_per_g;
12119 let n_dim = c_in_per_g * kh * kw;
12120 let k_dim = h_out * w_out;
12121
12122 let x_stride_n = c_in * h * w;
12123 let x_stride_g = c_in_per_g * h * w;
12124 let dy_stride_n = c_out * h_out * w_out;
12125 let dy_stride_g = c_out_per_g * h_out * w_out;
12126 let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12127
12128 unsafe {
12129 let xs = sl(*x, base, n * c_in * h * w);
12130 let dys = sl(*dy, base, n * c_out * h_out * w_out);
12131 let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
12132 for v in dws.iter_mut() {
12133 *v = 0.0;
12134 }
12135
12136 let mut col = vec![0f32; n_dim * k_dim];
12137
12138 for ni in 0..n {
12139 for g in 0..groups {
12140 let x_n_g_off = ni * x_stride_n + g * x_stride_g;
12141 im2col(
12142 &xs[x_n_g_off..x_n_g_off + x_stride_g],
12143 &mut col,
12144 c_in_per_g,
12145 h,
12146 w,
12147 h_out,
12148 w_out,
12149 kh,
12150 kw,
12151 sh,
12152 sw,
12153 ph,
12154 pw,
12155 dh,
12156 dw_dil,
12157 );
12158
12159 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12160 let dw_g_off = g * dw_stride_g;
12161
12162 crate::blas::sgemm_general(
12170 dys.as_ptr().add(dy_n_g_off),
12171 col.as_ptr(),
12172 dws.as_mut_ptr().add(dw_g_off),
12173 m_dim,
12174 n_dim,
12175 k_dim,
12176 1.0,
12177 1.0,
12178 k_dim,
12179 k_dim,
12180 n_dim,
12181 false,
12182 true,
12183 );
12184 }
12185 }
12186 }
12187 }
12188
12189 Thunk::SoftmaxCrossEntropy {
12190 logits,
12191 labels,
12192 dst,
12193 n,
12194 c,
12195 } => {
12196 let n = *n as usize;
12197 let c = *c as usize;
12198 unsafe {
12199 let lg = sl(*logits, base, n * c);
12200 let lb = sl(*labels, base, n);
12201 let out = sl_mut(*dst, base, n);
12202 for ni in 0..n {
12203 let row = &lg[ni * c..(ni + 1) * c];
12204 let mut m = f32::NEG_INFINITY;
12206 for &v in row {
12207 if v > m {
12208 m = v;
12209 }
12210 }
12211 let mut sum = 0f32;
12212 for &v in row {
12213 sum += (v - m).exp();
12214 }
12215 let lse = m + sum.ln();
12216 let label_idx = lb[ni] as usize;
12217 out[ni] = lse - row[label_idx];
12219 }
12220 }
12221 }
12222
12223 Thunk::SoftmaxCrossEntropyBackward {
12224 logits,
12225 labels,
12226 d_loss,
12227 dlogits,
12228 n,
12229 c,
12230 } => {
12231 let n = *n as usize;
12232 let c = *c as usize;
12233 unsafe {
12234 let lg = sl(*logits, base, n * c);
12235 let lb = sl(*labels, base, n);
12236 let dl = sl(*d_loss, base, n);
12237 let out = sl_mut(*dlogits, base, n * c);
12238 for ni in 0..n {
12239 let row = &lg[ni * c..(ni + 1) * c];
12240 let label_idx = lb[ni] as usize;
12241 let scale = dl[ni];
12242 let mut m = f32::NEG_INFINITY;
12243 for &v in row {
12244 if v > m {
12245 m = v;
12246 }
12247 }
12248 let mut sum = 0f32;
12249 for &v in row {
12250 sum += (v - m).exp();
12251 }
12252 let inv_sum = 1.0 / sum;
12253 let dst_row = &mut out[ni * c..(ni + 1) * c];
12254 for k in 0..c {
12255 let p = (row[k] - m).exp() * inv_sum;
12256 let one_hot = if k == label_idx { 1.0 } else { 0.0 };
12257 dst_row[k] = (p - one_hot) * scale;
12258 }
12259 }
12260 }
12261 }
12262
12263 Thunk::GatherAxis {
12264 table,
12265 idx,
12266 dst,
12267 outer,
12268 axis_dim,
12269 num_idx,
12270 trailing,
12271 } => {
12272 let outer = *outer as usize;
12273 let axis_dim = *axis_dim as usize;
12274 let num_idx = *num_idx as usize;
12275 let trailing = *trailing as usize;
12276 unsafe {
12277 let tab = sl(*table, base, outer * axis_dim * trailing);
12278 let ids = sl(*idx, base, num_idx);
12279 let out = sl_mut(*dst, base, outer * num_idx * trailing);
12280 for o in 0..outer {
12281 let tab_outer = o * axis_dim * trailing;
12282 let out_outer = o * num_idx * trailing;
12283 for k in 0..num_idx {
12284 let row = ids[k] as usize;
12285 let tab_row = tab_outer + row * trailing;
12286 let out_row = out_outer + k * trailing;
12287 out[out_row..out_row + trailing]
12288 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
12289 }
12290 }
12291 }
12292 }
12293
12294 Thunk::Transpose {
12295 src,
12296 dst,
12297 in_total,
12298 out_dims,
12299 in_strides,
12300 } => {
12301 let rank = out_dims.len();
12306 let total: usize = out_dims.iter().map(|&d| d as usize).product();
12307 let in_total = *in_total as usize;
12308 unsafe {
12309 let inp = sl(*src, base, in_total);
12310 let out = sl_mut(*dst, base, total);
12311 let mut idx = vec![0usize; rank];
12312 for o in 0..total {
12313 let mut src_idx = 0usize;
12314 for d in 0..rank {
12315 src_idx += idx[d] * in_strides[d] as usize;
12316 }
12317 out[o] = inp[src_idx];
12318 for d in (0..rank).rev() {
12320 idx[d] += 1;
12321 if idx[d] < out_dims[d] as usize {
12322 break;
12323 }
12324 idx[d] = 0;
12325 }
12326 }
12327 }
12328 }
12329
12330 Thunk::CustomOp {
12336 kernel,
12337 inputs,
12338 output,
12339 attrs,
12340 } => {
12341 let (out_off, out_len, out_shape) = output;
12342 unsafe {
12343 dispatch_custom_op(
12344 &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
12345 );
12346 }
12347 }
12348 }
12349 }
12350}
12351
12352#[allow(clippy::too_many_arguments)]
12367unsafe fn griewank_process_segment(
12368 t_lo: usize,
12369 t_hi: usize,
12370 anchor_carry: &[u8],
12371 cb: usize,
12372 fwd_sched: &ThunkSchedule,
12373 fwd_init: &[u8],
12374 fwd_carry_in_off: usize,
12375 fwd_output_off: usize,
12376 fwd_x_offs: &[usize],
12377 base: *mut u8,
12378 outer_xs_offs: &[(usize, u32)],
12379 fwd_buf: &mut Vec<u8>,
12380 leaf_threshold: usize,
12381 process_iter: &mut dyn FnMut(usize, &[u8]),
12382) {
12383 unsafe {
12384 let size = t_hi - t_lo + 1;
12385 if size == 1 {
12386 process_iter(t_lo, anchor_carry);
12387 return;
12388 }
12389 if size <= leaf_threshold {
12390 let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
12392 cache.extend_from_slice(anchor_carry);
12393 fwd_buf.copy_from_slice(fwd_init);
12394 std::ptr::copy_nonoverlapping(
12395 anchor_carry.as_ptr(),
12396 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12397 cb,
12398 );
12399 for i in 1..size {
12400 let cur_iter = t_lo + i - 1;
12401 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12402 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12403 let xb = x_psb as usize;
12404 std::ptr::copy_nonoverlapping(
12405 base.add(outer_xs_off + cur_iter * xb),
12406 fwd_buf.as_mut_ptr().add(*fb_x_off),
12407 xb,
12408 );
12409 }
12410 execute_thunks(fwd_sched, fwd_buf);
12411 if fwd_output_off != fwd_carry_in_off {
12412 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12413 }
12414 cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
12415 }
12416 for t in (t_lo..=t_hi).rev() {
12418 let idx = t - t_lo;
12419 let carry = &cache[idx * cb..(idx + 1) * cb];
12420 process_iter(t, carry);
12421 }
12422 return;
12423 }
12424
12425 let mid = t_lo + size / 2;
12429 fwd_buf.copy_from_slice(fwd_init);
12430 std::ptr::copy_nonoverlapping(
12431 anchor_carry.as_ptr(),
12432 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12433 cb,
12434 );
12435 for cur_iter in t_lo..mid {
12436 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12437 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12438 let xb = x_psb as usize;
12439 std::ptr::copy_nonoverlapping(
12440 base.add(outer_xs_off + cur_iter * xb),
12441 fwd_buf.as_mut_ptr().add(*fb_x_off),
12442 xb,
12443 );
12444 }
12445 execute_thunks(fwd_sched, fwd_buf);
12446 if fwd_output_off != fwd_carry_in_off {
12447 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12448 }
12449 }
12450 let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
12451
12452 griewank_process_segment(
12456 mid,
12457 t_hi,
12458 &mid_carry,
12459 cb,
12460 fwd_sched,
12461 fwd_init,
12462 fwd_carry_in_off,
12463 fwd_output_off,
12464 fwd_x_offs,
12465 base,
12466 outer_xs_offs,
12467 fwd_buf,
12468 leaf_threshold,
12469 process_iter,
12470 );
12471 griewank_process_segment(
12473 t_lo,
12474 mid - 1,
12475 anchor_carry,
12476 cb,
12477 fwd_sched,
12478 fwd_init,
12479 fwd_carry_in_off,
12480 fwd_output_off,
12481 fwd_x_offs,
12482 base,
12483 outer_xs_offs,
12484 fwd_buf,
12485 leaf_threshold,
12486 process_iter,
12487 );
12488 }
12489}
12490
12491pub unsafe fn execute_fft1d_f64(
12508 src: usize,
12509 dst: usize,
12510 outer: usize,
12511 n_complex: usize,
12512 inverse: bool,
12513 base: *mut u8,
12514) {
12515 let row_elems = 2 * n_complex;
12516 let mut re = vec![0f64; n_complex];
12517 let mut im = vec![0f64; n_complex];
12518 let mut scratch = if n_complex.is_power_of_two() {
12521 BluesteinScratchF64::empty()
12522 } else {
12523 BluesteinScratchF64::build(n_complex, inverse)
12524 };
12525 for o in 0..outer {
12526 let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
12527 let s = unsafe { sl_f64(row_offset, base, row_elems) };
12528 re.copy_from_slice(&s[..n_complex]);
12529 im.copy_from_slice(&s[n_complex..]);
12530 if n_complex.is_power_of_two() {
12531 fft_radix2_inplace_f64(&mut re, &mut im, inverse);
12532 } else {
12533 fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
12534 }
12535 let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
12536 let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
12537 d[..n_complex].copy_from_slice(&re);
12538 d[n_complex..].copy_from_slice(&im);
12539 }
12540}
12541
12542pub unsafe fn execute_gated_delta_net_f32(
12551 q: usize,
12552 k: usize,
12553 v: usize,
12554 g: usize,
12555 beta: usize,
12556 state: usize,
12557 dst: usize,
12558 batch: usize,
12559 seq: usize,
12560 heads: usize,
12561 state_size: usize,
12562 base: *mut u8,
12563) {
12564 use rayon::prelude::*;
12565
12566 #[derive(Copy, Clone)]
12567 struct ArenaPtr(usize);
12568 unsafe impl Send for ArenaPtr {}
12569 unsafe impl Sync for ArenaPtr {}
12570 impl ArenaPtr {
12571 #[inline]
12572 fn get(self) -> *mut u8 {
12573 self.0 as *mut u8
12574 }
12575 }
12576
12577 unsafe {
12578 let arena = ArenaPtr(base as usize);
12579 let (b, s, h, n) = (batch, seq, heads, state_size);
12580 let scale = 1.0f32 / (n as f32).sqrt();
12581 let use_external = state != 0;
12582 let mut owned_state = vec![0f32; h * n * n];
12583
12584 crate::pool::num_threads();
12585
12586 assert!(
12587 n <= crate::gdn::GDN_MAX_STATE,
12588 "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
12589 crate::gdn::GDN_MAX_STATE
12590 );
12591
12592 let qs = sl(q, arena.get(), b * s * h * n);
12593 let ks = sl(k, arena.get(), b * s * h * n);
12594 let vs = sl(v, arena.get(), b * s * h * n);
12595 let gs = sl(g, arena.get(), b * s * h);
12596 let betas = sl(beta, arena.get(), b * s * h);
12597 let _out = sl_mut(dst, arena.get(), b * s * h * n);
12598 let hs_n = h * n;
12599
12600 let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
12601 for ti in 0..s {
12602 let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
12603 let gb_step = bi * s * h + ti * h + hi;
12604 let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
12605 crate::gdn::gdn_step_blas(
12606 s_mat,
12607 &qs[qkv_step..qkv_step + n],
12608 &ks[qkv_step..qkv_step + n],
12609 &vs[qkv_step..qkv_step + n],
12610 gs[gb_step],
12611 betas[gb_step],
12612 out_row,
12613 sk,
12614 n,
12615 scale,
12616 );
12617 }
12618 };
12619
12620 if !use_external && s > 1 {
12623 for bi in 0..b {
12624 (0..h).into_par_iter().for_each(|hi| {
12625 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12626 let sk = &mut sk_buf[..n];
12627 let mut local_state =
12628 [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
12629 let s_mat = &mut local_state[..n * n];
12630 s_mat.fill(0.0);
12631 run_head(bi, hi, s_mat, sk);
12632 });
12633 }
12634 return;
12635 }
12636
12637 if use_external {
12638 let state_bytes = state;
12639 (0..b * h).into_par_iter().for_each(|bhi| {
12640 let bi = bhi / h;
12641 let hi = bhi % h;
12642 let elem_off = bi * h * n * n + hi * n * n;
12643 let s_mat = sl_mut(
12644 state_bytes + elem_off * std::mem::size_of::<f32>(),
12645 arena.get(),
12646 n * n,
12647 );
12648 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12649 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12650 });
12651 } else {
12652 for bi in 0..b {
12653 owned_state.fill(0.0);
12654 owned_state
12655 .par_chunks_mut(n * n)
12656 .enumerate()
12657 .for_each(|(hi, s_mat)| {
12658 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12659 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12660 });
12661 }
12662 }
12663 }
12664}
12665
12666pub unsafe fn execute_rms_norm_backward_input_f32(
12668 x: usize,
12669 gamma: usize,
12670 beta: usize,
12671 dy: usize,
12672 dx: usize,
12673 rows: u32,
12674 h: u32,
12675 eps: f32,
12676 base: *mut u8,
12677) {
12678 let (rows, h) = (rows as usize, h as usize);
12679 let mut dg = vec![0f32; h];
12680 let mut db = vec![0f32; h];
12681 let xs = sl(x, base, rows * h);
12682 let dys = sl(dy, base, rows * h);
12683 let g = sl(gamma, base, h);
12684 let b = sl(beta, base, h);
12685 let out = sl_mut(dx, base, rows * h);
12686 for r in 0..rows {
12687 crate::training_bwd::rms_norm_backward_row(
12688 &xs[r * h..(r + 1) * h],
12689 g,
12690 b,
12691 &dys[r * h..(r + 1) * h],
12692 &mut out[r * h..(r + 1) * h],
12693 &mut dg,
12694 &mut db,
12695 eps,
12696 );
12697 }
12698}
12699
12700pub unsafe fn execute_rms_norm_backward_gamma_f32(
12701 x: usize,
12702 gamma: usize,
12703 beta: usize,
12704 dy: usize,
12705 dgamma: usize,
12706 rows: u32,
12707 h: u32,
12708 eps: f32,
12709 base: *mut u8,
12710) {
12711 let (rows, h) = (rows as usize, h as usize);
12712 let out = sl_mut(dgamma, base, h);
12713 out.fill(0.0);
12714 let mut dx = vec![0f32; h];
12715 let mut db = vec![0f32; h];
12716 let xs = sl(x, base, rows * h);
12717 let dys = sl(dy, base, rows * h);
12718 let g = sl(gamma, base, h);
12719 let b = sl(beta, base, h);
12720 for r in 0..rows {
12721 crate::training_bwd::rms_norm_backward_row(
12722 &xs[r * h..(r + 1) * h],
12723 g,
12724 b,
12725 &dys[r * h..(r + 1) * h],
12726 &mut dx,
12727 out,
12728 &mut db,
12729 eps,
12730 );
12731 }
12732}
12733
12734pub unsafe fn execute_rms_norm_backward_beta_f32(
12735 x: usize,
12736 gamma: usize,
12737 beta: usize,
12738 dy: usize,
12739 dbeta: usize,
12740 rows: u32,
12741 h: u32,
12742 eps: f32,
12743 base: *mut u8,
12744) {
12745 let (rows, h) = (rows as usize, h as usize);
12746 let out = sl_mut(dbeta, base, h);
12747 out.fill(0.0);
12748 let mut dx = vec![0f32; h];
12749 let mut dg = vec![0f32; h];
12750 let xs = sl(x, base, rows * h);
12751 let dys = sl(dy, base, rows * h);
12752 let g = sl(gamma, base, h);
12753 let b = sl(beta, base, h);
12754 for r in 0..rows {
12755 crate::training_bwd::rms_norm_backward_row(
12756 &xs[r * h..(r + 1) * h],
12757 g,
12758 b,
12759 &dys[r * h..(r + 1) * h],
12760 &mut dx,
12761 &mut dg,
12762 out,
12763 eps,
12764 );
12765 }
12766}
12767
12768pub unsafe fn execute_rope_backward_f32(
12769 dy: usize,
12770 cos: usize,
12771 sin: usize,
12772 dx: usize,
12773 batch: u32,
12774 seq: u32,
12775 hidden: u32,
12776 head_dim: u32,
12777 n_rot: u32,
12778 cos_len: u32,
12779 base: *mut u8,
12780) {
12781 let (b, s, hs, dh, nr, cl) = (
12782 batch as usize,
12783 seq as usize,
12784 hidden as usize,
12785 head_dim as usize,
12786 n_rot as usize,
12787 cos_len as usize,
12788 );
12789 let nh = hs / dh;
12790 let tab_half = dh / 2;
12791 let dys = sl(dy, base, b * s * hs);
12792 let cos_tab = sl(cos, base, cl);
12793 let sin_tab = sl(sin, base, cl);
12794 let out = sl_mut(dx, base, b * s * hs);
12795 for bi in 0..b {
12796 for si in 0..s {
12797 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12798 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12799 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12800 for hi in 0..nh {
12801 let base_idx = bi * s * hs + si * hs + hi * dh;
12802 crate::training_bwd::rope_backward_row(
12803 &dys[base_idx..base_idx + dh],
12804 cp,
12805 sp,
12806 &mut out[base_idx..base_idx + dh],
12807 dh,
12808 nr,
12809 );
12810 }
12811 }
12812 }
12813}
12814
12815pub unsafe fn execute_cumsum_backward_f32(
12816 dy: usize,
12817 dx: usize,
12818 rows: u32,
12819 cols: u32,
12820 exclusive: bool,
12821 base: *mut u8,
12822) {
12823 let (rows, cols) = (rows as usize, cols as usize);
12824 let dys = sl(dy, base, rows * cols);
12825 let out = sl_mut(dx, base, rows * cols);
12826 for r in 0..rows {
12827 crate::training_bwd::cumsum_backward_row(
12828 &dys[r * cols..(r + 1) * cols],
12829 &mut out[r * cols..(r + 1) * cols],
12830 exclusive,
12831 );
12832 }
12833}
12834
12835pub unsafe fn execute_gather_backward_f32(
12836 dy: usize,
12837 indices: usize,
12838 dst: usize,
12839 outer: u32,
12840 axis_dim: u32,
12841 num_idx: u32,
12842 trailing: u32,
12843 base: *mut u8,
12844) {
12845 let (outer, axis_dim, num_idx, trailing) = (
12846 outer as usize,
12847 axis_dim as usize,
12848 num_idx as usize,
12849 trailing as usize,
12850 );
12851 let out = sl_mut(dst, base, outer * axis_dim * trailing);
12852 out.fill(0.0);
12853 crate::training_bwd::gather_axis_backward(
12854 sl(dy, base, outer * num_idx * trailing),
12855 sl(indices, base, num_idx),
12856 out,
12857 outer,
12858 axis_dim,
12859 num_idx,
12860 trailing,
12861 );
12862}
12863
12864pub unsafe fn execute_dequant_matmul_gguf_f32(
12866 x: usize,
12867 w_q: usize,
12868 dst: usize,
12869 m: usize,
12870 k: usize,
12871 n: usize,
12872 scheme: rlx_ir::quant::QuantScheme,
12873 base: *mut u8,
12874) {
12875 unsafe {
12876 let block_bytes = scheme.gguf_block_bytes() as usize;
12877 let block_elems = scheme.gguf_block_size() as usize;
12878 let total_bytes = (k * n) / block_elems * block_bytes;
12879 let xs = sl(x, base, m * k);
12880 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
12881 let out = sl_mut(dst, base, m * n);
12882 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
12883 }
12884}
12885
12886pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
12888 input: usize,
12889 w_q: usize,
12890 expert_idx: usize,
12891 dst: usize,
12892 m: usize,
12893 k: usize,
12894 n: usize,
12895 num_experts: usize,
12896 scheme: rlx_ir::quant::QuantScheme,
12897 base: *mut u8,
12898) {
12899 unsafe {
12900 let block_bytes = scheme.gguf_block_bytes() as usize;
12901 let block_elems = scheme.gguf_block_size() as usize;
12902 let slab_bytes = (k * n) / block_elems * block_bytes;
12903 let xs = sl(input, base, m * k);
12904 let w_bytes =
12905 std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
12906 let ids = sl(expert_idx, base, m);
12907 let out = sl_mut(dst, base, m * n);
12908 crate::gguf_matmul::gguf_grouped_matmul_bt(
12909 xs,
12910 w_bytes,
12911 ids,
12912 out,
12913 m,
12914 k,
12915 n,
12916 num_experts,
12917 scheme,
12918 );
12919 }
12920}
12921
12922pub unsafe fn execute_dequant_matmul_int4_f32(
12924 x: usize,
12925 w_q: usize,
12926 scale: usize,
12927 zp: usize,
12928 dst: usize,
12929 m: usize,
12930 k: usize,
12931 n: usize,
12932 block_size: u32,
12933 is_asymmetric: bool,
12934 base: *mut u8,
12935) {
12936 let bs = block_size as usize;
12937 let n_blocks = k.div_ceil(bs);
12938 unsafe {
12939 let xs = sl(x, base, m * k);
12940 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12941 let scales = sl(scale, base, n_blocks * n);
12942 let zps = if is_asymmetric {
12943 sl(zp, base, n_blocks * n)
12944 } else {
12945 &[][..]
12946 };
12947 let out = sl_mut(dst, base, m * n);
12948 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
12949 }
12950}
12951
12952pub unsafe fn execute_dequant_matmul_fp8_f32(
12954 x: usize,
12955 w_q: usize,
12956 scale: usize,
12957 dst: usize,
12958 m: usize,
12959 k: usize,
12960 n: usize,
12961 e5m2: bool,
12962 base: *mut u8,
12963) {
12964 unsafe {
12965 let xs = sl(x, base, m * k);
12966 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
12967 let scales = sl(scale, base, n);
12968 let out = sl_mut(dst, base, m * n);
12969 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
12970 }
12971}
12972
12973pub unsafe fn execute_dequant_matmul_nvfp4_f32(
12975 x: usize,
12976 w_q: usize,
12977 scale: usize,
12978 global_scale: usize,
12979 dst: usize,
12980 m: usize,
12981 k: usize,
12982 n: usize,
12983 base: *mut u8,
12984) {
12985 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
12986 unsafe {
12987 let xs = sl(x, base, m * k);
12988 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12989 let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
12990 let gs = sl(global_scale, base, 1)[0];
12991 let out = sl_mut(dst, base, m * n);
12992 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
12993 }
12994}
12995
12996pub unsafe fn execute_gated_delta_net_f16(
12998 q: usize,
12999 k: usize,
13000 v: usize,
13001 g: usize,
13002 beta: usize,
13003 state: usize,
13004 dst: usize,
13005 batch: usize,
13006 seq: usize,
13007 heads: usize,
13008 state_size: usize,
13009 base: *mut u8,
13010) {
13011 use half::f16;
13012 unsafe {
13013 let read_f16 = |off: usize, len: usize| -> Vec<f32> {
13014 let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
13015 raw.chunks_exact(2)
13016 .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
13017 .collect()
13018 };
13019 let write_f16 = |off: usize, data: &[f32]| {
13020 let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
13021 for (i, &v) in data.iter().enumerate() {
13022 let le = f16::from_f32(v).to_le_bytes();
13023 out[i * 2] = le[0];
13024 out[i * 2 + 1] = le[1];
13025 }
13026 };
13027
13028 let (b, s, h, n) = (batch, seq, heads, state_size);
13029 let q_f = read_f16(q, b * s * h * n);
13030 let k_f = read_f16(k, b * s * h * n);
13031 let v_f = read_f16(v, b * s * h * n);
13032 let g_f = read_f16(g, b * s * h);
13033 let b_f = read_f16(beta, b * s * h);
13034 let mut state_f = if state != 0 {
13035 read_f16(state, b * h * n * n)
13036 } else {
13037 vec![0f32; b * h * n * n]
13038 };
13039 let mut out_f = vec![0f32; b * s * h * n];
13040 let scale = 1.0f32 / (n as f32).sqrt();
13041 let mut sk_buf = vec![0f32; n];
13042 let mut owned_state = vec![0f32; h * n * n];
13043
13044 for bi in 0..b {
13045 let state_slice: &mut [f32] = if state != 0 {
13046 let start = bi * h * n * n;
13047 &mut state_f[start..start + h * n * n]
13048 } else {
13049 owned_state.fill(0.0);
13050 &mut owned_state
13051 };
13052
13053 for ti in 0..s {
13054 let qkv_step_base = bi * s * h * n + ti * h * n;
13055 let gb_step_base = bi * s * h + ti * h;
13056
13057 for hi in 0..h {
13058 let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13059 let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13060 let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13061 let g_t = g_f[gb_step_base + hi];
13062 let beta_t = b_f[gb_step_base + hi];
13063
13064 let s_base = hi * n * n;
13065 let s_mat = &mut state_slice[s_base..s_base + n * n];
13066
13067 let g_exp = g_t.exp();
13068 for st in s_mat.iter_mut() {
13069 *st *= g_exp;
13070 }
13071
13072 for j in 0..n {
13073 let mut acc = 0f32;
13074 for i in 0..n {
13075 acc += s_mat[i * n + j] * k_row[i];
13076 }
13077 sk_buf[j] = acc;
13078 }
13079
13080 for j in 0..n {
13081 sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
13082 }
13083
13084 for i in 0..n {
13085 let ki = k_row[i];
13086 for j in 0..n {
13087 s_mat[i * n + j] += ki * sk_buf[j];
13088 }
13089 }
13090
13091 let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13092 for j in 0..n {
13093 let mut acc = 0f32;
13094 for i in 0..n {
13095 acc += s_mat[i * n + j] * q_row[i];
13096 }
13097 out_row[j] = acc * scale;
13098 }
13099 }
13100 }
13101 }
13102
13103 write_f16(dst, &out_f);
13104 if state != 0 {
13105 write_f16(state, &state_f);
13106 }
13107 }
13108}
13109
13110pub unsafe fn execute_group_norm_nchw_f32(
13112 src: usize,
13113 g: usize,
13114 b: usize,
13115 dst: usize,
13116 n: usize,
13117 c: usize,
13118 h: usize,
13119 w: usize,
13120 num_groups: usize,
13121 eps: f32,
13122 base: *mut u8,
13123) {
13124 let plane = c * h * w;
13125 for ni in 0..n {
13126 let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13127 let gamma = unsafe { sl(g, base, c) };
13128 let beta = unsafe { sl(b, base, c) };
13129 let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13130 crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
13131 }
13132}
13133
13134pub unsafe fn execute_layer_norm2d_nchw_f32(
13136 src: usize,
13137 g: usize,
13138 b: usize,
13139 dst: usize,
13140 n: usize,
13141 c: usize,
13142 h: usize,
13143 w: usize,
13144 eps: f32,
13145 base: *mut u8,
13146) {
13147 let plane = c * h * w;
13148 unsafe {
13149 let input = sl(src, base, n * plane);
13150 let gamma = sl(g, base, c);
13151 let beta = sl(b, base, c);
13152 let output = sl_mut(dst, base, n * plane);
13153 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
13154 }
13155}
13156
13157pub unsafe fn execute_conv_transpose2d_nchw_f32(
13159 src: usize,
13160 weight: usize,
13161 dst: usize,
13162 n: usize,
13163 c_in: usize,
13164 h: usize,
13165 w_in: usize,
13166 c_out: usize,
13167 h_out: usize,
13168 w_out: usize,
13169 kh: usize,
13170 kw: usize,
13171 sh: usize,
13172 sw: usize,
13173 ph: usize,
13174 pw: usize,
13175 dh: usize,
13176 dw: usize,
13177 groups: usize,
13178 base: *mut u8,
13179) {
13180 let in_elems = n * c_in * h * w_in;
13181 let w_elems = c_in * (c_out / groups) * kh * kw;
13182 let out_elems = n * c_out * h_out * w_out;
13183 unsafe {
13184 let input = sl(src, base, in_elems);
13185 let wt = sl(weight, base, w_elems);
13186 let output = sl_mut(dst, base, out_elems);
13187 crate::kernels::conv_transpose2d_nchw(
13188 input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
13189 dw, groups,
13190 );
13191 }
13192}
13193
13194pub unsafe fn execute_resize_nearest_2x_f32(
13196 src: usize,
13197 dst: usize,
13198 n: usize,
13199 c: usize,
13200 h: usize,
13201 w: usize,
13202 base: *mut u8,
13203) {
13204 let in_plane = c * h * w;
13205 let out_plane = c * h * 2 * w * 2;
13206 for ni in 0..n {
13207 let input = unsafe {
13208 sl(
13209 src + ni * in_plane * std::mem::size_of::<f32>(),
13210 base,
13211 in_plane,
13212 )
13213 };
13214 let output = unsafe {
13215 sl_mut(
13216 dst + ni * out_plane * std::mem::size_of::<f32>(),
13217 base,
13218 out_plane,
13219 )
13220 };
13221 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
13222 }
13223}
13224
13225pub unsafe fn execute_axial_rope2d_f32(
13227 src: usize,
13228 dst: usize,
13229 batch: usize,
13230 seq: usize,
13231 hidden: usize,
13232 end_x: usize,
13233 end_y: usize,
13234 head_dim: usize,
13235 num_heads: usize,
13236 theta: f32,
13237 repeat_factor: usize,
13238 base: *mut u8,
13239) {
13240 let plane = seq * hidden;
13241 let plane_bytes = plane * std::mem::size_of::<f32>();
13242 for bi in 0..batch {
13243 let in_off = src + bi * plane_bytes;
13244 let input = unsafe { sl(in_off, base, plane) };
13245 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
13246 input,
13247 num_heads,
13248 seq,
13249 head_dim,
13250 end_x,
13251 end_y,
13252 theta,
13253 repeat_factor,
13254 );
13255 let out_off = dst + bi * plane_bytes;
13256 let output = unsafe { sl_mut(out_off, base, plane) };
13257 output.copy_from_slice(&rotated);
13258 }
13259}
13260
13261pub unsafe fn execute_fft1d_f32(
13263 src: usize,
13264 dst: usize,
13265 outer: usize,
13266 n_complex: usize,
13267 inverse: bool,
13268 base: *mut u8,
13269) {
13270 let row_elems = 2 * n_complex;
13271 let mut re = vec![0f32; n_complex];
13272 let mut im = vec![0f32; n_complex];
13273 let mut scratch = if n_complex.is_power_of_two() {
13274 BluesteinScratchF32::empty()
13275 } else {
13276 BluesteinScratchF32::build(n_complex, inverse)
13277 };
13278 for o in 0..outer {
13279 let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
13280 let s = unsafe { sl(row_offset, base, row_elems) };
13281 re.copy_from_slice(&s[..n_complex]);
13282 im.copy_from_slice(&s[n_complex..]);
13283 if n_complex.is_power_of_two() {
13284 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
13285 } else {
13286 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
13287 }
13288 let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
13289 let d = unsafe { sl_mut(dst_offset, base, row_elems) };
13290 d[..n_complex].copy_from_slice(&re);
13291 d[n_complex..].copy_from_slice(&im);
13292 }
13293}
13294
13295fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
13300 let n = re.len();
13301 debug_assert_eq!(im.len(), n);
13302 debug_assert!(
13303 n.is_power_of_two(),
13304 "fft_radix2_f32: n={n} must be a power of two"
13305 );
13306 if n <= 1 {
13307 return;
13308 }
13309
13310 let mut j = 0usize;
13311 for i in 1..n {
13312 let mut bit = n >> 1;
13313 while j & bit != 0 {
13314 j ^= bit;
13315 bit >>= 1;
13316 }
13317 j ^= bit;
13318 if i < j {
13319 re.swap(i, j);
13320 im.swap(i, j);
13321 }
13322 }
13323
13324 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13325 let mut len = 2usize;
13326 while len <= n {
13327 let half = len / 2;
13328 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13329 let w_re_step = theta.cos();
13330 let w_im_step = theta.sin();
13331 let mut i = 0usize;
13332 while i < n {
13333 let mut wre = 1.0_f64;
13334 let mut wim = 0.0_f64;
13335 for k in 0..half {
13336 let wre_f = wre as f32;
13337 let wim_f = wim as f32;
13338 let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
13339 let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
13340 let u_re = re[i + k];
13341 let u_im = im[i + k];
13342 re[i + k] = u_re + t_re;
13343 im[i + k] = u_im + t_im;
13344 re[i + k + half] = u_re - t_re;
13345 im[i + k + half] = u_im - t_im;
13346 let new_wre = wre * w_re_step - wim * w_im_step;
13347 let new_wim = wre * w_im_step + wim * w_re_step;
13348 wre = new_wre;
13349 wim = new_wim;
13350 }
13351 i += len;
13352 }
13353 len <<= 1;
13354 }
13355}
13356
13357fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
13361 let n = re.len();
13362 debug_assert_eq!(im.len(), n);
13363 debug_assert!(
13364 n.is_power_of_two(),
13365 "fft_radix2: n={n} must be a power of two"
13366 );
13367 if n <= 1 {
13368 return;
13369 }
13370
13371 let mut j = 0usize;
13373 for i in 1..n {
13374 let mut bit = n >> 1;
13375 while j & bit != 0 {
13376 j ^= bit;
13377 bit >>= 1;
13378 }
13379 j ^= bit;
13380 if i < j {
13381 re.swap(i, j);
13382 im.swap(i, j);
13383 }
13384 }
13385
13386 let sign = if inverse { 1.0 } else { -1.0 };
13388 let mut len = 2usize;
13389 while len <= n {
13390 let half = len / 2;
13391 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13392 let w_re_step = theta.cos();
13393 let w_im_step = theta.sin();
13394 let mut i = 0usize;
13395 while i < n {
13396 let mut wre = 1.0_f64;
13398 let mut wim = 0.0_f64;
13399 for k in 0..half {
13400 let t_re = wre * re[i + k + half] - wim * im[i + k + half];
13401 let t_im = wre * im[i + k + half] + wim * re[i + k + half];
13402 let u_re = re[i + k];
13403 let u_im = im[i + k];
13404 re[i + k] = u_re + t_re;
13405 im[i + k] = u_im + t_im;
13406 re[i + k + half] = u_re - t_re;
13407 im[i + k + half] = u_im - t_im;
13408 let new_wre = wre * w_re_step - wim * w_im_step;
13409 let new_wim = wre * w_im_step + wim * w_re_step;
13410 wre = new_wre;
13411 wim = new_wim;
13412 }
13413 i += len;
13414 }
13415 len <<= 1;
13416 }
13417}
13418
13419struct BluesteinScratchF64 {
13423 m: usize,
13425 w_re: Vec<f64>,
13429 w_im: Vec<f64>,
13430 bf_re: Vec<f64>,
13433 bf_im: Vec<f64>,
13434 ar: Vec<f64>,
13436 ai: Vec<f64>,
13437}
13438
13439impl BluesteinScratchF64 {
13440 fn empty() -> Self {
13441 Self {
13442 m: 0,
13443 w_re: Vec::new(),
13444 w_im: Vec::new(),
13445 bf_re: Vec::new(),
13446 bf_im: Vec::new(),
13447 ar: Vec::new(),
13448 ai: Vec::new(),
13449 }
13450 }
13451
13452 fn build(n: usize, inverse: bool) -> Self {
13453 let m = if n <= 1 {
13456 1
13457 } else {
13458 (2 * n - 1).next_power_of_two()
13459 };
13460
13461 let mod_2n = (2 * n) as u64;
13464 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13465 let mut w_re = vec![0.0_f64; n];
13466 let mut w_im = vec![0.0_f64; n];
13467 for k in 0..n {
13468 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13469 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13470 w_re[k] = theta.cos();
13471 w_im[k] = theta.sin();
13472 }
13473
13474 let mut bf_re = vec![0.0_f64; m];
13477 let mut bf_im = vec![0.0_f64; m];
13478 if n > 0 {
13479 bf_re[0] = w_re[0];
13480 bf_im[0] = -w_im[0];
13481 for k in 1..n {
13482 bf_re[k] = w_re[k];
13483 bf_im[k] = -w_im[k];
13484 bf_re[m - k] = w_re[k];
13485 bf_im[m - k] = -w_im[k];
13486 }
13487 }
13488 if m > 1 {
13489 fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
13490 }
13491
13492 Self {
13493 m,
13494 w_re,
13495 w_im,
13496 bf_re,
13497 bf_im,
13498 ar: vec![0.0_f64; m],
13499 ai: vec![0.0_f64; m],
13500 }
13501 }
13502}
13503
13504fn fft_bluestein_inplace_f64(
13513 re: &mut [f64],
13514 im: &mut [f64],
13515 _inverse: bool,
13516 s: &mut BluesteinScratchF64,
13517) {
13518 let n = re.len();
13519 debug_assert_eq!(im.len(), n);
13520 debug_assert_eq!(s.w_re.len(), n);
13521 if n <= 1 {
13522 return;
13523 }
13524 let m = s.m;
13525
13526 for k in 0..m {
13528 s.ar[k] = 0.0;
13529 s.ai[k] = 0.0;
13530 }
13531 for k in 0..n {
13532 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13533 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13534 }
13535
13536 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
13538
13539 for k in 0..m {
13541 let ar = s.ar[k];
13542 let ai = s.ai[k];
13543 let br = s.bf_re[k];
13544 let bi = s.bf_im[k];
13545 s.ar[k] = ar * br - ai * bi;
13546 s.ai[k] = ar * bi + ai * br;
13547 }
13548
13549 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
13552 let inv_m = 1.0 / (m as f64);
13553
13554 for k in 0..n {
13556 let yr = s.ar[k] * inv_m;
13557 let yi = s.ai[k] * inv_m;
13558 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13559 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13560 }
13561}
13562
13563struct BluesteinScratchF32 {
13567 m: usize,
13568 w_re: Vec<f32>,
13569 w_im: Vec<f32>,
13570 bf_re: Vec<f32>,
13571 bf_im: Vec<f32>,
13572 ar: Vec<f32>,
13573 ai: Vec<f32>,
13574}
13575
13576impl BluesteinScratchF32 {
13577 fn empty() -> Self {
13578 Self {
13579 m: 0,
13580 w_re: Vec::new(),
13581 w_im: Vec::new(),
13582 bf_re: Vec::new(),
13583 bf_im: Vec::new(),
13584 ar: Vec::new(),
13585 ai: Vec::new(),
13586 }
13587 }
13588
13589 fn build(n: usize, inverse: bool) -> Self {
13590 let m = if n <= 1 {
13591 1
13592 } else {
13593 (2 * n - 1).next_power_of_two()
13594 };
13595
13596 let mod_2n = (2 * n) as u64;
13597 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13598 let mut w_re = vec![0.0_f32; n];
13599 let mut w_im = vec![0.0_f32; n];
13600 for k in 0..n {
13601 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13602 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13603 w_re[k] = theta.cos() as f32;
13604 w_im[k] = theta.sin() as f32;
13605 }
13606
13607 let mut bf_re = vec![0.0_f32; m];
13608 let mut bf_im = vec![0.0_f32; m];
13609 if n > 0 {
13610 bf_re[0] = w_re[0];
13611 bf_im[0] = -w_im[0];
13612 for k in 1..n {
13613 bf_re[k] = w_re[k];
13614 bf_im[k] = -w_im[k];
13615 bf_re[m - k] = w_re[k];
13616 bf_im[m - k] = -w_im[k];
13617 }
13618 }
13619 if m > 1 {
13620 fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
13621 }
13622
13623 Self {
13624 m,
13625 w_re,
13626 w_im,
13627 bf_re,
13628 bf_im,
13629 ar: vec![0.0_f32; m],
13630 ai: vec![0.0_f32; m],
13631 }
13632 }
13633}
13634
13635fn fft_bluestein_inplace_f32(
13636 re: &mut [f32],
13637 im: &mut [f32],
13638 _inverse: bool,
13639 s: &mut BluesteinScratchF32,
13640) {
13641 let n = re.len();
13642 debug_assert_eq!(im.len(), n);
13643 debug_assert_eq!(s.w_re.len(), n);
13644 if n <= 1 {
13645 return;
13646 }
13647 let m = s.m;
13648
13649 for k in 0..m {
13650 s.ar[k] = 0.0;
13651 s.ai[k] = 0.0;
13652 }
13653 for k in 0..n {
13654 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13655 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13656 }
13657
13658 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
13659
13660 for k in 0..m {
13661 let ar = s.ar[k];
13662 let ai = s.ai[k];
13663 let br = s.bf_re[k];
13664 let bi = s.bf_im[k];
13665 s.ar[k] = ar * br - ai * bi;
13666 s.ai[k] = ar * bi + ai * br;
13667 }
13668
13669 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
13670 let inv_m = 1.0_f32 / (m as f32);
13671
13672 for k in 0..n {
13673 let yr = s.ar[k] * inv_m;
13674 let yi = s.ai[k] * inv_m;
13675 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13676 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13677 }
13678}
13679
13680unsafe fn dispatch_custom_op(
13686 kernel: &dyn crate::op_registry::CpuKernel,
13687 inputs: &[(usize, u32, Shape)],
13688 out_off: usize,
13689 out_len: u32,
13690 out_shape: &Shape,
13691 attrs: &[u8],
13692 base: *mut u8,
13693) {
13694 use crate::op_registry::{CpuTensorMut, CpuTensorRef};
13695 use rlx_ir::DType;
13696
13697 macro_rules! build_in_view {
13702 ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
13703 CpuTensorRef::$variant {
13704 data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
13705 shape: $shape,
13706 }
13707 };
13708 }
13709 macro_rules! build_out_view {
13710 ($variant:ident, $rust_ty:ty) => {
13711 CpuTensorMut::$variant {
13712 data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
13713 shape: out_shape,
13714 }
13715 };
13716 }
13717
13718 let in_views: Vec<CpuTensorRef<'_>> = inputs
13719 .iter()
13720 .map(|(off, len, shape)| {
13721 let n = *len as usize;
13722 let off = *off;
13723 match shape.dtype() {
13724 DType::F32 => build_in_view!(shape, off, n, F32, f32),
13725 DType::F64 => build_in_view!(shape, off, n, F64, f64),
13726 DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
13727 DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
13728 DType::I8 => build_in_view!(shape, off, n, I8, i8),
13729 DType::I16 => build_in_view!(shape, off, n, I16, i16),
13730 DType::I32 => build_in_view!(shape, off, n, I32, i32),
13731 DType::I64 => build_in_view!(shape, off, n, I64, i64),
13732 DType::U8 => build_in_view!(shape, off, n, U8, u8),
13733 DType::U32 => build_in_view!(shape, off, n, U32, u32),
13734 DType::Bool => build_in_view!(shape, off, n, Bool, u8),
13735 DType::C64 => panic!(
13739 "Op::Custom kernel input has DType::C64 — built-in \
13740 complex ops handle their own kernels; user-registered \
13741 ops don't yet see complex tensors"
13742 ),
13743 }
13744 })
13745 .collect();
13746
13747 let result = match out_shape.dtype() {
13748 DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
13749 DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
13750 DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
13751 DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
13752 DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
13753 DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
13754 DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
13755 DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
13756 DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
13757 DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
13758 DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
13759 DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
13760 };
13761 if let Err(e) = result {
13762 panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
13763 }
13764}
13765
13766#[inline(always)]
13772unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
13773 if offset == usize::MAX {
13774 return &[];
13775 }
13776 unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
13777}
13778
13779#[inline(always)]
13780unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
13781 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
13782}
13783
13784#[inline(always)]
13786fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
13790 use rlx_ir::op::Activation;
13791 match act {
13792 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
13793 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
13794 Activation::Silu => crate::kernels::par_silu_inplace(d),
13795 Activation::Relu => {
13796 for v in d.iter_mut() {
13797 *v = v.max(0.0);
13798 }
13799 }
13800 Activation::Sigmoid => {
13801 for v in d.iter_mut() {
13802 *v = 1.0 / (1.0 + (-*v).exp());
13803 }
13804 }
13805 Activation::Tanh => {
13806 for v in d.iter_mut() {
13807 *v = v.tanh();
13808 }
13809 }
13810 Activation::Exp => {
13811 for v in d.iter_mut() {
13812 *v = v.exp();
13813 }
13814 }
13815 Activation::Log => {
13816 for v in d.iter_mut() {
13817 *v = v.ln();
13818 }
13819 }
13820 Activation::Sqrt => {
13821 for v in d.iter_mut() {
13822 *v = v.sqrt();
13823 }
13824 }
13825 Activation::Rsqrt => {
13826 for v in d.iter_mut() {
13827 *v = 1.0 / v.sqrt();
13828 }
13829 }
13830 Activation::Neg => {
13831 for v in d.iter_mut() {
13832 *v = -*v;
13833 }
13834 }
13835 Activation::Abs => {
13836 for v in d.iter_mut() {
13837 *v = v.abs();
13838 }
13839 }
13840 Activation::Round => {
13841 for v in d.iter_mut() {
13842 *v = v.round();
13843 }
13844 }
13845 Activation::Sin => {
13846 for v in d.iter_mut() {
13847 *v = v.sin();
13848 }
13849 }
13850 Activation::Cos => {
13851 for v in d.iter_mut() {
13852 *v = v.cos();
13853 }
13854 }
13855 Activation::Tan => {
13856 for v in d.iter_mut() {
13857 *v = v.tan();
13858 }
13859 }
13860 Activation::Atan => {
13861 for v in d.iter_mut() {
13862 *v = v.atan();
13863 }
13864 }
13865 }
13866}
13867
13868#[allow(clippy::too_many_arguments)]
13877fn im2col(
13878 x: &[f32],
13879 col: &mut [f32],
13880 c_in: usize,
13881 h: usize,
13882 w: usize,
13883 h_out: usize,
13884 w_out: usize,
13885 kh: usize,
13886 kw: usize,
13887 sh: usize,
13888 sw: usize,
13889 ph: usize,
13890 pw: usize,
13891 dh: usize,
13892 dw_dil: usize,
13893) {
13894 let n_dim = h_out * w_out;
13895 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
13896 debug_assert_eq!(x.len(), c_in * h * w);
13897 let h_isz = h as isize;
13898 let w_isz = w as isize;
13899 let ph_isz = ph as isize;
13900 let pw_isz = pw as isize;
13901 for ci in 0..c_in {
13902 for ki in 0..kh {
13903 for kj in 0..kw {
13904 let row = ((ci * kh) + ki) * kw + kj;
13905 let row_off = row * n_dim;
13906 for ho in 0..h_out {
13907 let hi = (ho * sh + ki * dh) as isize - ph_isz;
13908 if hi < 0 || hi >= h_isz {
13909 for wo in 0..w_out {
13910 col[row_off + ho * w_out + wo] = 0.0;
13911 }
13912 continue;
13913 }
13914 let hi = hi as usize;
13915 let in_row_off = (ci * h + hi) * w;
13916 for wo in 0..w_out {
13917 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
13918 col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
13919 0.0
13920 } else {
13921 x[in_row_off + wi as usize]
13922 };
13923 }
13924 }
13925 }
13926 }
13927 }
13928}
13929
13930#[allow(clippy::too_many_arguments)]
13937fn col2im(
13938 col: &[f32],
13939 x: &mut [f32],
13940 c_in: usize,
13941 h: usize,
13942 w: usize,
13943 h_out: usize,
13944 w_out: usize,
13945 kh: usize,
13946 kw: usize,
13947 sh: usize,
13948 sw: usize,
13949 ph: usize,
13950 pw: usize,
13951 dh: usize,
13952 dw_dil: usize,
13953) {
13954 let n_dim = h_out * w_out;
13955 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
13956 debug_assert_eq!(x.len(), c_in * h * w);
13957 let h_isz = h as isize;
13958 let w_isz = w as isize;
13959 let ph_isz = ph as isize;
13960 let pw_isz = pw as isize;
13961 for ci in 0..c_in {
13962 for ki in 0..kh {
13963 for kj in 0..kw {
13964 let row = ((ci * kh) + ki) * kw + kj;
13965 let row_off = row * n_dim;
13966 for ho in 0..h_out {
13967 let hi = (ho * sh + ki * dh) as isize - ph_isz;
13968 if hi < 0 || hi >= h_isz {
13969 continue;
13970 }
13971 let hi = hi as usize;
13972 let in_row_off = (ci * h + hi) * w;
13973 for wo in 0..w_out {
13974 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
13975 if wi < 0 || wi >= w_isz {
13976 continue;
13977 }
13978 x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
13979 }
13980 }
13981 }
13982 }
13983 }
13984}
13985
13986fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
13996 match axis {
13997 None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
13998 Some(d) => {
13999 let chan_dim = shape.dim(d).unwrap_static();
14000 let inner: usize = (d + 1..shape.rank())
14001 .map(|i| shape.dim(i).unwrap_static())
14002 .product::<usize>()
14003 .max(1);
14004 (d, chan_dim, inner)
14005 }
14006 }
14007}
14008
14009fn activation_backward_kernel(
14010 act: rlx_ir::op::Activation,
14011 xs: &[f32],
14012 dys: &[f32],
14013 out: &mut [f32],
14014) {
14015 use rlx_ir::op::Activation;
14016 let n = xs.len();
14017 debug_assert_eq!(dys.len(), n);
14018 debug_assert_eq!(out.len(), n);
14019 match act {
14020 Activation::Relu => {
14021 for i in 0..n {
14022 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14023 }
14024 }
14025 Activation::Sigmoid => {
14026 for i in 0..n {
14027 let s = 1.0 / (1.0 + (-xs[i]).exp());
14028 out[i] = s * (1.0 - s) * dys[i];
14029 }
14030 }
14031 Activation::Tanh => {
14032 for i in 0..n {
14033 let t = xs[i].tanh();
14034 out[i] = (1.0 - t * t) * dys[i];
14035 }
14036 }
14037 Activation::Silu => {
14038 for i in 0..n {
14040 let s = 1.0 / (1.0 + (-xs[i]).exp());
14041 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14042 }
14043 }
14044 Activation::Gelu => {
14045 const INV_SQRT2: f32 = 0.707_106_77;
14048 const INV_SQRT_2PI: f32 = 0.398_942_3;
14049 for i in 0..n {
14050 let x = xs[i];
14051 let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
14052 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14053 out[i] = (phi + x * pdf) * dys[i];
14054 }
14055 }
14056 Activation::GeluApprox => {
14057 const C: f32 = 0.797_884_6; const A: f32 = 0.044_715;
14061 for i in 0..n {
14062 let x = xs[i];
14063 let inner = C * (x + A * x * x * x);
14064 let t = inner.tanh();
14065 let dinner = C * (1.0 + 3.0 * A * x * x);
14066 let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
14067 out[i] = d * dys[i];
14068 }
14069 }
14070 Activation::Exp => {
14071 for i in 0..n {
14072 out[i] = xs[i].exp() * dys[i];
14073 }
14074 }
14075 Activation::Log => {
14076 for i in 0..n {
14077 out[i] = dys[i] / xs[i];
14078 }
14079 }
14080 Activation::Sqrt => {
14081 for i in 0..n {
14083 let s = xs[i].sqrt();
14084 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14085 }
14086 }
14087 Activation::Rsqrt => {
14088 for i in 0..n {
14090 let s = xs[i].sqrt();
14091 out[i] = if s > 0.0 {
14092 -0.5 * dys[i] / (xs[i] * s)
14093 } else {
14094 0.0
14095 };
14096 }
14097 }
14098 Activation::Neg => {
14099 for i in 0..n {
14100 out[i] = -dys[i];
14101 }
14102 }
14103 Activation::Abs => {
14104 for i in 0..n {
14106 let x = xs[i];
14107 let s = if x > 0.0 {
14108 1.0
14109 } else if x < 0.0 {
14110 -1.0
14111 } else {
14112 0.0
14113 };
14114 out[i] = s * dys[i];
14115 }
14116 }
14117 Activation::Round => {
14118 out.copy_from_slice(dys);
14123 }
14124 Activation::Sin => {
14125 for i in 0..n {
14127 out[i] = xs[i].cos() * dys[i];
14128 }
14129 }
14130 Activation::Cos => {
14131 for i in 0..n {
14132 out[i] = -xs[i].sin() * dys[i];
14133 }
14134 }
14135 Activation::Tan => {
14136 for i in 0..n {
14138 let t = xs[i].tan();
14139 out[i] = (1.0 + t * t) * dys[i];
14140 }
14141 }
14142 Activation::Atan => {
14143 for i in 0..n {
14145 let x = xs[i];
14146 out[i] = dys[i] / (1.0 + x * x);
14147 }
14148 }
14149 }
14150}
14151
14152fn activation_backward_kernel_f64(
14156 act: rlx_ir::op::Activation,
14157 xs: &[f64],
14158 dys: &[f64],
14159 out: &mut [f64],
14160) {
14161 use rlx_ir::op::Activation;
14162 let n = xs.len();
14163 debug_assert_eq!(dys.len(), n);
14164 debug_assert_eq!(out.len(), n);
14165 match act {
14166 Activation::Relu => {
14167 for i in 0..n {
14168 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14169 }
14170 }
14171 Activation::Sigmoid => {
14172 for i in 0..n {
14173 let s = 1.0 / (1.0 + (-xs[i]).exp());
14174 out[i] = s * (1.0 - s) * dys[i];
14175 }
14176 }
14177 Activation::Tanh => {
14178 for i in 0..n {
14179 let t = xs[i].tanh();
14180 out[i] = (1.0 - t * t) * dys[i];
14181 }
14182 }
14183 Activation::Silu => {
14184 for i in 0..n {
14185 let s = 1.0 / (1.0 + (-xs[i]).exp());
14186 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14187 }
14188 }
14189 Activation::Gelu | Activation::GeluApprox => {
14190 const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
14192 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
14193 for i in 0..n {
14194 let x = xs[i];
14195 let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
14196 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14197 out[i] = (phi + x * pdf) * dys[i];
14198 }
14199 }
14200 Activation::Exp => {
14201 for i in 0..n {
14202 out[i] = xs[i].exp() * dys[i];
14203 }
14204 }
14205 Activation::Log => {
14206 for i in 0..n {
14207 out[i] = dys[i] / xs[i];
14208 }
14209 }
14210 Activation::Sqrt => {
14211 for i in 0..n {
14212 let s = xs[i].sqrt();
14213 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14214 }
14215 }
14216 Activation::Rsqrt => {
14217 for i in 0..n {
14218 let s = xs[i].sqrt();
14219 out[i] = if s > 0.0 {
14220 -0.5 * dys[i] / (xs[i] * s)
14221 } else {
14222 0.0
14223 };
14224 }
14225 }
14226 Activation::Neg => {
14227 for i in 0..n {
14228 out[i] = -dys[i];
14229 }
14230 }
14231 Activation::Abs => {
14232 for i in 0..n {
14233 let x = xs[i];
14234 let s = if x > 0.0 {
14235 1.0
14236 } else if x < 0.0 {
14237 -1.0
14238 } else {
14239 0.0
14240 };
14241 out[i] = s * dys[i];
14242 }
14243 }
14244 Activation::Round => {
14245 out.copy_from_slice(dys);
14246 }
14247 Activation::Sin => {
14248 for i in 0..n {
14249 out[i] = xs[i].cos() * dys[i];
14250 }
14251 }
14252 Activation::Cos => {
14253 for i in 0..n {
14254 out[i] = -xs[i].sin() * dys[i];
14255 }
14256 }
14257 Activation::Tan => {
14258 for i in 0..n {
14259 let t = xs[i].tan();
14260 out[i] = (1.0 + t * t) * dys[i];
14261 }
14262 }
14263 Activation::Atan => {
14264 for i in 0..n {
14265 let x = xs[i];
14266 out[i] = dys[i] / (1.0 + x * x);
14267 }
14268 }
14269 }
14270}
14271
14272#[inline(always)]
14277fn erf_f64(x: f64) -> f64 {
14278 let s = x.signum();
14279 let x = x.abs();
14280 let t = 1.0 / (1.0 + 0.327_591_1 * x);
14281 let y = 1.0
14282 - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
14283 + 0.254_829_59)
14284 * t
14285 * (-x * x).exp();
14286 s * y
14287}
14288
14289#[inline(always)]
14292fn erf_f32(x: f32) -> f32 {
14293 let s = x.signum();
14294 let x = x.abs();
14295 let t = 1.0 / (1.0 + 0.327_591_1 * x);
14296 let y = 1.0
14297 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
14298 + 0.254_829_6)
14299 * t
14300 * (-x * x).exp();
14301 s * y
14302}
14303
14304fn narrow_thunk_closure(
14305 src: usize,
14306 dst: usize,
14307 outer: u32,
14308 src_stride: u32,
14309 dst_stride: u32,
14310 inner: u32,
14311 elem_bytes: u8,
14312) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
14313 let (outer, ss, ds, inner) = (
14314 outer as usize,
14315 src_stride as usize,
14316 dst_stride as usize,
14317 inner as usize,
14318 );
14319 if elem_bytes == 8 {
14320 Arc::new(move |base: *mut u8| unsafe {
14321 let s = sl_f64(src, base, outer * ss);
14322 let d = sl_mut_f64(dst, base, outer * ds);
14323 for o in 0..outer {
14324 d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14325 }
14326 })
14327 } else {
14328 Arc::new(move |base: *mut u8| unsafe {
14329 let s = sl(src, base, outer * ss);
14330 let d = sl_mut(dst, base, outer * ds);
14331 for o in 0..outer {
14332 d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14333 }
14334 })
14335 }
14336}
14337
14338unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
14339 if offset == usize::MAX {
14340 return &[];
14341 }
14342 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
14343}
14344
14345#[inline(always)]
14346unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
14347 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
14348}
14349
14350#[inline(always)]
14351unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
14352 if offset == usize::MAX {
14353 return &[];
14354 }
14355 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
14356}
14357
14358#[inline(always)]
14359unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
14360 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
14361}
14362
14363#[allow(dead_code)]
14368#[inline(always)]
14369unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
14370 if offset == usize::MAX {
14371 return &[];
14372 }
14373 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
14374}
14375
14376#[allow(dead_code)]
14377#[inline(always)]
14378unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
14379 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
14380}
14381
14382#[allow(dead_code)]
14383#[inline(always)]
14384unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
14385 if offset == usize::MAX {
14386 return &[];
14387 }
14388 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
14389}
14390
14391#[allow(dead_code)]
14392#[inline(always)]
14393unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
14394 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
14395}
14396
14397fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
14401 let rank = out_dims.len();
14402 let mut idx = vec![0u32; rank];
14403 for o in 0..out.len() {
14404 let mut src_off = 0usize;
14405 for d in 0..rank {
14406 src_off += idx[d] as usize * in_strides[d] as usize;
14407 }
14408 out[o] = inp[src_off];
14409 for d in (0..rank).rev() {
14411 idx[d] += 1;
14412 if idx[d] < out_dims[d] {
14413 break;
14414 }
14415 idx[d] = 0;
14416 }
14417 }
14418}
14419
14420fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
14426 match kind {
14427 Activation::Neg => {
14428 for (o, &v) in out.iter_mut().zip(inp) {
14429 *o = -v;
14430 }
14431 }
14432 Activation::Exp => {
14433 for (o, &v) in out.iter_mut().zip(inp) {
14434 *o = v.exp();
14435 }
14436 }
14437 Activation::Log => {
14438 for (o, &v) in out.iter_mut().zip(inp) {
14439 *o = v.ln();
14440 }
14441 }
14442 Activation::Sqrt => {
14443 for (o, &v) in out.iter_mut().zip(inp) {
14444 *o = v.sqrt();
14445 }
14446 }
14447 Activation::Rsqrt => {
14448 for (o, &v) in out.iter_mut().zip(inp) {
14449 *o = 1.0 / v.sqrt();
14450 }
14451 }
14452 Activation::Abs => {
14453 for (o, &v) in out.iter_mut().zip(inp) {
14454 *o = v.abs();
14455 }
14456 }
14457 Activation::Tanh => {
14458 for (o, &v) in out.iter_mut().zip(inp) {
14459 *o = v.tanh();
14460 }
14461 }
14462 Activation::Sigmoid => {
14463 for (o, &v) in out.iter_mut().zip(inp) {
14464 *o = 1.0 / (1.0 + (-v).exp());
14465 }
14466 }
14467 Activation::Relu => {
14468 for (o, &v) in out.iter_mut().zip(inp) {
14469 *o = v.max(0.0);
14470 }
14471 }
14472 Activation::Round => {
14473 for (o, &v) in out.iter_mut().zip(inp) {
14474 *o = v.round_ties_even();
14475 }
14476 }
14477 Activation::Sin => {
14478 for (o, &v) in out.iter_mut().zip(inp) {
14479 *o = v.sin();
14480 }
14481 }
14482 Activation::Cos => {
14483 for (o, &v) in out.iter_mut().zip(inp) {
14484 *o = v.cos();
14485 }
14486 }
14487 Activation::Tan => {
14488 for (o, &v) in out.iter_mut().zip(inp) {
14489 *o = v.tan();
14490 }
14491 }
14492 Activation::Atan => {
14493 for (o, &v) in out.iter_mut().zip(inp) {
14494 *o = v.atan();
14495 }
14496 }
14497 Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
14498 panic!(
14499 "apply_activation_f64: {kind:?} not yet implemented at f64. \
14500 Add when a workload needs it."
14501 );
14502 }
14503 }
14504}
14505
14506#[inline]
14507fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
14508 match op {
14509 BinaryOp::Add => a + b,
14510 BinaryOp::Sub => a - b,
14511 BinaryOp::Mul => a * b,
14512 BinaryOp::Div => a / b,
14513 BinaryOp::Max => a.max(b),
14514 BinaryOp::Min => a.min(b),
14515 BinaryOp::Pow => a.powf(b),
14516 }
14517}
14518
14519fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
14522 for o in 0..outer {
14523 for n in 0..inner {
14524 let mut acc = 0.0_f64;
14525 for r in 0..reduced {
14526 acc += inp[o * reduced * inner + r * inner + n];
14527 }
14528 out[o * inner + n] = acc;
14529 }
14530 }
14531}
14532
14533#[cfg(test)]
14534mod tests {
14535 use super::*;
14536 use rlx_ir::*;
14537
14538 #[test]
14544 fn narrow_rope_fuses_in_unfused_path() {
14545 let f = DType::F32;
14546 let mut g = Graph::new("nr_fuse");
14547 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); let cos = g.input("cos", Shape::new(&[16], f));
14550 let sin = g.input("sin", Shape::new(&[16], f));
14551 let q = g.narrow_(qkv, 2, 0, 64);
14553 let q_rope = g.rope(q, cos, sin, 16);
14554 g.set_outputs(vec![q_rope]);
14555
14556 let plan = rlx_opt::memory::plan_memory(&g);
14557 let arena = crate::arena::Arena::from_plan(plan);
14558 let sched = compile_thunks(&g, &arena);
14559
14560 let mut narrow_count = 0;
14561 let mut rope_with_stride: Option<u32> = None;
14562 for t in &sched.thunks {
14563 match t {
14564 Thunk::Narrow { .. } => narrow_count += 1,
14565 Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
14566 _ => {}
14567 }
14568 }
14569 assert_eq!(
14572 narrow_count, 0,
14573 "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
14574 );
14575 assert_eq!(
14576 rope_with_stride,
14577 Some(192),
14578 "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
14579 );
14580 }
14581
14582 #[test]
14585 fn ssm_selective_scan_matches_reference() {
14586 use rlx_ir::Philox4x32;
14587 let bch = 1usize;
14588 let s = 4usize;
14589 let h = 3usize;
14590 let n = 2usize;
14591
14592 let mut rng = Philox4x32::new(13);
14593 let mut x = vec![0f32; bch * s * h];
14594 rng.fill_normal(&mut x);
14595 let mut delta = vec![0f32; bch * s * h];
14596 for v in delta.iter_mut() {
14598 *v = (rng.next_f32() - 0.5) * 0.1;
14599 }
14600 let mut a = vec![0f32; h * n];
14601 for v in a.iter_mut() {
14602 *v = -(rng.next_f32() * 0.5 + 0.1);
14603 } let mut b = vec![0f32; bch * s * n];
14605 rng.fill_normal(&mut b);
14606 let mut c = vec![0f32; bch * s * n];
14607 rng.fill_normal(&mut c);
14608
14609 let mut expected = vec![0f32; bch * s * h];
14611 for bi in 0..bch {
14612 let mut state = vec![0f32; h * n];
14613 for si in 0..s {
14614 for ci in 0..h {
14615 let d = delta[bi * s * h + si * h + ci];
14616 let xv = x[bi * s * h + si * h + ci];
14617 let mut acc = 0f32;
14618 for ni in 0..n {
14619 let da = (d * a[ci * n + ni]).exp();
14620 state[ci * n + ni] =
14621 da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
14622 acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
14623 }
14624 expected[bi * s * h + si * h + ci] = acc;
14625 }
14626 }
14627 }
14628
14629 let f = DType::F32;
14631 let mut g = Graph::new("ssm");
14632 let xn = g.input("x", Shape::new(&[bch, s, h], f));
14633 let dn = g.input("delta", Shape::new(&[bch, s, h], f));
14634 let an = g.param("a", Shape::new(&[h, n], f));
14635 let bn = g.param("b", Shape::new(&[bch, s, n], f));
14636 let cn = g.param("c", Shape::new(&[bch, s, n], f));
14637 let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
14638 g.set_outputs(vec![yn]);
14639
14640 let plan = rlx_opt::memory::plan_memory(&g);
14641 let mut arena = crate::arena::Arena::from_plan(plan);
14642 let sched = compile_thunks(&g, &arena);
14643
14644 let xn_off = arena.byte_offset(xn);
14645 let dn_off = arena.byte_offset(dn);
14646 let an_off = arena.byte_offset(an);
14647 let bn_off = arena.byte_offset(bn);
14648 let cn_off = arena.byte_offset(cn);
14649 let yn_off = arena.byte_offset(yn);
14650 let buf = arena.raw_buf_mut();
14651 unsafe {
14652 let copy = |dst: *mut f32, data: &[f32]| {
14653 for (i, &v) in data.iter().enumerate() {
14654 *dst.add(i) = v;
14655 }
14656 };
14657 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14658 copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
14659 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14660 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14661 copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
14662 }
14663 execute_thunks(&sched, arena.raw_buf_mut());
14664
14665 let actual: Vec<f32> = unsafe {
14666 let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
14667 (0..bch * s * h).map(|i| *p.add(i)).collect()
14668 };
14669
14670 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14671 assert!(
14672 (e - a).abs() < 1e-3,
14673 "mismatch at {i}: expected {e}, got {a}"
14674 );
14675 }
14676 }
14677
14678 #[test]
14681 fn conv_1x1_fast_path_matches_scalar() {
14682 use rlx_ir::Philox4x32;
14683 let n = 2usize;
14685 let c_in = 4usize;
14686 let h = 3usize;
14687 let w = 3usize;
14688 let c_out = 5usize;
14689 let mut rng = Philox4x32::new(31);
14690 let mut x = vec![0f32; n * c_in * h * w];
14691 rng.fill_normal(&mut x);
14692 let mut weight = vec![0f32; c_out * c_in];
14693 rng.fill_normal(&mut weight);
14694
14695 let mut expected = vec![0f32; n * c_out * h * w];
14698 for ni in 0..n {
14699 for co in 0..c_out {
14700 for hi in 0..h {
14701 for wi in 0..w {
14702 let mut acc = 0f32;
14703 for ci in 0..c_in {
14704 acc += weight[co * c_in + ci]
14705 * x[((ni * c_in) + ci) * h * w + hi * w + wi];
14706 }
14707 expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
14708 }
14709 }
14710 }
14711 }
14712
14713 let f = DType::F32;
14715 let mut g = Graph::new("conv1x1");
14716 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
14717 let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
14718 let cn = g.add_node(
14720 rlx_ir::Op::Conv {
14721 kernel_size: vec![1, 1],
14722 stride: vec![1, 1],
14723 padding: vec![0, 0],
14724 dilation: vec![1, 1],
14725 groups: 1,
14726 },
14727 vec![xn, wn],
14728 Shape::new(&[n, c_out, h, w], f),
14729 );
14730 g.set_outputs(vec![cn]);
14731
14732 let plan = rlx_opt::memory::plan_memory(&g);
14733 let mut arena = crate::arena::Arena::from_plan(plan);
14734 let sched = compile_thunks(&g, &arena);
14735
14736 let saw_fast = sched
14738 .thunks
14739 .iter()
14740 .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
14741 let saw_slow = sched
14742 .thunks
14743 .iter()
14744 .any(|t| matches!(t, Thunk::Conv2D { .. }));
14745 assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
14746 assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
14747
14748 let xn_off = arena.byte_offset(xn);
14749 let wn_off = arena.byte_offset(wn);
14750 let cn_off = arena.byte_offset(cn);
14751 let buf = arena.raw_buf_mut();
14752 unsafe {
14753 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14754 for (i, &v) in x.iter().enumerate() {
14755 *xp.add(i) = v;
14756 }
14757 let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
14758 for (i, &v) in weight.iter().enumerate() {
14759 *wp.add(i) = v;
14760 }
14761 }
14762 execute_thunks(&sched, arena.raw_buf_mut());
14763
14764 let actual: Vec<f32> = unsafe {
14765 let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
14766 (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
14767 };
14768
14769 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14770 assert!(
14771 (e - a).abs() < 1e-3,
14772 "mismatch at {i}: expected {e}, got {a}"
14773 );
14774 }
14775 }
14776
14777 #[test]
14780 fn dequant_matmul_int8_sym_matches_reference() {
14781 use rlx_ir::Philox4x32;
14782 use rlx_ir::quant::QuantScheme;
14783
14784 let m = 3usize;
14785 let k = 8usize;
14786 let n = 4usize;
14787 let block_size = 4usize; let blocks_per_col = k / block_size;
14789
14790 let mut rng = Philox4x32::new(99);
14792 let mut x = vec![0f32; m * k];
14793 rng.fill_normal(&mut x);
14794 let w_q: Vec<i8> = (0..(k * n))
14795 .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
14796 .collect();
14797 let scales: Vec<f32> = (0..(blocks_per_col * n))
14798 .map(|i| 0.01 + 0.001 * i as f32)
14799 .collect();
14800
14801 let mut w_f32 = vec![0f32; k * n];
14803 for p in 0..k {
14804 let block = p / block_size;
14805 for j in 0..n {
14806 let s = scales[block * n + j];
14807 w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
14808 }
14809 }
14810 let mut expected = vec![0f32; m * n];
14811 for i in 0..m {
14812 for j in 0..n {
14813 let mut acc = 0f32;
14814 for p in 0..k {
14815 acc += x[i * k + p] * w_f32[p * n + j];
14816 }
14817 expected[i * n + j] = acc;
14818 }
14819 }
14820
14821 let f = DType::F32;
14823 let mut g = Graph::new("dq");
14824 let xn = g.input("x", Shape::new(&[m, k], f));
14825 let wn = g.param("w", Shape::new(&[k, n], DType::I8));
14826 let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
14827 let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); let dq = g.dequant_matmul(
14829 xn,
14830 wn,
14831 sn,
14832 zn,
14833 QuantScheme::Int8Block {
14834 block_size: block_size as u32,
14835 },
14836 Shape::new(&[m, n], f),
14837 );
14838 g.set_outputs(vec![dq]);
14839
14840 let plan = rlx_opt::memory::plan_memory(&g);
14841 let mut arena = crate::arena::Arena::from_plan(plan);
14842 let sched = compile_thunks(&g, &arena);
14843
14844 let xn_off = arena.byte_offset(xn);
14845 let wn_off = arena.byte_offset(wn);
14846 let sn_off = arena.byte_offset(sn);
14847 let zn_off = arena.byte_offset(zn);
14848 let dq_off = arena.byte_offset(dq);
14849 let buf = arena.raw_buf_mut();
14850 unsafe {
14851 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14853 for (i, &v) in x.iter().enumerate() {
14854 *xp.add(i) = v;
14855 }
14856 let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
14857 for (i, &v) in scales.iter().enumerate() {
14858 *sp.add(i) = v;
14859 }
14860 let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
14861 for i in 0..(blocks_per_col * n) {
14862 *zp.add(i) = 0.0;
14863 }
14864 let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
14866 for (i, &v) in w_q.iter().enumerate() {
14867 *wp.add(i) = v;
14868 }
14869 }
14870 execute_thunks(&sched, arena.raw_buf_mut());
14871
14872 let actual: Vec<f32> = unsafe {
14873 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
14874 (0..m * n).map(|i| *p.add(i)).collect()
14875 };
14876
14877 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14878 assert!(
14879 (e - a).abs() < 1e-3,
14880 "mismatch at {i}: expected {e}, got {a}"
14881 );
14882 }
14883 }
14884
14885 #[test]
14887 fn lora_matmul_matches_unfused_reference() {
14888 use rlx_ir::Philox4x32;
14889
14890 let m = 4usize;
14891 let k = 8usize;
14892 let n = 6usize;
14893 let r = 2usize;
14894 let scale = 0.5f32;
14895
14896 let mut rng = Philox4x32::new(42);
14898 let mut x = vec![0f32; m * k];
14899 rng.fill_normal(&mut x);
14900 let mut w = vec![0f32; k * n];
14901 rng.fill_normal(&mut w);
14902 let mut a = vec![0f32; k * r];
14903 rng.fill_normal(&mut a);
14904 let mut b = vec![0f32; r * n];
14905 rng.fill_normal(&mut b);
14906
14907 let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
14909 let mut o = vec![0f32; rows * cols];
14910 for i in 0..rows {
14911 for j in 0..cols {
14912 let mut acc = 0f32;
14913 for p in 0..inner {
14914 acc += a_buf[i * inner + p] * b_buf[p * cols + j];
14915 }
14916 o[i * cols + j] = acc;
14917 }
14918 }
14919 o
14920 };
14921 let xw = naive(&x, &w, m, k, n);
14922 let xa = naive(&x, &a, m, k, r);
14923 let xab = naive(&xa, &b, m, r, n);
14924 let mut expected = xw;
14925 for i in 0..(m * n) {
14926 expected[i] += scale * xab[i];
14927 }
14928
14929 let f = DType::F32;
14931 let mut g = Graph::new("lora");
14932 let xn = g.input("x", Shape::new(&[m, k], f));
14933 let wn = g.param("w", Shape::new(&[k, n], f));
14934 let an = g.param("a", Shape::new(&[k, r], f));
14935 let bn = g.param("b", Shape::new(&[r, n], f));
14936 let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
14937 g.set_outputs(vec![lm]);
14938
14939 let plan = rlx_opt::memory::plan_memory(&g);
14940 let mut arena = crate::arena::Arena::from_plan(plan);
14941 let sched = compile_thunks(&g, &arena);
14942
14943 let xn_off = arena.byte_offset(xn);
14944 let wn_off = arena.byte_offset(wn);
14945 let an_off = arena.byte_offset(an);
14946 let bn_off = arena.byte_offset(bn);
14947 let lm_off = arena.byte_offset(lm);
14948 let buf = arena.raw_buf_mut();
14949 unsafe {
14950 let copy = |dst: *mut f32, data: &[f32]| {
14951 for (i, &v) in data.iter().enumerate() {
14952 *dst.add(i) = v;
14953 }
14954 };
14955 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14956 copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
14957 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14958 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14959 }
14960 execute_thunks(&sched, arena.raw_buf_mut());
14961
14962 let actual: Vec<f32> = unsafe {
14963 let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
14964 (0..m * n).map(|i| *p.add(i)).collect()
14965 };
14966
14967 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14968 assert!(
14969 (e - a).abs() < 1e-3,
14970 "mismatch at {i}: expected {e}, got {a}"
14971 );
14972 }
14973 }
14974
14975 #[test]
14977 fn sample_temperature_zero_is_argmax() {
14978 let f = DType::F32;
14981 let mut g = Graph::new("samp");
14982 let logits = g.input("logits", Shape::new(&[1, 8], f));
14983 let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
14984 g.set_outputs(vec![s]);
14985 let plan = rlx_opt::memory::plan_memory(&g);
14986 let mut arena = crate::arena::Arena::from_plan(plan);
14987 let sched = compile_thunks(&g, &arena);
14988
14989 let logits_off = arena.byte_offset(logits);
14990 let s_off = arena.byte_offset(s);
14991 let buf = arena.raw_buf_mut();
14992 unsafe {
14993 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
14994 let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
14996 for (i, &v) in inputs.iter().enumerate() {
14997 *p.add(i) = v;
14998 }
14999 }
15000 execute_thunks(&sched, arena.raw_buf_mut());
15001
15002 let token = unsafe {
15003 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15004 *p as usize
15005 };
15006 assert_eq!(token, 5, "low-temp sampling should pick the argmax");
15007 }
15008
15009 #[test]
15010 fn sample_top_k_one_is_deterministic() {
15011 let f = DType::F32;
15013 let mut g = Graph::new("samp_k1");
15014 let logits = g.input("logits", Shape::new(&[1, 4], f));
15015 let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
15016 g.set_outputs(vec![s]);
15017 let plan = rlx_opt::memory::plan_memory(&g);
15018 let mut arena = crate::arena::Arena::from_plan(plan);
15019 let sched = compile_thunks(&g, &arena);
15020
15021 let logits_off = arena.byte_offset(logits);
15022 let s_off = arena.byte_offset(s);
15023 let buf = arena.raw_buf_mut();
15024 unsafe {
15025 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
15026 let inputs = [0.1f32, 5.0, 0.3, 0.4]; for (i, &v) in inputs.iter().enumerate() {
15028 *p.add(i) = v;
15029 }
15030 }
15031 execute_thunks(&sched, arena.raw_buf_mut());
15032 let token = unsafe {
15033 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15034 *p as usize
15035 };
15036 assert_eq!(token, 1);
15037 }
15038
15039 #[test]
15041 fn cumsum_inclusive_matches_naive() {
15042 let f = DType::F32;
15043 let mut g = Graph::new("cumsum");
15044 let x = g.input("x", Shape::new(&[2, 4], f));
15045 let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
15046 g.set_outputs(vec![cs]);
15047 let plan = rlx_opt::memory::plan_memory(&g);
15048 let mut arena = crate::arena::Arena::from_plan(plan);
15049 let sched = compile_thunks(&g, &arena);
15050
15051 let x_off = arena.byte_offset(x);
15053 let out_off = arena.byte_offset(cs);
15054 let buf = arena.raw_buf_mut();
15055 unsafe {
15056 let p = buf.as_mut_ptr().add(x_off) as *mut f32;
15057 let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
15058 for (i, &v) in inputs.iter().enumerate() {
15059 *p.add(i) = v;
15060 }
15061 }
15062 execute_thunks(&sched, arena.raw_buf_mut());
15063
15064 let out: Vec<f32> = unsafe {
15065 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
15066 (0..8).map(|i| *p.add(i)).collect()
15067 };
15068 assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
15069 }
15070
15071 #[test]
15075 fn narrow_attention_fuses_in_unfused_path() {
15076 let f = DType::F32;
15077 let mut g = Graph::new("nattn_fuse");
15078 let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); let mask = g.input("mask", Shape::new(&[8, 16], f));
15081 let q = g.narrow_(qkv, 2, 0, 64);
15082 let k = g.narrow_(qkv, 2, 64, 64);
15083 let v = g.narrow_(qkv, 2, 128, 64);
15084 let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
15085 g.set_outputs(vec![attn]);
15086
15087 let plan = rlx_opt::memory::plan_memory(&g);
15088 let arena = crate::arena::Arena::from_plan(plan);
15089 let sched = compile_thunks(&g, &arena);
15090
15091 let mut narrow_count = 0;
15092 let mut attn_strides: Option<(u32, u32, u32)> = None;
15093 for t in &sched.thunks {
15094 match t {
15095 Thunk::Narrow { .. } => narrow_count += 1,
15096 Thunk::Attention {
15097 q_row_stride,
15098 k_row_stride,
15099 v_row_stride,
15100 ..
15101 } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
15102 _ => {}
15103 }
15104 }
15105 assert_eq!(
15108 narrow_count, 0,
15109 "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
15110 );
15111 assert_eq!(
15112 attn_strides,
15113 Some((192, 192, 192)),
15114 "Attention should walk Q/K/V with parent row stride 192"
15115 );
15116 }
15117
15118 fn run_graph(
15129 g: &Graph,
15130 inputs: &[(NodeId, &[f32])],
15131 out_id: NodeId,
15132 out_len: usize,
15133 ) -> Vec<f32> {
15134 let plan = rlx_opt::memory::plan_memory(g);
15135 let mut arena = crate::arena::Arena::from_plan(plan);
15136 let sched = compile_thunks(g, &arena);
15137 for &(id, data) in inputs {
15138 let off = arena.byte_offset(id);
15139 let buf = arena.raw_buf_mut();
15140 unsafe {
15141 let p = buf.as_mut_ptr().add(off) as *mut f32;
15142 for (i, &v) in data.iter().enumerate() {
15143 *p.add(i) = v;
15144 }
15145 }
15146 }
15147 execute_thunks(&sched, arena.raw_buf_mut());
15148 let off = arena.byte_offset(out_id);
15149 unsafe {
15150 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15151 (0..out_len).map(|i| *p.add(i)).collect()
15152 }
15153 }
15154
15155 #[test]
15156 fn relu_backward_matches_mask() {
15157 let f = DType::F32;
15158 let len = 7usize;
15159 let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
15160 let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
15161
15162 let mut g = Graph::new("relu_bw");
15163 let xn = g.input("x", Shape::new(&[len], f));
15164 let dyn_ = g.input("dy", Shape::new(&[len], f));
15165 let dx = g.relu_backward(xn, dyn_);
15166 g.set_outputs(vec![dx]);
15167
15168 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
15169 let expected: Vec<f32> = x
15173 .iter()
15174 .zip(&dy)
15175 .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
15176 .collect();
15177 for (a, e) in actual.iter().zip(&expected) {
15178 assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
15179 }
15180 }
15181
15182 #[test]
15183 fn maxpool2d_backward_routes_to_argmax() {
15184 let f = DType::F32;
15185 let x: Vec<f32> = vec![
15187 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
15188 ];
15189 let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
15193
15194 let mut g = Graph::new("maxpool_bw");
15195 let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
15196 let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
15197 let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
15198 g.set_outputs(vec![dx]);
15199
15200 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
15201 let mut expected = vec![0f32; 16];
15202 expected[5] = 0.5;
15203 expected[7] = 1.0;
15204 expected[13] = 2.0;
15205 expected[15] = 4.0;
15206 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15207 assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
15208 }
15209 }
15210
15211 #[test]
15212 fn conv2d_backward_input_matches_numerical_gradient() {
15213 use rlx_ir::Philox4x32;
15214 let n = 1usize;
15217 let c_in = 2usize;
15218 let h = 4usize;
15219 let w = 4usize;
15220 let c_out = 3usize;
15221 let kh = 3usize;
15222 let kw = 3usize;
15223 let ph = 1usize;
15224 let pw = 1usize;
15225 let sh = 1usize;
15226 let sw = 1usize;
15227 let h_out = (h + 2 * ph - kh) / sh + 1;
15229 let w_out = (w + 2 * pw - kw) / sw + 1;
15230 assert_eq!(h_out, 4);
15231 assert_eq!(w_out, 4);
15232
15233 let mut rng = Philox4x32::new(7);
15234 let mut x = vec![0f32; n * c_in * h * w];
15235 rng.fill_normal(&mut x);
15236 let mut wt = vec![0f32; c_out * c_in * kh * kw];
15237 rng.fill_normal(&mut wt);
15238 let mut dy = vec![0f32; n * c_out * h_out * w_out];
15239 rng.fill_normal(&mut dy);
15240
15241 let f = DType::F32;
15243 let mut g = Graph::new("conv_bwi");
15244 let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15245 let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
15246 let dx = g.conv2d_backward_input(
15247 dy_in,
15248 w_in,
15249 Shape::new(&[n, c_in, h, w], f),
15250 vec![kh, kw],
15251 vec![sh, sw],
15252 vec![ph, pw],
15253 vec![1, 1],
15254 1,
15255 );
15256 g.set_outputs(vec![dx]);
15257 let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
15258
15259 let forward = |x: &[f32]| -> Vec<f32> {
15263 let mut out = vec![0f32; n * c_out * h_out * w_out];
15264 for ni in 0..n {
15265 for co in 0..c_out {
15266 for ho in 0..h_out {
15267 for wo in 0..w_out {
15268 let mut acc = 0f32;
15269 for ci in 0..c_in {
15270 for ki in 0..kh {
15271 for kj in 0..kw {
15272 let hi = ho * sh + ki;
15273 let wi = wo * sw + kj;
15274 if hi < ph || wi < pw {
15275 continue;
15276 }
15277 let hi = hi - ph;
15278 let wi = wi - pw;
15279 if hi >= h || wi >= w {
15280 continue;
15281 }
15282 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15283 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15284 acc += xv * wv;
15285 }
15286 }
15287 }
15288 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15289 }
15290 }
15291 }
15292 }
15293 out
15294 };
15295 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15296 let eps = 1e-3f32;
15297 let mut numerical = vec![0f32; x.len()];
15298 for i in 0..x.len() {
15299 let saved = x[i];
15300 x[i] = saved + eps;
15301 let plus = dot(&forward(&x), &dy);
15302 x[i] = saved - eps;
15303 let minus = dot(&forward(&x), &dy);
15304 x[i] = saved;
15305 numerical[i] = (plus - minus) / (2.0 * eps);
15306 }
15307 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15308 assert!(
15310 (a - n).abs() < 5e-3,
15311 "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
15312 );
15313 }
15314 }
15315
15316 #[test]
15317 fn conv2d_backward_weight_matches_numerical_gradient() {
15318 use rlx_ir::Philox4x32;
15319 let n = 2usize;
15320 let c_in = 2usize;
15321 let h = 4usize;
15322 let w = 4usize;
15323 let c_out = 2usize;
15324 let kh = 3usize;
15325 let kw = 3usize;
15326 let ph = 0usize;
15327 let pw = 0usize;
15328 let sh = 1usize;
15329 let sw = 1usize;
15330 let h_out = (h + 2 * ph - kh) / sh + 1;
15331 let w_out = (w + 2 * pw - kw) / sw + 1;
15332
15333 let mut rng = Philox4x32::new(11);
15334 let mut x = vec![0f32; n * c_in * h * w];
15335 rng.fill_normal(&mut x);
15336 let mut wt = vec![0f32; c_out * c_in * kh * kw];
15337 rng.fill_normal(&mut wt);
15338 let mut dy = vec![0f32; n * c_out * h_out * w_out];
15339 rng.fill_normal(&mut dy);
15340
15341 let f = DType::F32;
15342 let mut g = Graph::new("conv_bww");
15343 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
15344 let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15345 let dwn = g.conv2d_backward_weight(
15346 xn,
15347 dyn_,
15348 Shape::new(&[c_out, c_in, kh, kw], f),
15349 vec![kh, kw],
15350 vec![sh, sw],
15351 vec![ph, pw],
15352 vec![1, 1],
15353 1,
15354 );
15355 g.set_outputs(vec![dwn]);
15356 let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
15357
15358 let forward = |wt: &[f32]| -> Vec<f32> {
15359 let mut out = vec![0f32; n * c_out * h_out * w_out];
15360 for ni in 0..n {
15361 for co in 0..c_out {
15362 for ho in 0..h_out {
15363 for wo in 0..w_out {
15364 let mut acc = 0f32;
15365 for ci in 0..c_in {
15366 for ki in 0..kh {
15367 for kj in 0..kw {
15368 let hi = ho + ki;
15369 let wi = wo + kj;
15370 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15371 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15372 acc += xv * wv;
15373 }
15374 }
15375 }
15376 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15377 }
15378 }
15379 }
15380 }
15381 out
15382 };
15383 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15384 let eps = 1e-3f32;
15385 let mut numerical = vec![0f32; wt.len()];
15386 for i in 0..wt.len() {
15387 let saved = wt[i];
15388 wt[i] = saved + eps;
15389 let plus = dot(&forward(&wt), &dy);
15390 wt[i] = saved - eps;
15391 let minus = dot(&forward(&wt), &dy);
15392 wt[i] = saved;
15393 numerical[i] = (plus - minus) / (2.0 * eps);
15394 }
15395 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15396 assert!(
15397 (a - n).abs() < 5e-3,
15398 "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
15399 );
15400 }
15401 }
15402
15403 #[test]
15404 fn softmax_cross_entropy_matches_reference() {
15405 let f = DType::F32;
15406 let logits: Vec<f32> = vec![
15407 1.0, 2.0, 3.0, -1.0, 0.0, 4.0, 5.0, 5.0, 5.0, ];
15411 let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
15412
15413 let mut g = Graph::new("sce");
15414 let lg = g.input("logits", Shape::new(&[3, 3], f));
15415 let lb = g.input("labels", Shape::new(&[3], f));
15416 let loss = g.softmax_cross_entropy_with_logits(lg, lb);
15417 g.set_outputs(vec![loss]);
15418 let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
15419
15420 let mut expected = vec![0f32; 3];
15422 for ni in 0..3 {
15423 let row = &logits[ni * 3..(ni + 1) * 3];
15424 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15425 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15426 let lse = m + sum.ln();
15427 let label_idx = labels[ni] as usize;
15428 expected[ni] = lse - row[label_idx];
15429 }
15430 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15431 assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
15432 }
15433 }
15434
15435 #[test]
15436 fn softmax_cross_entropy_backward_matches_numerical_gradient() {
15437 use rlx_ir::Philox4x32;
15438 let n = 4usize;
15439 let c = 5usize;
15440 let mut rng = Philox4x32::new(23);
15441 let mut logits = vec![0f32; n * c];
15442 rng.fill_normal(&mut logits);
15443 let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
15444 let mut d_loss = vec![0f32; n];
15445 rng.fill_normal(&mut d_loss);
15446
15447 let f = DType::F32;
15448 let mut g = Graph::new("sce_bw");
15449 let lg = g.input("logits", Shape::new(&[n, c], f));
15450 let lb = g.input("labels", Shape::new(&[n], f));
15451 let dl = g.input("d_loss", Shape::new(&[n], f));
15452 let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
15453 g.set_outputs(vec![dlogits]);
15454 let analytical = run_graph(
15455 &g,
15456 &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
15457 dlogits,
15458 n * c,
15459 );
15460
15461 let sce_loss = |logits: &[f32]| -> Vec<f32> {
15463 let mut out = vec![0f32; n];
15464 for ni in 0..n {
15465 let row = &logits[ni * c..(ni + 1) * c];
15466 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15467 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15468 out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
15469 }
15470 out
15471 };
15472 let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
15473 let eps = 1e-3f32;
15474 let mut numerical = vec![0f32; logits.len()];
15475 for i in 0..logits.len() {
15476 let saved = logits[i];
15477 logits[i] = saved + eps;
15478 let plus = dot(&sce_loss(&logits), &d_loss);
15479 logits[i] = saved - eps;
15480 let minus = dot(&sce_loss(&logits), &d_loss);
15481 logits[i] = saved;
15482 numerical[i] = (plus - minus) / (2.0 * eps);
15483 }
15484 for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
15485 assert!(
15486 (a - num).abs() < 5e-3,
15487 "sce_bw[{i}]: analytical {a} vs numerical {num}"
15488 );
15489 }
15490 }
15491
15492 fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
15505 for node in graph.nodes() {
15506 if let Op::Constant { data } = &node.op
15507 && arena.has_buffer(node.id)
15508 && !data.is_empty()
15509 {
15510 let buf = arena.slice_mut(node.id);
15511 let n_floats = data.len() / 4;
15512 let n = buf.len().min(n_floats);
15513 for i in 0..n {
15514 let bytes = [
15515 data[i * 4],
15516 data[i * 4 + 1],
15517 data[i * 4 + 2],
15518 data[i * 4 + 3],
15519 ];
15520 buf[i] = f32::from_le_bytes(bytes);
15521 }
15522 }
15523 }
15524 }
15525
15526 fn prepare(
15530 graph: &Graph,
15531 seed_inputs: &[(NodeId, &[f32])],
15532 ) -> (ThunkSchedule, crate::arena::Arena) {
15533 let plan = rlx_opt::memory::plan_memory(graph);
15534 let mut arena = crate::arena::Arena::from_plan(plan);
15535 let sched = compile_thunks(graph, &arena);
15536 fill_constants_into_arena(graph, &mut arena);
15537 for &(id, data) in seed_inputs {
15538 let off = arena.byte_offset(id);
15539 let buf = arena.raw_buf_mut();
15540 unsafe {
15541 let p = buf.as_mut_ptr().add(off) as *mut f32;
15542 for (i, &v) in data.iter().enumerate() {
15543 *p.add(i) = v;
15544 }
15545 }
15546 }
15547 (sched, arena)
15548 }
15549
15550 fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
15551 let off = arena.byte_offset(id);
15552 unsafe {
15553 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15554 (0..len).map(|i| *p.add(i)).collect()
15555 }
15556 }
15557
15558 fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
15559 let off = arena.byte_offset(id);
15560 let buf = arena.raw_buf_mut();
15561 unsafe {
15562 let p = buf.as_mut_ptr().add(off) as *mut f32;
15563 for (i, &v) in data.iter().enumerate() {
15564 *p.add(i) = v;
15565 }
15566 }
15567 }
15568
15569 fn prepare_f64(
15571 graph: &Graph,
15572 seed_inputs: &[(NodeId, &[f64])],
15573 ) -> (ThunkSchedule, crate::arena::Arena) {
15574 let plan = rlx_opt::memory::plan_memory(graph);
15575 let mut arena = crate::arena::Arena::from_plan(plan);
15576 let sched = compile_thunks(graph, &arena);
15577 fill_constants_into_arena(graph, &mut arena);
15578 for &(id, data) in seed_inputs {
15579 let off = arena.byte_offset(id);
15580 let buf = arena.raw_buf_mut();
15581 unsafe {
15582 let p = buf.as_mut_ptr().add(off) as *mut f64;
15583 for (i, &v) in data.iter().enumerate() {
15584 *p.add(i) = v;
15585 }
15586 }
15587 }
15588 (sched, arena)
15589 }
15590
15591 fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
15592 let off = arena.byte_offset(id);
15593 unsafe {
15594 let p = arena.raw_buf().as_ptr().add(off) as *const f64;
15595 (0..len).map(|i| *p.add(i)).collect()
15596 }
15597 }
15598
15599 #[test]
15609 fn dense_solve_f64_end_to_end() {
15610 let mut g = Graph::new("solve_e2e");
15611 let a = g.input("A", Shape::new(&[2, 2], DType::F64));
15612 let b = g.input("b", Shape::new(&[2], DType::F64));
15613 let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
15614 g.set_outputs(vec![x]);
15615
15616 let a_data = [2.0, 1.0, 1.0, 3.0_f64];
15617 let b_data = [5.0, 10.0_f64];
15618 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15619 execute_thunks(&sched, arena.raw_buf_mut());
15620
15621 let got = read_arena_f64(&arena, x, 2);
15622 let want = [1.0, 3.0_f64];
15623 for i in 0..2 {
15624 assert!(
15625 (got[i] - want[i]).abs() < 1e-12,
15626 "x[{i}] = {} (expected {})",
15627 got[i],
15628 want[i]
15629 );
15630 }
15631 }
15632
15633 #[test]
15639 fn dense_solve_f64_5x5_laplacian() {
15640 let n = 5usize;
15641 let mut g = Graph::new("solve_5x5");
15642 let a = g.input("A", Shape::new(&[n, n], DType::F64));
15643 let b = g.input("b", Shape::new(&[n], DType::F64));
15644 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15645 g.set_outputs(vec![x]);
15646
15647 let mut a_data = vec![0.0_f64; n * n];
15649 for i in 0..n {
15650 a_data[i * n + i] = 2.0;
15651 if i > 0 {
15652 a_data[i * n + (i - 1)] = -1.0;
15653 }
15654 if i + 1 < n {
15655 a_data[i * n + (i + 1)] = -1.0;
15656 }
15657 }
15658 let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
15659 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15660 execute_thunks(&sched, arena.raw_buf_mut());
15661
15662 let got = read_arena_f64(&arena, x, n);
15663 let mut residual = vec![0.0_f64; n];
15665 for i in 0..n {
15666 for j in 0..n {
15667 residual[i] += a_data[i * n + j] * got[j];
15668 }
15669 }
15670 for i in 0..n {
15671 assert!(
15672 (residual[i] - b_data[i]).abs() < 1e-10,
15673 "row {i}: residual {} vs b {}",
15674 residual[i],
15675 b_data[i]
15676 );
15677 }
15678 }
15679
15680 #[test]
15699 fn hello_resistor_gradient_end_to_end() {
15700 use rlx_opt::autodiff::grad_with_loss;
15701 let n = 3usize;
15702
15703 let mut g = Graph::new("hello_resistor");
15705 let a = g.param("A", Shape::new(&[n, n], DType::F64));
15706 let b = g.input("b", Shape::new(&[n], DType::F64));
15707 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15708 let loss = g.reduce(
15709 x,
15710 ReduceOp::Sum,
15711 vec![0],
15712 false,
15713 Shape::new(&[1], DType::F64),
15714 );
15715 g.set_outputs(vec![loss]);
15716
15717 let bwd = grad_with_loss(&g, &[a, b]);
15719 assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
15720
15721 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
15725 for node in graph.nodes() {
15726 let name = match &node.op {
15727 rlx_ir::Op::Input { name } => Some(name.as_str()),
15728 rlx_ir::Op::Param { name } => Some(name.as_str()),
15729 _ => None,
15730 };
15731 if name == Some(want) {
15732 return node.id;
15733 }
15734 }
15735 panic!("no node named {want:?} in bwd graph");
15736 };
15737 let a_bwd = find_by_name(&bwd, "A");
15738 let b_bwd = find_by_name(&bwd, "b");
15739 let d_out_bwd = find_by_name(&bwd, "d_output");
15740
15741 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
15745 let b_data = [1.0, 2.0, 3.0_f64];
15746 let d_output = [1.0_f64]; let (sched, mut arena) = prepare_f64(
15750 &bwd,
15751 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
15752 );
15753 execute_thunks(&sched, arena.raw_buf_mut());
15754
15755 let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
15756 let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
15757 let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
15758
15759 let x_ref = {
15762 let mut a = a_data;
15763 let mut b = b_data;
15764 let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
15765 assert_eq!(info, 0);
15766 b
15767 };
15768 let loss_ref: f64 = x_ref.iter().sum();
15769 let db_ref = {
15771 let mut at = [0.0_f64; 9];
15772 for i in 0..n {
15773 for j in 0..n {
15774 at[i * n + j] = a_data[j * n + i];
15775 }
15776 }
15777 let mut ones = [1.0_f64; 3];
15778 let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
15779 assert_eq!(info, 0);
15780 ones
15781 };
15782 let mut da_ref = [0.0_f64; 9];
15784 for i in 0..n {
15785 for j in 0..n {
15786 da_ref[i * n + j] = -db_ref[i] * x_ref[j];
15787 }
15788 }
15789
15790 assert!(
15792 (loss_out[0] - loss_ref).abs() < 1e-10,
15793 "loss: got {}, want {}",
15794 loss_out[0],
15795 loss_ref
15796 );
15797 for i in 0..n {
15798 assert!(
15799 (db_out[i] - db_ref[i]).abs() < 1e-10,
15800 "db[{i}]: got {}, want {}",
15801 db_out[i],
15802 db_ref[i]
15803 );
15804 }
15805 for i in 0..n * n {
15806 assert!(
15807 (da_out[i] - da_ref[i]).abs() < 1e-10,
15808 "dA[{i}]: got {}, want {}",
15809 da_out[i],
15810 da_ref[i]
15811 );
15812 }
15813
15814 let h = 1e-6_f64;
15817 for k in 0..n {
15818 let mut bp = b_data;
15819 bp[k] += h;
15820 let mut bm = b_data;
15821 bm[k] -= h;
15822 let lp = {
15823 let mut ac = a_data;
15824 let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
15825 assert_eq!(info, 0);
15826 bp.iter().sum::<f64>()
15827 };
15828 let lm = {
15829 let mut ac = a_data;
15830 let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
15831 assert_eq!(info, 0);
15832 bm.iter().sum::<f64>()
15833 };
15834 let fd = (lp - lm) / (2.0 * h);
15835 assert!(
15836 (db_out[k] - fd).abs() < 1e-7,
15837 "FD mismatch on db[{k}]: AD={} FD={}",
15838 db_out[k],
15839 fd
15840 );
15841 }
15842 }
15843
15844 #[test]
15849 fn scan_geometric_growth_f64() {
15850 let n = 3usize;
15851 let length = 10u32;
15852
15853 let mut body = Graph::new("scan_body");
15855 let x = body.input("carry", Shape::new(&[n], DType::F64));
15856 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
15857 let scale = body.add_node(
15858 Op::Constant { data: scale_bytes },
15859 vec![],
15860 Shape::new(&[n], DType::F64),
15861 );
15862 let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
15863 let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
15864 body.set_outputs(vec![next]);
15865
15866 let mut g = Graph::new("scan_outer");
15868 let init = g.input("init", Shape::new(&[n], DType::F64));
15869 let final_carry = g.scan(init, body, length);
15870 g.set_outputs(vec![final_carry]);
15871
15872 let init_data = vec![1.0_f64; n];
15873 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
15874 execute_thunks(&sched, arena.raw_buf_mut());
15875 let got = read_arena_f64(&arena, final_carry, n);
15876 let want: f64 = 1.1_f64.powi(length as i32);
15877 for i in 0..n {
15878 assert!(
15879 (got[i] - want).abs() < 1e-12,
15880 "got[{i}] = {} want {}",
15881 got[i],
15882 want
15883 );
15884 }
15885 }
15886
15887 #[test]
15894 fn scan_with_xs_cumulative_sum() {
15895 let n = 3usize;
15896 let length = 4u32;
15897
15898 let mut body = Graph::new("cumsum_body");
15899 let carry = body.input("carry", Shape::new(&[n], DType::F64));
15901 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
15902 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
15903 body.set_outputs(vec![next]);
15904
15905 let mut g = Graph::new("cumsum_outer");
15906 let init = g.input("init", Shape::new(&[n], DType::F64));
15907 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
15908 let final_carry = g.scan_with_xs(init, &[xs], body, length);
15909 g.set_outputs(vec![final_carry]);
15910
15911 let init_data = vec![0.0_f64; n];
15912 let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
15914 execute_thunks(&sched, arena.raw_buf_mut());
15915 let got = read_arena_f64(&arena, final_carry, n);
15916
15917 let mut want = init_data.clone();
15921 for t in 0..length as usize {
15922 for j in 0..n {
15923 want[j] += xs_data[t * n + j];
15924 }
15925 }
15926 for i in 0..n {
15927 assert!(
15928 (got[i] - want[i]).abs() < 1e-12,
15929 "got[{i}] = {} want {}",
15930 got[i],
15931 want[i]
15932 );
15933 }
15934 }
15935
15936 #[test]
15940 fn scan_with_xs_be_with_drive() {
15941 let n = 3usize;
15942 let length = 4u32;
15943 let dt = 0.1_f64;
15944
15945 let mut m_data = vec![0.0_f64; n * n];
15946 for i in 0..n {
15947 m_data[i * n + i] = 1.0 + dt * 2.0;
15948 if i > 0 {
15949 m_data[i * n + (i - 1)] = -dt;
15950 }
15951 if i + 1 < n {
15952 m_data[i * n + (i + 1)] = -dt;
15953 }
15954 }
15955 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
15956
15957 let mut body = Graph::new("be_drive_body");
15958 let carry = body.input("carry", Shape::new(&[n], DType::F64));
15959 let drive = body.input("drive", Shape::new(&[n], DType::F64));
15960 let m = body.add_node(
15961 Op::Constant { data: m_bytes },
15962 vec![],
15963 Shape::new(&[n, n], DType::F64),
15964 );
15965 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
15966 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
15967 body.set_outputs(vec![next]);
15968
15969 let mut g = Graph::new("be_drive_outer");
15970 let init = g.input("init", Shape::new(&[n], DType::F64));
15971 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
15972 let final_carry = g.scan_with_xs(init, &[xs], body, length);
15973 g.set_outputs(vec![final_carry]);
15974
15975 let init_data = vec![0.0_f64; n];
15976 let mut xs_data = vec![0.0_f64; length as usize * n];
15979 xs_data[0] = 1.0;
15980
15981 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
15982 execute_thunks(&sched, arena.raw_buf_mut());
15983 let got = read_arena_f64(&arena, final_carry, n);
15984
15985 let mut x = init_data.clone();
15987 for t in 0..length as usize {
15988 for j in 0..n {
15989 x[j] += xs_data[t * n + j];
15990 }
15991 let mut a_copy = m_data.clone();
15992 crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
15993 }
15994 for i in 0..n {
15995 assert!(
15996 (got[i] - x[i]).abs() < 1e-12,
15997 "got[{i}] = {} ref {}",
15998 got[i],
15999 x[i]
16000 );
16001 }
16002 }
16003
16004 #[test]
16010 fn batched_dense_solve_gradient_matches_per_batch_analytic() {
16011 use rlx_opt::autodiff::grad_with_loss;
16012 let n = 3usize;
16013 let batch = 4usize;
16014
16015 let mut g = Graph::new("bds_grad");
16016 let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
16017 let b = g.input("b", Shape::new(&[batch, n], DType::F64));
16018 let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
16019 let loss = g.reduce(
16020 x,
16021 ReduceOp::Sum,
16022 vec![0, 1],
16023 false,
16024 Shape::new(&[1], DType::F64),
16025 );
16026 g.set_outputs(vec![loss]);
16027
16028 let bwd = grad_with_loss(&g, &[a, b]);
16029
16030 let find = |graph: &Graph, want: &str| -> NodeId {
16031 for node in graph.nodes() {
16032 let name = match &node.op {
16033 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16034 _ => None,
16035 };
16036 if name == Some(want) {
16037 return node.id;
16038 }
16039 }
16040 panic!("no node named {want}");
16041 };
16042 let a_id = find(&bwd, "A");
16043 let b_id = find(&bwd, "b");
16044 let d_out_id = find(&bwd, "d_output");
16045
16046 let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
16047 let mut a_data = vec![0.0_f64; batch * n * n];
16048 let mut b_data = vec![0.0_f64; batch * n];
16049 for bi in 0..batch {
16050 for i in 0..n {
16051 for j in 0..n {
16052 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16053 }
16054 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16055 }
16056 for i in 0..n {
16057 b_data[bi * n + i] = rng.next_f32() as f64;
16058 }
16059 }
16060 let d_seed = [1.0_f64];
16061
16062 let (sched, mut arena) = prepare_f64(
16063 &bwd,
16064 &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
16065 );
16066 execute_thunks(&sched, arena.raw_buf_mut());
16067 let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
16068 let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
16069
16070 for bi in 0..batch {
16073 let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16074 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16075 let mut a_copy = a_slice.clone();
16076 crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
16077 let x_ref = b_slice.clone();
16078 let mut at = vec![0.0_f64; n * n];
16080 for i in 0..n {
16081 for j in 0..n {
16082 at[i * n + j] = a_slice[j * n + i];
16083 }
16084 }
16085 let mut ones = vec![1.0_f64; n];
16086 crate::blas::dgesv(&mut at, &mut ones, n, 1);
16087 let db_ref = ones;
16088 for i in 0..n {
16089 let got = db_out[bi * n + i];
16090 assert!(
16091 (got - db_ref[i]).abs() < 1e-10,
16092 "batch {bi}, db[{i}]: got {got} ref {}",
16093 db_ref[i]
16094 );
16095 }
16096 for i in 0..n {
16098 for j in 0..n {
16099 let got = da_out[bi * n * n + i * n + j];
16100 let want = -db_ref[i] * x_ref[j];
16101 assert!(
16102 (got - want).abs() < 1e-10,
16103 "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
16104 );
16105 }
16106 }
16107 }
16108 }
16109
16110 #[test]
16115 fn scan_checkpointed_grad_matches_plain_scan_grad() {
16116 use rlx_opt::autodiff::grad_with_loss;
16117 let n = 2usize;
16118 let length = 6u32;
16119
16120 let make_body = || {
16121 let mut body = Graph::new("ck_body");
16122 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16123 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
16124 let scale = body.add_node(
16125 Op::Constant { data: scale_bytes },
16126 vec![],
16127 Shape::new(&[n], DType::F64),
16128 );
16129 let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
16130 body.set_outputs(vec![next]);
16131 body
16132 };
16133
16134 let mut g_plain = Graph::new("ck_plain");
16136 let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
16137 let final_p = g_plain.scan(init_p, make_body(), length);
16138 let loss_p = g_plain.reduce(
16139 final_p,
16140 ReduceOp::Sum,
16141 vec![0],
16142 false,
16143 Shape::new(&[1], DType::F64),
16144 );
16145 g_plain.set_outputs(vec![loss_p]);
16146 let bwd_p = grad_with_loss(&g_plain, &[init_p]);
16147
16148 let mut g_ck = Graph::new("ck_ckpt");
16150 let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
16151 let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
16152 let loss_c = g_ck.reduce(
16153 final_c,
16154 ReduceOp::Sum,
16155 vec![0],
16156 false,
16157 Shape::new(&[1], DType::F64),
16158 );
16159 g_ck.set_outputs(vec![loss_c]);
16160 let bwd_c = grad_with_loss(&g_ck, &[init_c]);
16161
16162 let find = |graph: &Graph, want: &str| -> NodeId {
16163 for node in graph.nodes() {
16164 let name = match &node.op {
16165 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16166 _ => None,
16167 };
16168 if name == Some(want) {
16169 return node.id;
16170 }
16171 }
16172 panic!("no {want}");
16173 };
16174
16175 let init_data = vec![0.5_f64, -0.5];
16176 let d_seed = [1.0_f64];
16177
16178 let (s_p, mut a_p) = prepare_f64(
16179 &bwd_p,
16180 &[
16181 (find(&bwd_p, "init"), &init_data),
16182 (find(&bwd_p, "d_output"), &d_seed),
16183 ],
16184 );
16185 execute_thunks(&s_p, a_p.raw_buf_mut());
16186 let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
16187
16188 let (s_c, mut a_c) = prepare_f64(
16189 &bwd_c,
16190 &[
16191 (find(&bwd_c, "init"), &init_data),
16192 (find(&bwd_c, "d_output"), &d_seed),
16193 ],
16194 );
16195 execute_thunks(&s_c, a_c.raw_buf_mut());
16196 let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
16197
16198 for i in 0..n {
16199 assert!(
16200 (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
16201 "dinit[{i}]: plain={} checkpointed={}",
16202 dinit_p[i],
16203 dinit_c[i]
16204 );
16205 }
16206 }
16207
16208 #[test]
16214 fn recursive_checkpointing_matches_full_trajectory() {
16215 let n = 2usize;
16216 let length = 4u32;
16217
16218 let build_body = || -> Graph {
16220 let mut body = Graph::new("rc_body");
16221 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16222 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16223 let ones = body.add_node(
16224 Op::Constant { data: ones_bytes },
16225 vec![],
16226 Shape::new(&[n], DType::F64),
16227 );
16228 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16229 body.set_outputs(vec![next]);
16230 body
16231 };
16232
16233 let body_vjp_for = || -> Graph {
16236 use rlx_opt::autodiff::grad;
16237 let body = build_body();
16238 let carry_id = body
16240 .nodes()
16241 .iter()
16242 .find(|n| matches!(n.op, Op::Input { .. }))
16243 .map(|n| n.id)
16244 .unwrap();
16245 grad(&body, &[carry_id])
16246 };
16247
16248 let mut g_full = Graph::new("rc_outer_full");
16250 let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
16251 let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
16252 let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16254 let dinit_full_id = g_full.scan_backward(
16255 init_full,
16256 traj_full_id,
16257 upstream_full,
16258 &[],
16259 body_vjp_for(),
16260 length,
16261 true,
16262 Shape::new(&[n], DType::F64),
16263 );
16264 g_full.set_outputs(vec![dinit_full_id]);
16265
16266 let k = 2u32;
16269 let mut g_rec = Graph::new("rc_outer_rec");
16270 let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
16271 let traj_rec_id = g_rec.add_node(
16272 Op::Scan {
16273 body: Box::new(build_body()),
16274 length,
16275 save_trajectory: true,
16276 num_bcast: 0,
16277 num_xs: 0,
16278 num_checkpoints: k,
16279 },
16280 vec![init_rec],
16281 Shape::new(&[k as usize, n], DType::F64),
16282 );
16283 let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16286 let dinit_rec_id = g_rec.add_node(
16287 Op::ScanBackward {
16288 body_vjp: Box::new(body_vjp_for()),
16289 length,
16290 save_trajectory: true,
16291 num_xs: 0,
16292 num_checkpoints: k,
16293 forward_body: Some(Box::new(build_body())),
16294 },
16295 vec![init_rec, traj_rec_id, upstream_rec],
16296 Shape::new(&[n], DType::F64),
16297 );
16298 g_rec.set_outputs(vec![dinit_rec_id]);
16299
16300 let init_data = vec![0.5_f64, -0.5];
16302 let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
16303
16304 let find = |graph: &Graph, want: &str| -> NodeId {
16305 for node in graph.nodes() {
16306 if let Op::Input { name } = &node.op
16307 && name == want
16308 {
16309 return node.id;
16310 }
16311 }
16312 panic!("no input {want}");
16313 };
16314
16315 let (s_full, mut a_full) = prepare_f64(
16316 &g_full,
16317 &[
16318 (find(&g_full, "init"), &init_data),
16319 (find(&g_full, "upstream"), &upstream_data),
16320 ],
16321 );
16322 execute_thunks(&s_full, a_full.raw_buf_mut());
16323 let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
16324
16325 let (s_rec, mut a_rec) = prepare_f64(
16326 &g_rec,
16327 &[
16328 (find(&g_rec, "init"), &init_data),
16329 (find(&g_rec, "upstream"), &upstream_data),
16330 ],
16331 );
16332 execute_thunks(&s_rec, a_rec.raw_buf_mut());
16333 let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
16334
16335 for i in 0..n {
16336 assert!(
16337 (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
16338 "i={i}: full={} rec={}",
16339 dinit_full[i],
16340 dinit_rec[i]
16341 );
16342 }
16343 }
16344
16345 #[test]
16354 fn vmap_of_grad_scan_matches_per_row_runs() {
16355 use rlx_opt::autodiff::grad_with_loss;
16356 use rlx_opt::vmap::vmap;
16357 let n = 2usize;
16358 let length = 3u32;
16359 let batch = 3usize;
16360
16361 let mut body = Graph::new("scan_grad_body");
16362 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16363 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16364 let ones = body.add_node(
16365 Op::Constant { data: ones_bytes },
16366 vec![],
16367 Shape::new(&[n], DType::F64),
16368 );
16369 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16370 body.set_outputs(vec![next]);
16371
16372 let mut g = Graph::new("scan_grad_outer");
16373 let init = g.input("init", Shape::new(&[n], DType::F64));
16374 let final_x = g.scan(init, body, length);
16375 let loss = g.reduce(
16376 final_x,
16377 ReduceOp::Sum,
16378 vec![0],
16379 false,
16380 Shape::new(&[1], DType::F64),
16381 );
16382 g.set_outputs(vec![loss]);
16383
16384 let bwd = grad_with_loss(&g, &[init]);
16385 let bg = vmap(&bwd, &["init"], batch);
16386
16387 let find = |graph: &Graph, want: &str| -> NodeId {
16388 for node in graph.nodes() {
16389 let name = match &node.op {
16390 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16391 _ => None,
16392 };
16393 if name == Some(want) {
16394 return node.id;
16395 }
16396 }
16397 panic!("no node named {want}");
16398 };
16399 let init_b = find(&bg, "init");
16400 let d_out_b = find(&bg, "d_output");
16401
16402 let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
16403 let d_seed = [1.0_f64];
16404
16405 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
16406 execute_thunks(&sched, arena.raw_buf_mut());
16407 let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
16408
16409 for i in 0..batch * n {
16410 assert!(
16411 (dinit_b[i] - 1.0).abs() < 1e-12,
16412 "dinit[{i}] = {} (expected 1.0)",
16413 dinit_b[i]
16414 );
16415 }
16416
16417 for bi in 0..batch {
16419 let row = &init_data[bi * n..(bi + 1) * n];
16420 let mut g2 = Graph::new("per_row_grad");
16421 let init2 = g2.input("init", Shape::new(&[n], DType::F64));
16422 let mut body2 = Graph::new("per_row_body");
16423 let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
16424 let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16425 let ones2 = body2.add_node(
16426 Op::Constant { data: ones2_bytes },
16427 vec![],
16428 Shape::new(&[n], DType::F64),
16429 );
16430 let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
16431 body2.set_outputs(vec![next2]);
16432 let final2 = g2.scan(init2, body2, length);
16433 let loss2 = g2.reduce(
16434 final2,
16435 ReduceOp::Sum,
16436 vec![0],
16437 false,
16438 Shape::new(&[1], DType::F64),
16439 );
16440 g2.set_outputs(vec![loss2]);
16441 let bwd2 = grad_with_loss(&g2, &[init2]);
16442 let init2_id = find(&bwd2, "init");
16443 let d_out2_id = find(&bwd2, "d_output");
16444 let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
16445 execute_thunks(&s2, a2.raw_buf_mut());
16446 let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
16447 for j in 0..n {
16448 let got = dinit_b[bi * n + j];
16449 let want = row_dinit[j];
16450 assert!(
16451 (got - want).abs() < 1e-12,
16452 "row {bi}, j {j}: vmap'd={got} per-row={want}"
16453 );
16454 }
16455 }
16456 }
16457
16458 #[test]
16464 fn vmap_scan_cumulative_sum_matches_scalar_runs() {
16465 use rlx_opt::vmap::vmap;
16466 let n = 2usize;
16467 let length = 4u32;
16468 let batch = 3usize;
16469
16470 let mut body = Graph::new("scan_body_cumsum");
16472 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16473 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
16474 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
16475 body.set_outputs(vec![next]);
16476
16477 let mut g = Graph::new("scan_outer_cumsum");
16478 let init = g.input("init", Shape::new(&[n], DType::F64));
16479 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16480 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16481 g.set_outputs(vec![final_carry]);
16482
16483 let bg = vmap(&g, &["init", "xs"], batch);
16485
16486 let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
16488 let xs_data: Vec<f64> = (0..batch * length as usize * n)
16491 .map(|i| 0.1 * (i as f64))
16492 .collect();
16493
16494 let find = |graph: &Graph, want: &str| -> NodeId {
16495 for node in graph.nodes() {
16496 if let Op::Input { name } = &node.op
16497 && name == want
16498 {
16499 return node.id;
16500 }
16501 }
16502 panic!("no input {want}");
16503 };
16504 let init_b = find(&bg, "init");
16505 let xs_b = find(&bg, "xs");
16506 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
16507 execute_thunks(&sched, arena.raw_buf_mut());
16508 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
16509
16510 for bi in 0..batch {
16512 let init_slice = &init_data[bi * n..(bi + 1) * n];
16513 let mut x = init_slice.to_vec();
16514 for t in 0..length as usize {
16515 for j in 0..n {
16516 x[j] += xs_data[bi * length as usize * n + t * n + j];
16517 }
16518 }
16519
16520 for i in 0..n {
16521 let got = batched_out[bi * n + i];
16522 assert!(
16523 (got - x[i]).abs() < 1e-12,
16524 "row {bi}, i {i}: got {got} ref {}",
16525 x[i]
16526 );
16527 }
16528 }
16529 }
16530
16531 #[test]
16536 fn vmap_dense_solve_matches_scalar_runs() {
16537 use rlx_opt::vmap::vmap;
16538 let n = 3usize;
16539 let batch = 4usize;
16540
16541 let mut g = Graph::new("solve_forward");
16542 let a = g.input("A", Shape::new(&[n, n], DType::F64));
16543 let b = g.input("b", Shape::new(&[n], DType::F64));
16544 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
16545 g.set_outputs(vec![x]);
16546
16547 let bg = vmap(&g, &["A", "b"], batch);
16549
16550 let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
16552 let mut a_data = vec![0.0_f64; batch * n * n];
16553 let mut b_data = vec![0.0_f64; batch * n];
16554 for bi in 0..batch {
16555 for i in 0..n {
16557 for j in 0..n {
16558 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16559 }
16560 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16561 }
16562 for i in 0..n {
16563 b_data[bi * n + i] = rng.next_f32() as f64;
16564 }
16565 }
16566
16567 let find = |graph: &Graph, want: &str| -> NodeId {
16568 for node in graph.nodes() {
16569 if let Op::Input { name } = &node.op
16570 && name == want
16571 {
16572 return node.id;
16573 }
16574 }
16575 panic!("no input named {want}");
16576 };
16577 let ba = find(&bg, "A");
16578 let bb = find(&bg, "b");
16579 let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
16580 execute_thunks(&sched, arena.raw_buf_mut());
16581 let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
16582
16583 for bi in 0..batch {
16585 let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16586 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16587 crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
16588 for i in 0..n {
16589 let got = batched_x[bi * n + i];
16590 let want = b_slice[i];
16591 assert!(
16592 (got - want).abs() < 1e-12,
16593 "row {bi}, i {i}: got {got} want {want}"
16594 );
16595 }
16596 }
16597 }
16598
16599 #[test]
16606 fn vmap_matmul_add_reduce_matches_scalar_runs() {
16607 use rlx_opt::vmap::vmap;
16608 let n = 3usize;
16609 let batch = 4usize;
16610
16611 let mut g = Graph::new("vmap_e2e_forward");
16613 let x = g.input("x", Shape::new(&[n], DType::F64));
16614 let w = g.input("w", Shape::new(&[n, n], DType::F64));
16615 let b = g.input("b", Shape::new(&[n], DType::F64));
16616 let x_row = g.add_node(
16617 Op::Reshape {
16618 new_shape: vec![1, n as i64],
16619 },
16620 vec![x],
16621 Shape::new(&[1, n], DType::F64),
16622 );
16623 let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
16624 let mm_flat = g.add_node(
16625 Op::Reshape {
16626 new_shape: vec![n as i64],
16627 },
16628 vec![mm],
16629 Shape::new(&[n], DType::F64),
16630 );
16631 let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
16632 let loss = g.reduce(
16633 yv,
16634 ReduceOp::Sum,
16635 vec![0],
16636 false,
16637 Shape::new(&[1], DType::F64),
16638 );
16639 g.set_outputs(vec![loss]);
16640
16641 let bg = vmap(&g, &["x"], batch);
16643
16644 let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
16646 let n_w = n * n;
16647 let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
16648 let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
16649 let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
16650 for _ in 0..batch * n {
16651 x_data_batched.push(rng.next_f32() as f64);
16652 }
16653
16654 let find = |graph: &Graph, want: &str| -> NodeId {
16656 for node in graph.nodes() {
16657 if let Op::Input { name } = &node.op
16658 && name == want
16659 {
16660 return node.id;
16661 }
16662 }
16663 panic!("no input named {want}");
16664 };
16665 let bx = find(&bg, "x");
16666 let bw = find(&bg, "w");
16667 let bb = find(&bg, "b");
16668 let (sched, mut arena) =
16669 prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
16670 execute_thunks(&sched, arena.raw_buf_mut());
16671 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
16677
16678 for bi in 0..batch {
16680 let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
16681 let mut g2 = Graph::new("scalar_run");
16682 let x2 = g2.input("x", Shape::new(&[n], DType::F64));
16683 let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
16684 let b2 = g2.input("b", Shape::new(&[n], DType::F64));
16685 let xr = g2.add_node(
16686 Op::Reshape {
16687 new_shape: vec![1, n as i64],
16688 },
16689 vec![x2],
16690 Shape::new(&[1, n], DType::F64),
16691 );
16692 let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
16693 let mf = g2.add_node(
16694 Op::Reshape {
16695 new_shape: vec![n as i64],
16696 },
16697 vec![m],
16698 Shape::new(&[n], DType::F64),
16699 );
16700 let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
16701 let l2 = g2.reduce(
16702 yv2,
16703 ReduceOp::Sum,
16704 vec![0],
16705 false,
16706 Shape::new(&[1], DType::F64),
16707 );
16708 g2.set_outputs(vec![l2]);
16709 let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
16710 execute_thunks(&s2, a2.raw_buf_mut());
16711 let scalar_out = read_arena_f64(&a2, l2, 1);
16712 assert!(
16713 (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
16714 "row {bi}: batched={} scalar={}",
16715 batched_out[bi],
16716 scalar_out[0]
16717 );
16718 }
16719 }
16720
16721 #[test]
16728 fn scan_with_xs_dxs_matches_fd() {
16729 use rlx_opt::autodiff::grad_with_loss;
16730 let n = 3usize;
16731 let length = 3u32;
16732 let dt = 0.1_f64;
16733
16734 let mut m_data = vec![0.0_f64; n * n];
16735 for i in 0..n {
16736 m_data[i * n + i] = 1.0 + dt * 2.0;
16737 if i > 0 {
16738 m_data[i * n + (i - 1)] = -dt;
16739 }
16740 if i + 1 < n {
16741 m_data[i * n + (i + 1)] = -dt;
16742 }
16743 }
16744 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16745
16746 let mut body = Graph::new("be_dxs_body");
16747 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16748 let drive = body.input("drive", Shape::new(&[n], DType::F64));
16749 let m = body.add_node(
16750 Op::Constant { data: m_bytes },
16751 vec![],
16752 Shape::new(&[n, n], DType::F64),
16753 );
16754 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16755 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16756 body.set_outputs(vec![next]);
16757
16758 let mut g = Graph::new("be_dxs_outer");
16759 let init = g.input("init", Shape::new(&[n], DType::F64));
16760 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16761 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16762 let loss = g.reduce(
16763 final_carry,
16764 ReduceOp::Sum,
16765 vec![0],
16766 false,
16767 Shape::new(&[1], DType::F64),
16768 );
16769 g.set_outputs(vec![loss]);
16770
16771 let bwd = grad_with_loss(&g, &[init, xs]);
16773 assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
16774
16775 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16776 for node in graph.nodes() {
16777 let name = match &node.op {
16778 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16779 _ => None,
16780 };
16781 if name == Some(want) {
16782 return node.id;
16783 }
16784 }
16785 panic!("no node named {want:?}");
16786 };
16787 let init_bwd = find_by_name(&bwd, "init");
16788 let xs_bwd = find_by_name(&bwd, "xs");
16789 let d_out_bwd = find_by_name(&bwd, "d_output");
16790
16791 let init_data = vec![0.5_f64, 0.0, -0.5];
16792 let xs_data: Vec<f64> = (0..length as usize * n)
16793 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16794 .collect();
16795 let d_seed = [1.0_f64];
16796
16797 let (sched, mut arena) = prepare_f64(
16798 &bwd,
16799 &[
16800 (init_bwd, &init_data),
16801 (xs_bwd, &xs_data),
16802 (d_out_bwd, &d_seed),
16803 ],
16804 );
16805 execute_thunks(&sched, arena.raw_buf_mut());
16806 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16807 let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
16808
16809 let h = 1e-6;
16810 let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
16811 let mut acc = x0.to_vec();
16812 for t in 0..length as usize {
16813 for j in 0..n {
16814 acc[j] += xs_in[t * n + j];
16815 }
16816 let mut a_copy = m_data.clone();
16817 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16818 }
16819 acc.iter().sum()
16820 };
16821
16822 for i in 0..n {
16824 let mut ip = init_data.to_vec();
16825 ip[i] += h;
16826 let mut im = init_data.to_vec();
16827 im[i] -= h;
16828 let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
16829 assert!(
16830 (dinit[i] - fd).abs() < 1e-7,
16831 "FD dinit[{i}]: AD={} FD={}",
16832 dinit[i],
16833 fd
16834 );
16835 }
16836
16837 for t in 0..length as usize {
16839 for j in 0..n {
16840 let idx = t * n + j;
16841 let mut xp = xs_data.clone();
16842 xp[idx] += h;
16843 let mut xm = xs_data.clone();
16844 xm[idx] -= h;
16845 let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
16846 assert!(
16847 (dxs[idx] - fd).abs() < 1e-7,
16848 "FD dxs[t={t},j={j}]: AD={} FD={}",
16849 dxs[idx],
16850 fd
16851 );
16852 }
16853 }
16854 }
16855
16856 #[test]
16864 fn scan_with_xs_gradient_dinit_matches_fd() {
16865 use rlx_opt::autodiff::grad_with_loss;
16866 let n = 3usize;
16867 let length = 3u32;
16868 let dt = 0.1_f64;
16869
16870 let mut m_data = vec![0.0_f64; n * n];
16871 for i in 0..n {
16872 m_data[i * n + i] = 1.0 + dt * 2.0;
16873 if i > 0 {
16874 m_data[i * n + (i - 1)] = -dt;
16875 }
16876 if i + 1 < n {
16877 m_data[i * n + (i + 1)] = -dt;
16878 }
16879 }
16880 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16881
16882 let mut body = Graph::new("be_xs_grad_body");
16883 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16884 let drive = body.input("drive", Shape::new(&[n], DType::F64));
16885 let m = body.add_node(
16886 Op::Constant { data: m_bytes },
16887 vec![],
16888 Shape::new(&[n, n], DType::F64),
16889 );
16890 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16891 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16892 body.set_outputs(vec![next]);
16893
16894 let mut g = Graph::new("be_xs_grad_outer");
16895 let init = g.input("init", Shape::new(&[n], DType::F64));
16896 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16897 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16898 let loss = g.reduce(
16899 final_carry,
16900 ReduceOp::Sum,
16901 vec![0],
16902 false,
16903 Shape::new(&[1], DType::F64),
16904 );
16905 g.set_outputs(vec![loss]);
16906
16907 let bwd = grad_with_loss(&g, &[init]);
16908
16909 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16910 for node in graph.nodes() {
16911 let name = match &node.op {
16912 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16913 _ => None,
16914 };
16915 if name == Some(want) {
16916 return node.id;
16917 }
16918 }
16919 panic!("no node named {want:?}");
16920 };
16921 let init_bwd = find_by_name(&bwd, "init");
16922 let xs_bwd = find_by_name(&bwd, "xs");
16923 let d_out_bwd = find_by_name(&bwd, "d_output");
16924
16925 let init_data = vec![0.5_f64, 0.0, -0.5];
16926 let xs_data: Vec<f64> = (0..length as usize * n)
16928 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16929 .collect();
16930 let d_seed = [1.0_f64];
16931
16932 let (sched, mut arena) = prepare_f64(
16933 &bwd,
16934 &[
16935 (init_bwd, &init_data),
16936 (xs_bwd, &xs_data),
16937 (d_out_bwd, &d_seed),
16938 ],
16939 );
16940 execute_thunks(&sched, arena.raw_buf_mut());
16941 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16942
16943 let h = 1e-6;
16944 let loss_at = |x0: &[f64]| -> f64 {
16945 let mut acc = x0.to_vec();
16946 for t in 0..length as usize {
16947 for j in 0..n {
16948 acc[j] += xs_data[t * n + j];
16949 }
16950 let mut a_copy = m_data.clone();
16951 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16952 }
16953 acc.iter().sum()
16954 };
16955 for i in 0..n {
16956 let mut ip = init_data.to_vec();
16957 ip[i] += h;
16958 let mut im = init_data.to_vec();
16959 im[i] -= h;
16960 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
16961 assert!(
16962 (dinit[i] - fd).abs() < 1e-7,
16963 "FD dinit[{i}]: AD={} FD={}",
16964 dinit[i],
16965 fd
16966 );
16967 }
16968 }
16969
16970 #[test]
16978 fn scan_gradient_geometric_matches_closed_form() {
16979 use rlx_opt::autodiff::grad_with_loss;
16980 let n = 3usize;
16981 let length = 5u32;
16982
16983 let mut body = Graph::new("scan_grad_body");
16984 let x = body.input("carry", Shape::new(&[n], DType::F64));
16985 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
16986 let scale = body.add_node(
16987 Op::Constant { data: scale_bytes },
16988 vec![],
16989 Shape::new(&[n], DType::F64),
16990 );
16991 let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
16992 body.set_outputs(vec![next]);
16993
16994 let mut g = Graph::new("scan_grad_outer");
16995 let init = g.input("init", Shape::new(&[n], DType::F64));
16996 let final_x = g.scan(init, body, length);
16997 let loss = g.reduce(
16998 final_x,
16999 ReduceOp::Sum,
17000 vec![0],
17001 false,
17002 Shape::new(&[1], DType::F64),
17003 );
17004 g.set_outputs(vec![loss]);
17005
17006 let bwd = grad_with_loss(&g, &[init]);
17007 assert_eq!(bwd.outputs.len(), 2);
17008
17009 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17010 for node in graph.nodes() {
17011 let name = match &node.op {
17012 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17013 _ => None,
17014 };
17015 if name == Some(want) {
17016 return node.id;
17017 }
17018 }
17019 panic!("no node named {want:?}");
17020 };
17021 let init_bwd = find_by_name(&bwd, "init");
17022 let d_out_bwd = find_by_name(&bwd, "d_output");
17023
17024 let init_data = vec![1.0_f64; n];
17025 let d_seed = [1.0_f64];
17026 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17027 execute_thunks(&sched, arena.raw_buf_mut());
17028 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17029
17030 let want = 1.1_f64.powi(length as i32);
17031 for i in 0..n {
17032 assert!(
17033 (dinit[i] - want).abs() < 1e-12,
17034 "dinit[{i}] = {} want {}",
17035 dinit[i],
17036 want
17037 );
17038 }
17039
17040 let h = 1e-6;
17042 let loss_at = |x: &[f64]| -> f64 {
17043 let mut acc = x.to_vec();
17044 for _ in 0..length {
17045 for v in acc.iter_mut() {
17046 *v *= 1.1;
17047 }
17048 }
17049 acc.iter().sum()
17050 };
17051 let mut ip = init_data.clone();
17052 ip[0] += h;
17053 let mut im = init_data.clone();
17054 im[0] -= h;
17055 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17056 assert!(
17057 (dinit[0] - fd).abs() < 1e-7,
17058 "FD dinit[0]: AD={} FD={}",
17059 dinit[0],
17060 fd
17061 );
17062 }
17063
17064 #[test]
17067 fn scan_gradient_backward_euler_matches_fd() {
17068 use rlx_opt::autodiff::grad_with_loss;
17069 let n = 4usize;
17070 let length = 3u32;
17071 let dt = 0.05_f64;
17072
17073 let mut m_data = vec![0.0_f64; n * n];
17074 for i in 0..n {
17075 m_data[i * n + i] = 1.0 + dt * 2.0;
17076 if i > 0 {
17077 m_data[i * n + (i - 1)] = -dt;
17078 }
17079 if i + 1 < n {
17080 m_data[i * n + (i + 1)] = -dt;
17081 }
17082 }
17083 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17084
17085 let mut body = Graph::new("be_grad_body");
17086 let x = body.input("x", Shape::new(&[n], DType::F64));
17087 let m = body.add_node(
17088 Op::Constant { data: m_bytes },
17089 vec![],
17090 Shape::new(&[n, n], DType::F64),
17091 );
17092 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17093 body.set_outputs(vec![next]);
17094
17095 let mut g = Graph::new("be_grad_outer");
17096 let init = g.input("x0", Shape::new(&[n], DType::F64));
17097 let final_x = g.scan(init, body, length);
17098 let loss = g.reduce(
17099 final_x,
17100 ReduceOp::Sum,
17101 vec![0],
17102 false,
17103 Shape::new(&[1], DType::F64),
17104 );
17105 g.set_outputs(vec![loss]);
17106
17107 let bwd = grad_with_loss(&g, &[init]);
17108
17109 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17110 for node in graph.nodes() {
17111 let name = match &node.op {
17112 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17113 _ => None,
17114 };
17115 if name == Some(want) {
17116 return node.id;
17117 }
17118 }
17119 panic!("no node named {want:?}");
17120 };
17121 let init_bwd = find_by_name(&bwd, "x0");
17122 let d_out_bwd = find_by_name(&bwd, "d_output");
17123
17124 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17125 let d_seed = [1.0_f64];
17126 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17127 execute_thunks(&sched, arena.raw_buf_mut());
17128 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17129
17130 let h = 1e-6;
17131 let loss_at = |x0: &[f64]| -> f64 {
17132 let mut acc = x0.to_vec();
17133 for _ in 0..length {
17134 let mut a_copy = m_data.clone();
17135 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
17136 }
17137 acc.iter().sum()
17138 };
17139 for i in 0..n {
17140 let mut ip = init_data.to_vec();
17141 ip[i] += h;
17142 let mut im = init_data.to_vec();
17143 im[i] -= h;
17144 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17145 assert!(
17146 (dinit[i] - fd).abs() < 1e-7,
17147 "FD dinit[{i}]: AD={} FD={}",
17148 dinit[i],
17149 fd
17150 );
17151 }
17152 }
17153
17154 #[test]
17160 fn scan_trajectory_backward_euler_records_waveform() {
17161 let n = 4usize;
17162 let length = 5u32;
17163 let dt = 0.05_f64;
17164
17165 let mut m_data = vec![0.0_f64; n * n];
17166 for i in 0..n {
17167 m_data[i * n + i] = 1.0 + dt * 2.0;
17168 if i > 0 {
17169 m_data[i * n + (i - 1)] = -dt;
17170 }
17171 if i + 1 < n {
17172 m_data[i * n + (i + 1)] = -dt;
17173 }
17174 }
17175 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17176
17177 let mut body = Graph::new("be_traj_body");
17178 let x = body.input("x", Shape::new(&[n], DType::F64));
17179 let m = body.add_node(
17180 Op::Constant { data: m_bytes },
17181 vec![],
17182 Shape::new(&[n, n], DType::F64),
17183 );
17184 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17185 body.set_outputs(vec![next]);
17186
17187 let mut g = Graph::new("be_traj_outer");
17188 let init = g.input("x0", Shape::new(&[n], DType::F64));
17189 let traj = g.scan_trajectory(init, body, length);
17190 g.set_outputs(vec![traj]);
17191
17192 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17193 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17194 execute_thunks(&sched, arena.raw_buf_mut());
17195 let got = read_arena_f64(&arena, traj, length as usize * n);
17196
17197 let mut want = Vec::<f64>::with_capacity(length as usize * n);
17199 let mut x_ref = init_data.to_vec();
17200 for _ in 0..length {
17201 let mut a_copy = m_data.clone();
17202 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
17203 want.extend_from_slice(&x_ref);
17204 }
17205 for i in 0..length as usize * n {
17206 assert!(
17207 (got[i] - want[i]).abs() < 1e-12,
17208 "got[{i}] = {} ref {}",
17209 got[i],
17210 want[i]
17211 );
17212 }
17213
17214 for t in 1..length as usize {
17217 let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
17218 let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
17219 assert!(
17220 curr <= prev + 1e-15,
17221 "mass should decay: row {} sum {prev}, row {t} sum {curr}",
17222 t - 1
17223 );
17224 }
17225
17226 let mut body2 = Graph::new("be_final_body");
17230 let x2 = body2.input("x", Shape::new(&[n], DType::F64));
17231 let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17232 let m2 = body2.add_node(
17233 Op::Constant { data: m_bytes2 },
17234 vec![],
17235 Shape::new(&[n, n], DType::F64),
17236 );
17237 let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
17238 body2.set_outputs(vec![next2]);
17239
17240 let mut g2 = Graph::new("be_final_outer");
17241 let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
17242 let final_x = g2.scan(init2, body2, length);
17243 g2.set_outputs(vec![final_x]);
17244 let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
17245 execute_thunks(&sched2, arena2.raw_buf_mut());
17246 let final_got = read_arena_f64(&arena2, final_x, n);
17247
17248 let last_row = &got[(length as usize - 1) * n..length as usize * n];
17249 for i in 0..n {
17250 assert!(
17251 (last_row[i] - final_got[i]).abs() < 1e-15,
17252 "last trajectory row[{i}] = {} vs final-scan = {}",
17253 last_row[i],
17254 final_got[i]
17255 );
17256 }
17257 }
17258
17259 #[test]
17265 fn scan_backward_euler_heat_f64() {
17266 let n = 4usize;
17267 let length = 5u32;
17268 let dt = 0.05_f64;
17269
17270 let mut m_data = vec![0.0_f64; n * n];
17273 for i in 0..n {
17274 m_data[i * n + i] = 1.0 + dt * 2.0;
17275 if i > 0 {
17276 m_data[i * n + (i - 1)] = -dt;
17277 }
17278 if i + 1 < n {
17279 m_data[i * n + (i + 1)] = -dt;
17280 }
17281 }
17282 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17283
17284 let mut body = Graph::new("be_body");
17285 let x = body.input("x", Shape::new(&[n], DType::F64));
17286 let m = body.add_node(
17287 Op::Constant { data: m_bytes },
17288 vec![],
17289 Shape::new(&[n, n], DType::F64),
17290 );
17291 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17292 body.set_outputs(vec![next]);
17293
17294 let mut g = Graph::new("be_outer");
17295 let init = g.input("x0", Shape::new(&[n], DType::F64));
17296 let final_x = g.scan(init, body, length);
17297 g.set_outputs(vec![final_x]);
17298
17299 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17301 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17302 execute_thunks(&sched, arena.raw_buf_mut());
17303 let got = read_arena_f64(&arena, final_x, n);
17304
17305 let mut ref_x = init_data.to_vec();
17307 for _ in 0..length {
17308 let mut a_copy = m_data.clone();
17309 crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
17310 }
17311 for i in 0..n {
17312 assert!(
17313 (got[i] - ref_x[i]).abs() < 1e-12,
17314 "got[{i}] = {} ref {}",
17315 got[i],
17316 ref_x[i]
17317 );
17318 }
17319 let mass: f64 = got.iter().sum();
17324 assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
17325 }
17326
17327 #[test]
17331 fn dense_solve_f64_multi_rhs_forward() {
17332 let n = 3usize;
17333 let k = 2usize;
17334 let mut g = Graph::new("solve_multi_rhs");
17335 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17336 let b = g.input("B", Shape::new(&[n, k], DType::F64));
17337 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17338 g.set_outputs(vec![x]);
17339
17340 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17341 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17342 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17343 execute_thunks(&sched, arena.raw_buf_mut());
17344 let x_got = read_arena_f64(&arena, x, n * k);
17345 for c in 0..k {
17346 for i in 0..n {
17347 let mut acc = 0.0_f64;
17348 for j in 0..n {
17349 acc += a_data[i * n + j] * x_got[j * k + c];
17350 }
17351 let want = b_data[i * k + c];
17352 assert!(
17353 (acc - want).abs() < 1e-10,
17354 "col {c} row {i}: got {acc} want {want}"
17355 );
17356 }
17357 }
17358 }
17359
17360 #[test]
17363 fn dense_solve_f64_multi_rhs_gradient() {
17364 use rlx_opt::autodiff::grad_with_loss;
17365 let n = 3usize;
17366 let k = 2usize;
17367 let mut g = Graph::new("solve_mrhs_grad");
17368 let a = g.param("A", Shape::new(&[n, n], DType::F64));
17369 let b = g.input("B", Shape::new(&[n, k], DType::F64));
17370 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17371 let loss = g.reduce(
17372 x,
17373 ReduceOp::Sum,
17374 vec![0, 1],
17375 false,
17376 Shape::new(&[1], DType::F64),
17377 );
17378 g.set_outputs(vec![loss]);
17379
17380 let bwd = grad_with_loss(&g, &[a, b]);
17381 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17382 for node in graph.nodes() {
17383 let name = match &node.op {
17384 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17385 _ => None,
17386 };
17387 if name == Some(want) {
17388 return node.id;
17389 }
17390 }
17391 panic!("no node named {want:?}");
17392 };
17393 let a_bwd = find_by_name(&bwd, "A");
17394 let b_bwd = find_by_name(&bwd, "B");
17395 let d_out = find_by_name(&bwd, "d_output");
17396
17397 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17398 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17399 let d_seed = [1.0_f64];
17400
17401 let (sched, mut arena) = prepare_f64(
17402 &bwd,
17403 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
17404 );
17405 execute_thunks(&sched, arena.raw_buf_mut());
17406 let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
17407 let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
17408
17409 let mut x_ref = b_data;
17411 {
17412 let mut a_copy = a_data;
17413 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
17414 }
17415 let mut at = [0.0_f64; 9];
17416 for i in 0..n {
17417 for j in 0..n {
17418 at[i * n + j] = a_data[j * n + i];
17419 }
17420 }
17421 let mut ones_nk = vec![1.0_f64; n * k];
17422 crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
17423 let db_ref = ones_nk;
17424 let mut da_ref = [0.0_f64; 9];
17425 for i in 0..n {
17426 for j in 0..n {
17427 let mut acc = 0.0_f64;
17428 for c in 0..k {
17429 acc += db_ref[i * k + c] * x_ref[j * k + c];
17430 }
17431 da_ref[i * n + j] = -acc;
17432 }
17433 }
17434 for i in 0..n * k {
17435 assert!(
17436 (db_got[i] - db_ref[i]).abs() < 1e-10,
17437 "dB[{i}]: got {} want {}",
17438 db_got[i],
17439 db_ref[i]
17440 );
17441 }
17442 for i in 0..n * n {
17443 assert!(
17444 (da_got[i] - da_ref[i]).abs() < 1e-10,
17445 "dA[{i}]: got {} want {}",
17446 da_got[i],
17447 da_ref[i]
17448 );
17449 }
17450
17451 let h = 1e-6;
17453 let mut bp = b_data;
17454 bp[0] += h;
17455 let mut bm = b_data;
17456 bm[0] -= h;
17457 let xp = {
17458 let mut a_copy = a_data;
17459 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17460 bp
17461 };
17462 let xm = {
17463 let mut a_copy = a_data;
17464 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17465 bm
17466 };
17467 let lp: f64 = xp.iter().sum();
17468 let lm: f64 = xm.iter().sum();
17469 let fd = (lp - lm) / (2.0 * h);
17470 assert!(
17471 (db_got[0] - fd).abs() < 1e-7,
17472 "FD dB[0,0]: AD={} FD={}",
17473 db_got[0],
17474 fd
17475 );
17476 }
17477
17478 #[test]
17480 fn dense_solve_f64_multi_rhs_jvp() {
17481 use rlx_opt::autodiff_fwd::jvp;
17482 let n = 3usize;
17483 let k = 2usize;
17484 let mut g = Graph::new("solve_mrhs_jvp");
17485 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17486 let b = g.input("B", Shape::new(&[n, k], DType::F64));
17487 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17488 g.set_outputs(vec![x]);
17489
17490 let jg = jvp(&g, &[b]);
17491 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17492 for node in graph.nodes() {
17493 let name = match &node.op {
17494 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17495 _ => None,
17496 };
17497 if name == Some(want) {
17498 return node.id;
17499 }
17500 }
17501 panic!("no node named {want:?}");
17502 };
17503 let a_id = find_by_name(&jg, "A");
17504 let b_id = find_by_name(&jg, "B");
17505 let tb_id = find_by_name(&jg, "tangent_B");
17506
17507 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17508 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17509 let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
17510
17511 let (sched, mut arena) =
17512 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17513 execute_thunks(&sched, arena.raw_buf_mut());
17514 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
17515
17516 let mut a_copy = a_data;
17517 let mut tb_copy = tb_data;
17518 crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
17519 for i in 0..n * k {
17520 assert!(
17521 (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
17522 "t_X[{i}]: AD={} ref={}",
17523 tangent_x[i],
17524 tb_copy[i]
17525 );
17526 }
17527
17528 let h = 1e-6;
17529 let mut bp = b_data;
17530 let mut bm = b_data;
17531 for i in 0..n * k {
17532 bp[i] += h * tb_data[i];
17533 bm[i] -= h * tb_data[i];
17534 }
17535 let xp = {
17536 let mut a_copy = a_data;
17537 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17538 bp
17539 };
17540 let xm = {
17541 let mut a_copy = a_data;
17542 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17543 bm
17544 };
17545 for i in 0..n * k {
17546 let fd = (xp[i] - xm[i]) / (2.0 * h);
17547 assert!(
17548 (tangent_x[i] - fd).abs() < 1e-7,
17549 "FD t_X[{i}]: AD={} FD={}",
17550 tangent_x[i],
17551 fd
17552 );
17553 }
17554 }
17555
17556 #[test]
17563 fn jvp_dense_solve_b_runs_and_matches_fd() {
17564 use rlx_opt::autodiff_fwd::jvp;
17565 let n = 3usize;
17566
17567 let mut g = Graph::new("jvp_b_e2e");
17569 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17570 let b = g.input("b", Shape::new(&[n], DType::F64));
17571 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17572 g.set_outputs(vec![x]);
17573
17574 let jg = jvp(&g, &[b]);
17576 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17578 for node in graph.nodes() {
17579 let name = match &node.op {
17580 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17581 _ => None,
17582 };
17583 if name == Some(want) {
17584 return node.id;
17585 }
17586 }
17587 panic!("no node named {want:?}");
17588 };
17589 let a_id = find_by_name(&jg, "A");
17590 let b_id = find_by_name(&jg, "b");
17591 let tb_id = find_by_name(&jg, "tangent_b");
17592
17593 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17594 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17595 let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
17597
17598 let (sched, mut arena) =
17599 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17600 execute_thunks(&sched, arena.raw_buf_mut());
17601
17602 let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
17604 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17605
17606 let t_x_ref = {
17608 let mut a = a_data;
17609 let mut tb = tb_data;
17610 let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
17611 assert_eq!(info, 0);
17612 tb
17613 };
17614 for i in 0..n {
17615 assert!(
17616 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17617 "t_x[{i}]: got {} want {}",
17618 tangent_x[i],
17619 t_x_ref[i]
17620 );
17621 }
17622
17623 let h = 1e-6;
17625 let mut bp = b_data;
17626 let mut bm = b_data;
17627 for i in 0..n {
17628 bp[i] += h * tb_data[i];
17629 bm[i] -= h * tb_data[i];
17630 }
17631 let xp = {
17632 let mut a = a_data;
17633 let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
17634 assert_eq!(info, 0);
17635 bp
17636 };
17637 let xm = {
17638 let mut a = a_data;
17639 let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
17640 assert_eq!(info, 0);
17641 bm
17642 };
17643 let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
17644 for i in 0..n {
17645 assert!(
17646 (tangent_x[i] - fd[i]).abs() < 1e-7,
17647 "FD mismatch t_x[{i}]: AD={} FD={}",
17648 tangent_x[i],
17649 fd[i]
17650 );
17651 }
17652 let primal_ref = {
17654 let mut a = a_data;
17655 let mut b = b_data;
17656 crate::blas::dgesv(&mut a, &mut b, n, 1);
17657 b
17658 };
17659 for i in 0..n {
17660 assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
17661 }
17662 }
17663
17664 #[test]
17670 fn jvp_dense_solve_a_runs_and_matches_fd() {
17671 use rlx_opt::autodiff_fwd::jvp;
17672 let n = 3usize;
17673
17674 let mut g = Graph::new("jvp_a_e2e");
17675 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17676 let b = g.input("b", Shape::new(&[n], DType::F64));
17677 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17678 g.set_outputs(vec![x]);
17679
17680 let jg = jvp(&g, &[a]);
17681 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17682 for node in graph.nodes() {
17683 let name = match &node.op {
17684 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17685 _ => None,
17686 };
17687 if name == Some(want) {
17688 return node.id;
17689 }
17690 }
17691 panic!("no node named {want:?}");
17692 };
17693 let a_id = find_by_name(&jg, "A");
17694 let b_id = find_by_name(&jg, "b");
17695 let ta_id = find_by_name(&jg, "tangent_A");
17696
17697 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17698 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17699 let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
17701
17702 let (sched, mut arena) =
17703 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
17704 execute_thunks(&sched, arena.raw_buf_mut());
17705
17706 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17707
17708 let x_ref = {
17710 let mut a = a_data;
17711 let mut b = b_data;
17712 crate::blas::dgesv(&mut a, &mut b, n, 1);
17713 b
17714 };
17715 let mut prod = [0.0_f64; 3];
17716 for i in 0..n {
17717 for j in 0..n {
17718 prod[i] += ta_data[i * n + j] * x_ref[j];
17719 }
17720 }
17721 let t_x_ref = {
17722 let mut a = a_data;
17723 let mut p = prod;
17724 crate::blas::dgesv(&mut a, &mut p, n, 1);
17725 [-p[0], -p[1], -p[2]]
17726 };
17727 for i in 0..n {
17728 assert!(
17729 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17730 "closed-form t_x[{i}]: AD={} ref={}",
17731 tangent_x[i],
17732 t_x_ref[i]
17733 );
17734 }
17735
17736 let h = 1e-6;
17738 let mut ap = a_data;
17739 let mut am = a_data;
17740 for i in 0..n * n {
17741 ap[i] += h * ta_data[i];
17742 am[i] -= h * ta_data[i];
17743 }
17744 let xp = {
17745 let mut a = ap;
17746 let mut b = b_data;
17747 crate::blas::dgesv(&mut a, &mut b, n, 1);
17748 b
17749 };
17750 let xm = {
17751 let mut a = am;
17752 let mut b = b_data;
17753 crate::blas::dgesv(&mut a, &mut b, n, 1);
17754 b
17755 };
17756 for i in 0..n {
17757 let fd = (xp[i] - xm[i]) / (2.0 * h);
17758 assert!(
17759 (tangent_x[i] - fd).abs() < 1e-7,
17760 "FD t_x[{i}]: AD={} FD={}",
17761 tangent_x[i],
17762 fd
17763 );
17764 }
17765 }
17766
17767 #[test]
17773 fn q_conv2d_matches_reference() {
17774 use rlx_ir::Philox4x32;
17775 let n = 1usize;
17777 let c_in = 2usize;
17778 let h = 5usize;
17779 let w_in = 5usize;
17780 let c_out = 3usize;
17781 let kh = 3usize;
17782 let kw = 3usize;
17783 let ph = 1usize;
17784 let pw = 1usize;
17785 let sh = 1usize;
17786 let sw = 1usize;
17787 let h_out = (h + 2 * ph - kh) / sh + 1;
17788 let w_out = (w_in + 2 * pw - kw) / sw + 1;
17789
17790 let x_scale = 0.04f32;
17791 let w_scale = 0.02f32;
17792 let out_scale = 0.5f32;
17793 let mult = x_scale * w_scale / out_scale;
17794
17795 let mut rng = Philox4x32::new(2099);
17796 let mut xf = vec![0f32; n * c_in * h * w_in];
17797 rng.fill_normal(&mut xf);
17798 let mut wf = vec![0f32; c_out * c_in * kh * kw];
17799 rng.fill_normal(&mut wf);
17800 let xq: Vec<i8> = xf
17801 .iter()
17802 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17803 .collect();
17804 let wq: Vec<i8> = wf
17805 .iter()
17806 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17807 .collect();
17808 let bias: Vec<i32> = vec![0i32; c_out];
17809
17810 let mut g = Graph::new("qconv");
17811 let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
17812 let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
17813 let bn = g.input("b", Shape::new(&[c_out], DType::I32));
17814 let out = g.q_conv2d(
17815 xn,
17816 wn,
17817 bn,
17818 vec![kh, kw],
17819 vec![sh, sw],
17820 vec![ph, pw],
17821 vec![1, 1],
17822 1,
17823 0,
17824 0,
17825 0,
17826 mult,
17827 Shape::new(&[n, c_out, h_out, w_out], DType::I8),
17828 );
17829 g.set_outputs(vec![out]);
17830
17831 let plan = rlx_opt::memory::plan_memory(&g);
17832 let mut arena = crate::arena::Arena::from_plan(plan);
17833 let sched = compile_thunks(&g, &arena);
17834 let xn_off = arena.byte_offset(xn);
17837 let wn_off = arena.byte_offset(wn);
17838 let bn_off = arena.byte_offset(bn);
17839 let out_off = arena.byte_offset(out);
17840 let buf = arena.raw_buf_mut();
17841 unsafe {
17842 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
17843 for (i, &v) in xq.iter().enumerate() {
17844 *p.add(i) = v;
17845 }
17846 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
17847 for (i, &v) in wq.iter().enumerate() {
17848 *p.add(i) = v;
17849 }
17850 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
17851 for (i, &v) in bias.iter().enumerate() {
17852 *p.add(i) = v;
17853 }
17854 }
17855 execute_thunks(&sched, arena.raw_buf_mut());
17856 let out_q: Vec<i8> = unsafe {
17857 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
17858 (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
17859 };
17860
17861 let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
17863 for ni in 0..n {
17864 for co in 0..c_out {
17865 for ho in 0..h_out {
17866 for wo in 0..w_out {
17867 let mut acc: i32 = 0;
17868 for ci in 0..c_in {
17869 for ki in 0..kh {
17870 for kj in 0..kw {
17871 let hi = ho * sh + ki;
17872 let wi = wo * sw + kj;
17873 if hi < ph || wi < pw {
17874 continue;
17875 }
17876 let hi = hi - ph;
17877 let wi = wi - pw;
17878 if hi >= h || wi >= w_in {
17879 continue;
17880 }
17881 let xv =
17882 xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
17883 let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
17884 acc += xv * wv;
17885 }
17886 }
17887 }
17888 let r = (acc as f32 * mult).round() as i32;
17889 let r = r.clamp(-128, 127) as i8;
17890 out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
17891 }
17892 }
17893 }
17894 }
17895
17896 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
17897 assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
17898 }
17899 }
17900
17901 #[test]
17909 fn q_matmul_matches_fake_quant_reference() {
17910 use rlx_ir::Philox4x32;
17911 let m = 3usize;
17912 let k = 8usize;
17913 let n = 5usize;
17914 let mut rng = Philox4x32::new(2031);
17915
17916 let x_scale = 0.05f32;
17918 let w_scale = 0.03f32;
17919 let out_scale = 0.4f32;
17920 let mult = x_scale * w_scale / out_scale;
17921 let mut xf = vec![0f32; m * k];
17922 rng.fill_normal(&mut xf);
17923 let mut wf = vec![0f32; k * n];
17924 rng.fill_normal(&mut wf);
17925 let xq: Vec<i8> = xf
17926 .iter()
17927 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17928 .collect();
17929 let wq: Vec<i8> = wf
17930 .iter()
17931 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17932 .collect();
17933 let bias: Vec<i32> = vec![0i32; n];
17934
17935 let _f = DType::F32;
17937 let mut g_q = Graph::new("qmm_direct");
17938 let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
17939 let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
17940 let bn = g_q.input("b", Shape::new(&[n], DType::I32));
17941 let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
17942 g_q.set_outputs(vec![out]);
17943 let plan = rlx_opt::memory::plan_memory(&g_q);
17944 let mut arena = crate::arena::Arena::from_plan(plan);
17945 let sched = compile_thunks(&g_q, &arena);
17946
17947 let xn_off = arena.byte_offset(xn);
17949 let wn_off = arena.byte_offset(wn);
17950 let bn_off = arena.byte_offset(bn);
17951 let out_off = arena.byte_offset(out);
17952 let buf = arena.raw_buf_mut();
17953 unsafe {
17954 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
17955 for (i, &v) in xq.iter().enumerate() {
17956 *p.add(i) = v;
17957 }
17958 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
17959 for (i, &v) in wq.iter().enumerate() {
17960 *p.add(i) = v;
17961 }
17962 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
17963 for (i, &v) in bias.iter().enumerate() {
17964 *p.add(i) = v;
17965 }
17966 }
17967 execute_thunks(&sched, arena.raw_buf_mut());
17968 let out_q: Vec<i8> = unsafe {
17969 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
17970 (0..m * n).map(|i| *p.add(i)).collect()
17971 };
17972
17973 let mut out_ref = vec![0i8; m * n];
17978 for mi in 0..m {
17979 for ni in 0..n {
17980 let mut acc: i32 = 0;
17981 for ki in 0..k {
17982 acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
17983 }
17984 let r = (acc as f32 * mult).round() as i32;
17985 out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
17986 }
17987 }
17988
17989 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
17990 assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
17991 }
17992 }
17993
17994 #[test]
17999 fn quantize_dequantize_round_trip() {
18000 use rlx_ir::Philox4x32;
18001 let len = 64;
18002 let mut rng = Philox4x32::new(2027);
18003 let mut x = vec![0f32; len];
18004 rng.fill_normal(&mut x);
18005 x[0] = 999.0;
18008 x[1] = -999.0;
18009
18010 let scale = 0.05f32;
18011 let zp = 3i32;
18012
18013 let f = DType::F32;
18014 let mut g = Graph::new("qdq");
18015 let xn = g.input("x", Shape::new(&[len], f));
18016 let q = g.quantize(xn, scale, zp);
18017 let dq = g.dequantize(q, scale, zp);
18018 g.set_outputs(vec![dq]);
18019
18020 let plan = rlx_opt::memory::plan_memory(&g);
18021 let mut arena = crate::arena::Arena::from_plan(plan);
18022 let sched = compile_thunks(&g, &arena);
18023 let xn_off = arena.byte_offset(xn);
18024 let dq_off = arena.byte_offset(dq);
18025 let buf = arena.raw_buf_mut();
18026 unsafe {
18027 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18028 for (i, &v) in x.iter().enumerate() {
18029 *p.add(i) = v;
18030 }
18031 }
18032 execute_thunks(&sched, arena.raw_buf_mut());
18033 let out: Vec<f32> = unsafe {
18034 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18035 (0..len).map(|i| *p.add(i)).collect()
18036 };
18037
18038 let sat_pos = (127 - zp) as f32 * scale;
18041 let sat_neg = (-128 - zp) as f32 * scale;
18042 assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
18043 assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
18044
18045 for i in 2..len {
18048 assert!(
18049 (out[i] - x[i]).abs() <= scale + 1e-5,
18050 "qdq[{i}]: {} → {}, scale={scale}",
18051 x[i],
18052 out[i]
18053 );
18054 }
18055 }
18056
18057 #[test]
18063 fn quantize_per_channel_round_trip() {
18064 let c = 4usize;
18065 let inner = 5usize;
18066 let mags = [0.01f32, 0.5, 5.0, 50.0];
18069 let mut x = vec![0f32; c * inner];
18070 for ci in 0..c {
18071 for ii in 0..inner {
18072 x[ci * inner + ii] = match ii {
18076 0 => -mags[ci],
18077 1 => 0.0,
18078 2 => mags[ci],
18079 3 => mags[ci] * 1000.0, _ => -mags[ci] * 1000.0, };
18082 }
18083 }
18084 let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
18085 let zps: Vec<i32> = vec![0, 0, 0, 0];
18086
18087 let f = DType::F32;
18088 let mut g = Graph::new("qdq_pc");
18089 let xn = g.input("x", Shape::new(&[c, inner], f));
18090 let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
18091 let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
18092 g.set_outputs(vec![dq]);
18093
18094 let plan = rlx_opt::memory::plan_memory(&g);
18095 let mut arena = crate::arena::Arena::from_plan(plan);
18096 let sched = compile_thunks(&g, &arena);
18097 let xn_off = arena.byte_offset(xn);
18098 let dq_off = arena.byte_offset(dq);
18099 let buf = arena.raw_buf_mut();
18100 unsafe {
18101 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18102 for (i, &v) in x.iter().enumerate() {
18103 *p.add(i) = v;
18104 }
18105 }
18106 execute_thunks(&sched, arena.raw_buf_mut());
18107 let out: Vec<f32> = unsafe {
18108 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18109 (0..c * inner).map(|i| *p.add(i)).collect()
18110 };
18111
18112 for ci in 0..c {
18113 for ii in 0..3 {
18116 let idx = ci * inner + ii;
18117 assert!(
18118 (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
18119 "ch {ci} idx {ii}: {} vs {}",
18120 x[idx],
18121 out[idx]
18122 );
18123 }
18124 let sat_pos = 127.0 * scales[ci];
18126 let sat_neg = -128.0 * scales[ci];
18127 assert!(
18128 (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
18129 "ch {ci} +sat: {}",
18130 out[ci * inner + 3]
18131 );
18132 assert!(
18133 (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
18134 "ch {ci} -sat: {}",
18135 out[ci * inner + 4]
18136 );
18137 }
18138 }
18139
18140 #[test]
18146 fn activation_backward_matches_numerical_per_kind() {
18147 use rlx_ir::Philox4x32;
18148 use rlx_ir::op::Activation;
18149 let mut rng = Philox4x32::new(91);
18150 let len = 32;
18151 let mut x_pos = vec![0f32; len];
18156 rng.fill_normal(&mut x_pos);
18157 for v in x_pos.iter_mut() {
18158 *v = v.abs() + 0.5;
18159 }
18160 let mut x_any = vec![0f32; len];
18161 rng.fill_normal(&mut x_any);
18162 let mut dy = vec![0f32; len];
18163 rng.fill_normal(&mut dy);
18164
18165 for &(kind, x_data, eps, tol) in &[
18166 (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
18167 (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
18168 (Activation::Silu, &x_any[..], 1e-3, 5e-3),
18169 (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
18170 (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
18171 (Activation::Exp, &x_any[..], 1e-4, 5e-3),
18172 (Activation::Log, &x_pos[..], 1e-4, 5e-3),
18173 (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
18174 (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
18175 (Activation::Neg, &x_any[..], 1e-3, 5e-4),
18176 ] {
18177 let f = DType::F32;
18178 let mut g = Graph::new("act_bw");
18179 let xn = g.input("x", Shape::new(&[len], f));
18180 let dyn_ = g.input("dy", Shape::new(&[len], f));
18181 let dx = g.activation_backward(kind, xn, dyn_);
18182 g.set_outputs(vec![dx]);
18183
18184 let plan = rlx_opt::memory::plan_memory(&g);
18185 let mut arena = crate::arena::Arena::from_plan(plan);
18186 let sched = compile_thunks(&g, &arena);
18187
18188 let xn_off = arena.byte_offset(xn);
18189 let dyn_off = arena.byte_offset(dyn_);
18190 let dx_off = arena.byte_offset(dx);
18191 let buf = arena.raw_buf_mut();
18192 unsafe {
18193 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18194 for (i, &v) in x_data.iter().enumerate() {
18195 *p.add(i) = v;
18196 }
18197 let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
18198 for (i, &v) in dy.iter().enumerate() {
18199 *p.add(i) = v;
18200 }
18201 }
18202 execute_thunks(&sched, arena.raw_buf_mut());
18203 let analytical: Vec<f32> = unsafe {
18204 let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
18205 (0..len).map(|i| *p.add(i)).collect()
18206 };
18207
18208 let act_apply = |kind: Activation, x: f32| -> f32 {
18211 match kind {
18212 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
18213 Activation::Tanh => x.tanh(),
18214 Activation::Silu => x / (1.0 + (-x).exp()),
18215 Activation::Gelu => {
18216 const INV_SQRT2: f32 = 0.707_106_77;
18218 0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
18219 }
18220 Activation::GeluApprox => {
18221 const C: f32 = 0.797_884_6;
18222 const A: f32 = 0.044_715;
18223 let inner = C * (x + A * x * x * x);
18224 0.5 * x * (1.0 + inner.tanh())
18225 }
18226 Activation::Exp => x.exp(),
18227 Activation::Log => x.ln(),
18228 Activation::Sqrt => x.sqrt(),
18229 Activation::Rsqrt => 1.0 / x.sqrt(),
18230 Activation::Neg => -x,
18231 Activation::Relu => x.max(0.0),
18232 Activation::Abs => x.abs(),
18233 Activation::Round => x.round(),
18234 Activation::Sin => x.sin(),
18235 Activation::Cos => x.cos(),
18236 Activation::Tan => x.tan(),
18237 Activation::Atan => x.atan(),
18238 }
18239 };
18240 for i in 0..len {
18241 let xv = x_data[i];
18242 let plus = act_apply(kind, xv + eps);
18243 let minus = act_apply(kind, xv - eps);
18244 let num = (plus - minus) / (2.0 * eps) * dy[i];
18245 assert!(
18246 (analytical[i] - num).abs() < tol,
18247 "{kind:?}[{i}]: analytical {} vs numerical {num}",
18248 analytical[i]
18249 );
18250 }
18251 }
18252 }
18253
18254 #[test]
18258 fn matmul_3d_gradient_matches_numerical() {
18259 use rlx_ir::Philox4x32;
18260 let batch = 2usize;
18261 let m = 3usize;
18262 let k = 4usize;
18263 let n = 5usize;
18264 let mut rng = Philox4x32::new(101);
18265 let mut a_data = vec![0f32; batch * m * k];
18266 rng.fill_normal(&mut a_data);
18267 let mut b_data = vec![0f32; batch * k * n];
18268 rng.fill_normal(&mut b_data);
18269
18270 let f = DType::F32;
18271 let mut fwd = Graph::new("matmul_3d");
18272 let an = fwd.input("a", Shape::new(&[batch, m, k], f));
18273 let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
18274 let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
18275 let loss = fwd.add_node(
18276 Op::Reduce {
18277 op: ReduceOp::Sum,
18278 axes: vec![0, 1, 2],
18279 keep_dim: false,
18280 },
18281 vec![mm],
18282 Shape::from_dims(&[], f),
18283 );
18284 fwd.set_outputs(vec![loss]);
18285
18286 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
18287 let d_out = bwd_graph
18288 .nodes()
18289 .iter()
18290 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18291 .map(|n| n.id)
18292 .unwrap();
18293
18294 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18295 let mut arena = crate::arena::Arena::from_plan(plan);
18296 let sched = compile_thunks(&bwd_graph, &arena);
18297 for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
18298 let off = arena.byte_offset(id);
18299 let buf = arena.raw_buf_mut();
18300 unsafe {
18301 let p = buf.as_mut_ptr().add(off) as *mut f32;
18302 for (i, &v) in data.iter().enumerate() {
18303 *p.add(i) = v;
18304 }
18305 }
18306 }
18307 execute_thunks(&sched, arena.raw_buf_mut());
18308 let gb_id = bwd_graph.outputs[1];
18309 let g_b: Vec<f32> = unsafe {
18310 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
18311 (0..batch * k * n).map(|i| *p.add(i)).collect()
18312 };
18313
18314 let forward_loss = |b_vals: &[f32]| -> f32 {
18316 let mut out = vec![0f32; batch * m * n];
18317 for bi in 0..batch {
18318 for mi in 0..m {
18319 for ni in 0..n {
18320 let mut acc = 0f32;
18321 for ki in 0..k {
18322 acc +=
18323 a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
18324 }
18325 out[bi * m * n + mi * n + ni] = acc;
18326 }
18327 }
18328 }
18329 out.iter().sum()
18330 };
18331 let eps = 1e-3f32;
18332 let mut bp_p = b_data.clone();
18333 let mut g_b_num = vec![0f32; b_data.len()];
18334 for i in 0..b_data.len() {
18335 let s = bp_p[i];
18336 bp_p[i] = s + eps;
18337 let lp = forward_loss(&bp_p);
18338 bp_p[i] = s - eps;
18339 let lm = forward_loss(&bp_p);
18340 bp_p[i] = s;
18341 g_b_num[i] = (lp - lm) / (2.0 * eps);
18342 }
18343 for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
18344 assert!(
18345 (a - n).abs() < 5e-3,
18346 "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
18347 );
18348 }
18349 }
18350
18351 #[test]
18357 fn softmax_gradient_matches_numerical() {
18358 use rlx_ir::Philox4x32;
18359 let n = 3usize;
18360 let c = 5usize;
18361 let mut rng = Philox4x32::new(57);
18362 let mut x_data = vec![0f32; n * c];
18363 rng.fill_normal(&mut x_data);
18364
18365 let f = DType::F32;
18366 let mut fwd = Graph::new("softmax_only");
18367 let xn = fwd.input("x", Shape::new(&[n, c], f));
18368 let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
18369 let loss = fwd.add_node(
18373 Op::Reduce {
18374 op: ReduceOp::Sum,
18375 axes: vec![0, 1],
18376 keep_dim: false,
18377 },
18378 vec![sm],
18379 Shape::from_dims(&[], f),
18380 );
18381 fwd.set_outputs(vec![loss]);
18382
18383 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
18387 let d_out = bwd_graph
18388 .nodes()
18389 .iter()
18390 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18391 .map(|n| n.id)
18392 .unwrap();
18393
18394 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18395 let mut arena = crate::arena::Arena::from_plan(plan);
18396 let sched = compile_thunks(&bwd_graph, &arena);
18397 for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
18398 let off = arena.byte_offset(id);
18399 let buf = arena.raw_buf_mut();
18400 unsafe {
18401 let p = buf.as_mut_ptr().add(off) as *mut f32;
18402 for (i, &v) in data.iter().enumerate() {
18403 *p.add(i) = v;
18404 }
18405 }
18406 }
18407 execute_thunks(&sched, arena.raw_buf_mut());
18408 let g_x_id = bwd_graph.outputs[1];
18409 let g_x: Vec<f32> = unsafe {
18410 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
18411 (0..n * c).map(|i| *p.add(i)).collect()
18412 };
18413
18414 let forward_loss = |x: &[f32]| -> f32 {
18418 let mut total = 0f32;
18419 for ni in 0..n {
18420 let row = &x[ni * c..(ni + 1) * c];
18421 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
18422 let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
18423 for &v in row {
18424 total += (v - m).exp() / denom;
18425 }
18426 }
18427 total
18428 };
18429 let eps = 1e-3f32;
18430 let mut p = x_data.clone();
18431 for i in 0..x_data.len() {
18432 let s = p[i];
18433 p[i] = s + eps;
18434 let lp = forward_loss(&p);
18435 p[i] = s - eps;
18436 let lm = forward_loss(&p);
18437 p[i] = s;
18438 let num = (lp - lm) / (2.0 * eps);
18439 assert!(
18440 (g_x[i] - num).abs() < 5e-3,
18441 "softmax g_x[{i}]: analytical {} vs numerical {num}",
18442 g_x[i]
18443 );
18444 }
18445 }
18446
18447 #[test]
18452 fn layer_norm_gradient_matches_numerical() {
18453 use rlx_ir::Philox4x32;
18454 let rows = 3usize;
18455 let h = 6usize;
18456 let mut rng = Philox4x32::new(1009);
18457 let mut x_data = vec![0f32; rows * h];
18458 rng.fill_normal(&mut x_data);
18459 let mut g_data = vec![0f32; h];
18460 rng.fill_normal(&mut g_data);
18461 for v in g_data.iter_mut() {
18462 *v = v.abs() + 0.5;
18463 }
18464 let mut b_data = vec![0f32; h];
18465 rng.fill_normal(&mut b_data);
18466 let eps = 1e-5f32;
18467
18468 let f = DType::F32;
18469 let mut fwd = Graph::new("ln_only");
18470 let xn = fwd.input("x", Shape::new(&[rows, h], f));
18471 let gp = fwd.param("gamma", Shape::new(&[h], f));
18472 let bp = fwd.param("beta", Shape::new(&[h], f));
18473 let ln = fwd.add_node(
18474 Op::LayerNorm { axis: -1, eps },
18475 vec![xn, gp, bp],
18476 Shape::new(&[rows, h], f),
18477 );
18478 let loss = fwd.add_node(
18479 Op::Reduce {
18480 op: ReduceOp::Sum,
18481 axes: vec![0, 1],
18482 keep_dim: false,
18483 },
18484 vec![ln],
18485 Shape::from_dims(&[], f),
18486 );
18487 fwd.set_outputs(vec![loss]);
18488
18489 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
18490 let d_out = bwd_graph
18491 .nodes()
18492 .iter()
18493 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18494 .map(|n| n.id)
18495 .unwrap();
18496
18497 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18498 let mut arena = crate::arena::Arena::from_plan(plan);
18499 let sched = compile_thunks(&bwd_graph, &arena);
18500 for &(id, data) in &[
18501 (xn, &x_data),
18502 (gp, &g_data),
18503 (bp, &b_data),
18504 (d_out, &vec![1.0f32]),
18505 ] {
18506 let off = arena.byte_offset(id);
18507 let buf = arena.raw_buf_mut();
18508 unsafe {
18509 let p = buf.as_mut_ptr().add(off) as *mut f32;
18510 for (i, &v) in data.iter().enumerate() {
18511 *p.add(i) = v;
18512 }
18513 }
18514 }
18515 execute_thunks(&sched, arena.raw_buf_mut());
18516 let read = |id: NodeId, n: usize| -> Vec<f32> {
18517 let off = arena.byte_offset(id);
18518 unsafe {
18519 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
18520 (0..n).map(|i| *p.add(i)).collect()
18521 }
18522 };
18523 let dx_a = read(bwd_graph.outputs[1], rows * h);
18524 let dg_a = read(bwd_graph.outputs[2], h);
18525 let db_a = read(bwd_graph.outputs[3], h);
18526
18527 let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
18528 let mut total = 0f32;
18529 for r in 0..rows {
18530 let row = &x[r * h..(r + 1) * h];
18531 let mean = row.iter().sum::<f32>() / h as f32;
18532 let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
18533 let inv_std = 1.0 / (var + eps).sqrt();
18534 for d in 0..h {
18535 total += ((row[d] - mean) * inv_std) * g[d] + b[d];
18536 }
18537 }
18538 total
18539 };
18540 let h_eps = 1e-3f32;
18541
18542 let mut x_p = x_data.clone();
18543 for i in 0..x_p.len() {
18544 let s = x_p[i];
18545 x_p[i] = s + h_eps;
18546 let lp = forward_loss(&x_p, &g_data, &b_data);
18547 x_p[i] = s - h_eps;
18548 let lm = forward_loss(&x_p, &g_data, &b_data);
18549 x_p[i] = s;
18550 let num = (lp - lm) / (2.0 * h_eps);
18551 assert!(
18552 (dx_a[i] - num).abs() < 5e-3,
18553 "ln dx[{i}]: analytical {} vs numerical {num}",
18554 dx_a[i]
18555 );
18556 }
18557 let mut g_p = g_data.clone();
18558 for i in 0..g_p.len() {
18559 let s = g_p[i];
18560 g_p[i] = s + h_eps;
18561 let lp = forward_loss(&x_data, &g_p, &b_data);
18562 g_p[i] = s - h_eps;
18563 let lm = forward_loss(&x_data, &g_p, &b_data);
18564 g_p[i] = s;
18565 let num = (lp - lm) / (2.0 * h_eps);
18566 assert!(
18567 (dg_a[i] - num).abs() < 5e-3,
18568 "ln dg[{i}]: analytical {} vs numerical {num}",
18569 dg_a[i]
18570 );
18571 }
18572 let mut b_p = b_data.clone();
18573 for i in 0..b_p.len() {
18574 let s = b_p[i];
18575 b_p[i] = s + h_eps;
18576 let lp = forward_loss(&x_data, &g_data, &b_p);
18577 b_p[i] = s - h_eps;
18578 let lm = forward_loss(&x_data, &g_data, &b_p);
18579 b_p[i] = s;
18580 let num = (lp - lm) / (2.0 * h_eps);
18581 assert!(
18582 (db_a[i] - num).abs() < 5e-3,
18583 "ln db[{i}]: analytical {} vs numerical {num}",
18584 db_a[i]
18585 );
18586 }
18587 }
18588
18589 #[test]
18594 fn dense_sce_mean_gradient_matches_numerical() {
18595 use rlx_ir::Philox4x32;
18596 let bs = 4usize;
18597 let k_in = 3usize;
18598 let c = 5usize;
18599 let mut rng = Philox4x32::new(7);
18600 let mut x = vec![0f32; bs * k_in];
18601 rng.fill_normal(&mut x);
18602 let mut w_init = vec![0f32; k_in * c];
18603 rng.fill_normal(&mut w_init);
18604 let mut b_init = vec![0f32; c];
18605 rng.fill_normal(&mut b_init);
18606 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18607
18608 let f = DType::F32;
18610 let mut fwd = Graph::new("dense_sce");
18611 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18612 let lb = fwd.input("labels", Shape::new(&[bs], f));
18613 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18614 let bp = fwd.param("b", Shape::new(&[c], f));
18615 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18616 let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
18617 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18618 let loss = fwd.add_node(
18619 Op::Reduce {
18620 op: ReduceOp::Sum,
18621 axes: vec![0],
18622 keep_dim: false,
18623 },
18624 vec![loss_per],
18625 Shape::from_dims(&[], f),
18627 );
18628 fwd.set_outputs(vec![loss]);
18636
18637 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
18639 let d_out = bwd_graph
18642 .nodes()
18643 .iter()
18644 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18645 .map(|n| n.id)
18646 .expect("d_output input");
18647
18648 let (sched, mut arena) = prepare(
18649 &bwd_graph,
18650 &[
18651 (xn, &x),
18652 (lb, &labels),
18653 (wp, &w_init),
18654 (bp, &b_init),
18655 (d_out, &[1.0]),
18656 ],
18657 );
18658 execute_thunks(&sched, arena.raw_buf_mut());
18659
18660 let outs = &bwd_graph.outputs;
18661 let loss_id = outs[0];
18662 let gw_id = outs[1];
18663 let gb_id = outs[2];
18664 let loss_actual = read_arena(&arena, loss_id, 1)[0];
18665 let gw_actual = read_arena(&arena, gw_id, k_in * c);
18666 let gb_actual = read_arena(&arena, gb_id, c);
18667
18668 let plan = rlx_opt::memory::plan_memory(&fwd);
18672 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18673 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18674 write_arena(&mut fwd_arena, xn, &x);
18675 write_arena(&mut fwd_arena, lb, &labels);
18676
18677 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
18678 write_arena(arena, wp, w);
18679 write_arena(arena, bp, b);
18680 execute_thunks(&fwd_sched, arena.raw_buf_mut());
18681 read_arena(arena, loss, 1)[0]
18682 };
18683
18684 let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
18687 assert!(
18688 (loss_actual - loss_check).abs() < 1e-4,
18689 "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
18690 );
18691
18692 let eps = 1e-3f32;
18693 let mut w_perturbed = w_init.clone();
18694 let mut gw_numerical = vec![0f32; w_init.len()];
18695 for i in 0..w_init.len() {
18696 let saved = w_perturbed[i];
18697 w_perturbed[i] = saved + eps;
18698 let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18699 w_perturbed[i] = saved - eps;
18700 let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18701 w_perturbed[i] = saved;
18702 gw_numerical[i] = (lp - lm) / (2.0 * eps);
18703 }
18704 for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
18705 assert!(
18706 (a - n).abs() < 5e-3,
18707 "grad_w[{i}]: analytical {a} vs numerical {n}"
18708 );
18709 }
18710
18711 let mut b_perturbed = b_init.clone();
18712 let mut gb_numerical = vec![0f32; b_init.len()];
18713 for i in 0..b_init.len() {
18714 let saved = b_perturbed[i];
18715 b_perturbed[i] = saved + eps;
18716 let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18717 b_perturbed[i] = saved - eps;
18718 let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18719 b_perturbed[i] = saved;
18720 gb_numerical[i] = (lp - lm) / (2.0 * eps);
18721 }
18722 for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
18723 assert!(
18724 (a - n).abs() < 5e-3,
18725 "grad_b[{i}]: analytical {a} vs numerical {n}"
18726 );
18727 }
18728 }
18729
18730 #[test]
18733 fn dense_sce_mean_reduce_gradient_matches_numerical() {
18734 use rlx_ir::Philox4x32;
18735 let bs = 3usize;
18736 let k_in = 2usize;
18737 let c = 4usize;
18738 let mut rng = Philox4x32::new(13);
18739 let mut x = vec![0f32; bs * k_in];
18740 rng.fill_normal(&mut x);
18741 let mut w_init = vec![0f32; k_in * c];
18742 rng.fill_normal(&mut w_init);
18743 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18744
18745 let f = DType::F32;
18746 let mut fwd = Graph::new("dense_sce_mean");
18747 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18748 let lb = fwd.input("labels", Shape::new(&[bs], f));
18749 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18750 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18751 let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
18752 let loss = fwd.add_node(
18753 Op::Reduce {
18754 op: ReduceOp::Mean,
18755 axes: vec![0],
18756 keep_dim: false,
18757 },
18758 vec![loss_per],
18759 Shape::from_dims(&[], f),
18760 );
18761 fwd.set_outputs(vec![loss]);
18762
18763 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
18764 let d_out = bwd_graph
18765 .nodes()
18766 .iter()
18767 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18768 .map(|n| n.id)
18769 .unwrap();
18770
18771 let (sched, mut arena) = prepare(
18772 &bwd_graph,
18773 &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
18774 );
18775 execute_thunks(&sched, arena.raw_buf_mut());
18776
18777 let outs = &bwd_graph.outputs;
18778 let loss_id = outs[0];
18779 let gw_id = outs[1];
18780 let _ = read_arena(&arena, loss_id, 1)[0];
18781 let gw_actual = read_arena(&arena, gw_id, k_in * c);
18782
18783 let plan = rlx_opt::memory::plan_memory(&fwd);
18784 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18785 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18786 write_arena(&mut fwd_arena, xn, &x);
18787 write_arena(&mut fwd_arena, lb, &labels);
18788
18789 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
18790 write_arena(arena, wp, w);
18791 execute_thunks(&fwd_sched, arena.raw_buf_mut());
18792 read_arena(arena, loss, 1)[0]
18793 };
18794
18795 let eps = 1e-3f32;
18796 let mut wp_p = w_init.clone();
18797 let mut gw_num = vec![0f32; w_init.len()];
18798 for i in 0..w_init.len() {
18799 let s = wp_p[i];
18800 wp_p[i] = s + eps;
18801 let lp = run_loss(&mut fwd_arena, &wp_p);
18802 wp_p[i] = s - eps;
18803 let lm = run_loss(&mut fwd_arena, &wp_p);
18804 wp_p[i] = s;
18805 gw_num[i] = (lp - lm) / (2.0 * eps);
18806 }
18807 for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
18808 assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
18809 }
18810 }
18811 #[test]
18816 fn tinyconv_full_gradient_matches_numerical() {
18817 use rlx_ir::Philox4x32;
18818 let n = 1usize;
18820 let c_in = 1usize;
18821 let h = 6usize;
18822 let w_in = 6usize;
18823 let c_mid = 2usize; let kh = 3;
18825 let kw = 3;
18826 let h1 = h - kh + 1; let w1 = w_in - kw + 1; let h2 = h1 / 2;
18829 let w2 = w1 / 2; let flat = c_mid * h2 * w2; let num_classes = 3usize;
18832
18833 let mut rng = Philox4x32::new(31);
18834 let mut x = vec![0f32; n * c_in * h * w_in];
18835 rng.fill_normal(&mut x);
18836 let mut wc = vec![0f32; c_mid * c_in * kh * kw];
18837 rng.fill_normal(&mut wc);
18838 for v in wc.iter_mut() {
18839 *v *= 0.2;
18840 }
18841 let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
18850 let mut wfc = vec![0f32; flat * num_classes];
18851 rng.fill_normal(&mut wfc);
18852 for v in wfc.iter_mut() {
18853 *v *= 0.5;
18854 }
18855 let mut bfc = vec![0f32; num_classes];
18856 rng.fill_normal(&mut bfc);
18857 let labels: Vec<f32> = vec![1.0]; let f = DType::F32;
18860 let mut fwd = Graph::new("tinyconv");
18861 let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
18862 let lb = fwd.input("labels", Shape::new(&[n], f));
18863 let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
18864 let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
18865 let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
18866 let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
18867
18868 let conv = fwd.add_node(
18870 Op::Conv {
18871 kernel_size: vec![kh, kw],
18872 stride: vec![1, 1],
18873 padding: vec![0, 0],
18874 dilation: vec![1, 1],
18875 groups: 1,
18876 },
18877 vec![xn, wcp],
18878 Shape::new(&[n, c_mid, h1, w1], f),
18879 );
18880 let bc_4d = fwd.add_node(
18892 Op::Reshape {
18893 new_shape: vec![1, c_mid as i64, 1, 1],
18894 },
18895 vec![bcp],
18896 Shape::new(&[1, c_mid, 1, 1], f),
18897 );
18898 let bc_expanded = fwd.add_node(
18899 Op::Expand {
18900 target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
18901 },
18902 vec![bc_4d],
18903 Shape::new(&[n, c_mid, h1, w1], f),
18904 );
18905 let conv_b = fwd.binary(
18906 BinaryOp::Add,
18907 conv,
18908 bc_expanded,
18909 Shape::new(&[n, c_mid, h1, w1], f),
18910 );
18911 let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
18912 let pool = fwd.add_node(
18913 Op::Pool {
18914 kind: ReduceOp::Max,
18915 kernel_size: vec![2, 2],
18916 stride: vec![2, 2],
18917 padding: vec![0, 0],
18918 },
18919 vec![relu],
18920 Shape::new(&[n, c_mid, h2, w2], f),
18921 );
18922 let flatn = fwd.add_node(
18923 Op::Reshape {
18924 new_shape: vec![n as i64, flat as i64],
18925 },
18926 vec![pool],
18927 Shape::new(&[n, flat], f),
18928 );
18929 let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
18930 let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
18931 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18932 let loss = fwd.add_node(
18933 Op::Reduce {
18934 op: ReduceOp::Mean,
18935 axes: vec![0],
18936 keep_dim: false,
18937 },
18938 vec![loss_per],
18939 Shape::from_dims(&[], f),
18940 );
18941 fwd.set_outputs(vec![loss]);
18942
18943 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
18944 let d_out = bwd_graph
18945 .nodes()
18946 .iter()
18947 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18948 .map(|n| n.id)
18949 .unwrap();
18950
18951 let (sched, mut arena) = prepare(
18952 &bwd_graph,
18953 &[
18954 (xn, &x),
18955 (lb, &labels),
18956 (wcp, &wc),
18957 (bcp, &bc),
18958 (wfp, &wfc),
18959 (bfp, &bfc),
18960 (d_out, &[1.0]),
18961 ],
18962 );
18963 execute_thunks(&sched, arena.raw_buf_mut());
18964
18965 let outs = bwd_graph.outputs.clone();
18966 let loss_id = outs[0];
18967 let g_wc_id = outs[1];
18968 let g_bc_id = outs[2];
18969 let g_wfc_id = outs[3];
18970 let g_bfc_id = outs[4];
18971 let loss_actual = read_arena(&arena, loss_id, 1)[0];
18972 let g_wc = read_arena(&arena, g_wc_id, wc.len());
18973 let g_bc = read_arena(&arena, g_bc_id, bc.len());
18974 let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
18975 let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
18976
18977 let plan = rlx_opt::memory::plan_memory(&fwd);
18979 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18980 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18981 write_arena(&mut fwd_arena, xn, &x);
18982 write_arena(&mut fwd_arena, lb, &labels);
18983
18984 let run_loss = |arena: &mut crate::arena::Arena,
18987 wc: &[f32],
18988 bc: &[f32],
18989 wfc: &[f32],
18990 bfc: &[f32]|
18991 -> f32 {
18992 write_arena(arena, wcp, wc);
18993 write_arena(arena, bcp, bc);
18994 write_arena(arena, wfp, wfc);
18995 write_arena(arena, bfp, bfc);
18996 execute_thunks(&fwd_sched, arena.raw_buf_mut());
18997 read_arena(arena, loss, 1)[0]
18998 };
18999
19000 let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
19001 assert!(
19002 (loss_actual - loss_check).abs() < 1e-4,
19003 "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
19004 );
19005
19006 let eps = 1e-3f32;
19007 let check_grad = |arena: &mut crate::arena::Arena,
19008 name: &str,
19009 analytical: &[f32],
19010 mut perturb: Box<
19011 dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
19012 >,
19013 n: usize| {
19014 for i in 0..n {
19015 let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
19016 let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
19017 let num = (lp - lm) / (2.0 * eps);
19018 assert!(
19019 (analytical[i] - num).abs() < 5e-3,
19020 "{name}[{i}]: analytical {} vs numerical {num}",
19021 analytical[i]
19022 );
19023 }
19024 };
19025
19026 #[allow(unused_macros)]
19029 macro_rules! sweep {
19030 ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
19031 let n = $base.len();
19032 for i in 0..n {
19033 let mut p = $base.clone();
19034 let s = p[i];
19035 p[i] = s + eps;
19036 let lp = {
19037 let $set_param = &p;
19038 run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
19039 let _ = $set_param;
19042 0.0_f32
19044 };
19045 let _ = lp;
19046 }
19047 }};
19048 }
19049 let _ = check_grad; for i in 0..wc.len() {
19053 let mut p = wc.clone();
19054 let s = p[i];
19055 p[i] = s + eps;
19056 let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19057 p[i] = s - eps;
19058 let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19059 let num = (lp - lm) / (2.0 * eps);
19060 assert!(
19061 (g_wc[i] - num).abs() < 5e-3,
19062 "g_wc[{i}]: {} vs {num}",
19063 g_wc[i]
19064 );
19065 }
19066 for i in 0..bc.len() {
19067 let mut p = bc.clone();
19068 let s = p[i];
19069 p[i] = s + eps;
19070 let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19071 p[i] = s - eps;
19072 let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19073 let num = (lp - lm) / (2.0 * eps);
19074 assert!(
19075 (g_bc[i] - num).abs() < 5e-3,
19076 "g_bc[{i}]: {} vs {num}",
19077 g_bc[i]
19078 );
19079 }
19080 for i in 0..wfc.len() {
19081 let mut p = wfc.clone();
19082 let s = p[i];
19083 p[i] = s + eps;
19084 let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19085 p[i] = s - eps;
19086 let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19087 let num = (lp - lm) / (2.0 * eps);
19088 assert!(
19089 (g_wfc[i] - num).abs() < 5e-3,
19090 "g_wfc[{i}]: {} vs {num}",
19091 g_wfc[i]
19092 );
19093 }
19094 for i in 0..bfc.len() {
19095 let mut p = bfc.clone();
19096 let s = p[i];
19097 p[i] = s + eps;
19098 let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19099 p[i] = s - eps;
19100 let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19101 let num = (lp - lm) / (2.0 * eps);
19102 assert!(
19103 (g_bfc[i] - num).abs() < 5e-3,
19104 "g_bfc[{i}]: {} vs {num}",
19105 g_bfc[i]
19106 );
19107 }
19108 }
19109
19110 #[test]
19114 fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
19115 let f = DType::F32;
19116 let mut g = Graph::new("nr_skip");
19117 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
19118 let cos = g.input("cos", Shape::new(&[16], f));
19119 let sin = g.input("sin", Shape::new(&[16], f));
19120 let q = g.narrow_(qkv, 2, 0, 64);
19121 let q_rope = g.rope(q, cos, sin, 16);
19122 let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
19124 g.set_outputs(vec![q_rope, q_dup]);
19125
19126 let plan = rlx_opt::memory::plan_memory(&g);
19127 let arena = crate::arena::Arena::from_plan(plan);
19128 let sched = compile_thunks(&g, &arena);
19129
19130 let narrow_count = sched
19131 .thunks
19132 .iter()
19133 .filter(|t| matches!(t, Thunk::Narrow { .. }))
19134 .count();
19135 assert!(
19136 narrow_count >= 1,
19137 "Narrow with multiple consumers must NOT be fused away"
19138 );
19139 }
19140
19141 #[test]
19154 fn custom_fn_forward_inlines_body() {
19155 let s = Shape::new(&[3], DType::F32);
19156
19157 let mut body = Graph::new("addone_body");
19159 let x = body.input("x", s.clone());
19160 let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
19161 let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
19162 let y = body.binary(BinaryOp::Add, x, one, s.clone());
19163 body.set_outputs(vec![y]);
19164
19165 let mut g = Graph::new("custom_fn_outer");
19166 let xin = g.input("x_in", s.clone());
19167 let cf = g.custom_fn(vec![xin], body, None, None);
19168 g.set_outputs(vec![cf]);
19169
19170 let xs = vec![10.0_f32, 20.0, 30.0];
19171 let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
19172 execute_thunks(&sched, arena.raw_buf_mut());
19173 let got = read_arena(&arena, cf, 3);
19174 assert_eq!(got, vec![11.0, 21.0, 31.0]);
19175 }
19176
19177 fn find_named(graph: &Graph, want: &str) -> NodeId {
19179 for n in graph.nodes() {
19180 let name = match &n.op {
19181 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19182 _ => None,
19183 };
19184 if name == Some(want) {
19185 return n.id;
19186 }
19187 }
19188 panic!("no node named {want:?} in graph");
19189 }
19190
19191 #[test]
19195 fn custom_fn_vjp_overrides_natural_gradient() {
19196 use rlx_opt::autodiff::grad_with_loss;
19197 let s = Shape::new(&[1], DType::F32);
19198
19199 let mut fwd = Graph::new("id_fwd");
19200 let x = fwd.input("x", s.clone());
19201 fwd.set_outputs(vec![x]);
19202
19203 let mut vjp_g = Graph::new("id_vjp");
19204 let _x_p = vjp_g.input("x", s.clone());
19205 let _y_p = vjp_g.input("primal_output", s.clone());
19206 let dy = vjp_g.input("d_output", s.clone());
19207 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19208 let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19209 let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
19210 vjp_g.set_outputs(vec![dx]);
19211
19212 let mut g = Graph::new("outer");
19213 let xp = g.param("x", s.clone());
19214 let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
19215 g.set_outputs(vec![cf]);
19216
19217 let bwd = grad_with_loss(&g, &[xp]);
19218 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
19219
19220 let xb = find_named(&bwd, "x");
19221 let dout = find_named(&bwd, "d_output");
19222 let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
19223 execute_thunks(&sched, arena.raw_buf_mut());
19224 let loss = read_arena(&arena, bwd.outputs[0], 1);
19225 let dx_v = read_arena(&arena, bwd.outputs[1], 1);
19226 assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
19227 assert!(
19228 (dx_v[0] - 2.0).abs() < 1e-6,
19229 "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
19230 dx_v[0]
19231 );
19232 }
19233
19234 #[test]
19239 fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
19240 use rlx_opt::autodiff::grad_with_loss;
19241 let s = Shape::new(&[1], DType::F32);
19242
19243 let mut fwd = Graph::new("mul_fwd");
19244 let a_f = fwd.input("a", s.clone());
19245 let b_f = fwd.input("b", s.clone());
19246 let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
19247 fwd.set_outputs(vec![y_f]);
19248
19249 let mut vjp_g = Graph::new("mul_vjp");
19250 let a_v = vjp_g.input("a", s.clone());
19251 let b_v = vjp_g.input("b", s.clone());
19252 let _y_v = vjp_g.input("primal_output", s.clone());
19253 let dy_v = vjp_g.input("d_output", s.clone());
19254 let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
19255 let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
19256 vjp_g.set_outputs(vec![da, db]);
19257
19258 let mut g = Graph::new("outer");
19259 let ap = g.param("a", s.clone());
19260 let bp = g.param("b", s.clone());
19261 let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
19262 g.set_outputs(vec![cf]);
19263
19264 let bwd = grad_with_loss(&g, &[ap, bp]);
19265 assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
19266
19267 let ab = find_named(&bwd, "a");
19268 let bb = find_named(&bwd, "b");
19269 let dout = find_named(&bwd, "d_output");
19270 let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
19271 execute_thunks(&sched, arena.raw_buf_mut());
19272 let loss = read_arena(&arena, bwd.outputs[0], 1);
19273 let da_v = read_arena(&arena, bwd.outputs[1], 1);
19274 let db_v = read_arena(&arena, bwd.outputs[2], 1);
19275 assert!((loss[0] - 15.0).abs() < 1e-5);
19276 assert!(
19277 (da_v[0] - 5.0).abs() < 1e-5,
19278 "da should be b=5.0, got {}",
19279 da_v[0]
19280 );
19281 assert!(
19282 (db_v[0] - 3.0).abs() < 1e-5,
19283 "db should be a=3.0, got {}",
19284 db_v[0]
19285 );
19286 }
19287
19288 #[test]
19291 fn custom_fn_jvp_overrides_natural_tangent() {
19292 use rlx_opt::autodiff_fwd::jvp;
19293 let s = Shape::new(&[1], DType::F32);
19294
19295 let mut fwd = Graph::new("id_fwd");
19296 let x = fwd.input("x", s.clone());
19297 fwd.set_outputs(vec![x]);
19298
19299 let mut jvp_g = Graph::new("id_jvp");
19300 let _x_p = jvp_g.input("x", s.clone());
19301 let tx = jvp_g.input("tangent_0", s.clone());
19302 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19303 let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19304 let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
19305 jvp_g.set_outputs(vec![ty]);
19306
19307 let mut g = Graph::new("outer");
19308 let xin = g.input("x_in", s.clone());
19309 let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
19310 g.set_outputs(vec![cf]);
19311
19312 let fwd_g = jvp(&g, &[xin]);
19313 assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
19314
19315 let xb = find_named(&fwd_g, "x_in");
19316 let tan = find_named(&fwd_g, "tangent_x_in");
19317 let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
19318 execute_thunks(&sched, arena.raw_buf_mut());
19319 let y = read_arena(&arena, fwd_g.outputs[0], 1);
19320 let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
19321 assert!((y[0] - 7.0).abs() < 1e-6);
19322 assert!(
19323 (ty_v[0] - 2.0).abs() < 1e-6,
19324 "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
19325 ty_v[0]
19326 );
19327 }
19328
19329 #[test]
19334 fn c64_dtype_storage_layout() {
19335 assert_eq!(
19336 DType::C64.size_bytes(),
19337 8,
19338 "C64 should be 8 bytes (f32 real + f32 imag)"
19339 );
19340 assert!(DType::C64.is_complex());
19341 assert!(!DType::C64.is_float());
19342
19343 let s = Shape::new(&[2], DType::C64);
19345 assert_eq!(s.size_bytes().unwrap(), 16);
19346 }
19347
19348 fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
19355 let n = a.len();
19356 let s = Shape::new(&[n], DType::C64);
19357 let mut g = Graph::new("c64_bin");
19358 let in_a = g.input("a", s.clone());
19359 let in_b = g.input("b", s.clone());
19360 let out = g.binary(op, in_a, in_b, s.clone());
19361 g.set_outputs(vec![out]);
19362
19363 let plan = rlx_opt::memory::plan_memory(&g);
19364 let mut arena = crate::arena::Arena::from_plan(plan);
19365 let sched = compile_thunks(&g, &arena);
19366
19367 let a_off = arena.byte_offset(in_a);
19368 let b_off = arena.byte_offset(in_b);
19369 let out_off = arena.byte_offset(out);
19370 let buf = arena.raw_buf_mut();
19372 unsafe {
19373 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19374 let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
19375 for (i, &(re, im)) in a.iter().enumerate() {
19376 *pa.add(2 * i) = re;
19377 *pa.add(2 * i + 1) = im;
19378 }
19379 for (i, &(re, im)) in b.iter().enumerate() {
19380 *pb.add(2 * i) = re;
19381 *pb.add(2 * i + 1) = im;
19382 }
19383 }
19384 execute_thunks(&sched, arena.raw_buf_mut());
19385 let raw_out: Vec<f32> = unsafe {
19386 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19387 (0..(2 * n)).map(|i| *p.add(i)).collect()
19388 };
19389 (0..n)
19390 .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
19391 .collect()
19392 }
19393
19394 #[track_caller]
19395 fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
19396 let dr = (got.0 - expected.0).abs();
19397 let di = (got.1 - expected.1).abs();
19398 assert!(
19399 dr < tol && di < tol,
19400 "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
19401 got.0,
19402 got.1,
19403 expected.0,
19404 expected.1
19405 );
19406 }
19407
19408 #[test]
19409 fn c64_binary_add_matches_complex_arithmetic() {
19410 let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
19411 let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
19412 let out = run_c64_binary(BinaryOp::Add, &a, &b);
19413 assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
19414 assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
19415 }
19416
19417 #[test]
19418 fn c64_binary_sub_matches_complex_arithmetic() {
19419 let a = [(5.0_f32, 1.0_f32)];
19420 let b = [(2.0_f32, 3.0_f32)];
19421 let out = run_c64_binary(BinaryOp::Sub, &a, &b);
19422 assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
19423 }
19424
19425 #[test]
19426 fn c64_binary_mul_matches_complex_arithmetic() {
19427 let a = [(1.0_f32, 2.0_f32)];
19429 let b = [(3.0_f32, 4.0_f32)];
19430 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19431 assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
19432 }
19433
19434 #[test]
19435 fn c64_binary_div_matches_complex_arithmetic() {
19436 let a = [(1.0_f32, 2.0_f32)];
19440 let b = [(3.0_f32, 4.0_f32)];
19441 let out = run_c64_binary(BinaryOp::Div, &a, &b);
19442 assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
19443 }
19444
19445 #[test]
19446 fn c64_binary_mul_identity_one_is_no_op() {
19447 let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
19449 let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
19450 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19451 assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
19452 assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
19453 }
19454
19455 #[test]
19456 fn c64_binary_mul_by_i_rotates_90_degrees() {
19457 let a = [(1.0_f32, 0.0_f32)];
19459 let b = [(0.0_f32, 1.0_f32)];
19460 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19461 assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
19462 }
19463
19464 #[test]
19465 fn c64_binary_div_by_self_gives_unity() {
19466 let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
19467 let out = run_c64_binary(BinaryOp::Div, &a, &a);
19468 assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
19469 assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
19470 }
19471
19472 #[test]
19473 #[should_panic(expected = "C64: complex max/min/pow")]
19474 fn c64_binary_max_is_rejected_at_lowering() {
19475 run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
19476 }
19477
19478 fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
19479 let n = a.len();
19480 let s = Shape::new(&[n], DType::C64);
19481 let mut g = Graph::new("c64_act");
19482 let in_a = g.input("a", s.clone());
19483 let out = g.activation(act, in_a, s.clone());
19484 g.set_outputs(vec![out]);
19485 let plan = rlx_opt::memory::plan_memory(&g);
19486 let mut arena = crate::arena::Arena::from_plan(plan);
19487 let sched = compile_thunks(&g, &arena);
19488 let a_off = arena.byte_offset(in_a);
19489 let out_off = arena.byte_offset(out);
19490 let buf = arena.raw_buf_mut();
19491 unsafe {
19492 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19493 for (i, &(re, im)) in a.iter().enumerate() {
19494 *pa.add(2 * i) = re;
19495 *pa.add(2 * i + 1) = im;
19496 }
19497 }
19498 execute_thunks(&sched, arena.raw_buf_mut());
19499 let raw: Vec<f32> = unsafe {
19500 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19501 (0..(2 * n)).map(|i| *p.add(i)).collect()
19502 };
19503 (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
19504 }
19505
19506 #[test]
19507 fn c64_activation_neg_negates_both_components() {
19508 let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
19509 let out = run_c64_activation(Activation::Neg, &inp);
19510 assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
19511 assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
19512 }
19513
19514 #[test]
19515 fn c64_activation_exp_matches_euler() {
19516 let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
19519 let out = run_c64_activation(Activation::Exp, &inp);
19520 assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
19521 assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
19522 }
19523
19524 #[test]
19525 fn c64_activation_log_matches_principal_branch() {
19526 let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
19530 let out = run_c64_activation(Activation::Log, &inp);
19531 assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
19532 assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
19533 assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
19534 }
19535
19536 #[test]
19537 fn c64_activation_sqrt_squared_recovers_input() {
19538 let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
19541 let roots = run_c64_activation(Activation::Sqrt, &inp);
19542 assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
19544 assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
19545 }
19546
19547 #[test]
19548 #[should_panic(expected = "no natural complex extension")]
19549 fn c64_activation_relu_is_rejected_at_lowering() {
19550 run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
19551 }
19552
19553 fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
19557 let n = z.len();
19558 let mut g = Graph::new("cns_fwd");
19559 let in_z = g.input("z", Shape::new(&[n], DType::C64));
19560 let out = g.complex_norm_sq(in_z);
19561 g.set_outputs(vec![out]);
19562 let plan = rlx_opt::memory::plan_memory(&g);
19563 let mut arena = crate::arena::Arena::from_plan(plan);
19564 let sched = compile_thunks(&g, &arena);
19565 let z_off = arena.byte_offset(in_z);
19566 let out_off = arena.byte_offset(out);
19567 let buf = arena.raw_buf_mut();
19568 unsafe {
19569 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19570 for (i, &(re, im)) in z.iter().enumerate() {
19571 *pz.add(2 * i) = re;
19572 *pz.add(2 * i + 1) = im;
19573 }
19574 }
19575 execute_thunks(&sched, arena.raw_buf_mut());
19576 unsafe {
19577 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19578 (0..n).map(|i| *p.add(i)).collect()
19579 }
19580 }
19581
19582 fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
19584 let n = z.len();
19585 let mut gr = Graph::new("cns_bwd");
19586 let in_z = gr.input("z", Shape::new(&[n], DType::C64));
19587 let in_g = gr.input("g", Shape::new(&[n], DType::F32));
19588 let out = gr.complex_norm_sq_backward(in_z, in_g);
19589 gr.set_outputs(vec![out]);
19590 let plan = rlx_opt::memory::plan_memory(&gr);
19591 let mut arena = crate::arena::Arena::from_plan(plan);
19592 let sched = compile_thunks(&gr, &arena);
19593 let z_off = arena.byte_offset(in_z);
19594 let g_off = arena.byte_offset(in_g);
19595 let out_off = arena.byte_offset(out);
19596 let buf = arena.raw_buf_mut();
19597 unsafe {
19598 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19599 let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
19600 for (i, &(re, im)) in z.iter().enumerate() {
19601 *pz.add(2 * i) = re;
19602 *pz.add(2 * i + 1) = im;
19603 }
19604 for (i, &v) in g.iter().enumerate() {
19605 *pg.add(i) = v;
19606 }
19607 }
19608 execute_thunks(&sched, arena.raw_buf_mut());
19609 unsafe {
19610 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19611 (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
19612 }
19613 }
19614
19615 #[test]
19616 fn complex_norm_sq_matches_textbook() {
19617 let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
19621 let out = run_complex_norm_sq(&z);
19622 assert!((out[0] - 25.0).abs() < 1e-5);
19623 assert!((out[1] - 1.0).abs() < 1e-6);
19624 assert!(out[2].abs() < 1e-6);
19625 }
19626
19627 #[test]
19628 fn complex_norm_sq_backward_matches_wirtinger_formula() {
19629 let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
19631 let g = [1.0_f32, 1.0_f32];
19632 let dz = run_complex_norm_sq_bwd(&z, &g);
19633 assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
19634 assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
19635 }
19636
19637 #[test]
19638 fn complex_norm_sq_backward_scales_with_upstream() {
19639 let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
19641 let g = [0.5_f32, -2.0_f32];
19642 let dz = run_complex_norm_sq_bwd(&z, &g);
19643 assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
19644 assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
19645 }
19646
19647 #[test]
19652 fn custom_fn_multi_extracts_each_subgraph_output() {
19653 use rlx_ir::ops::special::MultiOutputHandle;
19654
19655 let _ = MultiOutputHandle {
19656 source: NodeId(0),
19657 sub_shapes: vec![],
19658 offsets: vec![],
19659 }; let mut body = Graph::new("multi_body");
19663 let s3 = Shape::new(&[3], DType::F32);
19664 let x = body.input("x", s3.clone());
19665 let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
19666 let two = body.add_node(
19667 Op::Constant {
19668 data: vec![
19669 2.0_f32.to_le_bytes(),
19670 2.0_f32.to_le_bytes(),
19671 2.0_f32.to_le_bytes(),
19672 ]
19673 .into_iter()
19674 .flatten()
19675 .collect(),
19676 },
19677 vec![],
19678 s3.clone(),
19679 );
19680 let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
19681 body.set_outputs(vec![x_sq, two_x]);
19682
19683 let mut outer = Graph::new("multi_outer");
19685 let in_x = outer.input("xin", s3.clone());
19686 let handle = outer.custom_fn_multi(vec![in_x], body);
19687 assert_eq!(handle.n_outputs(), 2);
19688 let out0 = handle.output(&mut outer, 0); let out1 = handle.output(&mut outer, 1); outer.set_outputs(vec![out0, out1]);
19691
19692 let plan = rlx_opt::memory::plan_memory(&outer);
19693 let mut arena = crate::arena::Arena::from_plan(plan);
19694 let sched = compile_thunks(&outer, &arena);
19695 let xin_off = arena.byte_offset(in_x);
19696 let out0_off = arena.byte_offset(out0);
19697 let out1_off = arena.byte_offset(out1);
19698 let xs = [1.0_f32, 2.0, 3.0];
19699 unsafe {
19700 let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
19701 for (i, &v) in xs.iter().enumerate() {
19702 *p.add(i) = v;
19703 }
19704 }
19705 execute_thunks(&sched, arena.raw_buf_mut());
19706 let out0_v: Vec<f32> = unsafe {
19707 let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
19708 (0..3).map(|i| *p.add(i)).collect()
19709 };
19710 let out1_v: Vec<f32> = unsafe {
19711 let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
19712 (0..3).map(|i| *p.add(i)).collect()
19713 };
19714 for i in 0..3 {
19716 assert!(
19717 (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
19718 "out0[{i}] = {} != x² = {}",
19719 out0_v[i],
19720 xs[i] * xs[i]
19721 );
19722 assert!(
19723 (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
19724 "out1[{i}] = {} != 2x = {}",
19725 out1_v[i],
19726 2.0 * xs[i]
19727 );
19728 }
19729 }
19730
19731 #[test]
19732 fn complex_norm_sq_gradient_matches_finite_difference() {
19733 let z = [(3.0_f32, 4.0_f32)];
19735 let eps = 1e-3_f32;
19736 let v0 = run_complex_norm_sq(&z)[0];
19737 let z_pert = [(3.0_f32 + eps, 4.0_f32)];
19738 let v1 = run_complex_norm_sq(&z_pert)[0];
19739 let fd_re = (v1 - v0) / eps;
19740 let analytic_re = 2.0 * z[0].0;
19741 assert!((fd_re - analytic_re).abs() < 1e-2);
19742
19743 let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
19745 let v2 = run_complex_norm_sq(&z_pert_im)[0];
19746 let fd_im = (v2 - v0) / eps;
19747 let analytic_im = 2.0 * z[0].1;
19748 assert!((fd_im - analytic_im).abs() < 1e-2);
19749
19750 let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
19756 assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
19757 assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
19758 }
19759
19760 #[test]
19765 fn binary_full_5d_mid_singleton_broadcast() {
19766 let bh = 2usize;
19767 let h = 3;
19768 let w = 4;
19769 let f = DType::F32;
19770
19771 let mut g = Graph::new("bcast_5d");
19772 let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
19773 let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
19775 let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
19776 g.set_outputs(vec![out]);
19777
19778 let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
19780 let rhs_data: Vec<f32> = (0..bh * h * w * w)
19781 .map(|i| (i as f32 + 100.0) * 0.01)
19782 .collect();
19783
19784 let mut expected = vec![0f32; bh * h * w * h * w];
19786 for b_ in 0..bh {
19787 for hq in 0..h {
19788 for wq in 0..w {
19789 for hk in 0..h {
19790 for wk in 0..w {
19791 let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
19792 let ri = ((b_ * h + hq) * w + wq) * w + wk;
19794 expected[li] = lhs_data[li] + rhs_data[ri];
19795 }
19796 }
19797 }
19798 }
19799 }
19800
19801 let plan = rlx_opt::memory::plan_memory(&g);
19802 let mut arena = crate::arena::Arena::from_plan(plan);
19803 let sched = compile_thunks(&g, &arena);
19804 let lhs_off = arena.byte_offset(lhs);
19805 let rhs_off = arena.byte_offset(rhs);
19806 let out_off = arena.byte_offset(out);
19807 let buf = arena.raw_buf_mut();
19808 unsafe {
19809 let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
19810 for (i, &v) in lhs_data.iter().enumerate() {
19811 *p.add(i) = v;
19812 }
19813 let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
19814 for (i, &v) in rhs_data.iter().enumerate() {
19815 *p.add(i) = v;
19816 }
19817 }
19818 execute_thunks(&sched, arena.raw_buf_mut());
19819 let actual: Vec<f32> = unsafe {
19820 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19821 (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
19822 };
19823
19824 let mut max_diff = 0f32;
19826 let mut max_idx = 0;
19827 for i in 0..actual.len() {
19828 let d = (actual[i] - expected[i]).abs();
19829 if d > max_diff {
19830 max_diff = d;
19831 max_idx = i;
19832 }
19833 }
19834 assert!(
19835 max_diff < 1e-6,
19836 "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
19837 (actual={}, expected={})",
19838 actual[max_idx],
19839 expected[max_idx]
19840 );
19841 }
19842
19843 #[test]
19844 fn layer_norm2d_and_conv_transpose2d_kernels() {
19845 let mut out = vec![0f32; 8];
19846 crate::kernels::layer_norm2d_nchw(
19847 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
19848 &[1.0, 1.0],
19849 &[0.0, 0.0],
19850 &mut out,
19851 1,
19852 2,
19853 2,
19854 2,
19855 1e-5,
19856 );
19857 let mean0: f32 = (1.0 + 3.0) / 2.0;
19858 assert!((out[0] - mean0).abs() > 0.1);
19859
19860 let mut up = vec![0f32; 4];
19861 crate::kernels::conv_transpose2d_nchw(
19862 &[2.0],
19863 &[1.0, 0.0, 0.0, 1.0],
19864 &mut up,
19865 1,
19866 1,
19867 1,
19868 1,
19869 1,
19870 2,
19871 2,
19872 2,
19873 2,
19874 2,
19875 2,
19876 0,
19877 0,
19878 1,
19879 1,
19880 1,
19881 );
19882 assert!((up[0] - 2.0).abs() < 1e-5);
19883 assert!((up[3] - 2.0).abs() < 1e-5);
19884 }
19885}