1#![allow(unsafe_op_in_unsafe_fn)]
24use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33#[derive(Clone)]
35pub enum Thunk {
36 Nop,
38 Sgemm {
40 a: usize,
41 b: usize,
42 c: usize,
43 m: u32,
44 k: u32,
45 n: u32,
46 },
47 DenseSolveF64 {
53 a: usize,
54 b: usize,
55 x: usize,
56 n: u32,
57 nrhs: u32,
58 },
59 DenseSolveF32 {
62 a: usize,
63 b: usize,
64 x: usize,
65 n: u32,
66 nrhs: u32,
67 },
68 BatchedDenseSolveF64 {
73 a: usize,
74 b: usize,
75 x: usize,
76 batch: u32,
77 n: u32,
78 nrhs: u32,
79 },
80 BatchedDenseSolveF32 {
82 a: usize,
83 b: usize,
84 x: usize,
85 batch: u32,
86 n: u32,
87 nrhs: u32,
88 },
89 BatchedDgemmF64 {
95 a: usize,
96 b: usize,
97 c: usize,
98 batch: u32,
99 m: u32,
100 k: u32,
101 n: u32,
102 },
103 BatchedSgemm {
110 a: usize,
111 b: usize,
112 c: usize,
113 batch: u32,
114 m: u32,
115 k: u32,
116 n: u32,
117 },
118 Dgemm {
120 a: usize,
121 b: usize,
122 c: usize,
123 m: u32,
124 k: u32,
125 n: u32,
126 },
127 TransposeF64 {
131 src: usize,
132 dst: usize,
133 in_total: u32,
134 out_dims: Vec<u32>,
135 in_strides: Vec<u32>,
136 },
137 ActivationF64 {
141 src: usize,
142 dst: usize,
143 len: u32,
144 kind: Activation,
145 },
146 ComplexNormSqF32 {
150 src: usize,
151 dst: usize,
152 len: u32,
154 },
155 ComplexNormSqBackwardF32 {
159 z: usize,
160 g: usize,
161 dz: usize,
162 len: u32,
163 },
164 ConjugateC64 { src: usize, dst: usize, len: u32 },
167 ActivationC64 {
174 src: usize,
175 dst: usize,
176 len: u32,
177 kind: Activation,
178 },
179 ReduceSumF64 {
183 src: usize,
184 dst: usize,
185 outer: u32,
186 reduced: u32,
187 inner: u32,
188 },
189 CopyF64 { src: usize, dst: usize, len: u32 },
192 BinaryFullF64 {
196 lhs: usize,
197 rhs: usize,
198 dst: usize,
199 len: u32,
200 lhs_len: u32,
201 rhs_len: u32,
202 op: BinaryOp,
203 out_dims_bcast: Vec<u32>,
206 bcast_lhs_strides: Vec<u32>,
207 bcast_rhs_strides: Vec<u32>,
208 },
209 ConcatF64 {
213 dst: usize,
214 outer: u32,
215 inner: u32,
216 total_axis: u32,
217 inputs: Vec<(usize, u32)>,
218 },
219 BinaryFullC64 {
227 lhs: usize,
228 rhs: usize,
229 dst: usize,
230 len: u32,
233 lhs_len: u32,
234 rhs_len: u32,
235 op: BinaryOp,
236 out_dims_bcast: Vec<u32>,
237 bcast_lhs_strides: Vec<u32>,
238 bcast_rhs_strides: Vec<u32>,
239 },
240 Scan {
249 body: Arc<ThunkSchedule>,
250 body_init: Arc<Vec<u8>>, body_input_off: usize, body_output_off: usize, outer_init_off: usize, outer_final_off: usize, length: u32,
256 carry_bytes: u32, save_trajectory: bool,
262 xs_inputs: Arc<Vec<(usize, usize, u32)>>,
267 bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
273 num_checkpoints: u32,
279 },
280
281 ScanBackward {
289 body_vjp: Arc<ThunkSchedule>,
290 body_init: Arc<Vec<u8>>,
291 body_carry_in_off: usize, body_x_offs: Arc<Vec<usize>>, body_d_output_off: usize, body_dcarry_out_off: usize, outer_init_off: usize, outer_traj_off: usize, outer_upstream_off: usize, outer_xs_offs: Arc<Vec<(usize, u32)>>,
301 outer_dinit_off: usize, length: u32,
303 carry_bytes: u32,
304 carry_elem_size: u32,
310 save_trajectory: bool, num_checkpoints: u32,
317 forward_body: Option<Arc<ThunkSchedule>>,
321 forward_body_init: Option<Arc<Vec<u8>>>,
323 forward_body_carry_in_off: usize,
326 forward_body_output_off: usize,
327 forward_body_x_offs: Arc<Vec<usize>>,
330 },
331
332 ScanBackwardXs {
339 body_vjp: Arc<ThunkSchedule>,
340 body_init: Arc<Vec<u8>>,
341 body_carry_in_off: usize,
342 body_x_offs: Arc<Vec<usize>>,
343 body_d_output_off: usize,
344 body_dcarry_out_off: usize,
345 body_dxs_out_off: usize, outer_init_off: usize,
347 outer_traj_off: usize,
348 outer_upstream_off: usize,
349 outer_xs_offs: Arc<Vec<(usize, u32)>>,
350 outer_dxs_off: usize, length: u32,
352 carry_bytes: u32,
353 carry_elem_size: u32,
355 per_step_bytes: u32, save_trajectory: bool,
357 num_checkpoints: u32,
365 forward_body: Option<Arc<ThunkSchedule>>,
366 forward_body_init: Option<Arc<Vec<u8>>>,
367 forward_body_carry_in_off: usize,
368 forward_body_output_off: usize,
369 forward_body_x_offs: Arc<Vec<usize>>,
370 },
371 CustomFn {
376 body: Arc<ThunkSchedule>,
377 body_init: Arc<Vec<u8>>,
378 inputs: Arc<Vec<(usize, usize, u32)>>,
380 body_output_off: usize,
381 outer_output_off: usize,
382 out_bytes: u32,
383 },
384 FusedMmBiasAct {
386 a: usize,
387 w: usize,
388 bias: usize,
389 c: usize,
390 m: u32,
391 k: u32,
392 n: u32,
393 act: Option<Activation>,
394 },
395 FusedResidualLN {
397 x: usize,
398 res: usize,
399 bias: usize,
400 g: usize,
401 b: usize,
402 out: usize,
403 rows: u32,
404 h: u32,
405 eps: f32,
406 has_bias: bool,
407 },
408 FusedResidualRmsNorm {
410 x: usize,
411 res: usize,
412 bias: usize,
413 g: usize,
414 b: usize,
415 out: usize,
416 rows: u32,
417 h: u32,
418 eps: f32,
419 has_bias: bool,
420 },
421 BiasAdd {
423 src: usize,
424 bias: usize,
425 dst: usize,
426 m: u32,
427 n: u32,
428 },
429 BinaryFull {
444 lhs: usize,
445 rhs: usize,
446 dst: usize,
447 len: u32,
448 lhs_len: u32,
449 rhs_len: u32,
450 op: BinaryOp,
451 out_dims_bcast: Vec<u32>,
453 bcast_lhs_strides: Vec<u32>,
455 bcast_rhs_strides: Vec<u32>,
457 },
458 ActivationInPlace {
460 data: usize,
461 len: u32,
462 act: Activation,
463 },
464 Gather {
466 table: usize,
467 table_len: u32,
468 idx: usize,
469 dst: usize,
470 num_idx: u32,
471 trailing: u32,
472 },
473 Narrow {
475 src: usize,
476 dst: usize,
477 outer: u32,
478 src_stride: u32,
479 dst_stride: u32,
480 inner: u32,
481 elem_bytes: u8,
482 },
483 Copy { src: usize, dst: usize, len: u32 },
485 LayerNorm {
487 src: usize,
488 g: usize,
489 b: usize,
490 dst: usize,
491 rows: u32,
492 h: u32,
493 eps: f32,
494 },
495 GroupNorm {
497 src: usize,
498 g: usize,
499 b: usize,
500 dst: usize,
501 n: u32,
502 c: u32,
503 h: u32,
504 w: u32,
505 num_groups: u32,
506 eps: f32,
507 },
508 LayerNorm2d {
510 src: usize,
511 g: usize,
512 b: usize,
513 dst: usize,
514 n: u32,
515 c: u32,
516 h: u32,
517 w: u32,
518 eps: f32,
519 },
520 ConvTranspose2d {
522 src: usize,
523 weight: usize,
524 dst: usize,
525 n: u32,
526 c_in: u32,
527 h: u32,
528 w_in: u32,
529 c_out: u32,
530 h_out: u32,
531 w_out: u32,
532 kh: u32,
533 kw: u32,
534 sh: u32,
535 sw: u32,
536 ph: u32,
537 pw: u32,
538 dh: u32,
539 dw: u32,
540 groups: u32,
541 },
542 ResizeNearest2x {
544 src: usize,
545 dst: usize,
546 n: u32,
547 c: u32,
548 h: u32,
549 w: u32,
550 },
551 AxialRope2d {
553 src: usize,
554 dst: usize,
555 batch: u32,
556 seq: u32,
557 hidden: u32,
558 end_x: u32,
559 end_y: u32,
560 head_dim: u32,
561 num_heads: u32,
562 theta: f32,
563 repeat_factor: u32,
564 },
565 RmsNorm {
568 src: usize,
569 g: usize,
570 b: usize,
571 dst: usize,
572 rows: u32,
573 h: u32,
574 eps: f32,
575 },
576 Softmax { data: usize, rows: u32, cols: u32 },
578 Cumsum {
581 src: usize,
582 dst: usize,
583 rows: u32,
584 cols: u32,
585 exclusive: bool,
586 },
587 SelectiveScan {
591 x: usize,
592 delta: usize,
593 a: usize,
594 b: usize,
595 c: usize,
596 dst: usize,
597 batch: u32,
598 seq: u32,
599 hidden: u32,
600 state_size: u32,
601 },
602
603 GatedDeltaNet {
607 q: usize,
608 k: usize,
609 v: usize,
610 g: usize,
611 beta: usize,
612 state: usize,
615 dst: usize,
616 batch: u32,
617 seq: u32,
618 heads: u32,
619 state_size: u32,
620 },
621
622 Conv2D1x1 {
632 src: usize,
633 weight: usize,
634 dst: usize,
635 n: u32,
636 c_in: u32,
637 c_out: u32,
638 hw: u32,
639 },
640
641 DequantMatMul {
645 x: usize,
646 w_q: usize, scale: usize, zp: usize, dst: usize,
650 m: u32,
651 k: u32,
652 n: u32,
653 block_size: u32,
654 is_asymmetric: bool,
655 },
656
657 DequantMatMulGguf {
667 x: usize, w_q: usize, dst: usize, m: u32,
671 k: u32,
672 n: u32,
673 scheme: rlx_ir::quant::QuantScheme,
674 },
675
676 DequantMatMulInt4 {
678 x: usize,
679 w_q: usize,
680 scale: usize,
681 zp: usize,
682 dst: usize,
683 m: u32,
684 k: u32,
685 n: u32,
686 block_size: u32,
687 is_asymmetric: bool,
688 },
689
690 DequantMatMulFp8 {
692 x: usize,
693 w_q: usize,
694 scale: usize,
695 dst: usize,
696 m: u32,
697 k: u32,
698 n: u32,
699 e5m2: bool,
700 },
701
702 DequantMatMulNvfp4 {
704 x: usize,
705 w_q: usize,
706 scale: usize,
707 global_scale: usize,
708 dst: usize,
709 m: u32,
710 k: u32,
711 n: u32,
712 },
713
714 LoraMatMul {
718 x: usize,
719 w: usize,
720 a: usize,
721 b: usize,
722 dst: usize,
723 m: u32,
724 k: u32,
725 n: u32,
726 r: u32,
727 scale: f32,
728 },
729 Sample {
733 logits: usize,
734 dst: usize,
735 batch: u32,
736 vocab: u32,
737 top_k: u32, top_p: f32, temperature: f32, seed: u64,
741 },
742 Attention {
753 q: usize,
754 k: usize,
755 v: usize,
756 mask: usize,
757 out: usize,
758 batch: u32,
759 seq: u32,
761 kv_seq: u32,
763 heads: u32,
764 head_dim: u32,
765 mask_kind: rlx_ir::op::MaskKind,
766 q_row_stride: u32,
767 k_row_stride: u32,
768 v_row_stride: u32,
769 bhsd: bool,
777 },
778 AttentionBackward {
780 q: usize,
781 k: usize,
782 v: usize,
783 dy: usize,
784 mask: usize,
785 out: usize,
786 batch: u32,
787 seq: u32,
788 kv_seq: u32,
789 heads: u32,
790 head_dim: u32,
791 mask_kind: rlx_ir::op::MaskKind,
792 wrt: rlx_ir::op::AttentionBwdWrt,
793 bhsd: bool,
794 },
795 Rope {
801 src: usize,
802 cos: usize,
803 sin: usize,
804 dst: usize,
805 batch: u32,
806 seq: u32,
807 hidden: u32,
808 head_dim: u32,
809 n_rot: u32,
810 cos_len: u32,
811 src_row_stride: u32,
812 },
813 FusedAttnBlock {
816 hidden: usize,
817 qkv_w: usize,
818 out_w: usize,
819 mask: usize,
820 out: usize,
821 qkv_b: usize,
822 out_b: usize, cos: usize,
824 sin: usize,
825 cos_len: u32, batch: u32,
827 seq: u32,
828 hs: u32,
829 nh: u32,
830 dh: u32,
831 has_bias: bool,
832 has_rope: bool,
833 },
834 FusedBertLayer {
837 hidden: usize,
839 qkv_w: usize,
840 qkv_b: usize,
841 out_w: usize,
842 out_b: usize,
843 mask: usize,
844 ln1_g: usize,
846 ln1_b: usize,
847 eps1: f32,
848 fc1_w: usize,
850 fc1_b: usize,
851 fc2_w: usize,
852 fc2_b: usize,
853 ln2_g: usize,
855 ln2_b: usize,
856 eps2: f32,
857 out: usize,
859 batch: u32,
861 seq: u32,
862 hs: u32,
863 nh: u32,
864 dh: u32,
865 int_dim: u32,
866 },
867 FusedNomicLayer {
869 hidden: usize,
870 qkv_w: usize,
871 out_w: usize,
872 mask: usize,
873 cos: usize,
874 sin: usize,
875 cos_len: u32,
876 ln1_g: usize,
877 ln1_b: usize,
878 eps1: f32,
879 fc11_w: usize,
880 fc12_w: usize,
881 fc2_w: usize,
882 ln2_g: usize,
883 ln2_b: usize,
884 eps2: f32,
885 out: usize,
886 batch: u32,
887 seq: u32,
888 hs: u32,
889 nh: u32,
890 dh: u32,
891 int_dim: u32,
892 },
893 FusedSwiGLU {
897 src: usize,
898 dst: usize,
899 n_half: u32,
900 total: u32,
901 gate_first: bool,
902 },
903 Concat {
908 dst: usize,
909 outer: u32,
910 inner: u32,
911 total_axis: u32,
912 inputs: Vec<(usize, u32)>,
913 },
914 Compare {
916 lhs: usize,
917 rhs: usize,
918 dst: usize,
919 len: u32,
920 op: CmpOp,
921 },
922 Reduce {
930 src: usize,
931 dst: usize,
932 outer: u32,
933 reduced: u32,
934 inner: u32,
935 op: ReduceOp,
936 },
937 TopK {
941 src: usize,
942 dst: usize,
943 outer: u32,
944 axis_dim: u32,
945 k: u32,
946 },
947 GroupedMatMul {
951 input: usize,
952 weight: usize,
953 expert_idx: usize,
954 dst: usize,
955 m: u32,
956 k_dim: u32,
957 n: u32,
958 num_experts: u32,
959 },
960 DequantGroupedMatMulGguf {
962 input: usize,
963 w_q: usize,
964 expert_idx: usize,
965 dst: usize,
966 m: u32,
967 k_dim: u32,
968 n: u32,
969 num_experts: u32,
970 scheme: rlx_ir::quant::QuantScheme,
971 },
972 DequantMoEWeightsGguf {
974 w_q: usize,
975 dst: usize,
976 k_dim: u32,
977 n: u32,
978 num_experts: u32,
979 scheme: rlx_ir::quant::QuantScheme,
980 },
981 ScatterAdd {
984 updates: usize,
985 indices: usize,
986 dst: usize,
987 num_updates: u32,
988 out_dim: u32,
989 trailing: u32,
990 },
991 Where {
993 cond: usize,
994 on_true: usize,
995 on_false: usize,
996 dst: usize,
997 len: u32,
998 },
999 Transpose {
1005 src: usize,
1006 dst: usize,
1007 in_total: u32,
1008 out_dims: Vec<u32>,
1009 in_strides: Vec<u32>,
1010 },
1011 GatherAxis {
1016 table: usize,
1017 idx: usize,
1018 dst: usize,
1019 outer: u32,
1020 axis_dim: u32,
1021 num_idx: u32,
1022 trailing: u32,
1023 },
1024 Pool2D {
1028 src: usize,
1029 dst: usize,
1030 n: u32,
1031 c: u32,
1032 h: u32,
1033 w: u32,
1034 h_out: u32,
1035 w_out: u32,
1036 kh: u32,
1037 kw: u32,
1038 sh: u32,
1039 sw: u32,
1040 ph: u32,
1041 pw: u32,
1042 kind: ReduceOp,
1043 },
1044 Conv2D {
1049 src: usize,
1050 weight: usize,
1051 dst: usize,
1052 n: u32,
1053 c_in: u32,
1054 h: u32,
1055 w: u32,
1056 c_out: u32,
1057 h_out: u32,
1058 w_out: u32,
1059 kh: u32,
1060 kw: u32,
1061 sh: u32,
1062 sw: u32,
1063 ph: u32,
1064 pw: u32,
1065 dh: u32,
1066 dw: u32,
1067 groups: u32,
1068 },
1069
1070 QMatMul {
1078 x: usize,
1079 w: usize,
1080 bias: usize,
1081 out: usize,
1082 m: u32,
1083 k: u32,
1084 n: u32,
1085 x_zp: i32,
1086 w_zp: i32,
1087 out_zp: i32,
1088 mult: f32,
1089 },
1090
1091 QConv2d {
1095 x: usize,
1096 w: usize,
1097 bias: usize,
1098 out: usize,
1099 n: u32,
1100 c_in: u32,
1101 h: u32,
1102 w_in: u32,
1103 c_out: u32,
1104 h_out: u32,
1105 w_out: u32,
1106 kh: u32,
1107 kw: u32,
1108 sh: u32,
1109 sw: u32,
1110 ph: u32,
1111 pw: u32,
1112 dh: u32,
1113 dw: u32,
1114 groups: u32,
1115 x_zp: i32,
1116 w_zp: i32,
1117 out_zp: i32,
1118 mult: f32,
1119 },
1120
1121 Quantize {
1128 x: usize,
1129 q: usize,
1130 len: u32,
1131 chan_axis: u32,
1132 chan_dim: u32,
1133 inner: u32,
1134 scales: Vec<f32>,
1135 zero_points: Vec<i32>,
1136 },
1137
1138 Dequantize {
1140 q: usize,
1141 x: usize,
1142 len: u32,
1143 chan_axis: u32,
1144 chan_dim: u32,
1145 inner: u32,
1146 scales: Vec<f32>,
1147 zero_points: Vec<i32>,
1148 },
1149
1150 FakeQuantize {
1161 x: usize,
1162 out: usize,
1163 len: u32,
1164 chan_axis: u32,
1165 chan_dim: u32,
1166 inner: u32,
1167 bits: u8,
1168 ste: rlx_ir::op::SteKind,
1172 scale_mode: rlx_ir::op::ScaleMode,
1177 state_off: Option<usize>,
1181 },
1182
1183 FakeQuantizeBackward {
1188 x: usize,
1189 dy: usize,
1190 dx: usize,
1191 len: u32,
1192 chan_axis: u32,
1193 chan_dim: u32,
1194 inner: u32,
1195 bits: u8,
1196 ste: rlx_ir::op::SteKind,
1197 },
1198
1199 FakeQuantizeLSQ {
1202 x: usize,
1203 scale_off: usize,
1204 out: usize,
1205 len: u32,
1206 chan_axis: u32,
1207 chan_dim: u32,
1208 inner: u32,
1209 bits: u8,
1210 },
1211
1212 FakeQuantizeLSQBackwardX {
1215 x: usize,
1216 scale_off: usize,
1217 dy: usize,
1218 dx: usize,
1219 len: u32,
1220 chan_axis: u32,
1221 chan_dim: u32,
1222 inner: u32,
1223 bits: u8,
1224 },
1225
1226 FakeQuantizeLSQBackwardScale {
1231 x: usize,
1232 scale_off: usize,
1233 dy: usize,
1234 dscale: usize,
1235 len: u32,
1236 chan_axis: u32,
1237 chan_dim: u32,
1238 inner: u32,
1239 bits: u8,
1240 },
1241
1242 ReluBackward {
1244 x: usize,
1245 dy: usize,
1246 dx: usize,
1247 len: u32,
1248 },
1249 ReluBackwardF64 {
1255 x: usize,
1256 dy: usize,
1257 dx: usize,
1258 len: u32,
1259 },
1260
1261 ActivationBackward {
1266 x: usize,
1267 dy: usize,
1268 dx: usize,
1269 len: u32,
1270 kind: Activation,
1271 },
1272 ActivationBackwardF64 {
1278 x: usize,
1279 dy: usize,
1280 dx: usize,
1281 len: u32,
1282 kind: Activation,
1283 },
1284
1285 LayerNormBackwardInput {
1288 x: usize,
1289 gamma: usize,
1290 dy: usize,
1291 dx: usize,
1292 rows: u32,
1293 h: u32,
1294 eps: f32,
1295 },
1296
1297 LayerNormBackwardGamma {
1299 x: usize,
1300 dy: usize,
1301 dgamma: usize,
1302 rows: u32,
1303 h: u32,
1304 eps: f32,
1305 },
1306
1307 RmsNormBackwardInput {
1308 x: usize,
1309 gamma: usize,
1310 beta: usize,
1311 dy: usize,
1312 dx: usize,
1313 rows: u32,
1314 h: u32,
1315 eps: f32,
1316 },
1317 RmsNormBackwardGamma {
1318 x: usize,
1319 gamma: usize,
1320 beta: usize,
1321 dy: usize,
1322 dgamma: usize,
1323 rows: u32,
1324 h: u32,
1325 eps: f32,
1326 },
1327 RmsNormBackwardBeta {
1328 x: usize,
1329 gamma: usize,
1330 beta: usize,
1331 dy: usize,
1332 dbeta: usize,
1333 rows: u32,
1334 h: u32,
1335 eps: f32,
1336 },
1337 RopeBackward {
1338 dy: usize,
1339 cos: usize,
1340 sin: usize,
1341 dx: usize,
1342 batch: u32,
1343 seq: u32,
1344 hidden: u32,
1345 head_dim: u32,
1346 n_rot: u32,
1347 cos_len: u32,
1348 },
1349 CumsumBackward {
1350 dy: usize,
1351 dx: usize,
1352 rows: u32,
1353 cols: u32,
1354 exclusive: bool,
1355 },
1356 GatherBackward {
1357 dy: usize,
1358 indices: usize,
1359 dst: usize,
1360 outer: u32,
1361 axis_dim: u32,
1362 num_idx: u32,
1363 trailing: u32,
1364 },
1365
1366 GroupNormBackwardInput {
1367 x: usize,
1368 gamma: usize,
1369 beta: usize,
1370 dy: usize,
1371 dx: usize,
1372 n: u32,
1373 c: u32,
1374 h: u32,
1375 w: u32,
1376 num_groups: u32,
1377 eps: f32,
1378 },
1379 GroupNormBackwardGamma {
1380 x: usize,
1381 dy: usize,
1382 dgamma: usize,
1383 n: u32,
1384 c: u32,
1385 h: u32,
1386 w: u32,
1387 num_groups: u32,
1388 eps: f32,
1389 },
1390 GroupNormBackwardBeta {
1391 dy: usize,
1392 dbeta: usize,
1393 n: u32,
1394 c: u32,
1395 h: u32,
1396 w: u32,
1397 },
1398
1399 MaxPool2dBackward {
1405 x: usize,
1406 dy: usize,
1407 dx: usize,
1408 n: u32,
1409 c: u32,
1410 h: u32,
1411 w: u32,
1412 h_out: u32,
1413 w_out: u32,
1414 kh: u32,
1415 kw: u32,
1416 sh: u32,
1417 sw: u32,
1418 ph: u32,
1419 pw: u32,
1420 },
1421
1422 Conv2dBackwardInput {
1426 dy: usize,
1427 w: usize,
1428 dx: usize,
1429 n: u32,
1430 c_in: u32,
1431 h: u32,
1432 w_in: u32,
1433 c_out: u32,
1434 h_out: u32,
1435 w_out: u32,
1436 kh: u32,
1437 kw: u32,
1438 sh: u32,
1439 sw: u32,
1440 ph: u32,
1441 pw: u32,
1442 dh: u32,
1443 dw: u32,
1444 groups: u32,
1445 },
1446
1447 Conv2dBackwardWeight {
1451 x: usize,
1452 dy: usize,
1453 dw: usize,
1454 n: u32,
1455 c_in: u32,
1456 h: u32,
1457 w: u32,
1458 c_out: u32,
1459 h_out: u32,
1460 w_out: u32,
1461 kh: u32,
1462 kw: u32,
1463 sh: u32,
1464 sw: u32,
1465 ph: u32,
1466 pw: u32,
1467 dh: u32,
1468 dw_dil: u32,
1469 groups: u32,
1470 },
1471
1472 SoftmaxCrossEntropy {
1476 logits: usize,
1477 labels: usize,
1478 dst: usize,
1479 n: u32,
1480 c: u32,
1481 },
1482
1483 SoftmaxCrossEntropyBackward {
1486 logits: usize,
1487 labels: usize,
1488 d_loss: usize,
1489 dlogits: usize,
1490 n: u32,
1491 c: u32,
1492 },
1493
1494 CustomOp {
1500 kernel: Arc<dyn CpuKernel>,
1501 inputs: Vec<(usize, u32, Shape)>, output: (usize, u32, Shape), attrs: Vec<u8>,
1504 },
1505
1506 GaussianSplatRender {
1516 positions_off: usize,
1517 positions_len: usize,
1518 scales_off: usize,
1519 scales_len: usize,
1520 rotations_off: usize,
1521 rotations_len: usize,
1522 opacities_off: usize,
1523 opacities_len: usize,
1524 colors_off: usize,
1525 colors_len: usize,
1526 sh_coeffs_off: usize,
1527 sh_coeffs_len: usize,
1528 meta_off: usize,
1529 dst_off: usize,
1530 dst_len: usize,
1531 width: u32,
1532 height: u32,
1533 tile_size: u32,
1534 radius_scale: f32,
1535 alpha_cutoff: f32,
1536 max_splat_steps: u32,
1537 transmittance_threshold: f32,
1538 max_list_entries: u32,
1539 },
1540 GaussianSplatRenderBackward {
1541 positions_off: usize,
1542 positions_len: usize,
1543 scales_off: usize,
1544 scales_len: usize,
1545 rotations_off: usize,
1546 rotations_len: usize,
1547 opacities_off: usize,
1548 opacities_len: usize,
1549 colors_off: usize,
1550 colors_len: usize,
1551 sh_coeffs_off: usize,
1552 sh_coeffs_len: usize,
1553 meta_off: usize,
1554 d_loss_off: usize,
1555 d_loss_len: usize,
1556 packed_off: usize,
1557 packed_len: usize,
1558 width: u32,
1559 height: u32,
1560 tile_size: u32,
1561 radius_scale: f32,
1562 alpha_cutoff: f32,
1563 max_splat_steps: u32,
1564 transmittance_threshold: f32,
1565 max_list_entries: u32,
1566 loss_grad_clip: f32,
1567 sh_band: u32,
1568 max_anisotropy: f32,
1569 },
1570 GaussianSplatPrepare {
1572 positions_off: usize,
1573 positions_len: usize,
1574 scales_off: usize,
1575 scales_len: usize,
1576 rotations_off: usize,
1577 rotations_len: usize,
1578 opacities_off: usize,
1579 opacities_len: usize,
1580 colors_off: usize,
1581 colors_len: usize,
1582 sh_coeffs_off: usize,
1583 sh_coeffs_len: usize,
1584 meta_off: usize,
1585 meta_len: usize,
1586 prep_off: usize,
1587 prep_len: usize,
1588 width: u32,
1589 height: u32,
1590 tile_size: u32,
1591 radius_scale: f32,
1592 alpha_cutoff: f32,
1593 max_splat_steps: u32,
1594 transmittance_threshold: f32,
1595 max_list_entries: u32,
1596 },
1597 GaussianSplatRasterize {
1599 prep_off: usize,
1600 prep_len: usize,
1601 meta_off: usize,
1602 meta_len: usize,
1603 dst_off: usize,
1604 dst_len: usize,
1605 count: usize,
1606 width: u32,
1607 height: u32,
1608 tile_size: u32,
1609 alpha_cutoff: f32,
1610 max_splat_steps: u32,
1611 transmittance_threshold: f32,
1612 max_list_entries: u32,
1613 },
1614 Fft1d {
1615 src: usize,
1616 dst: usize,
1617 outer: u32,
1618 n_complex: u32,
1619 inverse: bool,
1620 norm_tag: u32,
1621 dtype: rlx_ir::DType,
1622 },
1623}
1624
1625#[derive(Clone)]
1628pub struct ThunkSchedule {
1629 pub thunks: Vec<Thunk>,
1630 pub moe_resident: Option<std::sync::Arc<[bool]>>,
1632 pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1634 pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1636 pub mask_threshold: f32,
1638 pub mask_neg_inf: f32,
1639 pub score_skip: f32,
1640 pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1646}
1647
1648impl ThunkSchedule {
1649 pub fn strip_nops(&mut self) {
1650 self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1651 self.compiled_fns.clear();
1654 }
1655}
1656
1657fn node_offset(arena: &Arena, id: NodeId) -> usize {
1659 if arena.has_buffer(id) {
1660 arena.byte_offset(id)
1661 } else {
1662 usize::MAX
1663 }
1664}
1665
1666fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1672 match t {
1673 Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1674 Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1675 Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1676 Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1677 Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1678 Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1679 Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1680 Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1681 Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1682 Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1683 Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1684 Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1685 Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1686 Thunk::ConjugateC64 { src, .. } => vec![*src],
1687 Thunk::Scan {
1688 outer_init_off,
1689 xs_inputs,
1690 ..
1691 } => {
1692 let mut v = vec![*outer_init_off];
1693 for (_, outer_xs_off, _) in xs_inputs.iter() {
1694 v.push(*outer_xs_off);
1695 }
1696 v
1697 }
1698 Thunk::ScanBackward {
1699 outer_init_off,
1700 outer_traj_off,
1701 outer_upstream_off,
1702 outer_xs_offs,
1703 ..
1704 } => {
1705 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1706 for (off, _) in outer_xs_offs.iter() {
1707 v.push(*off);
1708 }
1709 v
1710 }
1711 Thunk::ScanBackwardXs {
1712 outer_init_off,
1713 outer_traj_off,
1714 outer_upstream_off,
1715 outer_xs_offs,
1716 ..
1717 } => {
1718 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1719 for (off, _) in outer_xs_offs.iter() {
1720 v.push(*off);
1721 }
1722 v
1723 }
1724 Thunk::CustomFn { inputs, .. } => {
1725 inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1726 }
1727 Thunk::ActivationInPlace { data, .. } => vec![*data],
1728 Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1729 vec![*src, *g, *b]
1730 }
1731 Thunk::ResizeNearest2x { src, .. } => vec![*src],
1732 Thunk::AxialRope2d { src, .. } => vec![*src],
1733 Thunk::FusedResidualLN {
1734 x, res, bias, g, b, ..
1735 } => vec![*x, *res, *bias, *g, *b],
1736 Thunk::FusedResidualRmsNorm {
1737 x, res, bias, g, b, ..
1738 } => vec![*x, *res, *bias, *g, *b],
1739 Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1740 Thunk::Softmax { data, .. } => vec![*data],
1741 Thunk::Cumsum { src, .. } => vec![*src],
1742 Thunk::Sample { logits, .. } => vec![*logits],
1743 Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1744 Thunk::DequantMatMul {
1745 x, w_q, scale, zp, ..
1746 } => vec![*x, *w_q, *scale, *zp],
1747 Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1748 Thunk::DequantMatMulInt4 {
1749 x, w_q, scale, zp, ..
1750 } => vec![*x, *w_q, *scale, *zp],
1751 Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1752 Thunk::DequantMatMulNvfp4 {
1753 x,
1754 w_q,
1755 scale,
1756 global_scale,
1757 ..
1758 } => vec![*x, *w_q, *scale, *global_scale],
1759 Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1760 Thunk::SelectiveScan {
1761 x, delta, a, b, c, ..
1762 } => vec![*x, *delta, *a, *b, *c],
1763 Thunk::GatedDeltaNet {
1764 q,
1765 k,
1766 v,
1767 g,
1768 beta,
1769 state,
1770 ..
1771 } => {
1772 let mut v = vec![*q, *k, *v, *g, *beta];
1773 if *state != 0 {
1774 v.push(*state);
1775 }
1776 v
1777 }
1778 Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1779 Thunk::AttentionBackward {
1780 q, k, v, dy, mask, ..
1781 } => {
1782 let mut v = vec![*q, *k, *v, *dy];
1783 if *mask != 0 {
1784 v.push(*mask);
1785 }
1786 v
1787 }
1788 Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1789 Thunk::FusedAttnBlock {
1790 hidden,
1791 qkv_w,
1792 out_w,
1793 mask,
1794 qkv_b,
1795 out_b,
1796 cos,
1797 sin,
1798 ..
1799 } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1800 Thunk::FusedSwiGLU { src, .. } => vec![*src],
1801 Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1802 Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1803 Thunk::Narrow { src, .. } => vec![*src],
1804 Thunk::Copy { src, .. } => vec![*src],
1805 Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1806 _ => vec![],
1810 }
1811}
1812
1813#[allow(clippy::too_many_arguments)]
1827fn dequant_matmul_int8(
1828 x: &[f32], w_bytes: &[i8], scales: &[f32], zps: &[f32], out: &mut [f32], m: usize,
1834 k: usize,
1835 n: usize,
1836 block_size: usize,
1837 asym: bool,
1838) {
1839 let blocks_per_col = k.div_ceil(block_size);
1840 for i in 0..m {
1841 for j in 0..n {
1842 let mut acc = 0f32;
1843 for p in 0..k {
1844 let block = p / block_size;
1845 let s = scales[block * n + j];
1846 let z = if asym { zps[block * n + j] } else { 0.0 };
1847 let q = w_bytes[p * n + j] as f32;
1848 let dequantized = (q - z) * s;
1849 acc += x[i * k + p] * dequantized;
1850 }
1851 out[i * n + j] = acc;
1852 }
1853 }
1854 let _ = blocks_per_col;
1855}
1856
1857#[allow(clippy::too_many_arguments)]
1858fn dequant_matmul_int4(
1859 x: &[f32],
1860 w_bytes: &[u8],
1861 scales: &[f32],
1862 zps: &[f32],
1863 out: &mut [f32],
1864 m: usize,
1865 k: usize,
1866 n: usize,
1867 block_size: usize,
1868 asym: bool,
1869) {
1870 for i in 0..m {
1871 for j in 0..n {
1872 let mut acc = 0f32;
1873 for p in 0..k {
1874 let block = p / block_size;
1875 let s = scales[block * n + j];
1876 let z = if asym { zps[block * n + j] } else { 0.0 };
1877 let byte_idx = (p * n + j) / 2;
1878 let nibble = if (p * n + j) & 1 == 0 {
1879 w_bytes[byte_idx] & 0x0F
1880 } else {
1881 w_bytes[byte_idx] >> 4
1882 };
1883 let dequantized = (nibble as f32 - z) * s;
1884 acc += x[i * k + p] * dequantized;
1885 }
1886 out[i * n + j] = acc;
1887 }
1888 }
1889}
1890
1891fn fp8_e4m3_to_f32(b: u8) -> f32 {
1892 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1893 let exp = (b >> 3) & 0x0F;
1894 let mant = b & 0x07;
1895 if exp == 0 {
1896 if mant == 0 {
1897 return 0.0;
1898 }
1899 return sign * (mant as f32) * 2f32.powi(-9);
1900 }
1901 if exp == 0x0F {
1902 return if mant == 0 {
1903 sign * f32::INFINITY
1904 } else {
1905 f32::NAN
1906 };
1907 }
1908 sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
1909}
1910
1911fn fp8_e5m2_to_f32(b: u8) -> f32 {
1912 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1913 let exp = (b >> 2) & 0x1F;
1914 let mant = b & 0x03;
1915 if exp == 0 {
1916 if mant == 0 {
1917 return 0.0;
1918 }
1919 return sign * (mant as f32) * 2f32.powi(-16);
1920 }
1921 if exp == 0x1F {
1922 return if mant == 0 {
1923 sign * f32::INFINITY
1924 } else {
1925 f32::NAN
1926 };
1927 }
1928 sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
1929}
1930
1931#[allow(clippy::too_many_arguments)]
1932fn dequant_matmul_fp8(
1933 x: &[f32],
1934 w_bytes: &[u8],
1935 scales: &[f32],
1936 out: &mut [f32],
1937 m: usize,
1938 k: usize,
1939 n: usize,
1940 e5m2: bool,
1941) {
1942 let dequant = if e5m2 {
1943 fp8_e5m2_to_f32
1944 } else {
1945 fp8_e4m3_to_f32
1946 };
1947 for i in 0..m {
1948 for j in 0..n {
1949 let mut acc = 0f32;
1950 for p in 0..k {
1951 let w = dequant(w_bytes[p * n + j]);
1952 let s = scales.get(j).copied().unwrap_or(1.0);
1953 acc += x[i * k + p] * w * s;
1954 }
1955 out[i * n + j] = acc;
1956 }
1957 }
1958}
1959
1960#[allow(clippy::too_many_arguments)]
1961pub fn dequant_matmul_nvfp4(
1962 x: &[f32],
1963 w_bytes: &[u8],
1964 scale_bytes: &[u8],
1965 global_scale: f32,
1966 out: &mut [f32],
1967 m: usize,
1968 k: usize,
1969 n: usize,
1970) {
1971 use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
1972 let gs = NVFP4_GROUP_SIZE;
1973 for i in 0..m {
1974 for j in 0..n {
1975 let mut acc = 0f32;
1976 for p in 0..k {
1977 let byte_idx = (p * n + j) / 2;
1978 let nibble = if (p * n + j) & 1 == 0 {
1979 w_bytes[byte_idx] & 0x0F
1980 } else {
1981 w_bytes[byte_idx] >> 4
1982 };
1983 let block = p / gs;
1984 let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
1985 let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
1986 acc += x[i * k + p] * w;
1987 }
1988 out[i * n + j] = acc;
1989 }
1990 }
1991}
1992
1993fn sample_row(
2002 logits: &[f32],
2003 top_k: usize,
2004 top_p: f32,
2005 temperature: f32,
2006 rng: &mut rlx_ir::Philox4x32,
2007) -> usize {
2008 let v = logits.len();
2009 if v == 0 {
2010 return 0;
2011 }
2012 let temp = temperature.max(1e-6);
2013 let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2015
2016 if top_k > 0 && top_k < v {
2018 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2020 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2023 let cutoff = indexed[top_k - 1].1;
2024 for x in scaled.iter_mut() {
2025 if *x < cutoff {
2026 *x = f32::NEG_INFINITY;
2027 }
2028 }
2029 }
2030
2031 let mut max_l = f32::NEG_INFINITY;
2033 for &x in &scaled {
2034 if x > max_l {
2035 max_l = x;
2036 }
2037 }
2038 let mut sum = 0.0f32;
2039 for x in scaled.iter_mut() {
2040 *x = (*x - max_l).exp();
2041 sum += *x;
2042 }
2043 let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2044 for x in scaled.iter_mut() {
2045 *x *= inv;
2046 }
2047
2048 if top_p < 1.0 {
2051 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2052 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2053 let mut cum = 0.0f32;
2054 let mut keep = vec![false; v];
2055 for (idx, p) in indexed.iter() {
2056 keep[*idx] = true;
2057 cum += *p;
2058 if cum >= top_p {
2059 break;
2060 }
2061 }
2062 let mut new_sum = 0.0f32;
2063 for (i, x) in scaled.iter_mut().enumerate() {
2064 if !keep[i] {
2065 *x = 0.0;
2066 }
2067 new_sum += *x;
2068 }
2069 let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2070 for x in scaled.iter_mut() {
2071 *x *= inv;
2072 }
2073 }
2074
2075 let r = rng.next_f32();
2077 let mut acc = 0.0f32;
2078 for (i, &p) in scaled.iter().enumerate() {
2079 acc += p;
2080 if r <= acc {
2081 return i;
2082 }
2083 }
2084 v - 1 }
2086
2087#[inline]
2091fn apply_synthetic_mask(
2092 scores: &mut [f32],
2093 q_seq: usize,
2094 k_seq: usize,
2095 kind: rlx_ir::op::MaskKind,
2096) {
2097 let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2098 let q_offset = k_seq.saturating_sub(q_seq);
2099 match kind {
2100 rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2101 rlx_ir::op::MaskKind::Causal => {
2102 for qi in 0..q_seq {
2103 let abs_q = q_offset + qi;
2104 for ki in (abs_q + 1)..k_seq {
2105 scores[qi * k_seq + ki] = neg;
2106 }
2107 }
2108 }
2109 rlx_ir::op::MaskKind::SlidingWindow(w) => {
2110 for qi in 0..q_seq {
2111 let abs_q = q_offset + qi;
2112 let lo = abs_q.saturating_sub(w);
2113 for ki in 0..k_seq {
2114 if ki < lo || ki > abs_q {
2115 scores[qi * k_seq + ki] = neg;
2116 }
2117 }
2118 }
2119 }
2120 }
2121}
2122
2123pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2125 let mut thunks = Vec::with_capacity(graph.len());
2126
2127 for node in graph.nodes() {
2128 if rlx_opt::is_pure_view(graph, node) {
2132 thunks.push(Thunk::Nop);
2133 continue;
2134 }
2135 let t = match &node.op {
2136 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2137
2138 Op::FusedMatMulBiasAct { activation } => {
2139 let shape = &node.shape;
2140 let n = shape.dim(shape.rank() - 1).unwrap_static();
2141 let total = shape.num_elements().unwrap();
2142 let m = total / n;
2143 let a_len = get_len(graph, node.inputs[0]);
2144 let k = a_len / m;
2145 Thunk::FusedMmBiasAct {
2146 a: node_offset(arena, node.inputs[0]),
2147 w: node_offset(arena, node.inputs[1]),
2148 bias: node_offset(arena, node.inputs[2]),
2149 c: node_offset(arena, node.id),
2150 m: m as u32,
2151 k: k as u32,
2152 n: n as u32,
2153 act: *activation,
2154 }
2155 }
2156
2157 Op::FusedResidualLN { has_bias, eps } => {
2158 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2159 let total = node.shape.num_elements().unwrap();
2160 let rows = total / h;
2161 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2162 Thunk::FusedResidualLN {
2163 x: node_offset(arena, node.inputs[0]),
2164 res: node_offset(arena, node.inputs[1]),
2165 bias: if *has_bias {
2166 node_offset(arena, node.inputs[2])
2167 } else {
2168 0
2169 },
2170 g: node_offset(arena, node.inputs[g_idx]),
2171 b: node_offset(arena, node.inputs[b_idx]),
2172 out: node_offset(arena, node.id),
2173 rows: rows as u32,
2174 h: h as u32,
2175 eps: *eps,
2176 has_bias: *has_bias,
2177 }
2178 }
2179
2180 Op::FusedResidualRmsNorm { has_bias, eps } => {
2181 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2182 let total = node.shape.num_elements().unwrap();
2183 let rows = total / h;
2184 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2185 Thunk::FusedResidualRmsNorm {
2186 x: node_offset(arena, node.inputs[0]),
2187 res: node_offset(arena, node.inputs[1]),
2188 bias: if *has_bias {
2189 node_offset(arena, node.inputs[2])
2190 } else {
2191 0
2192 },
2193 g: node_offset(arena, node.inputs[g_idx]),
2194 b: node_offset(arena, node.inputs[b_idx]),
2195 out: node_offset(arena, node.id),
2196 rows: rows as u32,
2197 h: h as u32,
2198 eps: *eps,
2199 has_bias: *has_bias,
2200 }
2201 }
2202
2203 Op::MatMul => {
2204 let shape = &node.shape;
2205 let a_shape = &graph.node(node.inputs[0]).shape;
2206 let b_shape = &graph.node(node.inputs[1]).shape;
2207 let n = shape.dim(shape.rank() - 1).unwrap_static();
2208
2209 let batched_3d = a_shape.rank() >= 3
2216 && b_shape.rank() == a_shape.rank()
2217 && shape.rank() == a_shape.rank()
2218 && {
2219 let mut ok = true;
2221 for d in 0..a_shape.rank() - 2 {
2222 if a_shape.dim(d) != b_shape.dim(d) || a_shape.dim(d) != shape.dim(d) {
2223 ok = false;
2224 break;
2225 }
2226 }
2227 ok
2228 };
2229 if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2230 let r = shape.rank();
2234 let mut batch_prod = 1usize;
2235 for d in 0..r - 2 {
2236 batch_prod *= shape.dim(d).unwrap_static();
2237 }
2238 let m_dim = shape.dim(r - 2).unwrap_static();
2239 let k_dim = a_shape.dim(r - 1).unwrap_static();
2240 debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2241 Thunk::BatchedDgemmF64 {
2242 a: node_offset(arena, node.inputs[0]),
2243 b: node_offset(arena, node.inputs[1]),
2244 c: node_offset(arena, node.id),
2245 batch: batch_prod as u32,
2246 m: m_dim as u32,
2247 k: k_dim as u32,
2248 n: n as u32,
2249 }
2250 } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2251 let r = shape.rank();
2254 let mut batch_prod = 1usize;
2255 for d in 0..r - 2 {
2256 batch_prod *= shape.dim(d).unwrap_static();
2257 }
2258 let m_dim = shape.dim(r - 2).unwrap_static();
2259 let k_dim = a_shape.dim(r - 1).unwrap_static();
2260 debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2261 Thunk::BatchedSgemm {
2262 a: node_offset(arena, node.inputs[0]),
2263 b: node_offset(arena, node.inputs[1]),
2264 c: node_offset(arena, node.id),
2265 batch: batch_prod as u32,
2266 m: m_dim as u32,
2267 k: k_dim as u32,
2268 n: n as u32,
2269 }
2270 } else {
2271 let total = shape.num_elements().unwrap();
2272 let m = total / n;
2273 let a_len = get_len(graph, node.inputs[0]);
2274 let k = a_len / m;
2275 match shape.dtype() {
2276 rlx_ir::DType::F64 => Thunk::Dgemm {
2277 a: node_offset(arena, node.inputs[0]),
2278 b: node_offset(arena, node.inputs[1]),
2279 c: node_offset(arena, node.id),
2280 m: m as u32,
2281 k: k as u32,
2282 n: n as u32,
2283 },
2284 _ => Thunk::Sgemm {
2285 a: node_offset(arena, node.inputs[0]),
2286 b: node_offset(arena, node.inputs[1]),
2287 c: node_offset(arena, node.id),
2288 m: m as u32,
2289 k: k as u32,
2290 n: n as u32,
2291 },
2292 }
2293 }
2294 }
2295
2296 Op::Binary(op) => {
2297 let lhs_len = get_len(graph, node.inputs[0]);
2298 let rhs_len = get_len(graph, node.inputs[1]);
2299 let out_len = node.shape.num_elements().unwrap();
2300 if node.shape.dtype() == rlx_ir::DType::C64 {
2301 match op {
2305 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2306 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2307 "Op::Binary({op:?}) on DType::C64: complex \
2308 max/min/pow have no single natural definition \
2309 — caller should drop to 2N-real-block (see \
2310 spike-ac) and pick a convention there"
2311 ),
2312 }
2313 }
2314 let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2318 if lhs_len == out_len && rhs_len == out_len {
2319 (Vec::new(), Vec::new(), Vec::new())
2320 } else {
2321 let lhs_dims = get_static_dims(graph, node.inputs[0]);
2322 let rhs_dims = get_static_dims(graph, node.inputs[1]);
2323 let out_dims_v = get_static_dims(graph, node.id);
2324 if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2325 (Vec::new(), Vec::new(), Vec::new())
2330 } else {
2331 let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2332 let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2333 let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2334 (od, ls, rs)
2335 }
2336 };
2337 if node.shape.dtype() == rlx_ir::DType::C64 {
2338 Thunk::BinaryFullC64 {
2339 lhs: node_offset(arena, node.inputs[0]),
2340 rhs: node_offset(arena, node.inputs[1]),
2341 dst: node_offset(arena, node.id),
2342 len: out_len as u32,
2343 lhs_len: lhs_len as u32,
2344 rhs_len: rhs_len as u32,
2345 op: *op,
2346 out_dims_bcast,
2347 bcast_lhs_strides,
2348 bcast_rhs_strides,
2349 }
2350 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2351 Thunk::BinaryFullF64 {
2354 lhs: node_offset(arena, node.inputs[0]),
2355 rhs: node_offset(arena, node.inputs[1]),
2356 dst: node_offset(arena, node.id),
2357 len: out_len as u32,
2358 lhs_len: lhs_len as u32,
2359 rhs_len: rhs_len as u32,
2360 op: *op,
2361 out_dims_bcast,
2362 bcast_lhs_strides,
2363 bcast_rhs_strides,
2364 }
2365 } else if matches!(op, BinaryOp::Add)
2366 && rhs_len < out_len
2367 && out_len % rhs_len == 0
2368 && is_trailing_bias_broadcast(
2369 graph.node(node.inputs[1]).shape.dims(),
2370 graph.node(node.id).shape.dims(),
2371 )
2372 {
2373 Thunk::BiasAdd {
2383 src: node_offset(arena, node.inputs[0]),
2384 bias: node_offset(arena, node.inputs[1]),
2385 dst: node_offset(arena, node.id),
2386 m: (out_len / rhs_len) as u32,
2387 n: rhs_len as u32,
2388 }
2389 } else {
2390 let lhs_len = get_len(graph, node.inputs[0]);
2391 Thunk::BinaryFull {
2392 lhs: node_offset(arena, node.inputs[0]),
2393 rhs: node_offset(arena, node.inputs[1]),
2394 dst: node_offset(arena, node.id),
2395 len: out_len as u32,
2396 lhs_len: lhs_len as u32,
2397 rhs_len: rhs_len as u32,
2398 op: *op,
2399 out_dims_bcast,
2400 bcast_lhs_strides,
2401 bcast_rhs_strides,
2402 }
2403 }
2404 }
2405
2406 Op::Activation(act) => {
2407 let len = node.shape.num_elements().unwrap();
2408 let in_off = node_offset(arena, node.inputs[0]);
2409 let out_off = node_offset(arena, node.id);
2410 if node.shape.dtype() == rlx_ir::DType::C64 {
2411 match act {
2416 Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2417 other => panic!(
2418 "Op::Activation({other:?}) on DType::C64: no \
2419 natural complex extension — supported on C64: \
2420 Neg, Exp, Log, Sqrt"
2421 ),
2422 }
2423 Thunk::ActivationC64 {
2424 src: in_off,
2425 dst: out_off,
2426 len: len as u32,
2427 kind: *act,
2428 }
2429 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2430 Thunk::ActivationF64 {
2431 src: in_off,
2432 dst: out_off,
2433 len: len as u32,
2434 kind: *act,
2435 }
2436 } else if in_off == out_off {
2437 Thunk::ActivationInPlace {
2441 data: out_off,
2442 len: len as u32,
2443 act: *act,
2444 }
2445 } else {
2446 thunks.push(Thunk::Copy {
2450 src: in_off,
2451 dst: out_off,
2452 len: len as u32,
2453 });
2454 Thunk::ActivationInPlace {
2455 data: out_off,
2456 len: len as u32,
2457 act: *act,
2458 }
2459 }
2460 }
2461
2462 Op::Gather { axis } if *axis == 0 => {
2463 let table_shape = &graph.node(node.inputs[0]).shape;
2464 let table_total = table_shape.num_elements().unwrap();
2465 let trailing: usize = (1..table_shape.rank())
2466 .map(|i| table_shape.dim(i).unwrap_static())
2467 .product();
2468 let idx_len = get_len(graph, node.inputs[1]);
2469 Thunk::Gather {
2470 table: node_offset(arena, node.inputs[0]),
2471 table_len: table_total as u32,
2472 idx: node_offset(arena, node.inputs[1]),
2473 dst: node_offset(arena, node.id),
2474 num_idx: idx_len as u32,
2475 trailing: trailing as u32,
2476 }
2477 }
2478
2479 Op::Gather { axis } => {
2480 let table_shape = &graph.node(node.inputs[0]).shape;
2482 let rank = table_shape.rank();
2483 let outer: usize = (0..*axis)
2484 .map(|i| table_shape.dim(i).unwrap_static())
2485 .product::<usize>()
2486 .max(1);
2487 let trailing: usize = (*axis + 1..rank)
2488 .map(|i| table_shape.dim(i).unwrap_static())
2489 .product::<usize>()
2490 .max(1);
2491 let axis_dim = table_shape.dim(*axis).unwrap_static();
2492 let idx_len = get_len(graph, node.inputs[1]);
2493 Thunk::GatherAxis {
2494 table: node_offset(arena, node.inputs[0]),
2495 idx: node_offset(arena, node.inputs[1]),
2496 dst: node_offset(arena, node.id),
2497 outer: outer as u32,
2498 axis_dim: axis_dim as u32,
2499 num_idx: idx_len as u32,
2500 trailing: trailing as u32,
2501 }
2502 }
2503
2504 Op::Narrow { axis, start, len } => {
2505 let in_shape = &graph.node(node.inputs[0]).shape;
2506 let elem_bytes = in_shape.dtype().size_bytes() as u8;
2507 let rank = in_shape.rank();
2508 let outer: usize = (0..*axis)
2509 .map(|i| in_shape.dim(i).unwrap_static())
2510 .product::<usize>()
2511 .max(1);
2512 let inner: usize = (*axis + 1..rank)
2513 .map(|i| in_shape.dim(i).unwrap_static())
2514 .product::<usize>()
2515 .max(1);
2516 let in_axis = in_shape.dim(*axis).unwrap_static();
2517 let src_byte_offset =
2518 node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2519 Thunk::Narrow {
2520 src: src_byte_offset,
2521 dst: node_offset(arena, node.id),
2522 outer: outer as u32,
2523 src_stride: (in_axis * inner) as u32, dst_stride: (*len * inner) as u32, inner: (*len * inner) as u32, elem_bytes,
2527 }
2528 }
2529
2530 Op::Reshape { .. } | Op::Cast { .. } => {
2531 let len = node.shape.num_elements().unwrap();
2533 let src = node_offset(arena, node.inputs[0]);
2534 let dst = node_offset(arena, node.id);
2535 match node.shape.dtype() {
2536 rlx_ir::DType::F64 => Thunk::CopyF64 {
2537 src,
2538 dst,
2539 len: len as u32,
2540 },
2541 _ => Thunk::Copy {
2542 src,
2543 dst,
2544 len: len as u32,
2545 },
2546 }
2547 }
2548
2549 Op::Quantize {
2550 axis,
2551 scales,
2552 zero_points,
2553 } => {
2554 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2555 Thunk::Quantize {
2556 x: node_offset(arena, node.inputs[0]),
2557 q: node_offset(arena, node.id),
2558 len: node.shape.num_elements().unwrap() as u32,
2559 chan_axis: chan_axis as u32,
2560 chan_dim: chan_dim as u32,
2561 inner: inner as u32,
2562 scales: scales.clone(),
2563 zero_points: zero_points.clone(),
2564 }
2565 }
2566
2567 Op::FakeQuantize {
2568 bits,
2569 axis,
2570 ste,
2571 scale_mode,
2572 } => {
2573 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2574 let state_off = match scale_mode {
2575 rlx_ir::op::ScaleMode::PerBatch => None,
2576 rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2577 debug_assert_eq!(
2579 node.inputs.len(),
2580 2,
2581 "EMA/Fixed FakeQuantize needs a state input"
2582 );
2583 Some(node_offset(arena, node.inputs[1]))
2584 }
2585 };
2586 Thunk::FakeQuantize {
2587 x: node_offset(arena, node.inputs[0]),
2588 out: node_offset(arena, node.id),
2589 len: node.shape.num_elements().unwrap() as u32,
2590 chan_axis: chan_axis as u32,
2591 chan_dim: chan_dim as u32,
2592 inner: inner as u32,
2593 bits: *bits,
2594 ste: *ste,
2595 scale_mode: *scale_mode,
2596 state_off,
2597 }
2598 }
2599
2600 Op::FakeQuantizeLSQ { bits, axis } => {
2601 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2602 Thunk::FakeQuantizeLSQ {
2603 x: node_offset(arena, node.inputs[0]),
2604 scale_off: node_offset(arena, node.inputs[1]),
2605 out: node_offset(arena, node.id),
2606 len: node.shape.num_elements().unwrap() as u32,
2607 chan_axis: chan_axis as u32,
2608 chan_dim: chan_dim as u32,
2609 inner: inner as u32,
2610 bits: *bits,
2611 }
2612 }
2613
2614 Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2615 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2616 Thunk::FakeQuantizeLSQBackwardX {
2617 x: node_offset(arena, node.inputs[0]),
2618 scale_off: node_offset(arena, node.inputs[1]),
2619 dy: node_offset(arena, node.inputs[2]),
2620 dx: node_offset(arena, node.id),
2621 len: node.shape.num_elements().unwrap() as u32,
2622 chan_axis: chan_axis as u32,
2623 chan_dim: chan_dim as u32,
2624 inner: inner as u32,
2625 bits: *bits,
2626 }
2627 }
2628
2629 Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2630 let in_shape = &graph.node(node.inputs[0]).shape;
2633 let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2634 Thunk::FakeQuantizeLSQBackwardScale {
2635 x: node_offset(arena, node.inputs[0]),
2636 scale_off: node_offset(arena, node.inputs[1]),
2637 dy: node_offset(arena, node.inputs[2]),
2638 dscale: node_offset(arena, node.id),
2639 len: in_shape.num_elements().unwrap() as u32,
2640 chan_axis: chan_axis as u32,
2641 chan_dim: chan_dim as u32,
2642 inner: inner as u32,
2643 bits: *bits,
2644 }
2645 }
2646
2647 Op::FakeQuantizeBackward { bits, axis, ste } => {
2648 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2649 Thunk::FakeQuantizeBackward {
2650 x: node_offset(arena, node.inputs[0]),
2651 dy: node_offset(arena, node.inputs[1]),
2652 dx: node_offset(arena, node.id),
2653 len: node.shape.num_elements().unwrap() as u32,
2654 chan_axis: chan_axis as u32,
2655 chan_dim: chan_dim as u32,
2656 inner: inner as u32,
2657 bits: *bits,
2658 ste: *ste,
2659 }
2660 }
2661
2662 Op::Dequantize {
2663 axis,
2664 scales,
2665 zero_points,
2666 } => {
2667 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2668 Thunk::Dequantize {
2669 q: node_offset(arena, node.inputs[0]),
2670 x: node_offset(arena, node.id),
2671 len: node.shape.num_elements().unwrap() as u32,
2672 chan_axis: chan_axis as u32,
2673 chan_dim: chan_dim as u32,
2674 inner: inner as u32,
2675 scales: scales.clone(),
2676 zero_points: zero_points.clone(),
2677 }
2678 }
2679
2680 Op::Expand { .. } => {
2681 let in_shape = &graph.node(node.inputs[0]).shape;
2686 let out_shape = &node.shape;
2687 let in_rank = in_shape.rank();
2688 let out_rank = out_shape.rank();
2689 let pad = out_rank.saturating_sub(in_rank);
2691 let in_dims: Vec<usize> = (0..out_rank)
2692 .map(|i| {
2693 if i < pad {
2694 1
2695 } else {
2696 in_shape.dim(i - pad).unwrap_static()
2697 }
2698 })
2699 .collect();
2700 let mut in_strides_full = vec![1usize; out_rank];
2702 for d in (0..out_rank.saturating_sub(1)).rev() {
2703 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2704 }
2705 let out_dims: Vec<u32> = (0..out_rank)
2706 .map(|i| out_shape.dim(i).unwrap_static() as u32)
2707 .collect();
2708 let in_strides: Vec<u32> = (0..out_rank)
2710 .map(|i| {
2711 if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2712 0
2713 } else {
2714 in_strides_full[i] as u32
2715 }
2716 })
2717 .collect();
2718 let in_total = in_dims.iter().product::<usize>() as u32;
2719 let src = node_offset(arena, node.inputs[0]);
2720 let dst = node_offset(arena, node.id);
2721 match node.shape.dtype() {
2722 rlx_ir::DType::F64 => Thunk::TransposeF64 {
2723 src,
2724 dst,
2725 in_total,
2726 out_dims,
2727 in_strides,
2728 },
2729 _ => Thunk::Transpose {
2730 src,
2731 dst,
2732 in_total,
2733 out_dims,
2734 in_strides,
2735 },
2736 }
2737 }
2738
2739 Op::RmsNorm { eps, .. } => {
2740 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2741 let total = node.shape.num_elements().unwrap();
2742 Thunk::RmsNorm {
2743 src: node_offset(arena, node.inputs[0]),
2744 g: node_offset(arena, node.inputs[1]),
2745 b: node_offset(arena, node.inputs[2]),
2746 dst: node_offset(arena, node.id),
2747 rows: (total / h) as u32,
2748 h: h as u32,
2749 eps: *eps,
2750 }
2751 }
2752
2753 Op::LayerNorm { eps, .. } => {
2754 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2755 let total = node.shape.num_elements().unwrap();
2756 Thunk::LayerNorm {
2757 src: node_offset(arena, node.inputs[0]),
2758 g: node_offset(arena, node.inputs[1]),
2759 b: node_offset(arena, node.inputs[2]),
2760 dst: node_offset(arena, node.id),
2761 rows: (total / h) as u32,
2762 h: h as u32,
2763 eps: *eps,
2764 }
2765 }
2766
2767 Op::GroupNorm { num_groups, eps } => {
2768 let in_shape = &graph.node(node.inputs[0]).shape;
2769 Thunk::GroupNorm {
2770 src: node_offset(arena, node.inputs[0]),
2771 g: node_offset(arena, node.inputs[1]),
2772 b: node_offset(arena, node.inputs[2]),
2773 dst: node_offset(arena, node.id),
2774 n: in_shape.dim(0).unwrap_static() as u32,
2775 c: in_shape.dim(1).unwrap_static() as u32,
2776 h: in_shape.dim(2).unwrap_static() as u32,
2777 w: in_shape.dim(3).unwrap_static() as u32,
2778 num_groups: *num_groups as u32,
2779 eps: *eps,
2780 }
2781 }
2782
2783 Op::LayerNorm2d { eps } => {
2784 let in_shape = &graph.node(node.inputs[0]).shape;
2785 Thunk::LayerNorm2d {
2786 src: node_offset(arena, node.inputs[0]),
2787 g: node_offset(arena, node.inputs[1]),
2788 b: node_offset(arena, node.inputs[2]),
2789 dst: node_offset(arena, node.id),
2790 n: in_shape.dim(0).unwrap_static() as u32,
2791 c: in_shape.dim(1).unwrap_static() as u32,
2792 h: in_shape.dim(2).unwrap_static() as u32,
2793 w: in_shape.dim(3).unwrap_static() as u32,
2794 eps: *eps,
2795 }
2796 }
2797
2798 Op::ConvTranspose2d {
2799 kernel_size,
2800 stride,
2801 padding,
2802 dilation,
2803 output_padding: _,
2804 groups,
2805 } => {
2806 let in_shape = &graph.node(node.inputs[0]).shape;
2807 let out_shape = &node.shape;
2808 Thunk::ConvTranspose2d {
2809 src: node_offset(arena, node.inputs[0]),
2810 weight: node_offset(arena, node.inputs[1]),
2811 dst: node_offset(arena, node.id),
2812 n: in_shape.dim(0).unwrap_static() as u32,
2813 c_in: in_shape.dim(1).unwrap_static() as u32,
2814 h: in_shape.dim(2).unwrap_static() as u32,
2815 w_in: in_shape.dim(3).unwrap_static() as u32,
2816 c_out: out_shape.dim(1).unwrap_static() as u32,
2817 h_out: out_shape.dim(2).unwrap_static() as u32,
2818 w_out: out_shape.dim(3).unwrap_static() as u32,
2819 kh: kernel_size[0] as u32,
2820 kw: kernel_size[1] as u32,
2821 sh: stride.first().copied().unwrap_or(1) as u32,
2822 sw: stride.get(1).copied().unwrap_or(1) as u32,
2823 ph: padding.first().copied().unwrap_or(0) as u32,
2824 pw: padding.get(1).copied().unwrap_or(0) as u32,
2825 dh: dilation.first().copied().unwrap_or(1) as u32,
2826 dw: dilation.get(1).copied().unwrap_or(1) as u32,
2827 groups: *groups as u32,
2828 }
2829 }
2830
2831 Op::ResizeNearest2x => {
2832 let in_shape = &graph.node(node.inputs[0]).shape;
2833 Thunk::ResizeNearest2x {
2834 src: node_offset(arena, node.inputs[0]),
2835 dst: node_offset(arena, node.id),
2836 n: in_shape.dim(0).unwrap_static() as u32,
2837 c: in_shape.dim(1).unwrap_static() as u32,
2838 h: in_shape.dim(2).unwrap_static() as u32,
2839 w: in_shape.dim(3).unwrap_static() as u32,
2840 }
2841 }
2842
2843 Op::AxialRope2d {
2844 end_x,
2845 end_y,
2846 head_dim,
2847 num_heads,
2848 theta,
2849 repeat_factor,
2850 } => {
2851 let in_shape = &graph.node(node.inputs[0]).shape;
2852 let batch = in_shape.dim(0).unwrap_static() as u32;
2853 let seq = in_shape.dim(1).unwrap_static() as u32;
2854 let hidden = in_shape.dim(2).unwrap_static() as u32;
2855 Thunk::AxialRope2d {
2856 src: node_offset(arena, node.inputs[0]),
2857 dst: node_offset(arena, node.id),
2858 batch,
2859 seq,
2860 hidden,
2861 end_x: *end_x as u32,
2862 end_y: *end_y as u32,
2863 head_dim: *head_dim as u32,
2864 num_heads: *num_heads as u32,
2865 theta: *theta,
2866 repeat_factor: *repeat_factor as u32,
2867 }
2868 }
2869
2870 Op::Softmax { axis } => {
2871 let rank = node.shape.rank();
2872 let ax = if *axis < 0 {
2873 (rank as i32 + axis) as usize
2874 } else {
2875 *axis as usize
2876 };
2877 let cols = node.shape.dim(ax).unwrap_static();
2878 let total = node.shape.num_elements().unwrap();
2879 let in_off = node_offset(arena, node.inputs[0]);
2880 let out_off = node_offset(arena, node.id);
2881 if in_off != out_off {
2887 thunks.push(Thunk::Copy {
2888 src: in_off,
2889 dst: out_off,
2890 len: total as u32,
2891 });
2892 }
2893 Thunk::Softmax {
2894 data: out_off,
2895 rows: (total / cols) as u32,
2896 cols: cols as u32,
2897 }
2898 }
2899
2900 Op::SelectiveScan { state_size } => {
2901 let in_shape = &graph.node(node.inputs[0]).shape;
2902 let (batch, seq, hidden) = (
2903 in_shape.dim(0).unwrap_static(),
2904 in_shape.dim(1).unwrap_static(),
2905 in_shape.dim(2).unwrap_static(),
2906 );
2907 Thunk::SelectiveScan {
2908 x: node_offset(arena, node.inputs[0]),
2909 delta: node_offset(arena, node.inputs[1]),
2910 a: node_offset(arena, node.inputs[2]),
2911 b: node_offset(arena, node.inputs[3]),
2912 c: node_offset(arena, node.inputs[4]),
2913 dst: node_offset(arena, node.id),
2914 batch: batch as u32,
2915 seq: seq as u32,
2916 hidden: hidden as u32,
2917 state_size: *state_size as u32,
2918 }
2919 }
2920
2921 Op::GatedDeltaNet {
2922 state_size,
2923 carry_state,
2924 } => {
2925 let q_shape = &graph.node(node.inputs[0]).shape;
2926 let (batch, seq, heads) = (
2927 q_shape.dim(0).unwrap_static(),
2928 q_shape.dim(1).unwrap_static(),
2929 q_shape.dim(2).unwrap_static(),
2930 );
2931 let state_off = if *carry_state {
2932 node_offset(arena, node.inputs[5])
2933 } else {
2934 0
2935 };
2936 Thunk::GatedDeltaNet {
2937 q: node_offset(arena, node.inputs[0]),
2938 k: node_offset(arena, node.inputs[1]),
2939 v: node_offset(arena, node.inputs[2]),
2940 g: node_offset(arena, node.inputs[3]),
2941 beta: node_offset(arena, node.inputs[4]),
2942 state: state_off,
2943 dst: node_offset(arena, node.id),
2944 batch: batch as u32,
2945 seq: seq as u32,
2946 heads: heads as u32,
2947 state_size: *state_size as u32,
2948 }
2949 }
2950
2951 Op::QMatMul {
2952 x_zp,
2953 w_zp,
2954 out_zp,
2955 mult,
2956 } => {
2957 let x_shape = &graph.node(node.inputs[0]).shape;
2958 let w_shape = &graph.node(node.inputs[1]).shape;
2959 let m = x_shape.dim(0).unwrap_static();
2960 let k = x_shape.dim(1).unwrap_static();
2961 let n = w_shape.dim(1).unwrap_static();
2962 Thunk::QMatMul {
2963 x: node_offset(arena, node.inputs[0]),
2964 w: node_offset(arena, node.inputs[1]),
2965 bias: node_offset(arena, node.inputs[2]),
2966 out: node_offset(arena, node.id),
2967 m: m as u32,
2968 k: k as u32,
2969 n: n as u32,
2970 x_zp: *x_zp,
2971 w_zp: *w_zp,
2972 out_zp: *out_zp,
2973 mult: *mult,
2974 }
2975 }
2976
2977 Op::QConv2d {
2978 kernel_size,
2979 stride,
2980 padding,
2981 dilation,
2982 groups,
2983 x_zp,
2984 w_zp,
2985 out_zp,
2986 mult,
2987 } => {
2988 let in_shape = &graph.node(node.inputs[0]).shape;
2989 let w_shape = &graph.node(node.inputs[1]).shape;
2990 let out_shape = &node.shape;
2991 if kernel_size.len() == 2
2992 && in_shape.rank() == 4
2993 && w_shape.rank() == 4
2994 && out_shape.rank() == 4
2995 {
2996 Thunk::QConv2d {
2997 x: node_offset(arena, node.inputs[0]),
2998 w: node_offset(arena, node.inputs[1]),
2999 bias: node_offset(arena, node.inputs[2]),
3000 out: node_offset(arena, node.id),
3001 n: in_shape.dim(0).unwrap_static() as u32,
3002 c_in: in_shape.dim(1).unwrap_static() as u32,
3003 h: in_shape.dim(2).unwrap_static() as u32,
3004 w_in: in_shape.dim(3).unwrap_static() as u32,
3005 c_out: out_shape.dim(1).unwrap_static() as u32,
3006 h_out: out_shape.dim(2).unwrap_static() as u32,
3007 w_out: out_shape.dim(3).unwrap_static() as u32,
3008 kh: kernel_size[0] as u32,
3009 kw: kernel_size[1] as u32,
3010 sh: stride.first().copied().unwrap_or(1) as u32,
3011 sw: stride.get(1).copied().unwrap_or(1) as u32,
3012 ph: padding.first().copied().unwrap_or(0) as u32,
3013 pw: padding.get(1).copied().unwrap_or(0) as u32,
3014 dh: dilation.first().copied().unwrap_or(1) as u32,
3015 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3016 groups: *groups as u32,
3017 x_zp: *x_zp,
3018 w_zp: *w_zp,
3019 out_zp: *out_zp,
3020 mult: *mult,
3021 }
3022 } else {
3023 Thunk::Nop
3024 }
3025 }
3026
3027 Op::DequantMatMul { scheme } => {
3028 use rlx_ir::quant::QuantScheme;
3029 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3030 let total = node.shape.num_elements().unwrap();
3031 let m = total / n.max(1);
3032 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3033 let k = x_total / m.max(1);
3034 if scheme.is_gguf() {
3035 Thunk::DequantMatMulGguf {
3036 x: node_offset(arena, node.inputs[0]),
3037 w_q: node_offset(arena, node.inputs[1]),
3038 dst: node_offset(arena, node.id),
3039 m: m as u32,
3040 k: k as u32,
3041 n: n as u32,
3042 scheme: *scheme,
3043 }
3044 } else {
3045 match scheme {
3046 QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3047 x: node_offset(arena, node.inputs[0]),
3048 w_q: node_offset(arena, node.inputs[1]),
3049 scale: node_offset(arena, node.inputs[2]),
3050 global_scale: node_offset(arena, node.inputs[3]),
3051 dst: node_offset(arena, node.id),
3052 m: m as u32,
3053 k: k as u32,
3054 n: n as u32,
3055 },
3056 QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3057 x: node_offset(arena, node.inputs[0]),
3058 w_q: node_offset(arena, node.inputs[1]),
3059 scale: node_offset(arena, node.inputs[2]),
3060 zp: node_offset(arena, node.inputs[3]),
3061 dst: node_offset(arena, node.id),
3062 m: m as u32,
3063 k: k as u32,
3064 n: n as u32,
3065 block_size: *block_size,
3066 is_asymmetric: false,
3067 },
3068 QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3069 x: node_offset(arena, node.inputs[0]),
3070 w_q: node_offset(arena, node.inputs[1]),
3071 scale: node_offset(arena, node.inputs[2]),
3072 dst: node_offset(arena, node.id),
3073 m: m as u32,
3074 k: k as u32,
3075 n: n as u32,
3076 e5m2: false,
3077 },
3078 QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3079 x: node_offset(arena, node.inputs[0]),
3080 w_q: node_offset(arena, node.inputs[1]),
3081 scale: node_offset(arena, node.inputs[2]),
3082 dst: node_offset(arena, node.id),
3083 m: m as u32,
3084 k: k as u32,
3085 n: n as u32,
3086 e5m2: true,
3087 },
3088 QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3089 x: node_offset(arena, node.inputs[0]),
3090 w_q: node_offset(arena, node.inputs[1]),
3091 scale: node_offset(arena, node.inputs[2]),
3092 zp: node_offset(arena, node.inputs[3]),
3093 dst: node_offset(arena, node.id),
3094 m: m as u32,
3095 k: k as u32,
3096 n: n as u32,
3097 block_size: *block_size,
3098 is_asymmetric: false,
3099 },
3100 QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3101 x: node_offset(arena, node.inputs[0]),
3102 w_q: node_offset(arena, node.inputs[1]),
3103 scale: node_offset(arena, node.inputs[2]),
3104 zp: node_offset(arena, node.inputs[3]),
3105 dst: node_offset(arena, node.id),
3106 m: m as u32,
3107 k: k as u32,
3108 n: n as u32,
3109 block_size: *block_size,
3110 is_asymmetric: true,
3111 },
3112 other => panic!(
3113 "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3114 ),
3115 }
3116 }
3117 }
3118
3119 Op::LoraMatMul { scale } => {
3120 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3122 let total = node.shape.num_elements().unwrap();
3123 let m = total / n.max(1);
3124 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3125 let k = x_total / m.max(1);
3126 let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3127 let r = a_total / k.max(1);
3128 Thunk::LoraMatMul {
3129 x: node_offset(arena, node.inputs[0]),
3130 w: node_offset(arena, node.inputs[1]),
3131 a: node_offset(arena, node.inputs[2]),
3132 b: node_offset(arena, node.inputs[3]),
3133 dst: node_offset(arena, node.id),
3134 m: m as u32,
3135 k: k as u32,
3136 n: n as u32,
3137 r: r as u32,
3138 scale: *scale,
3139 }
3140 }
3141
3142 Op::Sample {
3143 top_k,
3144 top_p,
3145 temperature,
3146 seed,
3147 } => {
3148 let in_shape = &graph.node(node.inputs[0]).shape;
3149 let (batch, vocab) = if in_shape.rank() >= 2 {
3151 (
3152 in_shape.dim(0).unwrap_static(),
3153 in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3154 )
3155 } else {
3156 (1, in_shape.num_elements().unwrap_or(0))
3157 };
3158 Thunk::Sample {
3159 logits: node_offset(arena, node.inputs[0]),
3160 dst: node_offset(arena, node.id),
3161 batch: batch as u32,
3162 vocab: vocab as u32,
3163 top_k: *top_k as u32,
3164 top_p: *top_p,
3165 temperature: *temperature,
3166 seed: *seed,
3167 }
3168 }
3169
3170 Op::Cumsum { axis, exclusive } => {
3171 let rank = node.shape.rank();
3176 let ax = if *axis < 0 {
3177 (rank as i32 + axis) as usize
3178 } else {
3179 *axis as usize
3180 };
3181 assert_eq!(
3182 ax,
3183 rank - 1,
3184 "Cumsum only supports the last axis on CPU today"
3185 );
3186 let cols = node.shape.dim(ax).unwrap_static();
3187 let total = node.shape.num_elements().unwrap();
3188 Thunk::Cumsum {
3189 src: node_offset(arena, node.inputs[0]),
3190 dst: node_offset(arena, node.id),
3191 rows: (total / cols) as u32,
3192 cols: cols as u32,
3193 exclusive: *exclusive,
3194 }
3195 }
3196
3197 Op::Attention {
3198 num_heads,
3199 head_dim,
3200 mask_kind,
3201 score_scale: _,
3202 attn_logit_softcap: _,
3203 } => {
3204 let q_shape = &graph.node(node.inputs[0]).shape;
3210 let k_shape = &graph.node(node.inputs[1]).shape;
3211 let rank = q_shape.rank();
3212 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3213 let d1 = q_shape.dim(1).unwrap_static();
3214 let d2 = q_shape.dim(2).unwrap_static();
3215 if d1 == *num_heads {
3216 (
3218 q_shape.dim(0).unwrap_static(),
3219 d2,
3220 k_shape.dim(2).unwrap_static(),
3221 true,
3222 )
3223 } else {
3224 (
3226 q_shape.dim(0).unwrap_static(),
3227 d1,
3228 k_shape.dim(1).unwrap_static(),
3229 false,
3230 )
3231 }
3232 } else if rank >= 3 {
3233 (
3234 q_shape.dim(0).unwrap_static(),
3235 q_shape.dim(1).unwrap_static(),
3236 k_shape.dim(1).unwrap_static(),
3237 false,
3238 )
3239 } else {
3240 (
3241 1,
3242 q_shape.dim(0).unwrap_static(),
3243 k_shape.dim(0).unwrap_static(),
3244 false,
3245 )
3246 };
3247 let mask_off = if matches!(
3248 mask_kind,
3249 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3250 ) {
3251 node_offset(arena, node.inputs[3])
3252 } else {
3253 0
3254 };
3255 let hs = (*num_heads * *head_dim) as u32;
3256 Thunk::Attention {
3257 q: node_offset(arena, node.inputs[0]),
3258 k: node_offset(arena, node.inputs[1]),
3259 v: node_offset(arena, node.inputs[2]),
3260 mask: mask_off,
3261 out: node_offset(arena, node.id),
3262 batch: batch as u32,
3263 seq: seq as u32,
3264 kv_seq: kv_seq as u32,
3265 heads: *num_heads as u32,
3266 head_dim: *head_dim as u32,
3267 mask_kind: *mask_kind,
3268 q_row_stride: hs,
3272 k_row_stride: hs,
3273 v_row_stride: hs,
3274 bhsd,
3275 }
3276 }
3277
3278 Op::AttentionBackward {
3279 num_heads,
3280 head_dim,
3281 mask_kind,
3282 wrt,
3283 } => {
3284 let q_shape = &graph.node(node.inputs[0]).shape;
3285 let k_shape = &graph.node(node.inputs[1]).shape;
3286 let rank = q_shape.rank();
3287 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3288 let d1 = q_shape.dim(1).unwrap_static();
3289 let d2 = q_shape.dim(2).unwrap_static();
3290 if d1 == *num_heads {
3291 (
3292 q_shape.dim(0).unwrap_static(),
3293 d2,
3294 k_shape.dim(2).unwrap_static(),
3295 true,
3296 )
3297 } else {
3298 (
3299 q_shape.dim(0).unwrap_static(),
3300 d1,
3301 k_shape.dim(1).unwrap_static(),
3302 false,
3303 )
3304 }
3305 } else if rank >= 3 {
3306 (
3307 q_shape.dim(0).unwrap_static(),
3308 q_shape.dim(1).unwrap_static(),
3309 k_shape.dim(1).unwrap_static(),
3310 false,
3311 )
3312 } else {
3313 (
3314 1,
3315 q_shape.dim(0).unwrap_static(),
3316 k_shape.dim(0).unwrap_static(),
3317 false,
3318 )
3319 };
3320 let mask_off = if matches!(
3321 mask_kind,
3322 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3323 ) {
3324 node_offset(arena, node.inputs[4])
3325 } else {
3326 0
3327 };
3328 Thunk::AttentionBackward {
3329 q: node_offset(arena, node.inputs[0]),
3330 k: node_offset(arena, node.inputs[1]),
3331 v: node_offset(arena, node.inputs[2]),
3332 dy: node_offset(arena, node.inputs[3]),
3333 mask: mask_off,
3334 out: node_offset(arena, node.id),
3335 batch: batch as u32,
3336 seq: seq as u32,
3337 kv_seq: kv_seq as u32,
3338 heads: *num_heads as u32,
3339 head_dim: *head_dim as u32,
3340 mask_kind: *mask_kind,
3341 wrt: *wrt,
3342 bhsd,
3343 }
3344 }
3345
3346 Op::FusedAttentionBlock {
3347 num_heads,
3348 head_dim,
3349 has_bias,
3350 has_rope,
3351 } => {
3352 let x_shape = &graph.node(node.inputs[0]).shape;
3353 let (batch, seq) = if x_shape.rank() >= 3 {
3354 (
3355 x_shape.dim(0).unwrap_static(),
3356 x_shape.dim(1).unwrap_static(),
3357 )
3358 } else {
3359 let total = x_shape.num_elements().unwrap();
3360 let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3361 (total / (s * num_heads * head_dim), s)
3362 };
3363 let hs = (*num_heads * *head_dim) as u32;
3364 let mut idx = 4;
3366 let (qkv_b_off, out_b_off) = if *has_bias {
3367 let qb = node_offset(arena, node.inputs[idx]);
3368 let ob = node_offset(arena, node.inputs[idx + 1]);
3369 idx += 2;
3370 (qb, ob)
3371 } else {
3372 (0, 0)
3373 };
3374 let (cos_off, sin_off, cl) = if *has_rope {
3375 let c = node_offset(arena, node.inputs[idx]);
3376 let s = node_offset(arena, node.inputs[idx + 1]);
3377 let clen = get_len(graph, node.inputs[idx]);
3378 (c, s, clen as u32)
3379 } else {
3380 (0, 0, 0)
3381 };
3382
3383 Thunk::FusedAttnBlock {
3384 hidden: node_offset(arena, node.inputs[0]),
3385 qkv_w: node_offset(arena, node.inputs[1]),
3386 out_w: node_offset(arena, node.inputs[2]),
3387 mask: node_offset(arena, node.inputs[3]),
3388 out: node_offset(arena, node.id),
3389 qkv_b: qkv_b_off,
3390 out_b: out_b_off,
3391 cos: cos_off,
3392 sin: sin_off,
3393 cos_len: cl,
3394 batch: batch as u32,
3395 seq: seq as u32,
3396 hs,
3397 nh: *num_heads as u32,
3398 dh: *head_dim as u32,
3399 has_bias: *has_bias,
3400 has_rope: *has_rope,
3401 }
3402 }
3403
3404 Op::Rope { head_dim, n_rot } => {
3405 let x_shape = &graph.node(node.inputs[0]).shape;
3406 let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3407 (
3408 x_shape.dim(0).unwrap_static(),
3409 x_shape.dim(1).unwrap_static(),
3410 x_shape.dim(2).unwrap_static(),
3411 )
3412 } else {
3413 let total = x_shape.num_elements().unwrap();
3414 (
3415 1,
3416 x_shape.dim(0).unwrap_static(),
3417 total / x_shape.dim(0).unwrap_static(),
3418 )
3419 };
3420 let cos_len = get_len(graph, node.inputs[1]);
3421 Thunk::Rope {
3422 src: node_offset(arena, node.inputs[0]),
3423 cos: node_offset(arena, node.inputs[1]),
3424 sin: node_offset(arena, node.inputs[2]),
3425 dst: node_offset(arena, node.id),
3426 batch: batch as u32,
3427 seq: seq as u32,
3428 hidden: hidden as u32,
3429 head_dim: *head_dim as u32,
3430 n_rot: *n_rot as u32,
3431 cos_len: cos_len as u32,
3432 src_row_stride: hidden as u32,
3436 }
3437 }
3438
3439 Op::FusedSwiGLU {
3440 cast_to: _,
3441 gate_first,
3442 } => {
3443 let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3444 let total = node.shape.num_elements().unwrap();
3445 Thunk::FusedSwiGLU {
3446 src: node_offset(arena, node.inputs[0]),
3447 dst: node_offset(arena, node.id),
3448 n_half: n_half as u32,
3449 total: total as u32,
3450 gate_first: *gate_first,
3451 }
3452 }
3453
3454 Op::Conv {
3455 kernel_size,
3456 stride,
3457 padding,
3458 dilation,
3459 groups,
3460 } => {
3461 let in_shape = &graph.node(node.inputs[0]).shape;
3462 let w_shape = &graph.node(node.inputs[1]).shape;
3463 let out_shape = &node.shape;
3464 let is_1x1_simple = kernel_size.len() == 2
3468 && kernel_size[0] == 1
3469 && kernel_size[1] == 1
3470 && stride.iter().all(|&s| s == 1)
3471 && padding.iter().all(|&p| p == 0)
3472 && dilation.iter().all(|&d| d == 1)
3473 && *groups == 1;
3474 if is_1x1_simple && in_shape.rank() == 4 && out_shape.rank() == 4 {
3475 let n = in_shape.dim(0).unwrap_static();
3476 let c_in = in_shape.dim(1).unwrap_static();
3477 let c_out = out_shape.dim(1).unwrap_static();
3478 let h = in_shape.dim(2).unwrap_static();
3479 let w = in_shape.dim(3).unwrap_static();
3480 Thunk::Conv2D1x1 {
3481 src: node_offset(arena, node.inputs[0]),
3482 weight: node_offset(arena, node.inputs[1]),
3483 dst: node_offset(arena, node.id),
3484 n: n as u32,
3485 c_in: c_in as u32,
3486 c_out: c_out as u32,
3487 hw: (h * w) as u32,
3488 }
3489 } else if kernel_size.len() == 2
3490 && in_shape.rank() == 4
3491 && w_shape.rank() == 4
3492 && out_shape.rank() == 4
3493 {
3494 Thunk::Conv2D {
3495 src: node_offset(arena, node.inputs[0]),
3496 weight: node_offset(arena, node.inputs[1]),
3497 dst: node_offset(arena, node.id),
3498 n: in_shape.dim(0).unwrap_static() as u32,
3499 c_in: in_shape.dim(1).unwrap_static() as u32,
3500 h: in_shape.dim(2).unwrap_static() as u32,
3501 w: in_shape.dim(3).unwrap_static() as u32,
3502 c_out: out_shape.dim(1).unwrap_static() as u32,
3503 h_out: out_shape.dim(2).unwrap_static() as u32,
3504 w_out: out_shape.dim(3).unwrap_static() as u32,
3505 kh: kernel_size[0] as u32,
3506 kw: kernel_size[1] as u32,
3507 sh: stride.first().copied().unwrap_or(1) as u32,
3508 sw: stride.get(1).copied().unwrap_or(1) as u32,
3509 ph: padding.first().copied().unwrap_or(0) as u32,
3510 pw: padding.get(1).copied().unwrap_or(0) as u32,
3511 dh: dilation.first().copied().unwrap_or(1) as u32,
3512 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3513 groups: *groups as u32,
3514 }
3515 } else {
3516 Thunk::Nop
3517 }
3518 }
3519
3520 Op::Pool {
3521 kind,
3522 kernel_size,
3523 stride,
3524 padding,
3525 } => {
3526 let in_shape = &graph.node(node.inputs[0]).shape;
3528 let out_shape = &node.shape;
3529 if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3530 Thunk::Pool2D {
3531 src: node_offset(arena, node.inputs[0]),
3532 dst: node_offset(arena, node.id),
3533 n: in_shape.dim(0).unwrap_static() as u32,
3534 c: in_shape.dim(1).unwrap_static() as u32,
3535 h: in_shape.dim(2).unwrap_static() as u32,
3536 w: in_shape.dim(3).unwrap_static() as u32,
3537 h_out: out_shape.dim(2).unwrap_static() as u32,
3538 w_out: out_shape.dim(3).unwrap_static() as u32,
3539 kh: kernel_size[0] as u32,
3540 kw: kernel_size[1] as u32,
3541 sh: stride.first().copied().unwrap_or(1) as u32,
3542 sw: stride.get(1).copied().unwrap_or(1) as u32,
3543 ph: padding.first().copied().unwrap_or(0) as u32,
3544 pw: padding.get(1).copied().unwrap_or(0) as u32,
3545 kind: *kind,
3546 }
3547 } else {
3548 Thunk::Nop
3549 }
3550 }
3551
3552 Op::Transpose { perm } => {
3553 let in_shape = &graph.node(node.inputs[0]).shape;
3556 let in_rank = in_shape.rank();
3557 let in_dims: Vec<usize> = (0..in_rank)
3558 .map(|i| in_shape.dim(i).unwrap_static())
3559 .collect();
3560 let mut in_strides_full = vec![1usize; in_rank];
3562 for d in (0..in_rank.saturating_sub(1)).rev() {
3563 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3564 }
3565 let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3566 let in_strides: Vec<u32> =
3567 perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3568 let in_total = in_dims.iter().product::<usize>() as u32;
3569 let src = node_offset(arena, node.inputs[0]);
3570 let dst = node_offset(arena, node.id);
3571 match node.shape.dtype() {
3572 rlx_ir::DType::F64 => Thunk::TransposeF64 {
3573 src,
3574 dst,
3575 in_total,
3576 out_dims,
3577 in_strides,
3578 },
3579 _ => Thunk::Transpose {
3580 src,
3581 dst,
3582 in_total,
3583 out_dims,
3584 in_strides,
3585 },
3586 }
3587 }
3588
3589 Op::ScatterAdd => {
3590 let upd_shape = &graph.node(node.inputs[0]).shape;
3593 let out_shape = &node.shape;
3594 let num_updates = upd_shape.dim(0).unwrap_static();
3595 let out_dim = out_shape.dim(0).unwrap_static();
3596 let trailing: usize = (1..out_shape.rank())
3597 .map(|i| out_shape.dim(i).unwrap_static())
3598 .product::<usize>()
3599 .max(1);
3600 Thunk::ScatterAdd {
3601 updates: node_offset(arena, node.inputs[0]),
3602 indices: node_offset(arena, node.inputs[1]),
3603 dst: node_offset(arena, node.id),
3604 num_updates: num_updates as u32,
3605 out_dim: out_dim as u32,
3606 trailing: trailing as u32,
3607 }
3608 }
3609
3610 Op::GroupedMatMul => {
3611 let in_shape = &graph.node(node.inputs[0]).shape;
3613 let w_shape = &graph.node(node.inputs[1]).shape;
3614 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3615 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3616 let num_experts = w_shape.dim(0).unwrap_static();
3617 let n = w_shape.dim(2).unwrap_static();
3618 Thunk::GroupedMatMul {
3619 input: node_offset(arena, node.inputs[0]),
3620 weight: node_offset(arena, node.inputs[1]),
3621 expert_idx: node_offset(arena, node.inputs[2]),
3622 dst: node_offset(arena, node.id),
3623 m: m as u32,
3624 k_dim: k_dim as u32,
3625 n: n as u32,
3626 num_experts: num_experts as u32,
3627 }
3628 }
3629
3630 Op::DequantGroupedMatMul { scheme } => {
3631 let in_shape = &graph.node(node.inputs[0]).shape;
3632 let w_shape = &graph.node(node.inputs[1]).shape;
3633 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3634 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3635 let out_shape = &node.shape;
3636 let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3637 let block_elems = scheme.gguf_block_size() as usize;
3638 let block_bytes = scheme.gguf_block_bytes() as usize;
3639 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3640 let total_bytes = w_shape.num_elements().unwrap();
3641 let num_experts = total_bytes / slab_bytes.max(1);
3642 Thunk::DequantGroupedMatMulGguf {
3643 input: node_offset(arena, node.inputs[0]),
3644 w_q: node_offset(arena, node.inputs[1]),
3645 expert_idx: node_offset(arena, node.inputs[2]),
3646 dst: node_offset(arena, node.id),
3647 m: m as u32,
3648 k_dim: k_dim as u32,
3649 n: n as u32,
3650 num_experts: num_experts as u32,
3651 scheme: *scheme,
3652 }
3653 }
3654
3655 Op::DequantMoEWeights { scheme } => {
3656 let w_shape = &graph.node(node.inputs[0]).shape;
3657 let out_shape = &node.shape;
3658 let num_experts = out_shape.dim(0).unwrap_static();
3659 let k_dim = out_shape.dim(1).unwrap_static();
3660 let n = out_shape.dim(2).unwrap_static();
3661 let block_elems = scheme.gguf_block_size() as usize;
3662 let block_bytes = scheme.gguf_block_bytes() as usize;
3663 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3664 let total_bytes = w_shape.num_elements().unwrap();
3665 assert_eq!(
3666 total_bytes,
3667 num_experts * slab_bytes,
3668 "DequantMoEWeights packed bytes mismatch"
3669 );
3670 Thunk::DequantMoEWeightsGguf {
3671 w_q: node_offset(arena, node.inputs[0]),
3672 dst: node_offset(arena, node.id),
3673 k_dim: k_dim as u32,
3674 n: n as u32,
3675 num_experts: num_experts as u32,
3676 scheme: *scheme,
3677 }
3678 }
3679
3680 Op::TopK { k } => {
3681 let in_shape = &graph.node(node.inputs[0]).shape;
3682 let rank = in_shape.rank();
3683 let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3684 let outer = in_shape.num_elements().unwrap() / axis_dim;
3685 Thunk::TopK {
3686 src: node_offset(arena, node.inputs[0]),
3687 dst: node_offset(arena, node.id),
3688 outer: outer as u32,
3689 axis_dim: axis_dim as u32,
3690 k: *k as u32,
3691 }
3692 }
3693
3694 Op::Reduce {
3695 op,
3696 axes,
3697 keep_dim: _,
3698 } => {
3699 let in_shape = &graph.node(node.inputs[0]).shape;
3705 let rank = in_shape.rank();
3706 let mut sorted = axes.clone();
3707 sorted.sort();
3708 sorted.dedup();
3709 let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
3710 && !sorted.is_empty()
3711 && *sorted.last().unwrap() < rank;
3712 if !contiguous {
3713 Thunk::Nop
3714 } else {
3715 let first = sorted[0];
3716 let last = *sorted.last().unwrap();
3717 let outer: usize = (0..first)
3718 .map(|i| in_shape.dim(i).unwrap_static())
3719 .product::<usize>()
3720 .max(1);
3721 let reduced: usize = (first..=last)
3722 .map(|i| in_shape.dim(i).unwrap_static())
3723 .product();
3724 let inner: usize = (last + 1..rank)
3725 .map(|i| in_shape.dim(i).unwrap_static())
3726 .product::<usize>()
3727 .max(1);
3728 let src = node_offset(arena, node.inputs[0]);
3729 let dst = node_offset(arena, node.id);
3730 if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
3731 Thunk::ReduceSumF64 {
3732 src,
3733 dst,
3734 outer: outer as u32,
3735 reduced: reduced as u32,
3736 inner: inner as u32,
3737 }
3738 } else {
3739 Thunk::Reduce {
3740 src,
3741 dst,
3742 outer: outer as u32,
3743 reduced: reduced as u32,
3744 inner: inner as u32,
3745 op: *op,
3746 }
3747 }
3748 }
3749 }
3750
3751 Op::Compare(cmp) => {
3752 let len = node.shape.num_elements().unwrap();
3753 Thunk::Compare {
3754 lhs: node_offset(arena, node.inputs[0]),
3755 rhs: node_offset(arena, node.inputs[1]),
3756 dst: node_offset(arena, node.id),
3757 len: len as u32,
3758 op: *cmp,
3759 }
3760 }
3761
3762 Op::Where => {
3763 let len = node.shape.num_elements().unwrap();
3764 Thunk::Where {
3765 cond: node_offset(arena, node.inputs[0]),
3766 on_true: node_offset(arena, node.inputs[1]),
3767 on_false: node_offset(arena, node.inputs[2]),
3768 dst: node_offset(arena, node.id),
3769 len: len as u32,
3770 }
3771 }
3772
3773 Op::ReluBackward => {
3774 let len: usize = (0..node.shape.rank())
3775 .map(|i| node.shape.dim(i).unwrap_static())
3776 .product();
3777 let x = node_offset(arena, node.inputs[0]);
3778 let dy = node_offset(arena, node.inputs[1]);
3779 let dx = node_offset(arena, node.id);
3780 match node.shape.dtype() {
3781 rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
3782 x,
3783 dy,
3784 dx,
3785 len: len as u32,
3786 },
3787 _ => Thunk::ReluBackward {
3788 x,
3789 dy,
3790 dx,
3791 len: len as u32,
3792 },
3793 }
3794 }
3795
3796 Op::ComplexNormSq => {
3797 let len: usize = (0..node.shape.rank())
3798 .map(|i| node.shape.dim(i).unwrap_static())
3799 .product();
3800 let src = node_offset(arena, node.inputs[0]);
3801 let dst = node_offset(arena, node.id);
3802 Thunk::ComplexNormSqF32 {
3803 src,
3804 dst,
3805 len: len as u32,
3806 }
3807 }
3808
3809 Op::ComplexNormSqBackward => {
3810 let len: usize = (0..node.shape.rank())
3811 .map(|i| node.shape.dim(i).unwrap_static())
3812 .product();
3813 let z = node_offset(arena, node.inputs[0]);
3814 let g = node_offset(arena, node.inputs[1]);
3815 let dz = node_offset(arena, node.id);
3816 Thunk::ComplexNormSqBackwardF32 {
3817 z,
3818 g,
3819 dz,
3820 len: len as u32,
3821 }
3822 }
3823
3824 Op::Conjugate => {
3825 let len: usize = (0..node.shape.rank())
3826 .map(|i| node.shape.dim(i).unwrap_static())
3827 .product();
3828 Thunk::ConjugateC64 {
3829 src: node_offset(arena, node.inputs[0]),
3830 dst: node_offset(arena, node.id),
3831 len: len as u32,
3832 }
3833 }
3834
3835 Op::ActivationBackward { kind } => {
3836 let len: usize = (0..node.shape.rank())
3837 .map(|i| node.shape.dim(i).unwrap_static())
3838 .product();
3839 let x = node_offset(arena, node.inputs[0]);
3840 let dy = node_offset(arena, node.inputs[1]);
3841 let dx = node_offset(arena, node.id);
3842 match node.shape.dtype() {
3843 rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
3844 x,
3845 dy,
3846 dx,
3847 len: len as u32,
3848 kind: *kind,
3849 },
3850 _ => Thunk::ActivationBackward {
3851 x,
3852 dy,
3853 dx,
3854 len: len as u32,
3855 kind: *kind,
3856 },
3857 }
3858 }
3859
3860 Op::LayerNormBackwardInput { eps, .. } => {
3861 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3863 let total = node.shape.num_elements().unwrap();
3864 Thunk::LayerNormBackwardInput {
3865 x: node_offset(arena, node.inputs[0]),
3866 gamma: node_offset(arena, node.inputs[1]),
3867 dy: node_offset(arena, node.inputs[2]),
3868 dx: node_offset(arena, node.id),
3869 rows: (total / h) as u32,
3870 h: h as u32,
3871 eps: *eps,
3872 }
3873 }
3874
3875 Op::LayerNormBackwardGamma { eps, .. } => {
3876 let x_shape = &graph.node(node.inputs[0]).shape;
3877 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3878 let x_total = x_shape.num_elements().unwrap();
3879 Thunk::LayerNormBackwardGamma {
3880 x: node_offset(arena, node.inputs[0]),
3881 dy: node_offset(arena, node.inputs[1]),
3882 dgamma: node_offset(arena, node.id),
3883 rows: (x_total / h) as u32,
3884 h: h as u32,
3885 eps: *eps,
3886 }
3887 }
3888
3889 Op::RmsNormBackwardInput { eps, .. }
3890 | Op::RmsNormBackwardGamma { eps, .. }
3891 | Op::RmsNormBackwardBeta { eps, .. } => {
3892 let x_shape = &graph.node(node.inputs[0]).shape;
3893 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3894 let rows = (x_shape.num_elements().unwrap() / h) as u32;
3895 let off = |i: usize| node_offset(arena, node.inputs[i]);
3896 let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
3897 match &node.op {
3898 Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
3899 x: common.0,
3900 gamma: common.1,
3901 beta: common.2,
3902 dy: common.3,
3903 dx: node_offset(arena, node.id),
3904 rows: common.4,
3905 h: common.5,
3906 eps: common.6,
3907 },
3908 Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
3909 x: common.0,
3910 gamma: common.1,
3911 beta: common.2,
3912 dy: common.3,
3913 dgamma: node_offset(arena, node.id),
3914 rows: common.4,
3915 h: common.5,
3916 eps: common.6,
3917 },
3918 Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
3919 x: common.0,
3920 gamma: common.1,
3921 beta: common.2,
3922 dy: common.3,
3923 dbeta: node_offset(arena, node.id),
3924 rows: common.4,
3925 h: common.5,
3926 eps: common.6,
3927 },
3928 _ => unreachable!(),
3929 }
3930 }
3931
3932 Op::RopeBackward { head_dim, n_rot } => {
3933 let dy_shape = &graph.node(node.inputs[0]).shape;
3934 let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
3935 (
3936 dy_shape.dim(0).unwrap_static(),
3937 dy_shape.dim(1).unwrap_static(),
3938 dy_shape.dim(2).unwrap_static(),
3939 )
3940 } else {
3941 (
3942 1,
3943 dy_shape.dim(0).unwrap_static(),
3944 dy_shape.dim(1).unwrap_static(),
3945 )
3946 };
3947 let cos_shape = &graph.node(node.inputs[1]).shape;
3948 let cos_len = cos_shape.num_elements().unwrap();
3949 Thunk::RopeBackward {
3950 dy: node_offset(arena, node.inputs[0]),
3951 cos: node_offset(arena, node.inputs[1]),
3952 sin: node_offset(arena, node.inputs[2]),
3953 dx: node_offset(arena, node.id),
3954 batch: batch as u32,
3955 seq: seq as u32,
3956 hidden: hidden as u32,
3957 head_dim: *head_dim as u32,
3958 n_rot: *n_rot as u32,
3959 cos_len: cos_len as u32,
3960 }
3961 }
3962
3963 Op::CumsumBackward { exclusive, .. } => {
3964 let dy_shape = &graph.node(node.inputs[0]).shape;
3965 let rank = dy_shape.rank();
3966 let cols = dy_shape.dim(rank - 1).unwrap_static();
3967 let rows = dy_shape.num_elements().unwrap() / cols;
3968 Thunk::CumsumBackward {
3969 dy: node_offset(arena, node.inputs[0]),
3970 dx: node_offset(arena, node.id),
3971 rows: rows as u32,
3972 cols: cols as u32,
3973 exclusive: *exclusive,
3974 }
3975 }
3976
3977 Op::GatherBackward { .. } => {
3978 let dy_shape = &graph.node(node.inputs[0]).shape;
3979 let idx_shape = &graph.node(node.inputs[1]).shape;
3980 let out_shape = &node.shape;
3981 let rank = out_shape.rank();
3982 let axis = match &node.op {
3983 Op::GatherBackward { axis } => *axis,
3984 _ => 0,
3985 };
3986 let axis_u = if axis < 0 {
3987 (rank as i32 + axis) as usize
3988 } else {
3989 axis as usize
3990 };
3991 let outer: usize = (0..axis_u)
3992 .map(|i| dy_shape.dim(i).unwrap_static())
3993 .product::<usize>()
3994 .max(1);
3995 let num_idx = idx_shape.dim(axis_u).unwrap_static();
3996 let trailing: usize = (axis_u + 1..dy_shape.rank())
3997 .map(|i| dy_shape.dim(i).unwrap_static())
3998 .product::<usize>()
3999 .max(1);
4000 let axis_dim = out_shape.dim(axis_u).unwrap_static();
4001 Thunk::GatherBackward {
4002 dy: node_offset(arena, node.inputs[0]),
4003 indices: node_offset(arena, node.inputs[1]),
4004 dst: node_offset(arena, node.id),
4005 outer: outer as u32,
4006 axis_dim: axis_dim as u32,
4007 num_idx: num_idx as u32,
4008 trailing: trailing as u32,
4009 }
4010 }
4011
4012 Op::GroupNormBackwardInput { num_groups, eps }
4013 | Op::GroupNormBackwardGamma { num_groups, eps }
4014 | Op::GroupNormBackwardBeta { num_groups, eps } => {
4015 let x_shape = &graph.node(node.inputs[0]).shape;
4016 let n = x_shape.dim(0).unwrap_static() as u32;
4017 let c = x_shape.dim(1).unwrap_static() as u32;
4018 let h = x_shape.dim(2).unwrap_static() as u32;
4019 let w = x_shape.dim(3).unwrap_static() as u32;
4020 match &node.op {
4021 Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4022 x: node_offset(arena, node.inputs[0]),
4023 gamma: node_offset(arena, node.inputs[1]),
4024 beta: node_offset(arena, node.inputs[2]),
4025 dy: node_offset(arena, node.inputs[3]),
4026 dx: node_offset(arena, node.id),
4027 n,
4028 c,
4029 h,
4030 w,
4031 num_groups: *num_groups as u32,
4032 eps: *eps,
4033 },
4034 Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4035 x: node_offset(arena, node.inputs[0]),
4036 dy: node_offset(arena, node.inputs[1]),
4037 dgamma: node_offset(arena, node.id),
4038 n,
4039 c,
4040 h,
4041 w,
4042 num_groups: *num_groups as u32,
4043 eps: *eps,
4044 },
4045 Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4046 dy: node_offset(arena, node.inputs[1]),
4047 dbeta: node_offset(arena, node.id),
4048 n,
4049 c,
4050 h,
4051 w,
4052 },
4053 _ => unreachable!(),
4054 }
4055 }
4056
4057 Op::MaxPool2dBackward {
4058 kernel_size,
4059 stride,
4060 padding,
4061 } => {
4062 let x_shape = &graph.node(node.inputs[0]).shape;
4063 let dy_shape = &graph.node(node.inputs[1]).shape;
4064 if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4065 Thunk::MaxPool2dBackward {
4066 x: node_offset(arena, node.inputs[0]),
4067 dy: node_offset(arena, node.inputs[1]),
4068 dx: node_offset(arena, node.id),
4069 n: x_shape.dim(0).unwrap_static() as u32,
4070 c: x_shape.dim(1).unwrap_static() as u32,
4071 h: x_shape.dim(2).unwrap_static() as u32,
4072 w: x_shape.dim(3).unwrap_static() as u32,
4073 h_out: dy_shape.dim(2).unwrap_static() as u32,
4074 w_out: dy_shape.dim(3).unwrap_static() as u32,
4075 kh: kernel_size[0] as u32,
4076 kw: kernel_size[1] as u32,
4077 sh: stride.first().copied().unwrap_or(1) as u32,
4078 sw: stride.get(1).copied().unwrap_or(1) as u32,
4079 ph: padding.first().copied().unwrap_or(0) as u32,
4080 pw: padding.get(1).copied().unwrap_or(0) as u32,
4081 }
4082 } else {
4083 Thunk::Nop
4084 }
4085 }
4086
4087 Op::Conv2dBackwardInput {
4088 kernel_size,
4089 stride,
4090 padding,
4091 dilation,
4092 groups,
4093 } => {
4094 let dy_shape = &graph.node(node.inputs[0]).shape;
4095 let w_shape = &graph.node(node.inputs[1]).shape;
4096 let out_shape = &node.shape;
4097 if kernel_size.len() == 2
4098 && dy_shape.rank() == 4
4099 && w_shape.rank() == 4
4100 && out_shape.rank() == 4
4101 {
4102 Thunk::Conv2dBackwardInput {
4103 dy: node_offset(arena, node.inputs[0]),
4104 w: node_offset(arena, node.inputs[1]),
4105 dx: node_offset(arena, node.id),
4106 n: out_shape.dim(0).unwrap_static() as u32,
4107 c_in: out_shape.dim(1).unwrap_static() as u32,
4108 h: out_shape.dim(2).unwrap_static() as u32,
4109 w_in: out_shape.dim(3).unwrap_static() as u32,
4110 c_out: dy_shape.dim(1).unwrap_static() as u32,
4111 h_out: dy_shape.dim(2).unwrap_static() as u32,
4112 w_out: dy_shape.dim(3).unwrap_static() as u32,
4113 kh: kernel_size[0] as u32,
4114 kw: kernel_size[1] as u32,
4115 sh: stride.first().copied().unwrap_or(1) as u32,
4116 sw: stride.get(1).copied().unwrap_or(1) as u32,
4117 ph: padding.first().copied().unwrap_or(0) as u32,
4118 pw: padding.get(1).copied().unwrap_or(0) as u32,
4119 dh: dilation.first().copied().unwrap_or(1) as u32,
4120 dw: dilation.get(1).copied().unwrap_or(1) as u32,
4121 groups: *groups as u32,
4122 }
4123 } else {
4124 Thunk::Nop
4125 }
4126 }
4127
4128 Op::Conv2dBackwardWeight {
4129 kernel_size,
4130 stride,
4131 padding,
4132 dilation,
4133 groups,
4134 } => {
4135 let x_shape = &graph.node(node.inputs[0]).shape;
4136 let dy_shape = &graph.node(node.inputs[1]).shape;
4137 let dw_shape = &node.shape;
4138 if kernel_size.len() == 2
4139 && x_shape.rank() == 4
4140 && dy_shape.rank() == 4
4141 && dw_shape.rank() == 4
4142 {
4143 Thunk::Conv2dBackwardWeight {
4144 x: node_offset(arena, node.inputs[0]),
4145 dy: node_offset(arena, node.inputs[1]),
4146 dw: node_offset(arena, node.id),
4147 n: x_shape.dim(0).unwrap_static() as u32,
4148 c_in: x_shape.dim(1).unwrap_static() as u32,
4149 h: x_shape.dim(2).unwrap_static() as u32,
4150 w: x_shape.dim(3).unwrap_static() as u32,
4151 c_out: dy_shape.dim(1).unwrap_static() as u32,
4152 h_out: dy_shape.dim(2).unwrap_static() as u32,
4153 w_out: dy_shape.dim(3).unwrap_static() as u32,
4154 kh: kernel_size[0] as u32,
4155 kw: kernel_size[1] as u32,
4156 sh: stride.first().copied().unwrap_or(1) as u32,
4157 sw: stride.get(1).copied().unwrap_or(1) as u32,
4158 ph: padding.first().copied().unwrap_or(0) as u32,
4159 pw: padding.get(1).copied().unwrap_or(0) as u32,
4160 dh: dilation.first().copied().unwrap_or(1) as u32,
4161 dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4162 groups: *groups as u32,
4163 }
4164 } else {
4165 Thunk::Nop
4166 }
4167 }
4168
4169 Op::SoftmaxCrossEntropyWithLogits => {
4170 let logits_shape = &graph.node(node.inputs[0]).shape;
4171 if logits_shape.rank() == 2 {
4172 Thunk::SoftmaxCrossEntropy {
4173 logits: node_offset(arena, node.inputs[0]),
4174 labels: node_offset(arena, node.inputs[1]),
4175 dst: node_offset(arena, node.id),
4176 n: logits_shape.dim(0).unwrap_static() as u32,
4177 c: logits_shape.dim(1).unwrap_static() as u32,
4178 }
4179 } else {
4180 Thunk::Nop
4181 }
4182 }
4183
4184 Op::SoftmaxCrossEntropyBackward => {
4185 let logits_shape = &graph.node(node.inputs[0]).shape;
4186 if logits_shape.rank() == 2 {
4187 Thunk::SoftmaxCrossEntropyBackward {
4188 logits: node_offset(arena, node.inputs[0]),
4189 labels: node_offset(arena, node.inputs[1]),
4190 d_loss: node_offset(arena, node.inputs[2]),
4191 dlogits: node_offset(arena, node.id),
4192 n: logits_shape.dim(0).unwrap_static() as u32,
4193 c: logits_shape.dim(1).unwrap_static() as u32,
4194 }
4195 } else {
4196 Thunk::Nop
4197 }
4198 }
4199
4200 Op::DenseSolve => {
4201 let a_shape = &graph.node(node.inputs[0]).shape;
4203 let n = a_shape.dim(0).unwrap_static();
4204 debug_assert_eq!(
4205 n,
4206 a_shape.dim(1).unwrap_static(),
4207 "DenseSolve: A must be square"
4208 );
4209 let b_elems = node.shape.num_elements().unwrap();
4210 let nrhs = b_elems / n;
4211 match node.shape.dtype() {
4212 rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4213 a: node_offset(arena, node.inputs[0]),
4214 b: node_offset(arena, node.inputs[1]),
4215 x: node_offset(arena, node.id),
4216 n: n as u32,
4217 nrhs: nrhs as u32,
4218 },
4219 rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4220 a: node_offset(arena, node.inputs[0]),
4221 b: node_offset(arena, node.inputs[1]),
4222 x: node_offset(arena, node.id),
4223 n: n as u32,
4224 nrhs: nrhs as u32,
4225 },
4226 other => panic!(
4227 "DenseSolve: F32 + F64 lowered; got {other:?}. \
4228 Add another variant when needed."
4229 ),
4230 }
4231 }
4232
4233 Op::BatchedDenseSolve => {
4234 let a_shape = &graph.node(node.inputs[0]).shape;
4236 assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4237 let batch = a_shape.dim(0).unwrap_static();
4238 let n = a_shape.dim(1).unwrap_static();
4239 debug_assert_eq!(
4240 n,
4241 a_shape.dim(2).unwrap_static(),
4242 "BatchedDenseSolve: A's last two dims must match"
4243 );
4244 let total = node.shape.num_elements().unwrap();
4245 let nrhs = total / (batch * n);
4246 match node.shape.dtype() {
4247 rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4248 a: node_offset(arena, node.inputs[0]),
4249 b: node_offset(arena, node.inputs[1]),
4250 x: node_offset(arena, node.id),
4251 batch: batch as u32,
4252 n: n as u32,
4253 nrhs: nrhs as u32,
4254 },
4255 rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4256 a: node_offset(arena, node.inputs[0]),
4257 b: node_offset(arena, node.inputs[1]),
4258 x: node_offset(arena, node.id),
4259 batch: batch as u32,
4260 n: n as u32,
4261 nrhs: nrhs as u32,
4262 },
4263 other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4264 }
4265 }
4266
4267 Op::Scan {
4268 body,
4269 length,
4270 save_trajectory,
4271 num_bcast,
4272 num_xs,
4273 num_checkpoints,
4274 } => {
4275 assert!(
4276 *num_checkpoints == 0 || *num_checkpoints <= *length,
4277 "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4278 *num_checkpoints,
4279 *length
4280 );
4281 if *num_checkpoints != 0 && *num_checkpoints != *length {
4282 assert!(
4283 *save_trajectory,
4284 "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4285 );
4286 }
4287 let body_plan = rlx_opt::memory::plan_memory(body);
4298 let _body_arena_size = body_plan.arena_size;
4299 let body_offsets: HashMap<NodeId, usize> = body_plan
4302 .assignments
4303 .iter()
4304 .map(|(id, slot)| (*id, slot.offset))
4305 .collect();
4306
4307 let mut body_inputs: Vec<NodeId> = body
4310 .nodes()
4311 .iter()
4312 .filter(|n| matches!(n.op, Op::Input { .. }))
4313 .map(|n| n.id)
4314 .collect();
4315 body_inputs.sort();
4316 let n_body_inputs = body_inputs.len();
4317 let expected = 1 + *num_bcast as usize + *num_xs as usize;
4318 if n_body_inputs != expected {
4319 let names: Vec<String> = body
4320 .nodes()
4321 .iter()
4322 .filter_map(|n| match &n.op {
4323 Op::Input { name } => Some(format!("{}={}", n.id, name)),
4324 _ => None,
4325 })
4326 .collect();
4327 panic!(
4328 "Op::Scan body has {} Op::Input nodes; expected {} \
4329 (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4330 n_body_inputs,
4331 expected,
4332 *num_bcast,
4333 *num_xs,
4334 names.join(", ")
4335 );
4336 }
4337
4338 let body_input_id = body_inputs[0];
4339 let body_input_off = body_offsets[&body_input_id];
4340 let body_output_id = body
4341 .outputs
4342 .first()
4343 .copied()
4344 .expect("Op::Scan body must declare one output");
4345 let body_output_off = body_offsets[&body_output_id];
4346
4347 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4348 for n in body.nodes() {
4351 if let Op::Constant { data } = &n.op
4352 && body_arena.has_buffer(n.id)
4353 && !data.is_empty()
4354 {
4355 match n.shape.dtype() {
4356 rlx_ir::DType::F64 => {
4357 let off = body_arena.byte_offset(n.id);
4358 let buf = body_arena.raw_buf_mut();
4359 let nbytes = (buf.len() - off).min(data.len());
4360 buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4361 }
4362 _ => {
4363 let buf = body_arena.slice_mut(n.id);
4364 let n_floats = data.len() / 4;
4365 let n_lim = buf.len().min(n_floats);
4366 for i in 0..n_lim {
4367 let bytes = [
4368 data[i * 4],
4369 data[i * 4 + 1],
4370 data[i * 4 + 2],
4371 data[i * 4 + 3],
4372 ];
4373 buf[i] = f32::from_le_bytes(bytes);
4374 }
4375 }
4376 }
4377 }
4378 }
4379 let body_init = body_arena.raw_buf().to_vec();
4380 let body_schedule = compile_thunks(body, &body_arena);
4381
4382 let carry_bytes = if *save_trajectory {
4387 let total = node
4388 .shape
4389 .size_bytes()
4390 .expect("Op::Scan trajectory output must have static shape");
4391 total / *length as usize
4392 } else {
4393 node.shape
4394 .size_bytes()
4395 .expect("Op::Scan carry must have static shape")
4396 };
4397
4398 let mut bcast_inputs: Vec<(usize, usize, u32)> =
4403 Vec::with_capacity(*num_bcast as usize);
4404 for i in 0..*num_bcast as usize {
4405 let body_b_id = body_inputs[1 + i];
4406 let body_b_off = body_offsets[&body_b_id];
4407 let outer_b_id = node.inputs[1 + i];
4408 let outer_b_off = node_offset(arena, outer_b_id);
4409 let outer_b_shape = &graph.node(outer_b_id).shape;
4410 let total = outer_b_shape
4411 .size_bytes()
4412 .expect("Op::Scan bcast must have static shape");
4413 bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4414 }
4415
4416 let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4420 let xs_base = 1 + *num_bcast as usize;
4421 for i in 0..*num_xs as usize {
4422 let body_x_id = body_inputs[xs_base + i];
4423 let body_x_off = body_offsets[&body_x_id];
4424 let outer_xs_id = node.inputs[xs_base + i];
4425 let outer_xs_off = node_offset(arena, outer_xs_id);
4426 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4427 let total = outer_xs_shape
4428 .size_bytes()
4429 .expect("Op::Scan xs must have static shape");
4430 let per_step = total / *length as usize;
4431 xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4432 }
4433
4434 Thunk::Scan {
4435 body: Arc::new(body_schedule),
4436 body_init: Arc::new(body_init),
4437 body_input_off,
4438 body_output_off,
4439 outer_init_off: node_offset(arena, node.inputs[0]),
4440 outer_final_off: node_offset(arena, node.id),
4441 length: *length,
4442 carry_bytes: carry_bytes as u32,
4443 save_trajectory: *save_trajectory,
4444 xs_inputs: Arc::new(xs_inputs),
4445 bcast_inputs: Arc::new(bcast_inputs),
4446 num_checkpoints: *num_checkpoints,
4447 }
4448 }
4449
4450 Op::ScanBackward {
4451 body_vjp,
4452 length,
4453 save_trajectory,
4454 num_xs,
4455 num_checkpoints,
4456 forward_body,
4457 } => {
4458 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4459 if is_recursive {
4460 assert!(
4461 forward_body.is_some(),
4462 "Op::ScanBackward with num_checkpoints<length requires forward_body"
4463 );
4464 }
4465 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4473 let body_offsets: HashMap<NodeId, usize> = body_plan
4474 .assignments
4475 .iter()
4476 .map(|(id, slot)| (*id, slot.offset))
4477 .collect();
4478 let mut body_d_output_off: Option<usize> = None;
4479 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4480 for n in body_vjp.nodes() {
4481 if let Op::Input { name } = &n.op {
4482 let off = body_offsets[&n.id];
4483 if name == "d_output" {
4484 body_d_output_off = Some(off);
4485 } else {
4486 body_other_inputs.push((n.id, off));
4487 }
4488 }
4489 }
4490 body_other_inputs.sort_by_key(|(id, _)| *id);
4491 let body_d_output_off =
4492 body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4493 let expected_others = 1 + *num_xs as usize;
4494 assert_eq!(
4495 body_other_inputs.len(),
4496 expected_others,
4497 "ScanBackward body_vjp has {} non-d_output Inputs; \
4498 expected {} (1 carry + {} xs)",
4499 body_other_inputs.len(),
4500 expected_others,
4501 num_xs
4502 );
4503 let body_carry_in_off = body_other_inputs[0].1;
4504 let body_x_offs: Vec<usize> = body_other_inputs
4505 .iter()
4506 .skip(1)
4507 .map(|(_, off)| *off)
4508 .collect();
4509 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4510
4511 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4512 for n in body_vjp.nodes() {
4514 if let Op::Constant { data } = &n.op
4515 && body_arena.has_buffer(n.id)
4516 && !data.is_empty()
4517 {
4518 match n.shape.dtype() {
4519 rlx_ir::DType::F64 => {
4520 let off = body_arena.byte_offset(n.id);
4521 let buf = body_arena.raw_buf_mut();
4522 let nb = (buf.len() - off).min(data.len());
4523 buf[off..off + nb].copy_from_slice(&data[..nb]);
4524 }
4525 _ => {
4526 let buf = body_arena.slice_mut(n.id);
4527 let nf = data.len() / 4;
4528 let nl = buf.len().min(nf);
4529 for i in 0..nl {
4530 let bytes = [
4531 data[i * 4],
4532 data[i * 4 + 1],
4533 data[i * 4 + 2],
4534 data[i * 4 + 3],
4535 ];
4536 buf[i] = f32::from_le_bytes(bytes);
4537 }
4538 }
4539 }
4540 }
4541 }
4542 let body_init = body_arena.raw_buf().to_vec();
4543 let body_schedule = compile_thunks(body_vjp, &body_arena);
4544
4545 let carry_bytes = body_vjp
4547 .node(body_vjp.outputs[0])
4548 .shape
4549 .size_bytes()
4550 .expect("ScanBackward dcarry must be statically shaped");
4551 let carry_elem_size = body_vjp
4552 .node(body_vjp.outputs[0])
4553 .shape
4554 .dtype()
4555 .size_bytes() as u32;
4556
4557 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4560 for i in 0..*num_xs as usize {
4561 let outer_xs_id = node.inputs[3 + i];
4562 let outer_xs_off = node_offset(arena, outer_xs_id);
4563 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4564 let total = outer_xs_shape
4565 .size_bytes()
4566 .expect("ScanBackward xs must have static shape");
4567 let per_step = total / *length as usize;
4568 outer_xs_offs.push((outer_xs_off, per_step as u32));
4569 }
4570
4571 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4576 if is_recursive {
4577 let fb = forward_body.as_ref().unwrap();
4578 let fb_plan = rlx_opt::memory::plan_memory(fb);
4579 let fb_offsets: HashMap<NodeId, usize> = fb_plan
4580 .assignments
4581 .iter()
4582 .map(|(id, slot)| (*id, slot.offset))
4583 .collect();
4584 let mut fb_inputs: Vec<NodeId> = fb
4585 .nodes()
4586 .iter()
4587 .filter(|n| matches!(n.op, Op::Input { .. }))
4588 .map(|n| n.id)
4589 .collect();
4590 fb_inputs.sort();
4591 let fb_carry = fb_offsets[&fb_inputs[0]];
4592 let fb_xs: Vec<usize> = (1..fb_inputs.len())
4593 .map(|i| fb_offsets[&fb_inputs[i]])
4594 .collect();
4595 let fb_out = fb_offsets[&fb.outputs[0]];
4596 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4597 for n in fb.nodes() {
4598 if let Op::Constant { data } = &n.op
4599 && fb_arena.has_buffer(n.id)
4600 && !data.is_empty()
4601 {
4602 let off = fb_arena.byte_offset(n.id);
4609 let buf = fb_arena.raw_buf_mut();
4610 let nb = (buf.len() - off).min(data.len());
4611 buf[off..off + nb].copy_from_slice(&data[..nb]);
4612 }
4613 }
4614 let fb_init_bytes = fb_arena.raw_buf().to_vec();
4615 let fb_sched = compile_thunks(fb, &fb_arena);
4616 (
4617 Some(Arc::new(fb_sched)),
4618 Some(Arc::new(fb_init_bytes)),
4619 fb_carry,
4620 fb_out,
4621 fb_xs,
4622 )
4623 } else {
4624 (None, None, 0, 0, Vec::new())
4625 };
4626
4627 Thunk::ScanBackward {
4628 body_vjp: Arc::new(body_schedule),
4629 body_init: Arc::new(body_init),
4630 body_carry_in_off,
4631 body_x_offs: Arc::new(body_x_offs),
4632 body_d_output_off,
4633 body_dcarry_out_off,
4634 outer_init_off: node_offset(arena, node.inputs[0]),
4635 outer_traj_off: node_offset(arena, node.inputs[1]),
4636 outer_upstream_off: node_offset(arena, node.inputs[2]),
4637 outer_xs_offs: Arc::new(outer_xs_offs),
4638 outer_dinit_off: node_offset(arena, node.id),
4639 length: *length,
4640 carry_bytes: carry_bytes as u32,
4641 carry_elem_size,
4642 save_trajectory: *save_trajectory,
4643 num_checkpoints: *num_checkpoints,
4644 forward_body: fb_schedule,
4645 forward_body_init: fb_init,
4646 forward_body_carry_in_off: fb_carry_in_off,
4647 forward_body_output_off: fb_output_off,
4648 forward_body_x_offs: Arc::new(fb_x_offs),
4649 }
4650 }
4651
4652 Op::ScanBackwardXs {
4653 body_vjp,
4654 length,
4655 save_trajectory,
4656 num_xs,
4657 xs_idx,
4658 num_checkpoints,
4659 forward_body,
4660 } => {
4661 assert!(
4662 *num_checkpoints == 0 || *num_checkpoints <= *length,
4663 "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
4664 *num_checkpoints,
4665 *length
4666 );
4667 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4668 if is_recursive {
4669 assert!(
4670 forward_body.is_some(),
4671 "Op::ScanBackwardXs with num_checkpoints<length \
4672 requires forward_body"
4673 );
4674 }
4675 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4683 let body_offsets: HashMap<NodeId, usize> = body_plan
4684 .assignments
4685 .iter()
4686 .map(|(id, slot)| (*id, slot.offset))
4687 .collect();
4688 let mut body_d_output_off: Option<usize> = None;
4689 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4690 for n in body_vjp.nodes() {
4691 if let Op::Input { name } = &n.op {
4692 let off = body_offsets[&n.id];
4693 if name == "d_output" {
4694 body_d_output_off = Some(off);
4695 } else {
4696 body_other_inputs.push((n.id, off));
4697 }
4698 }
4699 }
4700 body_other_inputs.sort_by_key(|(id, _)| *id);
4701 let body_d_output_off =
4702 body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
4703 let expected_others = 1 + *num_xs as usize;
4704 assert_eq!(
4705 body_other_inputs.len(),
4706 expected_others,
4707 "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
4708 body_other_inputs.len(),
4709 expected_others
4710 );
4711 let body_carry_in_off = body_other_inputs[0].1;
4712 let body_x_offs: Vec<usize> = body_other_inputs
4713 .iter()
4714 .skip(1)
4715 .map(|(_, off)| *off)
4716 .collect();
4717 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4718 let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
4719 let body_dxs_out_off = body_offsets[&dxs_out_node];
4720
4721 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4722 for n in body_vjp.nodes() {
4723 if let Op::Constant { data } = &n.op
4724 && body_arena.has_buffer(n.id)
4725 && !data.is_empty()
4726 {
4727 match n.shape.dtype() {
4728 rlx_ir::DType::F64 => {
4729 let off = body_arena.byte_offset(n.id);
4730 let buf = body_arena.raw_buf_mut();
4731 let nb = (buf.len() - off).min(data.len());
4732 buf[off..off + nb].copy_from_slice(&data[..nb]);
4733 }
4734 _ => {
4735 let buf = body_arena.slice_mut(n.id);
4736 let nf = data.len() / 4;
4737 let nl = buf.len().min(nf);
4738 for i in 0..nl {
4739 let bytes = [
4740 data[i * 4],
4741 data[i * 4 + 1],
4742 data[i * 4 + 2],
4743 data[i * 4 + 3],
4744 ];
4745 buf[i] = f32::from_le_bytes(bytes);
4746 }
4747 }
4748 }
4749 }
4750 }
4751 let body_init = body_arena.raw_buf().to_vec();
4752 let body_schedule = compile_thunks(body_vjp, &body_arena);
4753
4754 let carry_bytes = body_vjp
4755 .node(body_vjp.outputs[0])
4756 .shape
4757 .size_bytes()
4758 .expect("ScanBackwardXs dcarry must be statically shaped");
4759 let carry_elem_size = body_vjp
4760 .node(body_vjp.outputs[0])
4761 .shape
4762 .dtype()
4763 .size_bytes() as u32;
4764 let per_step_bytes = body_vjp
4765 .node(dxs_out_node)
4766 .shape
4767 .size_bytes()
4768 .expect("ScanBackwardXs dxs body output must be statically shaped");
4769
4770 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4771 for i in 0..*num_xs as usize {
4772 let outer_xs_id = node.inputs[3 + i];
4773 let outer_xs_off = node_offset(arena, outer_xs_id);
4774 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4775 let total = outer_xs_shape
4776 .size_bytes()
4777 .expect("ScanBackwardXs xs must have static shape");
4778 let per_step = total / *length as usize;
4779 outer_xs_offs.push((outer_xs_off, per_step as u32));
4780 }
4781
4782 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4785 if is_recursive {
4786 let fb = forward_body.as_ref().unwrap();
4787 let fb_plan = rlx_opt::memory::plan_memory(fb);
4788 let fb_offsets: HashMap<NodeId, usize> = fb_plan
4789 .assignments
4790 .iter()
4791 .map(|(id, slot)| (*id, slot.offset))
4792 .collect();
4793 let mut fb_inputs: Vec<NodeId> = fb
4794 .nodes()
4795 .iter()
4796 .filter(|n| matches!(n.op, Op::Input { .. }))
4797 .map(|n| n.id)
4798 .collect();
4799 fb_inputs.sort();
4800 let fb_carry = fb_offsets[&fb_inputs[0]];
4801 let fb_xs: Vec<usize> = (1..fb_inputs.len())
4802 .map(|i| fb_offsets[&fb_inputs[i]])
4803 .collect();
4804 let fb_out = fb_offsets[&fb.outputs[0]];
4805 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4806 for n in fb.nodes() {
4807 if let Op::Constant { data } = &n.op
4808 && fb_arena.has_buffer(n.id)
4809 && !data.is_empty()
4810 {
4811 let off = fb_arena.byte_offset(n.id);
4818 let buf = fb_arena.raw_buf_mut();
4819 let nb = (buf.len() - off).min(data.len());
4820 buf[off..off + nb].copy_from_slice(&data[..nb]);
4821 }
4822 }
4823 let fb_init_bytes = fb_arena.raw_buf().to_vec();
4824 let fb_sched = compile_thunks(fb, &fb_arena);
4825 (
4826 Some(Arc::new(fb_sched)),
4827 Some(Arc::new(fb_init_bytes)),
4828 fb_carry,
4829 fb_out,
4830 fb_xs,
4831 )
4832 } else {
4833 (None, None, 0, 0, Vec::new())
4834 };
4835
4836 Thunk::ScanBackwardXs {
4837 body_vjp: Arc::new(body_schedule),
4838 body_init: Arc::new(body_init),
4839 body_carry_in_off,
4840 body_x_offs: Arc::new(body_x_offs),
4841 body_d_output_off,
4842 body_dcarry_out_off,
4843 body_dxs_out_off,
4844 outer_init_off: node_offset(arena, node.inputs[0]),
4845 outer_traj_off: node_offset(arena, node.inputs[1]),
4846 outer_upstream_off: node_offset(arena, node.inputs[2]),
4847 outer_xs_offs: Arc::new(outer_xs_offs),
4848 outer_dxs_off: node_offset(arena, node.id),
4849 length: *length,
4850 carry_bytes: carry_bytes as u32,
4851 carry_elem_size,
4852 per_step_bytes: per_step_bytes as u32,
4853 save_trajectory: *save_trajectory,
4854 num_checkpoints: *num_checkpoints,
4855 forward_body: fb_schedule,
4856 forward_body_init: fb_init,
4857 forward_body_carry_in_off: fb_carry_in_off,
4858 forward_body_output_off: fb_output_off,
4859 forward_body_x_offs: Arc::new(fb_x_offs),
4860 }
4861 }
4862
4863 Op::Concat { axis } => {
4864 let out_shape = &node.shape;
4868 let rank = out_shape.rank();
4869 let outer: usize = (0..*axis)
4870 .map(|i| out_shape.dim(i).unwrap_static())
4871 .product::<usize>()
4872 .max(1);
4873 let inner: usize = (*axis + 1..rank)
4874 .map(|i| out_shape.dim(i).unwrap_static())
4875 .product::<usize>()
4876 .max(1);
4877 let total_axis = out_shape.dim(*axis).unwrap_static();
4878 let inputs: Vec<(usize, u32)> = node
4879 .inputs
4880 .iter()
4881 .map(|&in_id| {
4882 let in_shape = &graph.node(in_id).shape;
4883 let in_axis = in_shape.dim(*axis).unwrap_static();
4884 (node_offset(arena, in_id), in_axis as u32)
4885 })
4886 .collect();
4887 let dst = node_offset(arena, node.id);
4888 match out_shape.dtype() {
4889 rlx_ir::DType::F64 => Thunk::ConcatF64 {
4890 dst,
4891 outer: outer as u32,
4892 inner: inner as u32,
4893 total_axis: total_axis as u32,
4894 inputs,
4895 },
4896 _ => Thunk::Concat {
4897 dst,
4898 outer: outer as u32,
4899 inner: inner as u32,
4900 total_axis: total_axis as u32,
4901 inputs,
4902 },
4903 }
4904 }
4905
4906 Op::GaussianSplatRender {
4907 width,
4908 height,
4909 tile_size,
4910 radius_scale,
4911 alpha_cutoff,
4912 max_splat_steps,
4913 transmittance_threshold,
4914 max_list_entries,
4915 } => {
4916 let elem_len =
4917 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4918 Thunk::GaussianSplatRender {
4919 positions_off: node_offset(arena, node.inputs[0]),
4920 positions_len: elem_len(node.inputs[0]),
4921 scales_off: node_offset(arena, node.inputs[1]),
4922 scales_len: elem_len(node.inputs[1]),
4923 rotations_off: node_offset(arena, node.inputs[2]),
4924 rotations_len: elem_len(node.inputs[2]),
4925 opacities_off: node_offset(arena, node.inputs[3]),
4926 opacities_len: elem_len(node.inputs[3]),
4927 colors_off: node_offset(arena, node.inputs[4]),
4928 colors_len: elem_len(node.inputs[4]),
4929 sh_coeffs_off: node_offset(arena, node.inputs[5]),
4930 sh_coeffs_len: elem_len(node.inputs[5]),
4931 meta_off: node_offset(arena, node.inputs[6]),
4932 dst_off: node_offset(arena, node.id),
4933 dst_len: node.shape.num_elements().unwrap_or(0),
4934 width: *width,
4935 height: *height,
4936 tile_size: *tile_size,
4937 radius_scale: *radius_scale,
4938 alpha_cutoff: *alpha_cutoff,
4939 max_splat_steps: *max_splat_steps,
4940 transmittance_threshold: *transmittance_threshold,
4941 max_list_entries: *max_list_entries,
4942 }
4943 }
4944
4945 Op::GaussianSplatRenderBackward {
4946 width,
4947 height,
4948 tile_size,
4949 radius_scale,
4950 alpha_cutoff,
4951 max_splat_steps,
4952 transmittance_threshold,
4953 max_list_entries,
4954 loss_grad_clip,
4955 sh_band,
4956 max_anisotropy,
4957 } => {
4958 let elem_len =
4959 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4960 Thunk::GaussianSplatRenderBackward {
4961 positions_off: node_offset(arena, node.inputs[0]),
4962 positions_len: elem_len(node.inputs[0]),
4963 scales_off: node_offset(arena, node.inputs[1]),
4964 scales_len: elem_len(node.inputs[1]),
4965 rotations_off: node_offset(arena, node.inputs[2]),
4966 rotations_len: elem_len(node.inputs[2]),
4967 opacities_off: node_offset(arena, node.inputs[3]),
4968 opacities_len: elem_len(node.inputs[3]),
4969 colors_off: node_offset(arena, node.inputs[4]),
4970 colors_len: elem_len(node.inputs[4]),
4971 sh_coeffs_off: node_offset(arena, node.inputs[5]),
4972 sh_coeffs_len: elem_len(node.inputs[5]),
4973 meta_off: node_offset(arena, node.inputs[6]),
4974 d_loss_off: node_offset(arena, node.inputs[7]),
4975 d_loss_len: elem_len(node.inputs[7]),
4976 packed_off: node_offset(arena, node.id),
4977 packed_len: node.shape.num_elements().unwrap_or(0),
4978 width: *width,
4979 height: *height,
4980 tile_size: *tile_size,
4981 radius_scale: *radius_scale,
4982 alpha_cutoff: *alpha_cutoff,
4983 max_splat_steps: *max_splat_steps,
4984 transmittance_threshold: *transmittance_threshold,
4985 max_list_entries: *max_list_entries,
4986 loss_grad_clip: *loss_grad_clip,
4987 sh_band: *sh_band,
4988 max_anisotropy: *max_anisotropy,
4989 }
4990 }
4991
4992 Op::GaussianSplatPrepare {
4993 width,
4994 height,
4995 tile_size,
4996 radius_scale,
4997 alpha_cutoff,
4998 max_splat_steps,
4999 transmittance_threshold,
5000 max_list_entries,
5001 } => {
5002 let elem_len =
5003 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5004 Thunk::GaussianSplatPrepare {
5005 positions_off: node_offset(arena, node.inputs[0]),
5006 positions_len: elem_len(node.inputs[0]),
5007 scales_off: node_offset(arena, node.inputs[1]),
5008 scales_len: elem_len(node.inputs[1]),
5009 rotations_off: node_offset(arena, node.inputs[2]),
5010 rotations_len: elem_len(node.inputs[2]),
5011 opacities_off: node_offset(arena, node.inputs[3]),
5012 opacities_len: elem_len(node.inputs[3]),
5013 colors_off: node_offset(arena, node.inputs[4]),
5014 colors_len: elem_len(node.inputs[4]),
5015 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5016 sh_coeffs_len: elem_len(node.inputs[5]),
5017 meta_off: node_offset(arena, node.inputs[6]),
5018 meta_len: elem_len(node.inputs[6]),
5019 prep_off: node_offset(arena, node.id),
5020 prep_len: node.shape.num_elements().unwrap_or(0),
5021 width: *width,
5022 height: *height,
5023 tile_size: *tile_size,
5024 radius_scale: *radius_scale,
5025 alpha_cutoff: *alpha_cutoff,
5026 max_splat_steps: *max_splat_steps,
5027 transmittance_threshold: *transmittance_threshold,
5028 max_list_entries: *max_list_entries,
5029 }
5030 }
5031
5032 Op::GaussianSplatRasterize {
5033 width,
5034 height,
5035 tile_size,
5036 alpha_cutoff,
5037 max_splat_steps,
5038 transmittance_threshold,
5039 max_list_entries,
5040 } => {
5041 let elem_len =
5042 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5043 let prep_id = node.inputs[0];
5044 let count = match &graph.node(prep_id).op {
5045 rlx_ir::Op::GaussianSplatPrepare { .. } => {
5046 elem_len(graph.node(prep_id).inputs[0]) / 3
5047 }
5048 _ => 1,
5049 };
5050 Thunk::GaussianSplatRasterize {
5051 prep_off: node_offset(arena, prep_id),
5052 prep_len: elem_len(prep_id),
5053 meta_off: node_offset(arena, node.inputs[1]),
5054 meta_len: elem_len(node.inputs[1]),
5055 dst_off: node_offset(arena, node.id),
5056 dst_len: node.shape.num_elements().unwrap_or(0),
5057 count,
5058 width: *width,
5059 height: *height,
5060 tile_size: *tile_size,
5061 alpha_cutoff: *alpha_cutoff,
5062 max_splat_steps: *max_splat_steps,
5063 transmittance_threshold: *transmittance_threshold,
5064 max_list_entries: *max_list_entries,
5065 }
5066 }
5067
5068 Op::Custom { name, attrs, .. } => {
5069 let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5070 panic!(
5071 "compile_thunks: no CPU kernel registered for \
5072 Op::Custom('{name}'). Register one via \
5073 rlx_cpu::op_registry::register_cpu_kernel \
5074 before compiling on the CPU backend."
5075 )
5076 });
5077 let inputs_v: Vec<(usize, u32, Shape)> = node
5078 .inputs
5079 .iter()
5080 .map(|&in_id| {
5081 let s = graph.node(in_id).shape.clone();
5082 let len = s.num_elements().unwrap_or(0) as u32;
5083 (node_offset(arena, in_id), len, s)
5084 })
5085 .collect();
5086 let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5087 Thunk::CustomOp {
5088 kernel,
5089 inputs: inputs_v,
5090 output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5091 attrs: attrs.clone(),
5092 }
5093 }
5094
5095 Op::Fft { inverse, norm } => {
5096 let shape = &node.shape;
5097 let meta = rlx_ir::fft::fft_meta(shape);
5098 let dtype = shape.dtype();
5099 assert!(
5100 matches!(
5101 dtype,
5102 rlx_ir::DType::F32 | rlx_ir::DType::F64 | rlx_ir::DType::C64
5103 ),
5104 "Op::Fft on CPU requires F32, F64, or C64, got {dtype:?}"
5105 );
5106 Thunk::Fft1d {
5107 src: node_offset(arena, node.inputs[0]),
5108 dst: node_offset(arena, node.id),
5109 outer: meta.outer as u32,
5110 n_complex: meta.n_complex as u32,
5111 inverse: *inverse,
5112 norm_tag: norm.tag(),
5113 dtype,
5114 }
5115 }
5116
5117 Op::CustomFn {
5118 fwd_body,
5119 num_inputs,
5120 ..
5121 } => {
5122 let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5128 let body_offsets: HashMap<NodeId, usize> = body_plan
5129 .assignments
5130 .iter()
5131 .map(|(id, slot)| (*id, slot.offset))
5132 .collect();
5133
5134 let mut body_input_ids: Vec<NodeId> = fwd_body
5135 .nodes()
5136 .iter()
5137 .filter(|n| matches!(n.op, Op::Input { .. }))
5138 .map(|n| n.id)
5139 .collect();
5140 body_input_ids.sort();
5141 assert_eq!(
5142 body_input_ids.len(),
5143 *num_inputs as usize,
5144 "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5145 body_input_ids.len(),
5146 *num_inputs,
5147 );
5148
5149 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5150 for n in fwd_body.nodes() {
5151 if let Op::Constant { data } = &n.op
5152 && body_arena.has_buffer(n.id)
5153 && !data.is_empty()
5154 {
5155 match n.shape.dtype() {
5156 rlx_ir::DType::F64 => {
5157 let off = body_arena.byte_offset(n.id);
5158 let buf = body_arena.raw_buf_mut();
5159 let nb = (buf.len() - off).min(data.len());
5160 buf[off..off + nb].copy_from_slice(&data[..nb]);
5161 }
5162 _ => {
5163 let buf = body_arena.slice_mut(n.id);
5164 let nf = data.len() / 4;
5165 let nl = buf.len().min(nf);
5166 for i in 0..nl {
5167 let bytes = [
5168 data[i * 4],
5169 data[i * 4 + 1],
5170 data[i * 4 + 2],
5171 data[i * 4 + 3],
5172 ];
5173 buf[i] = f32::from_le_bytes(bytes);
5174 }
5175 }
5176 }
5177 }
5178 }
5179 let body_init = body_arena.raw_buf().to_vec();
5180 let body_schedule = compile_thunks(fwd_body, &body_arena);
5181
5182 let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5184 .map(|i| {
5185 let body_in = body_input_ids[i];
5186 let body_off = body_offsets[&body_in];
5187 let outer_in = node.inputs[i];
5188 let outer_off = node_offset(arena, outer_in);
5189 let bytes = graph
5190 .node(outer_in)
5191 .shape
5192 .size_bytes()
5193 .expect("Op::CustomFn primal input must have static shape");
5194 (body_off, outer_off, bytes as u32)
5195 })
5196 .collect();
5197
5198 let body_output_id = fwd_body
5199 .outputs
5200 .first()
5201 .copied()
5202 .expect("Op::CustomFn fwd_body must declare exactly one output");
5203 let body_output_off = body_offsets[&body_output_id];
5204 let out_bytes = node
5205 .shape
5206 .size_bytes()
5207 .expect("Op::CustomFn output must have static shape");
5208
5209 Thunk::CustomFn {
5210 body: Arc::new(body_schedule),
5211 body_init: Arc::new(body_init),
5212 inputs: Arc::new(inputs_v),
5213 body_output_off,
5214 outer_output_off: node_offset(arena, node.id),
5215 out_bytes: out_bytes as u32,
5216 }
5217 }
5218
5219 _ => Thunk::Nop,
5220 };
5221 thunks.push(t);
5222 }
5223
5224 let cfg = crate::config::RuntimeConfig::global();
5225 let mask_thr = cfg.mask_binary_threshold;
5226 let mask_neg = cfg.attn_mask_neg_inf;
5227 let score_skip = cfg.score_skip_threshold;
5228
5229 let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5231 .iter()
5232 .filter(|t| !matches!(t, Thunk::Nop))
5233 .map(|thunk| {
5234 match thunk.clone() {
5235 Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5236
5237 Thunk::Sgemm { a, b, c, m, k, n } => {
5238 let (m, k, n) = (m as usize, k as usize, n as usize);
5239 Arc::new(move |base: *mut u8| unsafe {
5240 crate::blas::sgemm(
5241 sl(a, base, m * k),
5242 sl(b, base, k * n),
5243 sl_mut(c, base, m * n),
5244 m,
5245 k,
5246 n,
5247 );
5248 })
5249 }
5250
5251 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5252 let (n_, nrhs_) = (n as usize, nrhs as usize);
5253 Arc::new(move |base: *mut u8| unsafe {
5254 let a_src = sl_f64(a, base, n_ * n_);
5255 let b_src = sl_f64(b, base, n_ * nrhs_);
5256 let mut a_scratch: Vec<f64> = a_src.to_vec();
5257 let mut x_buf: Vec<f64> = b_src.to_vec();
5258 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5259 if info != 0 {
5260 panic!("DenseSolveF64: singular (info={info})");
5261 }
5262 sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5263 })
5264 }
5265
5266 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5267 let (n_, nrhs_) = (n as usize, nrhs as usize);
5268 Arc::new(move |base: *mut u8| unsafe {
5269 let a_src = sl(a, base, n_ * n_);
5270 let b_src = sl(b, base, n_ * nrhs_);
5271 let mut a_scratch: Vec<f32> = a_src.to_vec();
5272 let mut x_buf: Vec<f32> = b_src.to_vec();
5273 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5274 if info != 0 {
5275 panic!("DenseSolveF32: singular (info={info})");
5276 }
5277 sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5278 })
5279 }
5280
5281 Thunk::FusedMmBiasAct {
5282 a,
5283 w,
5284 bias,
5285 c,
5286 m,
5287 k,
5288 n,
5289 act,
5290 } => {
5291 let (m, k, n) = (m as usize, k as usize, n as usize);
5292 Arc::new(move |base: *mut u8| unsafe {
5293 let out = sl_mut(c, base, m * n);
5294 crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5295 match act {
5303 Some(Activation::Gelu) => {
5304 crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5305 }
5306 Some(other) => {
5307 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5308 apply_activation_inplace(out, other);
5309 }
5310 None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5311 }
5312 })
5313 }
5314
5315 Thunk::FusedResidualLN {
5316 x,
5317 res,
5318 bias,
5319 g,
5320 b,
5321 out,
5322 rows,
5323 h,
5324 eps,
5325 has_bias,
5326 } => {
5327 let (rows, h) = (rows as usize, h as usize);
5328 Arc::new(move |base: *mut u8| unsafe {
5329 let zero = vec![0f32; h]; let bi = if has_bias { sl(bias, base, h) } else { &zero };
5331 let xp = sl(x, base, rows * h).as_ptr() as usize;
5332 let rp = sl(res, base, rows * h).as_ptr() as usize;
5333 let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5334 let bp = bi.as_ptr() as usize;
5335 let gp = sl(g, base, h).as_ptr() as usize;
5336 let bbp = sl(b, base, h).as_ptr() as usize;
5337 crate::pool::par_for(rows, 4, &|off, cnt| {
5338 let xs = std::slice::from_raw_parts(
5339 (xp as *const f32).add(off * h),
5340 cnt * h,
5341 );
5342 let rs = std::slice::from_raw_parts(
5343 (rp as *const f32).add(off * h),
5344 cnt * h,
5345 );
5346 let os = std::slice::from_raw_parts_mut(
5347 (op as *mut f32).add(off * h),
5348 cnt * h,
5349 );
5350 let bi = std::slice::from_raw_parts(bp as *const f32, h);
5351 let g = std::slice::from_raw_parts(gp as *const f32, h);
5352 let b = std::slice::from_raw_parts(bbp as *const f32, h);
5353 crate::kernels::residual_bias_layer_norm(
5354 xs, rs, bi, g, b, os, cnt, h, eps,
5355 );
5356 });
5357 })
5358 }
5359
5360 Thunk::BiasAdd {
5361 src,
5362 bias,
5363 dst,
5364 m,
5365 n,
5366 } => {
5367 let (m, n) = (m as usize, n as usize);
5368 Arc::new(move |base: *mut u8| unsafe {
5369 let out = sl_mut(dst, base, m * n);
5370 out.copy_from_slice(sl(src, base, m * n));
5371 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5372 })
5373 }
5374
5375 Thunk::Gather {
5376 table,
5377 table_len,
5378 idx,
5379 dst,
5380 num_idx,
5381 trailing,
5382 } => {
5383 let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5384 Arc::new(move |base: *mut u8| unsafe {
5385 let tab = sl(table, base, tl);
5386 let ids = sl(idx, base, ni);
5387 let out = sl_mut(dst, base, ni * tr);
5388 for i in 0..ni {
5389 let row = ids[i] as usize;
5390 out[i * tr..(i + 1) * tr]
5391 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5392 }
5393 })
5394 }
5395
5396 Thunk::Narrow {
5397 src,
5398 dst,
5399 outer,
5400 src_stride,
5401 dst_stride,
5402 inner,
5403 elem_bytes,
5404 } => {
5405 narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5406 }
5407
5408 Thunk::Copy { src, dst, len } => {
5409 let len = len as usize;
5410 Arc::new(move |base: *mut u8| unsafe {
5411 sl_mut(dst, base, len).copy_from_slice(sl(src, base, len));
5412 })
5413 }
5414
5415 Thunk::Softmax { data, rows, cols } => {
5416 let (rows, cols) = (rows as usize, cols as usize);
5417 Arc::new(move |base: *mut u8| unsafe {
5418 crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5419 })
5420 }
5421
5422 Thunk::Cumsum {
5423 src,
5424 dst,
5425 rows,
5426 cols,
5427 exclusive,
5428 } => {
5429 let (rows, cols) = (rows as usize, cols as usize);
5430 Arc::new(move |base: *mut u8| unsafe {
5431 let s = sl(src, base, rows * cols);
5432 let d = sl_mut(dst, base, rows * cols);
5433 if exclusive {
5434 for r in 0..rows {
5435 let mut acc = 0.0f32;
5436 for c in 0..cols {
5437 d[r * cols + c] = acc;
5438 acc += s[r * cols + c];
5439 }
5440 }
5441 } else {
5442 for r in 0..rows {
5443 let mut acc = 0.0f32;
5444 for c in 0..cols {
5445 acc += s[r * cols + c];
5446 d[r * cols + c] = acc;
5447 }
5448 }
5449 }
5450 })
5451 }
5452
5453 Thunk::Sample {
5454 logits,
5455 dst,
5456 batch,
5457 vocab,
5458 top_k,
5459 top_p,
5460 temperature,
5461 seed,
5462 } => {
5463 let (b, v) = (batch as usize, vocab as usize);
5464 let k = (top_k as usize).min(v);
5465 Arc::new(move |base: *mut u8| unsafe {
5466 let lg = sl(logits, base, b * v);
5467 let out = sl_mut(dst, base, b);
5468 let mut rng =
5469 rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5470 for bi in 0..b {
5471 let row = &lg[bi * v..(bi + 1) * v];
5472 out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5473 }
5474 })
5475 }
5476
5477 Thunk::DequantMatMul {
5478 x,
5479 w_q,
5480 scale,
5481 zp,
5482 dst,
5483 m,
5484 k,
5485 n,
5486 block_size,
5487 is_asymmetric,
5488 } => {
5489 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5490 let n_blocks_per_col = k.div_ceil(bs);
5491 Arc::new(move |base: *mut u8| unsafe {
5492 let xs = sl(x, base, m * k);
5493 let raw = base.add(w_q);
5495 let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5496 let scales = sl(scale, base, n_blocks_per_col * n);
5497 let zps = if is_asymmetric {
5498 sl(zp, base, n_blocks_per_col * n)
5499 } else {
5500 &[][..]
5501 };
5502 let out = sl_mut(dst, base, m * n);
5503 dequant_matmul_int8(
5504 xs,
5505 w_bytes,
5506 scales,
5507 zps,
5508 out,
5509 m,
5510 k,
5511 n,
5512 bs,
5513 is_asymmetric,
5514 );
5515 })
5516 }
5517
5518 Thunk::DequantMatMulGguf {
5519 x,
5520 w_q,
5521 dst,
5522 m,
5523 k,
5524 n,
5525 scheme,
5526 } => {
5527 let (m, k, n) = (m as usize, k as usize, n as usize);
5528 let block_bytes = scheme.gguf_block_bytes() as usize;
5529 let block_elems = scheme.gguf_block_size() as usize;
5530 let total_bytes = (k * n) / block_elems * block_bytes;
5531 Arc::new(move |base: *mut u8| unsafe {
5532 let xs = sl(x, base, m * k);
5533 let w_bytes =
5534 std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
5535 let out = sl_mut(dst, base, m * n);
5536 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
5537 })
5538 }
5539
5540 Thunk::DequantMatMulInt4 {
5541 x,
5542 w_q,
5543 scale,
5544 zp,
5545 dst,
5546 m,
5547 k,
5548 n,
5549 block_size,
5550 is_asymmetric,
5551 } => {
5552 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5553 let n_blocks = k.div_ceil(bs);
5554 Arc::new(move |base: *mut u8| unsafe {
5555 let xs = sl(x, base, m * k);
5556 let w_bytes = std::slice::from_raw_parts(
5557 base.add(w_q) as *const u8,
5558 (k * n).div_ceil(2),
5559 );
5560 let scales = sl(scale, base, n_blocks * n);
5561 let zps = if is_asymmetric {
5562 sl(zp, base, n_blocks * n)
5563 } else {
5564 &[][..]
5565 };
5566 let out = sl_mut(dst, base, m * n);
5567 dequant_matmul_int4(
5568 xs,
5569 w_bytes,
5570 scales,
5571 zps,
5572 out,
5573 m,
5574 k,
5575 n,
5576 bs,
5577 is_asymmetric,
5578 );
5579 })
5580 }
5581
5582 Thunk::DequantMatMulFp8 {
5583 x,
5584 w_q,
5585 scale,
5586 dst,
5587 m,
5588 k,
5589 n,
5590 e5m2,
5591 } => {
5592 let (m, k, n) = (m as usize, k as usize, n as usize);
5593 Arc::new(move |base: *mut u8| unsafe {
5594 let xs = sl(x, base, m * k);
5595 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
5596 let scales = sl(scale, base, n);
5597 let out = sl_mut(dst, base, m * n);
5598 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
5599 })
5600 }
5601
5602 Thunk::DequantMatMulNvfp4 {
5603 x,
5604 w_q,
5605 scale,
5606 global_scale,
5607 dst,
5608 m,
5609 k,
5610 n,
5611 } => {
5612 let (m, k, n) = (m as usize, k as usize, n as usize);
5613 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
5614 Arc::new(move |base: *mut u8| unsafe {
5615 let xs = sl(x, base, m * k);
5616 let w_bytes = std::slice::from_raw_parts(
5617 base.add(w_q) as *const u8,
5618 (k * n).div_ceil(2),
5619 );
5620 let scale_bytes =
5621 std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
5622 let gs = sl(global_scale, base, 1)[0];
5623 let out = sl_mut(dst, base, m * n);
5624 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
5625 })
5626 }
5627
5628 Thunk::LoraMatMul {
5629 x,
5630 w,
5631 a,
5632 b,
5633 dst,
5634 m,
5635 k,
5636 n,
5637 r,
5638 scale,
5639 } => {
5640 let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
5641 Arc::new(move |base: *mut u8| unsafe {
5642 let xs = sl(x, base, m * k);
5643 let ws = sl(w, base, k * n);
5644 let a_s = sl(a, base, k * r);
5645 let bs = sl(b, base, r * n);
5646 let out = sl_mut(dst, base, m * n);
5647 crate::blas::sgemm(xs, ws, out, m, k, n);
5649 let mut tmp = vec![0f32; m * r];
5651 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
5652 if scale != 1.0 {
5656 for v in tmp.iter_mut() {
5657 *v *= scale;
5658 }
5659 }
5660 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
5661 })
5662 }
5663
5664 Thunk::LayerNorm {
5665 src,
5666 g,
5667 b,
5668 dst,
5669 rows,
5670 h,
5671 eps,
5672 } => {
5673 let (rows, h) = (rows as usize, h as usize);
5674 Arc::new(move |base: *mut u8| unsafe {
5675 let inp = sl(src, base, rows * h);
5676 let gamma = sl(g, base, h);
5677 let beta = sl(b, base, h);
5678 let out = sl_mut(dst, base, rows * h);
5679 for row in 0..rows {
5680 crate::kernels::layer_norm_row(
5681 &inp[row * h..(row + 1) * h],
5682 gamma,
5683 beta,
5684 &mut out[row * h..(row + 1) * h],
5685 h,
5686 eps,
5687 );
5688 }
5689 })
5690 }
5691
5692 Thunk::Attention {
5693 q,
5694 k,
5695 v,
5696 mask,
5697 out,
5698 batch,
5699 seq,
5700 kv_seq: _,
5701 heads,
5702 head_dim,
5703 mask_kind,
5704 q_row_stride,
5705 k_row_stride,
5706 v_row_stride,
5707 bhsd,
5708 } => {
5709 let (b, s, nh, dh) = (
5710 batch as usize,
5711 seq as usize,
5712 heads as usize,
5713 head_dim as usize,
5714 );
5715 let hs = nh * dh;
5716 let qrs = q_row_stride as usize;
5717 let krs = k_row_stride as usize;
5718 let vrs = v_row_stride as usize;
5719 let scale = (dh as f32).powf(-0.5);
5720 Arc::new(move |base: *mut u8| unsafe {
5721 let (q_len, k_len, v_len, o_len) = if bhsd {
5726 let n = b * nh * s * dh;
5727 (n, n, n, n)
5728 } else {
5729 (b * s * qrs, b * s * krs, b * s * vrs, b * s * hs)
5730 };
5731 let q_d = sl(q, base, q_len);
5732 let k_d = sl(k, base, k_len);
5733 let v_d = sl(v, base, v_len);
5734 let m_d: &[f32] = match mask_kind {
5735 rlx_ir::op::MaskKind::Custom => sl(mask, base, b * s),
5736 rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * s * s),
5737 _ => &[],
5738 };
5739 let o_d = sl_mut(out, base, o_len);
5740 let sdh = s * dh;
5741 let mut qh = vec![0f32; sdh];
5742 let mut kh = vec![0f32; sdh];
5743 let mut vh = vec![0f32; sdh];
5744 let mut sc = vec![0f32; s * s];
5745 let mut oh = vec![0f32; sdh];
5746 for bi in 0..b {
5747 for hi in 0..nh {
5748 for si in 0..s {
5749 let (q_off, k_off, v_off) = if bhsd {
5761 (
5762 bi * nh * s * dh + hi * s * dh + si * dh,
5763 bi * nh * s * dh + hi * s * dh + si * dh,
5764 bi * nh * s * dh + hi * s * dh + si * dh,
5765 )
5766 } else {
5767 (
5768 bi * s * qrs + si * qrs + hi * dh,
5769 bi * s * krs + si * krs + hi * dh,
5770 bi * s * vrs + si * vrs + hi * dh,
5771 )
5772 };
5773 qh[si * dh..(si + 1) * dh]
5774 .copy_from_slice(&q_d[q_off..q_off + dh]);
5775 kh[si * dh..(si + 1) * dh]
5776 .copy_from_slice(&k_d[k_off..k_off + dh]);
5777 vh[si * dh..(si + 1) * dh]
5778 .copy_from_slice(&v_d[v_off..v_off + dh]);
5779 }
5780 for qi in 0..s {
5781 for ki in 0..s {
5782 let mut dot = 0f32;
5783 for d in 0..dh {
5784 dot += qh[qi * dh + d] * kh[ki * dh + d];
5785 }
5786 sc[qi * s + ki] = dot * scale;
5787 }
5788 }
5789 match mask_kind {
5792 rlx_ir::op::MaskKind::None => {}
5793 rlx_ir::op::MaskKind::Causal => {
5794 for qi in 0..s {
5795 for ki in (qi + 1)..s {
5796 sc[qi * s + ki] = mask_neg;
5797 }
5798 }
5799 }
5800 rlx_ir::op::MaskKind::SlidingWindow(w) => {
5801 for qi in 0..s {
5802 let lo = qi.saturating_sub(w);
5803 for ki in 0..s {
5804 if ki < lo || ki > qi {
5805 sc[qi * s + ki] = mask_neg;
5806 }
5807 }
5808 }
5809 }
5810 rlx_ir::op::MaskKind::Custom => {
5811 for qi in 0..s {
5812 for ki in 0..s {
5813 if m_d[bi * s + ki] < mask_thr {
5814 sc[qi * s + ki] = mask_neg;
5815 }
5816 }
5817 }
5818 }
5819 rlx_ir::op::MaskKind::Bias => {
5820 let per_bh = s * s;
5821 let off = (bi * nh + hi) * per_bh;
5822 for i in 0..per_bh {
5823 sc[i] += m_d[off + i];
5824 }
5825 }
5826 }
5827 crate::naive::softmax(&mut sc, s, s);
5828 oh.fill(0.0);
5829 for qi in 0..s {
5830 for ki in 0..s {
5831 let w = sc[qi * s + ki];
5832 if w > score_skip {
5833 for d in 0..dh {
5834 oh[qi * dh + d] += w * vh[ki * dh + d];
5835 }
5836 }
5837 }
5838 }
5839 for si in 0..s {
5840 let off = if bhsd {
5841 bi * nh * s * dh + hi * s * dh + si * dh
5842 } else {
5843 bi * s * hs + si * hs + hi * dh
5844 };
5845 o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
5846 }
5847 }
5848 }
5849 })
5850 }
5851
5852 Thunk::FusedSwiGLU {
5853 src,
5854 dst,
5855 n_half,
5856 total,
5857 gate_first,
5858 } => {
5859 let n = n_half as usize;
5860 let t = total as usize;
5861 let outer = t / n;
5862 let in_total = outer * 2 * n;
5863 Arc::new(move |base: *mut u8| unsafe {
5864 let inp = sl(src, base, in_total);
5865 let out = sl_mut(dst, base, t);
5866 for o in 0..outer {
5867 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
5868 let out_row = &mut out[o * n..(o + 1) * n];
5869 for i in 0..n {
5870 let (up, gate) = if gate_first {
5871 (in_row[n + i], in_row[i])
5872 } else {
5873 (in_row[i], in_row[n + i])
5874 };
5875 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
5876 }
5877 }
5878 })
5879 }
5880
5881 Thunk::Concat {
5882 dst,
5883 outer,
5884 inner,
5885 total_axis,
5886 inputs,
5887 } => {
5888 let outer = outer as usize;
5889 let inner = inner as usize;
5890 let total_axis = total_axis as usize;
5891 let out_total = outer * total_axis * inner;
5892 let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
5895 let mut cum: usize = 0;
5896 for (src_off, in_axis) in &inputs {
5897 let in_axis = *in_axis as usize;
5898 layout.push((*src_off, cum * inner, in_axis * inner));
5899 cum += in_axis;
5900 }
5901 Arc::new(move |base: *mut u8| unsafe {
5902 let out = sl_mut(dst, base, out_total);
5903 let row_stride = total_axis * inner;
5904 for (src_off, dst_col_off, copy_per_row) in &layout {
5905 let in_total = outer * *copy_per_row;
5906 let inp = sl(*src_off, base, in_total);
5907 for o in 0..outer {
5908 let dst_row_start = o * row_stride + *dst_col_off;
5909 let src_row_start = o * *copy_per_row;
5910 out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
5911 &inp[src_row_start..src_row_start + *copy_per_row],
5912 );
5913 }
5914 }
5915 })
5916 }
5917
5918 Thunk::CustomOp {
5919 kernel,
5920 inputs,
5921 output,
5922 attrs,
5923 } => {
5924 let kernel = kernel.clone();
5930 let attrs = attrs.clone();
5931 let inputs = inputs.clone();
5932 let (out_off, out_len, out_shape) = output.clone();
5933 Arc::new(move |base: *mut u8| unsafe {
5934 dispatch_custom_op(
5935 &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
5936 );
5937 })
5938 }
5939
5940 Thunk::GaussianSplatRender {
5941 positions_off,
5942 positions_len,
5943 scales_off,
5944 scales_len,
5945 rotations_off,
5946 rotations_len,
5947 opacities_off,
5948 opacities_len,
5949 colors_off,
5950 colors_len,
5951 sh_coeffs_off,
5952 sh_coeffs_len,
5953 meta_off,
5954 dst_off,
5955 dst_len,
5956 width,
5957 height,
5958 tile_size,
5959 radius_scale,
5960 alpha_cutoff,
5961 max_splat_steps,
5962 transmittance_threshold,
5963 max_list_entries,
5964 } => Arc::new(move |base: *mut u8| unsafe {
5965 crate::splat::execute_gaussian_splat_render(
5966 positions_off,
5967 positions_len,
5968 scales_off,
5969 scales_len,
5970 rotations_off,
5971 rotations_len,
5972 opacities_off,
5973 opacities_len,
5974 colors_off,
5975 colors_len,
5976 sh_coeffs_off,
5977 sh_coeffs_len,
5978 meta_off,
5979 dst_off,
5980 dst_len,
5981 width,
5982 height,
5983 tile_size,
5984 radius_scale,
5985 alpha_cutoff,
5986 max_splat_steps,
5987 transmittance_threshold,
5988 max_list_entries,
5989 base,
5990 );
5991 }),
5992
5993 Thunk::GaussianSplatRenderBackward {
5994 positions_off,
5995 positions_len,
5996 scales_off,
5997 scales_len,
5998 rotations_off,
5999 rotations_len,
6000 opacities_off,
6001 opacities_len,
6002 colors_off,
6003 colors_len,
6004 sh_coeffs_off,
6005 sh_coeffs_len,
6006 meta_off,
6007 d_loss_off,
6008 d_loss_len,
6009 packed_off,
6010 packed_len,
6011 width,
6012 height,
6013 tile_size,
6014 radius_scale,
6015 alpha_cutoff,
6016 max_splat_steps,
6017 transmittance_threshold,
6018 max_list_entries,
6019 loss_grad_clip,
6020 sh_band,
6021 max_anisotropy,
6022 } => Arc::new(move |base: *mut u8| unsafe {
6023 crate::splat::execute_gaussian_splat_render_backward(
6024 positions_off,
6025 positions_len,
6026 scales_off,
6027 scales_len,
6028 rotations_off,
6029 rotations_len,
6030 opacities_off,
6031 opacities_len,
6032 colors_off,
6033 colors_len,
6034 sh_coeffs_off,
6035 sh_coeffs_len,
6036 meta_off,
6037 d_loss_off,
6038 d_loss_len,
6039 packed_off,
6040 packed_len,
6041 width,
6042 height,
6043 tile_size,
6044 radius_scale,
6045 alpha_cutoff,
6046 max_splat_steps,
6047 transmittance_threshold,
6048 max_list_entries,
6049 loss_grad_clip,
6050 sh_band,
6051 max_anisotropy,
6052 base,
6053 );
6054 }),
6055
6056 Thunk::GaussianSplatPrepare {
6057 positions_off,
6058 positions_len,
6059 scales_off,
6060 scales_len,
6061 rotations_off,
6062 rotations_len,
6063 opacities_off,
6064 opacities_len,
6065 colors_off,
6066 colors_len,
6067 sh_coeffs_off,
6068 sh_coeffs_len,
6069 meta_off,
6070 meta_len,
6071 prep_off,
6072 prep_len,
6073 width,
6074 height,
6075 tile_size,
6076 radius_scale,
6077 alpha_cutoff,
6078 max_splat_steps,
6079 transmittance_threshold,
6080 max_list_entries,
6081 } => Arc::new(move |base: *mut u8| unsafe {
6082 crate::splat::execute_gaussian_splat_prepare(
6083 positions_off,
6084 positions_len,
6085 scales_off,
6086 scales_len,
6087 rotations_off,
6088 rotations_len,
6089 opacities_off,
6090 opacities_len,
6091 colors_off,
6092 colors_len,
6093 sh_coeffs_off,
6094 sh_coeffs_len,
6095 meta_off,
6096 meta_len,
6097 prep_off,
6098 prep_len,
6099 width,
6100 height,
6101 tile_size,
6102 radius_scale,
6103 alpha_cutoff,
6104 max_splat_steps,
6105 transmittance_threshold,
6106 max_list_entries,
6107 base,
6108 );
6109 }),
6110
6111 Thunk::GaussianSplatRasterize {
6112 prep_off,
6113 prep_len,
6114 meta_off,
6115 meta_len,
6116 dst_off,
6117 dst_len,
6118 count,
6119 width,
6120 height,
6121 tile_size,
6122 alpha_cutoff,
6123 max_splat_steps,
6124 transmittance_threshold,
6125 max_list_entries,
6126 } => Arc::new(move |base: *mut u8| unsafe {
6127 crate::splat::execute_gaussian_splat_rasterize(
6128 prep_off,
6129 prep_len,
6130 meta_off,
6131 meta_len,
6132 dst_off,
6133 dst_len,
6134 count,
6135 width,
6136 height,
6137 tile_size,
6138 alpha_cutoff,
6139 max_splat_steps,
6140 transmittance_threshold,
6141 max_list_entries,
6142 base,
6143 );
6144 }),
6145
6146 Thunk::Fft1d {
6147 src,
6148 dst,
6149 outer,
6150 n_complex,
6151 inverse,
6152 norm_tag,
6153 dtype,
6154 } => {
6155 let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6156 rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6157 execute_fft1d_f64(
6158 src,
6159 dst,
6160 outer as usize,
6161 n_complex as usize,
6162 inverse,
6163 norm_tag,
6164 base,
6165 );
6166 }),
6167 rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6168 execute_fft1d_f32(
6169 src,
6170 dst,
6171 outer as usize,
6172 n_complex as usize,
6173 inverse,
6174 norm_tag,
6175 base,
6176 );
6177 }),
6178 rlx_ir::DType::C64 => Arc::new(move |base: *mut u8| unsafe {
6179 execute_fft1d_c64(
6180 src,
6181 dst,
6182 outer as usize,
6183 n_complex as usize,
6184 inverse,
6185 norm_tag,
6186 base,
6187 );
6188 }),
6189 other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
6190 };
6191 f
6192 }
6193
6194 _ => Arc::new(|_: *mut u8| {}),
6195 }
6196 })
6197 .collect();
6198
6199 let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6203 .and_then(|v| v.parse().ok())
6204 .unwrap_or(64);
6205 let should_fuse = thunks.iter().any(|t| match t {
6206 Thunk::Attention { batch, seq, .. } => {
6207 (*batch as usize) * (*seq as usize) <= fuse_threshold
6208 }
6209 _ => false,
6210 });
6211
6212 if should_fuse {
6213 let active: Vec<usize> = thunks
6215 .iter()
6216 .enumerate()
6217 .filter(|(_, t)| !matches!(t, Thunk::Nop))
6218 .map(|(i, _)| i)
6219 .collect();
6220
6221 let mut kill = vec![false; thunks.len()]; let mut insertions: Vec<(usize, Thunk)> = Vec::new(); let mut ai = 0;
6225 while ai < active.len() {
6226 let a = |off: usize| -> Option<(usize, &Thunk)> {
6228 active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6229 };
6230
6231 let matched = (|| {
6233 let (_i0, t0) = a(0)?;
6234 let (_, t1) = a(1)?;
6235 let (_, t2) = a(2)?;
6236 let (_, t3) = a(3)?;
6237
6238 let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6240 Thunk::FusedMmBiasAct {
6241 a,
6242 w,
6243 bias,
6244 n: _,
6245 act: None,
6246 ..
6247 } => (*a, *w, *bias, true),
6248 Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6249 _ => return None,
6250 };
6251
6252 if !matches!(t1, Thunk::Narrow { .. }) {
6254 return None;
6255 }
6256 if !matches!(t2, Thunk::Narrow { .. }) {
6257 return None;
6258 }
6259 if !matches!(t3, Thunk::Narrow { .. }) {
6260 return None;
6261 }
6262
6263 let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6265 _,
6266 Thunk::Rope {
6267 cos, sin, cos_len, ..
6268 },
6269 )) = a(4)
6270 {
6271 if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6272 if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6273 (true, 6, *cos, *sin, *cos_len)
6274 } else {
6275 return None;
6276 }
6277 } else {
6278 return None;
6279 }
6280 } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6281 (false, 4, 0, 0, 0)
6282 } else {
6283 return None;
6284 };
6285
6286 let (_attn_real_idx, attn_t) = a(attn_ai)?;
6287 let (batch, seq, heads, head_dim, mask) = match attn_t {
6288 Thunk::Attention {
6289 batch,
6290 seq,
6291 heads,
6292 head_dim,
6293 mask,
6294 ..
6295 } => (*batch, *seq, *heads, *head_dim, *mask),
6296 _ => return None,
6297 };
6298
6299 let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6301 let (out_w, out_b, out_dst) = match out_t {
6302 Thunk::FusedMmBiasAct {
6303 w,
6304 bias,
6305 c,
6306 act: None,
6307 ..
6308 } => (*w, *bias, *c),
6309 Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6310 _ => return None,
6311 };
6312
6313 let hs = heads * head_dim;
6314 let total_active = attn_ai + 2; Some((
6317 total_active,
6318 Thunk::FusedAttnBlock {
6319 hidden,
6320 qkv_w,
6321 out_w,
6322 mask,
6323 out: out_dst,
6324 qkv_b: if has_b { qkv_b } else { 0 },
6325 out_b: if has_b { out_b } else { 0 },
6326 cos: cos_off,
6327 sin: sin_off,
6328 cos_len: cl,
6329 batch,
6330 seq,
6331 hs,
6332 nh: heads,
6333 dh: head_dim,
6334 has_bias: has_b,
6335 has_rope,
6336 },
6337 ))
6338 })();
6339
6340 if let Some((count, fused_thunk)) = matched {
6341 for off in 0..count {
6343 if let Some(&idx) = active.get(ai + off) {
6344 kill[idx] = true;
6345 }
6346 }
6347 insertions.push((active[ai], fused_thunk));
6349 ai += count;
6350 } else {
6351 ai += 1;
6352 }
6353 }
6354
6355 if !insertions.is_empty() {
6357 let mut new_thunks = Vec::with_capacity(thunks.len());
6358 let mut insert_idx = 0;
6359 for (i, t) in thunks.into_iter().enumerate() {
6360 if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6361 new_thunks.push(insertions[insert_idx].1.clone());
6362 insert_idx += 1;
6363 }
6364 if !kill[i] {
6365 new_thunks.push(t);
6366 }
6367 }
6368 if cfg.verbose >= 1 {
6369 eprintln!(
6370 "[rlx] fused_attention: {} attention blocks fused",
6371 insertions.len()
6372 );
6373 }
6374 thunks = new_thunks;
6375 }
6376 }
6377
6378 if should_fuse {
6383 let active: Vec<usize> = thunks
6384 .iter()
6385 .enumerate()
6386 .filter(|(_, t)| !matches!(t, Thunk::Nop))
6387 .map(|(i, _)| i)
6388 .collect();
6389
6390 let mut kill = vec![false; thunks.len()];
6391 let mut insertions: Vec<(usize, Thunk)> = Vec::new();
6392
6393 let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
6394
6395 let mut ai = 0;
6396 while ai < active.len() {
6397 let bert_match = (|| -> Option<usize> {
6399 let fab = a(ai)?;
6400 let rln1 = a(ai + 1)?;
6401 let ffn1 = a(ai + 2)?;
6402 let ffn2 = a(ai + 3)?;
6403 let rln2 = a(ai + 4)?;
6404
6405 let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
6406 Thunk::FusedAttnBlock {
6407 hidden,
6408 qkv_w,
6409 qkv_b,
6410 out_w,
6411 out_b,
6412 mask,
6413 batch,
6414 seq,
6415 hs,
6416 nh,
6417 dh,
6418 has_bias: true,
6419 has_rope: false,
6420 ..
6421 } => (
6422 *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
6423 ),
6424 _ => return None,
6425 };
6426 let (ln1_g, ln1_b, eps1) = match rln1 {
6427 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6428 _ => return None,
6429 };
6430 let (fc1_w, fc1_b, int_dim) = match ffn1 {
6431 Thunk::FusedMmBiasAct {
6432 w,
6433 bias,
6434 n,
6435 act: Some(Activation::Gelu),
6436 ..
6437 } => (*w, *bias, *n),
6438 _ => return None,
6439 };
6440 let (fc2_w, fc2_b) = match ffn2 {
6441 Thunk::FusedMmBiasAct {
6442 w, bias, act: None, ..
6443 } => (*w, *bias),
6444 _ => return None,
6445 };
6446 let (ln2_g, ln2_b, eps2, out) = match rln2 {
6447 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6448 _ => return None,
6449 };
6450
6451 for off in 0..5 {
6452 kill[active[ai + off]] = true;
6453 }
6454 insertions.push((
6455 active[ai],
6456 Thunk::FusedBertLayer {
6457 hidden,
6458 qkv_w,
6459 qkv_b,
6460 out_w,
6461 out_b,
6462 mask,
6463 ln1_g,
6464 ln1_b,
6465 eps1,
6466 fc1_w,
6467 fc1_b,
6468 fc2_w,
6469 fc2_b,
6470 ln2_g,
6471 ln2_b,
6472 eps2,
6473 out,
6474 batch,
6475 seq,
6476 hs,
6477 nh,
6478 dh,
6479 int_dim,
6480 },
6481 ));
6482 Some(5)
6483 })();
6484 if let Some(n) = bert_match {
6485 ai += n;
6486 continue;
6487 }
6488
6489 #[allow(unreachable_code)]
6493 let nomic_match = (|| -> Option<usize> {
6494 return None; let fab = a(ai)?;
6496 let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
6497 match fab {
6498 Thunk::FusedAttnBlock {
6499 hidden,
6500 qkv_w,
6501 out_w,
6502 mask,
6503 cos,
6504 sin,
6505 cos_len,
6506 batch,
6507 seq,
6508 hs,
6509 nh,
6510 dh,
6511 has_bias: false,
6512 has_rope: true,
6513 ..
6514 } => (
6515 *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
6516 *hs, *nh, *dh,
6517 ),
6518 _ => return None,
6519 };
6520 let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
6522 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6523 _ => return None,
6524 };
6525 let fused_fc_w = match a(ai + 2)? {
6527 Thunk::Sgemm { b: w, .. } => *w,
6528 _ => return None,
6529 };
6530 if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
6532 return None;
6533 }
6534 if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
6535 return None;
6536 }
6537 if !matches!(
6539 a(ai + 5)?,
6540 Thunk::ActivationInPlace {
6541 act: Activation::Silu,
6542 ..
6543 }
6544 ) {
6545 return None;
6546 }
6547 if !matches!(
6549 a(ai + 6)?,
6550 Thunk::BinaryFull {
6551 op: BinaryOp::Mul,
6552 ..
6553 }
6554 ) {
6555 return None;
6556 }
6557 let fc2_w = match a(ai + 7)? {
6559 Thunk::Sgemm { b: w, .. } => *w,
6560 _ => return None,
6561 };
6562 let int_dim = match a(ai + 3)? {
6564 Thunk::Narrow { inner, .. } => *inner,
6565 _ => return None,
6566 };
6567 let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
6569 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6570 _ => return None,
6571 };
6572
6573 for off in 0..9 {
6574 kill[active[ai + off]] = true;
6575 }
6576 insertions.push((
6577 active[ai],
6578 Thunk::FusedNomicLayer {
6579 hidden,
6580 qkv_w,
6581 out_w,
6582 mask,
6583 cos,
6584 sin,
6585 cos_len,
6586 ln1_g,
6587 ln1_b,
6588 eps1,
6589 fc11_w: fused_fc_w,
6590 fc12_w: 0,
6591 fc2_w,
6592 ln2_g,
6593 ln2_b,
6594 eps2,
6595 out,
6596 batch,
6597 seq,
6598 hs,
6599 nh,
6600 dh,
6601 int_dim,
6602 },
6603 ));
6604 Some(9)
6605 })();
6606 if let Some(n) = nomic_match {
6607 ai += n;
6608 continue;
6609 }
6610
6611 ai += 1;
6612 }
6613
6614 if !insertions.is_empty() {
6615 let mut new_thunks = Vec::with_capacity(thunks.len());
6616 let mut ins_idx = 0;
6617 for (i, t) in thunks.into_iter().enumerate() {
6618 if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
6619 new_thunks.push(insertions[ins_idx].1.clone());
6620 ins_idx += 1;
6621 }
6622 if !kill[i] {
6623 new_thunks.push(t);
6624 }
6625 }
6626 if cfg.verbose >= 1 {
6627 eprintln!(
6628 "[rlx] fused_layer: {} full transformer layers fused",
6629 insertions.len()
6630 );
6631 }
6632 thunks = new_thunks;
6633 }
6634 }
6635
6636 {
6648 let mut read_offsets: HashMap<usize, usize> = HashMap::new();
6651 for t in &thunks {
6652 for off in thunk_read_offsets(t) {
6653 *read_offsets.entry(off).or_insert(0) += 1;
6654 }
6655 }
6656
6657 let mut fused_count = 0usize;
6658 for i in 0..thunks.len().saturating_sub(1) {
6659 let narrow = match &thunks[i] {
6662 Thunk::Narrow { .. } => i,
6663 _ => continue,
6664 };
6665 let mut j = narrow + 1;
6667 while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
6668 j += 1;
6669 }
6670 if j >= thunks.len() {
6671 continue;
6672 }
6673 let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
6675 Thunk::Narrow {
6676 src,
6677 dst,
6678 src_stride,
6679 ..
6680 } => (*src, *dst, *src_stride),
6681 _ => continue,
6682 };
6683 let rope_reads_narrow = matches!(&thunks[j],
6684 Thunk::Rope { src, .. } if *src == n_dst);
6685 if !rope_reads_narrow {
6686 continue;
6687 }
6688 if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
6692 continue;
6693 }
6694
6695 if let Thunk::Rope {
6698 src,
6699 src_row_stride,
6700 ..
6701 } = &mut thunks[j]
6702 {
6703 *src = n_src;
6704 *src_row_stride = n_src_stride;
6705 }
6706 thunks[narrow] = Thunk::Nop;
6707 fused_count += 1;
6708 }
6709
6710 if fused_count > 0 && cfg.verbose >= 1 {
6711 eprintln!(
6712 "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
6713 fused_count
6714 );
6715 }
6716 }
6717
6718 {
6730 let mut read_counts: HashMap<usize, usize> = HashMap::new();
6731 for t in &thunks {
6732 for off in thunk_read_offsets(t) {
6733 *read_counts.entry(off).or_insert(0) += 1;
6734 }
6735 }
6736 let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
6738 for (i, t) in thunks.iter().enumerate() {
6739 if let Thunk::Narrow { dst, .. } = t {
6740 dst_to_idx.insert(*dst, i);
6741 }
6742 }
6743
6744 let mut fused_count = 0usize;
6745 for i in 0..thunks.len() {
6746 let (q_off, k_off, v_off) = match &thunks[i] {
6747 Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
6748 _ => continue,
6749 };
6750 let q_n = match dst_to_idx.get(&q_off).copied() {
6752 Some(x) => x,
6753 None => continue,
6754 };
6755 let k_n = match dst_to_idx.get(&k_off).copied() {
6756 Some(x) => x,
6757 None => continue,
6758 };
6759 let v_n = match dst_to_idx.get(&v_off).copied() {
6760 Some(x) => x,
6761 None => continue,
6762 };
6763 if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
6765 continue;
6766 }
6767 if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
6768 continue;
6769 }
6770 if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
6771 continue;
6772 }
6773
6774 let (q_src, q_stride) = match &thunks[q_n] {
6775 Thunk::Narrow {
6776 src, src_stride, ..
6777 } => (*src, *src_stride),
6778 _ => continue,
6779 };
6780 let (k_src, k_stride) = match &thunks[k_n] {
6781 Thunk::Narrow {
6782 src, src_stride, ..
6783 } => (*src, *src_stride),
6784 _ => continue,
6785 };
6786 let (v_src, v_stride) = match &thunks[v_n] {
6787 Thunk::Narrow {
6788 src, src_stride, ..
6789 } => (*src, *src_stride),
6790 _ => continue,
6791 };
6792
6793 if let Thunk::Attention {
6794 q,
6795 k,
6796 v,
6797 q_row_stride,
6798 k_row_stride,
6799 v_row_stride,
6800 ..
6801 } = &mut thunks[i]
6802 {
6803 *q = q_src;
6804 *k = k_src;
6805 *v = v_src;
6806 *q_row_stride = q_stride;
6807 *k_row_stride = k_stride;
6808 *v_row_stride = v_stride;
6809 }
6810 thunks[q_n] = Thunk::Nop;
6811 thunks[k_n] = Thunk::Nop;
6812 thunks[v_n] = Thunk::Nop;
6813 fused_count += 1;
6814 }
6815
6816 if fused_count > 0 && cfg.verbose >= 1 {
6817 eprintln!(
6818 "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
6819 fused_count
6820 );
6821 }
6822 }
6823
6824 ThunkSchedule {
6825 thunks,
6826 moe_resident: None,
6827 moe_resident_layers: None,
6828 moe_topk_capture: None,
6829 mask_threshold: cfg.mask_binary_threshold,
6830 mask_neg_inf: cfg.attn_mask_neg_inf,
6831 score_skip: cfg.score_skip_threshold,
6832 compiled_fns,
6833 }
6834}
6835
6836fn get_len(graph: &Graph, id: NodeId) -> usize {
6837 graph.node(id).shape.num_elements().unwrap_or(0)
6838}
6839
6840fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
6842 let dims = graph.node(id).shape.dims();
6843 let mut out = Vec::with_capacity(dims.len());
6844 for d in dims {
6845 if let Some(s) = match d {
6846 rlx_ir::Dim::Static(s) => Some(*s),
6847 _ => None,
6848 } {
6849 out.push(s);
6850 } else {
6851 return Vec::new();
6852 }
6853 }
6854 out
6855}
6856
6857fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
6875 if rhs_dims.len() > out_dims.len() {
6876 return false;
6877 }
6878 let off = out_dims.len() - rhs_dims.len();
6879 for i in 0..rhs_dims.len() {
6880 let r = match rhs_dims[i] {
6881 rlx_ir::Dim::Static(n) => n,
6882 _ => return false,
6883 };
6884 let o = match out_dims[off + i] {
6885 rlx_ir::Dim::Static(n) => n,
6886 _ => return false,
6887 };
6888 if r != o {
6889 return false;
6890 }
6891 }
6892 true
6893}
6894
6895fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
6896 let r_out = out_dims.len();
6897 let r_in = in_dims.len();
6898 assert!(
6899 r_in <= r_out,
6900 "broadcast: input rank {r_in} > output rank {r_out}"
6901 );
6902 let pad = r_out - r_in;
6903 let mut strides = vec![0u32; r_out];
6904 let mut acc: usize = 1;
6905 for d in (0..r_out).rev() {
6906 let in_size = if d < pad { 1 } else { in_dims[d - pad] };
6907 if in_size == 1 {
6908 strides[d] = 0;
6909 } else {
6910 assert_eq!(
6911 in_size, out_dims[d],
6912 "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
6913 out_dims[d]
6914 );
6915 strides[d] = acc as u32;
6916 acc *= in_size;
6917 }
6918 }
6919 strides
6920}
6921
6922pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6926 let base = arena_buf.as_mut_ptr();
6927 for f in &schedule.compiled_fns {
6928 f(base);
6929 }
6930}
6931
6932pub fn execute_thunks_active(
6937 schedule: &ThunkSchedule,
6938 _arena_buf: &mut [u8],
6939 _actual: usize,
6940 _upper: usize,
6941) -> bool {
6942 let _ = schedule;
6943 false
6944}
6945
6946struct MoeResidencyGuard;
6948impl Drop for MoeResidencyGuard {
6949 fn drop(&mut self) {
6950 if let Some(stats) = crate::moe_residency::take_stats() {
6951 crate::moe_residency::stash_last_forward_stats(stats);
6952 } else {
6953 crate::moe_residency::clear_mask();
6954 }
6955 }
6956}
6957
6958pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6959 crate::moe_residency::reset_gmm_counters();
6960 if let Some(layers) = schedule.moe_resident_layers.clone() {
6961 crate::moe_residency::set_per_layer_masks(Some(layers));
6962 } else {
6963 crate::moe_residency::set_mask(schedule.moe_resident.clone());
6964 }
6965 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
6966 cap.clear();
6967 }
6968 let _moe_guard = MoeResidencyGuard;
6969 let base = arena_buf.as_mut_ptr();
6970 let mask_thr = schedule.mask_threshold;
6971 let mask_neg = schedule.mask_neg_inf;
6972 let score_thr = schedule.score_skip;
6973 let thunks = &schedule.thunks;
6974 let len = thunks.len();
6975
6976 let max_h = thunks
6978 .iter()
6979 .filter_map(|t| match t {
6980 Thunk::FusedResidualLN { h, .. }
6981 | Thunk::FusedResidualRmsNorm { h, .. }
6982 | Thunk::LayerNorm { h, .. } => Some(*h as usize),
6983 _ => None,
6984 })
6985 .max()
6986 .unwrap_or(0);
6987 let zero_bias = vec![0f32; max_h];
6988
6989 let max_sdpa = thunks
6992 .iter()
6993 .filter_map(|t| match t {
6994 Thunk::Attention {
6995 batch,
6996 seq,
6997 kv_seq,
6998 heads,
6999 head_dim,
7000 ..
7001 } => Some((
7002 *batch as usize,
7003 (*seq as usize).max(*kv_seq as usize),
7004 *heads as usize,
7005 *head_dim as usize,
7006 )),
7007 _ => None,
7008 })
7009 .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
7010 (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7011 });
7012 let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7013 let max_units = max_batch * max_heads;
7014 let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7015
7016 let fl = thunks
7018 .iter()
7019 .filter_map(|t| match t {
7020 Thunk::FusedBertLayer {
7021 batch,
7022 seq,
7023 hs,
7024 int_dim,
7025 ..
7026 } => {
7027 let m = (*batch as usize) * (*seq as usize);
7028 let h = *hs as usize;
7029 let id = *int_dim as usize;
7030 Some((m, h, id, m * (*seq as usize)))
7031 }
7032 Thunk::FusedNomicLayer {
7033 batch,
7034 seq,
7035 hs,
7036 int_dim,
7037 ..
7038 } => {
7039 let m = (*batch as usize) * (*seq as usize);
7040 let h = *hs as usize;
7041 let id = *int_dim as usize;
7042 Some((m, h, id, m * (*seq as usize)))
7043 }
7044 _ => None,
7045 })
7046 .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7047 (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7048 });
7049 let (fl_m, fl_h, fl_int, fl_ss) = fl;
7050 let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7051 let mut fl_attn = vec![0f32; fl_m * fl_h];
7052 let mut fl_res = vec![0f32; fl_m * fl_h];
7053 let mut fl_normed = vec![0f32; fl_m * fl_h];
7054 let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; let mut fl_sc = vec![0f32; fl_ss.max(1)];
7056
7057 for i in 0..len {
7058 let thunk = unsafe { thunks.get_unchecked(i) };
7059 match thunk {
7060 Thunk::Nop => {}
7061
7062 Thunk::GaussianSplatRender {
7063 positions_off,
7064 positions_len,
7065 scales_off,
7066 scales_len,
7067 rotations_off,
7068 rotations_len,
7069 opacities_off,
7070 opacities_len,
7071 colors_off,
7072 colors_len,
7073 sh_coeffs_off,
7074 sh_coeffs_len,
7075 meta_off,
7076 dst_off,
7077 dst_len,
7078 width,
7079 height,
7080 tile_size,
7081 radius_scale,
7082 alpha_cutoff,
7083 max_splat_steps,
7084 transmittance_threshold,
7085 max_list_entries,
7086 } => unsafe {
7087 crate::splat::execute_gaussian_splat_render(
7088 *positions_off,
7089 *positions_len,
7090 *scales_off,
7091 *scales_len,
7092 *rotations_off,
7093 *rotations_len,
7094 *opacities_off,
7095 *opacities_len,
7096 *colors_off,
7097 *colors_len,
7098 *sh_coeffs_off,
7099 *sh_coeffs_len,
7100 *meta_off,
7101 *dst_off,
7102 *dst_len,
7103 *width,
7104 *height,
7105 *tile_size,
7106 *radius_scale,
7107 *alpha_cutoff,
7108 *max_splat_steps,
7109 *transmittance_threshold,
7110 *max_list_entries,
7111 base,
7112 );
7113 },
7114
7115 Thunk::GaussianSplatRenderBackward {
7116 positions_off,
7117 positions_len,
7118 scales_off,
7119 scales_len,
7120 rotations_off,
7121 rotations_len,
7122 opacities_off,
7123 opacities_len,
7124 colors_off,
7125 colors_len,
7126 sh_coeffs_off,
7127 sh_coeffs_len,
7128 meta_off,
7129 d_loss_off,
7130 d_loss_len,
7131 packed_off,
7132 packed_len,
7133 width,
7134 height,
7135 tile_size,
7136 radius_scale,
7137 alpha_cutoff,
7138 max_splat_steps,
7139 transmittance_threshold,
7140 max_list_entries,
7141 loss_grad_clip,
7142 sh_band,
7143 max_anisotropy,
7144 } => unsafe {
7145 crate::splat::execute_gaussian_splat_render_backward(
7146 *positions_off,
7147 *positions_len,
7148 *scales_off,
7149 *scales_len,
7150 *rotations_off,
7151 *rotations_len,
7152 *opacities_off,
7153 *opacities_len,
7154 *colors_off,
7155 *colors_len,
7156 *sh_coeffs_off,
7157 *sh_coeffs_len,
7158 *meta_off,
7159 *d_loss_off,
7160 *d_loss_len,
7161 *packed_off,
7162 *packed_len,
7163 *width,
7164 *height,
7165 *tile_size,
7166 *radius_scale,
7167 *alpha_cutoff,
7168 *max_splat_steps,
7169 *transmittance_threshold,
7170 *max_list_entries,
7171 *loss_grad_clip,
7172 *sh_band,
7173 *max_anisotropy,
7174 base,
7175 );
7176 },
7177
7178 Thunk::GaussianSplatPrepare {
7179 positions_off,
7180 positions_len,
7181 scales_off,
7182 scales_len,
7183 rotations_off,
7184 rotations_len,
7185 opacities_off,
7186 opacities_len,
7187 colors_off,
7188 colors_len,
7189 sh_coeffs_off,
7190 sh_coeffs_len,
7191 meta_off,
7192 meta_len,
7193 prep_off,
7194 prep_len,
7195 width,
7196 height,
7197 tile_size,
7198 radius_scale,
7199 alpha_cutoff,
7200 max_splat_steps,
7201 transmittance_threshold,
7202 max_list_entries,
7203 } => unsafe {
7204 crate::splat::execute_gaussian_splat_prepare(
7205 *positions_off,
7206 *positions_len,
7207 *scales_off,
7208 *scales_len,
7209 *rotations_off,
7210 *rotations_len,
7211 *opacities_off,
7212 *opacities_len,
7213 *colors_off,
7214 *colors_len,
7215 *sh_coeffs_off,
7216 *sh_coeffs_len,
7217 *meta_off,
7218 *meta_len,
7219 *prep_off,
7220 *prep_len,
7221 *width,
7222 *height,
7223 *tile_size,
7224 *radius_scale,
7225 *alpha_cutoff,
7226 *max_splat_steps,
7227 *transmittance_threshold,
7228 *max_list_entries,
7229 base,
7230 );
7231 },
7232
7233 Thunk::GaussianSplatRasterize {
7234 prep_off,
7235 prep_len,
7236 meta_off,
7237 meta_len,
7238 dst_off,
7239 dst_len,
7240 count,
7241 width,
7242 height,
7243 tile_size,
7244 alpha_cutoff,
7245 max_splat_steps,
7246 transmittance_threshold,
7247 max_list_entries,
7248 } => unsafe {
7249 crate::splat::execute_gaussian_splat_rasterize(
7250 *prep_off,
7251 *prep_len,
7252 *meta_off,
7253 *meta_len,
7254 *dst_off,
7255 *dst_len,
7256 *count,
7257 *width,
7258 *height,
7259 *tile_size,
7260 *alpha_cutoff,
7261 *max_splat_steps,
7262 *transmittance_threshold,
7263 *max_list_entries,
7264 base,
7265 );
7266 },
7267
7268 Thunk::Fft1d {
7269 src,
7270 dst,
7271 outer,
7272 n_complex,
7273 inverse,
7274 norm_tag,
7275 dtype,
7276 } => unsafe {
7277 match dtype {
7278 rlx_ir::DType::F64 => execute_fft1d_f64(
7279 *src,
7280 *dst,
7281 *outer as usize,
7282 *n_complex as usize,
7283 *inverse,
7284 *norm_tag,
7285 base,
7286 ),
7287 rlx_ir::DType::F32 => execute_fft1d_f32(
7288 *src,
7289 *dst,
7290 *outer as usize,
7291 *n_complex as usize,
7292 *inverse,
7293 *norm_tag,
7294 base,
7295 ),
7296 rlx_ir::DType::C64 => execute_fft1d_c64(
7297 *src,
7298 *dst,
7299 *outer as usize,
7300 *n_complex as usize,
7301 *inverse,
7302 *norm_tag,
7303 base,
7304 ),
7305 other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
7306 }
7307 },
7308
7309 Thunk::CustomFn {
7313 body,
7314 body_init,
7315 inputs,
7316 body_output_off,
7317 outer_output_off,
7318 out_bytes,
7319 } => {
7320 let mut body_buf: Vec<u8> = (**body_init).clone();
7321 unsafe {
7322 for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
7323 let src = (base as *const u8).add(*outer_in_off);
7324 let dst = body_buf.as_mut_ptr().add(*body_in_off);
7325 std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
7326 }
7327 }
7328 execute_thunks(body, &mut body_buf);
7329 unsafe {
7330 let src = body_buf.as_ptr().add(*body_output_off);
7331 let dst = base.add(*outer_output_off);
7332 std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
7333 }
7334 }
7335
7336 Thunk::Sgemm { a, b, c, m, k, n } => {
7337 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7338 unsafe {
7339 crate::blas::sgemm_auto(
7340 sl(*a, base, m * k),
7341 sl(*b, base, k * n),
7342 sl_mut(*c, base, m * n),
7343 m,
7344 k,
7345 n,
7346 );
7347 }
7348 }
7349
7350 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
7351 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7352 unsafe {
7358 let a_src = sl_f64(*a, base, n_ * n_);
7359 let b_src = sl_f64(*b, base, n_ * nrhs_);
7360 let mut a_scratch: Vec<f64> = a_src.to_vec();
7361 let mut x_buf: Vec<f64> = b_src.to_vec();
7362 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7363 if info != 0 {
7364 panic!(
7365 "DenseSolveF64: dgesv reported singular matrix \
7366 (info={info}, n={n_}, nrhs={nrhs_})"
7367 );
7368 }
7369 let dst = sl_mut_f64(*x, base, n_ * nrhs_);
7370 dst.copy_from_slice(&x_buf);
7371 }
7372 }
7373
7374 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
7375 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7376 unsafe {
7377 let a_src = sl(*a, base, n_ * n_);
7378 let b_src = sl(*b, base, n_ * nrhs_);
7379 let mut a_scratch: Vec<f32> = a_src.to_vec();
7380 let mut x_buf: Vec<f32> = b_src.to_vec();
7381 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7382 if info != 0 {
7383 panic!(
7384 "DenseSolveF32: sgesv reported singular matrix \
7385 (info={info}, n={n_}, nrhs={nrhs_})"
7386 );
7387 }
7388 let dst = sl_mut(*x, base, n_ * nrhs_);
7389 dst.copy_from_slice(&x_buf);
7390 }
7391 }
7392
7393 Thunk::BatchedDenseSolveF64 {
7394 a,
7395 b,
7396 x,
7397 batch,
7398 n,
7399 nrhs,
7400 } => {
7401 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7408 let a_stride = n_ * n_;
7409 let b_stride = n_ * nrhs_;
7410 unsafe {
7411 let a_full = sl_f64(*a, base, b_ * a_stride);
7412 let b_full = sl_f64(*b, base, b_ * b_stride);
7413 let x_full = sl_mut_f64(*x, base, b_ * b_stride);
7414 for bi in 0..b_ {
7415 let mut a_scratch: Vec<f64> =
7416 a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7417 let mut x_buf: Vec<f64> =
7418 b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7419 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7420 if info != 0 {
7421 panic!(
7422 "BatchedDenseSolveF64: slice {bi} \
7423 singular (info={info}, n={n_}, nrhs={nrhs_})"
7424 );
7425 }
7426 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7427 }
7428 }
7429 }
7430
7431 Thunk::BatchedDenseSolveF32 {
7432 a,
7433 b,
7434 x,
7435 batch,
7436 n,
7437 nrhs,
7438 } => {
7439 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7440 let a_stride = n_ * n_;
7441 let b_stride = n_ * nrhs_;
7442 unsafe {
7443 let a_full = sl(*a, base, b_ * a_stride);
7444 let b_full = sl(*b, base, b_ * b_stride);
7445 let x_full = sl_mut(*x, base, b_ * b_stride);
7446 for bi in 0..b_ {
7447 let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7448 let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7449 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7450 if info != 0 {
7451 panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
7452 }
7453 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7454 }
7455 }
7456 }
7457
7458 Thunk::BatchedDgemmF64 {
7459 a,
7460 b,
7461 c,
7462 batch,
7463 m,
7464 k,
7465 n,
7466 } => {
7467 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7468 let a_stride = m_ * k_;
7469 let b_stride = k_ * n_;
7470 let c_stride = m_ * n_;
7471 unsafe {
7472 let a_full = sl_f64(*a, base, b_ * a_stride);
7473 let b_full = sl_f64(*b, base, b_ * b_stride);
7474 let c_full = sl_mut_f64(*c, base, b_ * c_stride);
7475 for bi in 0..b_ {
7476 let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7477 let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7478 let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7479 crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
7480 }
7481 }
7482 }
7483
7484 Thunk::BatchedSgemm {
7485 a,
7486 b,
7487 c,
7488 batch,
7489 m,
7490 k,
7491 n,
7492 } => {
7493 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7494 let a_stride = m_ * k_;
7495 let b_stride = k_ * n_;
7496 let c_stride = m_ * n_;
7497 unsafe {
7498 let a_full = sl(*a, base, b_ * a_stride);
7499 let b_full = sl(*b, base, b_ * b_stride);
7500 let c_full = sl_mut(*c, base, b_ * c_stride);
7501 for bi in 0..b_ {
7502 let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7503 let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7504 let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7505 crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
7506 }
7507 }
7508 }
7509
7510 Thunk::Dgemm { a, b, c, m, k, n } => {
7511 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7512 unsafe {
7513 crate::blas::dgemm(
7514 sl_f64(*a, base, m * k),
7515 sl_f64(*b, base, k * n),
7516 sl_mut_f64(*c, base, m * n),
7517 m,
7518 k,
7519 n,
7520 );
7521 }
7522 }
7523
7524 Thunk::TransposeF64 {
7525 src,
7526 dst,
7527 in_total,
7528 out_dims,
7529 in_strides,
7530 } => unsafe {
7531 let inp = sl_f64(*src, base, *in_total as usize);
7532 let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
7533 let out = sl_mut_f64(*dst, base, out_total);
7534 transpose_walk_f64(inp, out, out_dims, in_strides);
7535 },
7536
7537 Thunk::ActivationF64 {
7538 src,
7539 dst,
7540 len,
7541 kind,
7542 } => {
7543 let len = *len as usize;
7544 unsafe {
7545 let inp = sl_f64(*src, base, len);
7546 let out = sl_mut_f64(*dst, base, len);
7547 apply_activation_f64(inp, out, *kind);
7548 }
7549 }
7550
7551 Thunk::ReduceSumF64 {
7552 src,
7553 dst,
7554 outer,
7555 reduced,
7556 inner,
7557 } => {
7558 let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
7559 unsafe {
7560 let inp = sl_f64(*src, base, o * r * n);
7561 let out = sl_mut_f64(*dst, base, o * n);
7562 reduce_sum_f64(inp, out, o, r, n);
7563 }
7564 }
7565
7566 Thunk::CopyF64 { src, dst, len } => {
7567 let len = *len as usize;
7568 if *src == *dst { } else {
7570 unsafe {
7571 let s = sl_f64(*src, base, len);
7572 let d = sl_mut_f64(*dst, base, len);
7573 d.copy_from_slice(s);
7574 }
7575 }
7576 }
7577
7578 Thunk::BinaryFullF64 {
7579 lhs,
7580 rhs,
7581 dst,
7582 len,
7583 lhs_len,
7584 rhs_len,
7585 op,
7586 out_dims_bcast,
7587 bcast_lhs_strides,
7588 bcast_rhs_strides,
7589 } => {
7590 let len = *len as usize;
7591 let lhs_len = *lhs_len as usize;
7592 let rhs_len = *rhs_len as usize;
7593 unsafe {
7594 let l = sl_f64(*lhs, base, lhs_len);
7595 let r = sl_f64(*rhs, base, rhs_len);
7596 let d = sl_mut_f64(*dst, base, len);
7597 if lhs_len == len && rhs_len == len {
7598 for i in 0..len {
7599 d[i] = binary_op_f64(*op, l[i], r[i]);
7600 }
7601 } else if !out_dims_bcast.is_empty() {
7602 let rank = out_dims_bcast.len();
7606 let mut coords = vec![0u32; rank];
7607 for i in 0..len {
7608 let mut rem = i;
7609 for ax in (0..rank).rev() {
7610 let sz = out_dims_bcast[ax] as usize;
7611 coords[ax] = (rem % sz) as u32;
7612 rem /= sz;
7613 }
7614 let mut li: usize = 0;
7615 let mut ri: usize = 0;
7616 for ax in 0..rank {
7617 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7618 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7619 }
7620 d[i] = binary_op_f64(*op, l[li], r[ri]);
7621 }
7622 } else {
7623 for i in 0..len {
7628 d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
7629 }
7630 }
7631 }
7632 }
7633
7634 Thunk::BinaryFullC64 {
7635 lhs,
7636 rhs,
7637 dst,
7638 len,
7639 lhs_len,
7640 rhs_len,
7641 op,
7642 out_dims_bcast,
7643 bcast_lhs_strides,
7644 bcast_rhs_strides,
7645 } => {
7646 let n_out = *len as usize;
7652 let n_l = *lhs_len as usize;
7653 let n_r = *rhs_len as usize;
7654 unsafe {
7655 let l = sl(*lhs, base, 2 * n_l);
7656 let r = sl(*rhs, base, 2 * n_r);
7657 let d = sl_mut(*dst, base, 2 * n_out);
7658 let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
7659 match op {
7660 BinaryOp::Add => (a_re + b_re, a_im + b_im),
7661 BinaryOp::Sub => (a_re - b_re, a_im - b_im),
7662 BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
7663 BinaryOp::Div => {
7664 let denom = b_re * b_re + b_im * b_im;
7665 (
7666 (a_re * b_re + a_im * b_im) / denom,
7667 (a_im * b_re - a_re * b_im) / denom,
7668 )
7669 }
7670 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
7671 unreachable!("C64 max/min/pow rejected at lowering")
7672 }
7673 }
7674 };
7675 if n_l == n_out && n_r == n_out {
7676 for i in 0..n_out {
7677 let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
7678 d[2 * i] = re;
7679 d[2 * i + 1] = im;
7680 }
7681 } else if !out_dims_bcast.is_empty() {
7682 let rank = out_dims_bcast.len();
7686 let mut coords = vec![0u32; rank];
7687 for i in 0..n_out {
7688 let mut rem = i;
7689 for ax in (0..rank).rev() {
7690 let sz = out_dims_bcast[ax] as usize;
7691 coords[ax] = (rem % sz) as u32;
7692 rem /= sz;
7693 }
7694 let mut li: usize = 0;
7695 let mut ri: usize = 0;
7696 for ax in 0..rank {
7697 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7698 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7699 }
7700 let (re, im) =
7701 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7702 d[2 * i] = re;
7703 d[2 * i + 1] = im;
7704 }
7705 } else {
7706 for i in 0..n_out {
7708 let li = if n_l == 1 { 0 } else { i % n_l };
7709 let ri = if n_r == 1 { 0 } else { i % n_r };
7710 let (re, im) =
7711 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7712 d[2 * i] = re;
7713 d[2 * i + 1] = im;
7714 }
7715 }
7716 }
7717 }
7718
7719 Thunk::ComplexNormSqF32 { src, dst, len } => {
7720 let n = *len as usize;
7721 unsafe {
7722 let s = sl(*src, base, 2 * n);
7723 let d = sl_mut(*dst, base, n);
7724 for i in 0..n {
7725 let re = s[2 * i];
7726 let im = s[2 * i + 1];
7727 d[i] = re * re + im * im;
7728 }
7729 }
7730 }
7731
7732 Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
7733 let n = *len as usize;
7736 unsafe {
7737 let zb = sl(*z, base, 2 * n);
7738 let gb = sl(*g, base, n);
7739 let db = sl_mut(*dz, base, 2 * n);
7740 for i in 0..n {
7741 let re = zb[2 * i];
7742 let im = zb[2 * i + 1];
7743 let gv = gb[i];
7744 db[2 * i] = gv * re;
7745 db[2 * i + 1] = gv * im;
7746 }
7747 }
7748 }
7749
7750 Thunk::ConjugateC64 { src, dst, len } => {
7751 let n = *len as usize;
7752 unsafe {
7753 let s = sl(*src, base, 2 * n);
7754 let d = sl_mut(*dst, base, 2 * n);
7755 for i in 0..n {
7756 d[2 * i] = s[2 * i];
7757 d[2 * i + 1] = -s[2 * i + 1];
7758 }
7759 }
7760 }
7761
7762 Thunk::ActivationC64 {
7763 src,
7764 dst,
7765 len,
7766 kind,
7767 } => {
7768 let n = *len as usize;
7769 unsafe {
7770 let s = sl(*src, base, 2 * n);
7771 let d = sl_mut(*dst, base, 2 * n);
7772 for i in 0..n {
7773 let a = s[2 * i];
7774 let b = s[2 * i + 1];
7775 let (re, im) = match kind {
7776 Activation::Neg => (-a, -b),
7777 Activation::Exp => {
7778 let ea = a.exp();
7780 (ea * b.cos(), ea * b.sin())
7781 }
7782 Activation::Log => {
7783 let r = (a * a + b * b).sqrt();
7785 (r.ln(), b.atan2(a))
7786 }
7787 Activation::Sqrt => {
7788 let r = (a * a + b * b).sqrt();
7791 let re = ((r + a) * 0.5).max(0.0).sqrt();
7792 let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
7793 let im = if b >= 0.0 { im_mag } else { -im_mag };
7794 (re, im)
7795 }
7796 _ => unreachable!("non-C64 activation kind survived lowering"),
7797 };
7798 d[2 * i] = re;
7799 d[2 * i + 1] = im;
7800 }
7801 }
7802 }
7803
7804 Thunk::Scan {
7805 body,
7806 body_init,
7807 body_input_off,
7808 body_output_off,
7809 outer_init_off,
7810 outer_final_off,
7811 length,
7812 carry_bytes,
7813 save_trajectory,
7814 xs_inputs,
7815 bcast_inputs,
7816 num_checkpoints,
7817 } => {
7818 let cb = *carry_bytes as usize;
7819 let n_steps = *length as usize;
7820 let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
7824 n_steps } else {
7826 *num_checkpoints as usize
7827 };
7828 let checkpoint_t_for_k = |k: usize| -> usize {
7829 if k_total == n_steps {
7830 k
7831 } else {
7832 ((k + 1) * n_steps)
7833 .div_ceil(k_total)
7834 .saturating_sub(1)
7835 .min(n_steps - 1)
7836 }
7837 };
7838 let mut next_k = 0usize;
7839
7840 let mut body_buf: Vec<u8> = (**body_init).clone();
7841 unsafe {
7842 std::ptr::copy_nonoverlapping(
7843 base.add(*outer_init_off),
7844 body_buf.as_mut_ptr().add(*body_input_off),
7845 cb,
7846 );
7847 for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
7851 std::ptr::copy_nonoverlapping(
7852 base.add(*outer_b_off),
7853 body_buf.as_mut_ptr().add(*body_b_off),
7854 *total_bytes as usize,
7855 );
7856 }
7857 }
7858 for t in 0..n_steps {
7859 for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
7860 let psb = *per_step_bytes as usize;
7861 unsafe {
7862 std::ptr::copy_nonoverlapping(
7863 base.add(*outer_xs_off + t * psb),
7864 body_buf.as_mut_ptr().add(*body_x_off),
7865 psb,
7866 );
7867 }
7868 }
7869
7870 execute_thunks(body, &mut body_buf);
7871
7872 if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
7873 unsafe {
7874 std::ptr::copy_nonoverlapping(
7875 body_buf.as_ptr().add(*body_output_off),
7876 base.add(*outer_final_off + next_k * cb),
7877 cb,
7878 );
7879 }
7880 next_k += 1;
7881 }
7882
7883 if *body_output_off != *body_input_off {
7884 body_buf
7885 .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
7886 }
7887 }
7888
7889 if !*save_trajectory {
7890 unsafe {
7892 std::ptr::copy_nonoverlapping(
7893 body_buf.as_ptr().add(*body_output_off),
7894 base.add(*outer_final_off),
7895 cb,
7896 );
7897 }
7898 }
7899 }
7900
7901 Thunk::ScanBackward {
7902 body_vjp,
7903 body_init,
7904 body_carry_in_off,
7905 body_x_offs,
7906 body_d_output_off,
7907 body_dcarry_out_off,
7908 outer_init_off,
7909 outer_traj_off,
7910 outer_upstream_off,
7911 outer_xs_offs,
7912 outer_dinit_off,
7913 length,
7914 carry_bytes,
7915 save_trajectory,
7916 num_checkpoints,
7917 forward_body,
7918 forward_body_init,
7919 forward_body_carry_in_off,
7920 forward_body_output_off,
7921 forward_body_x_offs,
7922 carry_elem_size,
7923 } => {
7924 let cb = *carry_bytes as usize;
7937 let n_steps = *length as usize;
7938 let k_total = *num_checkpoints as usize;
7939 let is_recursive = k_total != 0 && k_total != n_steps;
7940 let checkpoint_t_for_k = |k: usize| -> usize {
7941 ((k + 1) * n_steps)
7942 .div_ceil(k_total)
7943 .saturating_sub(1)
7944 .min(n_steps - 1)
7945 };
7946
7947 let mut fwd_buf: Vec<u8> = if is_recursive {
7948 (**forward_body_init.as_ref().unwrap()).clone()
7949 } else {
7950 Vec::new()
7951 };
7952
7953 let mut dcarry: Vec<u8> = vec![0u8; cb];
7954 if !*save_trajectory {
7955 unsafe {
7956 std::ptr::copy_nonoverlapping(
7957 base.add(*outer_upstream_off),
7958 dcarry.as_mut_ptr(),
7959 cb,
7960 );
7961 }
7962 }
7963
7964 let mut body_buf: Vec<u8> = (**body_init).clone();
7965
7966 let process_iter =
7971 |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
7972 if *save_trajectory {
7973 unsafe {
7974 let up_off = *outer_upstream_off + t * cb;
7975 match *carry_elem_size {
7976 4 => {
7977 let up_ptr = base.add(up_off) as *const f32;
7978 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
7979 let n_elems = cb / 4;
7980 for i in 0..n_elems {
7981 *dc_ptr.add(i) += *up_ptr.add(i);
7982 }
7983 }
7984 8 => {
7985 let up_ptr = base.add(up_off) as *const f64;
7986 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
7987 let n_elems = cb / 8;
7988 for i in 0..n_elems {
7989 *dc_ptr.add(i) += *up_ptr.add(i);
7990 }
7991 }
7992 other => panic!(
7993 "ScanBackward: unsupported carry elem size {other} \
7994 (only f32/f64 carries are supported today)"
7995 ),
7996 }
7997 }
7998 }
7999 body_buf[*body_carry_in_off..*body_carry_in_off + cb]
8000 .copy_from_slice(carry_in);
8001 unsafe {
8002 for (i, body_x_off) in body_x_offs.iter().enumerate() {
8003 let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
8004 let psb = per_step_bytes as usize;
8005 std::ptr::copy_nonoverlapping(
8006 base.add(outer_xs_off + t * psb),
8007 body_buf.as_mut_ptr().add(*body_x_off),
8008 psb,
8009 );
8010 }
8011 std::ptr::copy_nonoverlapping(
8012 dcarry.as_ptr(),
8013 body_buf.as_mut_ptr().add(*body_d_output_off),
8014 cb,
8015 );
8016 }
8017 execute_thunks(body_vjp, body_buf);
8018 unsafe {
8019 std::ptr::copy_nonoverlapping(
8020 body_buf.as_ptr().add(*body_dcarry_out_off),
8021 dcarry.as_mut_ptr(),
8022 cb,
8023 );
8024 }
8025 };
8026
8027 if is_recursive {
8028 let leaf_threshold = 4usize;
8036 let fb_sched = forward_body.as_ref().unwrap();
8037 let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8038 let mut segment_end = n_steps - 1;
8039 for seg_k in (0..k_total).rev() {
8040 let segment_start = if seg_k == 0 {
8041 0
8042 } else {
8043 checkpoint_t_for_k(seg_k - 1) + 1
8044 };
8045 let mut anchor: Vec<u8> = vec![0u8; cb];
8046 unsafe {
8047 let src = if seg_k == 0 {
8048 base.add(*outer_init_off)
8049 } else {
8050 base.add(*outer_traj_off + (seg_k - 1) * cb)
8051 };
8052 std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8053 }
8054 let mut leaf_action = |t: usize, carry_in: &[u8]| {
8057 process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8058 };
8059 unsafe {
8060 griewank_process_segment(
8061 segment_start,
8062 segment_end,
8063 &anchor,
8064 cb,
8065 fb_sched,
8066 fb_init,
8067 *forward_body_carry_in_off,
8068 *forward_body_output_off,
8069 forward_body_x_offs,
8070 base,
8071 outer_xs_offs,
8072 &mut fwd_buf,
8073 leaf_threshold,
8074 &mut leaf_action,
8075 );
8076 }
8077 if seg_k == 0 {
8078 break;
8079 }
8080 segment_end = segment_start - 1;
8081 }
8082 } else {
8083 let mut carry_buf: Vec<u8> = vec![0u8; cb];
8086 for t in (0..n_steps).rev() {
8087 unsafe {
8088 let src = if t == 0 {
8089 base.add(*outer_init_off)
8090 } else {
8091 base.add(*outer_traj_off + (t - 1) * cb)
8092 };
8093 std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8094 }
8095 process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8096 }
8097 }
8098
8099 unsafe {
8100 std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8101 }
8102 }
8103
8104 Thunk::ScanBackwardXs {
8105 body_vjp,
8106 body_init,
8107 body_carry_in_off,
8108 body_x_offs,
8109 body_d_output_off,
8110 body_dcarry_out_off,
8111 body_dxs_out_off,
8112 outer_init_off,
8113 outer_traj_off,
8114 outer_upstream_off,
8115 outer_xs_offs,
8116 outer_dxs_off,
8117 length,
8118 carry_bytes,
8119 carry_elem_size,
8120 per_step_bytes,
8121 save_trajectory,
8122 num_checkpoints,
8123 forward_body,
8124 forward_body_init,
8125 forward_body_carry_in_off,
8126 forward_body_output_off,
8127 forward_body_x_offs,
8128 } => {
8129 let cb = *carry_bytes as usize;
8130 let psb = *per_step_bytes as usize;
8131 let n_steps = *length as usize;
8132 let k_total = *num_checkpoints as usize;
8133 let is_recursive = k_total != 0 && k_total != n_steps;
8134 let checkpoint_t_for_k = |k: usize| -> usize {
8135 ((k + 1) * n_steps)
8136 .div_ceil(k_total)
8137 .saturating_sub(1)
8138 .min(n_steps - 1)
8139 };
8140
8141 let mut fwd_buf: Vec<u8> = if is_recursive {
8145 (**forward_body_init.as_ref().unwrap()).clone()
8146 } else {
8147 Vec::new()
8148 };
8149 let mut seg_cache: Vec<u8> = Vec::new();
8150 let mut seg_start_t: usize = usize::MAX;
8151 let mut seg_count: usize = 0;
8152 let recompute_carry_t =
8153 |t: usize,
8154 dst: &mut [u8],
8155 fwd_buf: &mut Vec<u8>,
8156 seg_cache: &mut Vec<u8>,
8157 seg_start_t: &mut usize,
8158 seg_count: &mut usize| {
8159 if !is_recursive {
8160 unsafe {
8161 let src = if t == 0 {
8162 base.add(*outer_init_off)
8163 } else {
8164 base.add(*outer_traj_off + (t - 1) * cb)
8165 };
8166 std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
8167 }
8168 return;
8169 }
8170 if *seg_start_t != usize::MAX
8171 && t >= *seg_start_t
8172 && t < *seg_start_t + *seg_count
8173 {
8174 let off = (t - *seg_start_t) * cb;
8175 dst.copy_from_slice(&seg_cache[off..off + cb]);
8176 return;
8177 }
8178 let seg_k = (0..k_total)
8179 .find(|&k| t <= checkpoint_t_for_k(k))
8180 .unwrap_or(k_total - 1);
8181 let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
8182 (0, unsafe { base.add(*outer_init_off) as *const u8 })
8183 } else {
8184 let prev_ck = checkpoint_t_for_k(seg_k - 1);
8185 (prev_ck + 1, unsafe {
8186 base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
8187 })
8188 };
8189 let seg_end_t = checkpoint_t_for_k(seg_k);
8190 let seg_size = seg_end_t - anchor_t + 1;
8191
8192 fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
8193 unsafe {
8194 std::ptr::copy_nonoverlapping(
8195 anchor_ptr,
8196 fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
8197 cb,
8198 );
8199 }
8200 seg_cache.resize(seg_size * cb, 0u8);
8201 seg_cache[0..cb].copy_from_slice(
8202 &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8203 );
8204 let fb_sched = forward_body.as_ref().unwrap();
8205 for i in 1..seg_size {
8206 let cur_iter = anchor_t + i - 1;
8207 for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
8208 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
8209 let xb = x_psb as usize;
8210 unsafe {
8211 std::ptr::copy_nonoverlapping(
8212 base.add(outer_xs_off + cur_iter * xb),
8213 fwd_buf.as_mut_ptr().add(*fb_x_off),
8214 xb,
8215 );
8216 }
8217 }
8218 execute_thunks(fb_sched, fwd_buf);
8219 if *forward_body_output_off != *forward_body_carry_in_off {
8220 fwd_buf.copy_within(
8221 *forward_body_output_off..*forward_body_output_off + cb,
8222 *forward_body_carry_in_off,
8223 );
8224 }
8225 let cache_off = i * cb;
8226 seg_cache[cache_off..cache_off + cb].copy_from_slice(
8227 &fwd_buf
8228 [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8229 );
8230 }
8231 *seg_start_t = anchor_t;
8232 *seg_count = seg_size;
8233
8234 let off = (t - anchor_t) * cb;
8235 dst.copy_from_slice(&seg_cache[off..off + cb]);
8236 };
8237
8238 let mut dcarry: Vec<u8> = vec![0u8; cb];
8239 if !*save_trajectory {
8240 unsafe {
8241 std::ptr::copy_nonoverlapping(
8242 base.add(*outer_upstream_off),
8243 dcarry.as_mut_ptr(),
8244 cb,
8245 );
8246 }
8247 }
8248
8249 let mut body_buf: Vec<u8> = (**body_init).clone();
8250
8251 for t in (0..n_steps).rev() {
8252 if *save_trajectory {
8253 unsafe {
8254 let up_off = *outer_upstream_off + t * cb;
8255 match *carry_elem_size {
8256 4 => {
8257 let up_ptr = base.add(up_off) as *const f32;
8258 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8259 let n_elems = cb / 4;
8260 for i in 0..n_elems {
8261 *dc_ptr.add(i) += *up_ptr.add(i);
8262 }
8263 }
8264 8 => {
8265 let up_ptr = base.add(up_off) as *const f64;
8266 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8267 let n_elems = cb / 8;
8268 for i in 0..n_elems {
8269 *dc_ptr.add(i) += *up_ptr.add(i);
8270 }
8271 }
8272 other => panic!(
8273 "ScanBackwardXs: unsupported carry elem size {other} \
8274 (only f32/f64 carries are supported today)"
8275 ),
8276 }
8277 }
8278 }
8279
8280 let carry_dst_start = *body_carry_in_off;
8284 {
8285 let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
8286 recompute_carry_t(
8287 t,
8288 carry_slice,
8289 &mut fwd_buf,
8290 &mut seg_cache,
8291 &mut seg_start_t,
8292 &mut seg_count,
8293 );
8294 }
8295 unsafe {
8296 for (i, body_x_off) in body_x_offs.iter().enumerate() {
8297 let (outer_xs_off, x_psb) = outer_xs_offs[i];
8298 let xb = x_psb as usize;
8299 std::ptr::copy_nonoverlapping(
8300 base.add(outer_xs_off + t * xb),
8301 body_buf.as_mut_ptr().add(*body_x_off),
8302 xb,
8303 );
8304 }
8305 std::ptr::copy_nonoverlapping(
8306 dcarry.as_ptr(),
8307 body_buf.as_mut_ptr().add(*body_d_output_off),
8308 cb,
8309 );
8310 }
8311
8312 execute_thunks(body_vjp, &mut body_buf);
8313
8314 unsafe {
8317 std::ptr::copy_nonoverlapping(
8318 body_buf.as_ptr().add(*body_dxs_out_off),
8319 base.add(*outer_dxs_off + t * psb),
8320 psb,
8321 );
8322 }
8323
8324 unsafe {
8326 std::ptr::copy_nonoverlapping(
8327 body_buf.as_ptr().add(*body_dcarry_out_off),
8328 dcarry.as_mut_ptr(),
8329 cb,
8330 );
8331 }
8332 }
8333 }
8334
8335 Thunk::FusedMmBiasAct {
8336 a,
8337 w,
8338 bias,
8339 c,
8340 m,
8341 k,
8342 n,
8343 act,
8344 } => {
8345 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8346 unsafe {
8347 let out = sl_mut(*c, base, m * n);
8348 crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
8349 match act {
8350 Some(Activation::Gelu) => {
8351 crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
8352 }
8353 Some(other) => {
8354 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8355 apply_activation_inplace(out, *other);
8356 }
8357 None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
8358 }
8359 }
8360 }
8361
8362 Thunk::FusedResidualLN {
8363 x,
8364 res,
8365 bias,
8366 g,
8367 b,
8368 out,
8369 rows,
8370 h,
8371 eps,
8372 has_bias,
8373 } => {
8374 let (rows, h) = (*rows as usize, *h as usize);
8375 unsafe {
8376 let zero = &zero_bias[..h];
8377 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8378 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8379 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8380 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8381 let bi_ptr = bi.as_ptr() as usize;
8382 let g_ptr = sl(*g, base, h).as_ptr() as usize;
8383 let b_ptr = sl(*b, base, h).as_ptr() as usize;
8384 let e = *eps;
8385 crate::pool::par_for(rows, 4, &|off, cnt| {
8386 let xs =
8387 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8388 let rs =
8389 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8390 let os = std::slice::from_raw_parts_mut(
8391 (o_ptr as *mut f32).add(off * h),
8392 cnt * h,
8393 );
8394 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8395 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8396 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8397 crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
8398 });
8399 }
8400 }
8401
8402 Thunk::FusedResidualRmsNorm {
8403 x,
8404 res,
8405 bias,
8406 g,
8407 b,
8408 out,
8409 rows,
8410 h,
8411 eps,
8412 has_bias,
8413 } => {
8414 let (rows, h) = (*rows as usize, *h as usize);
8415 unsafe {
8416 let zero = &zero_bias[..h];
8417 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8418 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8419 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8420 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8421 let bi_ptr = bi.as_ptr() as usize;
8422 let g_ptr = sl(*g, base, h).as_ptr() as usize;
8423 let b_ptr = sl(*b, base, h).as_ptr() as usize;
8424 let e = *eps;
8425 crate::pool::par_for(rows, 4, &|off, cnt| {
8426 let xs =
8427 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8428 let rs =
8429 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8430 let os = std::slice::from_raw_parts_mut(
8431 (o_ptr as *mut f32).add(off * h),
8432 cnt * h,
8433 );
8434 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8435 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8436 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8437 crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
8438 });
8439 }
8440 }
8441
8442 Thunk::BiasAdd {
8443 src,
8444 bias,
8445 dst,
8446 m,
8447 n,
8448 } => {
8449 let (m, n) = (*m as usize, *n as usize);
8450 unsafe {
8451 let out = sl_mut(*dst, base, m * n);
8452 out.copy_from_slice(sl(*src, base, m * n));
8453 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8454 }
8455 }
8456
8457 Thunk::BinaryFull {
8458 lhs,
8459 rhs,
8460 dst,
8461 len,
8462 lhs_len,
8463 rhs_len,
8464 op,
8465 out_dims_bcast,
8466 bcast_lhs_strides,
8467 bcast_rhs_strides,
8468 } => {
8469 let len = *len as usize;
8470 let ll = (*lhs_len as usize).max(1);
8471 let rl = (*rhs_len as usize).max(1);
8472 unsafe {
8473 let l = sl(*lhs, base, ll);
8474 let r = sl(*rhs, base, rl);
8475 let o = sl_mut(*dst, base, len);
8476 if ll == len && rl == len {
8478 #[cfg(target_arch = "aarch64")]
8479 if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
8480 use std::arch::aarch64::*;
8481 let chunks = len / 4;
8482 for c in 0..chunks {
8483 let off = c * 4;
8484 let vl = vld1q_f32(l.as_ptr().add(off));
8485 let vr = vld1q_f32(r.as_ptr().add(off));
8486 let res = match op {
8487 BinaryOp::Add => vaddq_f32(vl, vr),
8488 BinaryOp::Mul => vmulq_f32(vl, vr),
8489 _ => unreachable!(),
8490 };
8491 vst1q_f32(o.as_mut_ptr().add(off), res);
8492 }
8493 for i in (chunks * 4)..len {
8494 o[i] = match op {
8495 BinaryOp::Add => l[i] + r[i],
8496 BinaryOp::Mul => l[i] * r[i],
8497 _ => unreachable!(),
8498 };
8499 }
8500 continue;
8506 }
8507 }
8508 if !out_dims_bcast.is_empty() {
8509 let rank = out_dims_bcast.len();
8512 let mut coords = vec![0u32; rank];
8513 for i in 0..len {
8514 let mut rem = i;
8515 for ax in (0..rank).rev() {
8516 let sz = out_dims_bcast[ax] as usize;
8517 coords[ax] = (rem % sz) as u32;
8518 rem /= sz;
8519 }
8520 let mut li: usize = 0;
8521 let mut ri: usize = 0;
8522 for ax in 0..rank {
8523 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8524 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8525 }
8526 o[i] = match op {
8527 BinaryOp::Add => l[li] + r[ri],
8528 BinaryOp::Sub => l[li] - r[ri],
8529 BinaryOp::Mul => l[li] * r[ri],
8530 BinaryOp::Div => l[li] / r[ri],
8531 BinaryOp::Max => l[li].max(r[ri]),
8532 BinaryOp::Min => l[li].min(r[ri]),
8533 BinaryOp::Pow => l[li].powf(r[ri]),
8534 };
8535 }
8536 } else {
8537 for i in 0..len {
8539 let li = if ll == 1 { 0 } else { i % ll };
8540 let ri = if rl == 1 { 0 } else { i % rl };
8541 o[i] = match op {
8542 BinaryOp::Add => l[li] + r[ri],
8543 BinaryOp::Sub => l[li] - r[ri],
8544 BinaryOp::Mul => l[li] * r[ri],
8545 BinaryOp::Div => l[li] / r[ri],
8546 BinaryOp::Max => l[li].max(r[ri]),
8547 BinaryOp::Min => l[li].min(r[ri]),
8548 BinaryOp::Pow => l[li].powf(r[ri]),
8549 };
8550 }
8551 }
8552 }
8553 }
8554
8555 Thunk::Gather {
8556 table,
8557 table_len,
8558 idx,
8559 dst,
8560 num_idx,
8561 trailing,
8562 } => {
8563 let (ni, tr) = (*num_idx as usize, *trailing as usize);
8564 unsafe {
8565 let tab = sl(*table, base, *table_len as usize);
8566 let ids = sl(*idx, base, ni);
8567 let out = sl_mut(*dst, base, ni * tr);
8568 for i in 0..ni {
8569 let row = ids[i] as usize;
8570 out[i * tr..(i + 1) * tr].copy_from_slice(&tab[row * tr..(row + 1) * tr]);
8571 }
8572 }
8573 }
8574
8575 Thunk::Narrow {
8576 src,
8577 dst,
8578 outer,
8579 src_stride,
8580 dst_stride,
8581 inner,
8582 elem_bytes,
8583 } => {
8584 let f = narrow_thunk_closure(
8585 *src,
8586 *dst,
8587 *outer,
8588 *src_stride,
8589 *dst_stride,
8590 *inner,
8591 *elem_bytes,
8592 );
8593 f(base);
8594 }
8595
8596 Thunk::Copy { src, dst, len } => {
8597 let len = *len as usize;
8598 unsafe {
8599 let s = sl(*src, base, len);
8600 let d = sl_mut(*dst, base, len);
8601 d.copy_from_slice(s);
8602 }
8603 }
8604
8605 Thunk::LayerNorm {
8606 src,
8607 g,
8608 b,
8609 dst,
8610 rows,
8611 h,
8612 eps,
8613 } => {
8614 let (rows, h) = (*rows as usize, *h as usize);
8615 unsafe {
8616 let input = sl(*src, base, rows * h);
8617 let gamma = sl(*g, base, h);
8618 let beta = sl(*b, base, h);
8619 let output = sl_mut(*dst, base, rows * h);
8620 if rows >= 4 && rows * h >= 30_000 {
8622 let i_ptr = input.as_ptr() as usize;
8623 let o_ptr = output.as_mut_ptr() as usize;
8624 let g_ptr = gamma.as_ptr() as usize;
8625 let b_ptr = beta.as_ptr() as usize;
8626 let e = *eps;
8627 crate::pool::par_for(rows, 4, &|off, cnt| {
8628 let inp = std::slice::from_raw_parts(
8629 (i_ptr as *const f32).add(off * h),
8630 cnt * h,
8631 );
8632 let out = std::slice::from_raw_parts_mut(
8633 (o_ptr as *mut f32).add(off * h),
8634 cnt * h,
8635 );
8636 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8637 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8638 for row in 0..cnt {
8639 crate::kernels::layer_norm_row(
8640 &inp[row * h..(row + 1) * h],
8641 g,
8642 b,
8643 &mut out[row * h..(row + 1) * h],
8644 h,
8645 e,
8646 );
8647 }
8648 });
8649 } else {
8650 for row in 0..rows {
8651 crate::kernels::layer_norm_row(
8652 &input[row * h..(row + 1) * h],
8653 gamma,
8654 beta,
8655 &mut output[row * h..(row + 1) * h],
8656 h,
8657 *eps,
8658 );
8659 }
8660 }
8661 }
8662 }
8663
8664 Thunk::GroupNorm {
8665 src,
8666 g,
8667 b,
8668 dst,
8669 n,
8670 c,
8671 h,
8672 w,
8673 num_groups,
8674 eps,
8675 } => {
8676 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8677 let plane = c * h * w;
8678 unsafe {
8679 for ni in 0..n {
8680 let input = sl(*src, base.add(ni * plane), plane);
8681 let gamma = sl(*g, base, c);
8682 let beta = sl(*b, base, c);
8683 let output = sl_mut(*dst, base.add(ni * plane), plane);
8684 crate::kernels::group_norm_nchw(
8685 input,
8686 gamma,
8687 beta,
8688 output,
8689 1,
8690 c,
8691 h,
8692 w,
8693 *num_groups as usize,
8694 *eps,
8695 );
8696 }
8697 }
8698 }
8699
8700 Thunk::LayerNorm2d {
8701 src,
8702 g,
8703 b,
8704 dst,
8705 n,
8706 c,
8707 h,
8708 w,
8709 eps,
8710 } => {
8711 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8712 let plane = c * h * w;
8713 unsafe {
8714 let input = sl(*src, base, n * plane);
8715 let gamma = sl(*g, base, c);
8716 let beta = sl(*b, base, c);
8717 let output = sl_mut(*dst, base, n * plane);
8718 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
8719 }
8720 }
8721
8722 Thunk::ConvTranspose2d {
8723 src,
8724 weight,
8725 dst,
8726 n,
8727 c_in,
8728 h,
8729 w_in,
8730 c_out,
8731 h_out,
8732 w_out,
8733 kh,
8734 kw,
8735 sh,
8736 sw,
8737 ph,
8738 pw,
8739 dh,
8740 dw,
8741 groups,
8742 } => {
8743 let n = *n as usize;
8744 let c_in = *c_in as usize;
8745 let h = *h as usize;
8746 let w_in = *w_in as usize;
8747 let c_out = *c_out as usize;
8748 let h_out = *h_out as usize;
8749 let w_out = *w_out as usize;
8750 unsafe {
8751 let inp = sl(*src, base, n * c_in * h * w_in);
8752 let wt = sl(
8753 *weight,
8754 base,
8755 c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
8756 );
8757 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
8758 crate::kernels::conv_transpose2d_nchw(
8759 inp,
8760 wt,
8761 out,
8762 n,
8763 c_in,
8764 h,
8765 w_in,
8766 c_out,
8767 h_out,
8768 w_out,
8769 *kh as usize,
8770 *kw as usize,
8771 *sh as usize,
8772 *sw as usize,
8773 *ph as usize,
8774 *pw as usize,
8775 *dh as usize,
8776 *dw as usize,
8777 *groups as usize,
8778 );
8779 }
8780 }
8781
8782 Thunk::ResizeNearest2x {
8783 src,
8784 dst,
8785 n,
8786 c,
8787 h,
8788 w,
8789 } => {
8790 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8791 let in_plane = c * h * w;
8792 let out_plane = c * h * 2 * w * 2;
8793 unsafe {
8794 for ni in 0..n {
8795 let input = sl(*src, base.add(ni * in_plane), in_plane);
8796 let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
8797 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
8798 }
8799 }
8800 }
8801
8802 Thunk::AxialRope2d {
8803 src,
8804 dst,
8805 batch,
8806 seq,
8807 hidden,
8808 end_x,
8809 end_y,
8810 head_dim,
8811 num_heads,
8812 theta,
8813 repeat_factor,
8814 } => {
8815 let b = *batch as usize;
8816 let s = *seq as usize;
8817 let hdim = *head_dim as usize;
8818 let nh = *num_heads as usize;
8819 let plane = s * (*hidden as usize);
8820 unsafe {
8821 for bi in 0..b {
8822 let input = sl(*src, base.add(bi * plane), plane);
8823 let output = sl_mut(*dst, base.add(bi * plane), plane);
8824 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
8825 input,
8826 nh,
8827 s,
8828 hdim,
8829 *end_x as usize,
8830 *end_y as usize,
8831 *theta,
8832 *repeat_factor as usize,
8833 );
8834 output.copy_from_slice(&rotated);
8835 }
8836 }
8837 }
8838
8839 Thunk::RmsNorm {
8840 src,
8841 g,
8842 b,
8843 dst,
8844 rows,
8845 h,
8846 eps,
8847 } => {
8848 let (rows, h) = (*rows as usize, *h as usize);
8849 unsafe {
8850 let input = sl(*src, base, rows * h);
8851 let gamma = sl(*g, base, h);
8852 let beta = sl(*b, base, h);
8853 let output = sl_mut(*dst, base, rows * h);
8854 let inv_h = 1.0 / h as f32;
8855 for row in 0..rows {
8856 let in_row = &input[row * h..(row + 1) * h];
8857 let out_row = &mut output[row * h..(row + 1) * h];
8858 let mut sumsq = 0f32;
8860 for &v in in_row {
8861 sumsq += v * v;
8862 }
8863 let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
8864 for i in 0..h {
8865 out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
8866 }
8867 }
8868 }
8869 }
8870
8871 Thunk::Softmax { data, rows, cols } => {
8872 let (rows, cols) = (*rows as usize, *cols as usize);
8873 unsafe {
8874 crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
8875 }
8876 }
8877
8878 Thunk::Cumsum {
8879 src,
8880 dst,
8881 rows,
8882 cols,
8883 exclusive,
8884 } => {
8885 let (rows, cols) = (*rows as usize, *cols as usize);
8886 unsafe {
8887 let s = sl(*src, base, rows * cols);
8888 let d = sl_mut(*dst, base, rows * cols);
8889 if *exclusive {
8890 for r in 0..rows {
8891 let mut acc = 0.0f32;
8892 for c in 0..cols {
8893 d[r * cols + c] = acc;
8894 acc += s[r * cols + c];
8895 }
8896 }
8897 } else {
8898 for r in 0..rows {
8899 let mut acc = 0.0f32;
8900 for c in 0..cols {
8901 acc += s[r * cols + c];
8902 d[r * cols + c] = acc;
8903 }
8904 }
8905 }
8906 }
8907 }
8908
8909 Thunk::Sample {
8910 logits,
8911 dst,
8912 batch,
8913 vocab,
8914 top_k,
8915 top_p,
8916 temperature,
8917 seed,
8918 } => {
8919 let (b, v) = (*batch as usize, *vocab as usize);
8920 let k = (*top_k as usize).min(v);
8921 unsafe {
8922 let lg = sl(*logits, base, b * v);
8923 let out = sl_mut(*dst, base, b);
8924 let mut rng =
8925 rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
8926 for bi in 0..b {
8927 let row = &lg[bi * v..(bi + 1) * v];
8928 out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
8929 }
8930 }
8931 }
8932
8933 Thunk::GatedDeltaNet {
8934 q,
8935 k,
8936 v,
8937 g,
8938 beta,
8939 state,
8940 dst,
8941 batch,
8942 seq,
8943 heads,
8944 state_size,
8945 } => unsafe {
8946 execute_gated_delta_net_f32(
8947 *q,
8948 *k,
8949 *v,
8950 *g,
8951 *beta,
8952 *state,
8953 *dst,
8954 *batch as usize,
8955 *seq as usize,
8956 *heads as usize,
8957 *state_size as usize,
8958 base,
8959 );
8960 },
8961
8962 Thunk::SelectiveScan {
8963 x,
8964 delta,
8965 a,
8966 b: bp,
8967 c: cp,
8968 dst,
8969 batch,
8970 seq,
8971 hidden,
8972 state_size,
8973 } => {
8974 let (b, s, h, n) = (
8975 *batch as usize,
8976 *seq as usize,
8977 *hidden as usize,
8978 *state_size as usize,
8979 );
8980 unsafe {
8981 let xs = sl(*x, base, b * s * h);
8982 let dt = sl(*delta, base, b * s * h);
8983 let am = sl(*a, base, h * n);
8984 let bm = sl(*bp, base, b * s * n);
8985 let cm = sl(*cp, base, b * s * n);
8986 let out = sl_mut(*dst, base, b * s * h);
8987
8988 let mut state = vec![0f32; h * n];
8992 for bi in 0..b {
8993 for v in state.iter_mut() {
8995 *v = 0.0;
8996 }
8997 for si in 0..s {
8998 let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8999 let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
9000 let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
9001 let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
9002 let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
9003
9004 for ci in 0..h {
9005 let d = dt_row[ci];
9006 let xv = x_row[ci];
9007 let mut acc = 0f32;
9008 for ni in 0..n {
9009 let da = (d * am[ci * n + ni]).exp();
9011 state[ci * n + ni] =
9012 da * state[ci * n + ni] + d * b_row[ni] * xv;
9013 acc += c_row[ni] * state[ci * n + ni];
9014 }
9015 out_row[ci] = acc;
9016 }
9017 }
9018 }
9019 }
9020 }
9021
9022 Thunk::DequantMatMul {
9023 x,
9024 w_q,
9025 scale,
9026 zp,
9027 dst,
9028 m,
9029 k,
9030 n,
9031 block_size,
9032 is_asymmetric,
9033 } => {
9034 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9035 let n_blocks = k.div_ceil(bs);
9036 unsafe {
9037 let xs = sl(*x, base, m * k);
9038 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
9039 let scales = sl(*scale, base, n_blocks * n);
9040 let zps = if *is_asymmetric {
9041 sl(*zp, base, n_blocks * n)
9042 } else {
9043 &[][..]
9044 };
9045 let out = sl_mut(*dst, base, m * n);
9046 dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9047 }
9048 }
9049
9050 Thunk::DequantMatMulGguf {
9051 x,
9052 w_q,
9053 dst,
9054 m,
9055 k,
9056 n,
9057 scheme,
9058 } => {
9059 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9060 let block_bytes = scheme.gguf_block_bytes() as usize;
9061 let block_elems = scheme.gguf_block_size() as usize;
9062 debug_assert!(
9063 block_bytes > 0 && block_elems > 0,
9064 "non-GGUF scheme in GGUF arm"
9065 );
9066 debug_assert!(
9067 (k * n).is_multiple_of(block_elems),
9068 "k*n={} not aligned to GGUF block size {}",
9069 k * n,
9070 block_elems
9071 );
9072 let total_bytes = (k * n) / block_elems * block_bytes;
9073 unsafe {
9074 let xs = sl(*x, base, m * k);
9075 let w_bytes_ptr = base.add(*w_q) as *const u8;
9076 let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
9077 let out = sl_mut(*dst, base, m * n);
9078 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
9079 }
9080 }
9081
9082 Thunk::DequantMatMulInt4 {
9083 x,
9084 w_q,
9085 scale,
9086 zp,
9087 dst,
9088 m,
9089 k,
9090 n,
9091 block_size,
9092 is_asymmetric,
9093 } => {
9094 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9095 let n_blocks = k.div_ceil(bs);
9096 unsafe {
9097 let xs = sl(*x, base, m * k);
9098 let w_bytes = std::slice::from_raw_parts(
9099 base.add(*w_q) as *const u8,
9100 (k * n).div_ceil(2),
9101 );
9102 let scales = sl(*scale, base, n_blocks * n);
9103 let zps = if *is_asymmetric {
9104 sl(*zp, base, n_blocks * n)
9105 } else {
9106 &[][..]
9107 };
9108 let out = sl_mut(*dst, base, m * n);
9109 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9110 }
9111 }
9112
9113 Thunk::DequantMatMulFp8 {
9114 x,
9115 w_q,
9116 scale,
9117 dst,
9118 m,
9119 k,
9120 n,
9121 e5m2,
9122 } => {
9123 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9124 unsafe {
9125 let xs = sl(*x, base, m * k);
9126 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
9127 let scales = sl(*scale, base, n);
9128 let out = sl_mut(*dst, base, m * n);
9129 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
9130 }
9131 }
9132
9133 Thunk::DequantMatMulNvfp4 {
9134 x,
9135 w_q,
9136 scale,
9137 global_scale,
9138 dst,
9139 m,
9140 k,
9141 n,
9142 } => {
9143 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9144 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
9145 unsafe {
9146 let xs = sl(*x, base, m * k);
9147 let w_bytes = std::slice::from_raw_parts(
9148 base.add(*w_q) as *const u8,
9149 (k * n).div_ceil(2),
9150 );
9151 let scale_bytes =
9152 std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
9153 let gs = sl(*global_scale, base, 1)[0];
9154 let out = sl_mut(*dst, base, m * n);
9155 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
9156 }
9157 }
9158
9159 Thunk::LoraMatMul {
9160 x,
9161 w,
9162 a,
9163 b,
9164 dst,
9165 m,
9166 k,
9167 n,
9168 r,
9169 scale,
9170 } => {
9171 let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
9172 unsafe {
9173 let xs = sl(*x, base, m * k);
9174 let ws = sl(*w, base, k * n);
9175 let a_s = sl(*a, base, k * r);
9176 let bs = sl(*b, base, r * n);
9177 let out = sl_mut(*dst, base, m * n);
9178 crate::blas::sgemm(xs, ws, out, m, k, n);
9179 let mut tmp = vec![0f32; m * r];
9180 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
9181 if *scale != 1.0 {
9182 for v in tmp.iter_mut() {
9183 *v *= *scale;
9184 }
9185 }
9186 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
9187 }
9188 }
9189
9190 Thunk::Attention {
9191 q,
9192 k,
9193 v,
9194 mask,
9195 out,
9196 batch,
9197 seq,
9198 kv_seq,
9199 heads,
9200 head_dim,
9201 mask_kind,
9202 q_row_stride,
9203 k_row_stride,
9204 v_row_stride,
9205 bhsd,
9206 } => {
9207 let (b, q_s, k_s, nh, dh) = (
9208 *batch as usize,
9209 *seq as usize,
9210 *kv_seq as usize,
9211 *heads as usize,
9212 *head_dim as usize,
9213 );
9214 let hs = nh * dh;
9215 let (qrs, krs, vrs) = if *bhsd {
9218 (dh, dh, dh)
9219 } else {
9220 (
9221 *q_row_stride as usize,
9222 *k_row_stride as usize,
9223 *v_row_stride as usize,
9224 )
9225 };
9226 let bhsd = *bhsd;
9227 let _ = (q_row_stride, k_row_stride, v_row_stride);
9228 let scale = (dh as f32).powf(-0.5);
9229 let ss = q_s * k_s;
9230 let cfg = crate::config::RuntimeConfig::global();
9231 unsafe {
9232 let q_len = if bhsd {
9239 b * nh * q_s * dh
9240 } else {
9241 b * q_s * qrs
9242 };
9243 let k_len = if bhsd {
9244 b * nh * k_s * dh
9245 } else {
9246 b * k_s * krs
9247 };
9248 let v_len = if bhsd {
9249 b * nh * k_s * dh
9250 } else {
9251 b * k_s * vrs
9252 };
9253 let q_data = sl(*q, base, q_len);
9254 let k_data = sl(*k, base, k_len);
9255 let v_data = sl(*v, base, v_len);
9256 let mask_data: &[f32] = match mask_kind {
9257 rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
9258 rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
9259 _ => &[],
9260 };
9261 let out_len = if bhsd {
9262 b * nh * q_s * dh
9263 } else {
9264 b * q_s * hs
9265 };
9266 let out_data = sl_mut(*out, base, out_len);
9267
9268 if bhsd {
9279 let scores = &mut sdpa_scores[..ss];
9280 for bi in 0..b {
9281 for hi in 0..nh {
9282 let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
9283 let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
9284 for qi in 0..q_s {
9286 let q_base = q_head_base + qi * dh;
9287 for ki in 0..k_s {
9288 let k_base = k_head_base + ki * dh;
9289 let mut dot = 0f32;
9290 for d in 0..dh {
9291 dot += q_data[q_base + d] * k_data[k_base + d];
9292 }
9293 scores[qi * k_s + ki] = dot * scale;
9294 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9295 && !mask_data.is_empty()
9296 && mask_data[bi * k_s + ki] < mask_thr
9297 {
9298 scores[qi * k_s + ki] = mask_neg;
9299 }
9300 }
9301 }
9302 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9303 let off = (bi * nh + hi) * q_s * k_s;
9304 for i in 0..q_s * k_s {
9305 scores[i] += mask_data[off + i];
9306 }
9307 }
9308 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9309 crate::kernels::neon_softmax(scores, q_s, k_s);
9310 for qi in 0..q_s {
9312 let o_base = q_head_base + qi * dh;
9313 for d in 0..dh {
9314 out_data[o_base + d] = 0.0;
9315 }
9316 for ki in 0..k_s {
9317 let sc = scores[qi * k_s + ki];
9318 if sc > score_thr {
9319 let v_base = k_head_base + ki * dh;
9320 for d in 0..dh {
9321 out_data[o_base + d] += sc * v_data[v_base + d];
9322 }
9323 }
9324 }
9325 }
9326 }
9327 }
9328 continue;
9329 }
9330
9331 if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
9338 let scores = &mut sdpa_scores[..ss];
9340 #[cfg(target_arch = "aarch64")]
9341 let neon_chunks = dh / 4;
9342
9343 for bi in 0..b {
9344 for hi in 0..nh {
9345 for qi in 0..q_s {
9347 let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
9348 for ki in 0..k_s {
9349 let k_off = bi * k_s * krs + ki * krs + hi * dh;
9350 #[cfg(target_arch = "aarch64")]
9351 let mut dot;
9352 #[cfg(not(target_arch = "aarch64"))]
9353 let mut dot = 0f32;
9354 #[cfg(target_arch = "aarch64")]
9355 {
9356 use std::arch::aarch64::*;
9357 let mut acc = vdupq_n_f32(0.0);
9358 for c in 0..neon_chunks {
9359 let vq =
9360 vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
9361 let vk =
9362 vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
9363 acc = vfmaq_f32(acc, vq, vk);
9364 }
9365 dot = vaddvq_f32(acc);
9366 for d in (neon_chunks * 4)..dh {
9367 dot += q_data[q_off + d] * k_data[k_off + d];
9368 }
9369 }
9370 #[cfg(not(target_arch = "aarch64"))]
9371 for d in 0..dh {
9372 dot += q_data[q_off + d] * k_data[k_off + d];
9373 }
9374 scores[qi * k_s + ki] = dot * scale;
9375 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9382 && !mask_data.is_empty()
9383 && mask_data[bi * k_s + ki] < mask_thr
9384 {
9385 scores[qi * k_s + ki] = mask_neg;
9386 }
9387 }
9388 }
9389
9390 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9391 let off = (bi * nh + hi) * q_s * k_s;
9392 for i in 0..q_s * k_s {
9393 scores[i] += mask_data[off + i];
9394 }
9395 }
9396 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9397 crate::kernels::neon_softmax(scores, q_s, k_s);
9398
9399 for qi in 0..q_s {
9401 let o_off = bi * q_s * hs + qi * hs + hi * dh;
9402 for d in 0..dh {
9404 out_data[o_off + d] = 0.0;
9405 }
9406 for ki in 0..k_s {
9407 let sc = scores[qi * k_s + ki];
9408 if sc > score_thr {
9409 let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
9410 #[cfg(target_arch = "aarch64")]
9411 {
9412 use std::arch::aarch64::*;
9413 let vsc = vdupq_n_f32(sc);
9414 for c in 0..neon_chunks {
9415 let off = c * 4;
9416 let vo = vld1q_f32(
9417 out_data.as_ptr().add(o_off + off),
9418 );
9419 let vv =
9420 vld1q_f32(v_data.as_ptr().add(v_off + off));
9421 vst1q_f32(
9422 out_data.as_mut_ptr().add(o_off + off),
9423 vfmaq_f32(vo, vsc, vv),
9424 );
9425 }
9426 }
9427 #[cfg(not(target_arch = "aarch64"))]
9428 for d in 0..dh {
9429 out_data[o_off + d] += sc * v_data[v_off + d];
9430 }
9431 }
9432 }
9433 }
9434 }
9435 }
9436 } else {
9437 let total_work = b * nh;
9439 let q_addr = q_data.as_ptr() as usize;
9440 let k_addr = k_data.as_ptr() as usize;
9441 let v_addr = v_data.as_ptr() as usize;
9442 let m_addr = mask_data.as_ptr() as usize;
9443 let o_addr = out_data.as_mut_ptr() as usize;
9444 let sc_addr = sdpa_scores.as_mut_ptr() as usize;
9445
9446 crate::pool::par_for(total_work, 1, &|off, cnt| {
9447 for idx in off..off + cnt {
9448 let bi = idx / nh;
9449 let hi = idx % nh;
9450
9451 let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
9452 let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
9453 let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
9454 let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
9455 let sc = std::slice::from_raw_parts_mut(
9456 (sc_addr as *mut f32).add(idx * ss),
9457 ss,
9458 );
9459
9460 crate::blas::sgemm_general(
9463 q_start,
9464 k_start,
9465 sc.as_mut_ptr(),
9466 q_s,
9467 k_s,
9468 dh,
9469 scale,
9470 0.0,
9471 qrs,
9472 krs,
9473 k_s,
9474 false,
9475 true,
9476 );
9477
9478 match mask_kind {
9479 rlx_ir::op::MaskKind::Custom => {
9480 let mask_bi = std::slice::from_raw_parts(
9481 (m_addr as *const f32).add(bi * k_s),
9482 k_s,
9483 );
9484 for ki in 0..k_s {
9485 if mask_bi[ki] < mask_thr {
9486 for qi in 0..q_s {
9487 sc[qi * k_s + ki] = mask_neg;
9488 }
9489 }
9490 }
9491 }
9492 rlx_ir::op::MaskKind::Bias => {
9493 let bias = std::slice::from_raw_parts(
9495 (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
9496 q_s * k_s,
9497 );
9498 for i in 0..q_s * k_s {
9499 sc[i] += bias[i];
9500 }
9501 }
9502 _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
9503 }
9504
9505 crate::kernels::neon_softmax(sc, q_s, k_s);
9506
9507 crate::blas::sgemm_general(
9511 sc.as_ptr(),
9512 v_start,
9513 o_start,
9514 q_s,
9515 dh,
9516 k_s,
9517 1.0,
9518 0.0,
9519 k_s,
9520 vrs,
9521 hs,
9522 false,
9523 false,
9524 );
9525 }
9526 });
9527 }
9528 }
9529 }
9530
9531 Thunk::AttentionBackward {
9532 q,
9533 k,
9534 v,
9535 dy,
9536 mask,
9537 out,
9538 batch,
9539 seq,
9540 kv_seq,
9541 heads,
9542 head_dim,
9543 mask_kind,
9544 wrt,
9545 bhsd,
9546 } => {
9547 let (b, q_s, k_s, nh, dh) = (
9548 *batch as usize,
9549 *seq as usize,
9550 *kv_seq as usize,
9551 *heads as usize,
9552 *head_dim as usize,
9553 );
9554 unsafe {
9555 let q_len = if *bhsd {
9556 b * nh * q_s * dh
9557 } else {
9558 b * q_s * nh * dh
9559 };
9560 let k_len = if *bhsd {
9561 b * nh * k_s * dh
9562 } else {
9563 b * k_s * nh * dh
9564 };
9565 let out_len = match wrt {
9566 rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
9567 k_len
9568 }
9569 rlx_ir::op::AttentionBwdWrt::Query => q_len,
9570 };
9571 let q_data = sl(*q, base, q_len);
9572 let k_data = sl(*k, base, k_len);
9573 let v_data = sl(*v, base, k_len);
9574 let dy_data = sl(*dy, base, q_len);
9575 let out_data = sl_mut(*out, base, out_len);
9576 let mask_data: &[f32] = if *mask != 0 {
9577 let ml = match mask_kind {
9578 rlx_ir::op::MaskKind::Custom => b * k_s,
9579 rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
9580 _ => 0,
9581 };
9582 sl(*mask, base, ml)
9583 } else {
9584 &[]
9585 };
9586 crate::attention_bwd::attention_backward(
9587 *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
9588 *mask_kind, mask_data, *bhsd,
9589 );
9590 }
9591 }
9592
9593 Thunk::ActivationInPlace { data, len, act } => {
9594 let len = *len as usize;
9595 unsafe {
9596 let d = sl_mut(*data, base, len);
9597 match act {
9598 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
9599 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
9600 Activation::Silu => crate::kernels::par_silu_inplace(d),
9601 Activation::Relu => {
9602 for v in d.iter_mut() {
9603 *v = v.max(0.0);
9604 }
9605 }
9606 Activation::Sigmoid => {
9607 for v in d.iter_mut() {
9608 *v = 1.0 / (1.0 + (-*v).exp());
9609 }
9610 }
9611 Activation::Tanh => {
9612 for v in d.iter_mut() {
9613 *v = v.tanh();
9614 }
9615 }
9616 Activation::Exp => {
9617 for v in d.iter_mut() {
9618 *v = v.exp();
9619 }
9620 }
9621 Activation::Log => {
9622 for v in d.iter_mut() {
9623 *v = v.ln();
9624 }
9625 }
9626 Activation::Sqrt => {
9627 for v in d.iter_mut() {
9628 *v = v.sqrt();
9629 }
9630 }
9631 Activation::Rsqrt => {
9632 for v in d.iter_mut() {
9633 *v = 1.0 / v.sqrt();
9634 }
9635 }
9636 Activation::Neg => {
9637 for v in d.iter_mut() {
9638 *v = -*v;
9639 }
9640 }
9641 Activation::Abs => {
9642 for v in d.iter_mut() {
9643 *v = v.abs();
9644 }
9645 }
9646 Activation::Round => {
9647 for v in d.iter_mut() {
9648 *v = v.round();
9649 }
9650 }
9651 Activation::Sin => {
9652 for v in d.iter_mut() {
9653 *v = v.sin();
9654 }
9655 }
9656 Activation::Cos => {
9657 for v in d.iter_mut() {
9658 *v = v.cos();
9659 }
9660 }
9661 Activation::Tan => {
9662 for v in d.iter_mut() {
9663 *v = v.tan();
9664 }
9665 }
9666 Activation::Atan => {
9667 for v in d.iter_mut() {
9668 *v = v.atan();
9669 }
9670 }
9671 }
9672 }
9673 }
9674
9675 Thunk::FusedAttnBlock {
9676 hidden,
9677 qkv_w,
9678 out_w,
9679 mask,
9680 out,
9681 qkv_b,
9682 out_b,
9683 cos,
9684 sin,
9685 cos_len,
9686 batch,
9687 seq,
9688 hs,
9689 nh,
9690 dh,
9691 has_bias,
9692 has_rope,
9693 } => {
9694 let (b, s) = (*batch as usize, *seq as usize);
9695 let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
9696 let m = b * s;
9697 let scale = (d_h as f32).powf(-0.5);
9698 let half = d_h / 2;
9699 unsafe {
9700 let inp = sl(*hidden, base, m * h);
9701 let wq = sl(*qkv_w, base, h * 3 * h);
9702 let wo = sl(*out_w, base, h * h);
9703 let mk = sl(*mask, base, b * s);
9704 let dst = sl_mut(*out, base, m * h);
9705
9706 let mut qkv = vec![0f32; m * 3 * h];
9708 let mut attn_out = vec![0f32; m * h];
9709 let mut scores_buf = vec![0f32; s * s]; crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
9713 if *has_bias {
9714 let bias = sl(*qkv_b, base, 3 * h);
9715 crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
9716 }
9717
9718 #[cfg(target_arch = "aarch64")]
9721 let neon_chunks = d_h / 4;
9722 #[cfg(target_arch = "aarch64")]
9723 let _rope_chunks = half / 4;
9724
9725 for bi in 0..b {
9726 for hi in 0..n_h {
9727 for qi in 0..s {
9729 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9730 for ki in 0..s {
9731 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9732 let mut dot = 0f32;
9733
9734 if *has_rope {
9735 let q_cos = qi * half;
9737 let k_cos = ki * half;
9738 let cos_tab = sl(*cos, base, *cos_len as usize);
9739 let sin_tab = sl(*sin, base, *cos_len as usize);
9740 for i in 0..half {
9743 let q1 = qkv[q_base + i];
9744 let q2 = qkv[q_base + half + i];
9745 let k1 = qkv[k_base + i];
9746 let k2 = qkv[k_base + half + i];
9747 let c_q = cos_tab[q_cos + i];
9748 let s_q = sin_tab[q_cos + i];
9749 let c_k = cos_tab[k_cos + i];
9750 let s_k = sin_tab[k_cos + i];
9751 let qr1 = q1 * c_q - q2 * s_q;
9752 let kr1 = k1 * c_k - k2 * s_k;
9753 let qr2 = q2 * c_q + q1 * s_q;
9754 let kr2 = k2 * c_k + k1 * s_k;
9755 dot += qr1 * kr1 + qr2 * kr2;
9756 }
9757 } else {
9758 #[cfg(target_arch = "aarch64")]
9760 {
9761 use std::arch::aarch64::*;
9762 let mut acc = vdupq_n_f32(0.0);
9763 for c in 0..neon_chunks {
9764 let vq =
9765 vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
9766 let vk =
9767 vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
9768 acc = vfmaq_f32(acc, vq, vk);
9769 }
9770 dot = vaddvq_f32(acc);
9771 for d in (neon_chunks * 4)..d_h {
9772 dot += qkv[q_base + d] * qkv[k_base + d];
9773 }
9774 }
9775 #[cfg(not(target_arch = "aarch64"))]
9776 for d in 0..d_h {
9777 dot += qkv[q_base + d] * qkv[k_base + d];
9778 }
9779 }
9780
9781 scores_buf[qi * s + ki] = dot * scale;
9782 if mk[bi * s + ki] < mask_thr {
9783 scores_buf[qi * s + ki] = mask_neg;
9784 }
9785 }
9786 }
9787
9788 crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
9790
9791 for qi in 0..s {
9793 let o_base = bi * s * h + qi * h + hi * d_h;
9794 for d in 0..d_h {
9795 attn_out[o_base + d] = 0.0;
9796 }
9797 for ki in 0..s {
9798 let sc = scores_buf[qi * s + ki];
9799 if sc > score_thr {
9800 let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9801 #[cfg(target_arch = "aarch64")]
9802 {
9803 use std::arch::aarch64::*;
9804 let vsc = vdupq_n_f32(sc);
9805 for c in 0..neon_chunks {
9806 let off = c * 4;
9807 let vo =
9808 vld1q_f32(attn_out.as_ptr().add(o_base + off));
9809 let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
9810 vst1q_f32(
9811 attn_out.as_mut_ptr().add(o_base + off),
9812 vfmaq_f32(vo, vsc, vv),
9813 );
9814 }
9815 }
9816 #[cfg(not(target_arch = "aarch64"))]
9817 for d in 0..d_h {
9818 attn_out[o_base + d] += sc * qkv[v_base + d];
9819 }
9820 }
9821 }
9822 }
9823 }
9824 }
9825
9826 crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
9828 if *has_bias {
9829 let bias = sl(*out_b, base, h);
9830 crate::blas::bias_add(dst, bias, m, h);
9831 }
9832 }
9833 }
9834
9835 Thunk::Rope {
9836 src,
9837 cos,
9838 sin,
9839 dst,
9840 batch,
9841 seq,
9842 hidden,
9843 head_dim,
9844 n_rot,
9845 cos_len,
9846 src_row_stride,
9847 } => {
9848 let (b, s, hs, dh, nr) = (
9849 *batch as usize,
9850 *seq as usize,
9851 *hidden as usize,
9852 *head_dim as usize,
9853 *n_rot as usize,
9854 );
9855 let tab_half = dh / 2;
9856 let rot_half = nr / 2;
9857 let nh = hs / dh;
9858 let cl = *cos_len as usize;
9859 let src_rs = *src_row_stride as usize;
9860 unsafe {
9861 let x = sl(*src, base, b * s * src_rs);
9862 let cos_tab = sl(*cos, base, cl);
9863 let sin_tab = sl(*sin, base, cl);
9864 let out = sl_mut(*dst, base, b * s * hs);
9865
9866 let total = b * s;
9867 let x_ptr = x.as_ptr() as usize;
9868 let o_ptr = out.as_mut_ptr() as usize;
9869 let c_ptr = cos_tab.as_ptr() as usize;
9870 let s_ptr = sin_tab.as_ptr() as usize;
9871
9872 crate::pool::par_for(total, 4, &|off, cnt| {
9873 for idx in off..off + cnt {
9874 let bi = idx / s;
9875 let si = idx % s;
9876 let tab_off = si * tab_half;
9877
9878 for hi in 0..nh {
9879 let src_base = bi * s * src_rs + si * src_rs + hi * dh;
9880 let dst_base = bi * s * hs + si * hs + hi * dh;
9881 let xp = (x_ptr as *const f32).add(src_base);
9882 let op = (o_ptr as *mut f32).add(dst_base);
9883 let cp = (c_ptr as *const f32).add(tab_off);
9884 let sp = (s_ptr as *const f32).add(tab_off);
9885
9886 for i in 0..rot_half {
9887 let x1 = *xp.add(i);
9888 let x2 = *xp.add(rot_half + i);
9889 let cv = *cp.add(i);
9890 let sv = *sp.add(i);
9891 *op.add(i) = x1 * cv - x2 * sv;
9892 *op.add(rot_half + i) = x2 * cv + x1 * sv;
9893 }
9894 for j in nr..dh {
9895 *op.add(j) = *xp.add(j);
9896 }
9897 }
9898 }
9899 });
9900 }
9901 }
9902 Thunk::FusedBertLayer {
9903 hidden,
9904 qkv_w,
9905 qkv_b,
9906 out_w,
9907 out_b,
9908 mask,
9909 ln1_g,
9910 ln1_b,
9911 eps1,
9912 fc1_w,
9913 fc1_b,
9914 fc2_w,
9915 fc2_b,
9916 ln2_g,
9917 ln2_b,
9918 eps2,
9919 out,
9920 batch,
9921 seq,
9922 hs,
9923 nh,
9924 dh,
9925 int_dim,
9926 } => {
9927 let (b, s, h, n_h, d_h) = (
9928 *batch as usize,
9929 *seq as usize,
9930 *hs as usize,
9931 *nh as usize,
9932 *dh as usize,
9933 );
9934 let m = b * s;
9935 let id = *int_dim as usize;
9936 let scale = (d_h as f32).powf(-0.5);
9937 let _half = d_h / 2;
9938 #[cfg(target_arch = "aarch64")]
9939 let neon_chunks = d_h / 4;
9940 unsafe {
9941 let inp = sl(*hidden, base, m * h);
9942 let dst = sl_mut(*out, base, m * h);
9943 let mk = sl(*mask, base, b * s);
9944
9945 let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
9947 let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
9948 let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
9949 let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
9950 let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
9951 let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
9952
9953 crate::blas::par_sgemm_bias(
9955 inp,
9956 sl(*qkv_w, base, h * 3 * h),
9957 sl(*qkv_b, base, 3 * h),
9958 qkv,
9959 m,
9960 h,
9961 3 * h,
9962 );
9963
9964 for bi in 0..b {
9966 for hi in 0..n_h {
9967 for qi in 0..s {
9968 for ki in 0..s {
9969 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9970 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9971 #[cfg(target_arch = "aarch64")]
9972 let dot;
9973 #[cfg(not(target_arch = "aarch64"))]
9974 let mut dot = 0f32;
9975 #[cfg(target_arch = "aarch64")]
9976 {
9977 use std::arch::aarch64::*;
9978 let mut acc = vdupq_n_f32(0.0);
9979 for c in 0..neon_chunks {
9980 acc = vfmaq_f32(
9981 acc,
9982 vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
9983 vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
9984 );
9985 }
9986 dot = vaddvq_f32(acc);
9987 }
9988 #[cfg(not(target_arch = "aarch64"))]
9989 for d in 0..d_h {
9990 dot += qkv[q_base + d] * qkv[k_base + d];
9991 }
9992 sc[qi * s + ki] = dot * scale;
9993 if mk[bi * s + ki] < mask_thr {
9994 sc[qi * s + ki] = mask_neg;
9995 }
9996 }
9997 }
9998 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
9999 for qi in 0..s {
10000 let o = bi * s * h + qi * h + hi * d_h;
10001 for d in 0..d_h {
10002 attn[o + d] = 0.0;
10003 }
10004 for ki in 0..s {
10005 let w = sc[qi * s + ki];
10006 if w > score_thr {
10007 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10008 #[cfg(target_arch = "aarch64")]
10009 {
10010 use std::arch::aarch64::*;
10011 let vw = vdupq_n_f32(w);
10012 for c in 0..neon_chunks {
10013 let off = c * 4;
10014 vst1q_f32(
10015 attn.as_mut_ptr().add(o + off),
10016 vfmaq_f32(
10017 vld1q_f32(attn.as_ptr().add(o + off)),
10018 vw,
10019 vld1q_f32(qkv.as_ptr().add(v + off)),
10020 ),
10021 );
10022 }
10023 }
10024 #[cfg(not(target_arch = "aarch64"))]
10025 for d in 0..d_h {
10026 attn[o + d] += w * qkv[v + d];
10027 }
10028 }
10029 }
10030 }
10031 }
10032 }
10033
10034 crate::blas::sgemm_bias(
10036 attn,
10037 sl(*out_w, base, h * h),
10038 sl(*out_b, base, h),
10039 res,
10040 m,
10041 h,
10042 h,
10043 );
10044 #[cfg(target_arch = "aarch64")]
10045 {
10046 use std::arch::aarch64::*;
10047 let chunks_h = (m * h) / 4;
10048 for c in 0..chunks_h {
10049 let off = c * 4;
10050 vst1q_f32(
10051 res.as_mut_ptr().add(off),
10052 vaddq_f32(
10053 vld1q_f32(res.as_ptr().add(off)),
10054 vld1q_f32(inp.as_ptr().add(off)),
10055 ),
10056 );
10057 }
10058 for i in (chunks_h * 4)..(m * h) {
10059 res[i] += inp[i];
10060 }
10061 }
10062 #[cfg(not(target_arch = "aarch64"))]
10063 for i in 0..m * h {
10064 res[i] += inp[i];
10065 }
10066
10067 let g1 = sl(*ln1_g, base, h);
10069 let b1 = sl(*ln1_b, base, h);
10070 for r in 0..m {
10071 crate::kernels::layer_norm_row(
10072 &res[r * h..(r + 1) * h],
10073 g1,
10074 b1,
10075 &mut normed[r * h..(r + 1) * h],
10076 h,
10077 *eps1,
10078 );
10079 }
10080
10081 crate::blas::par_sgemm_bias(
10083 normed,
10084 sl(*fc1_w, base, h * id),
10085 sl(*fc1_b, base, id),
10086 ffn,
10087 m,
10088 h,
10089 id,
10090 );
10091 crate::kernels::par_gelu_inplace(ffn);
10092
10093 crate::blas::par_sgemm_bias(
10095 ffn,
10096 sl(*fc2_w, base, id * h),
10097 sl(*fc2_b, base, h),
10098 res,
10099 m,
10100 id,
10101 h,
10102 );
10103 #[cfg(target_arch = "aarch64")]
10104 {
10105 use std::arch::aarch64::*;
10106 let chunks_h = (m * h) / 4;
10107 for c in 0..chunks_h {
10108 let off = c * 4;
10109 vst1q_f32(
10110 res.as_mut_ptr().add(off),
10111 vaddq_f32(
10112 vld1q_f32(res.as_ptr().add(off)),
10113 vld1q_f32(normed.as_ptr().add(off)),
10114 ),
10115 );
10116 }
10117 for i in (chunks_h * 4)..(m * h) {
10118 res[i] += normed[i];
10119 }
10120 }
10121 #[cfg(not(target_arch = "aarch64"))]
10122 for i in 0..m * h {
10123 res[i] += normed[i];
10124 }
10125
10126 let g2 = sl(*ln2_g, base, h);
10128 let b2 = sl(*ln2_b, base, h);
10129 for r in 0..m {
10130 crate::kernels::layer_norm_row(
10131 &res[r * h..(r + 1) * h],
10132 g2,
10133 b2,
10134 &mut dst[r * h..(r + 1) * h],
10135 h,
10136 *eps2,
10137 );
10138 }
10139 }
10140 }
10141
10142 Thunk::FusedNomicLayer {
10143 hidden,
10144 qkv_w,
10145 out_w,
10146 mask,
10147 cos,
10148 sin,
10149 cos_len,
10150 ln1_g,
10151 ln1_b,
10152 eps1,
10153 fc11_w,
10154 fc12_w: _,
10155 fc2_w,
10156 ln2_g,
10157 ln2_b,
10158 eps2,
10159 out,
10160 batch,
10161 seq,
10162 hs,
10163 nh,
10164 dh,
10165 int_dim,
10166 } => {
10167 let (b, s, h, n_h, d_h) = (
10168 *batch as usize,
10169 *seq as usize,
10170 *hs as usize,
10171 *nh as usize,
10172 *dh as usize,
10173 );
10174 let m = b * s;
10175 let id = *int_dim as usize;
10176 let scale = (d_h as f32).powf(-0.5);
10177 let half_dh = d_h / 2;
10178 #[cfg(target_arch = "aarch64")]
10179 let neon_chunks = d_h / 4;
10180 unsafe {
10181 let inp = sl(*hidden, base, m * h);
10182 let dst = sl_mut(*out, base, m * h);
10183 let mk = sl(*mask, base, b * s);
10184 let cos_tab = sl(*cos, base, *cos_len as usize);
10185 let sin_tab = sl(*sin, base, *cos_len as usize);
10186 let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
10188
10189 let mut qkv = vec![0f32; m * 3 * h];
10190 let mut attn = vec![0f32; m * h];
10191 let mut res = vec![0f32; m * h];
10192 let mut normed = vec![0f32; m * h];
10193 let mut ffn_concat = vec![0f32; m * 2 * id]; let mut sc = vec![0f32; s * s];
10195
10196 crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
10198
10199 for bi in 0..b {
10201 for hi in 0..n_h {
10202 for qi in 0..s {
10203 for ki in 0..s {
10204 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10205 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10206 let mut dot = 0f32;
10207 for i in 0..half_dh {
10208 let q1 = qkv[q_base + i];
10209 let q2 = qkv[q_base + half_dh + i];
10210 let k1 = qkv[k_base + i];
10211 let k2 = qkv[k_base + half_dh + i];
10212 let cq = cos_tab[qi * half_dh + i];
10213 let sq = sin_tab[qi * half_dh + i];
10214 let ck = cos_tab[ki * half_dh + i];
10215 let sk = sin_tab[ki * half_dh + i];
10216 dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
10217 + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
10218 }
10219 sc[qi * s + ki] = dot * scale;
10220 if mk[bi * s + ki] < mask_thr {
10221 sc[qi * s + ki] = mask_neg;
10222 }
10223 }
10224 }
10225 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
10226 for qi in 0..s {
10227 let o = bi * s * h + qi * h + hi * d_h;
10228 for d in 0..d_h {
10229 attn[o + d] = 0.0;
10230 }
10231 for ki in 0..s {
10232 let w = sc[qi * s + ki];
10233 if w > score_thr {
10234 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10235 #[cfg(target_arch = "aarch64")]
10236 {
10237 use std::arch::aarch64::*;
10238 let vw = vdupq_n_f32(w);
10239 for c in 0..neon_chunks {
10240 let off = c * 4;
10241 vst1q_f32(
10242 attn.as_mut_ptr().add(o + off),
10243 vfmaq_f32(
10244 vld1q_f32(attn.as_ptr().add(o + off)),
10245 vw,
10246 vld1q_f32(qkv.as_ptr().add(v + off)),
10247 ),
10248 );
10249 }
10250 }
10251 #[cfg(not(target_arch = "aarch64"))]
10252 for d in 0..d_h {
10253 attn[o + d] += w * qkv[v + d];
10254 }
10255 }
10256 }
10257 }
10258 }
10259 }
10260
10261 crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
10263 for i in 0..m * h {
10264 res[i] += inp[i];
10265 }
10266
10267 let g1 = sl(*ln1_g, base, h);
10269 let b1 = sl(*ln1_b, base, h);
10270 for r in 0..m {
10271 crate::kernels::layer_norm_row(
10272 &res[r * h..(r + 1) * h],
10273 g1,
10274 b1,
10275 &mut normed[r * h..(r + 1) * h],
10276 h,
10277 *eps1,
10278 );
10279 }
10280
10281 crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
10283 for row in 0..m {
10286 let bo = row * 2 * id;
10287 for j in 0..id {
10289 let x = ffn_concat[bo + id + j];
10290 ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
10291 }
10292 for j in 0..id {
10294 ffn_concat[bo + j] *= ffn_concat[bo + id + j];
10295 }
10296 }
10297
10298 crate::blas::sgemm_general(
10303 ffn_concat.as_ptr(),
10304 sl(*fc2_w, base, id * h).as_ptr(),
10305 res.as_mut_ptr(),
10306 m,
10307 h,
10308 id,
10309 1.0,
10310 0.0,
10311 2 * id,
10312 h,
10313 h,
10314 false,
10315 false,
10316 );
10317 for i in 0..m * h {
10318 res[i] += normed[i];
10319 }
10320
10321 let g2 = sl(*ln2_g, base, h);
10323 let b2 = sl(*ln2_b, base, h);
10324 for r in 0..m {
10325 crate::kernels::layer_norm_row(
10326 &res[r * h..(r + 1) * h],
10327 g2,
10328 b2,
10329 &mut dst[r * h..(r + 1) * h],
10330 h,
10331 *eps2,
10332 );
10333 }
10334 }
10335 }
10336
10337 Thunk::FusedSwiGLU {
10338 src,
10339 dst,
10340 n_half,
10341 total,
10342 gate_first,
10343 } => {
10344 let n = *n_half as usize;
10345 let t = *total as usize;
10346 let outer = t / n;
10347 let in_total = outer * 2 * n;
10348 let gate_first = *gate_first;
10349 unsafe {
10350 let inp = sl(*src, base, in_total);
10351 let out = sl_mut(*dst, base, t);
10352 for o in 0..outer {
10353 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
10354 let out_row = &mut out[o * n..(o + 1) * n];
10355 for i in 0..n {
10356 let (up, gate) = if gate_first {
10357 (in_row[n + i], in_row[i])
10358 } else {
10359 (in_row[i], in_row[n + i])
10360 };
10361 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
10362 }
10363 }
10364 }
10365 }
10366
10367 Thunk::Concat {
10368 dst,
10369 outer,
10370 inner,
10371 total_axis,
10372 inputs,
10373 } => {
10374 let outer = *outer as usize;
10375 let inner = *inner as usize;
10376 let total_axis = *total_axis as usize;
10377 let row_stride = total_axis * inner;
10378 let out_total = outer * row_stride;
10379 unsafe {
10380 let out = sl_mut(*dst, base, out_total);
10381 let mut cum: usize = 0;
10382 for (src_off, in_axis) in inputs {
10383 let in_axis = *in_axis as usize;
10384 let copy_per_row = in_axis * inner;
10385 let dst_col_off = cum * inner;
10386 let in_total = outer * copy_per_row;
10387 let inp = sl(*src_off, base, in_total);
10388 for o in 0..outer {
10389 let dst_row_start = o * row_stride + dst_col_off;
10390 let src_row_start = o * copy_per_row;
10391 out[dst_row_start..dst_row_start + copy_per_row]
10392 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10393 }
10394 cum += in_axis;
10395 }
10396 }
10397 }
10398
10399 Thunk::ConcatF64 {
10400 dst,
10401 outer,
10402 inner,
10403 total_axis,
10404 inputs,
10405 } => {
10406 let outer = *outer as usize;
10407 let inner = *inner as usize;
10408 let total_axis = *total_axis as usize;
10409 let row_stride = total_axis * inner;
10410 let out_total = outer * row_stride;
10411 unsafe {
10412 let out = sl_mut_f64(*dst, base, out_total);
10413 let mut cum: usize = 0;
10414 for (src_off, in_axis) in inputs {
10415 let in_axis = *in_axis as usize;
10416 let copy_per_row = in_axis * inner;
10417 let dst_col_off = cum * inner;
10418 let in_total = outer * copy_per_row;
10419 let inp = sl_f64(*src_off, base, in_total);
10420 for o in 0..outer {
10421 let dst_row_start = o * row_stride + dst_col_off;
10422 let src_row_start = o * copy_per_row;
10423 out[dst_row_start..dst_row_start + copy_per_row]
10424 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10425 }
10426 cum += in_axis;
10427 }
10428 }
10429 }
10430
10431 Thunk::Compare {
10432 lhs,
10433 rhs,
10434 dst,
10435 len,
10436 op,
10437 } => {
10438 let len = *len as usize;
10439 unsafe {
10440 let l = sl(*lhs, base, len);
10441 let r = sl(*rhs, base, len);
10442 let o = sl_mut(*dst, base, len);
10443 for i in 0..len {
10444 o[i] = match op {
10445 CmpOp::Eq => (l[i] == r[i]) as u32 as f32,
10446 CmpOp::Ne => (l[i] != r[i]) as u32 as f32,
10447 CmpOp::Lt => (l[i] < r[i]) as u32 as f32,
10448 CmpOp::Le => (l[i] <= r[i]) as u32 as f32,
10449 CmpOp::Gt => (l[i] > r[i]) as u32 as f32,
10450 CmpOp::Ge => (l[i] >= r[i]) as u32 as f32,
10451 };
10452 }
10453 }
10454 }
10455
10456 Thunk::Where {
10457 cond,
10458 on_true,
10459 on_false,
10460 dst,
10461 len,
10462 } => {
10463 let len = *len as usize;
10464 unsafe {
10465 let c = sl(*cond, base, len);
10466 let t = sl(*on_true, base, len);
10467 let e = sl(*on_false, base, len);
10468 let o = sl_mut(*dst, base, len);
10469 for i in 0..len {
10470 o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
10472 }
10473 }
10474 }
10475
10476 Thunk::ScatterAdd {
10477 updates,
10478 indices,
10479 dst,
10480 num_updates,
10481 out_dim,
10482 trailing,
10483 } => {
10484 let num_updates = *num_updates as usize;
10485 let out_dim = *out_dim as usize;
10486 let trailing = *trailing as usize;
10487 unsafe {
10488 let upd = sl(*updates, base, num_updates * trailing);
10489 let ids = sl(*indices, base, num_updates);
10490 let out = sl_mut(*dst, base, out_dim * trailing);
10491 for v in out.iter_mut() {
10493 *v = 0.0;
10494 }
10495 for i in 0..num_updates {
10496 let row = ids[i] as usize;
10497 debug_assert!(row < out_dim, "ScatterAdd index out of range");
10498 let src_off = i * trailing;
10499 let dst_off = row * trailing;
10500 for j in 0..trailing {
10501 out[dst_off + j] += upd[src_off + j];
10502 }
10503 }
10504 }
10505 }
10506
10507 Thunk::GroupedMatMul {
10508 input,
10509 weight,
10510 expert_idx,
10511 dst,
10512 m,
10513 k_dim,
10514 n,
10515 num_experts,
10516 } => {
10517 let m = *m as usize;
10518 let k_dim = *k_dim as usize;
10519 let n = *n as usize;
10520 let num_experts = *num_experts as usize;
10521 unsafe {
10522 let inp = sl(*input, base, m * k_dim);
10523 let wt = sl(*weight, base, num_experts * k_dim * n);
10524 let ids = sl(*expert_idx, base, m);
10525 let out = sl_mut(*dst, base, m * n);
10526
10527 let mut counts = vec![0usize; num_experts];
10530 for i in 0..m {
10531 let e = ids[i] as usize;
10532 debug_assert!(
10533 e < num_experts,
10534 "expert_idx out of range: {e} >= {num_experts}"
10535 );
10536 counts[e] += 1;
10537 }
10538 let mut offsets = vec![0usize; num_experts + 1];
10540 for e in 0..num_experts {
10541 offsets[e + 1] = offsets[e] + counts[e];
10542 }
10543 let mut packed_in = vec![0f32; m * k_dim];
10547 let mut original_pos = vec![0usize; m];
10548 let mut write_idx = vec![0usize; num_experts];
10549 for i in 0..m {
10550 let e = ids[i] as usize;
10551 let dst_row = offsets[e] + write_idx[e];
10552 packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
10553 .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
10554 original_pos[dst_row] = i;
10555 write_idx[e] += 1;
10556 }
10557
10558 let mut packed_out = vec![0f32; m * n];
10562 let expert_stride = k_dim * n;
10563 let gmm_ord = crate::moe_residency::next_gmm_ord();
10564 let moe_layer = gmm_ord / 3;
10565 for e in 0..num_experts {
10566 let count = counts[e];
10567 if count == 0 {
10568 continue;
10569 }
10570 crate::moe_residency::record_expert_tokens(moe_layer, e, count);
10571 let in_start = offsets[e];
10572 let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
10573 let w_slab: &[f32] =
10574 if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
10575 if let Some(ptr) =
10576 crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
10577 {
10578 std::slice::from_raw_parts(ptr, expert_stride)
10579 } else {
10580 &wt[e * expert_stride..(e + 1) * expert_stride]
10581 }
10582 } else {
10583 &wt[e * expert_stride..(e + 1) * expert_stride]
10584 };
10585 let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
10586 crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
10587 }
10588
10589 for packed_idx in 0..m {
10591 let i = original_pos[packed_idx];
10592 out[i * n..(i + 1) * n]
10593 .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
10594 }
10595 }
10596 }
10597
10598 Thunk::DequantGroupedMatMulGguf {
10599 input,
10600 w_q,
10601 expert_idx,
10602 dst,
10603 m,
10604 k_dim,
10605 n,
10606 num_experts,
10607 scheme,
10608 } => {
10609 let m = *m as usize;
10610 let k_dim = *k_dim as usize;
10611 let n = *n as usize;
10612 let num_experts = *num_experts as usize;
10613 let block_elems = scheme.gguf_block_size() as usize;
10614 let block_bytes = scheme.gguf_block_bytes() as usize;
10615 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10616 unsafe {
10617 let inp = sl(*input, base, m * k_dim);
10618 let wt = std::slice::from_raw_parts(
10619 base.add(*w_q) as *const u8,
10620 num_experts * slab_bytes,
10621 );
10622 let ids = sl(*expert_idx, base, m);
10623 let out = sl_mut(*dst, base, m * n);
10624 crate::gguf_matmul::gguf_grouped_matmul_bt(
10625 inp,
10626 wt,
10627 ids,
10628 out,
10629 m,
10630 k_dim,
10631 n,
10632 num_experts,
10633 *scheme,
10634 );
10635 }
10636 }
10637
10638 Thunk::DequantMoEWeightsGguf {
10639 w_q,
10640 dst,
10641 k_dim,
10642 n,
10643 num_experts,
10644 scheme,
10645 } => {
10646 let k_dim = *k_dim as usize;
10647 let n = *n as usize;
10648 let num_experts = *num_experts as usize;
10649 let block_elems = scheme.gguf_block_size() as usize;
10650 let block_bytes = scheme.gguf_block_bytes() as usize;
10651 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10652 unsafe {
10653 let wt = std::slice::from_raw_parts(
10654 base.add(*w_q) as *const u8,
10655 num_experts * slab_bytes,
10656 );
10657 let out = sl_mut(*dst, base, num_experts * k_dim * n);
10658 crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
10659 wt,
10660 out,
10661 num_experts,
10662 k_dim,
10663 n,
10664 *scheme,
10665 );
10666 }
10667 }
10668
10669 Thunk::TopK {
10670 src,
10671 dst,
10672 outer,
10673 axis_dim,
10674 k,
10675 } => {
10676 let outer = *outer as usize;
10677 let axis_dim = *axis_dim as usize;
10678 let k = *k as usize;
10679 unsafe {
10680 let inp = sl(*src, base, outer * axis_dim);
10681 let out = sl_mut(*dst, base, outer * k);
10682 let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
10686 for o in 0..outer {
10687 row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
10688 for ki in 0..k {
10689 let mut best_i = 0usize;
10691 let mut best_v = row_buf[0];
10692 for i in 1..axis_dim {
10693 let v = row_buf[i];
10694 if v > best_v {
10695 best_v = v;
10696 best_i = i;
10697 }
10698 }
10699 out[o * k + ki] = best_i as f32;
10700 row_buf[best_i] = f32::NEG_INFINITY;
10703 }
10704 }
10705 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
10706 cap.push_topk_f32(&out[..outer * k], axis_dim);
10707 }
10708 }
10709 }
10710
10711 Thunk::Reduce {
10712 src,
10713 dst,
10714 outer,
10715 reduced,
10716 inner,
10717 op,
10718 } => {
10719 let outer = *outer as usize;
10720 let reduced = *reduced as usize;
10721 let inner = *inner as usize;
10722 let in_total = outer * reduced * inner;
10723 let out_total = outer * inner;
10724 unsafe {
10725 let inp = sl(*src, base, in_total);
10726 let out = sl_mut(*dst, base, out_total);
10727 for o in 0..outer {
10728 for i in 0..inner {
10729 let mut acc = match op {
10730 ReduceOp::Max => f32::NEG_INFINITY,
10731 ReduceOp::Min => f32::INFINITY,
10732 ReduceOp::Prod => 1.0f32,
10733 _ => 0.0f32, };
10735 for r in 0..reduced {
10737 let v = inp[o * reduced * inner + r * inner + i];
10738 acc = match op {
10739 ReduceOp::Sum | ReduceOp::Mean => acc + v,
10740 ReduceOp::Max => acc.max(v),
10741 ReduceOp::Min => acc.min(v),
10742 ReduceOp::Prod => acc * v,
10743 };
10744 }
10745 if matches!(op, ReduceOp::Mean) {
10746 acc /= reduced as f32;
10747 }
10748 out[o * inner + i] = acc;
10749 }
10750 }
10751 }
10752 }
10753
10754 Thunk::Conv2D1x1 {
10755 src,
10756 weight,
10757 dst,
10758 n,
10759 c_in,
10760 c_out,
10761 hw,
10762 } => {
10763 let n = *n as usize;
10764 let c_in = *c_in as usize;
10765 let c_out = *c_out as usize;
10766 let hw = *hw as usize;
10767 unsafe {
10768 let inp = sl(*src, base, n * c_in * hw);
10769 let wt = sl(*weight, base, c_out * c_in);
10770 let out = sl_mut(*dst, base, n * c_out * hw);
10771 for ni in 0..n {
10776 let in_off = ni * c_in * hw;
10777 let out_off = ni * c_out * hw;
10778 crate::blas::sgemm(
10779 wt,
10780 &inp[in_off..in_off + c_in * hw],
10781 &mut out[out_off..out_off + c_out * hw],
10782 c_out,
10783 c_in,
10784 hw,
10785 );
10786 }
10787 }
10788 }
10789
10790 Thunk::Conv2D {
10791 src,
10792 weight,
10793 dst,
10794 n,
10795 c_in,
10796 h,
10797 w,
10798 c_out,
10799 h_out,
10800 w_out,
10801 kh,
10802 kw,
10803 sh,
10804 sw,
10805 ph,
10806 pw,
10807 dh,
10808 dw,
10809 groups,
10810 } => {
10811 let n = *n as usize;
10812 let c_in = *c_in as usize;
10813 let h = *h as usize;
10814 let w = *w as usize;
10815 let c_out = *c_out as usize;
10816 let h_out = *h_out as usize;
10817 let w_out = *w_out as usize;
10818 let kh = *kh as usize;
10819 let kw = *kw as usize;
10820 let sh = *sh as usize;
10821 let sw = *sw as usize;
10822 let ph = *ph as usize;
10823 let pw = *pw as usize;
10824 let dh = *dh as usize;
10825 let dw = *dw as usize;
10826 let groups = *groups as usize;
10827 let c_in_per_g = c_in / groups;
10828 let c_out_per_g = c_out / groups;
10829 unsafe {
10830 let inp = sl(*src, base, n * c_in * h * w);
10831 let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
10832 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
10833 for ni in 0..n {
10834 for co in 0..c_out {
10835 let g = co / c_out_per_g;
10836 let ci_start = g * c_in_per_g;
10837 for ho in 0..h_out {
10838 for wo in 0..w_out {
10839 let mut acc = 0f32;
10840 for ci_off in 0..c_in_per_g {
10841 let ci = ci_start + ci_off;
10842 let in_chan = ((ni * c_in) + ci) * h * w;
10843 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
10844 for ki in 0..kh {
10845 for kj in 0..kw {
10846 let hi = ho * sh + ki * dh;
10847 let wi = wo * sw + kj * dw;
10848 if hi < ph || wi < pw {
10849 continue;
10850 }
10851 let hi = hi - ph;
10852 let wi = wi - pw;
10853 if hi >= h || wi >= w {
10854 continue;
10855 }
10856 acc += inp[in_chan + hi * w + wi]
10857 * wt[wt_chan + ki * kw + kj];
10858 }
10859 }
10860 }
10861 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
10862 acc;
10863 }
10864 }
10865 }
10866 }
10867 }
10868 }
10869
10870 Thunk::Pool2D {
10871 src,
10872 dst,
10873 n,
10874 c,
10875 h,
10876 w,
10877 h_out,
10878 w_out,
10879 kh,
10880 kw,
10881 sh,
10882 sw,
10883 ph,
10884 pw,
10885 kind,
10886 } => {
10887 let n = *n as usize;
10888 let c = *c as usize;
10889 let h = *h as usize;
10890 let w = *w as usize;
10891 let h_out = *h_out as usize;
10892 let w_out = *w_out as usize;
10893 let kh = *kh as usize;
10894 let kw = *kw as usize;
10895 let sh = *sh as usize;
10896 let sw = *sw as usize;
10897 let ph = *ph as usize;
10898 let pw = *pw as usize;
10899 let kernel_area = (kh * kw) as f32;
10900 unsafe {
10901 let inp = sl(*src, base, n * c * h * w);
10902 let out = sl_mut(*dst, base, n * c * h_out * w_out);
10903 for ni in 0..n {
10904 for ci in 0..c {
10905 let in_chan = ni * c * h * w + ci * h * w;
10906 let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
10907 for ho in 0..h_out {
10908 for wo in 0..w_out {
10909 let mut acc = match kind {
10910 ReduceOp::Max => f32::NEG_INFINITY,
10911 _ => 0f32, };
10913 for ki in 0..kh {
10914 for kj in 0..kw {
10915 let hi = ho * sh + ki;
10916 let wi = wo * sw + kj;
10917 if hi < ph || wi < pw {
10919 continue;
10920 }
10921 let hi = hi - ph;
10922 let wi = wi - pw;
10923 if hi >= h || wi >= w {
10924 continue;
10925 }
10926 let v = inp[in_chan + hi * w + wi];
10927 match kind {
10928 ReduceOp::Max => acc = acc.max(v),
10929 _ => acc += v,
10930 }
10931 }
10932 }
10933 if matches!(kind, ReduceOp::Mean) {
10934 acc /= kernel_area;
10935 }
10936 out[out_chan + ho * w_out + wo] = acc;
10937 }
10938 }
10939 }
10940 }
10941 }
10942 }
10943
10944 Thunk::ReluBackward { x, dy, dx, len } => {
10945 let len = *len as usize;
10946 unsafe {
10947 let xs = sl(*x, base, len);
10948 let dys = sl(*dy, base, len);
10949 let out = sl_mut(*dx, base, len);
10950 for i in 0..len {
10951 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10952 }
10953 }
10954 }
10955
10956 Thunk::ReluBackwardF64 { x, dy, dx, len } => {
10957 let len = *len as usize;
10958 unsafe {
10959 let xs = sl_f64(*x, base, len);
10960 let dys = sl_f64(*dy, base, len);
10961 let out = sl_mut_f64(*dx, base, len);
10962 for i in 0..len {
10963 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10964 }
10965 }
10966 }
10967
10968 Thunk::QMatMul {
10969 x,
10970 w,
10971 bias,
10972 out,
10973 m,
10974 k,
10975 n,
10976 x_zp,
10977 w_zp,
10978 out_zp,
10979 mult,
10980 } => {
10981 let m = *m as usize;
10982 let k = *k as usize;
10983 let n = *n as usize;
10984 unsafe {
10985 let x_ptr = base.add(*x) as *const i8;
10986 let w_ptr = base.add(*w) as *const i8;
10987 let bias_ptr = base.add(*bias) as *const i32;
10988 let out_ptr = base.add(*out) as *mut i8;
10989 for mi in 0..m {
10990 for ni in 0..n {
10991 let mut acc: i32 = *bias_ptr.add(ni);
10992 for ki in 0..k {
10993 let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
10994 let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
10995 acc += xv * wv;
10996 }
10997 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
11000 let r = r.clamp(-128, 127) as i8;
11001 *out_ptr.add(mi * n + ni) = r;
11002 }
11003 }
11004 }
11005 }
11006
11007 Thunk::QConv2d {
11008 x,
11009 w,
11010 bias,
11011 out,
11012 n,
11013 c_in,
11014 h,
11015 w_in,
11016 c_out,
11017 h_out,
11018 w_out,
11019 kh,
11020 kw,
11021 sh,
11022 sw,
11023 ph,
11024 pw,
11025 dh,
11026 dw,
11027 groups,
11028 x_zp,
11029 w_zp,
11030 out_zp,
11031 mult,
11032 } => {
11033 let n = *n as usize;
11034 let c_in = *c_in as usize;
11035 let h = *h as usize;
11036 let w_in = *w_in as usize;
11037 let c_out = *c_out as usize;
11038 let h_out = *h_out as usize;
11039 let w_out = *w_out as usize;
11040 let kh = *kh as usize;
11041 let kw = *kw as usize;
11042 let sh = *sh as usize;
11043 let sw = *sw as usize;
11044 let ph = *ph as usize;
11045 let pw = *pw as usize;
11046 let dh = *dh as usize;
11047 let dw = *dw as usize;
11048 let groups = *groups as usize;
11049 let c_in_per_g = c_in / groups;
11050 let c_out_per_g = c_out / groups;
11051 unsafe {
11052 let x_ptr = base.add(*x) as *const i8;
11053 let w_ptr = base.add(*w) as *const i8;
11054 let bias_ptr = base.add(*bias) as *const i32;
11055 let out_ptr = base.add(*out) as *mut i8;
11056 for ni in 0..n {
11057 for co in 0..c_out {
11058 let g = co / c_out_per_g;
11059 let ci_start = g * c_in_per_g;
11060 for ho in 0..h_out {
11061 for wo in 0..w_out {
11062 let mut acc: i32 = *bias_ptr.add(co);
11063 for ci_off in 0..c_in_per_g {
11064 let ci = ci_start + ci_off;
11065 let in_chan = ((ni * c_in) + ci) * h * w_in;
11066 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
11067 for ki in 0..kh {
11068 for kj in 0..kw {
11069 let hi = ho * sh + ki * dh;
11070 let wi = wo * sw + kj * dw;
11071 if hi < ph || wi < pw {
11072 continue;
11073 }
11074 let hi = hi - ph;
11075 let wi = wi - pw;
11076 if hi >= h || wi >= w_in {
11077 continue;
11078 }
11079 let xv = *x_ptr.add(in_chan + hi * w_in + wi)
11080 as i32
11081 - *x_zp;
11082 let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
11083 - *w_zp;
11084 acc += xv * wv;
11085 }
11086 }
11087 }
11088 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
11089 let r = r.clamp(-128, 127) as i8;
11090 let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
11091 *out_ptr.add(dst) = r;
11092 }
11093 }
11094 }
11095 }
11096 }
11097 }
11098
11099 Thunk::Quantize {
11100 x,
11101 q,
11102 len,
11103 chan_axis: _,
11104 chan_dim,
11105 inner,
11106 scales,
11107 zero_points,
11108 } => {
11109 let len = *len as usize;
11110 let chan_dim = *chan_dim as usize;
11111 let inner = *inner as usize;
11112 unsafe {
11113 let xs = sl(*x, base, len);
11114 let q_ptr = base.add(*q) as *mut i8;
11115 for i in 0..len {
11116 let c = if chan_dim == 1 {
11117 0
11118 } else {
11119 (i / inner) % chan_dim
11120 };
11121 let inv_scale = 1.0 / scales[c];
11122 let zp = zero_points[c];
11123 let v = (xs[i] * inv_scale).round() as i32 + zp;
11124 *q_ptr.add(i) = v.clamp(-128, 127) as i8;
11125 }
11126 }
11127 }
11128
11129 Thunk::Dequantize {
11130 q,
11131 x,
11132 len,
11133 chan_axis: _,
11134 chan_dim,
11135 inner,
11136 scales,
11137 zero_points,
11138 } => {
11139 let len = *len as usize;
11140 let chan_dim = *chan_dim as usize;
11141 let inner = *inner as usize;
11142 unsafe {
11143 let q_ptr = base.add(*q) as *const i8;
11144 let out = sl_mut(*x, base, len);
11145 for i in 0..len {
11146 let c = if chan_dim == 1 {
11147 0
11148 } else {
11149 (i / inner) % chan_dim
11150 };
11151 let scale = scales[c];
11152 let zp = zero_points[c];
11153 let qv = *q_ptr.add(i) as i32;
11154 out[i] = (qv - zp) as f32 * scale;
11155 }
11156 }
11157 }
11158
11159 Thunk::FakeQuantize {
11160 x,
11161 out,
11162 len,
11163 chan_axis: _,
11164 chan_dim,
11165 inner,
11166 bits,
11167 ste: _,
11168 scale_mode,
11169 state_off,
11170 } => {
11171 use rlx_ir::op::ScaleMode;
11172 let len = *len as usize;
11173 let chan_dim = *chan_dim as usize;
11174 let inner = *inner as usize;
11175 let q_max: f32 = match *bits {
11176 8 => 127.0,
11177 4 => 7.0,
11178 2 => 1.0,
11179 n => panic!("FakeQuantize: unsupported bits {n}"),
11180 };
11181 unsafe {
11182 let xs = sl(*x, base, len);
11183 let outs = sl_mut(*out, base, len);
11184
11185 let mut scale = vec![0f32; chan_dim];
11186 match scale_mode {
11187 ScaleMode::PerBatch => {
11188 let mut max_abs = vec![0f32; chan_dim];
11189 for i in 0..len {
11190 let c = if chan_dim == 1 {
11191 0
11192 } else {
11193 (i / inner) % chan_dim
11194 };
11195 let a = xs[i].abs();
11196 if a > max_abs[c] {
11197 max_abs[c] = a;
11198 }
11199 }
11200 for c in 0..chan_dim {
11201 scale[c] = (max_abs[c] / q_max).max(1e-12);
11202 }
11203 }
11204 ScaleMode::EMA { decay } => {
11205 let mut max_abs = vec![0f32; chan_dim];
11208 for i in 0..len {
11209 let c = if chan_dim == 1 {
11210 0
11211 } else {
11212 (i / inner) % chan_dim
11213 };
11214 let a = xs[i].abs();
11215 if a > max_abs[c] {
11216 max_abs[c] = a;
11217 }
11218 }
11219 let state =
11220 sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
11221 for c in 0..chan_dim {
11222 let cur = (max_abs[c] / q_max).max(1e-12);
11223 let blended = if state[c] <= 0.0 {
11225 cur
11226 } else {
11227 *decay * state[c] + (1.0 - *decay) * cur
11228 };
11229 state[c] = blended;
11230 scale[c] = blended;
11231 }
11232 }
11233 ScaleMode::Fixed => {
11234 let state =
11235 sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
11236 for c in 0..chan_dim {
11237 scale[c] = state[c].max(1e-12);
11238 }
11239 }
11240 }
11241
11242 for i in 0..len {
11243 let c = if chan_dim == 1 {
11244 0
11245 } else {
11246 (i / inner) % chan_dim
11247 };
11248 let s = scale[c];
11249 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11250 outs[i] = qv * s;
11251 }
11252 }
11253 }
11254
11255 Thunk::ActivationBackward {
11256 x,
11257 dy,
11258 dx,
11259 len,
11260 kind,
11261 } => {
11262 let len = *len as usize;
11263 unsafe {
11264 let xs = sl(*x, base, len);
11265 let dys = sl(*dy, base, len);
11266 let out = sl_mut(*dx, base, len);
11267 activation_backward_kernel(*kind, xs, dys, out);
11268 }
11269 }
11270
11271 Thunk::ActivationBackwardF64 {
11272 x,
11273 dy,
11274 dx,
11275 len,
11276 kind,
11277 } => {
11278 let len = *len as usize;
11279 unsafe {
11280 let xs = sl_f64(*x, base, len);
11281 let dys = sl_f64(*dy, base, len);
11282 let out = sl_mut_f64(*dx, base, len);
11283 activation_backward_kernel_f64(*kind, xs, dys, out);
11284 }
11285 }
11286
11287 Thunk::FakeQuantizeLSQ {
11288 x,
11289 scale_off,
11290 out,
11291 len,
11292 chan_axis: _,
11293 chan_dim,
11294 inner,
11295 bits,
11296 } => {
11297 let len = *len as usize;
11298 let chan_dim = *chan_dim as usize;
11299 let inner = *inner as usize;
11300 let q_max: f32 = match *bits {
11301 8 => 127.0,
11302 4 => 7.0,
11303 2 => 1.0,
11304 n => panic!("FakeQuantizeLSQ: bad bits {n}"),
11305 };
11306 unsafe {
11307 let xs = sl(*x, base, len);
11308 let scale = sl(*scale_off, base, chan_dim);
11309 let outs = sl_mut(*out, base, len);
11310 for i in 0..len {
11311 let c = if chan_dim == 1 {
11312 0
11313 } else {
11314 (i / inner) % chan_dim
11315 };
11316 let s = scale[c].max(1e-12);
11317 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11318 outs[i] = qv * s;
11319 }
11320 }
11321 }
11322
11323 Thunk::FakeQuantizeLSQBackwardX {
11324 x,
11325 scale_off,
11326 dy,
11327 dx,
11328 len,
11329 chan_axis: _,
11330 chan_dim,
11331 inner,
11332 bits,
11333 } => {
11334 let len = *len as usize;
11335 let chan_dim = *chan_dim as usize;
11336 let inner = *inner as usize;
11337 let q_max: f32 = match *bits {
11338 8 => 127.0,
11339 4 => 7.0,
11340 2 => 1.0,
11341 n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
11342 };
11343 unsafe {
11344 let xs = sl(*x, base, len);
11345 let scale = sl(*scale_off, base, chan_dim);
11346 let dys = sl(*dy, base, len);
11347 let outs = sl_mut(*dx, base, len);
11348 for i in 0..len {
11350 let c = if chan_dim == 1 {
11351 0
11352 } else {
11353 (i / inner) % chan_dim
11354 };
11355 let z = xs[i] / scale[c].max(1e-12);
11356 outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
11357 }
11358 }
11359 }
11360
11361 Thunk::FakeQuantizeLSQBackwardScale {
11362 x,
11363 scale_off,
11364 dy,
11365 dscale,
11366 len,
11367 chan_axis: _,
11368 chan_dim,
11369 inner,
11370 bits,
11371 } => {
11372 let len = *len as usize;
11373 let chan_dim = *chan_dim as usize;
11374 let inner = *inner as usize;
11375 let q_max: f32 = match *bits {
11376 8 => 127.0,
11377 4 => 7.0,
11378 2 => 1.0,
11379 n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
11380 };
11381 unsafe {
11382 let xs = sl(*x, base, len);
11383 let scale = sl(*scale_off, base, chan_dim);
11384 let dys = sl(*dy, base, len);
11385 let outs = sl_mut(*dscale, base, chan_dim);
11386 for v in outs.iter_mut() {
11387 *v = 0.0;
11388 }
11389 for i in 0..len {
11392 let c = if chan_dim == 1 {
11393 0
11394 } else {
11395 (i / inner) % chan_dim
11396 };
11397 let s = scale[c].max(1e-12);
11398 let z = xs[i] / s;
11399 let psi = if z.abs() <= q_max {
11400 -z + z.round()
11401 } else if z > 0.0 {
11402 q_max
11403 } else {
11404 -q_max
11405 };
11406 outs[c] += psi * dys[i];
11407 }
11408 }
11409 }
11410
11411 Thunk::FakeQuantizeBackward {
11412 x,
11413 dy,
11414 dx,
11415 len,
11416 chan_axis: _,
11417 chan_dim,
11418 inner,
11419 bits,
11420 ste,
11421 } => {
11422 use rlx_ir::op::SteKind;
11423 let len = *len as usize;
11424 let chan_dim = *chan_dim as usize;
11425 let inner = *inner as usize;
11426 let q_max: f32 = match *bits {
11427 8 => 127.0,
11428 4 => 7.0,
11429 2 => 1.0,
11430 n => panic!("FakeQuantizeBackward: bad bits {n}"),
11431 };
11432 unsafe {
11433 let xs = sl(*x, base, len);
11434 let dys = sl(*dy, base, len);
11435 let outs = sl_mut(*dx, base, len);
11436
11437 let mut max_abs = vec![0f32; chan_dim];
11439 for i in 0..len {
11440 let c = if chan_dim == 1 {
11441 0
11442 } else {
11443 (i / inner) % chan_dim
11444 };
11445 let a = xs[i].abs();
11446 if a > max_abs[c] {
11447 max_abs[c] = a;
11448 }
11449 }
11450 let mut scale = vec![0f32; chan_dim];
11451 for c in 0..chan_dim {
11452 scale[c] = (max_abs[c] / q_max).max(1e-12);
11453 }
11454
11455 match *ste {
11456 SteKind::Identity => {
11457 outs.copy_from_slice(dys);
11459 }
11460 SteKind::ClippedIdentity => {
11461 for i in 0..len {
11464 let c = if chan_dim == 1 {
11465 0
11466 } else {
11467 (i / inner) % chan_dim
11468 };
11469 let bound = q_max * scale[c];
11470 outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
11471 }
11472 }
11473 SteKind::Tanh => {
11474 for i in 0..len {
11476 let c = if chan_dim == 1 {
11477 0
11478 } else {
11479 (i / inner) % chan_dim
11480 };
11481 let t = (xs[i] / scale[c]).tanh();
11482 outs[i] = dys[i] * (1.0 - t * t);
11483 }
11484 }
11485 SteKind::HardTanh => {
11486 for i in 0..len {
11488 let c = if chan_dim == 1 {
11489 0
11490 } else {
11491 (i / inner) % chan_dim
11492 };
11493 let bound = q_max * scale[c];
11494 let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
11495 outs[i] = dys[i] * attenuation;
11496 }
11497 }
11498 }
11499 }
11500 }
11501
11502 Thunk::LayerNormBackwardInput {
11503 x,
11504 gamma,
11505 dy,
11506 dx,
11507 rows,
11508 h,
11509 eps,
11510 } => {
11511 let rows = *rows as usize;
11512 let h = *h as usize;
11513 let eps = *eps;
11514 unsafe {
11515 let xs = sl(*x, base, rows * h);
11516 let g = sl(*gamma, base, h);
11517 let dys = sl(*dy, base, rows * h);
11518 let out = sl_mut(*dx, base, rows * h);
11519 let n_inv = 1.0 / h as f32;
11520 for r in 0..rows {
11521 let xr = &xs[r * h..(r + 1) * h];
11522 let dyr = &dys[r * h..(r + 1) * h];
11523 let mut sum = 0f32;
11526 for &v in xr {
11527 sum += v;
11528 }
11529 let mean = sum * n_inv;
11530 let mut var = 0f32;
11531 for &v in xr {
11532 let d = v - mean;
11533 var += d * d;
11534 }
11535 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11536
11537 let mut s_sy = 0f32;
11540 let mut s_sxh = 0f32;
11541 for d in 0..h {
11542 let xh = (xr[d] - mean) * inv_std;
11543 let sy = dyr[d] * g[d];
11544 s_sy += sy;
11545 s_sxh += sy * xh;
11546 }
11547 let m_sy = s_sy * n_inv;
11548 let m_sxh = s_sxh * n_inv;
11549
11550 for d in 0..h {
11551 let xh = (xr[d] - mean) * inv_std;
11552 let sy = dyr[d] * g[d];
11553 out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
11554 }
11555 }
11556 }
11557 }
11558
11559 Thunk::LayerNormBackwardGamma {
11560 x,
11561 dy,
11562 dgamma,
11563 rows,
11564 h,
11565 eps,
11566 } => {
11567 let rows = *rows as usize;
11568 let h = *h as usize;
11569 let eps = *eps;
11570 unsafe {
11571 let xs = sl(*x, base, rows * h);
11572 let dys = sl(*dy, base, rows * h);
11573 let out = sl_mut(*dgamma, base, h);
11574 for v in out.iter_mut() {
11575 *v = 0.0;
11576 }
11577 let n_inv = 1.0 / h as f32;
11578 for r in 0..rows {
11579 let xr = &xs[r * h..(r + 1) * h];
11580 let dyr = &dys[r * h..(r + 1) * h];
11581 let mut sum = 0f32;
11582 for &v in xr {
11583 sum += v;
11584 }
11585 let mean = sum * n_inv;
11586 let mut var = 0f32;
11587 for &v in xr {
11588 let d = v - mean;
11589 var += d * d;
11590 }
11591 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11592 for d in 0..h {
11593 let xh = (xr[d] - mean) * inv_std;
11594 out[d] += dyr[d] * xh;
11595 }
11596 }
11597 }
11598 }
11599
11600 Thunk::RmsNormBackwardInput {
11601 x,
11602 gamma,
11603 beta,
11604 dy,
11605 dx,
11606 rows,
11607 h,
11608 eps,
11609 } => {
11610 let (rows, h) = (*rows as usize, *h as usize);
11611 unsafe {
11612 let xs = sl(*x, base, rows * h);
11613 let g = sl(*gamma, base, h);
11614 let b = sl(*beta, base, h);
11615 let dys = sl(*dy, base, rows * h);
11616 let out = sl_mut(*dx, base, rows * h);
11617 let mut dg = vec![0f32; h];
11618 let mut db = vec![0f32; h];
11619 for r in 0..rows {
11620 crate::training_bwd::rms_norm_backward_row(
11621 &xs[r * h..(r + 1) * h],
11622 g,
11623 b,
11624 &dys[r * h..(r + 1) * h],
11625 &mut out[r * h..(r + 1) * h],
11626 &mut dg,
11627 &mut db,
11628 *eps,
11629 );
11630 }
11631 }
11632 }
11633
11634 Thunk::RmsNormBackwardGamma {
11635 x,
11636 gamma,
11637 beta,
11638 dy,
11639 dgamma,
11640 rows,
11641 h,
11642 eps,
11643 } => {
11644 let (rows, h) = (*rows as usize, *h as usize);
11645 unsafe {
11646 let xs = sl(*x, base, rows * h);
11647 let g = sl(*gamma, base, h);
11648 let b = sl(*beta, base, h);
11649 let dys = sl(*dy, base, rows * h);
11650 let out = sl_mut(*dgamma, base, h);
11651 for v in out.iter_mut() {
11652 *v = 0.0;
11653 }
11654 let mut dx = vec![0f32; h];
11655 let mut db = vec![0f32; h];
11656 for r in 0..rows {
11657 crate::training_bwd::rms_norm_backward_row(
11658 &xs[r * h..(r + 1) * h],
11659 g,
11660 b,
11661 &dys[r * h..(r + 1) * h],
11662 &mut dx,
11663 &mut *out,
11664 &mut db,
11665 *eps,
11666 );
11667 }
11668 }
11669 }
11670
11671 Thunk::RmsNormBackwardBeta {
11672 x,
11673 gamma,
11674 beta,
11675 dy,
11676 dbeta,
11677 rows,
11678 h,
11679 eps,
11680 } => {
11681 let (rows, h) = (*rows as usize, *h as usize);
11682 unsafe {
11683 let xs = sl(*x, base, rows * h);
11684 let g = sl(*gamma, base, h);
11685 let b = sl(*beta, base, h);
11686 let dys = sl(*dy, base, rows * h);
11687 let out = sl_mut(*dbeta, base, h);
11688 for v in out.iter_mut() {
11689 *v = 0.0;
11690 }
11691 let mut dx = vec![0f32; h];
11692 let mut dg = vec![0f32; h];
11693 for r in 0..rows {
11694 crate::training_bwd::rms_norm_backward_row(
11695 &xs[r * h..(r + 1) * h],
11696 g,
11697 b,
11698 &dys[r * h..(r + 1) * h],
11699 &mut dx,
11700 &mut dg,
11701 &mut *out,
11702 *eps,
11703 );
11704 }
11705 }
11706 }
11707
11708 Thunk::RopeBackward {
11709 dy,
11710 cos,
11711 sin,
11712 dx,
11713 batch,
11714 seq,
11715 hidden,
11716 head_dim,
11717 n_rot,
11718 cos_len,
11719 } => {
11720 let (b, s, hs, dh, nr, cl) = (
11721 *batch as usize,
11722 *seq as usize,
11723 *hidden as usize,
11724 *head_dim as usize,
11725 *n_rot as usize,
11726 *cos_len as usize,
11727 );
11728 let nh = hs / dh;
11729 let tab_half = dh / 2;
11730 unsafe {
11731 let dys = sl(*dy, base, b * s * hs);
11732 let cos_tab = sl(*cos, base, cl);
11733 let sin_tab = sl(*sin, base, cl);
11734 let out = sl_mut(*dx, base, b * s * hs);
11735 for bi in 0..b {
11736 for si in 0..s {
11737 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
11738 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
11739 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
11740 for hi in 0..nh {
11741 let base_idx = bi * s * hs + si * hs + hi * dh;
11742 crate::training_bwd::rope_backward_row(
11743 &dys[base_idx..base_idx + dh],
11744 cp,
11745 sp,
11746 &mut out[base_idx..base_idx + dh],
11747 dh,
11748 nr,
11749 );
11750 }
11751 }
11752 }
11753 }
11754 }
11755
11756 Thunk::CumsumBackward {
11757 dy,
11758 dx,
11759 rows,
11760 cols,
11761 exclusive,
11762 } => {
11763 let (rows, cols) = (*rows as usize, *cols as usize);
11764 unsafe {
11765 let dys = sl(*dy, base, rows * cols);
11766 let out = sl_mut(*dx, base, rows * cols);
11767 for r in 0..rows {
11768 crate::training_bwd::cumsum_backward_row(
11769 &dys[r * cols..(r + 1) * cols],
11770 &mut out[r * cols..(r + 1) * cols],
11771 *exclusive,
11772 );
11773 }
11774 }
11775 }
11776
11777 Thunk::GroupNormBackwardInput {
11778 x,
11779 gamma,
11780 beta: _beta,
11781 dy,
11782 dx,
11783 n,
11784 c,
11785 h,
11786 w,
11787 num_groups,
11788 eps,
11789 } => {
11790 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11791 let plane = c * h * w;
11792 unsafe {
11793 let xs = sl(*x, base, n * plane);
11794 let g = sl(*gamma, base, c);
11795 let dys = sl(*dy, base, n * plane);
11796 let out = sl_mut(*dx, base, n * plane);
11797 crate::training_bwd::group_norm_backward_input_nchw(
11798 xs,
11799 g,
11800 dys,
11801 out,
11802 n,
11803 c,
11804 h,
11805 w,
11806 *num_groups as usize,
11807 *eps,
11808 );
11809 }
11810 }
11811
11812 Thunk::GroupNormBackwardGamma {
11813 x,
11814 dy,
11815 dgamma,
11816 n,
11817 c,
11818 h,
11819 w,
11820 num_groups,
11821 eps,
11822 } => {
11823 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11824 let plane = c * h * w;
11825 unsafe {
11826 let xs = sl(*x, base, n * plane);
11827 let dys = sl(*dy, base, n * plane);
11828 let out = sl_mut(*dgamma, base, c);
11829 crate::training_bwd::group_norm_backward_gamma_nchw(
11830 xs,
11831 dys,
11832 out,
11833 n,
11834 c,
11835 h,
11836 w,
11837 *num_groups as usize,
11838 *eps,
11839 );
11840 }
11841 }
11842
11843 Thunk::GroupNormBackwardBeta {
11844 dy,
11845 dbeta,
11846 n,
11847 c,
11848 h,
11849 w,
11850 } => {
11851 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11852 let plane = c * h * w;
11853 unsafe {
11854 let dys = sl(*dy, base, n * plane);
11855 let out = sl_mut(*dbeta, base, c);
11856 crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
11857 }
11858 }
11859
11860 Thunk::GatherBackward {
11861 dy,
11862 indices,
11863 dst,
11864 outer,
11865 axis_dim,
11866 num_idx,
11867 trailing,
11868 } => {
11869 let (outer, axis_dim, num_idx, trailing) = (
11870 *outer as usize,
11871 *axis_dim as usize,
11872 *num_idx as usize,
11873 *trailing as usize,
11874 );
11875 unsafe {
11876 let dys = sl(*dy, base, outer * num_idx * trailing);
11877 let ids = sl(*indices, base, num_idx);
11878 let out = sl_mut(*dst, base, outer * axis_dim * trailing);
11879 for v in out.iter_mut() {
11880 *v = 0.0;
11881 }
11882 crate::training_bwd::gather_axis_backward(
11883 dys, ids, out, outer, axis_dim, num_idx, trailing,
11884 );
11885 }
11886 }
11887
11888 Thunk::MaxPool2dBackward {
11889 x,
11890 dy,
11891 dx,
11892 n,
11893 c,
11894 h,
11895 w,
11896 h_out,
11897 w_out,
11898 kh,
11899 kw,
11900 sh,
11901 sw,
11902 ph,
11903 pw,
11904 } => {
11905 let n = *n as usize;
11906 let c = *c as usize;
11907 let h = *h as usize;
11908 let w = *w as usize;
11909 let h_out = *h_out as usize;
11910 let w_out = *w_out as usize;
11911 let kh = *kh as usize;
11912 let kw = *kw as usize;
11913 let sh = *sh as usize;
11914 let sw = *sw as usize;
11915 let ph = *ph as usize;
11916 let pw = *pw as usize;
11917 unsafe {
11918 let xs = sl(*x, base, n * c * h * w);
11919 let dys = sl(*dy, base, n * c * h_out * w_out);
11920 let dxs = sl_mut(*dx, base, n * c * h * w);
11921 for v in dxs.iter_mut() {
11924 *v = 0.0;
11925 }
11926 for ni in 0..n {
11927 for ci in 0..c {
11928 let in_chan = (ni * c + ci) * h * w;
11929 let out_chan = (ni * c + ci) * h_out * w_out;
11930 for ho in 0..h_out {
11931 for wo in 0..w_out {
11932 let mut best_v = f32::NEG_INFINITY;
11934 let mut best_idx: Option<usize> = None;
11935 for ki in 0..kh {
11936 for kj in 0..kw {
11937 let hi = ho * sh + ki;
11938 let wi = wo * sw + kj;
11939 if hi < ph || wi < pw {
11940 continue;
11941 }
11942 let hi = hi - ph;
11943 let wi = wi - pw;
11944 if hi >= h || wi >= w {
11945 continue;
11946 }
11947 let idx = in_chan + hi * w + wi;
11948 let v = xs[idx];
11949 if v > best_v {
11953 best_v = v;
11954 best_idx = Some(idx);
11955 }
11956 }
11957 }
11958 if let Some(idx) = best_idx {
11959 dxs[idx] += dys[out_chan + ho * w_out + wo];
11960 }
11961 }
11962 }
11963 }
11964 }
11965 }
11966 }
11967
11968 Thunk::Conv2dBackwardInput {
11969 dy,
11970 w,
11971 dx,
11972 n,
11973 c_in,
11974 h,
11975 w_in,
11976 c_out,
11977 h_out,
11978 w_out,
11979 kh,
11980 kw,
11981 sh,
11982 sw,
11983 ph,
11984 pw,
11985 dh,
11986 dw,
11987 groups,
11988 } => {
11989 let n = *n as usize;
12001 let c_in = *c_in as usize;
12002 let h = *h as usize;
12003 let w_in = *w_in as usize;
12004 let c_out = *c_out as usize;
12005 let h_out = *h_out as usize;
12006 let w_out = *w_out as usize;
12007 let kh = *kh as usize;
12008 let kw = *kw as usize;
12009 let sh = *sh as usize;
12010 let sw = *sw as usize;
12011 let ph = *ph as usize;
12012 let pw = *pw as usize;
12013 let dh = *dh as usize;
12014 let dw = *dw as usize;
12015 let groups = *groups as usize;
12016 let c_in_per_g = c_in / groups;
12017 let c_out_per_g = c_out / groups;
12018
12019 let m_dim = c_in_per_g * kh * kw;
12020 let n_dim = h_out * w_out;
12021 let k_dim = c_out_per_g;
12022
12023 let dy_stride_n = c_out * h_out * w_out;
12024 let dy_stride_g = c_out_per_g * h_out * w_out;
12025 let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12026 let dx_stride_n = c_in * h * w_in;
12027 let dx_stride_g = c_in_per_g * h * w_in;
12028
12029 unsafe {
12030 let dys = sl(*dy, base, n * c_out * h_out * w_out);
12031 let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
12032 let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
12033 for v in dxs.iter_mut() {
12034 *v = 0.0;
12035 }
12036
12037 let mut dcol = vec![0f32; m_dim * n_dim];
12039
12040 for ni in 0..n {
12041 for g in 0..groups {
12042 let w_g_off = g * w_stride_g;
12043 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12044 let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
12045
12046 crate::blas::sgemm_general(
12051 ws.as_ptr().add(w_g_off),
12052 dys.as_ptr().add(dy_n_g_off),
12053 dcol.as_mut_ptr(),
12054 m_dim,
12055 n_dim,
12056 k_dim,
12057 1.0,
12058 0.0,
12059 m_dim,
12060 n_dim,
12061 n_dim,
12062 true,
12063 false,
12064 );
12065
12066 col2im(
12068 &dcol,
12069 &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
12070 c_in_per_g,
12071 h,
12072 w_in,
12073 h_out,
12074 w_out,
12075 kh,
12076 kw,
12077 sh,
12078 sw,
12079 ph,
12080 pw,
12081 dh,
12082 dw,
12083 );
12084 }
12085 }
12086 }
12087 }
12088
12089 Thunk::Conv2dBackwardWeight {
12090 x,
12091 dy,
12092 dw,
12093 n,
12094 c_in,
12095 h,
12096 w,
12097 c_out,
12098 h_out,
12099 w_out,
12100 kh,
12101 kw,
12102 sh,
12103 sw,
12104 ph,
12105 pw,
12106 dh,
12107 dw_dil,
12108 groups,
12109 } => {
12110 let n = *n as usize;
12111 let c_in = *c_in as usize;
12112 let h = *h as usize;
12113 let w = *w as usize;
12114 let c_out = *c_out as usize;
12125 let h_out = *h_out as usize;
12126 let w_out = *w_out as usize;
12127 let kh = *kh as usize;
12128 let kw = *kw as usize;
12129 let sh = *sh as usize;
12130 let sw = *sw as usize;
12131 let ph = *ph as usize;
12132 let pw = *pw as usize;
12133 let dh = *dh as usize;
12134 let dw_dil = *dw_dil as usize;
12135 let groups = *groups as usize;
12136 let c_in_per_g = c_in / groups;
12137 let c_out_per_g = c_out / groups;
12138
12139 let m_dim = c_out_per_g;
12140 let n_dim = c_in_per_g * kh * kw;
12141 let k_dim = h_out * w_out;
12142
12143 let x_stride_n = c_in * h * w;
12144 let x_stride_g = c_in_per_g * h * w;
12145 let dy_stride_n = c_out * h_out * w_out;
12146 let dy_stride_g = c_out_per_g * h_out * w_out;
12147 let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12148
12149 unsafe {
12150 let xs = sl(*x, base, n * c_in * h * w);
12151 let dys = sl(*dy, base, n * c_out * h_out * w_out);
12152 let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
12153 for v in dws.iter_mut() {
12154 *v = 0.0;
12155 }
12156
12157 let mut col = vec![0f32; n_dim * k_dim];
12158
12159 for ni in 0..n {
12160 for g in 0..groups {
12161 let x_n_g_off = ni * x_stride_n + g * x_stride_g;
12162 im2col(
12163 &xs[x_n_g_off..x_n_g_off + x_stride_g],
12164 &mut col,
12165 c_in_per_g,
12166 h,
12167 w,
12168 h_out,
12169 w_out,
12170 kh,
12171 kw,
12172 sh,
12173 sw,
12174 ph,
12175 pw,
12176 dh,
12177 dw_dil,
12178 );
12179
12180 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12181 let dw_g_off = g * dw_stride_g;
12182
12183 crate::blas::sgemm_general(
12191 dys.as_ptr().add(dy_n_g_off),
12192 col.as_ptr(),
12193 dws.as_mut_ptr().add(dw_g_off),
12194 m_dim,
12195 n_dim,
12196 k_dim,
12197 1.0,
12198 1.0,
12199 k_dim,
12200 k_dim,
12201 n_dim,
12202 false,
12203 true,
12204 );
12205 }
12206 }
12207 }
12208 }
12209
12210 Thunk::SoftmaxCrossEntropy {
12211 logits,
12212 labels,
12213 dst,
12214 n,
12215 c,
12216 } => {
12217 let n = *n as usize;
12218 let c = *c as usize;
12219 unsafe {
12220 let lg = sl(*logits, base, n * c);
12221 let lb = sl(*labels, base, n);
12222 let out = sl_mut(*dst, base, n);
12223 for ni in 0..n {
12224 let row = &lg[ni * c..(ni + 1) * c];
12225 let mut m = f32::NEG_INFINITY;
12227 for &v in row {
12228 if v > m {
12229 m = v;
12230 }
12231 }
12232 let mut sum = 0f32;
12233 for &v in row {
12234 sum += (v - m).exp();
12235 }
12236 let lse = m + sum.ln();
12237 let label_idx = lb[ni] as usize;
12238 out[ni] = lse - row[label_idx];
12240 }
12241 }
12242 }
12243
12244 Thunk::SoftmaxCrossEntropyBackward {
12245 logits,
12246 labels,
12247 d_loss,
12248 dlogits,
12249 n,
12250 c,
12251 } => {
12252 let n = *n as usize;
12253 let c = *c as usize;
12254 unsafe {
12255 let lg = sl(*logits, base, n * c);
12256 let lb = sl(*labels, base, n);
12257 let dl = sl(*d_loss, base, n);
12258 let out = sl_mut(*dlogits, base, n * c);
12259 for ni in 0..n {
12260 let row = &lg[ni * c..(ni + 1) * c];
12261 let label_idx = lb[ni] as usize;
12262 let scale = dl[ni];
12263 let mut m = f32::NEG_INFINITY;
12264 for &v in row {
12265 if v > m {
12266 m = v;
12267 }
12268 }
12269 let mut sum = 0f32;
12270 for &v in row {
12271 sum += (v - m).exp();
12272 }
12273 let inv_sum = 1.0 / sum;
12274 let dst_row = &mut out[ni * c..(ni + 1) * c];
12275 for k in 0..c {
12276 let p = (row[k] - m).exp() * inv_sum;
12277 let one_hot = if k == label_idx { 1.0 } else { 0.0 };
12278 dst_row[k] = (p - one_hot) * scale;
12279 }
12280 }
12281 }
12282 }
12283
12284 Thunk::GatherAxis {
12285 table,
12286 idx,
12287 dst,
12288 outer,
12289 axis_dim,
12290 num_idx,
12291 trailing,
12292 } => {
12293 let outer = *outer as usize;
12294 let axis_dim = *axis_dim as usize;
12295 let num_idx = *num_idx as usize;
12296 let trailing = *trailing as usize;
12297 unsafe {
12298 let tab = sl(*table, base, outer * axis_dim * trailing);
12299 let ids = sl(*idx, base, num_idx);
12300 let out = sl_mut(*dst, base, outer * num_idx * trailing);
12301 for o in 0..outer {
12302 let tab_outer = o * axis_dim * trailing;
12303 let out_outer = o * num_idx * trailing;
12304 for k in 0..num_idx {
12305 let row = ids[k] as usize;
12306 let tab_row = tab_outer + row * trailing;
12307 let out_row = out_outer + k * trailing;
12308 out[out_row..out_row + trailing]
12309 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
12310 }
12311 }
12312 }
12313 }
12314
12315 Thunk::Transpose {
12316 src,
12317 dst,
12318 in_total,
12319 out_dims,
12320 in_strides,
12321 } => {
12322 let rank = out_dims.len();
12327 let total: usize = out_dims.iter().map(|&d| d as usize).product();
12328 let in_total = *in_total as usize;
12329 unsafe {
12330 let inp = sl(*src, base, in_total);
12331 let out = sl_mut(*dst, base, total);
12332 let mut idx = vec![0usize; rank];
12333 for o in 0..total {
12334 let mut src_idx = 0usize;
12335 for d in 0..rank {
12336 src_idx += idx[d] * in_strides[d] as usize;
12337 }
12338 out[o] = inp[src_idx];
12339 for d in (0..rank).rev() {
12341 idx[d] += 1;
12342 if idx[d] < out_dims[d] as usize {
12343 break;
12344 }
12345 idx[d] = 0;
12346 }
12347 }
12348 }
12349 }
12350
12351 Thunk::CustomOp {
12357 kernel,
12358 inputs,
12359 output,
12360 attrs,
12361 } => {
12362 let (out_off, out_len, out_shape) = output;
12363 unsafe {
12364 dispatch_custom_op(
12365 &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
12366 );
12367 }
12368 }
12369 }
12370 }
12371}
12372
12373#[allow(clippy::too_many_arguments)]
12388unsafe fn griewank_process_segment(
12389 t_lo: usize,
12390 t_hi: usize,
12391 anchor_carry: &[u8],
12392 cb: usize,
12393 fwd_sched: &ThunkSchedule,
12394 fwd_init: &[u8],
12395 fwd_carry_in_off: usize,
12396 fwd_output_off: usize,
12397 fwd_x_offs: &[usize],
12398 base: *mut u8,
12399 outer_xs_offs: &[(usize, u32)],
12400 fwd_buf: &mut Vec<u8>,
12401 leaf_threshold: usize,
12402 process_iter: &mut dyn FnMut(usize, &[u8]),
12403) {
12404 unsafe {
12405 let size = t_hi - t_lo + 1;
12406 if size == 1 {
12407 process_iter(t_lo, anchor_carry);
12408 return;
12409 }
12410 if size <= leaf_threshold {
12411 let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
12413 cache.extend_from_slice(anchor_carry);
12414 fwd_buf.copy_from_slice(fwd_init);
12415 std::ptr::copy_nonoverlapping(
12416 anchor_carry.as_ptr(),
12417 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12418 cb,
12419 );
12420 for i in 1..size {
12421 let cur_iter = t_lo + i - 1;
12422 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12423 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12424 let xb = x_psb as usize;
12425 std::ptr::copy_nonoverlapping(
12426 base.add(outer_xs_off + cur_iter * xb),
12427 fwd_buf.as_mut_ptr().add(*fb_x_off),
12428 xb,
12429 );
12430 }
12431 execute_thunks(fwd_sched, fwd_buf);
12432 if fwd_output_off != fwd_carry_in_off {
12433 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12434 }
12435 cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
12436 }
12437 for t in (t_lo..=t_hi).rev() {
12439 let idx = t - t_lo;
12440 let carry = &cache[idx * cb..(idx + 1) * cb];
12441 process_iter(t, carry);
12442 }
12443 return;
12444 }
12445
12446 let mid = t_lo + size / 2;
12450 fwd_buf.copy_from_slice(fwd_init);
12451 std::ptr::copy_nonoverlapping(
12452 anchor_carry.as_ptr(),
12453 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12454 cb,
12455 );
12456 for cur_iter in t_lo..mid {
12457 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12458 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12459 let xb = x_psb as usize;
12460 std::ptr::copy_nonoverlapping(
12461 base.add(outer_xs_off + cur_iter * xb),
12462 fwd_buf.as_mut_ptr().add(*fb_x_off),
12463 xb,
12464 );
12465 }
12466 execute_thunks(fwd_sched, fwd_buf);
12467 if fwd_output_off != fwd_carry_in_off {
12468 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12469 }
12470 }
12471 let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
12472
12473 griewank_process_segment(
12477 mid,
12478 t_hi,
12479 &mid_carry,
12480 cb,
12481 fwd_sched,
12482 fwd_init,
12483 fwd_carry_in_off,
12484 fwd_output_off,
12485 fwd_x_offs,
12486 base,
12487 outer_xs_offs,
12488 fwd_buf,
12489 leaf_threshold,
12490 process_iter,
12491 );
12492 griewank_process_segment(
12494 t_lo,
12495 mid - 1,
12496 anchor_carry,
12497 cb,
12498 fwd_sched,
12499 fwd_init,
12500 fwd_carry_in_off,
12501 fwd_output_off,
12502 fwd_x_offs,
12503 base,
12504 outer_xs_offs,
12505 fwd_buf,
12506 leaf_threshold,
12507 process_iter,
12508 );
12509 }
12510}
12511
12512pub unsafe fn execute_fft1d_f64(
12529 src: usize,
12530 dst: usize,
12531 outer: usize,
12532 n_complex: usize,
12533 inverse: bool,
12534 norm_tag: u32,
12535 base: *mut u8,
12536) {
12537 let row_elems = 2 * n_complex;
12538 let mut re = vec![0f64; n_complex];
12539 let mut im = vec![0f64; n_complex];
12540 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
12541 let scale = norm.output_scale(n_complex, inverse);
12542 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
12545 BluesteinScratchF64::empty()
12546 } else {
12547 BluesteinScratchF64::build(n_complex, inverse)
12548 };
12549 for o in 0..outer {
12550 let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
12551 let s = unsafe { sl_f64(row_offset, base, row_elems) };
12552 re.copy_from_slice(&s[..n_complex]);
12553 im.copy_from_slice(&s[n_complex..]);
12554 if n_complex.is_power_of_two() {
12555 fft_radix2_inplace_f64(&mut re, &mut im, inverse);
12556 } else if n_complex <= 16 {
12557 fft_naive_inplace_f64(&mut re, &mut im, inverse);
12558 } else {
12559 fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
12560 }
12561 if scale != 1.0 {
12562 re.iter_mut().for_each(|v| *v *= scale);
12563 im.iter_mut().for_each(|v| *v *= scale);
12564 }
12565 let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
12566 let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
12567 d[..n_complex].copy_from_slice(&re);
12568 d[n_complex..].copy_from_slice(&im);
12569 }
12570}
12571
12572pub unsafe fn execute_gated_delta_net_f32(
12581 q: usize,
12582 k: usize,
12583 v: usize,
12584 g: usize,
12585 beta: usize,
12586 state: usize,
12587 dst: usize,
12588 batch: usize,
12589 seq: usize,
12590 heads: usize,
12591 state_size: usize,
12592 base: *mut u8,
12593) {
12594 use rayon::prelude::*;
12595
12596 #[derive(Copy, Clone)]
12597 struct ArenaPtr(usize);
12598 unsafe impl Send for ArenaPtr {}
12599 unsafe impl Sync for ArenaPtr {}
12600 impl ArenaPtr {
12601 #[inline]
12602 fn get(self) -> *mut u8 {
12603 self.0 as *mut u8
12604 }
12605 }
12606
12607 unsafe {
12608 let arena = ArenaPtr(base as usize);
12609 let (b, s, h, n) = (batch, seq, heads, state_size);
12610 let scale = 1.0f32 / (n as f32).sqrt();
12611 let use_external = state != 0;
12612 let mut owned_state = vec![0f32; h * n * n];
12613
12614 crate::pool::num_threads();
12615
12616 assert!(
12617 n <= crate::gdn::GDN_MAX_STATE,
12618 "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
12619 crate::gdn::GDN_MAX_STATE
12620 );
12621
12622 let qs = sl(q, arena.get(), b * s * h * n);
12623 let ks = sl(k, arena.get(), b * s * h * n);
12624 let vs = sl(v, arena.get(), b * s * h * n);
12625 let gs = sl(g, arena.get(), b * s * h);
12626 let betas = sl(beta, arena.get(), b * s * h);
12627 let _out = sl_mut(dst, arena.get(), b * s * h * n);
12628 let hs_n = h * n;
12629
12630 let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
12631 for ti in 0..s {
12632 let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
12633 let gb_step = bi * s * h + ti * h + hi;
12634 let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
12635 crate::gdn::gdn_step_blas(
12636 s_mat,
12637 &qs[qkv_step..qkv_step + n],
12638 &ks[qkv_step..qkv_step + n],
12639 &vs[qkv_step..qkv_step + n],
12640 gs[gb_step],
12641 betas[gb_step],
12642 out_row,
12643 sk,
12644 n,
12645 scale,
12646 );
12647 }
12648 };
12649
12650 if !use_external && s > 1 {
12653 for bi in 0..b {
12654 (0..h).into_par_iter().for_each(|hi| {
12655 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12656 let sk = &mut sk_buf[..n];
12657 let mut local_state =
12658 [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
12659 let s_mat = &mut local_state[..n * n];
12660 s_mat.fill(0.0);
12661 run_head(bi, hi, s_mat, sk);
12662 });
12663 }
12664 return;
12665 }
12666
12667 if use_external {
12668 let state_bytes = state;
12669 (0..b * h).into_par_iter().for_each(|bhi| {
12670 let bi = bhi / h;
12671 let hi = bhi % h;
12672 let elem_off = bi * h * n * n + hi * n * n;
12673 let s_mat = sl_mut(
12674 state_bytes + elem_off * std::mem::size_of::<f32>(),
12675 arena.get(),
12676 n * n,
12677 );
12678 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12679 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12680 });
12681 } else {
12682 for bi in 0..b {
12683 owned_state.fill(0.0);
12684 owned_state
12685 .par_chunks_mut(n * n)
12686 .enumerate()
12687 .for_each(|(hi, s_mat)| {
12688 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12689 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12690 });
12691 }
12692 }
12693 }
12694}
12695
12696pub unsafe fn execute_rms_norm_backward_input_f32(
12698 x: usize,
12699 gamma: usize,
12700 beta: usize,
12701 dy: usize,
12702 dx: usize,
12703 rows: u32,
12704 h: u32,
12705 eps: f32,
12706 base: *mut u8,
12707) {
12708 let (rows, h) = (rows as usize, h as usize);
12709 let mut dg = vec![0f32; h];
12710 let mut db = vec![0f32; h];
12711 let xs = sl(x, base, rows * h);
12712 let dys = sl(dy, base, rows * h);
12713 let g = sl(gamma, base, h);
12714 let b = sl(beta, base, h);
12715 let out = sl_mut(dx, base, rows * h);
12716 for r in 0..rows {
12717 crate::training_bwd::rms_norm_backward_row(
12718 &xs[r * h..(r + 1) * h],
12719 g,
12720 b,
12721 &dys[r * h..(r + 1) * h],
12722 &mut out[r * h..(r + 1) * h],
12723 &mut dg,
12724 &mut db,
12725 eps,
12726 );
12727 }
12728}
12729
12730pub unsafe fn execute_rms_norm_backward_gamma_f32(
12731 x: usize,
12732 gamma: usize,
12733 beta: usize,
12734 dy: usize,
12735 dgamma: usize,
12736 rows: u32,
12737 h: u32,
12738 eps: f32,
12739 base: *mut u8,
12740) {
12741 let (rows, h) = (rows as usize, h as usize);
12742 let out = sl_mut(dgamma, base, h);
12743 out.fill(0.0);
12744 let mut dx = vec![0f32; h];
12745 let mut db = vec![0f32; h];
12746 let xs = sl(x, base, rows * h);
12747 let dys = sl(dy, base, rows * h);
12748 let g = sl(gamma, base, h);
12749 let b = sl(beta, base, h);
12750 for r in 0..rows {
12751 crate::training_bwd::rms_norm_backward_row(
12752 &xs[r * h..(r + 1) * h],
12753 g,
12754 b,
12755 &dys[r * h..(r + 1) * h],
12756 &mut dx,
12757 out,
12758 &mut db,
12759 eps,
12760 );
12761 }
12762}
12763
12764pub unsafe fn execute_rms_norm_backward_beta_f32(
12765 x: usize,
12766 gamma: usize,
12767 beta: usize,
12768 dy: usize,
12769 dbeta: usize,
12770 rows: u32,
12771 h: u32,
12772 eps: f32,
12773 base: *mut u8,
12774) {
12775 let (rows, h) = (rows as usize, h as usize);
12776 let out = sl_mut(dbeta, base, h);
12777 out.fill(0.0);
12778 let mut dx = vec![0f32; h];
12779 let mut dg = vec![0f32; h];
12780 let xs = sl(x, base, rows * h);
12781 let dys = sl(dy, base, rows * h);
12782 let g = sl(gamma, base, h);
12783 let b = sl(beta, base, h);
12784 for r in 0..rows {
12785 crate::training_bwd::rms_norm_backward_row(
12786 &xs[r * h..(r + 1) * h],
12787 g,
12788 b,
12789 &dys[r * h..(r + 1) * h],
12790 &mut dx,
12791 &mut dg,
12792 out,
12793 eps,
12794 );
12795 }
12796}
12797
12798pub unsafe fn execute_rope_backward_f32(
12799 dy: usize,
12800 cos: usize,
12801 sin: usize,
12802 dx: usize,
12803 batch: u32,
12804 seq: u32,
12805 hidden: u32,
12806 head_dim: u32,
12807 n_rot: u32,
12808 cos_len: u32,
12809 base: *mut u8,
12810) {
12811 let (b, s, hs, dh, nr, cl) = (
12812 batch as usize,
12813 seq as usize,
12814 hidden as usize,
12815 head_dim as usize,
12816 n_rot as usize,
12817 cos_len as usize,
12818 );
12819 let nh = hs / dh;
12820 let tab_half = dh / 2;
12821 let dys = sl(dy, base, b * s * hs);
12822 let cos_tab = sl(cos, base, cl);
12823 let sin_tab = sl(sin, base, cl);
12824 let out = sl_mut(dx, base, b * s * hs);
12825 for bi in 0..b {
12826 for si in 0..s {
12827 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12828 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12829 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12830 for hi in 0..nh {
12831 let base_idx = bi * s * hs + si * hs + hi * dh;
12832 crate::training_bwd::rope_backward_row(
12833 &dys[base_idx..base_idx + dh],
12834 cp,
12835 sp,
12836 &mut out[base_idx..base_idx + dh],
12837 dh,
12838 nr,
12839 );
12840 }
12841 }
12842 }
12843}
12844
12845pub unsafe fn execute_cumsum_backward_f32(
12846 dy: usize,
12847 dx: usize,
12848 rows: u32,
12849 cols: u32,
12850 exclusive: bool,
12851 base: *mut u8,
12852) {
12853 let (rows, cols) = (rows as usize, cols as usize);
12854 let dys = sl(dy, base, rows * cols);
12855 let out = sl_mut(dx, base, rows * cols);
12856 for r in 0..rows {
12857 crate::training_bwd::cumsum_backward_row(
12858 &dys[r * cols..(r + 1) * cols],
12859 &mut out[r * cols..(r + 1) * cols],
12860 exclusive,
12861 );
12862 }
12863}
12864
12865pub unsafe fn execute_gather_backward_f32(
12866 dy: usize,
12867 indices: usize,
12868 dst: usize,
12869 outer: u32,
12870 axis_dim: u32,
12871 num_idx: u32,
12872 trailing: u32,
12873 base: *mut u8,
12874) {
12875 let (outer, axis_dim, num_idx, trailing) = (
12876 outer as usize,
12877 axis_dim as usize,
12878 num_idx as usize,
12879 trailing as usize,
12880 );
12881 let out = sl_mut(dst, base, outer * axis_dim * trailing);
12882 out.fill(0.0);
12883 crate::training_bwd::gather_axis_backward(
12884 sl(dy, base, outer * num_idx * trailing),
12885 sl(indices, base, num_idx),
12886 out,
12887 outer,
12888 axis_dim,
12889 num_idx,
12890 trailing,
12891 );
12892}
12893
12894pub unsafe fn execute_dequant_matmul_gguf_f32(
12896 x: usize,
12897 w_q: usize,
12898 dst: usize,
12899 m: usize,
12900 k: usize,
12901 n: usize,
12902 scheme: rlx_ir::quant::QuantScheme,
12903 base: *mut u8,
12904) {
12905 unsafe {
12906 let block_bytes = scheme.gguf_block_bytes() as usize;
12907 let block_elems = scheme.gguf_block_size() as usize;
12908 let total_bytes = (k * n) / block_elems * block_bytes;
12909 let xs = sl(x, base, m * k);
12910 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
12911 let out = sl_mut(dst, base, m * n);
12912 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
12913 }
12914}
12915
12916pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
12918 input: usize,
12919 w_q: usize,
12920 expert_idx: usize,
12921 dst: usize,
12922 m: usize,
12923 k: usize,
12924 n: usize,
12925 num_experts: usize,
12926 scheme: rlx_ir::quant::QuantScheme,
12927 base: *mut u8,
12928) {
12929 unsafe {
12930 let block_bytes = scheme.gguf_block_bytes() as usize;
12931 let block_elems = scheme.gguf_block_size() as usize;
12932 let slab_bytes = (k * n) / block_elems * block_bytes;
12933 let xs = sl(input, base, m * k);
12934 let w_bytes =
12935 std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
12936 let ids = sl(expert_idx, base, m);
12937 let out = sl_mut(dst, base, m * n);
12938 crate::gguf_matmul::gguf_grouped_matmul_bt(
12939 xs,
12940 w_bytes,
12941 ids,
12942 out,
12943 m,
12944 k,
12945 n,
12946 num_experts,
12947 scheme,
12948 );
12949 }
12950}
12951
12952pub unsafe fn execute_dequant_matmul_int4_f32(
12954 x: usize,
12955 w_q: usize,
12956 scale: usize,
12957 zp: usize,
12958 dst: usize,
12959 m: usize,
12960 k: usize,
12961 n: usize,
12962 block_size: u32,
12963 is_asymmetric: bool,
12964 base: *mut u8,
12965) {
12966 let bs = block_size as usize;
12967 let n_blocks = k.div_ceil(bs);
12968 unsafe {
12969 let xs = sl(x, base, m * k);
12970 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12971 let scales = sl(scale, base, n_blocks * n);
12972 let zps = if is_asymmetric {
12973 sl(zp, base, n_blocks * n)
12974 } else {
12975 &[][..]
12976 };
12977 let out = sl_mut(dst, base, m * n);
12978 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
12979 }
12980}
12981
12982pub unsafe fn execute_dequant_matmul_fp8_f32(
12984 x: usize,
12985 w_q: usize,
12986 scale: usize,
12987 dst: usize,
12988 m: usize,
12989 k: usize,
12990 n: usize,
12991 e5m2: bool,
12992 base: *mut u8,
12993) {
12994 unsafe {
12995 let xs = sl(x, base, m * k);
12996 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
12997 let scales = sl(scale, base, n);
12998 let out = sl_mut(dst, base, m * n);
12999 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
13000 }
13001}
13002
13003pub unsafe fn execute_dequant_matmul_nvfp4_f32(
13005 x: usize,
13006 w_q: usize,
13007 scale: usize,
13008 global_scale: usize,
13009 dst: usize,
13010 m: usize,
13011 k: usize,
13012 n: usize,
13013 base: *mut u8,
13014) {
13015 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
13016 unsafe {
13017 let xs = sl(x, base, m * k);
13018 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
13019 let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
13020 let gs = sl(global_scale, base, 1)[0];
13021 let out = sl_mut(dst, base, m * n);
13022 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
13023 }
13024}
13025
13026pub unsafe fn execute_gated_delta_net_f16(
13028 q: usize,
13029 k: usize,
13030 v: usize,
13031 g: usize,
13032 beta: usize,
13033 state: usize,
13034 dst: usize,
13035 batch: usize,
13036 seq: usize,
13037 heads: usize,
13038 state_size: usize,
13039 base: *mut u8,
13040) {
13041 use half::f16;
13042 unsafe {
13043 let read_f16 = |off: usize, len: usize| -> Vec<f32> {
13044 let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
13045 raw.chunks_exact(2)
13046 .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
13047 .collect()
13048 };
13049 let write_f16 = |off: usize, data: &[f32]| {
13050 let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
13051 for (i, &v) in data.iter().enumerate() {
13052 let le = f16::from_f32(v).to_le_bytes();
13053 out[i * 2] = le[0];
13054 out[i * 2 + 1] = le[1];
13055 }
13056 };
13057
13058 let (b, s, h, n) = (batch, seq, heads, state_size);
13059 let q_f = read_f16(q, b * s * h * n);
13060 let k_f = read_f16(k, b * s * h * n);
13061 let v_f = read_f16(v, b * s * h * n);
13062 let g_f = read_f16(g, b * s * h);
13063 let b_f = read_f16(beta, b * s * h);
13064 let mut state_f = if state != 0 {
13065 read_f16(state, b * h * n * n)
13066 } else {
13067 vec![0f32; b * h * n * n]
13068 };
13069 let mut out_f = vec![0f32; b * s * h * n];
13070 let scale = 1.0f32 / (n as f32).sqrt();
13071 let mut sk_buf = vec![0f32; n];
13072 let mut owned_state = vec![0f32; h * n * n];
13073
13074 for bi in 0..b {
13075 let state_slice: &mut [f32] = if state != 0 {
13076 let start = bi * h * n * n;
13077 &mut state_f[start..start + h * n * n]
13078 } else {
13079 owned_state.fill(0.0);
13080 &mut owned_state
13081 };
13082
13083 for ti in 0..s {
13084 let qkv_step_base = bi * s * h * n + ti * h * n;
13085 let gb_step_base = bi * s * h + ti * h;
13086
13087 for hi in 0..h {
13088 let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13089 let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13090 let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13091 let g_t = g_f[gb_step_base + hi];
13092 let beta_t = b_f[gb_step_base + hi];
13093
13094 let s_base = hi * n * n;
13095 let s_mat = &mut state_slice[s_base..s_base + n * n];
13096
13097 let g_exp = g_t.exp();
13098 for st in s_mat.iter_mut() {
13099 *st *= g_exp;
13100 }
13101
13102 for j in 0..n {
13103 let mut acc = 0f32;
13104 for i in 0..n {
13105 acc += s_mat[i * n + j] * k_row[i];
13106 }
13107 sk_buf[j] = acc;
13108 }
13109
13110 for j in 0..n {
13111 sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
13112 }
13113
13114 for i in 0..n {
13115 let ki = k_row[i];
13116 for j in 0..n {
13117 s_mat[i * n + j] += ki * sk_buf[j];
13118 }
13119 }
13120
13121 let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13122 for j in 0..n {
13123 let mut acc = 0f32;
13124 for i in 0..n {
13125 acc += s_mat[i * n + j] * q_row[i];
13126 }
13127 out_row[j] = acc * scale;
13128 }
13129 }
13130 }
13131 }
13132
13133 write_f16(dst, &out_f);
13134 if state != 0 {
13135 write_f16(state, &state_f);
13136 }
13137 }
13138}
13139
13140pub unsafe fn execute_group_norm_nchw_f32(
13142 src: usize,
13143 g: usize,
13144 b: usize,
13145 dst: usize,
13146 n: usize,
13147 c: usize,
13148 h: usize,
13149 w: usize,
13150 num_groups: usize,
13151 eps: f32,
13152 base: *mut u8,
13153) {
13154 let plane = c * h * w;
13155 for ni in 0..n {
13156 let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13157 let gamma = unsafe { sl(g, base, c) };
13158 let beta = unsafe { sl(b, base, c) };
13159 let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13160 crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
13161 }
13162}
13163
13164pub unsafe fn execute_layer_norm2d_nchw_f32(
13166 src: usize,
13167 g: usize,
13168 b: usize,
13169 dst: usize,
13170 n: usize,
13171 c: usize,
13172 h: usize,
13173 w: usize,
13174 eps: f32,
13175 base: *mut u8,
13176) {
13177 let plane = c * h * w;
13178 unsafe {
13179 let input = sl(src, base, n * plane);
13180 let gamma = sl(g, base, c);
13181 let beta = sl(b, base, c);
13182 let output = sl_mut(dst, base, n * plane);
13183 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
13184 }
13185}
13186
13187pub unsafe fn execute_conv_transpose2d_nchw_f32(
13189 src: usize,
13190 weight: usize,
13191 dst: usize,
13192 n: usize,
13193 c_in: usize,
13194 h: usize,
13195 w_in: usize,
13196 c_out: usize,
13197 h_out: usize,
13198 w_out: usize,
13199 kh: usize,
13200 kw: usize,
13201 sh: usize,
13202 sw: usize,
13203 ph: usize,
13204 pw: usize,
13205 dh: usize,
13206 dw: usize,
13207 groups: usize,
13208 base: *mut u8,
13209) {
13210 let in_elems = n * c_in * h * w_in;
13211 let w_elems = c_in * (c_out / groups) * kh * kw;
13212 let out_elems = n * c_out * h_out * w_out;
13213 unsafe {
13214 let input = sl(src, base, in_elems);
13215 let wt = sl(weight, base, w_elems);
13216 let output = sl_mut(dst, base, out_elems);
13217 crate::kernels::conv_transpose2d_nchw(
13218 input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
13219 dw, groups,
13220 );
13221 }
13222}
13223
13224pub unsafe fn execute_resize_nearest_2x_f32(
13226 src: usize,
13227 dst: usize,
13228 n: usize,
13229 c: usize,
13230 h: usize,
13231 w: usize,
13232 base: *mut u8,
13233) {
13234 let in_plane = c * h * w;
13235 let out_plane = c * h * 2 * w * 2;
13236 for ni in 0..n {
13237 let input = unsafe {
13238 sl(
13239 src + ni * in_plane * std::mem::size_of::<f32>(),
13240 base,
13241 in_plane,
13242 )
13243 };
13244 let output = unsafe {
13245 sl_mut(
13246 dst + ni * out_plane * std::mem::size_of::<f32>(),
13247 base,
13248 out_plane,
13249 )
13250 };
13251 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
13252 }
13253}
13254
13255pub unsafe fn execute_axial_rope2d_f32(
13257 src: usize,
13258 dst: usize,
13259 batch: usize,
13260 seq: usize,
13261 hidden: usize,
13262 end_x: usize,
13263 end_y: usize,
13264 head_dim: usize,
13265 num_heads: usize,
13266 theta: f32,
13267 repeat_factor: usize,
13268 base: *mut u8,
13269) {
13270 let plane = seq * hidden;
13271 let plane_bytes = plane * std::mem::size_of::<f32>();
13272 for bi in 0..batch {
13273 let in_off = src + bi * plane_bytes;
13274 let input = unsafe { sl(in_off, base, plane) };
13275 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
13276 input,
13277 num_heads,
13278 seq,
13279 head_dim,
13280 end_x,
13281 end_y,
13282 theta,
13283 repeat_factor,
13284 );
13285 let out_off = dst + bi * plane_bytes;
13286 let output = unsafe { sl_mut(out_off, base, plane) };
13287 output.copy_from_slice(&rotated);
13288 }
13289}
13290
13291pub unsafe fn execute_fft1d_f32(
13293 src: usize,
13294 dst: usize,
13295 outer: usize,
13296 n_complex: usize,
13297 inverse: bool,
13298 norm_tag: u32,
13299 base: *mut u8,
13300) {
13301 let row_elems = 2 * n_complex;
13302 let mut re = vec![0f32; n_complex];
13303 let mut im = vec![0f32; n_complex];
13304 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
13305 let scale = norm.output_scale(n_complex, inverse) as f32;
13306 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
13307 BluesteinScratchF32::empty()
13308 } else {
13309 BluesteinScratchF32::build(n_complex, inverse)
13310 };
13311 for o in 0..outer {
13312 let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
13313 let s = unsafe { sl(row_offset, base, row_elems) };
13314 re.copy_from_slice(&s[..n_complex]);
13315 im.copy_from_slice(&s[n_complex..]);
13316 if n_complex.is_power_of_two() {
13317 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
13318 } else if n_complex <= 16 {
13319 fft_naive_inplace_f32(&mut re, &mut im, inverse);
13320 } else {
13321 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
13322 }
13323 if scale != 1.0 {
13324 re.iter_mut().for_each(|v| *v *= scale);
13325 im.iter_mut().for_each(|v| *v *= scale);
13326 }
13327 let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
13328 let d = unsafe { sl_mut(dst_offset, base, row_elems) };
13329 d[..n_complex].copy_from_slice(&re);
13330 d[n_complex..].copy_from_slice(&im);
13331 }
13332}
13333
13334pub unsafe fn execute_fft1d_c64(
13336 src: usize,
13337 dst: usize,
13338 outer: usize,
13339 n_complex: usize,
13340 inverse: bool,
13341 norm_tag: u32,
13342 base: *mut u8,
13343) {
13344 let row_bytes = n_complex * 8;
13345 let mut re = vec![0f32; n_complex];
13346 let mut im = vec![0f32; n_complex];
13347 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
13348 let scale = norm.output_scale(n_complex, inverse) as f32;
13349 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
13350 BluesteinScratchF32::empty()
13351 } else {
13352 BluesteinScratchF32::build(n_complex, inverse)
13353 };
13354 for o in 0..outer {
13355 let row_offset = src + o * row_bytes;
13356 for i in 0..n_complex {
13357 let elem_off = row_offset + i * 8;
13358 re[i] = f32::from_le_bytes([
13359 *base.add(elem_off),
13360 *base.add(elem_off + 1),
13361 *base.add(elem_off + 2),
13362 *base.add(elem_off + 3),
13363 ]);
13364 im[i] = f32::from_le_bytes([
13365 *base.add(elem_off + 4),
13366 *base.add(elem_off + 5),
13367 *base.add(elem_off + 6),
13368 *base.add(elem_off + 7),
13369 ]);
13370 }
13371 if n_complex.is_power_of_two() {
13372 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
13373 } else if n_complex <= 16 {
13374 fft_naive_inplace_f32(&mut re, &mut im, inverse);
13375 } else {
13376 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
13377 }
13378 if scale != 1.0 {
13379 re.iter_mut().for_each(|v| *v *= scale);
13380 im.iter_mut().for_each(|v| *v *= scale);
13381 }
13382 let dst_row = dst + o * row_bytes;
13383 for i in 0..n_complex {
13384 let elem_off = dst_row + i * 8;
13385 let re_b = re[i].to_le_bytes();
13386 let im_b = im[i].to_le_bytes();
13387 for j in 0..4 {
13388 *base.add(elem_off + j) = re_b[j];
13389 *base.add(elem_off + 4 + j) = im_b[j];
13390 }
13391 }
13392 }
13393}
13394
13395pub unsafe fn execute_fft1d(
13397 src: usize,
13398 dst: usize,
13399 outer: usize,
13400 n_complex: usize,
13401 inverse: bool,
13402 norm_tag: u32,
13403 dtype: rlx_ir::DType,
13404 base: *mut u8,
13405) {
13406 match dtype {
13407 rlx_ir::DType::F32 => {
13408 execute_fft1d_f32(src, dst, outer, n_complex, inverse, norm_tag, base)
13409 }
13410 rlx_ir::DType::F64 => {
13411 execute_fft1d_f64(src, dst, outer, n_complex, inverse, norm_tag, base)
13412 }
13413 rlx_ir::DType::C64 => {
13414 execute_fft1d_c64(src, dst, outer, n_complex, inverse, norm_tag, base)
13415 }
13416 other => panic!("execute_fft1d: unsupported dtype {other:?}"),
13417 }
13418}
13419
13420fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
13425 let n = re.len();
13426 debug_assert_eq!(im.len(), n);
13427 debug_assert!(
13428 n.is_power_of_two(),
13429 "fft_radix2_f32: n={n} must be a power of two"
13430 );
13431 if n <= 1 {
13432 return;
13433 }
13434
13435 let mut j = 0usize;
13436 for i in 1..n {
13437 let mut bit = n >> 1;
13438 while j & bit != 0 {
13439 j ^= bit;
13440 bit >>= 1;
13441 }
13442 j ^= bit;
13443 if i < j {
13444 re.swap(i, j);
13445 im.swap(i, j);
13446 }
13447 }
13448
13449 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13450 let mut len = 2usize;
13451 while len <= n {
13452 let half = len / 2;
13453 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13454 let w_re_step = theta.cos();
13455 let w_im_step = theta.sin();
13456 let mut i = 0usize;
13457 while i < n {
13458 let mut wre = 1.0_f64;
13459 let mut wim = 0.0_f64;
13460 for k in 0..half {
13461 let wre_f = wre as f32;
13462 let wim_f = wim as f32;
13463 let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
13464 let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
13465 let u_re = re[i + k];
13466 let u_im = im[i + k];
13467 re[i + k] = u_re + t_re;
13468 im[i + k] = u_im + t_im;
13469 re[i + k + half] = u_re - t_re;
13470 im[i + k + half] = u_im - t_im;
13471 let new_wre = wre * w_re_step - wim * w_im_step;
13472 let new_wim = wre * w_im_step + wim * w_re_step;
13473 wre = new_wre;
13474 wim = new_wim;
13475 }
13476 i += len;
13477 }
13478 len <<= 1;
13479 }
13480}
13481
13482fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
13486 let n = re.len();
13487 debug_assert_eq!(im.len(), n);
13488 debug_assert!(
13489 n.is_power_of_two(),
13490 "fft_radix2: n={n} must be a power of two"
13491 );
13492 if n <= 1 {
13493 return;
13494 }
13495
13496 let mut j = 0usize;
13498 for i in 1..n {
13499 let mut bit = n >> 1;
13500 while j & bit != 0 {
13501 j ^= bit;
13502 bit >>= 1;
13503 }
13504 j ^= bit;
13505 if i < j {
13506 re.swap(i, j);
13507 im.swap(i, j);
13508 }
13509 }
13510
13511 let sign = if inverse { 1.0 } else { -1.0 };
13513 let mut len = 2usize;
13514 while len <= n {
13515 let half = len / 2;
13516 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13517 let w_re_step = theta.cos();
13518 let w_im_step = theta.sin();
13519 let mut i = 0usize;
13520 while i < n {
13521 let mut wre = 1.0_f64;
13523 let mut wim = 0.0_f64;
13524 for k in 0..half {
13525 let t_re = wre * re[i + k + half] - wim * im[i + k + half];
13526 let t_im = wre * im[i + k + half] + wim * re[i + k + half];
13527 let u_re = re[i + k];
13528 let u_im = im[i + k];
13529 re[i + k] = u_re + t_re;
13530 im[i + k] = u_im + t_im;
13531 re[i + k + half] = u_re - t_re;
13532 im[i + k + half] = u_im - t_im;
13533 let new_wre = wre * w_re_step - wim * w_im_step;
13534 let new_wim = wre * w_im_step + wim * w_re_step;
13535 wre = new_wre;
13536 wim = new_wim;
13537 }
13538 i += len;
13539 }
13540 len <<= 1;
13541 }
13542}
13543
13544struct BluesteinScratchF64 {
13548 m: usize,
13550 w_re: Vec<f64>,
13554 w_im: Vec<f64>,
13555 bf_re: Vec<f64>,
13558 bf_im: Vec<f64>,
13559 ar: Vec<f64>,
13561 ai: Vec<f64>,
13562}
13563
13564impl BluesteinScratchF64 {
13565 fn empty() -> Self {
13566 Self {
13567 m: 0,
13568 w_re: Vec::new(),
13569 w_im: Vec::new(),
13570 bf_re: Vec::new(),
13571 bf_im: Vec::new(),
13572 ar: Vec::new(),
13573 ai: Vec::new(),
13574 }
13575 }
13576
13577 fn build(n: usize, inverse: bool) -> Self {
13578 let m = if n <= 1 {
13581 1
13582 } else {
13583 (2 * n - 1).next_power_of_two()
13584 };
13585
13586 let mod_2n = (2 * n) as u64;
13589 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13590 let mut w_re = vec![0.0_f64; n];
13591 let mut w_im = vec![0.0_f64; n];
13592 for k in 0..n {
13593 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13594 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13595 w_re[k] = theta.cos();
13596 w_im[k] = theta.sin();
13597 }
13598
13599 let mut bf_re = vec![0.0_f64; m];
13602 let mut bf_im = vec![0.0_f64; m];
13603 if n > 0 {
13604 bf_re[0] = w_re[0];
13605 bf_im[0] = -w_im[0];
13606 for k in 1..n {
13607 bf_re[k] = w_re[k];
13608 bf_im[k] = -w_im[k];
13609 bf_re[m - k] = w_re[k];
13610 bf_im[m - k] = -w_im[k];
13611 }
13612 }
13613 if m > 1 {
13614 fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
13615 }
13616
13617 Self {
13618 m,
13619 w_re,
13620 w_im,
13621 bf_re,
13622 bf_im,
13623 ar: vec![0.0_f64; m],
13624 ai: vec![0.0_f64; m],
13625 }
13626 }
13627}
13628
13629fn fft_naive_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
13631 let n = re.len();
13632 if n <= 1 {
13633 return;
13634 }
13635 let sign = if inverse { 1.0 } else { -1.0 };
13636 let mut out_re = vec![0.0_f64; n];
13637 let mut out_im = vec![0.0_f64; n];
13638 for k in 0..n {
13639 for nn in 0..n {
13640 let theta = sign * 2.0 * std::f64::consts::PI * (nn as f64) * (k as f64) / (n as f64);
13641 let c = theta.cos();
13642 let s = theta.sin();
13643 out_re[k] += re[nn] * c - im[nn] * s;
13644 out_im[k] += re[nn] * s + im[nn] * c;
13645 }
13646 }
13647 re.copy_from_slice(&out_re);
13648 im.copy_from_slice(&out_im);
13649}
13650
13651fn fft_naive_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
13652 let n = re.len();
13653 if n <= 1 {
13654 return;
13655 }
13656 let sign = if inverse { 1.0f32 } else { -1.0f32 };
13657 let mut out_re = vec![0.0_f32; n];
13658 let mut out_im = vec![0.0_f32; n];
13659 for k in 0..n {
13660 for nn in 0..n {
13661 let theta = sign * 2.0 * std::f32::consts::PI * (nn as f32) * (k as f32) / (n as f32);
13662 let c = theta.cos();
13663 let s = theta.sin();
13664 out_re[k] += re[nn] * c - im[nn] * s;
13665 out_im[k] += re[nn] * s + im[nn] * c;
13666 }
13667 }
13668 re.copy_from_slice(&out_re);
13669 im.copy_from_slice(&out_im);
13670}
13671
13672fn fft_bluestein_inplace_f64(
13681 re: &mut [f64],
13682 im: &mut [f64],
13683 _inverse: bool,
13684 s: &mut BluesteinScratchF64,
13685) {
13686 let n = re.len();
13687 debug_assert_eq!(im.len(), n);
13688 debug_assert_eq!(s.w_re.len(), n);
13689 if n <= 1 {
13690 return;
13691 }
13692 let m = s.m;
13693
13694 for k in 0..m {
13696 s.ar[k] = 0.0;
13697 s.ai[k] = 0.0;
13698 }
13699 for k in 0..n {
13700 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13701 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13702 }
13703
13704 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
13706
13707 for k in 0..m {
13709 let ar = s.ar[k];
13710 let ai = s.ai[k];
13711 let br = s.bf_re[k];
13712 let bi = s.bf_im[k];
13713 s.ar[k] = ar * br - ai * bi;
13714 s.ai[k] = ar * bi + ai * br;
13715 }
13716
13717 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
13720 let inv_m = 1.0 / (m as f64);
13721
13722 for k in 0..n {
13724 let yr = s.ar[k] * inv_m;
13725 let yi = s.ai[k] * inv_m;
13726 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13727 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13728 }
13729}
13730
13731struct BluesteinScratchF32 {
13735 m: usize,
13736 w_re: Vec<f32>,
13737 w_im: Vec<f32>,
13738 bf_re: Vec<f32>,
13739 bf_im: Vec<f32>,
13740 ar: Vec<f32>,
13741 ai: Vec<f32>,
13742}
13743
13744impl BluesteinScratchF32 {
13745 fn empty() -> Self {
13746 Self {
13747 m: 0,
13748 w_re: Vec::new(),
13749 w_im: Vec::new(),
13750 bf_re: Vec::new(),
13751 bf_im: Vec::new(),
13752 ar: Vec::new(),
13753 ai: Vec::new(),
13754 }
13755 }
13756
13757 fn build(n: usize, inverse: bool) -> Self {
13758 let m = if n <= 1 {
13759 1
13760 } else {
13761 (2 * n - 1).next_power_of_two()
13762 };
13763
13764 let mod_2n = (2 * n) as u64;
13765 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13766 let mut w_re = vec![0.0_f32; n];
13767 let mut w_im = vec![0.0_f32; n];
13768 for k in 0..n {
13769 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13770 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13771 w_re[k] = theta.cos() as f32;
13772 w_im[k] = theta.sin() as f32;
13773 }
13774
13775 let mut bf_re = vec![0.0_f32; m];
13776 let mut bf_im = vec![0.0_f32; m];
13777 if n > 0 {
13778 bf_re[0] = w_re[0];
13779 bf_im[0] = -w_im[0];
13780 for k in 1..n {
13781 bf_re[k] = w_re[k];
13782 bf_im[k] = -w_im[k];
13783 bf_re[m - k] = w_re[k];
13784 bf_im[m - k] = -w_im[k];
13785 }
13786 }
13787 if m > 1 {
13788 fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
13789 }
13790
13791 Self {
13792 m,
13793 w_re,
13794 w_im,
13795 bf_re,
13796 bf_im,
13797 ar: vec![0.0_f32; m],
13798 ai: vec![0.0_f32; m],
13799 }
13800 }
13801}
13802
13803fn fft_bluestein_inplace_f32(
13804 re: &mut [f32],
13805 im: &mut [f32],
13806 _inverse: bool,
13807 s: &mut BluesteinScratchF32,
13808) {
13809 let n = re.len();
13810 debug_assert_eq!(im.len(), n);
13811 debug_assert_eq!(s.w_re.len(), n);
13812 if n <= 1 {
13813 return;
13814 }
13815 let m = s.m;
13816
13817 for k in 0..m {
13818 s.ar[k] = 0.0;
13819 s.ai[k] = 0.0;
13820 }
13821 for k in 0..n {
13822 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13823 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13824 }
13825
13826 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
13827
13828 for k in 0..m {
13829 let ar = s.ar[k];
13830 let ai = s.ai[k];
13831 let br = s.bf_re[k];
13832 let bi = s.bf_im[k];
13833 s.ar[k] = ar * br - ai * bi;
13834 s.ai[k] = ar * bi + ai * br;
13835 }
13836
13837 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
13838 let inv_m = 1.0_f32 / (m as f32);
13839
13840 for k in 0..n {
13841 let yr = s.ar[k] * inv_m;
13842 let yi = s.ai[k] * inv_m;
13843 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13844 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13845 }
13846}
13847
13848unsafe fn dispatch_custom_op(
13854 kernel: &dyn crate::op_registry::CpuKernel,
13855 inputs: &[(usize, u32, Shape)],
13856 out_off: usize,
13857 out_len: u32,
13858 out_shape: &Shape,
13859 attrs: &[u8],
13860 base: *mut u8,
13861) {
13862 use crate::op_registry::{CpuTensorMut, CpuTensorRef};
13863 use rlx_ir::DType;
13864
13865 macro_rules! build_in_view {
13870 ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
13871 CpuTensorRef::$variant {
13872 data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
13873 shape: $shape,
13874 }
13875 };
13876 }
13877 macro_rules! build_out_view {
13878 ($variant:ident, $rust_ty:ty) => {
13879 CpuTensorMut::$variant {
13880 data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
13881 shape: out_shape,
13882 }
13883 };
13884 }
13885
13886 let in_views: Vec<CpuTensorRef<'_>> = inputs
13887 .iter()
13888 .map(|(off, len, shape)| {
13889 let n = *len as usize;
13890 let off = *off;
13891 match shape.dtype() {
13892 DType::F32 => build_in_view!(shape, off, n, F32, f32),
13893 DType::F64 => build_in_view!(shape, off, n, F64, f64),
13894 DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
13895 DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
13896 DType::I8 => build_in_view!(shape, off, n, I8, i8),
13897 DType::I16 => build_in_view!(shape, off, n, I16, i16),
13898 DType::I32 => build_in_view!(shape, off, n, I32, i32),
13899 DType::I64 => build_in_view!(shape, off, n, I64, i64),
13900 DType::U8 => build_in_view!(shape, off, n, U8, u8),
13901 DType::U32 => build_in_view!(shape, off, n, U32, u32),
13902 DType::Bool => build_in_view!(shape, off, n, Bool, u8),
13903 DType::C64 => panic!(
13907 "Op::Custom kernel input has DType::C64 — built-in \
13908 complex ops handle their own kernels; user-registered \
13909 ops don't yet see complex tensors"
13910 ),
13911 }
13912 })
13913 .collect();
13914
13915 let result = match out_shape.dtype() {
13916 DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
13917 DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
13918 DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
13919 DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
13920 DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
13921 DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
13922 DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
13923 DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
13924 DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
13925 DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
13926 DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
13927 DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
13928 };
13929 if let Err(e) = result {
13930 panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
13931 }
13932}
13933
13934#[inline(always)]
13940unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
13941 if offset == usize::MAX {
13942 return &[];
13943 }
13944 unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
13945}
13946
13947#[inline(always)]
13948unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
13949 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
13950}
13951
13952#[inline(always)]
13954fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
13958 use rlx_ir::op::Activation;
13959 match act {
13960 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
13961 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
13962 Activation::Silu => crate::kernels::par_silu_inplace(d),
13963 Activation::Relu => {
13964 for v in d.iter_mut() {
13965 *v = v.max(0.0);
13966 }
13967 }
13968 Activation::Sigmoid => {
13969 for v in d.iter_mut() {
13970 *v = 1.0 / (1.0 + (-*v).exp());
13971 }
13972 }
13973 Activation::Tanh => {
13974 for v in d.iter_mut() {
13975 *v = v.tanh();
13976 }
13977 }
13978 Activation::Exp => {
13979 for v in d.iter_mut() {
13980 *v = v.exp();
13981 }
13982 }
13983 Activation::Log => {
13984 for v in d.iter_mut() {
13985 *v = v.ln();
13986 }
13987 }
13988 Activation::Sqrt => {
13989 for v in d.iter_mut() {
13990 *v = v.sqrt();
13991 }
13992 }
13993 Activation::Rsqrt => {
13994 for v in d.iter_mut() {
13995 *v = 1.0 / v.sqrt();
13996 }
13997 }
13998 Activation::Neg => {
13999 for v in d.iter_mut() {
14000 *v = -*v;
14001 }
14002 }
14003 Activation::Abs => {
14004 for v in d.iter_mut() {
14005 *v = v.abs();
14006 }
14007 }
14008 Activation::Round => {
14009 for v in d.iter_mut() {
14010 *v = v.round();
14011 }
14012 }
14013 Activation::Sin => {
14014 for v in d.iter_mut() {
14015 *v = v.sin();
14016 }
14017 }
14018 Activation::Cos => {
14019 for v in d.iter_mut() {
14020 *v = v.cos();
14021 }
14022 }
14023 Activation::Tan => {
14024 for v in d.iter_mut() {
14025 *v = v.tan();
14026 }
14027 }
14028 Activation::Atan => {
14029 for v in d.iter_mut() {
14030 *v = v.atan();
14031 }
14032 }
14033 }
14034}
14035
14036#[allow(clippy::too_many_arguments)]
14045fn im2col(
14046 x: &[f32],
14047 col: &mut [f32],
14048 c_in: usize,
14049 h: usize,
14050 w: usize,
14051 h_out: usize,
14052 w_out: usize,
14053 kh: usize,
14054 kw: usize,
14055 sh: usize,
14056 sw: usize,
14057 ph: usize,
14058 pw: usize,
14059 dh: usize,
14060 dw_dil: usize,
14061) {
14062 let n_dim = h_out * w_out;
14063 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
14064 debug_assert_eq!(x.len(), c_in * h * w);
14065 let h_isz = h as isize;
14066 let w_isz = w as isize;
14067 let ph_isz = ph as isize;
14068 let pw_isz = pw as isize;
14069 for ci in 0..c_in {
14070 for ki in 0..kh {
14071 for kj in 0..kw {
14072 let row = ((ci * kh) + ki) * kw + kj;
14073 let row_off = row * n_dim;
14074 for ho in 0..h_out {
14075 let hi = (ho * sh + ki * dh) as isize - ph_isz;
14076 if hi < 0 || hi >= h_isz {
14077 for wo in 0..w_out {
14078 col[row_off + ho * w_out + wo] = 0.0;
14079 }
14080 continue;
14081 }
14082 let hi = hi as usize;
14083 let in_row_off = (ci * h + hi) * w;
14084 for wo in 0..w_out {
14085 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
14086 col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
14087 0.0
14088 } else {
14089 x[in_row_off + wi as usize]
14090 };
14091 }
14092 }
14093 }
14094 }
14095 }
14096}
14097
14098#[allow(clippy::too_many_arguments)]
14105fn col2im(
14106 col: &[f32],
14107 x: &mut [f32],
14108 c_in: usize,
14109 h: usize,
14110 w: usize,
14111 h_out: usize,
14112 w_out: usize,
14113 kh: usize,
14114 kw: usize,
14115 sh: usize,
14116 sw: usize,
14117 ph: usize,
14118 pw: usize,
14119 dh: usize,
14120 dw_dil: usize,
14121) {
14122 let n_dim = h_out * w_out;
14123 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
14124 debug_assert_eq!(x.len(), c_in * h * w);
14125 let h_isz = h as isize;
14126 let w_isz = w as isize;
14127 let ph_isz = ph as isize;
14128 let pw_isz = pw as isize;
14129 for ci in 0..c_in {
14130 for ki in 0..kh {
14131 for kj in 0..kw {
14132 let row = ((ci * kh) + ki) * kw + kj;
14133 let row_off = row * n_dim;
14134 for ho in 0..h_out {
14135 let hi = (ho * sh + ki * dh) as isize - ph_isz;
14136 if hi < 0 || hi >= h_isz {
14137 continue;
14138 }
14139 let hi = hi as usize;
14140 let in_row_off = (ci * h + hi) * w;
14141 for wo in 0..w_out {
14142 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
14143 if wi < 0 || wi >= w_isz {
14144 continue;
14145 }
14146 x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
14147 }
14148 }
14149 }
14150 }
14151 }
14152}
14153
14154fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
14164 match axis {
14165 None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
14166 Some(d) => {
14167 let chan_dim = shape.dim(d).unwrap_static();
14168 let inner: usize = (d + 1..shape.rank())
14169 .map(|i| shape.dim(i).unwrap_static())
14170 .product::<usize>()
14171 .max(1);
14172 (d, chan_dim, inner)
14173 }
14174 }
14175}
14176
14177fn activation_backward_kernel(
14178 act: rlx_ir::op::Activation,
14179 xs: &[f32],
14180 dys: &[f32],
14181 out: &mut [f32],
14182) {
14183 use rlx_ir::op::Activation;
14184 let n = xs.len();
14185 debug_assert_eq!(dys.len(), n);
14186 debug_assert_eq!(out.len(), n);
14187 match act {
14188 Activation::Relu => {
14189 for i in 0..n {
14190 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14191 }
14192 }
14193 Activation::Sigmoid => {
14194 for i in 0..n {
14195 let s = 1.0 / (1.0 + (-xs[i]).exp());
14196 out[i] = s * (1.0 - s) * dys[i];
14197 }
14198 }
14199 Activation::Tanh => {
14200 for i in 0..n {
14201 let t = xs[i].tanh();
14202 out[i] = (1.0 - t * t) * dys[i];
14203 }
14204 }
14205 Activation::Silu => {
14206 for i in 0..n {
14208 let s = 1.0 / (1.0 + (-xs[i]).exp());
14209 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14210 }
14211 }
14212 Activation::Gelu => {
14213 const INV_SQRT2: f32 = 0.707_106_77;
14216 const INV_SQRT_2PI: f32 = 0.398_942_3;
14217 for i in 0..n {
14218 let x = xs[i];
14219 let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
14220 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14221 out[i] = (phi + x * pdf) * dys[i];
14222 }
14223 }
14224 Activation::GeluApprox => {
14225 const C: f32 = 0.797_884_6; const A: f32 = 0.044_715;
14229 for i in 0..n {
14230 let x = xs[i];
14231 let inner = C * (x + A * x * x * x);
14232 let t = inner.tanh();
14233 let dinner = C * (1.0 + 3.0 * A * x * x);
14234 let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
14235 out[i] = d * dys[i];
14236 }
14237 }
14238 Activation::Exp => {
14239 for i in 0..n {
14240 out[i] = xs[i].exp() * dys[i];
14241 }
14242 }
14243 Activation::Log => {
14244 for i in 0..n {
14245 out[i] = dys[i] / xs[i];
14246 }
14247 }
14248 Activation::Sqrt => {
14249 for i in 0..n {
14251 let s = xs[i].sqrt();
14252 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14253 }
14254 }
14255 Activation::Rsqrt => {
14256 for i in 0..n {
14258 let s = xs[i].sqrt();
14259 out[i] = if s > 0.0 {
14260 -0.5 * dys[i] / (xs[i] * s)
14261 } else {
14262 0.0
14263 };
14264 }
14265 }
14266 Activation::Neg => {
14267 for i in 0..n {
14268 out[i] = -dys[i];
14269 }
14270 }
14271 Activation::Abs => {
14272 for i in 0..n {
14274 let x = xs[i];
14275 let s = if x > 0.0 {
14276 1.0
14277 } else if x < 0.0 {
14278 -1.0
14279 } else {
14280 0.0
14281 };
14282 out[i] = s * dys[i];
14283 }
14284 }
14285 Activation::Round => {
14286 out.copy_from_slice(dys);
14291 }
14292 Activation::Sin => {
14293 for i in 0..n {
14295 out[i] = xs[i].cos() * dys[i];
14296 }
14297 }
14298 Activation::Cos => {
14299 for i in 0..n {
14300 out[i] = -xs[i].sin() * dys[i];
14301 }
14302 }
14303 Activation::Tan => {
14304 for i in 0..n {
14306 let t = xs[i].tan();
14307 out[i] = (1.0 + t * t) * dys[i];
14308 }
14309 }
14310 Activation::Atan => {
14311 for i in 0..n {
14313 let x = xs[i];
14314 out[i] = dys[i] / (1.0 + x * x);
14315 }
14316 }
14317 }
14318}
14319
14320fn activation_backward_kernel_f64(
14324 act: rlx_ir::op::Activation,
14325 xs: &[f64],
14326 dys: &[f64],
14327 out: &mut [f64],
14328) {
14329 use rlx_ir::op::Activation;
14330 let n = xs.len();
14331 debug_assert_eq!(dys.len(), n);
14332 debug_assert_eq!(out.len(), n);
14333 match act {
14334 Activation::Relu => {
14335 for i in 0..n {
14336 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14337 }
14338 }
14339 Activation::Sigmoid => {
14340 for i in 0..n {
14341 let s = 1.0 / (1.0 + (-xs[i]).exp());
14342 out[i] = s * (1.0 - s) * dys[i];
14343 }
14344 }
14345 Activation::Tanh => {
14346 for i in 0..n {
14347 let t = xs[i].tanh();
14348 out[i] = (1.0 - t * t) * dys[i];
14349 }
14350 }
14351 Activation::Silu => {
14352 for i in 0..n {
14353 let s = 1.0 / (1.0 + (-xs[i]).exp());
14354 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14355 }
14356 }
14357 Activation::Gelu | Activation::GeluApprox => {
14358 const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
14360 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
14361 for i in 0..n {
14362 let x = xs[i];
14363 let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
14364 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14365 out[i] = (phi + x * pdf) * dys[i];
14366 }
14367 }
14368 Activation::Exp => {
14369 for i in 0..n {
14370 out[i] = xs[i].exp() * dys[i];
14371 }
14372 }
14373 Activation::Log => {
14374 for i in 0..n {
14375 out[i] = dys[i] / xs[i];
14376 }
14377 }
14378 Activation::Sqrt => {
14379 for i in 0..n {
14380 let s = xs[i].sqrt();
14381 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14382 }
14383 }
14384 Activation::Rsqrt => {
14385 for i in 0..n {
14386 let s = xs[i].sqrt();
14387 out[i] = if s > 0.0 {
14388 -0.5 * dys[i] / (xs[i] * s)
14389 } else {
14390 0.0
14391 };
14392 }
14393 }
14394 Activation::Neg => {
14395 for i in 0..n {
14396 out[i] = -dys[i];
14397 }
14398 }
14399 Activation::Abs => {
14400 for i in 0..n {
14401 let x = xs[i];
14402 let s = if x > 0.0 {
14403 1.0
14404 } else if x < 0.0 {
14405 -1.0
14406 } else {
14407 0.0
14408 };
14409 out[i] = s * dys[i];
14410 }
14411 }
14412 Activation::Round => {
14413 out.copy_from_slice(dys);
14414 }
14415 Activation::Sin => {
14416 for i in 0..n {
14417 out[i] = xs[i].cos() * dys[i];
14418 }
14419 }
14420 Activation::Cos => {
14421 for i in 0..n {
14422 out[i] = -xs[i].sin() * dys[i];
14423 }
14424 }
14425 Activation::Tan => {
14426 for i in 0..n {
14427 let t = xs[i].tan();
14428 out[i] = (1.0 + t * t) * dys[i];
14429 }
14430 }
14431 Activation::Atan => {
14432 for i in 0..n {
14433 let x = xs[i];
14434 out[i] = dys[i] / (1.0 + x * x);
14435 }
14436 }
14437 }
14438}
14439
14440#[inline(always)]
14445fn erf_f64(x: f64) -> f64 {
14446 let s = x.signum();
14447 let x = x.abs();
14448 let t = 1.0 / (1.0 + 0.327_591_1 * x);
14449 let y = 1.0
14450 - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
14451 + 0.254_829_59)
14452 * t
14453 * (-x * x).exp();
14454 s * y
14455}
14456
14457#[inline(always)]
14460fn erf_f32(x: f32) -> f32 {
14461 let s = x.signum();
14462 let x = x.abs();
14463 let t = 1.0 / (1.0 + 0.327_591_1 * x);
14464 let y = 1.0
14465 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
14466 + 0.254_829_6)
14467 * t
14468 * (-x * x).exp();
14469 s * y
14470}
14471
14472fn narrow_thunk_closure(
14473 src: usize,
14474 dst: usize,
14475 outer: u32,
14476 src_stride: u32,
14477 dst_stride: u32,
14478 inner: u32,
14479 elem_bytes: u8,
14480) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
14481 let (outer, ss, ds, inner) = (
14482 outer as usize,
14483 src_stride as usize,
14484 dst_stride as usize,
14485 inner as usize,
14486 );
14487 if elem_bytes == 8 {
14488 Arc::new(move |base: *mut u8| unsafe {
14489 let s = sl_f64(src, base, outer * ss);
14490 let d = sl_mut_f64(dst, base, outer * ds);
14491 for o in 0..outer {
14492 d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14493 }
14494 })
14495 } else {
14496 Arc::new(move |base: *mut u8| unsafe {
14497 let s = sl(src, base, outer * ss);
14498 let d = sl_mut(dst, base, outer * ds);
14499 for o in 0..outer {
14500 d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14501 }
14502 })
14503 }
14504}
14505
14506unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
14507 if offset == usize::MAX {
14508 return &[];
14509 }
14510 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
14511}
14512
14513#[inline(always)]
14514unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
14515 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
14516}
14517
14518#[inline(always)]
14519unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
14520 if offset == usize::MAX {
14521 return &[];
14522 }
14523 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
14524}
14525
14526#[inline(always)]
14527unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
14528 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
14529}
14530
14531#[allow(dead_code)]
14536#[inline(always)]
14537unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
14538 if offset == usize::MAX {
14539 return &[];
14540 }
14541 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
14542}
14543
14544#[allow(dead_code)]
14545#[inline(always)]
14546unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
14547 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
14548}
14549
14550#[allow(dead_code)]
14551#[inline(always)]
14552unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
14553 if offset == usize::MAX {
14554 return &[];
14555 }
14556 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
14557}
14558
14559#[allow(dead_code)]
14560#[inline(always)]
14561unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
14562 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
14563}
14564
14565fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
14569 let rank = out_dims.len();
14570 let mut idx = vec![0u32; rank];
14571 for o in 0..out.len() {
14572 let mut src_off = 0usize;
14573 for d in 0..rank {
14574 src_off += idx[d] as usize * in_strides[d] as usize;
14575 }
14576 out[o] = inp[src_off];
14577 for d in (0..rank).rev() {
14579 idx[d] += 1;
14580 if idx[d] < out_dims[d] {
14581 break;
14582 }
14583 idx[d] = 0;
14584 }
14585 }
14586}
14587
14588fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
14594 match kind {
14595 Activation::Neg => {
14596 for (o, &v) in out.iter_mut().zip(inp) {
14597 *o = -v;
14598 }
14599 }
14600 Activation::Exp => {
14601 for (o, &v) in out.iter_mut().zip(inp) {
14602 *o = v.exp();
14603 }
14604 }
14605 Activation::Log => {
14606 for (o, &v) in out.iter_mut().zip(inp) {
14607 *o = v.ln();
14608 }
14609 }
14610 Activation::Sqrt => {
14611 for (o, &v) in out.iter_mut().zip(inp) {
14612 *o = v.sqrt();
14613 }
14614 }
14615 Activation::Rsqrt => {
14616 for (o, &v) in out.iter_mut().zip(inp) {
14617 *o = 1.0 / v.sqrt();
14618 }
14619 }
14620 Activation::Abs => {
14621 for (o, &v) in out.iter_mut().zip(inp) {
14622 *o = v.abs();
14623 }
14624 }
14625 Activation::Tanh => {
14626 for (o, &v) in out.iter_mut().zip(inp) {
14627 *o = v.tanh();
14628 }
14629 }
14630 Activation::Sigmoid => {
14631 for (o, &v) in out.iter_mut().zip(inp) {
14632 *o = 1.0 / (1.0 + (-v).exp());
14633 }
14634 }
14635 Activation::Relu => {
14636 for (o, &v) in out.iter_mut().zip(inp) {
14637 *o = v.max(0.0);
14638 }
14639 }
14640 Activation::Round => {
14641 for (o, &v) in out.iter_mut().zip(inp) {
14642 *o = v.round_ties_even();
14643 }
14644 }
14645 Activation::Sin => {
14646 for (o, &v) in out.iter_mut().zip(inp) {
14647 *o = v.sin();
14648 }
14649 }
14650 Activation::Cos => {
14651 for (o, &v) in out.iter_mut().zip(inp) {
14652 *o = v.cos();
14653 }
14654 }
14655 Activation::Tan => {
14656 for (o, &v) in out.iter_mut().zip(inp) {
14657 *o = v.tan();
14658 }
14659 }
14660 Activation::Atan => {
14661 for (o, &v) in out.iter_mut().zip(inp) {
14662 *o = v.atan();
14663 }
14664 }
14665 Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
14666 panic!(
14667 "apply_activation_f64: {kind:?} not yet implemented at f64. \
14668 Add when a workload needs it."
14669 );
14670 }
14671 }
14672}
14673
14674#[inline]
14675fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
14676 match op {
14677 BinaryOp::Add => a + b,
14678 BinaryOp::Sub => a - b,
14679 BinaryOp::Mul => a * b,
14680 BinaryOp::Div => a / b,
14681 BinaryOp::Max => a.max(b),
14682 BinaryOp::Min => a.min(b),
14683 BinaryOp::Pow => a.powf(b),
14684 }
14685}
14686
14687fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
14690 for o in 0..outer {
14691 for n in 0..inner {
14692 let mut acc = 0.0_f64;
14693 for r in 0..reduced {
14694 acc += inp[o * reduced * inner + r * inner + n];
14695 }
14696 out[o * inner + n] = acc;
14697 }
14698 }
14699}
14700
14701#[cfg(test)]
14702mod tests {
14703 use super::*;
14704 use rlx_ir::*;
14705
14706 #[test]
14712 fn narrow_rope_fuses_in_unfused_path() {
14713 let f = DType::F32;
14714 let mut g = Graph::new("nr_fuse");
14715 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); let cos = g.input("cos", Shape::new(&[16], f));
14718 let sin = g.input("sin", Shape::new(&[16], f));
14719 let q = g.narrow_(qkv, 2, 0, 64);
14721 let q_rope = g.rope(q, cos, sin, 16);
14722 g.set_outputs(vec![q_rope]);
14723
14724 let plan = rlx_opt::memory::plan_memory(&g);
14725 let arena = crate::arena::Arena::from_plan(plan);
14726 let sched = compile_thunks(&g, &arena);
14727
14728 let mut narrow_count = 0;
14729 let mut rope_with_stride: Option<u32> = None;
14730 for t in &sched.thunks {
14731 match t {
14732 Thunk::Narrow { .. } => narrow_count += 1,
14733 Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
14734 _ => {}
14735 }
14736 }
14737 assert_eq!(
14740 narrow_count, 0,
14741 "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
14742 );
14743 assert_eq!(
14744 rope_with_stride,
14745 Some(192),
14746 "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
14747 );
14748 }
14749
14750 #[test]
14753 fn ssm_selective_scan_matches_reference() {
14754 use rlx_ir::Philox4x32;
14755 let bch = 1usize;
14756 let s = 4usize;
14757 let h = 3usize;
14758 let n = 2usize;
14759
14760 let mut rng = Philox4x32::new(13);
14761 let mut x = vec![0f32; bch * s * h];
14762 rng.fill_normal(&mut x);
14763 let mut delta = vec![0f32; bch * s * h];
14764 for v in delta.iter_mut() {
14766 *v = (rng.next_f32() - 0.5) * 0.1;
14767 }
14768 let mut a = vec![0f32; h * n];
14769 for v in a.iter_mut() {
14770 *v = -(rng.next_f32() * 0.5 + 0.1);
14771 } let mut b = vec![0f32; bch * s * n];
14773 rng.fill_normal(&mut b);
14774 let mut c = vec![0f32; bch * s * n];
14775 rng.fill_normal(&mut c);
14776
14777 let mut expected = vec![0f32; bch * s * h];
14779 for bi in 0..bch {
14780 let mut state = vec![0f32; h * n];
14781 for si in 0..s {
14782 for ci in 0..h {
14783 let d = delta[bi * s * h + si * h + ci];
14784 let xv = x[bi * s * h + si * h + ci];
14785 let mut acc = 0f32;
14786 for ni in 0..n {
14787 let da = (d * a[ci * n + ni]).exp();
14788 state[ci * n + ni] =
14789 da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
14790 acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
14791 }
14792 expected[bi * s * h + si * h + ci] = acc;
14793 }
14794 }
14795 }
14796
14797 let f = DType::F32;
14799 let mut g = Graph::new("ssm");
14800 let xn = g.input("x", Shape::new(&[bch, s, h], f));
14801 let dn = g.input("delta", Shape::new(&[bch, s, h], f));
14802 let an = g.param("a", Shape::new(&[h, n], f));
14803 let bn = g.param("b", Shape::new(&[bch, s, n], f));
14804 let cn = g.param("c", Shape::new(&[bch, s, n], f));
14805 let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
14806 g.set_outputs(vec![yn]);
14807
14808 let plan = rlx_opt::memory::plan_memory(&g);
14809 let mut arena = crate::arena::Arena::from_plan(plan);
14810 let sched = compile_thunks(&g, &arena);
14811
14812 let xn_off = arena.byte_offset(xn);
14813 let dn_off = arena.byte_offset(dn);
14814 let an_off = arena.byte_offset(an);
14815 let bn_off = arena.byte_offset(bn);
14816 let cn_off = arena.byte_offset(cn);
14817 let yn_off = arena.byte_offset(yn);
14818 let buf = arena.raw_buf_mut();
14819 unsafe {
14820 let copy = |dst: *mut f32, data: &[f32]| {
14821 for (i, &v) in data.iter().enumerate() {
14822 *dst.add(i) = v;
14823 }
14824 };
14825 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14826 copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
14827 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14828 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14829 copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
14830 }
14831 execute_thunks(&sched, arena.raw_buf_mut());
14832
14833 let actual: Vec<f32> = unsafe {
14834 let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
14835 (0..bch * s * h).map(|i| *p.add(i)).collect()
14836 };
14837
14838 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14839 assert!(
14840 (e - a).abs() < 1e-3,
14841 "mismatch at {i}: expected {e}, got {a}"
14842 );
14843 }
14844 }
14845
14846 #[test]
14849 fn conv_1x1_fast_path_matches_scalar() {
14850 use rlx_ir::Philox4x32;
14851 let n = 2usize;
14853 let c_in = 4usize;
14854 let h = 3usize;
14855 let w = 3usize;
14856 let c_out = 5usize;
14857 let mut rng = Philox4x32::new(31);
14858 let mut x = vec![0f32; n * c_in * h * w];
14859 rng.fill_normal(&mut x);
14860 let mut weight = vec![0f32; c_out * c_in];
14861 rng.fill_normal(&mut weight);
14862
14863 let mut expected = vec![0f32; n * c_out * h * w];
14866 for ni in 0..n {
14867 for co in 0..c_out {
14868 for hi in 0..h {
14869 for wi in 0..w {
14870 let mut acc = 0f32;
14871 for ci in 0..c_in {
14872 acc += weight[co * c_in + ci]
14873 * x[((ni * c_in) + ci) * h * w + hi * w + wi];
14874 }
14875 expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
14876 }
14877 }
14878 }
14879 }
14880
14881 let f = DType::F32;
14883 let mut g = Graph::new("conv1x1");
14884 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
14885 let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
14886 let cn = g.add_node(
14888 rlx_ir::Op::Conv {
14889 kernel_size: vec![1, 1],
14890 stride: vec![1, 1],
14891 padding: vec![0, 0],
14892 dilation: vec![1, 1],
14893 groups: 1,
14894 },
14895 vec![xn, wn],
14896 Shape::new(&[n, c_out, h, w], f),
14897 );
14898 g.set_outputs(vec![cn]);
14899
14900 let plan = rlx_opt::memory::plan_memory(&g);
14901 let mut arena = crate::arena::Arena::from_plan(plan);
14902 let sched = compile_thunks(&g, &arena);
14903
14904 let saw_fast = sched
14906 .thunks
14907 .iter()
14908 .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
14909 let saw_slow = sched
14910 .thunks
14911 .iter()
14912 .any(|t| matches!(t, Thunk::Conv2D { .. }));
14913 assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
14914 assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
14915
14916 let xn_off = arena.byte_offset(xn);
14917 let wn_off = arena.byte_offset(wn);
14918 let cn_off = arena.byte_offset(cn);
14919 let buf = arena.raw_buf_mut();
14920 unsafe {
14921 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14922 for (i, &v) in x.iter().enumerate() {
14923 *xp.add(i) = v;
14924 }
14925 let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
14926 for (i, &v) in weight.iter().enumerate() {
14927 *wp.add(i) = v;
14928 }
14929 }
14930 execute_thunks(&sched, arena.raw_buf_mut());
14931
14932 let actual: Vec<f32> = unsafe {
14933 let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
14934 (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
14935 };
14936
14937 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14938 assert!(
14939 (e - a).abs() < 1e-3,
14940 "mismatch at {i}: expected {e}, got {a}"
14941 );
14942 }
14943 }
14944
14945 #[test]
14948 fn dequant_matmul_int8_sym_matches_reference() {
14949 use rlx_ir::Philox4x32;
14950 use rlx_ir::quant::QuantScheme;
14951
14952 let m = 3usize;
14953 let k = 8usize;
14954 let n = 4usize;
14955 let block_size = 4usize; let blocks_per_col = k / block_size;
14957
14958 let mut rng = Philox4x32::new(99);
14960 let mut x = vec![0f32; m * k];
14961 rng.fill_normal(&mut x);
14962 let w_q: Vec<i8> = (0..(k * n))
14963 .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
14964 .collect();
14965 let scales: Vec<f32> = (0..(blocks_per_col * n))
14966 .map(|i| 0.01 + 0.001 * i as f32)
14967 .collect();
14968
14969 let mut w_f32 = vec![0f32; k * n];
14971 for p in 0..k {
14972 let block = p / block_size;
14973 for j in 0..n {
14974 let s = scales[block * n + j];
14975 w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
14976 }
14977 }
14978 let mut expected = vec![0f32; m * n];
14979 for i in 0..m {
14980 for j in 0..n {
14981 let mut acc = 0f32;
14982 for p in 0..k {
14983 acc += x[i * k + p] * w_f32[p * n + j];
14984 }
14985 expected[i * n + j] = acc;
14986 }
14987 }
14988
14989 let f = DType::F32;
14991 let mut g = Graph::new("dq");
14992 let xn = g.input("x", Shape::new(&[m, k], f));
14993 let wn = g.param("w", Shape::new(&[k, n], DType::I8));
14994 let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
14995 let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); let dq = g.dequant_matmul(
14997 xn,
14998 wn,
14999 sn,
15000 zn,
15001 QuantScheme::Int8Block {
15002 block_size: block_size as u32,
15003 },
15004 Shape::new(&[m, n], f),
15005 );
15006 g.set_outputs(vec![dq]);
15007
15008 let plan = rlx_opt::memory::plan_memory(&g);
15009 let mut arena = crate::arena::Arena::from_plan(plan);
15010 let sched = compile_thunks(&g, &arena);
15011
15012 let xn_off = arena.byte_offset(xn);
15013 let wn_off = arena.byte_offset(wn);
15014 let sn_off = arena.byte_offset(sn);
15015 let zn_off = arena.byte_offset(zn);
15016 let dq_off = arena.byte_offset(dq);
15017 let buf = arena.raw_buf_mut();
15018 unsafe {
15019 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
15021 for (i, &v) in x.iter().enumerate() {
15022 *xp.add(i) = v;
15023 }
15024 let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
15025 for (i, &v) in scales.iter().enumerate() {
15026 *sp.add(i) = v;
15027 }
15028 let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
15029 for i in 0..(blocks_per_col * n) {
15030 *zp.add(i) = 0.0;
15031 }
15032 let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
15034 for (i, &v) in w_q.iter().enumerate() {
15035 *wp.add(i) = v;
15036 }
15037 }
15038 execute_thunks(&sched, arena.raw_buf_mut());
15039
15040 let actual: Vec<f32> = unsafe {
15041 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
15042 (0..m * n).map(|i| *p.add(i)).collect()
15043 };
15044
15045 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
15046 assert!(
15047 (e - a).abs() < 1e-3,
15048 "mismatch at {i}: expected {e}, got {a}"
15049 );
15050 }
15051 }
15052
15053 #[test]
15055 fn lora_matmul_matches_unfused_reference() {
15056 use rlx_ir::Philox4x32;
15057
15058 let m = 4usize;
15059 let k = 8usize;
15060 let n = 6usize;
15061 let r = 2usize;
15062 let scale = 0.5f32;
15063
15064 let mut rng = Philox4x32::new(42);
15066 let mut x = vec![0f32; m * k];
15067 rng.fill_normal(&mut x);
15068 let mut w = vec![0f32; k * n];
15069 rng.fill_normal(&mut w);
15070 let mut a = vec![0f32; k * r];
15071 rng.fill_normal(&mut a);
15072 let mut b = vec![0f32; r * n];
15073 rng.fill_normal(&mut b);
15074
15075 let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
15077 let mut o = vec![0f32; rows * cols];
15078 for i in 0..rows {
15079 for j in 0..cols {
15080 let mut acc = 0f32;
15081 for p in 0..inner {
15082 acc += a_buf[i * inner + p] * b_buf[p * cols + j];
15083 }
15084 o[i * cols + j] = acc;
15085 }
15086 }
15087 o
15088 };
15089 let xw = naive(&x, &w, m, k, n);
15090 let xa = naive(&x, &a, m, k, r);
15091 let xab = naive(&xa, &b, m, r, n);
15092 let mut expected = xw;
15093 for i in 0..(m * n) {
15094 expected[i] += scale * xab[i];
15095 }
15096
15097 let f = DType::F32;
15099 let mut g = Graph::new("lora");
15100 let xn = g.input("x", Shape::new(&[m, k], f));
15101 let wn = g.param("w", Shape::new(&[k, n], f));
15102 let an = g.param("a", Shape::new(&[k, r], f));
15103 let bn = g.param("b", Shape::new(&[r, n], f));
15104 let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
15105 g.set_outputs(vec![lm]);
15106
15107 let plan = rlx_opt::memory::plan_memory(&g);
15108 let mut arena = crate::arena::Arena::from_plan(plan);
15109 let sched = compile_thunks(&g, &arena);
15110
15111 let xn_off = arena.byte_offset(xn);
15112 let wn_off = arena.byte_offset(wn);
15113 let an_off = arena.byte_offset(an);
15114 let bn_off = arena.byte_offset(bn);
15115 let lm_off = arena.byte_offset(lm);
15116 let buf = arena.raw_buf_mut();
15117 unsafe {
15118 let copy = |dst: *mut f32, data: &[f32]| {
15119 for (i, &v) in data.iter().enumerate() {
15120 *dst.add(i) = v;
15121 }
15122 };
15123 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
15124 copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
15125 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
15126 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
15127 }
15128 execute_thunks(&sched, arena.raw_buf_mut());
15129
15130 let actual: Vec<f32> = unsafe {
15131 let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
15132 (0..m * n).map(|i| *p.add(i)).collect()
15133 };
15134
15135 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
15136 assert!(
15137 (e - a).abs() < 1e-3,
15138 "mismatch at {i}: expected {e}, got {a}"
15139 );
15140 }
15141 }
15142
15143 #[test]
15145 fn sample_temperature_zero_is_argmax() {
15146 let f = DType::F32;
15149 let mut g = Graph::new("samp");
15150 let logits = g.input("logits", Shape::new(&[1, 8], f));
15151 let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
15152 g.set_outputs(vec![s]);
15153 let plan = rlx_opt::memory::plan_memory(&g);
15154 let mut arena = crate::arena::Arena::from_plan(plan);
15155 let sched = compile_thunks(&g, &arena);
15156
15157 let logits_off = arena.byte_offset(logits);
15158 let s_off = arena.byte_offset(s);
15159 let buf = arena.raw_buf_mut();
15160 unsafe {
15161 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
15162 let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
15164 for (i, &v) in inputs.iter().enumerate() {
15165 *p.add(i) = v;
15166 }
15167 }
15168 execute_thunks(&sched, arena.raw_buf_mut());
15169
15170 let token = unsafe {
15171 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15172 *p as usize
15173 };
15174 assert_eq!(token, 5, "low-temp sampling should pick the argmax");
15175 }
15176
15177 #[test]
15178 fn sample_top_k_one_is_deterministic() {
15179 let f = DType::F32;
15181 let mut g = Graph::new("samp_k1");
15182 let logits = g.input("logits", Shape::new(&[1, 4], f));
15183 let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
15184 g.set_outputs(vec![s]);
15185 let plan = rlx_opt::memory::plan_memory(&g);
15186 let mut arena = crate::arena::Arena::from_plan(plan);
15187 let sched = compile_thunks(&g, &arena);
15188
15189 let logits_off = arena.byte_offset(logits);
15190 let s_off = arena.byte_offset(s);
15191 let buf = arena.raw_buf_mut();
15192 unsafe {
15193 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
15194 let inputs = [0.1f32, 5.0, 0.3, 0.4]; for (i, &v) in inputs.iter().enumerate() {
15196 *p.add(i) = v;
15197 }
15198 }
15199 execute_thunks(&sched, arena.raw_buf_mut());
15200 let token = unsafe {
15201 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15202 *p as usize
15203 };
15204 assert_eq!(token, 1);
15205 }
15206
15207 #[test]
15209 fn cumsum_inclusive_matches_naive() {
15210 let f = DType::F32;
15211 let mut g = Graph::new("cumsum");
15212 let x = g.input("x", Shape::new(&[2, 4], f));
15213 let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
15214 g.set_outputs(vec![cs]);
15215 let plan = rlx_opt::memory::plan_memory(&g);
15216 let mut arena = crate::arena::Arena::from_plan(plan);
15217 let sched = compile_thunks(&g, &arena);
15218
15219 let x_off = arena.byte_offset(x);
15221 let out_off = arena.byte_offset(cs);
15222 let buf = arena.raw_buf_mut();
15223 unsafe {
15224 let p = buf.as_mut_ptr().add(x_off) as *mut f32;
15225 let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
15226 for (i, &v) in inputs.iter().enumerate() {
15227 *p.add(i) = v;
15228 }
15229 }
15230 execute_thunks(&sched, arena.raw_buf_mut());
15231
15232 let out: Vec<f32> = unsafe {
15233 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
15234 (0..8).map(|i| *p.add(i)).collect()
15235 };
15236 assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
15237 }
15238
15239 #[test]
15243 fn narrow_attention_fuses_in_unfused_path() {
15244 let f = DType::F32;
15245 let mut g = Graph::new("nattn_fuse");
15246 let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); let mask = g.input("mask", Shape::new(&[8, 16], f));
15249 let q = g.narrow_(qkv, 2, 0, 64);
15250 let k = g.narrow_(qkv, 2, 64, 64);
15251 let v = g.narrow_(qkv, 2, 128, 64);
15252 let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
15253 g.set_outputs(vec![attn]);
15254
15255 let plan = rlx_opt::memory::plan_memory(&g);
15256 let arena = crate::arena::Arena::from_plan(plan);
15257 let sched = compile_thunks(&g, &arena);
15258
15259 let mut narrow_count = 0;
15260 let mut attn_strides: Option<(u32, u32, u32)> = None;
15261 for t in &sched.thunks {
15262 match t {
15263 Thunk::Narrow { .. } => narrow_count += 1,
15264 Thunk::Attention {
15265 q_row_stride,
15266 k_row_stride,
15267 v_row_stride,
15268 ..
15269 } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
15270 _ => {}
15271 }
15272 }
15273 assert_eq!(
15276 narrow_count, 0,
15277 "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
15278 );
15279 assert_eq!(
15280 attn_strides,
15281 Some((192, 192, 192)),
15282 "Attention should walk Q/K/V with parent row stride 192"
15283 );
15284 }
15285
15286 fn run_graph(
15297 g: &Graph,
15298 inputs: &[(NodeId, &[f32])],
15299 out_id: NodeId,
15300 out_len: usize,
15301 ) -> Vec<f32> {
15302 let plan = rlx_opt::memory::plan_memory(g);
15303 let mut arena = crate::arena::Arena::from_plan(plan);
15304 let sched = compile_thunks(g, &arena);
15305 for &(id, data) in inputs {
15306 let off = arena.byte_offset(id);
15307 let buf = arena.raw_buf_mut();
15308 unsafe {
15309 let p = buf.as_mut_ptr().add(off) as *mut f32;
15310 for (i, &v) in data.iter().enumerate() {
15311 *p.add(i) = v;
15312 }
15313 }
15314 }
15315 execute_thunks(&sched, arena.raw_buf_mut());
15316 let off = arena.byte_offset(out_id);
15317 unsafe {
15318 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15319 (0..out_len).map(|i| *p.add(i)).collect()
15320 }
15321 }
15322
15323 #[test]
15324 fn relu_backward_matches_mask() {
15325 let f = DType::F32;
15326 let len = 7usize;
15327 let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
15328 let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
15329
15330 let mut g = Graph::new("relu_bw");
15331 let xn = g.input("x", Shape::new(&[len], f));
15332 let dyn_ = g.input("dy", Shape::new(&[len], f));
15333 let dx = g.relu_backward(xn, dyn_);
15334 g.set_outputs(vec![dx]);
15335
15336 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
15337 let expected: Vec<f32> = x
15341 .iter()
15342 .zip(&dy)
15343 .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
15344 .collect();
15345 for (a, e) in actual.iter().zip(&expected) {
15346 assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
15347 }
15348 }
15349
15350 #[test]
15351 fn maxpool2d_backward_routes_to_argmax() {
15352 let f = DType::F32;
15353 let x: Vec<f32> = vec![
15355 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
15356 ];
15357 let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
15361
15362 let mut g = Graph::new("maxpool_bw");
15363 let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
15364 let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
15365 let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
15366 g.set_outputs(vec![dx]);
15367
15368 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
15369 let mut expected = vec![0f32; 16];
15370 expected[5] = 0.5;
15371 expected[7] = 1.0;
15372 expected[13] = 2.0;
15373 expected[15] = 4.0;
15374 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15375 assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
15376 }
15377 }
15378
15379 #[test]
15380 fn conv2d_backward_input_matches_numerical_gradient() {
15381 use rlx_ir::Philox4x32;
15382 let n = 1usize;
15385 let c_in = 2usize;
15386 let h = 4usize;
15387 let w = 4usize;
15388 let c_out = 3usize;
15389 let kh = 3usize;
15390 let kw = 3usize;
15391 let ph = 1usize;
15392 let pw = 1usize;
15393 let sh = 1usize;
15394 let sw = 1usize;
15395 let h_out = (h + 2 * ph - kh) / sh + 1;
15397 let w_out = (w + 2 * pw - kw) / sw + 1;
15398 assert_eq!(h_out, 4);
15399 assert_eq!(w_out, 4);
15400
15401 let mut rng = Philox4x32::new(7);
15402 let mut x = vec![0f32; n * c_in * h * w];
15403 rng.fill_normal(&mut x);
15404 let mut wt = vec![0f32; c_out * c_in * kh * kw];
15405 rng.fill_normal(&mut wt);
15406 let mut dy = vec![0f32; n * c_out * h_out * w_out];
15407 rng.fill_normal(&mut dy);
15408
15409 let f = DType::F32;
15411 let mut g = Graph::new("conv_bwi");
15412 let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15413 let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
15414 let dx = g.conv2d_backward_input(
15415 dy_in,
15416 w_in,
15417 Shape::new(&[n, c_in, h, w], f),
15418 vec![kh, kw],
15419 vec![sh, sw],
15420 vec![ph, pw],
15421 vec![1, 1],
15422 1,
15423 );
15424 g.set_outputs(vec![dx]);
15425 let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
15426
15427 let forward = |x: &[f32]| -> Vec<f32> {
15431 let mut out = vec![0f32; n * c_out * h_out * w_out];
15432 for ni in 0..n {
15433 for co in 0..c_out {
15434 for ho in 0..h_out {
15435 for wo in 0..w_out {
15436 let mut acc = 0f32;
15437 for ci in 0..c_in {
15438 for ki in 0..kh {
15439 for kj in 0..kw {
15440 let hi = ho * sh + ki;
15441 let wi = wo * sw + kj;
15442 if hi < ph || wi < pw {
15443 continue;
15444 }
15445 let hi = hi - ph;
15446 let wi = wi - pw;
15447 if hi >= h || wi >= w {
15448 continue;
15449 }
15450 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15451 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15452 acc += xv * wv;
15453 }
15454 }
15455 }
15456 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15457 }
15458 }
15459 }
15460 }
15461 out
15462 };
15463 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15464 let eps = 1e-3f32;
15465 let mut numerical = vec![0f32; x.len()];
15466 for i in 0..x.len() {
15467 let saved = x[i];
15468 x[i] = saved + eps;
15469 let plus = dot(&forward(&x), &dy);
15470 x[i] = saved - eps;
15471 let minus = dot(&forward(&x), &dy);
15472 x[i] = saved;
15473 numerical[i] = (plus - minus) / (2.0 * eps);
15474 }
15475 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15476 assert!(
15478 (a - n).abs() < 5e-3,
15479 "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
15480 );
15481 }
15482 }
15483
15484 #[test]
15485 fn conv2d_backward_weight_matches_numerical_gradient() {
15486 use rlx_ir::Philox4x32;
15487 let n = 2usize;
15488 let c_in = 2usize;
15489 let h = 4usize;
15490 let w = 4usize;
15491 let c_out = 2usize;
15492 let kh = 3usize;
15493 let kw = 3usize;
15494 let ph = 0usize;
15495 let pw = 0usize;
15496 let sh = 1usize;
15497 let sw = 1usize;
15498 let h_out = (h + 2 * ph - kh) / sh + 1;
15499 let w_out = (w + 2 * pw - kw) / sw + 1;
15500
15501 let mut rng = Philox4x32::new(11);
15502 let mut x = vec![0f32; n * c_in * h * w];
15503 rng.fill_normal(&mut x);
15504 let mut wt = vec![0f32; c_out * c_in * kh * kw];
15505 rng.fill_normal(&mut wt);
15506 let mut dy = vec![0f32; n * c_out * h_out * w_out];
15507 rng.fill_normal(&mut dy);
15508
15509 let f = DType::F32;
15510 let mut g = Graph::new("conv_bww");
15511 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
15512 let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15513 let dwn = g.conv2d_backward_weight(
15514 xn,
15515 dyn_,
15516 Shape::new(&[c_out, c_in, kh, kw], f),
15517 vec![kh, kw],
15518 vec![sh, sw],
15519 vec![ph, pw],
15520 vec![1, 1],
15521 1,
15522 );
15523 g.set_outputs(vec![dwn]);
15524 let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
15525
15526 let forward = |wt: &[f32]| -> Vec<f32> {
15527 let mut out = vec![0f32; n * c_out * h_out * w_out];
15528 for ni in 0..n {
15529 for co in 0..c_out {
15530 for ho in 0..h_out {
15531 for wo in 0..w_out {
15532 let mut acc = 0f32;
15533 for ci in 0..c_in {
15534 for ki in 0..kh {
15535 for kj in 0..kw {
15536 let hi = ho + ki;
15537 let wi = wo + kj;
15538 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15539 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15540 acc += xv * wv;
15541 }
15542 }
15543 }
15544 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15545 }
15546 }
15547 }
15548 }
15549 out
15550 };
15551 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15552 let eps = 1e-3f32;
15553 let mut numerical = vec![0f32; wt.len()];
15554 for i in 0..wt.len() {
15555 let saved = wt[i];
15556 wt[i] = saved + eps;
15557 let plus = dot(&forward(&wt), &dy);
15558 wt[i] = saved - eps;
15559 let minus = dot(&forward(&wt), &dy);
15560 wt[i] = saved;
15561 numerical[i] = (plus - minus) / (2.0 * eps);
15562 }
15563 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15564 assert!(
15565 (a - n).abs() < 5e-3,
15566 "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
15567 );
15568 }
15569 }
15570
15571 #[test]
15572 fn softmax_cross_entropy_matches_reference() {
15573 let f = DType::F32;
15574 let logits: Vec<f32> = vec![
15575 1.0, 2.0, 3.0, -1.0, 0.0, 4.0, 5.0, 5.0, 5.0, ];
15579 let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
15580
15581 let mut g = Graph::new("sce");
15582 let lg = g.input("logits", Shape::new(&[3, 3], f));
15583 let lb = g.input("labels", Shape::new(&[3], f));
15584 let loss = g.softmax_cross_entropy_with_logits(lg, lb);
15585 g.set_outputs(vec![loss]);
15586 let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
15587
15588 let mut expected = vec![0f32; 3];
15590 for ni in 0..3 {
15591 let row = &logits[ni * 3..(ni + 1) * 3];
15592 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15593 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15594 let lse = m + sum.ln();
15595 let label_idx = labels[ni] as usize;
15596 expected[ni] = lse - row[label_idx];
15597 }
15598 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15599 assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
15600 }
15601 }
15602
15603 #[test]
15604 fn softmax_cross_entropy_backward_matches_numerical_gradient() {
15605 use rlx_ir::Philox4x32;
15606 let n = 4usize;
15607 let c = 5usize;
15608 let mut rng = Philox4x32::new(23);
15609 let mut logits = vec![0f32; n * c];
15610 rng.fill_normal(&mut logits);
15611 let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
15612 let mut d_loss = vec![0f32; n];
15613 rng.fill_normal(&mut d_loss);
15614
15615 let f = DType::F32;
15616 let mut g = Graph::new("sce_bw");
15617 let lg = g.input("logits", Shape::new(&[n, c], f));
15618 let lb = g.input("labels", Shape::new(&[n], f));
15619 let dl = g.input("d_loss", Shape::new(&[n], f));
15620 let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
15621 g.set_outputs(vec![dlogits]);
15622 let analytical = run_graph(
15623 &g,
15624 &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
15625 dlogits,
15626 n * c,
15627 );
15628
15629 let sce_loss = |logits: &[f32]| -> Vec<f32> {
15631 let mut out = vec![0f32; n];
15632 for ni in 0..n {
15633 let row = &logits[ni * c..(ni + 1) * c];
15634 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15635 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15636 out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
15637 }
15638 out
15639 };
15640 let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
15641 let eps = 1e-3f32;
15642 let mut numerical = vec![0f32; logits.len()];
15643 for i in 0..logits.len() {
15644 let saved = logits[i];
15645 logits[i] = saved + eps;
15646 let plus = dot(&sce_loss(&logits), &d_loss);
15647 logits[i] = saved - eps;
15648 let minus = dot(&sce_loss(&logits), &d_loss);
15649 logits[i] = saved;
15650 numerical[i] = (plus - minus) / (2.0 * eps);
15651 }
15652 for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
15653 assert!(
15654 (a - num).abs() < 5e-3,
15655 "sce_bw[{i}]: analytical {a} vs numerical {num}"
15656 );
15657 }
15658 }
15659
15660 fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
15673 for node in graph.nodes() {
15674 if let Op::Constant { data } = &node.op
15675 && arena.has_buffer(node.id)
15676 && !data.is_empty()
15677 {
15678 let buf = arena.slice_mut(node.id);
15679 let n_floats = data.len() / 4;
15680 let n = buf.len().min(n_floats);
15681 for i in 0..n {
15682 let bytes = [
15683 data[i * 4],
15684 data[i * 4 + 1],
15685 data[i * 4 + 2],
15686 data[i * 4 + 3],
15687 ];
15688 buf[i] = f32::from_le_bytes(bytes);
15689 }
15690 }
15691 }
15692 }
15693
15694 fn prepare(
15698 graph: &Graph,
15699 seed_inputs: &[(NodeId, &[f32])],
15700 ) -> (ThunkSchedule, crate::arena::Arena) {
15701 let plan = rlx_opt::memory::plan_memory(graph);
15702 let mut arena = crate::arena::Arena::from_plan(plan);
15703 let sched = compile_thunks(graph, &arena);
15704 fill_constants_into_arena(graph, &mut arena);
15705 for &(id, data) in seed_inputs {
15706 let off = arena.byte_offset(id);
15707 let buf = arena.raw_buf_mut();
15708 unsafe {
15709 let p = buf.as_mut_ptr().add(off) as *mut f32;
15710 for (i, &v) in data.iter().enumerate() {
15711 *p.add(i) = v;
15712 }
15713 }
15714 }
15715 (sched, arena)
15716 }
15717
15718 fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
15719 let off = arena.byte_offset(id);
15720 unsafe {
15721 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15722 (0..len).map(|i| *p.add(i)).collect()
15723 }
15724 }
15725
15726 fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
15727 let off = arena.byte_offset(id);
15728 let buf = arena.raw_buf_mut();
15729 unsafe {
15730 let p = buf.as_mut_ptr().add(off) as *mut f32;
15731 for (i, &v) in data.iter().enumerate() {
15732 *p.add(i) = v;
15733 }
15734 }
15735 }
15736
15737 fn prepare_f64(
15739 graph: &Graph,
15740 seed_inputs: &[(NodeId, &[f64])],
15741 ) -> (ThunkSchedule, crate::arena::Arena) {
15742 let plan = rlx_opt::memory::plan_memory(graph);
15743 let mut arena = crate::arena::Arena::from_plan(plan);
15744 let sched = compile_thunks(graph, &arena);
15745 fill_constants_into_arena(graph, &mut arena);
15746 for &(id, data) in seed_inputs {
15747 let off = arena.byte_offset(id);
15748 let buf = arena.raw_buf_mut();
15749 unsafe {
15750 let p = buf.as_mut_ptr().add(off) as *mut f64;
15751 for (i, &v) in data.iter().enumerate() {
15752 *p.add(i) = v;
15753 }
15754 }
15755 }
15756 (sched, arena)
15757 }
15758
15759 fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
15760 let off = arena.byte_offset(id);
15761 unsafe {
15762 let p = arena.raw_buf().as_ptr().add(off) as *const f64;
15763 (0..len).map(|i| *p.add(i)).collect()
15764 }
15765 }
15766
15767 #[test]
15777 fn dense_solve_f64_end_to_end() {
15778 let mut g = Graph::new("solve_e2e");
15779 let a = g.input("A", Shape::new(&[2, 2], DType::F64));
15780 let b = g.input("b", Shape::new(&[2], DType::F64));
15781 let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
15782 g.set_outputs(vec![x]);
15783
15784 let a_data = [2.0, 1.0, 1.0, 3.0_f64];
15785 let b_data = [5.0, 10.0_f64];
15786 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15787 execute_thunks(&sched, arena.raw_buf_mut());
15788
15789 let got = read_arena_f64(&arena, x, 2);
15790 let want = [1.0, 3.0_f64];
15791 for i in 0..2 {
15792 assert!(
15793 (got[i] - want[i]).abs() < 1e-12,
15794 "x[{i}] = {} (expected {})",
15795 got[i],
15796 want[i]
15797 );
15798 }
15799 }
15800
15801 #[test]
15807 fn dense_solve_f64_5x5_laplacian() {
15808 let n = 5usize;
15809 let mut g = Graph::new("solve_5x5");
15810 let a = g.input("A", Shape::new(&[n, n], DType::F64));
15811 let b = g.input("b", Shape::new(&[n], DType::F64));
15812 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15813 g.set_outputs(vec![x]);
15814
15815 let mut a_data = vec![0.0_f64; n * n];
15817 for i in 0..n {
15818 a_data[i * n + i] = 2.0;
15819 if i > 0 {
15820 a_data[i * n + (i - 1)] = -1.0;
15821 }
15822 if i + 1 < n {
15823 a_data[i * n + (i + 1)] = -1.0;
15824 }
15825 }
15826 let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
15827 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15828 execute_thunks(&sched, arena.raw_buf_mut());
15829
15830 let got = read_arena_f64(&arena, x, n);
15831 let mut residual = vec![0.0_f64; n];
15833 for i in 0..n {
15834 for j in 0..n {
15835 residual[i] += a_data[i * n + j] * got[j];
15836 }
15837 }
15838 for i in 0..n {
15839 assert!(
15840 (residual[i] - b_data[i]).abs() < 1e-10,
15841 "row {i}: residual {} vs b {}",
15842 residual[i],
15843 b_data[i]
15844 );
15845 }
15846 }
15847
15848 #[test]
15867 fn hello_resistor_gradient_end_to_end() {
15868 use rlx_opt::autodiff::grad_with_loss;
15869 let n = 3usize;
15870
15871 let mut g = Graph::new("hello_resistor");
15873 let a = g.param("A", Shape::new(&[n, n], DType::F64));
15874 let b = g.input("b", Shape::new(&[n], DType::F64));
15875 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15876 let loss = g.reduce(
15877 x,
15878 ReduceOp::Sum,
15879 vec![0],
15880 false,
15881 Shape::new(&[1], DType::F64),
15882 );
15883 g.set_outputs(vec![loss]);
15884
15885 let bwd = grad_with_loss(&g, &[a, b]);
15887 assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
15888
15889 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
15893 for node in graph.nodes() {
15894 let name = match &node.op {
15895 rlx_ir::Op::Input { name } => Some(name.as_str()),
15896 rlx_ir::Op::Param { name } => Some(name.as_str()),
15897 _ => None,
15898 };
15899 if name == Some(want) {
15900 return node.id;
15901 }
15902 }
15903 panic!("no node named {want:?} in bwd graph");
15904 };
15905 let a_bwd = find_by_name(&bwd, "A");
15906 let b_bwd = find_by_name(&bwd, "b");
15907 let d_out_bwd = find_by_name(&bwd, "d_output");
15908
15909 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
15913 let b_data = [1.0, 2.0, 3.0_f64];
15914 let d_output = [1.0_f64]; let (sched, mut arena) = prepare_f64(
15918 &bwd,
15919 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
15920 );
15921 execute_thunks(&sched, arena.raw_buf_mut());
15922
15923 let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
15924 let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
15925 let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
15926
15927 let x_ref = {
15930 let mut a = a_data;
15931 let mut b = b_data;
15932 let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
15933 assert_eq!(info, 0);
15934 b
15935 };
15936 let loss_ref: f64 = x_ref.iter().sum();
15937 let db_ref = {
15939 let mut at = [0.0_f64; 9];
15940 for i in 0..n {
15941 for j in 0..n {
15942 at[i * n + j] = a_data[j * n + i];
15943 }
15944 }
15945 let mut ones = [1.0_f64; 3];
15946 let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
15947 assert_eq!(info, 0);
15948 ones
15949 };
15950 let mut da_ref = [0.0_f64; 9];
15952 for i in 0..n {
15953 for j in 0..n {
15954 da_ref[i * n + j] = -db_ref[i] * x_ref[j];
15955 }
15956 }
15957
15958 assert!(
15960 (loss_out[0] - loss_ref).abs() < 1e-10,
15961 "loss: got {}, want {}",
15962 loss_out[0],
15963 loss_ref
15964 );
15965 for i in 0..n {
15966 assert!(
15967 (db_out[i] - db_ref[i]).abs() < 1e-10,
15968 "db[{i}]: got {}, want {}",
15969 db_out[i],
15970 db_ref[i]
15971 );
15972 }
15973 for i in 0..n * n {
15974 assert!(
15975 (da_out[i] - da_ref[i]).abs() < 1e-10,
15976 "dA[{i}]: got {}, want {}",
15977 da_out[i],
15978 da_ref[i]
15979 );
15980 }
15981
15982 let h = 1e-6_f64;
15985 for k in 0..n {
15986 let mut bp = b_data;
15987 bp[k] += h;
15988 let mut bm = b_data;
15989 bm[k] -= h;
15990 let lp = {
15991 let mut ac = a_data;
15992 let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
15993 assert_eq!(info, 0);
15994 bp.iter().sum::<f64>()
15995 };
15996 let lm = {
15997 let mut ac = a_data;
15998 let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
15999 assert_eq!(info, 0);
16000 bm.iter().sum::<f64>()
16001 };
16002 let fd = (lp - lm) / (2.0 * h);
16003 assert!(
16004 (db_out[k] - fd).abs() < 1e-7,
16005 "FD mismatch on db[{k}]: AD={} FD={}",
16006 db_out[k],
16007 fd
16008 );
16009 }
16010 }
16011
16012 #[test]
16017 fn scan_geometric_growth_f64() {
16018 let n = 3usize;
16019 let length = 10u32;
16020
16021 let mut body = Graph::new("scan_body");
16023 let x = body.input("carry", Shape::new(&[n], DType::F64));
16024 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
16025 let scale = body.add_node(
16026 Op::Constant { data: scale_bytes },
16027 vec![],
16028 Shape::new(&[n], DType::F64),
16029 );
16030 let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
16031 let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
16032 body.set_outputs(vec![next]);
16033
16034 let mut g = Graph::new("scan_outer");
16036 let init = g.input("init", Shape::new(&[n], DType::F64));
16037 let final_carry = g.scan(init, body, length);
16038 g.set_outputs(vec![final_carry]);
16039
16040 let init_data = vec![1.0_f64; n];
16041 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
16042 execute_thunks(&sched, arena.raw_buf_mut());
16043 let got = read_arena_f64(&arena, final_carry, n);
16044 let want: f64 = 1.1_f64.powi(length as i32);
16045 for i in 0..n {
16046 assert!(
16047 (got[i] - want).abs() < 1e-12,
16048 "got[{i}] = {} want {}",
16049 got[i],
16050 want
16051 );
16052 }
16053 }
16054
16055 #[test]
16062 fn scan_with_xs_cumulative_sum() {
16063 let n = 3usize;
16064 let length = 4u32;
16065
16066 let mut body = Graph::new("cumsum_body");
16067 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16069 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
16070 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
16071 body.set_outputs(vec![next]);
16072
16073 let mut g = Graph::new("cumsum_outer");
16074 let init = g.input("init", Shape::new(&[n], DType::F64));
16075 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16076 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16077 g.set_outputs(vec![final_carry]);
16078
16079 let init_data = vec![0.0_f64; n];
16080 let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
16082 execute_thunks(&sched, arena.raw_buf_mut());
16083 let got = read_arena_f64(&arena, final_carry, n);
16084
16085 let mut want = init_data.clone();
16089 for t in 0..length as usize {
16090 for j in 0..n {
16091 want[j] += xs_data[t * n + j];
16092 }
16093 }
16094 for i in 0..n {
16095 assert!(
16096 (got[i] - want[i]).abs() < 1e-12,
16097 "got[{i}] = {} want {}",
16098 got[i],
16099 want[i]
16100 );
16101 }
16102 }
16103
16104 #[test]
16108 fn scan_with_xs_be_with_drive() {
16109 let n = 3usize;
16110 let length = 4u32;
16111 let dt = 0.1_f64;
16112
16113 let mut m_data = vec![0.0_f64; n * n];
16114 for i in 0..n {
16115 m_data[i * n + i] = 1.0 + dt * 2.0;
16116 if i > 0 {
16117 m_data[i * n + (i - 1)] = -dt;
16118 }
16119 if i + 1 < n {
16120 m_data[i * n + (i + 1)] = -dt;
16121 }
16122 }
16123 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16124
16125 let mut body = Graph::new("be_drive_body");
16126 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16127 let drive = body.input("drive", Shape::new(&[n], DType::F64));
16128 let m = body.add_node(
16129 Op::Constant { data: m_bytes },
16130 vec![],
16131 Shape::new(&[n, n], DType::F64),
16132 );
16133 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16134 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16135 body.set_outputs(vec![next]);
16136
16137 let mut g = Graph::new("be_drive_outer");
16138 let init = g.input("init", Shape::new(&[n], DType::F64));
16139 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16140 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16141 g.set_outputs(vec![final_carry]);
16142
16143 let init_data = vec![0.0_f64; n];
16144 let mut xs_data = vec![0.0_f64; length as usize * n];
16147 xs_data[0] = 1.0;
16148
16149 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
16150 execute_thunks(&sched, arena.raw_buf_mut());
16151 let got = read_arena_f64(&arena, final_carry, n);
16152
16153 let mut x = init_data.clone();
16155 for t in 0..length as usize {
16156 for j in 0..n {
16157 x[j] += xs_data[t * n + j];
16158 }
16159 let mut a_copy = m_data.clone();
16160 crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
16161 }
16162 for i in 0..n {
16163 assert!(
16164 (got[i] - x[i]).abs() < 1e-12,
16165 "got[{i}] = {} ref {}",
16166 got[i],
16167 x[i]
16168 );
16169 }
16170 }
16171
16172 #[test]
16178 fn batched_dense_solve_gradient_matches_per_batch_analytic() {
16179 use rlx_opt::autodiff::grad_with_loss;
16180 let n = 3usize;
16181 let batch = 4usize;
16182
16183 let mut g = Graph::new("bds_grad");
16184 let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
16185 let b = g.input("b", Shape::new(&[batch, n], DType::F64));
16186 let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
16187 let loss = g.reduce(
16188 x,
16189 ReduceOp::Sum,
16190 vec![0, 1],
16191 false,
16192 Shape::new(&[1], DType::F64),
16193 );
16194 g.set_outputs(vec![loss]);
16195
16196 let bwd = grad_with_loss(&g, &[a, b]);
16197
16198 let find = |graph: &Graph, want: &str| -> NodeId {
16199 for node in graph.nodes() {
16200 let name = match &node.op {
16201 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16202 _ => None,
16203 };
16204 if name == Some(want) {
16205 return node.id;
16206 }
16207 }
16208 panic!("no node named {want}");
16209 };
16210 let a_id = find(&bwd, "A");
16211 let b_id = find(&bwd, "b");
16212 let d_out_id = find(&bwd, "d_output");
16213
16214 let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
16215 let mut a_data = vec![0.0_f64; batch * n * n];
16216 let mut b_data = vec![0.0_f64; batch * n];
16217 for bi in 0..batch {
16218 for i in 0..n {
16219 for j in 0..n {
16220 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16221 }
16222 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16223 }
16224 for i in 0..n {
16225 b_data[bi * n + i] = rng.next_f32() as f64;
16226 }
16227 }
16228 let d_seed = [1.0_f64];
16229
16230 let (sched, mut arena) = prepare_f64(
16231 &bwd,
16232 &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
16233 );
16234 execute_thunks(&sched, arena.raw_buf_mut());
16235 let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
16236 let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
16237
16238 for bi in 0..batch {
16241 let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16242 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16243 let mut a_copy = a_slice.clone();
16244 crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
16245 let x_ref = b_slice.clone();
16246 let mut at = vec![0.0_f64; n * n];
16248 for i in 0..n {
16249 for j in 0..n {
16250 at[i * n + j] = a_slice[j * n + i];
16251 }
16252 }
16253 let mut ones = vec![1.0_f64; n];
16254 crate::blas::dgesv(&mut at, &mut ones, n, 1);
16255 let db_ref = ones;
16256 for i in 0..n {
16257 let got = db_out[bi * n + i];
16258 assert!(
16259 (got - db_ref[i]).abs() < 1e-10,
16260 "batch {bi}, db[{i}]: got {got} ref {}",
16261 db_ref[i]
16262 );
16263 }
16264 for i in 0..n {
16266 for j in 0..n {
16267 let got = da_out[bi * n * n + i * n + j];
16268 let want = -db_ref[i] * x_ref[j];
16269 assert!(
16270 (got - want).abs() < 1e-10,
16271 "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
16272 );
16273 }
16274 }
16275 }
16276 }
16277
16278 #[test]
16283 fn scan_checkpointed_grad_matches_plain_scan_grad() {
16284 use rlx_opt::autodiff::grad_with_loss;
16285 let n = 2usize;
16286 let length = 6u32;
16287
16288 let make_body = || {
16289 let mut body = Graph::new("ck_body");
16290 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16291 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
16292 let scale = body.add_node(
16293 Op::Constant { data: scale_bytes },
16294 vec![],
16295 Shape::new(&[n], DType::F64),
16296 );
16297 let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
16298 body.set_outputs(vec![next]);
16299 body
16300 };
16301
16302 let mut g_plain = Graph::new("ck_plain");
16304 let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
16305 let final_p = g_plain.scan(init_p, make_body(), length);
16306 let loss_p = g_plain.reduce(
16307 final_p,
16308 ReduceOp::Sum,
16309 vec![0],
16310 false,
16311 Shape::new(&[1], DType::F64),
16312 );
16313 g_plain.set_outputs(vec![loss_p]);
16314 let bwd_p = grad_with_loss(&g_plain, &[init_p]);
16315
16316 let mut g_ck = Graph::new("ck_ckpt");
16318 let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
16319 let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
16320 let loss_c = g_ck.reduce(
16321 final_c,
16322 ReduceOp::Sum,
16323 vec![0],
16324 false,
16325 Shape::new(&[1], DType::F64),
16326 );
16327 g_ck.set_outputs(vec![loss_c]);
16328 let bwd_c = grad_with_loss(&g_ck, &[init_c]);
16329
16330 let find = |graph: &Graph, want: &str| -> NodeId {
16331 for node in graph.nodes() {
16332 let name = match &node.op {
16333 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16334 _ => None,
16335 };
16336 if name == Some(want) {
16337 return node.id;
16338 }
16339 }
16340 panic!("no {want}");
16341 };
16342
16343 let init_data = vec![0.5_f64, -0.5];
16344 let d_seed = [1.0_f64];
16345
16346 let (s_p, mut a_p) = prepare_f64(
16347 &bwd_p,
16348 &[
16349 (find(&bwd_p, "init"), &init_data),
16350 (find(&bwd_p, "d_output"), &d_seed),
16351 ],
16352 );
16353 execute_thunks(&s_p, a_p.raw_buf_mut());
16354 let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
16355
16356 let (s_c, mut a_c) = prepare_f64(
16357 &bwd_c,
16358 &[
16359 (find(&bwd_c, "init"), &init_data),
16360 (find(&bwd_c, "d_output"), &d_seed),
16361 ],
16362 );
16363 execute_thunks(&s_c, a_c.raw_buf_mut());
16364 let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
16365
16366 for i in 0..n {
16367 assert!(
16368 (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
16369 "dinit[{i}]: plain={} checkpointed={}",
16370 dinit_p[i],
16371 dinit_c[i]
16372 );
16373 }
16374 }
16375
16376 #[test]
16382 fn recursive_checkpointing_matches_full_trajectory() {
16383 let n = 2usize;
16384 let length = 4u32;
16385
16386 let build_body = || -> Graph {
16388 let mut body = Graph::new("rc_body");
16389 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16390 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16391 let ones = body.add_node(
16392 Op::Constant { data: ones_bytes },
16393 vec![],
16394 Shape::new(&[n], DType::F64),
16395 );
16396 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16397 body.set_outputs(vec![next]);
16398 body
16399 };
16400
16401 let body_vjp_for = || -> Graph {
16404 use rlx_opt::autodiff::grad;
16405 let body = build_body();
16406 let carry_id = body
16408 .nodes()
16409 .iter()
16410 .find(|n| matches!(n.op, Op::Input { .. }))
16411 .map(|n| n.id)
16412 .unwrap();
16413 grad(&body, &[carry_id])
16414 };
16415
16416 let mut g_full = Graph::new("rc_outer_full");
16418 let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
16419 let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
16420 let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16422 let dinit_full_id = g_full.scan_backward(
16423 init_full,
16424 traj_full_id,
16425 upstream_full,
16426 &[],
16427 body_vjp_for(),
16428 length,
16429 true,
16430 Shape::new(&[n], DType::F64),
16431 );
16432 g_full.set_outputs(vec![dinit_full_id]);
16433
16434 let k = 2u32;
16437 let mut g_rec = Graph::new("rc_outer_rec");
16438 let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
16439 let traj_rec_id = g_rec.add_node(
16440 Op::Scan {
16441 body: Box::new(build_body()),
16442 length,
16443 save_trajectory: true,
16444 num_bcast: 0,
16445 num_xs: 0,
16446 num_checkpoints: k,
16447 },
16448 vec![init_rec],
16449 Shape::new(&[k as usize, n], DType::F64),
16450 );
16451 let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16454 let dinit_rec_id = g_rec.add_node(
16455 Op::ScanBackward {
16456 body_vjp: Box::new(body_vjp_for()),
16457 length,
16458 save_trajectory: true,
16459 num_xs: 0,
16460 num_checkpoints: k,
16461 forward_body: Some(Box::new(build_body())),
16462 },
16463 vec![init_rec, traj_rec_id, upstream_rec],
16464 Shape::new(&[n], DType::F64),
16465 );
16466 g_rec.set_outputs(vec![dinit_rec_id]);
16467
16468 let init_data = vec![0.5_f64, -0.5];
16470 let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
16471
16472 let find = |graph: &Graph, want: &str| -> NodeId {
16473 for node in graph.nodes() {
16474 if let Op::Input { name } = &node.op
16475 && name == want
16476 {
16477 return node.id;
16478 }
16479 }
16480 panic!("no input {want}");
16481 };
16482
16483 let (s_full, mut a_full) = prepare_f64(
16484 &g_full,
16485 &[
16486 (find(&g_full, "init"), &init_data),
16487 (find(&g_full, "upstream"), &upstream_data),
16488 ],
16489 );
16490 execute_thunks(&s_full, a_full.raw_buf_mut());
16491 let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
16492
16493 let (s_rec, mut a_rec) = prepare_f64(
16494 &g_rec,
16495 &[
16496 (find(&g_rec, "init"), &init_data),
16497 (find(&g_rec, "upstream"), &upstream_data),
16498 ],
16499 );
16500 execute_thunks(&s_rec, a_rec.raw_buf_mut());
16501 let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
16502
16503 for i in 0..n {
16504 assert!(
16505 (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
16506 "i={i}: full={} rec={}",
16507 dinit_full[i],
16508 dinit_rec[i]
16509 );
16510 }
16511 }
16512
16513 #[test]
16522 fn vmap_of_grad_scan_matches_per_row_runs() {
16523 use rlx_opt::autodiff::grad_with_loss;
16524 use rlx_opt::vmap::vmap;
16525 let n = 2usize;
16526 let length = 3u32;
16527 let batch = 3usize;
16528
16529 let mut body = Graph::new("scan_grad_body");
16530 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16531 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16532 let ones = body.add_node(
16533 Op::Constant { data: ones_bytes },
16534 vec![],
16535 Shape::new(&[n], DType::F64),
16536 );
16537 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16538 body.set_outputs(vec![next]);
16539
16540 let mut g = Graph::new("scan_grad_outer");
16541 let init = g.input("init", Shape::new(&[n], DType::F64));
16542 let final_x = g.scan(init, body, length);
16543 let loss = g.reduce(
16544 final_x,
16545 ReduceOp::Sum,
16546 vec![0],
16547 false,
16548 Shape::new(&[1], DType::F64),
16549 );
16550 g.set_outputs(vec![loss]);
16551
16552 let bwd = grad_with_loss(&g, &[init]);
16553 let bg = vmap(&bwd, &["init"], batch);
16554
16555 let find = |graph: &Graph, want: &str| -> NodeId {
16556 for node in graph.nodes() {
16557 let name = match &node.op {
16558 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16559 _ => None,
16560 };
16561 if name == Some(want) {
16562 return node.id;
16563 }
16564 }
16565 panic!("no node named {want}");
16566 };
16567 let init_b = find(&bg, "init");
16568 let d_out_b = find(&bg, "d_output");
16569
16570 let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
16571 let d_seed = [1.0_f64];
16572
16573 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
16574 execute_thunks(&sched, arena.raw_buf_mut());
16575 let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
16576
16577 for i in 0..batch * n {
16578 assert!(
16579 (dinit_b[i] - 1.0).abs() < 1e-12,
16580 "dinit[{i}] = {} (expected 1.0)",
16581 dinit_b[i]
16582 );
16583 }
16584
16585 for bi in 0..batch {
16587 let row = &init_data[bi * n..(bi + 1) * n];
16588 let mut g2 = Graph::new("per_row_grad");
16589 let init2 = g2.input("init", Shape::new(&[n], DType::F64));
16590 let mut body2 = Graph::new("per_row_body");
16591 let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
16592 let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16593 let ones2 = body2.add_node(
16594 Op::Constant { data: ones2_bytes },
16595 vec![],
16596 Shape::new(&[n], DType::F64),
16597 );
16598 let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
16599 body2.set_outputs(vec![next2]);
16600 let final2 = g2.scan(init2, body2, length);
16601 let loss2 = g2.reduce(
16602 final2,
16603 ReduceOp::Sum,
16604 vec![0],
16605 false,
16606 Shape::new(&[1], DType::F64),
16607 );
16608 g2.set_outputs(vec![loss2]);
16609 let bwd2 = grad_with_loss(&g2, &[init2]);
16610 let init2_id = find(&bwd2, "init");
16611 let d_out2_id = find(&bwd2, "d_output");
16612 let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
16613 execute_thunks(&s2, a2.raw_buf_mut());
16614 let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
16615 for j in 0..n {
16616 let got = dinit_b[bi * n + j];
16617 let want = row_dinit[j];
16618 assert!(
16619 (got - want).abs() < 1e-12,
16620 "row {bi}, j {j}: vmap'd={got} per-row={want}"
16621 );
16622 }
16623 }
16624 }
16625
16626 #[test]
16632 fn vmap_scan_cumulative_sum_matches_scalar_runs() {
16633 use rlx_opt::vmap::vmap;
16634 let n = 2usize;
16635 let length = 4u32;
16636 let batch = 3usize;
16637
16638 let mut body = Graph::new("scan_body_cumsum");
16640 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16641 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
16642 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
16643 body.set_outputs(vec![next]);
16644
16645 let mut g = Graph::new("scan_outer_cumsum");
16646 let init = g.input("init", Shape::new(&[n], DType::F64));
16647 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16648 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16649 g.set_outputs(vec![final_carry]);
16650
16651 let bg = vmap(&g, &["init", "xs"], batch);
16653
16654 let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
16656 let xs_data: Vec<f64> = (0..batch * length as usize * n)
16659 .map(|i| 0.1 * (i as f64))
16660 .collect();
16661
16662 let find = |graph: &Graph, want: &str| -> NodeId {
16663 for node in graph.nodes() {
16664 if let Op::Input { name } = &node.op
16665 && name == want
16666 {
16667 return node.id;
16668 }
16669 }
16670 panic!("no input {want}");
16671 };
16672 let init_b = find(&bg, "init");
16673 let xs_b = find(&bg, "xs");
16674 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
16675 execute_thunks(&sched, arena.raw_buf_mut());
16676 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
16677
16678 for bi in 0..batch {
16680 let init_slice = &init_data[bi * n..(bi + 1) * n];
16681 let mut x = init_slice.to_vec();
16682 for t in 0..length as usize {
16683 for j in 0..n {
16684 x[j] += xs_data[bi * length as usize * n + t * n + j];
16685 }
16686 }
16687
16688 for i in 0..n {
16689 let got = batched_out[bi * n + i];
16690 assert!(
16691 (got - x[i]).abs() < 1e-12,
16692 "row {bi}, i {i}: got {got} ref {}",
16693 x[i]
16694 );
16695 }
16696 }
16697 }
16698
16699 #[test]
16704 fn vmap_dense_solve_matches_scalar_runs() {
16705 use rlx_opt::vmap::vmap;
16706 let n = 3usize;
16707 let batch = 4usize;
16708
16709 let mut g = Graph::new("solve_forward");
16710 let a = g.input("A", Shape::new(&[n, n], DType::F64));
16711 let b = g.input("b", Shape::new(&[n], DType::F64));
16712 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
16713 g.set_outputs(vec![x]);
16714
16715 let bg = vmap(&g, &["A", "b"], batch);
16717
16718 let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
16720 let mut a_data = vec![0.0_f64; batch * n * n];
16721 let mut b_data = vec![0.0_f64; batch * n];
16722 for bi in 0..batch {
16723 for i in 0..n {
16725 for j in 0..n {
16726 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16727 }
16728 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16729 }
16730 for i in 0..n {
16731 b_data[bi * n + i] = rng.next_f32() as f64;
16732 }
16733 }
16734
16735 let find = |graph: &Graph, want: &str| -> NodeId {
16736 for node in graph.nodes() {
16737 if let Op::Input { name } = &node.op
16738 && name == want
16739 {
16740 return node.id;
16741 }
16742 }
16743 panic!("no input named {want}");
16744 };
16745 let ba = find(&bg, "A");
16746 let bb = find(&bg, "b");
16747 let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
16748 execute_thunks(&sched, arena.raw_buf_mut());
16749 let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
16750
16751 for bi in 0..batch {
16753 let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16754 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16755 crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
16756 for i in 0..n {
16757 let got = batched_x[bi * n + i];
16758 let want = b_slice[i];
16759 assert!(
16760 (got - want).abs() < 1e-12,
16761 "row {bi}, i {i}: got {got} want {want}"
16762 );
16763 }
16764 }
16765 }
16766
16767 #[test]
16774 fn vmap_matmul_add_reduce_matches_scalar_runs() {
16775 use rlx_opt::vmap::vmap;
16776 let n = 3usize;
16777 let batch = 4usize;
16778
16779 let mut g = Graph::new("vmap_e2e_forward");
16781 let x = g.input("x", Shape::new(&[n], DType::F64));
16782 let w = g.input("w", Shape::new(&[n, n], DType::F64));
16783 let b = g.input("b", Shape::new(&[n], DType::F64));
16784 let x_row = g.add_node(
16785 Op::Reshape {
16786 new_shape: vec![1, n as i64],
16787 },
16788 vec![x],
16789 Shape::new(&[1, n], DType::F64),
16790 );
16791 let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
16792 let mm_flat = g.add_node(
16793 Op::Reshape {
16794 new_shape: vec![n as i64],
16795 },
16796 vec![mm],
16797 Shape::new(&[n], DType::F64),
16798 );
16799 let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
16800 let loss = g.reduce(
16801 yv,
16802 ReduceOp::Sum,
16803 vec![0],
16804 false,
16805 Shape::new(&[1], DType::F64),
16806 );
16807 g.set_outputs(vec![loss]);
16808
16809 let bg = vmap(&g, &["x"], batch);
16811
16812 let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
16814 let n_w = n * n;
16815 let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
16816 let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
16817 let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
16818 for _ in 0..batch * n {
16819 x_data_batched.push(rng.next_f32() as f64);
16820 }
16821
16822 let find = |graph: &Graph, want: &str| -> NodeId {
16824 for node in graph.nodes() {
16825 if let Op::Input { name } = &node.op
16826 && name == want
16827 {
16828 return node.id;
16829 }
16830 }
16831 panic!("no input named {want}");
16832 };
16833 let bx = find(&bg, "x");
16834 let bw = find(&bg, "w");
16835 let bb = find(&bg, "b");
16836 let (sched, mut arena) =
16837 prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
16838 execute_thunks(&sched, arena.raw_buf_mut());
16839 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
16845
16846 for bi in 0..batch {
16848 let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
16849 let mut g2 = Graph::new("scalar_run");
16850 let x2 = g2.input("x", Shape::new(&[n], DType::F64));
16851 let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
16852 let b2 = g2.input("b", Shape::new(&[n], DType::F64));
16853 let xr = g2.add_node(
16854 Op::Reshape {
16855 new_shape: vec![1, n as i64],
16856 },
16857 vec![x2],
16858 Shape::new(&[1, n], DType::F64),
16859 );
16860 let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
16861 let mf = g2.add_node(
16862 Op::Reshape {
16863 new_shape: vec![n as i64],
16864 },
16865 vec![m],
16866 Shape::new(&[n], DType::F64),
16867 );
16868 let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
16869 let l2 = g2.reduce(
16870 yv2,
16871 ReduceOp::Sum,
16872 vec![0],
16873 false,
16874 Shape::new(&[1], DType::F64),
16875 );
16876 g2.set_outputs(vec![l2]);
16877 let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
16878 execute_thunks(&s2, a2.raw_buf_mut());
16879 let scalar_out = read_arena_f64(&a2, l2, 1);
16880 assert!(
16881 (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
16882 "row {bi}: batched={} scalar={}",
16883 batched_out[bi],
16884 scalar_out[0]
16885 );
16886 }
16887 }
16888
16889 #[test]
16896 fn scan_with_xs_dxs_matches_fd() {
16897 use rlx_opt::autodiff::grad_with_loss;
16898 let n = 3usize;
16899 let length = 3u32;
16900 let dt = 0.1_f64;
16901
16902 let mut m_data = vec![0.0_f64; n * n];
16903 for i in 0..n {
16904 m_data[i * n + i] = 1.0 + dt * 2.0;
16905 if i > 0 {
16906 m_data[i * n + (i - 1)] = -dt;
16907 }
16908 if i + 1 < n {
16909 m_data[i * n + (i + 1)] = -dt;
16910 }
16911 }
16912 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16913
16914 let mut body = Graph::new("be_dxs_body");
16915 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16916 let drive = body.input("drive", Shape::new(&[n], DType::F64));
16917 let m = body.add_node(
16918 Op::Constant { data: m_bytes },
16919 vec![],
16920 Shape::new(&[n, n], DType::F64),
16921 );
16922 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16923 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16924 body.set_outputs(vec![next]);
16925
16926 let mut g = Graph::new("be_dxs_outer");
16927 let init = g.input("init", Shape::new(&[n], DType::F64));
16928 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16929 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16930 let loss = g.reduce(
16931 final_carry,
16932 ReduceOp::Sum,
16933 vec![0],
16934 false,
16935 Shape::new(&[1], DType::F64),
16936 );
16937 g.set_outputs(vec![loss]);
16938
16939 let bwd = grad_with_loss(&g, &[init, xs]);
16941 assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
16942
16943 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16944 for node in graph.nodes() {
16945 let name = match &node.op {
16946 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16947 _ => None,
16948 };
16949 if name == Some(want) {
16950 return node.id;
16951 }
16952 }
16953 panic!("no node named {want:?}");
16954 };
16955 let init_bwd = find_by_name(&bwd, "init");
16956 let xs_bwd = find_by_name(&bwd, "xs");
16957 let d_out_bwd = find_by_name(&bwd, "d_output");
16958
16959 let init_data = vec![0.5_f64, 0.0, -0.5];
16960 let xs_data: Vec<f64> = (0..length as usize * n)
16961 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16962 .collect();
16963 let d_seed = [1.0_f64];
16964
16965 let (sched, mut arena) = prepare_f64(
16966 &bwd,
16967 &[
16968 (init_bwd, &init_data),
16969 (xs_bwd, &xs_data),
16970 (d_out_bwd, &d_seed),
16971 ],
16972 );
16973 execute_thunks(&sched, arena.raw_buf_mut());
16974 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16975 let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
16976
16977 let h = 1e-6;
16978 let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
16979 let mut acc = x0.to_vec();
16980 for t in 0..length as usize {
16981 for j in 0..n {
16982 acc[j] += xs_in[t * n + j];
16983 }
16984 let mut a_copy = m_data.clone();
16985 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16986 }
16987 acc.iter().sum()
16988 };
16989
16990 for i in 0..n {
16992 let mut ip = init_data.to_vec();
16993 ip[i] += h;
16994 let mut im = init_data.to_vec();
16995 im[i] -= h;
16996 let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
16997 assert!(
16998 (dinit[i] - fd).abs() < 1e-7,
16999 "FD dinit[{i}]: AD={} FD={}",
17000 dinit[i],
17001 fd
17002 );
17003 }
17004
17005 for t in 0..length as usize {
17007 for j in 0..n {
17008 let idx = t * n + j;
17009 let mut xp = xs_data.clone();
17010 xp[idx] += h;
17011 let mut xm = xs_data.clone();
17012 xm[idx] -= h;
17013 let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
17014 assert!(
17015 (dxs[idx] - fd).abs() < 1e-7,
17016 "FD dxs[t={t},j={j}]: AD={} FD={}",
17017 dxs[idx],
17018 fd
17019 );
17020 }
17021 }
17022 }
17023
17024 #[test]
17032 fn scan_with_xs_gradient_dinit_matches_fd() {
17033 use rlx_opt::autodiff::grad_with_loss;
17034 let n = 3usize;
17035 let length = 3u32;
17036 let dt = 0.1_f64;
17037
17038 let mut m_data = vec![0.0_f64; n * n];
17039 for i in 0..n {
17040 m_data[i * n + i] = 1.0 + dt * 2.0;
17041 if i > 0 {
17042 m_data[i * n + (i - 1)] = -dt;
17043 }
17044 if i + 1 < n {
17045 m_data[i * n + (i + 1)] = -dt;
17046 }
17047 }
17048 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17049
17050 let mut body = Graph::new("be_xs_grad_body");
17051 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17052 let drive = body.input("drive", Shape::new(&[n], DType::F64));
17053 let m = body.add_node(
17054 Op::Constant { data: m_bytes },
17055 vec![],
17056 Shape::new(&[n, n], DType::F64),
17057 );
17058 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
17059 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
17060 body.set_outputs(vec![next]);
17061
17062 let mut g = Graph::new("be_xs_grad_outer");
17063 let init = g.input("init", Shape::new(&[n], DType::F64));
17064 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17065 let final_carry = g.scan_with_xs(init, &[xs], body, length);
17066 let loss = g.reduce(
17067 final_carry,
17068 ReduceOp::Sum,
17069 vec![0],
17070 false,
17071 Shape::new(&[1], DType::F64),
17072 );
17073 g.set_outputs(vec![loss]);
17074
17075 let bwd = grad_with_loss(&g, &[init]);
17076
17077 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17078 for node in graph.nodes() {
17079 let name = match &node.op {
17080 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17081 _ => None,
17082 };
17083 if name == Some(want) {
17084 return node.id;
17085 }
17086 }
17087 panic!("no node named {want:?}");
17088 };
17089 let init_bwd = find_by_name(&bwd, "init");
17090 let xs_bwd = find_by_name(&bwd, "xs");
17091 let d_out_bwd = find_by_name(&bwd, "d_output");
17092
17093 let init_data = vec![0.5_f64, 0.0, -0.5];
17094 let xs_data: Vec<f64> = (0..length as usize * n)
17096 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
17097 .collect();
17098 let d_seed = [1.0_f64];
17099
17100 let (sched, mut arena) = prepare_f64(
17101 &bwd,
17102 &[
17103 (init_bwd, &init_data),
17104 (xs_bwd, &xs_data),
17105 (d_out_bwd, &d_seed),
17106 ],
17107 );
17108 execute_thunks(&sched, arena.raw_buf_mut());
17109 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17110
17111 let h = 1e-6;
17112 let loss_at = |x0: &[f64]| -> f64 {
17113 let mut acc = x0.to_vec();
17114 for t in 0..length as usize {
17115 for j in 0..n {
17116 acc[j] += xs_data[t * n + j];
17117 }
17118 let mut a_copy = m_data.clone();
17119 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
17120 }
17121 acc.iter().sum()
17122 };
17123 for i in 0..n {
17124 let mut ip = init_data.to_vec();
17125 ip[i] += h;
17126 let mut im = init_data.to_vec();
17127 im[i] -= h;
17128 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17129 assert!(
17130 (dinit[i] - fd).abs() < 1e-7,
17131 "FD dinit[{i}]: AD={} FD={}",
17132 dinit[i],
17133 fd
17134 );
17135 }
17136 }
17137
17138 #[test]
17146 fn scan_gradient_geometric_matches_closed_form() {
17147 use rlx_opt::autodiff::grad_with_loss;
17148 let n = 3usize;
17149 let length = 5u32;
17150
17151 let mut body = Graph::new("scan_grad_body");
17152 let x = body.input("carry", Shape::new(&[n], DType::F64));
17153 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
17154 let scale = body.add_node(
17155 Op::Constant { data: scale_bytes },
17156 vec![],
17157 Shape::new(&[n], DType::F64),
17158 );
17159 let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
17160 body.set_outputs(vec![next]);
17161
17162 let mut g = Graph::new("scan_grad_outer");
17163 let init = g.input("init", Shape::new(&[n], DType::F64));
17164 let final_x = g.scan(init, body, length);
17165 let loss = g.reduce(
17166 final_x,
17167 ReduceOp::Sum,
17168 vec![0],
17169 false,
17170 Shape::new(&[1], DType::F64),
17171 );
17172 g.set_outputs(vec![loss]);
17173
17174 let bwd = grad_with_loss(&g, &[init]);
17175 assert_eq!(bwd.outputs.len(), 2);
17176
17177 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17178 for node in graph.nodes() {
17179 let name = match &node.op {
17180 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17181 _ => None,
17182 };
17183 if name == Some(want) {
17184 return node.id;
17185 }
17186 }
17187 panic!("no node named {want:?}");
17188 };
17189 let init_bwd = find_by_name(&bwd, "init");
17190 let d_out_bwd = find_by_name(&bwd, "d_output");
17191
17192 let init_data = vec![1.0_f64; n];
17193 let d_seed = [1.0_f64];
17194 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17195 execute_thunks(&sched, arena.raw_buf_mut());
17196 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17197
17198 let want = 1.1_f64.powi(length as i32);
17199 for i in 0..n {
17200 assert!(
17201 (dinit[i] - want).abs() < 1e-12,
17202 "dinit[{i}] = {} want {}",
17203 dinit[i],
17204 want
17205 );
17206 }
17207
17208 let h = 1e-6;
17210 let loss_at = |x: &[f64]| -> f64 {
17211 let mut acc = x.to_vec();
17212 for _ in 0..length {
17213 for v in acc.iter_mut() {
17214 *v *= 1.1;
17215 }
17216 }
17217 acc.iter().sum()
17218 };
17219 let mut ip = init_data.clone();
17220 ip[0] += h;
17221 let mut im = init_data.clone();
17222 im[0] -= h;
17223 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17224 assert!(
17225 (dinit[0] - fd).abs() < 1e-7,
17226 "FD dinit[0]: AD={} FD={}",
17227 dinit[0],
17228 fd
17229 );
17230 }
17231
17232 #[test]
17235 fn scan_gradient_backward_euler_matches_fd() {
17236 use rlx_opt::autodiff::grad_with_loss;
17237 let n = 4usize;
17238 let length = 3u32;
17239 let dt = 0.05_f64;
17240
17241 let mut m_data = vec![0.0_f64; n * n];
17242 for i in 0..n {
17243 m_data[i * n + i] = 1.0 + dt * 2.0;
17244 if i > 0 {
17245 m_data[i * n + (i - 1)] = -dt;
17246 }
17247 if i + 1 < n {
17248 m_data[i * n + (i + 1)] = -dt;
17249 }
17250 }
17251 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17252
17253 let mut body = Graph::new("be_grad_body");
17254 let x = body.input("x", Shape::new(&[n], DType::F64));
17255 let m = body.add_node(
17256 Op::Constant { data: m_bytes },
17257 vec![],
17258 Shape::new(&[n, n], DType::F64),
17259 );
17260 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17261 body.set_outputs(vec![next]);
17262
17263 let mut g = Graph::new("be_grad_outer");
17264 let init = g.input("x0", Shape::new(&[n], DType::F64));
17265 let final_x = g.scan(init, body, length);
17266 let loss = g.reduce(
17267 final_x,
17268 ReduceOp::Sum,
17269 vec![0],
17270 false,
17271 Shape::new(&[1], DType::F64),
17272 );
17273 g.set_outputs(vec![loss]);
17274
17275 let bwd = grad_with_loss(&g, &[init]);
17276
17277 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17278 for node in graph.nodes() {
17279 let name = match &node.op {
17280 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17281 _ => None,
17282 };
17283 if name == Some(want) {
17284 return node.id;
17285 }
17286 }
17287 panic!("no node named {want:?}");
17288 };
17289 let init_bwd = find_by_name(&bwd, "x0");
17290 let d_out_bwd = find_by_name(&bwd, "d_output");
17291
17292 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17293 let d_seed = [1.0_f64];
17294 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17295 execute_thunks(&sched, arena.raw_buf_mut());
17296 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17297
17298 let h = 1e-6;
17299 let loss_at = |x0: &[f64]| -> f64 {
17300 let mut acc = x0.to_vec();
17301 for _ in 0..length {
17302 let mut a_copy = m_data.clone();
17303 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
17304 }
17305 acc.iter().sum()
17306 };
17307 for i in 0..n {
17308 let mut ip = init_data.to_vec();
17309 ip[i] += h;
17310 let mut im = init_data.to_vec();
17311 im[i] -= h;
17312 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17313 assert!(
17314 (dinit[i] - fd).abs() < 1e-7,
17315 "FD dinit[{i}]: AD={} FD={}",
17316 dinit[i],
17317 fd
17318 );
17319 }
17320 }
17321
17322 #[test]
17328 fn scan_trajectory_backward_euler_records_waveform() {
17329 let n = 4usize;
17330 let length = 5u32;
17331 let dt = 0.05_f64;
17332
17333 let mut m_data = vec![0.0_f64; n * n];
17334 for i in 0..n {
17335 m_data[i * n + i] = 1.0 + dt * 2.0;
17336 if i > 0 {
17337 m_data[i * n + (i - 1)] = -dt;
17338 }
17339 if i + 1 < n {
17340 m_data[i * n + (i + 1)] = -dt;
17341 }
17342 }
17343 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17344
17345 let mut body = Graph::new("be_traj_body");
17346 let x = body.input("x", Shape::new(&[n], DType::F64));
17347 let m = body.add_node(
17348 Op::Constant { data: m_bytes },
17349 vec![],
17350 Shape::new(&[n, n], DType::F64),
17351 );
17352 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17353 body.set_outputs(vec![next]);
17354
17355 let mut g = Graph::new("be_traj_outer");
17356 let init = g.input("x0", Shape::new(&[n], DType::F64));
17357 let traj = g.scan_trajectory(init, body, length);
17358 g.set_outputs(vec![traj]);
17359
17360 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17361 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17362 execute_thunks(&sched, arena.raw_buf_mut());
17363 let got = read_arena_f64(&arena, traj, length as usize * n);
17364
17365 let mut want = Vec::<f64>::with_capacity(length as usize * n);
17367 let mut x_ref = init_data.to_vec();
17368 for _ in 0..length {
17369 let mut a_copy = m_data.clone();
17370 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
17371 want.extend_from_slice(&x_ref);
17372 }
17373 for i in 0..length as usize * n {
17374 assert!(
17375 (got[i] - want[i]).abs() < 1e-12,
17376 "got[{i}] = {} ref {}",
17377 got[i],
17378 want[i]
17379 );
17380 }
17381
17382 for t in 1..length as usize {
17385 let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
17386 let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
17387 assert!(
17388 curr <= prev + 1e-15,
17389 "mass should decay: row {} sum {prev}, row {t} sum {curr}",
17390 t - 1
17391 );
17392 }
17393
17394 let mut body2 = Graph::new("be_final_body");
17398 let x2 = body2.input("x", Shape::new(&[n], DType::F64));
17399 let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17400 let m2 = body2.add_node(
17401 Op::Constant { data: m_bytes2 },
17402 vec![],
17403 Shape::new(&[n, n], DType::F64),
17404 );
17405 let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
17406 body2.set_outputs(vec![next2]);
17407
17408 let mut g2 = Graph::new("be_final_outer");
17409 let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
17410 let final_x = g2.scan(init2, body2, length);
17411 g2.set_outputs(vec![final_x]);
17412 let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
17413 execute_thunks(&sched2, arena2.raw_buf_mut());
17414 let final_got = read_arena_f64(&arena2, final_x, n);
17415
17416 let last_row = &got[(length as usize - 1) * n..length as usize * n];
17417 for i in 0..n {
17418 assert!(
17419 (last_row[i] - final_got[i]).abs() < 1e-15,
17420 "last trajectory row[{i}] = {} vs final-scan = {}",
17421 last_row[i],
17422 final_got[i]
17423 );
17424 }
17425 }
17426
17427 #[test]
17433 fn scan_backward_euler_heat_f64() {
17434 let n = 4usize;
17435 let length = 5u32;
17436 let dt = 0.05_f64;
17437
17438 let mut m_data = vec![0.0_f64; n * n];
17441 for i in 0..n {
17442 m_data[i * n + i] = 1.0 + dt * 2.0;
17443 if i > 0 {
17444 m_data[i * n + (i - 1)] = -dt;
17445 }
17446 if i + 1 < n {
17447 m_data[i * n + (i + 1)] = -dt;
17448 }
17449 }
17450 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17451
17452 let mut body = Graph::new("be_body");
17453 let x = body.input("x", Shape::new(&[n], DType::F64));
17454 let m = body.add_node(
17455 Op::Constant { data: m_bytes },
17456 vec![],
17457 Shape::new(&[n, n], DType::F64),
17458 );
17459 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17460 body.set_outputs(vec![next]);
17461
17462 let mut g = Graph::new("be_outer");
17463 let init = g.input("x0", Shape::new(&[n], DType::F64));
17464 let final_x = g.scan(init, body, length);
17465 g.set_outputs(vec![final_x]);
17466
17467 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17469 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17470 execute_thunks(&sched, arena.raw_buf_mut());
17471 let got = read_arena_f64(&arena, final_x, n);
17472
17473 let mut ref_x = init_data.to_vec();
17475 for _ in 0..length {
17476 let mut a_copy = m_data.clone();
17477 crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
17478 }
17479 for i in 0..n {
17480 assert!(
17481 (got[i] - ref_x[i]).abs() < 1e-12,
17482 "got[{i}] = {} ref {}",
17483 got[i],
17484 ref_x[i]
17485 );
17486 }
17487 let mass: f64 = got.iter().sum();
17492 assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
17493 }
17494
17495 #[test]
17499 fn dense_solve_f64_multi_rhs_forward() {
17500 let n = 3usize;
17501 let k = 2usize;
17502 let mut g = Graph::new("solve_multi_rhs");
17503 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17504 let b = g.input("B", Shape::new(&[n, k], DType::F64));
17505 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17506 g.set_outputs(vec![x]);
17507
17508 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17509 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17510 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17511 execute_thunks(&sched, arena.raw_buf_mut());
17512 let x_got = read_arena_f64(&arena, x, n * k);
17513 for c in 0..k {
17514 for i in 0..n {
17515 let mut acc = 0.0_f64;
17516 for j in 0..n {
17517 acc += a_data[i * n + j] * x_got[j * k + c];
17518 }
17519 let want = b_data[i * k + c];
17520 assert!(
17521 (acc - want).abs() < 1e-10,
17522 "col {c} row {i}: got {acc} want {want}"
17523 );
17524 }
17525 }
17526 }
17527
17528 #[test]
17531 fn dense_solve_f64_multi_rhs_gradient() {
17532 use rlx_opt::autodiff::grad_with_loss;
17533 let n = 3usize;
17534 let k = 2usize;
17535 let mut g = Graph::new("solve_mrhs_grad");
17536 let a = g.param("A", Shape::new(&[n, n], DType::F64));
17537 let b = g.input("B", Shape::new(&[n, k], DType::F64));
17538 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17539 let loss = g.reduce(
17540 x,
17541 ReduceOp::Sum,
17542 vec![0, 1],
17543 false,
17544 Shape::new(&[1], DType::F64),
17545 );
17546 g.set_outputs(vec![loss]);
17547
17548 let bwd = grad_with_loss(&g, &[a, b]);
17549 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17550 for node in graph.nodes() {
17551 let name = match &node.op {
17552 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17553 _ => None,
17554 };
17555 if name == Some(want) {
17556 return node.id;
17557 }
17558 }
17559 panic!("no node named {want:?}");
17560 };
17561 let a_bwd = find_by_name(&bwd, "A");
17562 let b_bwd = find_by_name(&bwd, "B");
17563 let d_out = find_by_name(&bwd, "d_output");
17564
17565 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17566 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17567 let d_seed = [1.0_f64];
17568
17569 let (sched, mut arena) = prepare_f64(
17570 &bwd,
17571 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
17572 );
17573 execute_thunks(&sched, arena.raw_buf_mut());
17574 let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
17575 let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
17576
17577 let mut x_ref = b_data;
17579 {
17580 let mut a_copy = a_data;
17581 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
17582 }
17583 let mut at = [0.0_f64; 9];
17584 for i in 0..n {
17585 for j in 0..n {
17586 at[i * n + j] = a_data[j * n + i];
17587 }
17588 }
17589 let mut ones_nk = vec![1.0_f64; n * k];
17590 crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
17591 let db_ref = ones_nk;
17592 let mut da_ref = [0.0_f64; 9];
17593 for i in 0..n {
17594 for j in 0..n {
17595 let mut acc = 0.0_f64;
17596 for c in 0..k {
17597 acc += db_ref[i * k + c] * x_ref[j * k + c];
17598 }
17599 da_ref[i * n + j] = -acc;
17600 }
17601 }
17602 for i in 0..n * k {
17603 assert!(
17604 (db_got[i] - db_ref[i]).abs() < 1e-10,
17605 "dB[{i}]: got {} want {}",
17606 db_got[i],
17607 db_ref[i]
17608 );
17609 }
17610 for i in 0..n * n {
17611 assert!(
17612 (da_got[i] - da_ref[i]).abs() < 1e-10,
17613 "dA[{i}]: got {} want {}",
17614 da_got[i],
17615 da_ref[i]
17616 );
17617 }
17618
17619 let h = 1e-6;
17621 let mut bp = b_data;
17622 bp[0] += h;
17623 let mut bm = b_data;
17624 bm[0] -= h;
17625 let xp = {
17626 let mut a_copy = a_data;
17627 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17628 bp
17629 };
17630 let xm = {
17631 let mut a_copy = a_data;
17632 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17633 bm
17634 };
17635 let lp: f64 = xp.iter().sum();
17636 let lm: f64 = xm.iter().sum();
17637 let fd = (lp - lm) / (2.0 * h);
17638 assert!(
17639 (db_got[0] - fd).abs() < 1e-7,
17640 "FD dB[0,0]: AD={} FD={}",
17641 db_got[0],
17642 fd
17643 );
17644 }
17645
17646 #[test]
17648 fn dense_solve_f64_multi_rhs_jvp() {
17649 use rlx_opt::autodiff_fwd::jvp;
17650 let n = 3usize;
17651 let k = 2usize;
17652 let mut g = Graph::new("solve_mrhs_jvp");
17653 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17654 let b = g.input("B", Shape::new(&[n, k], DType::F64));
17655 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17656 g.set_outputs(vec![x]);
17657
17658 let jg = jvp(&g, &[b]);
17659 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17660 for node in graph.nodes() {
17661 let name = match &node.op {
17662 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17663 _ => None,
17664 };
17665 if name == Some(want) {
17666 return node.id;
17667 }
17668 }
17669 panic!("no node named {want:?}");
17670 };
17671 let a_id = find_by_name(&jg, "A");
17672 let b_id = find_by_name(&jg, "B");
17673 let tb_id = find_by_name(&jg, "tangent_B");
17674
17675 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17676 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17677 let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
17678
17679 let (sched, mut arena) =
17680 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17681 execute_thunks(&sched, arena.raw_buf_mut());
17682 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
17683
17684 let mut a_copy = a_data;
17685 let mut tb_copy = tb_data;
17686 crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
17687 for i in 0..n * k {
17688 assert!(
17689 (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
17690 "t_X[{i}]: AD={} ref={}",
17691 tangent_x[i],
17692 tb_copy[i]
17693 );
17694 }
17695
17696 let h = 1e-6;
17697 let mut bp = b_data;
17698 let mut bm = b_data;
17699 for i in 0..n * k {
17700 bp[i] += h * tb_data[i];
17701 bm[i] -= h * tb_data[i];
17702 }
17703 let xp = {
17704 let mut a_copy = a_data;
17705 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17706 bp
17707 };
17708 let xm = {
17709 let mut a_copy = a_data;
17710 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17711 bm
17712 };
17713 for i in 0..n * k {
17714 let fd = (xp[i] - xm[i]) / (2.0 * h);
17715 assert!(
17716 (tangent_x[i] - fd).abs() < 1e-7,
17717 "FD t_X[{i}]: AD={} FD={}",
17718 tangent_x[i],
17719 fd
17720 );
17721 }
17722 }
17723
17724 #[test]
17731 fn jvp_dense_solve_b_runs_and_matches_fd() {
17732 use rlx_opt::autodiff_fwd::jvp;
17733 let n = 3usize;
17734
17735 let mut g = Graph::new("jvp_b_e2e");
17737 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17738 let b = g.input("b", Shape::new(&[n], DType::F64));
17739 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17740 g.set_outputs(vec![x]);
17741
17742 let jg = jvp(&g, &[b]);
17744 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17746 for node in graph.nodes() {
17747 let name = match &node.op {
17748 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17749 _ => None,
17750 };
17751 if name == Some(want) {
17752 return node.id;
17753 }
17754 }
17755 panic!("no node named {want:?}");
17756 };
17757 let a_id = find_by_name(&jg, "A");
17758 let b_id = find_by_name(&jg, "b");
17759 let tb_id = find_by_name(&jg, "tangent_b");
17760
17761 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17762 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17763 let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
17765
17766 let (sched, mut arena) =
17767 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17768 execute_thunks(&sched, arena.raw_buf_mut());
17769
17770 let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
17772 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17773
17774 let t_x_ref = {
17776 let mut a = a_data;
17777 let mut tb = tb_data;
17778 let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
17779 assert_eq!(info, 0);
17780 tb
17781 };
17782 for i in 0..n {
17783 assert!(
17784 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17785 "t_x[{i}]: got {} want {}",
17786 tangent_x[i],
17787 t_x_ref[i]
17788 );
17789 }
17790
17791 let h = 1e-6;
17793 let mut bp = b_data;
17794 let mut bm = b_data;
17795 for i in 0..n {
17796 bp[i] += h * tb_data[i];
17797 bm[i] -= h * tb_data[i];
17798 }
17799 let xp = {
17800 let mut a = a_data;
17801 let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
17802 assert_eq!(info, 0);
17803 bp
17804 };
17805 let xm = {
17806 let mut a = a_data;
17807 let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
17808 assert_eq!(info, 0);
17809 bm
17810 };
17811 let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
17812 for i in 0..n {
17813 assert!(
17814 (tangent_x[i] - fd[i]).abs() < 1e-7,
17815 "FD mismatch t_x[{i}]: AD={} FD={}",
17816 tangent_x[i],
17817 fd[i]
17818 );
17819 }
17820 let primal_ref = {
17822 let mut a = a_data;
17823 let mut b = b_data;
17824 crate::blas::dgesv(&mut a, &mut b, n, 1);
17825 b
17826 };
17827 for i in 0..n {
17828 assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
17829 }
17830 }
17831
17832 #[test]
17838 fn jvp_dense_solve_a_runs_and_matches_fd() {
17839 use rlx_opt::autodiff_fwd::jvp;
17840 let n = 3usize;
17841
17842 let mut g = Graph::new("jvp_a_e2e");
17843 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17844 let b = g.input("b", Shape::new(&[n], DType::F64));
17845 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17846 g.set_outputs(vec![x]);
17847
17848 let jg = jvp(&g, &[a]);
17849 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17850 for node in graph.nodes() {
17851 let name = match &node.op {
17852 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17853 _ => None,
17854 };
17855 if name == Some(want) {
17856 return node.id;
17857 }
17858 }
17859 panic!("no node named {want:?}");
17860 };
17861 let a_id = find_by_name(&jg, "A");
17862 let b_id = find_by_name(&jg, "b");
17863 let ta_id = find_by_name(&jg, "tangent_A");
17864
17865 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17866 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17867 let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
17869
17870 let (sched, mut arena) =
17871 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
17872 execute_thunks(&sched, arena.raw_buf_mut());
17873
17874 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17875
17876 let x_ref = {
17878 let mut a = a_data;
17879 let mut b = b_data;
17880 crate::blas::dgesv(&mut a, &mut b, n, 1);
17881 b
17882 };
17883 let mut prod = [0.0_f64; 3];
17884 for i in 0..n {
17885 for j in 0..n {
17886 prod[i] += ta_data[i * n + j] * x_ref[j];
17887 }
17888 }
17889 let t_x_ref = {
17890 let mut a = a_data;
17891 let mut p = prod;
17892 crate::blas::dgesv(&mut a, &mut p, n, 1);
17893 [-p[0], -p[1], -p[2]]
17894 };
17895 for i in 0..n {
17896 assert!(
17897 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17898 "closed-form t_x[{i}]: AD={} ref={}",
17899 tangent_x[i],
17900 t_x_ref[i]
17901 );
17902 }
17903
17904 let h = 1e-6;
17906 let mut ap = a_data;
17907 let mut am = a_data;
17908 for i in 0..n * n {
17909 ap[i] += h * ta_data[i];
17910 am[i] -= h * ta_data[i];
17911 }
17912 let xp = {
17913 let mut a = ap;
17914 let mut b = b_data;
17915 crate::blas::dgesv(&mut a, &mut b, n, 1);
17916 b
17917 };
17918 let xm = {
17919 let mut a = am;
17920 let mut b = b_data;
17921 crate::blas::dgesv(&mut a, &mut b, n, 1);
17922 b
17923 };
17924 for i in 0..n {
17925 let fd = (xp[i] - xm[i]) / (2.0 * h);
17926 assert!(
17927 (tangent_x[i] - fd).abs() < 1e-7,
17928 "FD t_x[{i}]: AD={} FD={}",
17929 tangent_x[i],
17930 fd
17931 );
17932 }
17933 }
17934
17935 #[test]
17941 fn q_conv2d_matches_reference() {
17942 use rlx_ir::Philox4x32;
17943 let n = 1usize;
17945 let c_in = 2usize;
17946 let h = 5usize;
17947 let w_in = 5usize;
17948 let c_out = 3usize;
17949 let kh = 3usize;
17950 let kw = 3usize;
17951 let ph = 1usize;
17952 let pw = 1usize;
17953 let sh = 1usize;
17954 let sw = 1usize;
17955 let h_out = (h + 2 * ph - kh) / sh + 1;
17956 let w_out = (w_in + 2 * pw - kw) / sw + 1;
17957
17958 let x_scale = 0.04f32;
17959 let w_scale = 0.02f32;
17960 let out_scale = 0.5f32;
17961 let mult = x_scale * w_scale / out_scale;
17962
17963 let mut rng = Philox4x32::new(2099);
17964 let mut xf = vec![0f32; n * c_in * h * w_in];
17965 rng.fill_normal(&mut xf);
17966 let mut wf = vec![0f32; c_out * c_in * kh * kw];
17967 rng.fill_normal(&mut wf);
17968 let xq: Vec<i8> = xf
17969 .iter()
17970 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17971 .collect();
17972 let wq: Vec<i8> = wf
17973 .iter()
17974 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17975 .collect();
17976 let bias: Vec<i32> = vec![0i32; c_out];
17977
17978 let mut g = Graph::new("qconv");
17979 let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
17980 let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
17981 let bn = g.input("b", Shape::new(&[c_out], DType::I32));
17982 let out = g.q_conv2d(
17983 xn,
17984 wn,
17985 bn,
17986 vec![kh, kw],
17987 vec![sh, sw],
17988 vec![ph, pw],
17989 vec![1, 1],
17990 1,
17991 0,
17992 0,
17993 0,
17994 mult,
17995 Shape::new(&[n, c_out, h_out, w_out], DType::I8),
17996 );
17997 g.set_outputs(vec![out]);
17998
17999 let plan = rlx_opt::memory::plan_memory(&g);
18000 let mut arena = crate::arena::Arena::from_plan(plan);
18001 let sched = compile_thunks(&g, &arena);
18002 let xn_off = arena.byte_offset(xn);
18005 let wn_off = arena.byte_offset(wn);
18006 let bn_off = arena.byte_offset(bn);
18007 let out_off = arena.byte_offset(out);
18008 let buf = arena.raw_buf_mut();
18009 unsafe {
18010 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
18011 for (i, &v) in xq.iter().enumerate() {
18012 *p.add(i) = v;
18013 }
18014 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
18015 for (i, &v) in wq.iter().enumerate() {
18016 *p.add(i) = v;
18017 }
18018 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
18019 for (i, &v) in bias.iter().enumerate() {
18020 *p.add(i) = v;
18021 }
18022 }
18023 execute_thunks(&sched, arena.raw_buf_mut());
18024 let out_q: Vec<i8> = unsafe {
18025 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
18026 (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
18027 };
18028
18029 let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
18031 for ni in 0..n {
18032 for co in 0..c_out {
18033 for ho in 0..h_out {
18034 for wo in 0..w_out {
18035 let mut acc: i32 = 0;
18036 for ci in 0..c_in {
18037 for ki in 0..kh {
18038 for kj in 0..kw {
18039 let hi = ho * sh + ki;
18040 let wi = wo * sw + kj;
18041 if hi < ph || wi < pw {
18042 continue;
18043 }
18044 let hi = hi - ph;
18045 let wi = wi - pw;
18046 if hi >= h || wi >= w_in {
18047 continue;
18048 }
18049 let xv =
18050 xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
18051 let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
18052 acc += xv * wv;
18053 }
18054 }
18055 }
18056 let r = (acc as f32 * mult).round() as i32;
18057 let r = r.clamp(-128, 127) as i8;
18058 out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
18059 }
18060 }
18061 }
18062 }
18063
18064 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
18065 assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
18066 }
18067 }
18068
18069 #[test]
18077 fn q_matmul_matches_fake_quant_reference() {
18078 use rlx_ir::Philox4x32;
18079 let m = 3usize;
18080 let k = 8usize;
18081 let n = 5usize;
18082 let mut rng = Philox4x32::new(2031);
18083
18084 let x_scale = 0.05f32;
18086 let w_scale = 0.03f32;
18087 let out_scale = 0.4f32;
18088 let mult = x_scale * w_scale / out_scale;
18089 let mut xf = vec![0f32; m * k];
18090 rng.fill_normal(&mut xf);
18091 let mut wf = vec![0f32; k * n];
18092 rng.fill_normal(&mut wf);
18093 let xq: Vec<i8> = xf
18094 .iter()
18095 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
18096 .collect();
18097 let wq: Vec<i8> = wf
18098 .iter()
18099 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
18100 .collect();
18101 let bias: Vec<i32> = vec![0i32; n];
18102
18103 let _f = DType::F32;
18105 let mut g_q = Graph::new("qmm_direct");
18106 let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
18107 let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
18108 let bn = g_q.input("b", Shape::new(&[n], DType::I32));
18109 let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
18110 g_q.set_outputs(vec![out]);
18111 let plan = rlx_opt::memory::plan_memory(&g_q);
18112 let mut arena = crate::arena::Arena::from_plan(plan);
18113 let sched = compile_thunks(&g_q, &arena);
18114
18115 let xn_off = arena.byte_offset(xn);
18117 let wn_off = arena.byte_offset(wn);
18118 let bn_off = arena.byte_offset(bn);
18119 let out_off = arena.byte_offset(out);
18120 let buf = arena.raw_buf_mut();
18121 unsafe {
18122 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
18123 for (i, &v) in xq.iter().enumerate() {
18124 *p.add(i) = v;
18125 }
18126 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
18127 for (i, &v) in wq.iter().enumerate() {
18128 *p.add(i) = v;
18129 }
18130 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
18131 for (i, &v) in bias.iter().enumerate() {
18132 *p.add(i) = v;
18133 }
18134 }
18135 execute_thunks(&sched, arena.raw_buf_mut());
18136 let out_q: Vec<i8> = unsafe {
18137 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
18138 (0..m * n).map(|i| *p.add(i)).collect()
18139 };
18140
18141 let mut out_ref = vec![0i8; m * n];
18146 for mi in 0..m {
18147 for ni in 0..n {
18148 let mut acc: i32 = 0;
18149 for ki in 0..k {
18150 acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
18151 }
18152 let r = (acc as f32 * mult).round() as i32;
18153 out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
18154 }
18155 }
18156
18157 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
18158 assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
18159 }
18160 }
18161
18162 #[test]
18167 fn quantize_dequantize_round_trip() {
18168 use rlx_ir::Philox4x32;
18169 let len = 64;
18170 let mut rng = Philox4x32::new(2027);
18171 let mut x = vec![0f32; len];
18172 rng.fill_normal(&mut x);
18173 x[0] = 999.0;
18176 x[1] = -999.0;
18177
18178 let scale = 0.05f32;
18179 let zp = 3i32;
18180
18181 let f = DType::F32;
18182 let mut g = Graph::new("qdq");
18183 let xn = g.input("x", Shape::new(&[len], f));
18184 let q = g.quantize(xn, scale, zp);
18185 let dq = g.dequantize(q, scale, zp);
18186 g.set_outputs(vec![dq]);
18187
18188 let plan = rlx_opt::memory::plan_memory(&g);
18189 let mut arena = crate::arena::Arena::from_plan(plan);
18190 let sched = compile_thunks(&g, &arena);
18191 let xn_off = arena.byte_offset(xn);
18192 let dq_off = arena.byte_offset(dq);
18193 let buf = arena.raw_buf_mut();
18194 unsafe {
18195 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18196 for (i, &v) in x.iter().enumerate() {
18197 *p.add(i) = v;
18198 }
18199 }
18200 execute_thunks(&sched, arena.raw_buf_mut());
18201 let out: Vec<f32> = unsafe {
18202 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18203 (0..len).map(|i| *p.add(i)).collect()
18204 };
18205
18206 let sat_pos = (127 - zp) as f32 * scale;
18209 let sat_neg = (-128 - zp) as f32 * scale;
18210 assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
18211 assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
18212
18213 for i in 2..len {
18216 assert!(
18217 (out[i] - x[i]).abs() <= scale + 1e-5,
18218 "qdq[{i}]: {} → {}, scale={scale}",
18219 x[i],
18220 out[i]
18221 );
18222 }
18223 }
18224
18225 #[test]
18231 fn quantize_per_channel_round_trip() {
18232 let c = 4usize;
18233 let inner = 5usize;
18234 let mags = [0.01f32, 0.5, 5.0, 50.0];
18237 let mut x = vec![0f32; c * inner];
18238 for ci in 0..c {
18239 for ii in 0..inner {
18240 x[ci * inner + ii] = match ii {
18244 0 => -mags[ci],
18245 1 => 0.0,
18246 2 => mags[ci],
18247 3 => mags[ci] * 1000.0, _ => -mags[ci] * 1000.0, };
18250 }
18251 }
18252 let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
18253 let zps: Vec<i32> = vec![0, 0, 0, 0];
18254
18255 let f = DType::F32;
18256 let mut g = Graph::new("qdq_pc");
18257 let xn = g.input("x", Shape::new(&[c, inner], f));
18258 let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
18259 let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
18260 g.set_outputs(vec![dq]);
18261
18262 let plan = rlx_opt::memory::plan_memory(&g);
18263 let mut arena = crate::arena::Arena::from_plan(plan);
18264 let sched = compile_thunks(&g, &arena);
18265 let xn_off = arena.byte_offset(xn);
18266 let dq_off = arena.byte_offset(dq);
18267 let buf = arena.raw_buf_mut();
18268 unsafe {
18269 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18270 for (i, &v) in x.iter().enumerate() {
18271 *p.add(i) = v;
18272 }
18273 }
18274 execute_thunks(&sched, arena.raw_buf_mut());
18275 let out: Vec<f32> = unsafe {
18276 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18277 (0..c * inner).map(|i| *p.add(i)).collect()
18278 };
18279
18280 for ci in 0..c {
18281 for ii in 0..3 {
18284 let idx = ci * inner + ii;
18285 assert!(
18286 (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
18287 "ch {ci} idx {ii}: {} vs {}",
18288 x[idx],
18289 out[idx]
18290 );
18291 }
18292 let sat_pos = 127.0 * scales[ci];
18294 let sat_neg = -128.0 * scales[ci];
18295 assert!(
18296 (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
18297 "ch {ci} +sat: {}",
18298 out[ci * inner + 3]
18299 );
18300 assert!(
18301 (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
18302 "ch {ci} -sat: {}",
18303 out[ci * inner + 4]
18304 );
18305 }
18306 }
18307
18308 #[test]
18314 fn activation_backward_matches_numerical_per_kind() {
18315 use rlx_ir::Philox4x32;
18316 use rlx_ir::op::Activation;
18317 let mut rng = Philox4x32::new(91);
18318 let len = 32;
18319 let mut x_pos = vec![0f32; len];
18324 rng.fill_normal(&mut x_pos);
18325 for v in x_pos.iter_mut() {
18326 *v = v.abs() + 0.5;
18327 }
18328 let mut x_any = vec![0f32; len];
18329 rng.fill_normal(&mut x_any);
18330 let mut dy = vec![0f32; len];
18331 rng.fill_normal(&mut dy);
18332
18333 for &(kind, x_data, eps, tol) in &[
18334 (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
18335 (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
18336 (Activation::Silu, &x_any[..], 1e-3, 5e-3),
18337 (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
18338 (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
18339 (Activation::Exp, &x_any[..], 1e-4, 5e-3),
18340 (Activation::Log, &x_pos[..], 1e-4, 5e-3),
18341 (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
18342 (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
18343 (Activation::Neg, &x_any[..], 1e-3, 5e-4),
18344 ] {
18345 let f = DType::F32;
18346 let mut g = Graph::new("act_bw");
18347 let xn = g.input("x", Shape::new(&[len], f));
18348 let dyn_ = g.input("dy", Shape::new(&[len], f));
18349 let dx = g.activation_backward(kind, xn, dyn_);
18350 g.set_outputs(vec![dx]);
18351
18352 let plan = rlx_opt::memory::plan_memory(&g);
18353 let mut arena = crate::arena::Arena::from_plan(plan);
18354 let sched = compile_thunks(&g, &arena);
18355
18356 let xn_off = arena.byte_offset(xn);
18357 let dyn_off = arena.byte_offset(dyn_);
18358 let dx_off = arena.byte_offset(dx);
18359 let buf = arena.raw_buf_mut();
18360 unsafe {
18361 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18362 for (i, &v) in x_data.iter().enumerate() {
18363 *p.add(i) = v;
18364 }
18365 let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
18366 for (i, &v) in dy.iter().enumerate() {
18367 *p.add(i) = v;
18368 }
18369 }
18370 execute_thunks(&sched, arena.raw_buf_mut());
18371 let analytical: Vec<f32> = unsafe {
18372 let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
18373 (0..len).map(|i| *p.add(i)).collect()
18374 };
18375
18376 let act_apply = |kind: Activation, x: f32| -> f32 {
18379 match kind {
18380 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
18381 Activation::Tanh => x.tanh(),
18382 Activation::Silu => x / (1.0 + (-x).exp()),
18383 Activation::Gelu => {
18384 const INV_SQRT2: f32 = 0.707_106_77;
18386 0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
18387 }
18388 Activation::GeluApprox => {
18389 const C: f32 = 0.797_884_6;
18390 const A: f32 = 0.044_715;
18391 let inner = C * (x + A * x * x * x);
18392 0.5 * x * (1.0 + inner.tanh())
18393 }
18394 Activation::Exp => x.exp(),
18395 Activation::Log => x.ln(),
18396 Activation::Sqrt => x.sqrt(),
18397 Activation::Rsqrt => 1.0 / x.sqrt(),
18398 Activation::Neg => -x,
18399 Activation::Relu => x.max(0.0),
18400 Activation::Abs => x.abs(),
18401 Activation::Round => x.round(),
18402 Activation::Sin => x.sin(),
18403 Activation::Cos => x.cos(),
18404 Activation::Tan => x.tan(),
18405 Activation::Atan => x.atan(),
18406 }
18407 };
18408 for i in 0..len {
18409 let xv = x_data[i];
18410 let plus = act_apply(kind, xv + eps);
18411 let minus = act_apply(kind, xv - eps);
18412 let num = (plus - minus) / (2.0 * eps) * dy[i];
18413 assert!(
18414 (analytical[i] - num).abs() < tol,
18415 "{kind:?}[{i}]: analytical {} vs numerical {num}",
18416 analytical[i]
18417 );
18418 }
18419 }
18420 }
18421
18422 #[test]
18426 fn matmul_3d_gradient_matches_numerical() {
18427 use rlx_ir::Philox4x32;
18428 let batch = 2usize;
18429 let m = 3usize;
18430 let k = 4usize;
18431 let n = 5usize;
18432 let mut rng = Philox4x32::new(101);
18433 let mut a_data = vec![0f32; batch * m * k];
18434 rng.fill_normal(&mut a_data);
18435 let mut b_data = vec![0f32; batch * k * n];
18436 rng.fill_normal(&mut b_data);
18437
18438 let f = DType::F32;
18439 let mut fwd = Graph::new("matmul_3d");
18440 let an = fwd.input("a", Shape::new(&[batch, m, k], f));
18441 let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
18442 let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
18443 let loss = fwd.add_node(
18444 Op::Reduce {
18445 op: ReduceOp::Sum,
18446 axes: vec![0, 1, 2],
18447 keep_dim: false,
18448 },
18449 vec![mm],
18450 Shape::from_dims(&[], f),
18451 );
18452 fwd.set_outputs(vec![loss]);
18453
18454 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
18455 let d_out = bwd_graph
18456 .nodes()
18457 .iter()
18458 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18459 .map(|n| n.id)
18460 .unwrap();
18461
18462 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18463 let mut arena = crate::arena::Arena::from_plan(plan);
18464 let sched = compile_thunks(&bwd_graph, &arena);
18465 for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
18466 let off = arena.byte_offset(id);
18467 let buf = arena.raw_buf_mut();
18468 unsafe {
18469 let p = buf.as_mut_ptr().add(off) as *mut f32;
18470 for (i, &v) in data.iter().enumerate() {
18471 *p.add(i) = v;
18472 }
18473 }
18474 }
18475 execute_thunks(&sched, arena.raw_buf_mut());
18476 let gb_id = bwd_graph.outputs[1];
18477 let g_b: Vec<f32> = unsafe {
18478 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
18479 (0..batch * k * n).map(|i| *p.add(i)).collect()
18480 };
18481
18482 let forward_loss = |b_vals: &[f32]| -> f32 {
18484 let mut out = vec![0f32; batch * m * n];
18485 for bi in 0..batch {
18486 for mi in 0..m {
18487 for ni in 0..n {
18488 let mut acc = 0f32;
18489 for ki in 0..k {
18490 acc +=
18491 a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
18492 }
18493 out[bi * m * n + mi * n + ni] = acc;
18494 }
18495 }
18496 }
18497 out.iter().sum()
18498 };
18499 let eps = 1e-3f32;
18500 let mut bp_p = b_data.clone();
18501 let mut g_b_num = vec![0f32; b_data.len()];
18502 for i in 0..b_data.len() {
18503 let s = bp_p[i];
18504 bp_p[i] = s + eps;
18505 let lp = forward_loss(&bp_p);
18506 bp_p[i] = s - eps;
18507 let lm = forward_loss(&bp_p);
18508 bp_p[i] = s;
18509 g_b_num[i] = (lp - lm) / (2.0 * eps);
18510 }
18511 for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
18512 assert!(
18513 (a - n).abs() < 5e-3,
18514 "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
18515 );
18516 }
18517 }
18518
18519 #[test]
18525 fn softmax_gradient_matches_numerical() {
18526 use rlx_ir::Philox4x32;
18527 let n = 3usize;
18528 let c = 5usize;
18529 let mut rng = Philox4x32::new(57);
18530 let mut x_data = vec![0f32; n * c];
18531 rng.fill_normal(&mut x_data);
18532
18533 let f = DType::F32;
18534 let mut fwd = Graph::new("softmax_only");
18535 let xn = fwd.input("x", Shape::new(&[n, c], f));
18536 let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
18537 let loss = fwd.add_node(
18541 Op::Reduce {
18542 op: ReduceOp::Sum,
18543 axes: vec![0, 1],
18544 keep_dim: false,
18545 },
18546 vec![sm],
18547 Shape::from_dims(&[], f),
18548 );
18549 fwd.set_outputs(vec![loss]);
18550
18551 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
18555 let d_out = bwd_graph
18556 .nodes()
18557 .iter()
18558 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18559 .map(|n| n.id)
18560 .unwrap();
18561
18562 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18563 let mut arena = crate::arena::Arena::from_plan(plan);
18564 let sched = compile_thunks(&bwd_graph, &arena);
18565 for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
18566 let off = arena.byte_offset(id);
18567 let buf = arena.raw_buf_mut();
18568 unsafe {
18569 let p = buf.as_mut_ptr().add(off) as *mut f32;
18570 for (i, &v) in data.iter().enumerate() {
18571 *p.add(i) = v;
18572 }
18573 }
18574 }
18575 execute_thunks(&sched, arena.raw_buf_mut());
18576 let g_x_id = bwd_graph.outputs[1];
18577 let g_x: Vec<f32> = unsafe {
18578 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
18579 (0..n * c).map(|i| *p.add(i)).collect()
18580 };
18581
18582 let forward_loss = |x: &[f32]| -> f32 {
18586 let mut total = 0f32;
18587 for ni in 0..n {
18588 let row = &x[ni * c..(ni + 1) * c];
18589 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
18590 let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
18591 for &v in row {
18592 total += (v - m).exp() / denom;
18593 }
18594 }
18595 total
18596 };
18597 let eps = 1e-3f32;
18598 let mut p = x_data.clone();
18599 for i in 0..x_data.len() {
18600 let s = p[i];
18601 p[i] = s + eps;
18602 let lp = forward_loss(&p);
18603 p[i] = s - eps;
18604 let lm = forward_loss(&p);
18605 p[i] = s;
18606 let num = (lp - lm) / (2.0 * eps);
18607 assert!(
18608 (g_x[i] - num).abs() < 5e-3,
18609 "softmax g_x[{i}]: analytical {} vs numerical {num}",
18610 g_x[i]
18611 );
18612 }
18613 }
18614
18615 #[test]
18620 fn layer_norm_gradient_matches_numerical() {
18621 use rlx_ir::Philox4x32;
18622 let rows = 3usize;
18623 let h = 6usize;
18624 let mut rng = Philox4x32::new(1009);
18625 let mut x_data = vec![0f32; rows * h];
18626 rng.fill_normal(&mut x_data);
18627 let mut g_data = vec![0f32; h];
18628 rng.fill_normal(&mut g_data);
18629 for v in g_data.iter_mut() {
18630 *v = v.abs() + 0.5;
18631 }
18632 let mut b_data = vec![0f32; h];
18633 rng.fill_normal(&mut b_data);
18634 let eps = 1e-5f32;
18635
18636 let f = DType::F32;
18637 let mut fwd = Graph::new("ln_only");
18638 let xn = fwd.input("x", Shape::new(&[rows, h], f));
18639 let gp = fwd.param("gamma", Shape::new(&[h], f));
18640 let bp = fwd.param("beta", Shape::new(&[h], f));
18641 let ln = fwd.add_node(
18642 Op::LayerNorm { axis: -1, eps },
18643 vec![xn, gp, bp],
18644 Shape::new(&[rows, h], f),
18645 );
18646 let loss = fwd.add_node(
18647 Op::Reduce {
18648 op: ReduceOp::Sum,
18649 axes: vec![0, 1],
18650 keep_dim: false,
18651 },
18652 vec![ln],
18653 Shape::from_dims(&[], f),
18654 );
18655 fwd.set_outputs(vec![loss]);
18656
18657 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
18658 let d_out = bwd_graph
18659 .nodes()
18660 .iter()
18661 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18662 .map(|n| n.id)
18663 .unwrap();
18664
18665 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18666 let mut arena = crate::arena::Arena::from_plan(plan);
18667 let sched = compile_thunks(&bwd_graph, &arena);
18668 for &(id, data) in &[
18669 (xn, &x_data),
18670 (gp, &g_data),
18671 (bp, &b_data),
18672 (d_out, &vec![1.0f32]),
18673 ] {
18674 let off = arena.byte_offset(id);
18675 let buf = arena.raw_buf_mut();
18676 unsafe {
18677 let p = buf.as_mut_ptr().add(off) as *mut f32;
18678 for (i, &v) in data.iter().enumerate() {
18679 *p.add(i) = v;
18680 }
18681 }
18682 }
18683 execute_thunks(&sched, arena.raw_buf_mut());
18684 let read = |id: NodeId, n: usize| -> Vec<f32> {
18685 let off = arena.byte_offset(id);
18686 unsafe {
18687 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
18688 (0..n).map(|i| *p.add(i)).collect()
18689 }
18690 };
18691 let dx_a = read(bwd_graph.outputs[1], rows * h);
18692 let dg_a = read(bwd_graph.outputs[2], h);
18693 let db_a = read(bwd_graph.outputs[3], h);
18694
18695 let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
18696 let mut total = 0f32;
18697 for r in 0..rows {
18698 let row = &x[r * h..(r + 1) * h];
18699 let mean = row.iter().sum::<f32>() / h as f32;
18700 let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
18701 let inv_std = 1.0 / (var + eps).sqrt();
18702 for d in 0..h {
18703 total += ((row[d] - mean) * inv_std) * g[d] + b[d];
18704 }
18705 }
18706 total
18707 };
18708 let h_eps = 1e-3f32;
18709
18710 let mut x_p = x_data.clone();
18711 for i in 0..x_p.len() {
18712 let s = x_p[i];
18713 x_p[i] = s + h_eps;
18714 let lp = forward_loss(&x_p, &g_data, &b_data);
18715 x_p[i] = s - h_eps;
18716 let lm = forward_loss(&x_p, &g_data, &b_data);
18717 x_p[i] = s;
18718 let num = (lp - lm) / (2.0 * h_eps);
18719 assert!(
18720 (dx_a[i] - num).abs() < 5e-3,
18721 "ln dx[{i}]: analytical {} vs numerical {num}",
18722 dx_a[i]
18723 );
18724 }
18725 let mut g_p = g_data.clone();
18726 for i in 0..g_p.len() {
18727 let s = g_p[i];
18728 g_p[i] = s + h_eps;
18729 let lp = forward_loss(&x_data, &g_p, &b_data);
18730 g_p[i] = s - h_eps;
18731 let lm = forward_loss(&x_data, &g_p, &b_data);
18732 g_p[i] = s;
18733 let num = (lp - lm) / (2.0 * h_eps);
18734 assert!(
18735 (dg_a[i] - num).abs() < 5e-3,
18736 "ln dg[{i}]: analytical {} vs numerical {num}",
18737 dg_a[i]
18738 );
18739 }
18740 let mut b_p = b_data.clone();
18741 for i in 0..b_p.len() {
18742 let s = b_p[i];
18743 b_p[i] = s + h_eps;
18744 let lp = forward_loss(&x_data, &g_data, &b_p);
18745 b_p[i] = s - h_eps;
18746 let lm = forward_loss(&x_data, &g_data, &b_p);
18747 b_p[i] = s;
18748 let num = (lp - lm) / (2.0 * h_eps);
18749 assert!(
18750 (db_a[i] - num).abs() < 5e-3,
18751 "ln db[{i}]: analytical {} vs numerical {num}",
18752 db_a[i]
18753 );
18754 }
18755 }
18756
18757 #[test]
18762 fn dense_sce_mean_gradient_matches_numerical() {
18763 use rlx_ir::Philox4x32;
18764 let bs = 4usize;
18765 let k_in = 3usize;
18766 let c = 5usize;
18767 let mut rng = Philox4x32::new(7);
18768 let mut x = vec![0f32; bs * k_in];
18769 rng.fill_normal(&mut x);
18770 let mut w_init = vec![0f32; k_in * c];
18771 rng.fill_normal(&mut w_init);
18772 let mut b_init = vec![0f32; c];
18773 rng.fill_normal(&mut b_init);
18774 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18775
18776 let f = DType::F32;
18778 let mut fwd = Graph::new("dense_sce");
18779 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18780 let lb = fwd.input("labels", Shape::new(&[bs], f));
18781 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18782 let bp = fwd.param("b", Shape::new(&[c], f));
18783 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18784 let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
18785 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18786 let loss = fwd.add_node(
18787 Op::Reduce {
18788 op: ReduceOp::Sum,
18789 axes: vec![0],
18790 keep_dim: false,
18791 },
18792 vec![loss_per],
18793 Shape::from_dims(&[], f),
18795 );
18796 fwd.set_outputs(vec![loss]);
18804
18805 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
18807 let d_out = bwd_graph
18810 .nodes()
18811 .iter()
18812 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18813 .map(|n| n.id)
18814 .expect("d_output input");
18815
18816 let (sched, mut arena) = prepare(
18817 &bwd_graph,
18818 &[
18819 (xn, &x),
18820 (lb, &labels),
18821 (wp, &w_init),
18822 (bp, &b_init),
18823 (d_out, &[1.0]),
18824 ],
18825 );
18826 execute_thunks(&sched, arena.raw_buf_mut());
18827
18828 let outs = &bwd_graph.outputs;
18829 let loss_id = outs[0];
18830 let gw_id = outs[1];
18831 let gb_id = outs[2];
18832 let loss_actual = read_arena(&arena, loss_id, 1)[0];
18833 let gw_actual = read_arena(&arena, gw_id, k_in * c);
18834 let gb_actual = read_arena(&arena, gb_id, c);
18835
18836 let plan = rlx_opt::memory::plan_memory(&fwd);
18840 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18841 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18842 write_arena(&mut fwd_arena, xn, &x);
18843 write_arena(&mut fwd_arena, lb, &labels);
18844
18845 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
18846 write_arena(arena, wp, w);
18847 write_arena(arena, bp, b);
18848 execute_thunks(&fwd_sched, arena.raw_buf_mut());
18849 read_arena(arena, loss, 1)[0]
18850 };
18851
18852 let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
18855 assert!(
18856 (loss_actual - loss_check).abs() < 1e-4,
18857 "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
18858 );
18859
18860 let eps = 1e-3f32;
18861 let mut w_perturbed = w_init.clone();
18862 let mut gw_numerical = vec![0f32; w_init.len()];
18863 for i in 0..w_init.len() {
18864 let saved = w_perturbed[i];
18865 w_perturbed[i] = saved + eps;
18866 let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18867 w_perturbed[i] = saved - eps;
18868 let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18869 w_perturbed[i] = saved;
18870 gw_numerical[i] = (lp - lm) / (2.0 * eps);
18871 }
18872 for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
18873 assert!(
18874 (a - n).abs() < 5e-3,
18875 "grad_w[{i}]: analytical {a} vs numerical {n}"
18876 );
18877 }
18878
18879 let mut b_perturbed = b_init.clone();
18880 let mut gb_numerical = vec![0f32; b_init.len()];
18881 for i in 0..b_init.len() {
18882 let saved = b_perturbed[i];
18883 b_perturbed[i] = saved + eps;
18884 let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18885 b_perturbed[i] = saved - eps;
18886 let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18887 b_perturbed[i] = saved;
18888 gb_numerical[i] = (lp - lm) / (2.0 * eps);
18889 }
18890 for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
18891 assert!(
18892 (a - n).abs() < 5e-3,
18893 "grad_b[{i}]: analytical {a} vs numerical {n}"
18894 );
18895 }
18896 }
18897
18898 #[test]
18901 fn dense_sce_mean_reduce_gradient_matches_numerical() {
18902 use rlx_ir::Philox4x32;
18903 let bs = 3usize;
18904 let k_in = 2usize;
18905 let c = 4usize;
18906 let mut rng = Philox4x32::new(13);
18907 let mut x = vec![0f32; bs * k_in];
18908 rng.fill_normal(&mut x);
18909 let mut w_init = vec![0f32; k_in * c];
18910 rng.fill_normal(&mut w_init);
18911 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18912
18913 let f = DType::F32;
18914 let mut fwd = Graph::new("dense_sce_mean");
18915 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18916 let lb = fwd.input("labels", Shape::new(&[bs], f));
18917 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18918 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18919 let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
18920 let loss = fwd.add_node(
18921 Op::Reduce {
18922 op: ReduceOp::Mean,
18923 axes: vec![0],
18924 keep_dim: false,
18925 },
18926 vec![loss_per],
18927 Shape::from_dims(&[], f),
18928 );
18929 fwd.set_outputs(vec![loss]);
18930
18931 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
18932 let d_out = bwd_graph
18933 .nodes()
18934 .iter()
18935 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18936 .map(|n| n.id)
18937 .unwrap();
18938
18939 let (sched, mut arena) = prepare(
18940 &bwd_graph,
18941 &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
18942 );
18943 execute_thunks(&sched, arena.raw_buf_mut());
18944
18945 let outs = &bwd_graph.outputs;
18946 let loss_id = outs[0];
18947 let gw_id = outs[1];
18948 let _ = read_arena(&arena, loss_id, 1)[0];
18949 let gw_actual = read_arena(&arena, gw_id, k_in * c);
18950
18951 let plan = rlx_opt::memory::plan_memory(&fwd);
18952 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18953 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18954 write_arena(&mut fwd_arena, xn, &x);
18955 write_arena(&mut fwd_arena, lb, &labels);
18956
18957 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
18958 write_arena(arena, wp, w);
18959 execute_thunks(&fwd_sched, arena.raw_buf_mut());
18960 read_arena(arena, loss, 1)[0]
18961 };
18962
18963 let eps = 1e-3f32;
18964 let mut wp_p = w_init.clone();
18965 let mut gw_num = vec![0f32; w_init.len()];
18966 for i in 0..w_init.len() {
18967 let s = wp_p[i];
18968 wp_p[i] = s + eps;
18969 let lp = run_loss(&mut fwd_arena, &wp_p);
18970 wp_p[i] = s - eps;
18971 let lm = run_loss(&mut fwd_arena, &wp_p);
18972 wp_p[i] = s;
18973 gw_num[i] = (lp - lm) / (2.0 * eps);
18974 }
18975 for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
18976 assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
18977 }
18978 }
18979 #[test]
18984 fn tinyconv_full_gradient_matches_numerical() {
18985 use rlx_ir::Philox4x32;
18986 let n = 1usize;
18988 let c_in = 1usize;
18989 let h = 6usize;
18990 let w_in = 6usize;
18991 let c_mid = 2usize; let kh = 3;
18993 let kw = 3;
18994 let h1 = h - kh + 1; let w1 = w_in - kw + 1; let h2 = h1 / 2;
18997 let w2 = w1 / 2; let flat = c_mid * h2 * w2; let num_classes = 3usize;
19000
19001 let mut rng = Philox4x32::new(31);
19002 let mut x = vec![0f32; n * c_in * h * w_in];
19003 rng.fill_normal(&mut x);
19004 let mut wc = vec![0f32; c_mid * c_in * kh * kw];
19005 rng.fill_normal(&mut wc);
19006 for v in wc.iter_mut() {
19007 *v *= 0.2;
19008 }
19009 let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
19018 let mut wfc = vec![0f32; flat * num_classes];
19019 rng.fill_normal(&mut wfc);
19020 for v in wfc.iter_mut() {
19021 *v *= 0.5;
19022 }
19023 let mut bfc = vec![0f32; num_classes];
19024 rng.fill_normal(&mut bfc);
19025 let labels: Vec<f32> = vec![1.0]; let f = DType::F32;
19028 let mut fwd = Graph::new("tinyconv");
19029 let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
19030 let lb = fwd.input("labels", Shape::new(&[n], f));
19031 let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
19032 let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
19033 let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
19034 let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
19035
19036 let conv = fwd.add_node(
19038 Op::Conv {
19039 kernel_size: vec![kh, kw],
19040 stride: vec![1, 1],
19041 padding: vec![0, 0],
19042 dilation: vec![1, 1],
19043 groups: 1,
19044 },
19045 vec![xn, wcp],
19046 Shape::new(&[n, c_mid, h1, w1], f),
19047 );
19048 let bc_4d = fwd.add_node(
19060 Op::Reshape {
19061 new_shape: vec![1, c_mid as i64, 1, 1],
19062 },
19063 vec![bcp],
19064 Shape::new(&[1, c_mid, 1, 1], f),
19065 );
19066 let bc_expanded = fwd.add_node(
19067 Op::Expand {
19068 target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
19069 },
19070 vec![bc_4d],
19071 Shape::new(&[n, c_mid, h1, w1], f),
19072 );
19073 let conv_b = fwd.binary(
19074 BinaryOp::Add,
19075 conv,
19076 bc_expanded,
19077 Shape::new(&[n, c_mid, h1, w1], f),
19078 );
19079 let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
19080 let pool = fwd.add_node(
19081 Op::Pool {
19082 kind: ReduceOp::Max,
19083 kernel_size: vec![2, 2],
19084 stride: vec![2, 2],
19085 padding: vec![0, 0],
19086 },
19087 vec![relu],
19088 Shape::new(&[n, c_mid, h2, w2], f),
19089 );
19090 let flatn = fwd.add_node(
19091 Op::Reshape {
19092 new_shape: vec![n as i64, flat as i64],
19093 },
19094 vec![pool],
19095 Shape::new(&[n, flat], f),
19096 );
19097 let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
19098 let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
19099 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
19100 let loss = fwd.add_node(
19101 Op::Reduce {
19102 op: ReduceOp::Mean,
19103 axes: vec![0],
19104 keep_dim: false,
19105 },
19106 vec![loss_per],
19107 Shape::from_dims(&[], f),
19108 );
19109 fwd.set_outputs(vec![loss]);
19110
19111 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
19112 let d_out = bwd_graph
19113 .nodes()
19114 .iter()
19115 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
19116 .map(|n| n.id)
19117 .unwrap();
19118
19119 let (sched, mut arena) = prepare(
19120 &bwd_graph,
19121 &[
19122 (xn, &x),
19123 (lb, &labels),
19124 (wcp, &wc),
19125 (bcp, &bc),
19126 (wfp, &wfc),
19127 (bfp, &bfc),
19128 (d_out, &[1.0]),
19129 ],
19130 );
19131 execute_thunks(&sched, arena.raw_buf_mut());
19132
19133 let outs = bwd_graph.outputs.clone();
19134 let loss_id = outs[0];
19135 let g_wc_id = outs[1];
19136 let g_bc_id = outs[2];
19137 let g_wfc_id = outs[3];
19138 let g_bfc_id = outs[4];
19139 let loss_actual = read_arena(&arena, loss_id, 1)[0];
19140 let g_wc = read_arena(&arena, g_wc_id, wc.len());
19141 let g_bc = read_arena(&arena, g_bc_id, bc.len());
19142 let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
19143 let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
19144
19145 let plan = rlx_opt::memory::plan_memory(&fwd);
19147 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
19148 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
19149 write_arena(&mut fwd_arena, xn, &x);
19150 write_arena(&mut fwd_arena, lb, &labels);
19151
19152 let run_loss = |arena: &mut crate::arena::Arena,
19155 wc: &[f32],
19156 bc: &[f32],
19157 wfc: &[f32],
19158 bfc: &[f32]|
19159 -> f32 {
19160 write_arena(arena, wcp, wc);
19161 write_arena(arena, bcp, bc);
19162 write_arena(arena, wfp, wfc);
19163 write_arena(arena, bfp, bfc);
19164 execute_thunks(&fwd_sched, arena.raw_buf_mut());
19165 read_arena(arena, loss, 1)[0]
19166 };
19167
19168 let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
19169 assert!(
19170 (loss_actual - loss_check).abs() < 1e-4,
19171 "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
19172 );
19173
19174 let eps = 1e-3f32;
19175 let check_grad = |arena: &mut crate::arena::Arena,
19176 name: &str,
19177 analytical: &[f32],
19178 mut perturb: Box<
19179 dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
19180 >,
19181 n: usize| {
19182 for i in 0..n {
19183 let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
19184 let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
19185 let num = (lp - lm) / (2.0 * eps);
19186 assert!(
19187 (analytical[i] - num).abs() < 5e-3,
19188 "{name}[{i}]: analytical {} vs numerical {num}",
19189 analytical[i]
19190 );
19191 }
19192 };
19193
19194 #[allow(unused_macros)]
19197 macro_rules! sweep {
19198 ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
19199 let n = $base.len();
19200 for i in 0..n {
19201 let mut p = $base.clone();
19202 let s = p[i];
19203 p[i] = s + eps;
19204 let lp = {
19205 let $set_param = &p;
19206 run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
19207 let _ = $set_param;
19210 0.0_f32
19212 };
19213 let _ = lp;
19214 }
19215 }};
19216 }
19217 let _ = check_grad; for i in 0..wc.len() {
19221 let mut p = wc.clone();
19222 let s = p[i];
19223 p[i] = s + eps;
19224 let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19225 p[i] = s - eps;
19226 let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19227 let num = (lp - lm) / (2.0 * eps);
19228 assert!(
19229 (g_wc[i] - num).abs() < 5e-3,
19230 "g_wc[{i}]: {} vs {num}",
19231 g_wc[i]
19232 );
19233 }
19234 for i in 0..bc.len() {
19235 let mut p = bc.clone();
19236 let s = p[i];
19237 p[i] = s + eps;
19238 let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19239 p[i] = s - eps;
19240 let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19241 let num = (lp - lm) / (2.0 * eps);
19242 assert!(
19243 (g_bc[i] - num).abs() < 5e-3,
19244 "g_bc[{i}]: {} vs {num}",
19245 g_bc[i]
19246 );
19247 }
19248 for i in 0..wfc.len() {
19249 let mut p = wfc.clone();
19250 let s = p[i];
19251 p[i] = s + eps;
19252 let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19253 p[i] = s - eps;
19254 let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19255 let num = (lp - lm) / (2.0 * eps);
19256 assert!(
19257 (g_wfc[i] - num).abs() < 5e-3,
19258 "g_wfc[{i}]: {} vs {num}",
19259 g_wfc[i]
19260 );
19261 }
19262 for i in 0..bfc.len() {
19263 let mut p = bfc.clone();
19264 let s = p[i];
19265 p[i] = s + eps;
19266 let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19267 p[i] = s - eps;
19268 let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19269 let num = (lp - lm) / (2.0 * eps);
19270 assert!(
19271 (g_bfc[i] - num).abs() < 5e-3,
19272 "g_bfc[{i}]: {} vs {num}",
19273 g_bfc[i]
19274 );
19275 }
19276 }
19277
19278 #[test]
19282 fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
19283 let f = DType::F32;
19284 let mut g = Graph::new("nr_skip");
19285 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
19286 let cos = g.input("cos", Shape::new(&[16], f));
19287 let sin = g.input("sin", Shape::new(&[16], f));
19288 let q = g.narrow_(qkv, 2, 0, 64);
19289 let q_rope = g.rope(q, cos, sin, 16);
19290 let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
19292 g.set_outputs(vec![q_rope, q_dup]);
19293
19294 let plan = rlx_opt::memory::plan_memory(&g);
19295 let arena = crate::arena::Arena::from_plan(plan);
19296 let sched = compile_thunks(&g, &arena);
19297
19298 let narrow_count = sched
19299 .thunks
19300 .iter()
19301 .filter(|t| matches!(t, Thunk::Narrow { .. }))
19302 .count();
19303 assert!(
19304 narrow_count >= 1,
19305 "Narrow with multiple consumers must NOT be fused away"
19306 );
19307 }
19308
19309 #[test]
19322 fn custom_fn_forward_inlines_body() {
19323 let s = Shape::new(&[3], DType::F32);
19324
19325 let mut body = Graph::new("addone_body");
19327 let x = body.input("x", s.clone());
19328 let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
19329 let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
19330 let y = body.binary(BinaryOp::Add, x, one, s.clone());
19331 body.set_outputs(vec![y]);
19332
19333 let mut g = Graph::new("custom_fn_outer");
19334 let xin = g.input("x_in", s.clone());
19335 let cf = g.custom_fn(vec![xin], body, None, None);
19336 g.set_outputs(vec![cf]);
19337
19338 let xs = vec![10.0_f32, 20.0, 30.0];
19339 let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
19340 execute_thunks(&sched, arena.raw_buf_mut());
19341 let got = read_arena(&arena, cf, 3);
19342 assert_eq!(got, vec![11.0, 21.0, 31.0]);
19343 }
19344
19345 fn find_named(graph: &Graph, want: &str) -> NodeId {
19347 for n in graph.nodes() {
19348 let name = match &n.op {
19349 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19350 _ => None,
19351 };
19352 if name == Some(want) {
19353 return n.id;
19354 }
19355 }
19356 panic!("no node named {want:?} in graph");
19357 }
19358
19359 #[test]
19363 fn custom_fn_vjp_overrides_natural_gradient() {
19364 use rlx_opt::autodiff::grad_with_loss;
19365 let s = Shape::new(&[1], DType::F32);
19366
19367 let mut fwd = Graph::new("id_fwd");
19368 let x = fwd.input("x", s.clone());
19369 fwd.set_outputs(vec![x]);
19370
19371 let mut vjp_g = Graph::new("id_vjp");
19372 let _x_p = vjp_g.input("x", s.clone());
19373 let _y_p = vjp_g.input("primal_output", s.clone());
19374 let dy = vjp_g.input("d_output", s.clone());
19375 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19376 let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19377 let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
19378 vjp_g.set_outputs(vec![dx]);
19379
19380 let mut g = Graph::new("outer");
19381 let xp = g.param("x", s.clone());
19382 let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
19383 g.set_outputs(vec![cf]);
19384
19385 let bwd = grad_with_loss(&g, &[xp]);
19386 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
19387
19388 let xb = find_named(&bwd, "x");
19389 let dout = find_named(&bwd, "d_output");
19390 let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
19391 execute_thunks(&sched, arena.raw_buf_mut());
19392 let loss = read_arena(&arena, bwd.outputs[0], 1);
19393 let dx_v = read_arena(&arena, bwd.outputs[1], 1);
19394 assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
19395 assert!(
19396 (dx_v[0] - 2.0).abs() < 1e-6,
19397 "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
19398 dx_v[0]
19399 );
19400 }
19401
19402 #[test]
19407 fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
19408 use rlx_opt::autodiff::grad_with_loss;
19409 let s = Shape::new(&[1], DType::F32);
19410
19411 let mut fwd = Graph::new("mul_fwd");
19412 let a_f = fwd.input("a", s.clone());
19413 let b_f = fwd.input("b", s.clone());
19414 let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
19415 fwd.set_outputs(vec![y_f]);
19416
19417 let mut vjp_g = Graph::new("mul_vjp");
19418 let a_v = vjp_g.input("a", s.clone());
19419 let b_v = vjp_g.input("b", s.clone());
19420 let _y_v = vjp_g.input("primal_output", s.clone());
19421 let dy_v = vjp_g.input("d_output", s.clone());
19422 let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
19423 let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
19424 vjp_g.set_outputs(vec![da, db]);
19425
19426 let mut g = Graph::new("outer");
19427 let ap = g.param("a", s.clone());
19428 let bp = g.param("b", s.clone());
19429 let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
19430 g.set_outputs(vec![cf]);
19431
19432 let bwd = grad_with_loss(&g, &[ap, bp]);
19433 assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
19434
19435 let ab = find_named(&bwd, "a");
19436 let bb = find_named(&bwd, "b");
19437 let dout = find_named(&bwd, "d_output");
19438 let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
19439 execute_thunks(&sched, arena.raw_buf_mut());
19440 let loss = read_arena(&arena, bwd.outputs[0], 1);
19441 let da_v = read_arena(&arena, bwd.outputs[1], 1);
19442 let db_v = read_arena(&arena, bwd.outputs[2], 1);
19443 assert!((loss[0] - 15.0).abs() < 1e-5);
19444 assert!(
19445 (da_v[0] - 5.0).abs() < 1e-5,
19446 "da should be b=5.0, got {}",
19447 da_v[0]
19448 );
19449 assert!(
19450 (db_v[0] - 3.0).abs() < 1e-5,
19451 "db should be a=3.0, got {}",
19452 db_v[0]
19453 );
19454 }
19455
19456 #[test]
19459 fn custom_fn_jvp_overrides_natural_tangent() {
19460 use rlx_opt::autodiff_fwd::jvp;
19461 let s = Shape::new(&[1], DType::F32);
19462
19463 let mut fwd = Graph::new("id_fwd");
19464 let x = fwd.input("x", s.clone());
19465 fwd.set_outputs(vec![x]);
19466
19467 let mut jvp_g = Graph::new("id_jvp");
19468 let _x_p = jvp_g.input("x", s.clone());
19469 let tx = jvp_g.input("tangent_0", s.clone());
19470 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19471 let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19472 let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
19473 jvp_g.set_outputs(vec![ty]);
19474
19475 let mut g = Graph::new("outer");
19476 let xin = g.input("x_in", s.clone());
19477 let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
19478 g.set_outputs(vec![cf]);
19479
19480 let fwd_g = jvp(&g, &[xin]);
19481 assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
19482
19483 let xb = find_named(&fwd_g, "x_in");
19484 let tan = find_named(&fwd_g, "tangent_x_in");
19485 let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
19486 execute_thunks(&sched, arena.raw_buf_mut());
19487 let y = read_arena(&arena, fwd_g.outputs[0], 1);
19488 let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
19489 assert!((y[0] - 7.0).abs() < 1e-6);
19490 assert!(
19491 (ty_v[0] - 2.0).abs() < 1e-6,
19492 "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
19493 ty_v[0]
19494 );
19495 }
19496
19497 #[test]
19502 fn c64_dtype_storage_layout() {
19503 assert_eq!(
19504 DType::C64.size_bytes(),
19505 8,
19506 "C64 should be 8 bytes (f32 real + f32 imag)"
19507 );
19508 assert!(DType::C64.is_complex());
19509 assert!(!DType::C64.is_float());
19510
19511 let s = Shape::new(&[2], DType::C64);
19513 assert_eq!(s.size_bytes().unwrap(), 16);
19514 }
19515
19516 fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
19523 let n = a.len();
19524 let s = Shape::new(&[n], DType::C64);
19525 let mut g = Graph::new("c64_bin");
19526 let in_a = g.input("a", s.clone());
19527 let in_b = g.input("b", s.clone());
19528 let out = g.binary(op, in_a, in_b, s.clone());
19529 g.set_outputs(vec![out]);
19530
19531 let plan = rlx_opt::memory::plan_memory(&g);
19532 let mut arena = crate::arena::Arena::from_plan(plan);
19533 let sched = compile_thunks(&g, &arena);
19534
19535 let a_off = arena.byte_offset(in_a);
19536 let b_off = arena.byte_offset(in_b);
19537 let out_off = arena.byte_offset(out);
19538 let buf = arena.raw_buf_mut();
19540 unsafe {
19541 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19542 let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
19543 for (i, &(re, im)) in a.iter().enumerate() {
19544 *pa.add(2 * i) = re;
19545 *pa.add(2 * i + 1) = im;
19546 }
19547 for (i, &(re, im)) in b.iter().enumerate() {
19548 *pb.add(2 * i) = re;
19549 *pb.add(2 * i + 1) = im;
19550 }
19551 }
19552 execute_thunks(&sched, arena.raw_buf_mut());
19553 let raw_out: Vec<f32> = unsafe {
19554 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19555 (0..(2 * n)).map(|i| *p.add(i)).collect()
19556 };
19557 (0..n)
19558 .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
19559 .collect()
19560 }
19561
19562 #[track_caller]
19563 fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
19564 let dr = (got.0 - expected.0).abs();
19565 let di = (got.1 - expected.1).abs();
19566 assert!(
19567 dr < tol && di < tol,
19568 "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
19569 got.0,
19570 got.1,
19571 expected.0,
19572 expected.1
19573 );
19574 }
19575
19576 #[test]
19577 fn c64_binary_add_matches_complex_arithmetic() {
19578 let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
19579 let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
19580 let out = run_c64_binary(BinaryOp::Add, &a, &b);
19581 assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
19582 assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
19583 }
19584
19585 #[test]
19586 fn c64_binary_sub_matches_complex_arithmetic() {
19587 let a = [(5.0_f32, 1.0_f32)];
19588 let b = [(2.0_f32, 3.0_f32)];
19589 let out = run_c64_binary(BinaryOp::Sub, &a, &b);
19590 assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
19591 }
19592
19593 #[test]
19594 fn c64_binary_mul_matches_complex_arithmetic() {
19595 let a = [(1.0_f32, 2.0_f32)];
19597 let b = [(3.0_f32, 4.0_f32)];
19598 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19599 assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
19600 }
19601
19602 #[test]
19603 fn c64_binary_div_matches_complex_arithmetic() {
19604 let a = [(1.0_f32, 2.0_f32)];
19608 let b = [(3.0_f32, 4.0_f32)];
19609 let out = run_c64_binary(BinaryOp::Div, &a, &b);
19610 assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
19611 }
19612
19613 #[test]
19614 fn c64_binary_mul_identity_one_is_no_op() {
19615 let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
19617 let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
19618 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19619 assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
19620 assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
19621 }
19622
19623 #[test]
19624 fn c64_binary_mul_by_i_rotates_90_degrees() {
19625 let a = [(1.0_f32, 0.0_f32)];
19627 let b = [(0.0_f32, 1.0_f32)];
19628 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19629 assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
19630 }
19631
19632 #[test]
19633 fn c64_binary_div_by_self_gives_unity() {
19634 let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
19635 let out = run_c64_binary(BinaryOp::Div, &a, &a);
19636 assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
19637 assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
19638 }
19639
19640 #[test]
19641 #[should_panic(expected = "C64: complex max/min/pow")]
19642 fn c64_binary_max_is_rejected_at_lowering() {
19643 run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
19644 }
19645
19646 fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
19647 let n = a.len();
19648 let s = Shape::new(&[n], DType::C64);
19649 let mut g = Graph::new("c64_act");
19650 let in_a = g.input("a", s.clone());
19651 let out = g.activation(act, in_a, s.clone());
19652 g.set_outputs(vec![out]);
19653 let plan = rlx_opt::memory::plan_memory(&g);
19654 let mut arena = crate::arena::Arena::from_plan(plan);
19655 let sched = compile_thunks(&g, &arena);
19656 let a_off = arena.byte_offset(in_a);
19657 let out_off = arena.byte_offset(out);
19658 let buf = arena.raw_buf_mut();
19659 unsafe {
19660 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19661 for (i, &(re, im)) in a.iter().enumerate() {
19662 *pa.add(2 * i) = re;
19663 *pa.add(2 * i + 1) = im;
19664 }
19665 }
19666 execute_thunks(&sched, arena.raw_buf_mut());
19667 let raw: Vec<f32> = unsafe {
19668 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19669 (0..(2 * n)).map(|i| *p.add(i)).collect()
19670 };
19671 (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
19672 }
19673
19674 #[test]
19675 fn c64_activation_neg_negates_both_components() {
19676 let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
19677 let out = run_c64_activation(Activation::Neg, &inp);
19678 assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
19679 assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
19680 }
19681
19682 #[test]
19683 fn c64_activation_exp_matches_euler() {
19684 let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
19687 let out = run_c64_activation(Activation::Exp, &inp);
19688 assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
19689 assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
19690 }
19691
19692 #[test]
19693 fn c64_activation_log_matches_principal_branch() {
19694 let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
19698 let out = run_c64_activation(Activation::Log, &inp);
19699 assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
19700 assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
19701 assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
19702 }
19703
19704 #[test]
19705 fn c64_activation_sqrt_squared_recovers_input() {
19706 let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
19709 let roots = run_c64_activation(Activation::Sqrt, &inp);
19710 assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
19712 assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
19713 }
19714
19715 #[test]
19716 #[should_panic(expected = "no natural complex extension")]
19717 fn c64_activation_relu_is_rejected_at_lowering() {
19718 run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
19719 }
19720
19721 fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
19725 let n = z.len();
19726 let mut g = Graph::new("cns_fwd");
19727 let in_z = g.input("z", Shape::new(&[n], DType::C64));
19728 let out = g.complex_norm_sq(in_z);
19729 g.set_outputs(vec![out]);
19730 let plan = rlx_opt::memory::plan_memory(&g);
19731 let mut arena = crate::arena::Arena::from_plan(plan);
19732 let sched = compile_thunks(&g, &arena);
19733 let z_off = arena.byte_offset(in_z);
19734 let out_off = arena.byte_offset(out);
19735 let buf = arena.raw_buf_mut();
19736 unsafe {
19737 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19738 for (i, &(re, im)) in z.iter().enumerate() {
19739 *pz.add(2 * i) = re;
19740 *pz.add(2 * i + 1) = im;
19741 }
19742 }
19743 execute_thunks(&sched, arena.raw_buf_mut());
19744 unsafe {
19745 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19746 (0..n).map(|i| *p.add(i)).collect()
19747 }
19748 }
19749
19750 fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
19752 let n = z.len();
19753 let mut gr = Graph::new("cns_bwd");
19754 let in_z = gr.input("z", Shape::new(&[n], DType::C64));
19755 let in_g = gr.input("g", Shape::new(&[n], DType::F32));
19756 let out = gr.complex_norm_sq_backward(in_z, in_g);
19757 gr.set_outputs(vec![out]);
19758 let plan = rlx_opt::memory::plan_memory(&gr);
19759 let mut arena = crate::arena::Arena::from_plan(plan);
19760 let sched = compile_thunks(&gr, &arena);
19761 let z_off = arena.byte_offset(in_z);
19762 let g_off = arena.byte_offset(in_g);
19763 let out_off = arena.byte_offset(out);
19764 let buf = arena.raw_buf_mut();
19765 unsafe {
19766 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19767 let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
19768 for (i, &(re, im)) in z.iter().enumerate() {
19769 *pz.add(2 * i) = re;
19770 *pz.add(2 * i + 1) = im;
19771 }
19772 for (i, &v) in g.iter().enumerate() {
19773 *pg.add(i) = v;
19774 }
19775 }
19776 execute_thunks(&sched, arena.raw_buf_mut());
19777 unsafe {
19778 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19779 (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
19780 }
19781 }
19782
19783 #[test]
19784 fn complex_norm_sq_matches_textbook() {
19785 let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
19789 let out = run_complex_norm_sq(&z);
19790 assert!((out[0] - 25.0).abs() < 1e-5);
19791 assert!((out[1] - 1.0).abs() < 1e-6);
19792 assert!(out[2].abs() < 1e-6);
19793 }
19794
19795 #[test]
19796 fn complex_norm_sq_backward_matches_wirtinger_formula() {
19797 let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
19799 let g = [1.0_f32, 1.0_f32];
19800 let dz = run_complex_norm_sq_bwd(&z, &g);
19801 assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
19802 assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
19803 }
19804
19805 #[test]
19806 fn complex_norm_sq_backward_scales_with_upstream() {
19807 let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
19809 let g = [0.5_f32, -2.0_f32];
19810 let dz = run_complex_norm_sq_bwd(&z, &g);
19811 assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
19812 assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
19813 }
19814
19815 #[test]
19820 fn custom_fn_multi_extracts_each_subgraph_output() {
19821 use rlx_ir::ops::special::MultiOutputHandle;
19822
19823 let _ = MultiOutputHandle {
19824 source: NodeId(0),
19825 sub_shapes: vec![],
19826 offsets: vec![],
19827 }; let mut body = Graph::new("multi_body");
19831 let s3 = Shape::new(&[3], DType::F32);
19832 let x = body.input("x", s3.clone());
19833 let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
19834 let two = body.add_node(
19835 Op::Constant {
19836 data: vec![
19837 2.0_f32.to_le_bytes(),
19838 2.0_f32.to_le_bytes(),
19839 2.0_f32.to_le_bytes(),
19840 ]
19841 .into_iter()
19842 .flatten()
19843 .collect(),
19844 },
19845 vec![],
19846 s3.clone(),
19847 );
19848 let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
19849 body.set_outputs(vec![x_sq, two_x]);
19850
19851 let mut outer = Graph::new("multi_outer");
19853 let in_x = outer.input("xin", s3.clone());
19854 let handle = outer.custom_fn_multi(vec![in_x], body);
19855 assert_eq!(handle.n_outputs(), 2);
19856 let out0 = handle.output(&mut outer, 0); let out1 = handle.output(&mut outer, 1); outer.set_outputs(vec![out0, out1]);
19859
19860 let plan = rlx_opt::memory::plan_memory(&outer);
19861 let mut arena = crate::arena::Arena::from_plan(plan);
19862 let sched = compile_thunks(&outer, &arena);
19863 let xin_off = arena.byte_offset(in_x);
19864 let out0_off = arena.byte_offset(out0);
19865 let out1_off = arena.byte_offset(out1);
19866 let xs = [1.0_f32, 2.0, 3.0];
19867 unsafe {
19868 let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
19869 for (i, &v) in xs.iter().enumerate() {
19870 *p.add(i) = v;
19871 }
19872 }
19873 execute_thunks(&sched, arena.raw_buf_mut());
19874 let out0_v: Vec<f32> = unsafe {
19875 let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
19876 (0..3).map(|i| *p.add(i)).collect()
19877 };
19878 let out1_v: Vec<f32> = unsafe {
19879 let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
19880 (0..3).map(|i| *p.add(i)).collect()
19881 };
19882 for i in 0..3 {
19884 assert!(
19885 (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
19886 "out0[{i}] = {} != x² = {}",
19887 out0_v[i],
19888 xs[i] * xs[i]
19889 );
19890 assert!(
19891 (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
19892 "out1[{i}] = {} != 2x = {}",
19893 out1_v[i],
19894 2.0 * xs[i]
19895 );
19896 }
19897 }
19898
19899 #[test]
19900 fn complex_norm_sq_gradient_matches_finite_difference() {
19901 let z = [(3.0_f32, 4.0_f32)];
19903 let eps = 1e-3_f32;
19904 let v0 = run_complex_norm_sq(&z)[0];
19905 let z_pert = [(3.0_f32 + eps, 4.0_f32)];
19906 let v1 = run_complex_norm_sq(&z_pert)[0];
19907 let fd_re = (v1 - v0) / eps;
19908 let analytic_re = 2.0 * z[0].0;
19909 assert!((fd_re - analytic_re).abs() < 1e-2);
19910
19911 let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
19913 let v2 = run_complex_norm_sq(&z_pert_im)[0];
19914 let fd_im = (v2 - v0) / eps;
19915 let analytic_im = 2.0 * z[0].1;
19916 assert!((fd_im - analytic_im).abs() < 1e-2);
19917
19918 let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
19924 assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
19925 assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
19926 }
19927
19928 #[test]
19933 fn binary_full_5d_mid_singleton_broadcast() {
19934 let bh = 2usize;
19935 let h = 3;
19936 let w = 4;
19937 let f = DType::F32;
19938
19939 let mut g = Graph::new("bcast_5d");
19940 let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
19941 let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
19943 let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
19944 g.set_outputs(vec![out]);
19945
19946 let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
19948 let rhs_data: Vec<f32> = (0..bh * h * w * w)
19949 .map(|i| (i as f32 + 100.0) * 0.01)
19950 .collect();
19951
19952 let mut expected = vec![0f32; bh * h * w * h * w];
19954 for b_ in 0..bh {
19955 for hq in 0..h {
19956 for wq in 0..w {
19957 for hk in 0..h {
19958 for wk in 0..w {
19959 let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
19960 let ri = ((b_ * h + hq) * w + wq) * w + wk;
19962 expected[li] = lhs_data[li] + rhs_data[ri];
19963 }
19964 }
19965 }
19966 }
19967 }
19968
19969 let plan = rlx_opt::memory::plan_memory(&g);
19970 let mut arena = crate::arena::Arena::from_plan(plan);
19971 let sched = compile_thunks(&g, &arena);
19972 let lhs_off = arena.byte_offset(lhs);
19973 let rhs_off = arena.byte_offset(rhs);
19974 let out_off = arena.byte_offset(out);
19975 let buf = arena.raw_buf_mut();
19976 unsafe {
19977 let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
19978 for (i, &v) in lhs_data.iter().enumerate() {
19979 *p.add(i) = v;
19980 }
19981 let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
19982 for (i, &v) in rhs_data.iter().enumerate() {
19983 *p.add(i) = v;
19984 }
19985 }
19986 execute_thunks(&sched, arena.raw_buf_mut());
19987 let actual: Vec<f32> = unsafe {
19988 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19989 (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
19990 };
19991
19992 let mut max_diff = 0f32;
19994 let mut max_idx = 0;
19995 for i in 0..actual.len() {
19996 let d = (actual[i] - expected[i]).abs();
19997 if d > max_diff {
19998 max_diff = d;
19999 max_idx = i;
20000 }
20001 }
20002 assert!(
20003 max_diff < 1e-6,
20004 "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
20005 (actual={}, expected={})",
20006 actual[max_idx],
20007 expected[max_idx]
20008 );
20009 }
20010
20011 #[test]
20012 fn layer_norm2d_and_conv_transpose2d_kernels() {
20013 let mut out = vec![0f32; 8];
20014 crate::kernels::layer_norm2d_nchw(
20015 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
20016 &[1.0, 1.0],
20017 &[0.0, 0.0],
20018 &mut out,
20019 1,
20020 2,
20021 2,
20022 2,
20023 1e-5,
20024 );
20025 let mean0: f32 = (1.0 + 3.0) / 2.0;
20026 assert!((out[0] - mean0).abs() > 0.1);
20027
20028 let mut up = vec![0f32; 4];
20029 crate::kernels::conv_transpose2d_nchw(
20030 &[2.0],
20031 &[1.0, 0.0, 0.0, 1.0],
20032 &mut up,
20033 1,
20034 1,
20035 1,
20036 1,
20037 1,
20038 2,
20039 2,
20040 2,
20041 2,
20042 2,
20043 2,
20044 0,
20045 0,
20046 1,
20047 1,
20048 1,
20049 );
20050 assert!((up[0] - 2.0).abs() < 1e-5);
20051 assert!((up[3] - 2.0).abs() < 1e-5);
20052 }
20053}