1#![allow(unsafe_op_in_unsafe_fn)]
24use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33#[derive(Clone)]
35pub enum Thunk {
36 Nop,
38 Sgemm {
40 a: usize,
41 b: usize,
42 c: usize,
43 m: u32,
44 k: u32,
45 n: u32,
46 },
47 DenseSolveF64 {
53 a: usize,
54 b: usize,
55 x: usize,
56 n: u32,
57 nrhs: u32,
58 },
59 DenseSolveF32 {
62 a: usize,
63 b: usize,
64 x: usize,
65 n: u32,
66 nrhs: u32,
67 },
68 BatchedDenseSolveF64 {
73 a: usize,
74 b: usize,
75 x: usize,
76 batch: u32,
77 n: u32,
78 nrhs: u32,
79 },
80 BatchedDenseSolveF32 {
82 a: usize,
83 b: usize,
84 x: usize,
85 batch: u32,
86 n: u32,
87 nrhs: u32,
88 },
89 BatchedDgemmF64 {
95 a: usize,
96 b: usize,
97 c: usize,
98 batch: u32,
99 m: u32,
100 k: u32,
101 n: u32,
102 },
103 BatchedSgemm {
110 a: usize,
111 b: usize,
112 c: usize,
113 batch: u32,
114 m: u32,
115 k: u32,
116 n: u32,
117 },
118 Dgemm {
120 a: usize,
121 b: usize,
122 c: usize,
123 m: u32,
124 k: u32,
125 n: u32,
126 },
127 TransposeF64 {
131 src: usize,
132 dst: usize,
133 in_total: u32,
134 out_dims: Vec<u32>,
135 in_strides: Vec<u32>,
136 },
137 ActivationF64 {
141 src: usize,
142 dst: usize,
143 len: u32,
144 kind: Activation,
145 },
146 ComplexNormSqF32 {
150 src: usize,
151 dst: usize,
152 len: u32,
154 },
155 ComplexNormSqBackwardF32 {
159 z: usize,
160 g: usize,
161 dz: usize,
162 len: u32,
163 },
164 ConjugateC64 { src: usize, dst: usize, len: u32 },
167 ActivationC64 {
174 src: usize,
175 dst: usize,
176 len: u32,
177 kind: Activation,
178 },
179 ReduceSumF64 {
183 src: usize,
184 dst: usize,
185 outer: u32,
186 reduced: u32,
187 inner: u32,
188 },
189 CopyF64 { src: usize, dst: usize, len: u32 },
192 BinaryFullF64 {
196 lhs: usize,
197 rhs: usize,
198 dst: usize,
199 len: u32,
200 lhs_len: u32,
201 rhs_len: u32,
202 op: BinaryOp,
203 out_dims_bcast: Vec<u32>,
206 bcast_lhs_strides: Vec<u32>,
207 bcast_rhs_strides: Vec<u32>,
208 },
209 ConcatF64 {
213 dst: usize,
214 outer: u32,
215 inner: u32,
216 total_axis: u32,
217 inputs: Vec<(usize, u32)>,
218 },
219 BinaryFullC64 {
227 lhs: usize,
228 rhs: usize,
229 dst: usize,
230 len: u32,
233 lhs_len: u32,
234 rhs_len: u32,
235 op: BinaryOp,
236 out_dims_bcast: Vec<u32>,
237 bcast_lhs_strides: Vec<u32>,
238 bcast_rhs_strides: Vec<u32>,
239 },
240 Scan {
249 body: Arc<ThunkSchedule>,
250 body_init: Arc<Vec<u8>>, body_input_off: usize, body_output_off: usize, outer_init_off: usize, outer_final_off: usize, length: u32,
256 carry_bytes: u32, save_trajectory: bool,
262 xs_inputs: Arc<Vec<(usize, usize, u32)>>,
267 bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
273 num_checkpoints: u32,
279 },
280
281 ScanBackward {
289 body_vjp: Arc<ThunkSchedule>,
290 body_init: Arc<Vec<u8>>,
291 body_carry_in_off: usize, body_x_offs: Arc<Vec<usize>>, body_d_output_off: usize, body_dcarry_out_off: usize, outer_init_off: usize, outer_traj_off: usize, outer_upstream_off: usize, outer_xs_offs: Arc<Vec<(usize, u32)>>,
301 outer_dinit_off: usize, length: u32,
303 carry_bytes: u32,
304 carry_elem_size: u32,
310 save_trajectory: bool, num_checkpoints: u32,
317 forward_body: Option<Arc<ThunkSchedule>>,
321 forward_body_init: Option<Arc<Vec<u8>>>,
323 forward_body_carry_in_off: usize,
326 forward_body_output_off: usize,
327 forward_body_x_offs: Arc<Vec<usize>>,
330 },
331
332 ScanBackwardXs {
339 body_vjp: Arc<ThunkSchedule>,
340 body_init: Arc<Vec<u8>>,
341 body_carry_in_off: usize,
342 body_x_offs: Arc<Vec<usize>>,
343 body_d_output_off: usize,
344 body_dcarry_out_off: usize,
345 body_dxs_out_off: usize, outer_init_off: usize,
347 outer_traj_off: usize,
348 outer_upstream_off: usize,
349 outer_xs_offs: Arc<Vec<(usize, u32)>>,
350 outer_dxs_off: usize, length: u32,
352 carry_bytes: u32,
353 carry_elem_size: u32,
355 per_step_bytes: u32, save_trajectory: bool,
357 num_checkpoints: u32,
365 forward_body: Option<Arc<ThunkSchedule>>,
366 forward_body_init: Option<Arc<Vec<u8>>>,
367 forward_body_carry_in_off: usize,
368 forward_body_output_off: usize,
369 forward_body_x_offs: Arc<Vec<usize>>,
370 },
371 CustomFn {
376 body: Arc<ThunkSchedule>,
377 body_init: Arc<Vec<u8>>,
378 inputs: Arc<Vec<(usize, usize, u32)>>,
380 body_output_off: usize,
381 outer_output_off: usize,
382 out_bytes: u32,
383 },
384 FusedMmBiasAct {
386 a: usize,
387 w: usize,
388 bias: usize,
389 c: usize,
390 m: u32,
391 k: u32,
392 n: u32,
393 act: Option<Activation>,
394 },
395 FusedResidualLN {
397 x: usize,
398 res: usize,
399 bias: usize,
400 g: usize,
401 b: usize,
402 out: usize,
403 rows: u32,
404 h: u32,
405 eps: f32,
406 has_bias: bool,
407 },
408 FusedResidualRmsNorm {
410 x: usize,
411 res: usize,
412 bias: usize,
413 g: usize,
414 b: usize,
415 out: usize,
416 rows: u32,
417 h: u32,
418 eps: f32,
419 has_bias: bool,
420 },
421 BiasAdd {
423 src: usize,
424 bias: usize,
425 dst: usize,
426 m: u32,
427 n: u32,
428 },
429 BinaryFull {
444 lhs: usize,
445 rhs: usize,
446 dst: usize,
447 len: u32,
448 lhs_len: u32,
449 rhs_len: u32,
450 op: BinaryOp,
451 out_dims_bcast: Vec<u32>,
453 bcast_lhs_strides: Vec<u32>,
455 bcast_rhs_strides: Vec<u32>,
457 },
458 ActivationInPlace {
460 data: usize,
461 len: u32,
462 act: Activation,
463 },
464 Gather {
466 table: usize,
467 table_len: u32,
468 idx: usize,
469 dst: usize,
470 num_idx: u32,
471 trailing: u32,
472 },
473 Narrow {
475 src: usize,
476 dst: usize,
477 outer: u32,
478 src_stride: u32,
479 dst_stride: u32,
480 inner: u32,
481 elem_bytes: u8,
482 },
483 Copy { src: usize, dst: usize, len: u32 },
485 LayerNorm {
487 src: usize,
488 g: usize,
489 b: usize,
490 dst: usize,
491 rows: u32,
492 h: u32,
493 eps: f32,
494 },
495 GroupNorm {
497 src: usize,
498 g: usize,
499 b: usize,
500 dst: usize,
501 n: u32,
502 c: u32,
503 h: u32,
504 w: u32,
505 num_groups: u32,
506 eps: f32,
507 },
508 LayerNorm2d {
510 src: usize,
511 g: usize,
512 b: usize,
513 dst: usize,
514 n: u32,
515 c: u32,
516 h: u32,
517 w: u32,
518 eps: f32,
519 },
520 ConvTranspose2d {
522 src: usize,
523 weight: usize,
524 dst: usize,
525 n: u32,
526 c_in: u32,
527 h: u32,
528 w_in: u32,
529 c_out: u32,
530 h_out: u32,
531 w_out: u32,
532 kh: u32,
533 kw: u32,
534 sh: u32,
535 sw: u32,
536 ph: u32,
537 pw: u32,
538 dh: u32,
539 dw: u32,
540 groups: u32,
541 },
542 ResizeNearest2x {
544 src: usize,
545 dst: usize,
546 n: u32,
547 c: u32,
548 h: u32,
549 w: u32,
550 },
551 AxialRope2d {
553 src: usize,
554 dst: usize,
555 batch: u32,
556 seq: u32,
557 hidden: u32,
558 end_x: u32,
559 end_y: u32,
560 head_dim: u32,
561 num_heads: u32,
562 theta: f32,
563 repeat_factor: u32,
564 },
565 RmsNorm {
568 src: usize,
569 g: usize,
570 b: usize,
571 dst: usize,
572 rows: u32,
573 h: u32,
574 eps: f32,
575 },
576 Softmax { data: usize, rows: u32, cols: u32 },
578 Cumsum {
581 src: usize,
582 dst: usize,
583 rows: u32,
584 cols: u32,
585 exclusive: bool,
586 },
587 SelectiveScan {
591 x: usize,
592 delta: usize,
593 a: usize,
594 b: usize,
595 c: usize,
596 dst: usize,
597 batch: u32,
598 seq: u32,
599 hidden: u32,
600 state_size: u32,
601 },
602
603 GatedDeltaNet {
607 q: usize,
608 k: usize,
609 v: usize,
610 g: usize,
611 beta: usize,
612 state: usize,
615 dst: usize,
616 batch: u32,
617 seq: u32,
618 heads: u32,
619 state_size: u32,
620 },
621
622 Conv2D1x1 {
632 src: usize,
633 weight: usize,
634 dst: usize,
635 n: u32,
636 c_in: u32,
637 c_out: u32,
638 hw: u32,
639 },
640
641 DequantMatMul {
645 x: usize,
646 w_q: usize, scale: usize, zp: usize, dst: usize,
650 m: u32,
651 k: u32,
652 n: u32,
653 block_size: u32,
654 is_asymmetric: bool,
655 },
656
657 DequantMatMulGguf {
667 x: usize, w_q: usize, dst: usize, m: u32,
671 k: u32,
672 n: u32,
673 scheme: rlx_ir::quant::QuantScheme,
674 },
675
676 DequantMatMulInt4 {
678 x: usize,
679 w_q: usize,
680 scale: usize,
681 zp: usize,
682 dst: usize,
683 m: u32,
684 k: u32,
685 n: u32,
686 block_size: u32,
687 is_asymmetric: bool,
688 },
689
690 DequantMatMulFp8 {
692 x: usize,
693 w_q: usize,
694 scale: usize,
695 dst: usize,
696 m: u32,
697 k: u32,
698 n: u32,
699 e5m2: bool,
700 },
701
702 DequantMatMulNvfp4 {
704 x: usize,
705 w_q: usize,
706 scale: usize,
707 global_scale: usize,
708 dst: usize,
709 m: u32,
710 k: u32,
711 n: u32,
712 },
713
714 LoraMatMul {
718 x: usize,
719 w: usize,
720 a: usize,
721 b: usize,
722 dst: usize,
723 m: u32,
724 k: u32,
725 n: u32,
726 r: u32,
727 scale: f32,
728 },
729 Sample {
733 logits: usize,
734 dst: usize,
735 batch: u32,
736 vocab: u32,
737 top_k: u32, top_p: f32, temperature: f32, seed: u64,
741 },
742 Attention {
753 q: usize,
754 k: usize,
755 v: usize,
756 mask: usize,
757 out: usize,
758 batch: u32,
759 seq: u32,
761 kv_seq: u32,
763 heads: u32,
764 head_dim: u32,
765 mask_kind: rlx_ir::op::MaskKind,
766 q_row_stride: u32,
767 k_row_stride: u32,
768 v_row_stride: u32,
769 bhsd: bool,
777 },
778 AttentionBackward {
780 q: usize,
781 k: usize,
782 v: usize,
783 dy: usize,
784 mask: usize,
785 out: usize,
786 batch: u32,
787 seq: u32,
788 kv_seq: u32,
789 heads: u32,
790 head_dim: u32,
791 mask_kind: rlx_ir::op::MaskKind,
792 wrt: rlx_ir::op::AttentionBwdWrt,
793 bhsd: bool,
794 },
795 Rope {
801 src: usize,
802 cos: usize,
803 sin: usize,
804 dst: usize,
805 batch: u32,
806 seq: u32,
807 hidden: u32,
808 head_dim: u32,
809 n_rot: u32,
810 cos_len: u32,
811 src_row_stride: u32,
812 },
813 FusedAttnBlock {
816 hidden: usize,
817 qkv_w: usize,
818 out_w: usize,
819 mask: usize,
820 out: usize,
821 qkv_b: usize,
822 out_b: usize, cos: usize,
824 sin: usize,
825 cos_len: u32, batch: u32,
827 seq: u32,
828 hs: u32,
829 nh: u32,
830 dh: u32,
831 has_bias: bool,
832 has_rope: bool,
833 },
834 FusedBertLayer {
837 hidden: usize,
839 qkv_w: usize,
840 qkv_b: usize,
841 out_w: usize,
842 out_b: usize,
843 mask: usize,
844 ln1_g: usize,
846 ln1_b: usize,
847 eps1: f32,
848 fc1_w: usize,
850 fc1_b: usize,
851 fc2_w: usize,
852 fc2_b: usize,
853 ln2_g: usize,
855 ln2_b: usize,
856 eps2: f32,
857 out: usize,
859 batch: u32,
861 seq: u32,
862 hs: u32,
863 nh: u32,
864 dh: u32,
865 int_dim: u32,
866 },
867 FusedNomicLayer {
869 hidden: usize,
870 qkv_w: usize,
871 out_w: usize,
872 mask: usize,
873 cos: usize,
874 sin: usize,
875 cos_len: u32,
876 ln1_g: usize,
877 ln1_b: usize,
878 eps1: f32,
879 fc11_w: usize,
880 fc12_w: usize,
881 fc2_w: usize,
882 ln2_g: usize,
883 ln2_b: usize,
884 eps2: f32,
885 out: usize,
886 batch: u32,
887 seq: u32,
888 hs: u32,
889 nh: u32,
890 dh: u32,
891 int_dim: u32,
892 },
893 FusedSwiGLU {
897 src: usize,
898 dst: usize,
899 n_half: u32,
900 total: u32,
901 gate_first: bool,
902 },
903 Concat {
908 dst: usize,
909 outer: u32,
910 inner: u32,
911 total_axis: u32,
912 inputs: Vec<(usize, u32)>,
913 },
914 Compare {
916 lhs: usize,
917 rhs: usize,
918 dst: usize,
919 len: u32,
920 op: CmpOp,
921 },
922 Reduce {
930 src: usize,
931 dst: usize,
932 outer: u32,
933 reduced: u32,
934 inner: u32,
935 op: ReduceOp,
936 },
937 TopK {
941 src: usize,
942 dst: usize,
943 outer: u32,
944 axis_dim: u32,
945 k: u32,
946 },
947 GroupedMatMul {
951 input: usize,
952 weight: usize,
953 expert_idx: usize,
954 dst: usize,
955 m: u32,
956 k_dim: u32,
957 n: u32,
958 num_experts: u32,
959 },
960 DequantGroupedMatMulGguf {
962 input: usize,
963 w_q: usize,
964 expert_idx: usize,
965 dst: usize,
966 m: u32,
967 k_dim: u32,
968 n: u32,
969 num_experts: u32,
970 scheme: rlx_ir::quant::QuantScheme,
971 },
972 DequantMoEWeightsGguf {
974 w_q: usize,
975 dst: usize,
976 k_dim: u32,
977 n: u32,
978 num_experts: u32,
979 scheme: rlx_ir::quant::QuantScheme,
980 },
981 ScatterAdd {
984 updates: usize,
985 indices: usize,
986 dst: usize,
987 num_updates: u32,
988 out_dim: u32,
989 trailing: u32,
990 },
991 Where {
993 cond: usize,
994 on_true: usize,
995 on_false: usize,
996 dst: usize,
997 len: u32,
998 },
999 Transpose {
1005 src: usize,
1006 dst: usize,
1007 in_total: u32,
1008 out_dims: Vec<u32>,
1009 in_strides: Vec<u32>,
1010 },
1011 GatherAxis {
1016 table: usize,
1017 idx: usize,
1018 dst: usize,
1019 outer: u32,
1020 axis_dim: u32,
1021 num_idx: u32,
1022 trailing: u32,
1023 },
1024 Pool2D {
1028 src: usize,
1029 dst: usize,
1030 n: u32,
1031 c: u32,
1032 h: u32,
1033 w: u32,
1034 h_out: u32,
1035 w_out: u32,
1036 kh: u32,
1037 kw: u32,
1038 sh: u32,
1039 sw: u32,
1040 ph: u32,
1041 pw: u32,
1042 kind: ReduceOp,
1043 },
1044 Conv2D {
1049 src: usize,
1050 weight: usize,
1051 dst: usize,
1052 n: u32,
1053 c_in: u32,
1054 h: u32,
1055 w: u32,
1056 c_out: u32,
1057 h_out: u32,
1058 w_out: u32,
1059 kh: u32,
1060 kw: u32,
1061 sh: u32,
1062 sw: u32,
1063 ph: u32,
1064 pw: u32,
1065 dh: u32,
1066 dw: u32,
1067 groups: u32,
1068 },
1069
1070 QMatMul {
1078 x: usize,
1079 w: usize,
1080 bias: usize,
1081 out: usize,
1082 m: u32,
1083 k: u32,
1084 n: u32,
1085 x_zp: i32,
1086 w_zp: i32,
1087 out_zp: i32,
1088 mult: f32,
1089 },
1090
1091 QConv2d {
1095 x: usize,
1096 w: usize,
1097 bias: usize,
1098 out: usize,
1099 n: u32,
1100 c_in: u32,
1101 h: u32,
1102 w_in: u32,
1103 c_out: u32,
1104 h_out: u32,
1105 w_out: u32,
1106 kh: u32,
1107 kw: u32,
1108 sh: u32,
1109 sw: u32,
1110 ph: u32,
1111 pw: u32,
1112 dh: u32,
1113 dw: u32,
1114 groups: u32,
1115 x_zp: i32,
1116 w_zp: i32,
1117 out_zp: i32,
1118 mult: f32,
1119 },
1120
1121 Quantize {
1128 x: usize,
1129 q: usize,
1130 len: u32,
1131 chan_axis: u32,
1132 chan_dim: u32,
1133 inner: u32,
1134 scales: Vec<f32>,
1135 zero_points: Vec<i32>,
1136 },
1137
1138 Dequantize {
1140 q: usize,
1141 x: usize,
1142 len: u32,
1143 chan_axis: u32,
1144 chan_dim: u32,
1145 inner: u32,
1146 scales: Vec<f32>,
1147 zero_points: Vec<i32>,
1148 },
1149
1150 FakeQuantize {
1161 x: usize,
1162 out: usize,
1163 len: u32,
1164 chan_axis: u32,
1165 chan_dim: u32,
1166 inner: u32,
1167 bits: u8,
1168 ste: rlx_ir::op::SteKind,
1172 scale_mode: rlx_ir::op::ScaleMode,
1177 state_off: Option<usize>,
1181 },
1182
1183 FakeQuantizeBackward {
1188 x: usize,
1189 dy: usize,
1190 dx: usize,
1191 len: u32,
1192 chan_axis: u32,
1193 chan_dim: u32,
1194 inner: u32,
1195 bits: u8,
1196 ste: rlx_ir::op::SteKind,
1197 },
1198
1199 FakeQuantizeLSQ {
1202 x: usize,
1203 scale_off: usize,
1204 out: usize,
1205 len: u32,
1206 chan_axis: u32,
1207 chan_dim: u32,
1208 inner: u32,
1209 bits: u8,
1210 },
1211
1212 FakeQuantizeLSQBackwardX {
1215 x: usize,
1216 scale_off: usize,
1217 dy: usize,
1218 dx: usize,
1219 len: u32,
1220 chan_axis: u32,
1221 chan_dim: u32,
1222 inner: u32,
1223 bits: u8,
1224 },
1225
1226 FakeQuantizeLSQBackwardScale {
1231 x: usize,
1232 scale_off: usize,
1233 dy: usize,
1234 dscale: usize,
1235 len: u32,
1236 chan_axis: u32,
1237 chan_dim: u32,
1238 inner: u32,
1239 bits: u8,
1240 },
1241
1242 ReluBackward {
1244 x: usize,
1245 dy: usize,
1246 dx: usize,
1247 len: u32,
1248 },
1249 ReluBackwardF64 {
1255 x: usize,
1256 dy: usize,
1257 dx: usize,
1258 len: u32,
1259 },
1260
1261 ActivationBackward {
1266 x: usize,
1267 dy: usize,
1268 dx: usize,
1269 len: u32,
1270 kind: Activation,
1271 },
1272 ActivationBackwardF64 {
1278 x: usize,
1279 dy: usize,
1280 dx: usize,
1281 len: u32,
1282 kind: Activation,
1283 },
1284
1285 LayerNormBackwardInput {
1288 x: usize,
1289 gamma: usize,
1290 dy: usize,
1291 dx: usize,
1292 rows: u32,
1293 h: u32,
1294 eps: f32,
1295 },
1296
1297 LayerNormBackwardGamma {
1299 x: usize,
1300 dy: usize,
1301 dgamma: usize,
1302 rows: u32,
1303 h: u32,
1304 eps: f32,
1305 },
1306
1307 RmsNormBackwardInput {
1308 x: usize,
1309 gamma: usize,
1310 beta: usize,
1311 dy: usize,
1312 dx: usize,
1313 rows: u32,
1314 h: u32,
1315 eps: f32,
1316 },
1317 RmsNormBackwardGamma {
1318 x: usize,
1319 gamma: usize,
1320 beta: usize,
1321 dy: usize,
1322 dgamma: usize,
1323 rows: u32,
1324 h: u32,
1325 eps: f32,
1326 },
1327 RmsNormBackwardBeta {
1328 x: usize,
1329 gamma: usize,
1330 beta: usize,
1331 dy: usize,
1332 dbeta: usize,
1333 rows: u32,
1334 h: u32,
1335 eps: f32,
1336 },
1337 RopeBackward {
1338 dy: usize,
1339 cos: usize,
1340 sin: usize,
1341 dx: usize,
1342 batch: u32,
1343 seq: u32,
1344 hidden: u32,
1345 head_dim: u32,
1346 n_rot: u32,
1347 cos_len: u32,
1348 },
1349 CumsumBackward {
1350 dy: usize,
1351 dx: usize,
1352 rows: u32,
1353 cols: u32,
1354 exclusive: bool,
1355 },
1356 GatherBackward {
1357 dy: usize,
1358 indices: usize,
1359 dst: usize,
1360 outer: u32,
1361 axis_dim: u32,
1362 num_idx: u32,
1363 trailing: u32,
1364 },
1365
1366 GroupNormBackwardInput {
1367 x: usize,
1368 gamma: usize,
1369 beta: usize,
1370 dy: usize,
1371 dx: usize,
1372 n: u32,
1373 c: u32,
1374 h: u32,
1375 w: u32,
1376 num_groups: u32,
1377 eps: f32,
1378 },
1379 GroupNormBackwardGamma {
1380 x: usize,
1381 dy: usize,
1382 dgamma: usize,
1383 n: u32,
1384 c: u32,
1385 h: u32,
1386 w: u32,
1387 num_groups: u32,
1388 eps: f32,
1389 },
1390 GroupNormBackwardBeta {
1391 dy: usize,
1392 dbeta: usize,
1393 n: u32,
1394 c: u32,
1395 h: u32,
1396 w: u32,
1397 },
1398
1399 MaxPool2dBackward {
1405 x: usize,
1406 dy: usize,
1407 dx: usize,
1408 n: u32,
1409 c: u32,
1410 h: u32,
1411 w: u32,
1412 h_out: u32,
1413 w_out: u32,
1414 kh: u32,
1415 kw: u32,
1416 sh: u32,
1417 sw: u32,
1418 ph: u32,
1419 pw: u32,
1420 },
1421
1422 Conv2dBackwardInput {
1426 dy: usize,
1427 w: usize,
1428 dx: usize,
1429 n: u32,
1430 c_in: u32,
1431 h: u32,
1432 w_in: u32,
1433 c_out: u32,
1434 h_out: u32,
1435 w_out: u32,
1436 kh: u32,
1437 kw: u32,
1438 sh: u32,
1439 sw: u32,
1440 ph: u32,
1441 pw: u32,
1442 dh: u32,
1443 dw: u32,
1444 groups: u32,
1445 },
1446
1447 Conv2dBackwardWeight {
1451 x: usize,
1452 dy: usize,
1453 dw: usize,
1454 n: u32,
1455 c_in: u32,
1456 h: u32,
1457 w: u32,
1458 c_out: u32,
1459 h_out: u32,
1460 w_out: u32,
1461 kh: u32,
1462 kw: u32,
1463 sh: u32,
1464 sw: u32,
1465 ph: u32,
1466 pw: u32,
1467 dh: u32,
1468 dw_dil: u32,
1469 groups: u32,
1470 },
1471
1472 SoftmaxCrossEntropy {
1476 logits: usize,
1477 labels: usize,
1478 dst: usize,
1479 n: u32,
1480 c: u32,
1481 },
1482
1483 SoftmaxCrossEntropyBackward {
1486 logits: usize,
1487 labels: usize,
1488 d_loss: usize,
1489 dlogits: usize,
1490 n: u32,
1491 c: u32,
1492 },
1493
1494 CustomOp {
1500 kernel: Arc<dyn CpuKernel>,
1501 inputs: Vec<(usize, u32, Shape)>, output: (usize, u32, Shape), attrs: Vec<u8>,
1504 },
1505
1506 GaussianSplatRender {
1516 positions_off: usize,
1517 positions_len: usize,
1518 scales_off: usize,
1519 scales_len: usize,
1520 rotations_off: usize,
1521 rotations_len: usize,
1522 opacities_off: usize,
1523 opacities_len: usize,
1524 colors_off: usize,
1525 colors_len: usize,
1526 sh_coeffs_off: usize,
1527 sh_coeffs_len: usize,
1528 meta_off: usize,
1529 dst_off: usize,
1530 dst_len: usize,
1531 width: u32,
1532 height: u32,
1533 tile_size: u32,
1534 radius_scale: f32,
1535 alpha_cutoff: f32,
1536 max_splat_steps: u32,
1537 transmittance_threshold: f32,
1538 max_list_entries: u32,
1539 },
1540 GaussianSplatRenderBackward {
1541 positions_off: usize,
1542 positions_len: usize,
1543 scales_off: usize,
1544 scales_len: usize,
1545 rotations_off: usize,
1546 rotations_len: usize,
1547 opacities_off: usize,
1548 opacities_len: usize,
1549 colors_off: usize,
1550 colors_len: usize,
1551 sh_coeffs_off: usize,
1552 sh_coeffs_len: usize,
1553 meta_off: usize,
1554 d_loss_off: usize,
1555 d_loss_len: usize,
1556 packed_off: usize,
1557 packed_len: usize,
1558 width: u32,
1559 height: u32,
1560 tile_size: u32,
1561 radius_scale: f32,
1562 alpha_cutoff: f32,
1563 max_splat_steps: u32,
1564 transmittance_threshold: f32,
1565 max_list_entries: u32,
1566 loss_grad_clip: f32,
1567 sh_band: u32,
1568 max_anisotropy: f32,
1569 },
1570 GaussianSplatPrepare {
1572 positions_off: usize,
1573 positions_len: usize,
1574 scales_off: usize,
1575 scales_len: usize,
1576 rotations_off: usize,
1577 rotations_len: usize,
1578 opacities_off: usize,
1579 opacities_len: usize,
1580 colors_off: usize,
1581 colors_len: usize,
1582 sh_coeffs_off: usize,
1583 sh_coeffs_len: usize,
1584 meta_off: usize,
1585 meta_len: usize,
1586 prep_off: usize,
1587 prep_len: usize,
1588 width: u32,
1589 height: u32,
1590 tile_size: u32,
1591 radius_scale: f32,
1592 alpha_cutoff: f32,
1593 max_splat_steps: u32,
1594 transmittance_threshold: f32,
1595 max_list_entries: u32,
1596 },
1597 GaussianSplatRasterize {
1599 prep_off: usize,
1600 prep_len: usize,
1601 meta_off: usize,
1602 meta_len: usize,
1603 dst_off: usize,
1604 dst_len: usize,
1605 count: usize,
1606 width: u32,
1607 height: u32,
1608 tile_size: u32,
1609 alpha_cutoff: f32,
1610 max_splat_steps: u32,
1611 transmittance_threshold: f32,
1612 max_list_entries: u32,
1613 },
1614 Fft1d {
1615 src: usize,
1616 dst: usize,
1617 outer: u32,
1618 n_complex: u32,
1619 inverse: bool,
1620 dtype: rlx_ir::DType,
1621 },
1622}
1623
1624#[derive(Clone)]
1627pub struct ThunkSchedule {
1628 pub thunks: Vec<Thunk>,
1629 pub moe_resident: Option<std::sync::Arc<[bool]>>,
1631 pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1633 pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1635 pub mask_threshold: f32,
1637 pub mask_neg_inf: f32,
1638 pub score_skip: f32,
1639 pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1645}
1646
1647impl ThunkSchedule {
1648 pub fn strip_nops(&mut self) {
1649 self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1650 self.compiled_fns.clear();
1653 }
1654}
1655
1656fn node_offset(arena: &Arena, id: NodeId) -> usize {
1658 if arena.has_buffer(id) {
1659 arena.byte_offset(id)
1660 } else {
1661 usize::MAX
1662 }
1663}
1664
1665fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1671 match t {
1672 Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1673 Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1674 Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1675 Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1676 Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1677 Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1678 Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1679 Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1680 Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1681 Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1682 Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1683 Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1684 Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1685 Thunk::ConjugateC64 { src, .. } => vec![*src],
1686 Thunk::Scan {
1687 outer_init_off,
1688 xs_inputs,
1689 ..
1690 } => {
1691 let mut v = vec![*outer_init_off];
1692 for (_, outer_xs_off, _) in xs_inputs.iter() {
1693 v.push(*outer_xs_off);
1694 }
1695 v
1696 }
1697 Thunk::ScanBackward {
1698 outer_init_off,
1699 outer_traj_off,
1700 outer_upstream_off,
1701 outer_xs_offs,
1702 ..
1703 } => {
1704 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1705 for (off, _) in outer_xs_offs.iter() {
1706 v.push(*off);
1707 }
1708 v
1709 }
1710 Thunk::ScanBackwardXs {
1711 outer_init_off,
1712 outer_traj_off,
1713 outer_upstream_off,
1714 outer_xs_offs,
1715 ..
1716 } => {
1717 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1718 for (off, _) in outer_xs_offs.iter() {
1719 v.push(*off);
1720 }
1721 v
1722 }
1723 Thunk::CustomFn { inputs, .. } => {
1724 inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1725 }
1726 Thunk::ActivationInPlace { data, .. } => vec![*data],
1727 Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1728 vec![*src, *g, *b]
1729 }
1730 Thunk::ResizeNearest2x { src, .. } => vec![*src],
1731 Thunk::AxialRope2d { src, .. } => vec![*src],
1732 Thunk::FusedResidualLN {
1733 x, res, bias, g, b, ..
1734 } => vec![*x, *res, *bias, *g, *b],
1735 Thunk::FusedResidualRmsNorm {
1736 x, res, bias, g, b, ..
1737 } => vec![*x, *res, *bias, *g, *b],
1738 Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1739 Thunk::Softmax { data, .. } => vec![*data],
1740 Thunk::Cumsum { src, .. } => vec![*src],
1741 Thunk::Sample { logits, .. } => vec![*logits],
1742 Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1743 Thunk::DequantMatMul {
1744 x, w_q, scale, zp, ..
1745 } => vec![*x, *w_q, *scale, *zp],
1746 Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1747 Thunk::DequantMatMulInt4 {
1748 x, w_q, scale, zp, ..
1749 } => vec![*x, *w_q, *scale, *zp],
1750 Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1751 Thunk::DequantMatMulNvfp4 {
1752 x,
1753 w_q,
1754 scale,
1755 global_scale,
1756 ..
1757 } => vec![*x, *w_q, *scale, *global_scale],
1758 Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1759 Thunk::SelectiveScan {
1760 x, delta, a, b, c, ..
1761 } => vec![*x, *delta, *a, *b, *c],
1762 Thunk::GatedDeltaNet {
1763 q,
1764 k,
1765 v,
1766 g,
1767 beta,
1768 state,
1769 ..
1770 } => {
1771 let mut v = vec![*q, *k, *v, *g, *beta];
1772 if *state != 0 {
1773 v.push(*state);
1774 }
1775 v
1776 }
1777 Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1778 Thunk::AttentionBackward {
1779 q, k, v, dy, mask, ..
1780 } => {
1781 let mut v = vec![*q, *k, *v, *dy];
1782 if *mask != 0 {
1783 v.push(*mask);
1784 }
1785 v
1786 }
1787 Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1788 Thunk::FusedAttnBlock {
1789 hidden,
1790 qkv_w,
1791 out_w,
1792 mask,
1793 qkv_b,
1794 out_b,
1795 cos,
1796 sin,
1797 ..
1798 } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1799 Thunk::FusedSwiGLU { src, .. } => vec![*src],
1800 Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1801 Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1802 Thunk::Narrow { src, .. } => vec![*src],
1803 Thunk::Copy { src, .. } => vec![*src],
1804 Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1805 _ => vec![],
1809 }
1810}
1811
1812#[allow(clippy::too_many_arguments)]
1826fn dequant_matmul_int8(
1827 x: &[f32], w_bytes: &[i8], scales: &[f32], zps: &[f32], out: &mut [f32], m: usize,
1833 k: usize,
1834 n: usize,
1835 block_size: usize,
1836 asym: bool,
1837) {
1838 let blocks_per_col = k.div_ceil(block_size);
1839 for i in 0..m {
1840 for j in 0..n {
1841 let mut acc = 0f32;
1842 for p in 0..k {
1843 let block = p / block_size;
1844 let s = scales[block * n + j];
1845 let z = if asym { zps[block * n + j] } else { 0.0 };
1846 let q = w_bytes[p * n + j] as f32;
1847 let dequantized = (q - z) * s;
1848 acc += x[i * k + p] * dequantized;
1849 }
1850 out[i * n + j] = acc;
1851 }
1852 }
1853 let _ = blocks_per_col;
1854}
1855
1856#[allow(clippy::too_many_arguments)]
1857fn dequant_matmul_int4(
1858 x: &[f32],
1859 w_bytes: &[u8],
1860 scales: &[f32],
1861 zps: &[f32],
1862 out: &mut [f32],
1863 m: usize,
1864 k: usize,
1865 n: usize,
1866 block_size: usize,
1867 asym: bool,
1868) {
1869 for i in 0..m {
1870 for j in 0..n {
1871 let mut acc = 0f32;
1872 for p in 0..k {
1873 let block = p / block_size;
1874 let s = scales[block * n + j];
1875 let z = if asym { zps[block * n + j] } else { 0.0 };
1876 let byte_idx = (p * n + j) / 2;
1877 let nibble = if (p * n + j) & 1 == 0 {
1878 w_bytes[byte_idx] & 0x0F
1879 } else {
1880 w_bytes[byte_idx] >> 4
1881 };
1882 let dequantized = (nibble as f32 - z) * s;
1883 acc += x[i * k + p] * dequantized;
1884 }
1885 out[i * n + j] = acc;
1886 }
1887 }
1888}
1889
1890fn fp8_e4m3_to_f32(b: u8) -> f32 {
1891 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1892 let exp = (b >> 3) & 0x0F;
1893 let mant = b & 0x07;
1894 if exp == 0 {
1895 if mant == 0 {
1896 return 0.0;
1897 }
1898 return sign * (mant as f32) * 2f32.powi(-9);
1899 }
1900 if exp == 0x0F {
1901 return if mant == 0 {
1902 sign * f32::INFINITY
1903 } else {
1904 f32::NAN
1905 };
1906 }
1907 sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
1908}
1909
1910fn fp8_e5m2_to_f32(b: u8) -> f32 {
1911 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1912 let exp = (b >> 2) & 0x1F;
1913 let mant = b & 0x03;
1914 if exp == 0 {
1915 if mant == 0 {
1916 return 0.0;
1917 }
1918 return sign * (mant as f32) * 2f32.powi(-16);
1919 }
1920 if exp == 0x1F {
1921 return if mant == 0 {
1922 sign * f32::INFINITY
1923 } else {
1924 f32::NAN
1925 };
1926 }
1927 sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
1928}
1929
1930#[allow(clippy::too_many_arguments)]
1931fn dequant_matmul_fp8(
1932 x: &[f32],
1933 w_bytes: &[u8],
1934 scales: &[f32],
1935 out: &mut [f32],
1936 m: usize,
1937 k: usize,
1938 n: usize,
1939 e5m2: bool,
1940) {
1941 let dequant = if e5m2 {
1942 fp8_e5m2_to_f32
1943 } else {
1944 fp8_e4m3_to_f32
1945 };
1946 for i in 0..m {
1947 for j in 0..n {
1948 let mut acc = 0f32;
1949 for p in 0..k {
1950 let w = dequant(w_bytes[p * n + j]);
1951 let s = scales.get(j).copied().unwrap_or(1.0);
1952 acc += x[i * k + p] * w * s;
1953 }
1954 out[i * n + j] = acc;
1955 }
1956 }
1957}
1958
1959#[allow(clippy::too_many_arguments)]
1960pub fn dequant_matmul_nvfp4(
1961 x: &[f32],
1962 w_bytes: &[u8],
1963 scale_bytes: &[u8],
1964 global_scale: f32,
1965 out: &mut [f32],
1966 m: usize,
1967 k: usize,
1968 n: usize,
1969) {
1970 use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
1971 let gs = NVFP4_GROUP_SIZE;
1972 for i in 0..m {
1973 for j in 0..n {
1974 let mut acc = 0f32;
1975 for p in 0..k {
1976 let byte_idx = (p * n + j) / 2;
1977 let nibble = if (p * n + j) & 1 == 0 {
1978 w_bytes[byte_idx] & 0x0F
1979 } else {
1980 w_bytes[byte_idx] >> 4
1981 };
1982 let block = p / gs;
1983 let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
1984 let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
1985 acc += x[i * k + p] * w;
1986 }
1987 out[i * n + j] = acc;
1988 }
1989 }
1990}
1991
1992fn sample_row(
2001 logits: &[f32],
2002 top_k: usize,
2003 top_p: f32,
2004 temperature: f32,
2005 rng: &mut rlx_ir::Philox4x32,
2006) -> usize {
2007 let v = logits.len();
2008 if v == 0 {
2009 return 0;
2010 }
2011 let temp = temperature.max(1e-6);
2012 let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2014
2015 if top_k > 0 && top_k < v {
2017 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2019 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2022 let cutoff = indexed[top_k - 1].1;
2023 for x in scaled.iter_mut() {
2024 if *x < cutoff {
2025 *x = f32::NEG_INFINITY;
2026 }
2027 }
2028 }
2029
2030 let mut max_l = f32::NEG_INFINITY;
2032 for &x in &scaled {
2033 if x > max_l {
2034 max_l = x;
2035 }
2036 }
2037 let mut sum = 0.0f32;
2038 for x in scaled.iter_mut() {
2039 *x = (*x - max_l).exp();
2040 sum += *x;
2041 }
2042 let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2043 for x in scaled.iter_mut() {
2044 *x *= inv;
2045 }
2046
2047 if top_p < 1.0 {
2050 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2051 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2052 let mut cum = 0.0f32;
2053 let mut keep = vec![false; v];
2054 for (idx, p) in indexed.iter() {
2055 keep[*idx] = true;
2056 cum += *p;
2057 if cum >= top_p {
2058 break;
2059 }
2060 }
2061 let mut new_sum = 0.0f32;
2062 for (i, x) in scaled.iter_mut().enumerate() {
2063 if !keep[i] {
2064 *x = 0.0;
2065 }
2066 new_sum += *x;
2067 }
2068 let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2069 for x in scaled.iter_mut() {
2070 *x *= inv;
2071 }
2072 }
2073
2074 let r = rng.next_f32();
2076 let mut acc = 0.0f32;
2077 for (i, &p) in scaled.iter().enumerate() {
2078 acc += p;
2079 if r <= acc {
2080 return i;
2081 }
2082 }
2083 v - 1 }
2085
2086#[inline]
2090fn apply_synthetic_mask(
2091 scores: &mut [f32],
2092 q_seq: usize,
2093 k_seq: usize,
2094 kind: rlx_ir::op::MaskKind,
2095) {
2096 let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2097 let q_offset = k_seq.saturating_sub(q_seq);
2098 match kind {
2099 rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2100 rlx_ir::op::MaskKind::Causal => {
2101 for qi in 0..q_seq {
2102 let abs_q = q_offset + qi;
2103 for ki in (abs_q + 1)..k_seq {
2104 scores[qi * k_seq + ki] = neg;
2105 }
2106 }
2107 }
2108 rlx_ir::op::MaskKind::SlidingWindow(w) => {
2109 for qi in 0..q_seq {
2110 let abs_q = q_offset + qi;
2111 let lo = abs_q.saturating_sub(w);
2112 for ki in 0..k_seq {
2113 if ki < lo || ki > abs_q {
2114 scores[qi * k_seq + ki] = neg;
2115 }
2116 }
2117 }
2118 }
2119 }
2120}
2121
2122pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2124 let mut thunks = Vec::with_capacity(graph.len());
2125
2126 for node in graph.nodes() {
2127 if rlx_opt::is_pure_view(graph, node) {
2131 thunks.push(Thunk::Nop);
2132 continue;
2133 }
2134 let t = match &node.op {
2135 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2136
2137 Op::FusedMatMulBiasAct { activation } => {
2138 let shape = &node.shape;
2139 let n = shape.dim(shape.rank() - 1).unwrap_static();
2140 let total = shape.num_elements().unwrap();
2141 let m = total / n;
2142 let a_len = get_len(graph, node.inputs[0]);
2143 let k = a_len / m;
2144 Thunk::FusedMmBiasAct {
2145 a: node_offset(arena, node.inputs[0]),
2146 w: node_offset(arena, node.inputs[1]),
2147 bias: node_offset(arena, node.inputs[2]),
2148 c: node_offset(arena, node.id),
2149 m: m as u32,
2150 k: k as u32,
2151 n: n as u32,
2152 act: *activation,
2153 }
2154 }
2155
2156 Op::FusedResidualLN { has_bias, eps } => {
2157 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2158 let total = node.shape.num_elements().unwrap();
2159 let rows = total / h;
2160 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2161 Thunk::FusedResidualLN {
2162 x: node_offset(arena, node.inputs[0]),
2163 res: node_offset(arena, node.inputs[1]),
2164 bias: if *has_bias {
2165 node_offset(arena, node.inputs[2])
2166 } else {
2167 0
2168 },
2169 g: node_offset(arena, node.inputs[g_idx]),
2170 b: node_offset(arena, node.inputs[b_idx]),
2171 out: node_offset(arena, node.id),
2172 rows: rows as u32,
2173 h: h as u32,
2174 eps: *eps,
2175 has_bias: *has_bias,
2176 }
2177 }
2178
2179 Op::FusedResidualRmsNorm { has_bias, eps } => {
2180 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2181 let total = node.shape.num_elements().unwrap();
2182 let rows = total / h;
2183 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2184 Thunk::FusedResidualRmsNorm {
2185 x: node_offset(arena, node.inputs[0]),
2186 res: node_offset(arena, node.inputs[1]),
2187 bias: if *has_bias {
2188 node_offset(arena, node.inputs[2])
2189 } else {
2190 0
2191 },
2192 g: node_offset(arena, node.inputs[g_idx]),
2193 b: node_offset(arena, node.inputs[b_idx]),
2194 out: node_offset(arena, node.id),
2195 rows: rows as u32,
2196 h: h as u32,
2197 eps: *eps,
2198 has_bias: *has_bias,
2199 }
2200 }
2201
2202 Op::MatMul => {
2203 let shape = &node.shape;
2204 let a_shape = &graph.node(node.inputs[0]).shape;
2205 let b_shape = &graph.node(node.inputs[1]).shape;
2206 let n = shape.dim(shape.rank() - 1).unwrap_static();
2207
2208 let batched_3d = a_shape.rank() >= 3
2215 && b_shape.rank() == a_shape.rank()
2216 && shape.rank() == a_shape.rank()
2217 && {
2218 let mut ok = true;
2220 for d in 0..a_shape.rank() - 2 {
2221 if a_shape.dim(d) != b_shape.dim(d) || a_shape.dim(d) != shape.dim(d) {
2222 ok = false;
2223 break;
2224 }
2225 }
2226 ok
2227 };
2228 if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2229 let r = shape.rank();
2233 let mut batch_prod = 1usize;
2234 for d in 0..r - 2 {
2235 batch_prod *= shape.dim(d).unwrap_static();
2236 }
2237 let m_dim = shape.dim(r - 2).unwrap_static();
2238 let k_dim = a_shape.dim(r - 1).unwrap_static();
2239 debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2240 Thunk::BatchedDgemmF64 {
2241 a: node_offset(arena, node.inputs[0]),
2242 b: node_offset(arena, node.inputs[1]),
2243 c: node_offset(arena, node.id),
2244 batch: batch_prod as u32,
2245 m: m_dim as u32,
2246 k: k_dim as u32,
2247 n: n as u32,
2248 }
2249 } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2250 let r = shape.rank();
2253 let mut batch_prod = 1usize;
2254 for d in 0..r - 2 {
2255 batch_prod *= shape.dim(d).unwrap_static();
2256 }
2257 let m_dim = shape.dim(r - 2).unwrap_static();
2258 let k_dim = a_shape.dim(r - 1).unwrap_static();
2259 debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2260 Thunk::BatchedSgemm {
2261 a: node_offset(arena, node.inputs[0]),
2262 b: node_offset(arena, node.inputs[1]),
2263 c: node_offset(arena, node.id),
2264 batch: batch_prod as u32,
2265 m: m_dim as u32,
2266 k: k_dim as u32,
2267 n: n as u32,
2268 }
2269 } else {
2270 let total = shape.num_elements().unwrap();
2271 let m = total / n;
2272 let a_len = get_len(graph, node.inputs[0]);
2273 let k = a_len / m;
2274 match shape.dtype() {
2275 rlx_ir::DType::F64 => Thunk::Dgemm {
2276 a: node_offset(arena, node.inputs[0]),
2277 b: node_offset(arena, node.inputs[1]),
2278 c: node_offset(arena, node.id),
2279 m: m as u32,
2280 k: k as u32,
2281 n: n as u32,
2282 },
2283 _ => Thunk::Sgemm {
2284 a: node_offset(arena, node.inputs[0]),
2285 b: node_offset(arena, node.inputs[1]),
2286 c: node_offset(arena, node.id),
2287 m: m as u32,
2288 k: k as u32,
2289 n: n as u32,
2290 },
2291 }
2292 }
2293 }
2294
2295 Op::Binary(op) => {
2296 let lhs_len = get_len(graph, node.inputs[0]);
2297 let rhs_len = get_len(graph, node.inputs[1]);
2298 let out_len = node.shape.num_elements().unwrap();
2299 if node.shape.dtype() == rlx_ir::DType::C64 {
2300 match op {
2304 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2305 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2306 "Op::Binary({op:?}) on DType::C64: complex \
2307 max/min/pow have no single natural definition \
2308 — caller should drop to 2N-real-block (see \
2309 spike-ac) and pick a convention there"
2310 ),
2311 }
2312 }
2313 let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2317 if lhs_len == out_len && rhs_len == out_len {
2318 (Vec::new(), Vec::new(), Vec::new())
2319 } else {
2320 let lhs_dims = get_static_dims(graph, node.inputs[0]);
2321 let rhs_dims = get_static_dims(graph, node.inputs[1]);
2322 let out_dims_v = get_static_dims(graph, node.id);
2323 if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2324 (Vec::new(), Vec::new(), Vec::new())
2329 } else {
2330 let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2331 let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2332 let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2333 (od, ls, rs)
2334 }
2335 };
2336 if node.shape.dtype() == rlx_ir::DType::C64 {
2337 Thunk::BinaryFullC64 {
2338 lhs: node_offset(arena, node.inputs[0]),
2339 rhs: node_offset(arena, node.inputs[1]),
2340 dst: node_offset(arena, node.id),
2341 len: out_len as u32,
2342 lhs_len: lhs_len as u32,
2343 rhs_len: rhs_len as u32,
2344 op: *op,
2345 out_dims_bcast,
2346 bcast_lhs_strides,
2347 bcast_rhs_strides,
2348 }
2349 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2350 Thunk::BinaryFullF64 {
2353 lhs: node_offset(arena, node.inputs[0]),
2354 rhs: node_offset(arena, node.inputs[1]),
2355 dst: node_offset(arena, node.id),
2356 len: out_len as u32,
2357 lhs_len: lhs_len as u32,
2358 rhs_len: rhs_len as u32,
2359 op: *op,
2360 out_dims_bcast,
2361 bcast_lhs_strides,
2362 bcast_rhs_strides,
2363 }
2364 } else if matches!(op, BinaryOp::Add)
2365 && rhs_len < out_len
2366 && out_len % rhs_len == 0
2367 && is_trailing_bias_broadcast(
2368 graph.node(node.inputs[1]).shape.dims(),
2369 graph.node(node.id).shape.dims(),
2370 )
2371 {
2372 Thunk::BiasAdd {
2382 src: node_offset(arena, node.inputs[0]),
2383 bias: node_offset(arena, node.inputs[1]),
2384 dst: node_offset(arena, node.id),
2385 m: (out_len / rhs_len) as u32,
2386 n: rhs_len as u32,
2387 }
2388 } else {
2389 let lhs_len = get_len(graph, node.inputs[0]);
2390 Thunk::BinaryFull {
2391 lhs: node_offset(arena, node.inputs[0]),
2392 rhs: node_offset(arena, node.inputs[1]),
2393 dst: node_offset(arena, node.id),
2394 len: out_len as u32,
2395 lhs_len: lhs_len as u32,
2396 rhs_len: rhs_len as u32,
2397 op: *op,
2398 out_dims_bcast,
2399 bcast_lhs_strides,
2400 bcast_rhs_strides,
2401 }
2402 }
2403 }
2404
2405 Op::Activation(act) => {
2406 let len = node.shape.num_elements().unwrap();
2407 let in_off = node_offset(arena, node.inputs[0]);
2408 let out_off = node_offset(arena, node.id);
2409 if node.shape.dtype() == rlx_ir::DType::C64 {
2410 match act {
2415 Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2416 other => panic!(
2417 "Op::Activation({other:?}) on DType::C64: no \
2418 natural complex extension — supported on C64: \
2419 Neg, Exp, Log, Sqrt"
2420 ),
2421 }
2422 Thunk::ActivationC64 {
2423 src: in_off,
2424 dst: out_off,
2425 len: len as u32,
2426 kind: *act,
2427 }
2428 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2429 Thunk::ActivationF64 {
2430 src: in_off,
2431 dst: out_off,
2432 len: len as u32,
2433 kind: *act,
2434 }
2435 } else if in_off == out_off {
2436 Thunk::ActivationInPlace {
2440 data: out_off,
2441 len: len as u32,
2442 act: *act,
2443 }
2444 } else {
2445 thunks.push(Thunk::Copy {
2449 src: in_off,
2450 dst: out_off,
2451 len: len as u32,
2452 });
2453 Thunk::ActivationInPlace {
2454 data: out_off,
2455 len: len as u32,
2456 act: *act,
2457 }
2458 }
2459 }
2460
2461 Op::Gather { axis } if *axis == 0 => {
2462 let table_shape = &graph.node(node.inputs[0]).shape;
2463 let table_total = table_shape.num_elements().unwrap();
2464 let trailing: usize = (1..table_shape.rank())
2465 .map(|i| table_shape.dim(i).unwrap_static())
2466 .product();
2467 let idx_len = get_len(graph, node.inputs[1]);
2468 Thunk::Gather {
2469 table: node_offset(arena, node.inputs[0]),
2470 table_len: table_total as u32,
2471 idx: node_offset(arena, node.inputs[1]),
2472 dst: node_offset(arena, node.id),
2473 num_idx: idx_len as u32,
2474 trailing: trailing as u32,
2475 }
2476 }
2477
2478 Op::Gather { axis } => {
2479 let table_shape = &graph.node(node.inputs[0]).shape;
2481 let rank = table_shape.rank();
2482 let outer: usize = (0..*axis)
2483 .map(|i| table_shape.dim(i).unwrap_static())
2484 .product::<usize>()
2485 .max(1);
2486 let trailing: usize = (*axis + 1..rank)
2487 .map(|i| table_shape.dim(i).unwrap_static())
2488 .product::<usize>()
2489 .max(1);
2490 let axis_dim = table_shape.dim(*axis).unwrap_static();
2491 let idx_len = get_len(graph, node.inputs[1]);
2492 Thunk::GatherAxis {
2493 table: node_offset(arena, node.inputs[0]),
2494 idx: node_offset(arena, node.inputs[1]),
2495 dst: node_offset(arena, node.id),
2496 outer: outer as u32,
2497 axis_dim: axis_dim as u32,
2498 num_idx: idx_len as u32,
2499 trailing: trailing as u32,
2500 }
2501 }
2502
2503 Op::Narrow { axis, start, len } => {
2504 let in_shape = &graph.node(node.inputs[0]).shape;
2505 let elem_bytes = in_shape.dtype().size_bytes() as u8;
2506 let rank = in_shape.rank();
2507 let outer: usize = (0..*axis)
2508 .map(|i| in_shape.dim(i).unwrap_static())
2509 .product::<usize>()
2510 .max(1);
2511 let inner: usize = (*axis + 1..rank)
2512 .map(|i| in_shape.dim(i).unwrap_static())
2513 .product::<usize>()
2514 .max(1);
2515 let in_axis = in_shape.dim(*axis).unwrap_static();
2516 let src_byte_offset =
2517 node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2518 Thunk::Narrow {
2519 src: src_byte_offset,
2520 dst: node_offset(arena, node.id),
2521 outer: outer as u32,
2522 src_stride: (in_axis * inner) as u32, dst_stride: (*len * inner) as u32, inner: (*len * inner) as u32, elem_bytes,
2526 }
2527 }
2528
2529 Op::Reshape { .. } | Op::Cast { .. } => {
2530 let len = node.shape.num_elements().unwrap();
2532 let src = node_offset(arena, node.inputs[0]);
2533 let dst = node_offset(arena, node.id);
2534 match node.shape.dtype() {
2535 rlx_ir::DType::F64 => Thunk::CopyF64 {
2536 src,
2537 dst,
2538 len: len as u32,
2539 },
2540 _ => Thunk::Copy {
2541 src,
2542 dst,
2543 len: len as u32,
2544 },
2545 }
2546 }
2547
2548 Op::Quantize {
2549 axis,
2550 scales,
2551 zero_points,
2552 } => {
2553 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2554 Thunk::Quantize {
2555 x: node_offset(arena, node.inputs[0]),
2556 q: node_offset(arena, node.id),
2557 len: node.shape.num_elements().unwrap() as u32,
2558 chan_axis: chan_axis as u32,
2559 chan_dim: chan_dim as u32,
2560 inner: inner as u32,
2561 scales: scales.clone(),
2562 zero_points: zero_points.clone(),
2563 }
2564 }
2565
2566 Op::FakeQuantize {
2567 bits,
2568 axis,
2569 ste,
2570 scale_mode,
2571 } => {
2572 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2573 let state_off = match scale_mode {
2574 rlx_ir::op::ScaleMode::PerBatch => None,
2575 rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2576 debug_assert_eq!(
2578 node.inputs.len(),
2579 2,
2580 "EMA/Fixed FakeQuantize needs a state input"
2581 );
2582 Some(node_offset(arena, node.inputs[1]))
2583 }
2584 };
2585 Thunk::FakeQuantize {
2586 x: node_offset(arena, node.inputs[0]),
2587 out: node_offset(arena, node.id),
2588 len: node.shape.num_elements().unwrap() as u32,
2589 chan_axis: chan_axis as u32,
2590 chan_dim: chan_dim as u32,
2591 inner: inner as u32,
2592 bits: *bits,
2593 ste: *ste,
2594 scale_mode: *scale_mode,
2595 state_off,
2596 }
2597 }
2598
2599 Op::FakeQuantizeLSQ { bits, axis } => {
2600 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2601 Thunk::FakeQuantizeLSQ {
2602 x: node_offset(arena, node.inputs[0]),
2603 scale_off: node_offset(arena, node.inputs[1]),
2604 out: node_offset(arena, node.id),
2605 len: node.shape.num_elements().unwrap() as u32,
2606 chan_axis: chan_axis as u32,
2607 chan_dim: chan_dim as u32,
2608 inner: inner as u32,
2609 bits: *bits,
2610 }
2611 }
2612
2613 Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2614 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2615 Thunk::FakeQuantizeLSQBackwardX {
2616 x: node_offset(arena, node.inputs[0]),
2617 scale_off: node_offset(arena, node.inputs[1]),
2618 dy: node_offset(arena, node.inputs[2]),
2619 dx: node_offset(arena, node.id),
2620 len: node.shape.num_elements().unwrap() as u32,
2621 chan_axis: chan_axis as u32,
2622 chan_dim: chan_dim as u32,
2623 inner: inner as u32,
2624 bits: *bits,
2625 }
2626 }
2627
2628 Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2629 let in_shape = &graph.node(node.inputs[0]).shape;
2632 let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2633 Thunk::FakeQuantizeLSQBackwardScale {
2634 x: node_offset(arena, node.inputs[0]),
2635 scale_off: node_offset(arena, node.inputs[1]),
2636 dy: node_offset(arena, node.inputs[2]),
2637 dscale: node_offset(arena, node.id),
2638 len: in_shape.num_elements().unwrap() as u32,
2639 chan_axis: chan_axis as u32,
2640 chan_dim: chan_dim as u32,
2641 inner: inner as u32,
2642 bits: *bits,
2643 }
2644 }
2645
2646 Op::FakeQuantizeBackward { bits, axis, ste } => {
2647 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2648 Thunk::FakeQuantizeBackward {
2649 x: node_offset(arena, node.inputs[0]),
2650 dy: node_offset(arena, node.inputs[1]),
2651 dx: node_offset(arena, node.id),
2652 len: node.shape.num_elements().unwrap() as u32,
2653 chan_axis: chan_axis as u32,
2654 chan_dim: chan_dim as u32,
2655 inner: inner as u32,
2656 bits: *bits,
2657 ste: *ste,
2658 }
2659 }
2660
2661 Op::Dequantize {
2662 axis,
2663 scales,
2664 zero_points,
2665 } => {
2666 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2667 Thunk::Dequantize {
2668 q: node_offset(arena, node.inputs[0]),
2669 x: node_offset(arena, node.id),
2670 len: node.shape.num_elements().unwrap() as u32,
2671 chan_axis: chan_axis as u32,
2672 chan_dim: chan_dim as u32,
2673 inner: inner as u32,
2674 scales: scales.clone(),
2675 zero_points: zero_points.clone(),
2676 }
2677 }
2678
2679 Op::Expand { .. } => {
2680 let in_shape = &graph.node(node.inputs[0]).shape;
2685 let out_shape = &node.shape;
2686 let in_rank = in_shape.rank();
2687 let out_rank = out_shape.rank();
2688 let pad = out_rank.saturating_sub(in_rank);
2690 let in_dims: Vec<usize> = (0..out_rank)
2691 .map(|i| {
2692 if i < pad {
2693 1
2694 } else {
2695 in_shape.dim(i - pad).unwrap_static()
2696 }
2697 })
2698 .collect();
2699 let mut in_strides_full = vec![1usize; out_rank];
2701 for d in (0..out_rank.saturating_sub(1)).rev() {
2702 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2703 }
2704 let out_dims: Vec<u32> = (0..out_rank)
2705 .map(|i| out_shape.dim(i).unwrap_static() as u32)
2706 .collect();
2707 let in_strides: Vec<u32> = (0..out_rank)
2709 .map(|i| {
2710 if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2711 0
2712 } else {
2713 in_strides_full[i] as u32
2714 }
2715 })
2716 .collect();
2717 let in_total = in_dims.iter().product::<usize>() as u32;
2718 let src = node_offset(arena, node.inputs[0]);
2719 let dst = node_offset(arena, node.id);
2720 match node.shape.dtype() {
2721 rlx_ir::DType::F64 => Thunk::TransposeF64 {
2722 src,
2723 dst,
2724 in_total,
2725 out_dims,
2726 in_strides,
2727 },
2728 _ => Thunk::Transpose {
2729 src,
2730 dst,
2731 in_total,
2732 out_dims,
2733 in_strides,
2734 },
2735 }
2736 }
2737
2738 Op::RmsNorm { eps, .. } => {
2739 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2740 let total = node.shape.num_elements().unwrap();
2741 Thunk::RmsNorm {
2742 src: node_offset(arena, node.inputs[0]),
2743 g: node_offset(arena, node.inputs[1]),
2744 b: node_offset(arena, node.inputs[2]),
2745 dst: node_offset(arena, node.id),
2746 rows: (total / h) as u32,
2747 h: h as u32,
2748 eps: *eps,
2749 }
2750 }
2751
2752 Op::LayerNorm { eps, .. } => {
2753 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2754 let total = node.shape.num_elements().unwrap();
2755 Thunk::LayerNorm {
2756 src: node_offset(arena, node.inputs[0]),
2757 g: node_offset(arena, node.inputs[1]),
2758 b: node_offset(arena, node.inputs[2]),
2759 dst: node_offset(arena, node.id),
2760 rows: (total / h) as u32,
2761 h: h as u32,
2762 eps: *eps,
2763 }
2764 }
2765
2766 Op::GroupNorm { num_groups, eps } => {
2767 let in_shape = &graph.node(node.inputs[0]).shape;
2768 Thunk::GroupNorm {
2769 src: node_offset(arena, node.inputs[0]),
2770 g: node_offset(arena, node.inputs[1]),
2771 b: node_offset(arena, node.inputs[2]),
2772 dst: node_offset(arena, node.id),
2773 n: in_shape.dim(0).unwrap_static() as u32,
2774 c: in_shape.dim(1).unwrap_static() as u32,
2775 h: in_shape.dim(2).unwrap_static() as u32,
2776 w: in_shape.dim(3).unwrap_static() as u32,
2777 num_groups: *num_groups as u32,
2778 eps: *eps,
2779 }
2780 }
2781
2782 Op::LayerNorm2d { eps } => {
2783 let in_shape = &graph.node(node.inputs[0]).shape;
2784 Thunk::LayerNorm2d {
2785 src: node_offset(arena, node.inputs[0]),
2786 g: node_offset(arena, node.inputs[1]),
2787 b: node_offset(arena, node.inputs[2]),
2788 dst: node_offset(arena, node.id),
2789 n: in_shape.dim(0).unwrap_static() as u32,
2790 c: in_shape.dim(1).unwrap_static() as u32,
2791 h: in_shape.dim(2).unwrap_static() as u32,
2792 w: in_shape.dim(3).unwrap_static() as u32,
2793 eps: *eps,
2794 }
2795 }
2796
2797 Op::ConvTranspose2d {
2798 kernel_size,
2799 stride,
2800 padding,
2801 dilation,
2802 output_padding: _,
2803 groups,
2804 } => {
2805 let in_shape = &graph.node(node.inputs[0]).shape;
2806 let out_shape = &node.shape;
2807 Thunk::ConvTranspose2d {
2808 src: node_offset(arena, node.inputs[0]),
2809 weight: node_offset(arena, node.inputs[1]),
2810 dst: node_offset(arena, node.id),
2811 n: in_shape.dim(0).unwrap_static() as u32,
2812 c_in: in_shape.dim(1).unwrap_static() as u32,
2813 h: in_shape.dim(2).unwrap_static() as u32,
2814 w_in: in_shape.dim(3).unwrap_static() as u32,
2815 c_out: out_shape.dim(1).unwrap_static() as u32,
2816 h_out: out_shape.dim(2).unwrap_static() as u32,
2817 w_out: out_shape.dim(3).unwrap_static() as u32,
2818 kh: kernel_size[0] as u32,
2819 kw: kernel_size[1] as u32,
2820 sh: stride.first().copied().unwrap_or(1) as u32,
2821 sw: stride.get(1).copied().unwrap_or(1) as u32,
2822 ph: padding.first().copied().unwrap_or(0) as u32,
2823 pw: padding.get(1).copied().unwrap_or(0) as u32,
2824 dh: dilation.first().copied().unwrap_or(1) as u32,
2825 dw: dilation.get(1).copied().unwrap_or(1) as u32,
2826 groups: *groups as u32,
2827 }
2828 }
2829
2830 Op::ResizeNearest2x => {
2831 let in_shape = &graph.node(node.inputs[0]).shape;
2832 Thunk::ResizeNearest2x {
2833 src: node_offset(arena, node.inputs[0]),
2834 dst: node_offset(arena, node.id),
2835 n: in_shape.dim(0).unwrap_static() as u32,
2836 c: in_shape.dim(1).unwrap_static() as u32,
2837 h: in_shape.dim(2).unwrap_static() as u32,
2838 w: in_shape.dim(3).unwrap_static() as u32,
2839 }
2840 }
2841
2842 Op::AxialRope2d {
2843 end_x,
2844 end_y,
2845 head_dim,
2846 num_heads,
2847 theta,
2848 repeat_factor,
2849 } => {
2850 let in_shape = &graph.node(node.inputs[0]).shape;
2851 let batch = in_shape.dim(0).unwrap_static() as u32;
2852 let seq = in_shape.dim(1).unwrap_static() as u32;
2853 let hidden = in_shape.dim(2).unwrap_static() as u32;
2854 Thunk::AxialRope2d {
2855 src: node_offset(arena, node.inputs[0]),
2856 dst: node_offset(arena, node.id),
2857 batch,
2858 seq,
2859 hidden,
2860 end_x: *end_x as u32,
2861 end_y: *end_y as u32,
2862 head_dim: *head_dim as u32,
2863 num_heads: *num_heads as u32,
2864 theta: *theta,
2865 repeat_factor: *repeat_factor as u32,
2866 }
2867 }
2868
2869 Op::Softmax { axis } => {
2870 let rank = node.shape.rank();
2871 let ax = if *axis < 0 {
2872 (rank as i32 + axis) as usize
2873 } else {
2874 *axis as usize
2875 };
2876 let cols = node.shape.dim(ax).unwrap_static();
2877 let total = node.shape.num_elements().unwrap();
2878 let in_off = node_offset(arena, node.inputs[0]);
2879 let out_off = node_offset(arena, node.id);
2880 if in_off != out_off {
2886 thunks.push(Thunk::Copy {
2887 src: in_off,
2888 dst: out_off,
2889 len: total as u32,
2890 });
2891 }
2892 Thunk::Softmax {
2893 data: out_off,
2894 rows: (total / cols) as u32,
2895 cols: cols as u32,
2896 }
2897 }
2898
2899 Op::SelectiveScan { state_size } => {
2900 let in_shape = &graph.node(node.inputs[0]).shape;
2901 let (batch, seq, hidden) = (
2902 in_shape.dim(0).unwrap_static(),
2903 in_shape.dim(1).unwrap_static(),
2904 in_shape.dim(2).unwrap_static(),
2905 );
2906 Thunk::SelectiveScan {
2907 x: node_offset(arena, node.inputs[0]),
2908 delta: node_offset(arena, node.inputs[1]),
2909 a: node_offset(arena, node.inputs[2]),
2910 b: node_offset(arena, node.inputs[3]),
2911 c: node_offset(arena, node.inputs[4]),
2912 dst: node_offset(arena, node.id),
2913 batch: batch as u32,
2914 seq: seq as u32,
2915 hidden: hidden as u32,
2916 state_size: *state_size as u32,
2917 }
2918 }
2919
2920 Op::GatedDeltaNet {
2921 state_size,
2922 carry_state,
2923 } => {
2924 let q_shape = &graph.node(node.inputs[0]).shape;
2925 let (batch, seq, heads) = (
2926 q_shape.dim(0).unwrap_static(),
2927 q_shape.dim(1).unwrap_static(),
2928 q_shape.dim(2).unwrap_static(),
2929 );
2930 let state_off = if *carry_state {
2931 node_offset(arena, node.inputs[5])
2932 } else {
2933 0
2934 };
2935 Thunk::GatedDeltaNet {
2936 q: node_offset(arena, node.inputs[0]),
2937 k: node_offset(arena, node.inputs[1]),
2938 v: node_offset(arena, node.inputs[2]),
2939 g: node_offset(arena, node.inputs[3]),
2940 beta: node_offset(arena, node.inputs[4]),
2941 state: state_off,
2942 dst: node_offset(arena, node.id),
2943 batch: batch as u32,
2944 seq: seq as u32,
2945 heads: heads as u32,
2946 state_size: *state_size as u32,
2947 }
2948 }
2949
2950 Op::QMatMul {
2951 x_zp,
2952 w_zp,
2953 out_zp,
2954 mult,
2955 } => {
2956 let x_shape = &graph.node(node.inputs[0]).shape;
2957 let w_shape = &graph.node(node.inputs[1]).shape;
2958 let m = x_shape.dim(0).unwrap_static();
2959 let k = x_shape.dim(1).unwrap_static();
2960 let n = w_shape.dim(1).unwrap_static();
2961 Thunk::QMatMul {
2962 x: node_offset(arena, node.inputs[0]),
2963 w: node_offset(arena, node.inputs[1]),
2964 bias: node_offset(arena, node.inputs[2]),
2965 out: node_offset(arena, node.id),
2966 m: m as u32,
2967 k: k as u32,
2968 n: n as u32,
2969 x_zp: *x_zp,
2970 w_zp: *w_zp,
2971 out_zp: *out_zp,
2972 mult: *mult,
2973 }
2974 }
2975
2976 Op::QConv2d {
2977 kernel_size,
2978 stride,
2979 padding,
2980 dilation,
2981 groups,
2982 x_zp,
2983 w_zp,
2984 out_zp,
2985 mult,
2986 } => {
2987 let in_shape = &graph.node(node.inputs[0]).shape;
2988 let w_shape = &graph.node(node.inputs[1]).shape;
2989 let out_shape = &node.shape;
2990 if kernel_size.len() == 2
2991 && in_shape.rank() == 4
2992 && w_shape.rank() == 4
2993 && out_shape.rank() == 4
2994 {
2995 Thunk::QConv2d {
2996 x: node_offset(arena, node.inputs[0]),
2997 w: node_offset(arena, node.inputs[1]),
2998 bias: node_offset(arena, node.inputs[2]),
2999 out: node_offset(arena, node.id),
3000 n: in_shape.dim(0).unwrap_static() as u32,
3001 c_in: in_shape.dim(1).unwrap_static() as u32,
3002 h: in_shape.dim(2).unwrap_static() as u32,
3003 w_in: in_shape.dim(3).unwrap_static() as u32,
3004 c_out: out_shape.dim(1).unwrap_static() as u32,
3005 h_out: out_shape.dim(2).unwrap_static() as u32,
3006 w_out: out_shape.dim(3).unwrap_static() as u32,
3007 kh: kernel_size[0] as u32,
3008 kw: kernel_size[1] as u32,
3009 sh: stride.first().copied().unwrap_or(1) as u32,
3010 sw: stride.get(1).copied().unwrap_or(1) as u32,
3011 ph: padding.first().copied().unwrap_or(0) as u32,
3012 pw: padding.get(1).copied().unwrap_or(0) as u32,
3013 dh: dilation.first().copied().unwrap_or(1) as u32,
3014 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3015 groups: *groups as u32,
3016 x_zp: *x_zp,
3017 w_zp: *w_zp,
3018 out_zp: *out_zp,
3019 mult: *mult,
3020 }
3021 } else {
3022 Thunk::Nop
3023 }
3024 }
3025
3026 Op::DequantMatMul { scheme } => {
3027 use rlx_ir::quant::QuantScheme;
3028 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3029 let total = node.shape.num_elements().unwrap();
3030 let m = total / n.max(1);
3031 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3032 let k = x_total / m.max(1);
3033 if scheme.is_gguf() {
3034 Thunk::DequantMatMulGguf {
3035 x: node_offset(arena, node.inputs[0]),
3036 w_q: node_offset(arena, node.inputs[1]),
3037 dst: node_offset(arena, node.id),
3038 m: m as u32,
3039 k: k as u32,
3040 n: n as u32,
3041 scheme: *scheme,
3042 }
3043 } else {
3044 match scheme {
3045 QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3046 x: node_offset(arena, node.inputs[0]),
3047 w_q: node_offset(arena, node.inputs[1]),
3048 scale: node_offset(arena, node.inputs[2]),
3049 global_scale: node_offset(arena, node.inputs[3]),
3050 dst: node_offset(arena, node.id),
3051 m: m as u32,
3052 k: k as u32,
3053 n: n as u32,
3054 },
3055 QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3056 x: node_offset(arena, node.inputs[0]),
3057 w_q: node_offset(arena, node.inputs[1]),
3058 scale: node_offset(arena, node.inputs[2]),
3059 zp: node_offset(arena, node.inputs[3]),
3060 dst: node_offset(arena, node.id),
3061 m: m as u32,
3062 k: k as u32,
3063 n: n as u32,
3064 block_size: *block_size,
3065 is_asymmetric: false,
3066 },
3067 QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3068 x: node_offset(arena, node.inputs[0]),
3069 w_q: node_offset(arena, node.inputs[1]),
3070 scale: node_offset(arena, node.inputs[2]),
3071 dst: node_offset(arena, node.id),
3072 m: m as u32,
3073 k: k as u32,
3074 n: n as u32,
3075 e5m2: false,
3076 },
3077 QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3078 x: node_offset(arena, node.inputs[0]),
3079 w_q: node_offset(arena, node.inputs[1]),
3080 scale: node_offset(arena, node.inputs[2]),
3081 dst: node_offset(arena, node.id),
3082 m: m as u32,
3083 k: k as u32,
3084 n: n as u32,
3085 e5m2: true,
3086 },
3087 QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3088 x: node_offset(arena, node.inputs[0]),
3089 w_q: node_offset(arena, node.inputs[1]),
3090 scale: node_offset(arena, node.inputs[2]),
3091 zp: node_offset(arena, node.inputs[3]),
3092 dst: node_offset(arena, node.id),
3093 m: m as u32,
3094 k: k as u32,
3095 n: n as u32,
3096 block_size: *block_size,
3097 is_asymmetric: false,
3098 },
3099 QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3100 x: node_offset(arena, node.inputs[0]),
3101 w_q: node_offset(arena, node.inputs[1]),
3102 scale: node_offset(arena, node.inputs[2]),
3103 zp: node_offset(arena, node.inputs[3]),
3104 dst: node_offset(arena, node.id),
3105 m: m as u32,
3106 k: k as u32,
3107 n: n as u32,
3108 block_size: *block_size,
3109 is_asymmetric: true,
3110 },
3111 other => panic!(
3112 "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3113 ),
3114 }
3115 }
3116 }
3117
3118 Op::LoraMatMul { scale } => {
3119 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3121 let total = node.shape.num_elements().unwrap();
3122 let m = total / n.max(1);
3123 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3124 let k = x_total / m.max(1);
3125 let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3126 let r = a_total / k.max(1);
3127 Thunk::LoraMatMul {
3128 x: node_offset(arena, node.inputs[0]),
3129 w: node_offset(arena, node.inputs[1]),
3130 a: node_offset(arena, node.inputs[2]),
3131 b: node_offset(arena, node.inputs[3]),
3132 dst: node_offset(arena, node.id),
3133 m: m as u32,
3134 k: k as u32,
3135 n: n as u32,
3136 r: r as u32,
3137 scale: *scale,
3138 }
3139 }
3140
3141 Op::Sample {
3142 top_k,
3143 top_p,
3144 temperature,
3145 seed,
3146 } => {
3147 let in_shape = &graph.node(node.inputs[0]).shape;
3148 let (batch, vocab) = if in_shape.rank() >= 2 {
3150 (
3151 in_shape.dim(0).unwrap_static(),
3152 in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3153 )
3154 } else {
3155 (1, in_shape.num_elements().unwrap_or(0))
3156 };
3157 Thunk::Sample {
3158 logits: node_offset(arena, node.inputs[0]),
3159 dst: node_offset(arena, node.id),
3160 batch: batch as u32,
3161 vocab: vocab as u32,
3162 top_k: *top_k as u32,
3163 top_p: *top_p,
3164 temperature: *temperature,
3165 seed: *seed,
3166 }
3167 }
3168
3169 Op::Cumsum { axis, exclusive } => {
3170 let rank = node.shape.rank();
3175 let ax = if *axis < 0 {
3176 (rank as i32 + axis) as usize
3177 } else {
3178 *axis as usize
3179 };
3180 assert_eq!(
3181 ax,
3182 rank - 1,
3183 "Cumsum only supports the last axis on CPU today"
3184 );
3185 let cols = node.shape.dim(ax).unwrap_static();
3186 let total = node.shape.num_elements().unwrap();
3187 Thunk::Cumsum {
3188 src: node_offset(arena, node.inputs[0]),
3189 dst: node_offset(arena, node.id),
3190 rows: (total / cols) as u32,
3191 cols: cols as u32,
3192 exclusive: *exclusive,
3193 }
3194 }
3195
3196 Op::Attention {
3197 num_heads,
3198 head_dim,
3199 mask_kind,
3200 } => {
3201 let q_shape = &graph.node(node.inputs[0]).shape;
3207 let k_shape = &graph.node(node.inputs[1]).shape;
3208 let rank = q_shape.rank();
3209 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3210 let d1 = q_shape.dim(1).unwrap_static();
3211 let d2 = q_shape.dim(2).unwrap_static();
3212 if d1 == *num_heads {
3213 (
3215 q_shape.dim(0).unwrap_static(),
3216 d2,
3217 k_shape.dim(2).unwrap_static(),
3218 true,
3219 )
3220 } else {
3221 (
3223 q_shape.dim(0).unwrap_static(),
3224 d1,
3225 k_shape.dim(1).unwrap_static(),
3226 false,
3227 )
3228 }
3229 } else if rank >= 3 {
3230 (
3231 q_shape.dim(0).unwrap_static(),
3232 q_shape.dim(1).unwrap_static(),
3233 k_shape.dim(1).unwrap_static(),
3234 false,
3235 )
3236 } else {
3237 (
3238 1,
3239 q_shape.dim(0).unwrap_static(),
3240 k_shape.dim(0).unwrap_static(),
3241 false,
3242 )
3243 };
3244 let mask_off = if matches!(
3245 mask_kind,
3246 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3247 ) {
3248 node_offset(arena, node.inputs[3])
3249 } else {
3250 0
3251 };
3252 let hs = (*num_heads * *head_dim) as u32;
3253 Thunk::Attention {
3254 q: node_offset(arena, node.inputs[0]),
3255 k: node_offset(arena, node.inputs[1]),
3256 v: node_offset(arena, node.inputs[2]),
3257 mask: mask_off,
3258 out: node_offset(arena, node.id),
3259 batch: batch as u32,
3260 seq: seq as u32,
3261 kv_seq: kv_seq as u32,
3262 heads: *num_heads as u32,
3263 head_dim: *head_dim as u32,
3264 mask_kind: *mask_kind,
3265 q_row_stride: hs,
3269 k_row_stride: hs,
3270 v_row_stride: hs,
3271 bhsd,
3272 }
3273 }
3274
3275 Op::AttentionBackward {
3276 num_heads,
3277 head_dim,
3278 mask_kind,
3279 wrt,
3280 } => {
3281 let q_shape = &graph.node(node.inputs[0]).shape;
3282 let k_shape = &graph.node(node.inputs[1]).shape;
3283 let rank = q_shape.rank();
3284 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3285 let d1 = q_shape.dim(1).unwrap_static();
3286 let d2 = q_shape.dim(2).unwrap_static();
3287 if d1 == *num_heads {
3288 (
3289 q_shape.dim(0).unwrap_static(),
3290 d2,
3291 k_shape.dim(2).unwrap_static(),
3292 true,
3293 )
3294 } else {
3295 (
3296 q_shape.dim(0).unwrap_static(),
3297 d1,
3298 k_shape.dim(1).unwrap_static(),
3299 false,
3300 )
3301 }
3302 } else if rank >= 3 {
3303 (
3304 q_shape.dim(0).unwrap_static(),
3305 q_shape.dim(1).unwrap_static(),
3306 k_shape.dim(1).unwrap_static(),
3307 false,
3308 )
3309 } else {
3310 (
3311 1,
3312 q_shape.dim(0).unwrap_static(),
3313 k_shape.dim(0).unwrap_static(),
3314 false,
3315 )
3316 };
3317 let mask_off = if matches!(
3318 mask_kind,
3319 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3320 ) {
3321 node_offset(arena, node.inputs[4])
3322 } else {
3323 0
3324 };
3325 Thunk::AttentionBackward {
3326 q: node_offset(arena, node.inputs[0]),
3327 k: node_offset(arena, node.inputs[1]),
3328 v: node_offset(arena, node.inputs[2]),
3329 dy: node_offset(arena, node.inputs[3]),
3330 mask: mask_off,
3331 out: node_offset(arena, node.id),
3332 batch: batch as u32,
3333 seq: seq as u32,
3334 kv_seq: kv_seq as u32,
3335 heads: *num_heads as u32,
3336 head_dim: *head_dim as u32,
3337 mask_kind: *mask_kind,
3338 wrt: *wrt,
3339 bhsd,
3340 }
3341 }
3342
3343 Op::FusedAttentionBlock {
3344 num_heads,
3345 head_dim,
3346 has_bias,
3347 has_rope,
3348 } => {
3349 let x_shape = &graph.node(node.inputs[0]).shape;
3350 let (batch, seq) = if x_shape.rank() >= 3 {
3351 (
3352 x_shape.dim(0).unwrap_static(),
3353 x_shape.dim(1).unwrap_static(),
3354 )
3355 } else {
3356 let total = x_shape.num_elements().unwrap();
3357 let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3358 (total / (s * num_heads * head_dim), s)
3359 };
3360 let hs = (*num_heads * *head_dim) as u32;
3361 let mut idx = 4;
3363 let (qkv_b_off, out_b_off) = if *has_bias {
3364 let qb = node_offset(arena, node.inputs[idx]);
3365 let ob = node_offset(arena, node.inputs[idx + 1]);
3366 idx += 2;
3367 (qb, ob)
3368 } else {
3369 (0, 0)
3370 };
3371 let (cos_off, sin_off, cl) = if *has_rope {
3372 let c = node_offset(arena, node.inputs[idx]);
3373 let s = node_offset(arena, node.inputs[idx + 1]);
3374 let clen = get_len(graph, node.inputs[idx]);
3375 (c, s, clen as u32)
3376 } else {
3377 (0, 0, 0)
3378 };
3379
3380 Thunk::FusedAttnBlock {
3381 hidden: node_offset(arena, node.inputs[0]),
3382 qkv_w: node_offset(arena, node.inputs[1]),
3383 out_w: node_offset(arena, node.inputs[2]),
3384 mask: node_offset(arena, node.inputs[3]),
3385 out: node_offset(arena, node.id),
3386 qkv_b: qkv_b_off,
3387 out_b: out_b_off,
3388 cos: cos_off,
3389 sin: sin_off,
3390 cos_len: cl,
3391 batch: batch as u32,
3392 seq: seq as u32,
3393 hs,
3394 nh: *num_heads as u32,
3395 dh: *head_dim as u32,
3396 has_bias: *has_bias,
3397 has_rope: *has_rope,
3398 }
3399 }
3400
3401 Op::Rope { head_dim, n_rot } => {
3402 let x_shape = &graph.node(node.inputs[0]).shape;
3403 let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3404 (
3405 x_shape.dim(0).unwrap_static(),
3406 x_shape.dim(1).unwrap_static(),
3407 x_shape.dim(2).unwrap_static(),
3408 )
3409 } else {
3410 let total = x_shape.num_elements().unwrap();
3411 (
3412 1,
3413 x_shape.dim(0).unwrap_static(),
3414 total / x_shape.dim(0).unwrap_static(),
3415 )
3416 };
3417 let cos_len = get_len(graph, node.inputs[1]);
3418 Thunk::Rope {
3419 src: node_offset(arena, node.inputs[0]),
3420 cos: node_offset(arena, node.inputs[1]),
3421 sin: node_offset(arena, node.inputs[2]),
3422 dst: node_offset(arena, node.id),
3423 batch: batch as u32,
3424 seq: seq as u32,
3425 hidden: hidden as u32,
3426 head_dim: *head_dim as u32,
3427 n_rot: *n_rot as u32,
3428 cos_len: cos_len as u32,
3429 src_row_stride: hidden as u32,
3433 }
3434 }
3435
3436 Op::FusedSwiGLU {
3437 cast_to: _,
3438 gate_first,
3439 } => {
3440 let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3441 let total = node.shape.num_elements().unwrap();
3442 Thunk::FusedSwiGLU {
3443 src: node_offset(arena, node.inputs[0]),
3444 dst: node_offset(arena, node.id),
3445 n_half: n_half as u32,
3446 total: total as u32,
3447 gate_first: *gate_first,
3448 }
3449 }
3450
3451 Op::Conv {
3452 kernel_size,
3453 stride,
3454 padding,
3455 dilation,
3456 groups,
3457 } => {
3458 let in_shape = &graph.node(node.inputs[0]).shape;
3459 let w_shape = &graph.node(node.inputs[1]).shape;
3460 let out_shape = &node.shape;
3461 let is_1x1_simple = kernel_size.len() == 2
3465 && kernel_size[0] == 1
3466 && kernel_size[1] == 1
3467 && stride.iter().all(|&s| s == 1)
3468 && padding.iter().all(|&p| p == 0)
3469 && dilation.iter().all(|&d| d == 1)
3470 && *groups == 1;
3471 if is_1x1_simple && in_shape.rank() == 4 && out_shape.rank() == 4 {
3472 let n = in_shape.dim(0).unwrap_static();
3473 let c_in = in_shape.dim(1).unwrap_static();
3474 let c_out = out_shape.dim(1).unwrap_static();
3475 let h = in_shape.dim(2).unwrap_static();
3476 let w = in_shape.dim(3).unwrap_static();
3477 Thunk::Conv2D1x1 {
3478 src: node_offset(arena, node.inputs[0]),
3479 weight: node_offset(arena, node.inputs[1]),
3480 dst: node_offset(arena, node.id),
3481 n: n as u32,
3482 c_in: c_in as u32,
3483 c_out: c_out as u32,
3484 hw: (h * w) as u32,
3485 }
3486 } else if kernel_size.len() == 2
3487 && in_shape.rank() == 4
3488 && w_shape.rank() == 4
3489 && out_shape.rank() == 4
3490 {
3491 Thunk::Conv2D {
3492 src: node_offset(arena, node.inputs[0]),
3493 weight: node_offset(arena, node.inputs[1]),
3494 dst: node_offset(arena, node.id),
3495 n: in_shape.dim(0).unwrap_static() as u32,
3496 c_in: in_shape.dim(1).unwrap_static() as u32,
3497 h: in_shape.dim(2).unwrap_static() as u32,
3498 w: in_shape.dim(3).unwrap_static() as u32,
3499 c_out: out_shape.dim(1).unwrap_static() as u32,
3500 h_out: out_shape.dim(2).unwrap_static() as u32,
3501 w_out: out_shape.dim(3).unwrap_static() as u32,
3502 kh: kernel_size[0] as u32,
3503 kw: kernel_size[1] as u32,
3504 sh: stride.first().copied().unwrap_or(1) as u32,
3505 sw: stride.get(1).copied().unwrap_or(1) as u32,
3506 ph: padding.first().copied().unwrap_or(0) as u32,
3507 pw: padding.get(1).copied().unwrap_or(0) as u32,
3508 dh: dilation.first().copied().unwrap_or(1) as u32,
3509 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3510 groups: *groups as u32,
3511 }
3512 } else {
3513 Thunk::Nop
3514 }
3515 }
3516
3517 Op::Pool {
3518 kind,
3519 kernel_size,
3520 stride,
3521 padding,
3522 } => {
3523 let in_shape = &graph.node(node.inputs[0]).shape;
3525 let out_shape = &node.shape;
3526 if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3527 Thunk::Pool2D {
3528 src: node_offset(arena, node.inputs[0]),
3529 dst: node_offset(arena, node.id),
3530 n: in_shape.dim(0).unwrap_static() as u32,
3531 c: in_shape.dim(1).unwrap_static() as u32,
3532 h: in_shape.dim(2).unwrap_static() as u32,
3533 w: in_shape.dim(3).unwrap_static() as u32,
3534 h_out: out_shape.dim(2).unwrap_static() as u32,
3535 w_out: out_shape.dim(3).unwrap_static() as u32,
3536 kh: kernel_size[0] as u32,
3537 kw: kernel_size[1] as u32,
3538 sh: stride.first().copied().unwrap_or(1) as u32,
3539 sw: stride.get(1).copied().unwrap_or(1) as u32,
3540 ph: padding.first().copied().unwrap_or(0) as u32,
3541 pw: padding.get(1).copied().unwrap_or(0) as u32,
3542 kind: *kind,
3543 }
3544 } else {
3545 Thunk::Nop
3546 }
3547 }
3548
3549 Op::Transpose { perm } => {
3550 let in_shape = &graph.node(node.inputs[0]).shape;
3553 let in_rank = in_shape.rank();
3554 let in_dims: Vec<usize> = (0..in_rank)
3555 .map(|i| in_shape.dim(i).unwrap_static())
3556 .collect();
3557 let mut in_strides_full = vec![1usize; in_rank];
3559 for d in (0..in_rank.saturating_sub(1)).rev() {
3560 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3561 }
3562 let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3563 let in_strides: Vec<u32> =
3564 perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3565 let in_total = in_dims.iter().product::<usize>() as u32;
3566 let src = node_offset(arena, node.inputs[0]);
3567 let dst = node_offset(arena, node.id);
3568 match node.shape.dtype() {
3569 rlx_ir::DType::F64 => Thunk::TransposeF64 {
3570 src,
3571 dst,
3572 in_total,
3573 out_dims,
3574 in_strides,
3575 },
3576 _ => Thunk::Transpose {
3577 src,
3578 dst,
3579 in_total,
3580 out_dims,
3581 in_strides,
3582 },
3583 }
3584 }
3585
3586 Op::ScatterAdd => {
3587 let upd_shape = &graph.node(node.inputs[0]).shape;
3590 let out_shape = &node.shape;
3591 let num_updates = upd_shape.dim(0).unwrap_static();
3592 let out_dim = out_shape.dim(0).unwrap_static();
3593 let trailing: usize = (1..out_shape.rank())
3594 .map(|i| out_shape.dim(i).unwrap_static())
3595 .product::<usize>()
3596 .max(1);
3597 Thunk::ScatterAdd {
3598 updates: node_offset(arena, node.inputs[0]),
3599 indices: node_offset(arena, node.inputs[1]),
3600 dst: node_offset(arena, node.id),
3601 num_updates: num_updates as u32,
3602 out_dim: out_dim as u32,
3603 trailing: trailing as u32,
3604 }
3605 }
3606
3607 Op::GroupedMatMul => {
3608 let in_shape = &graph.node(node.inputs[0]).shape;
3610 let w_shape = &graph.node(node.inputs[1]).shape;
3611 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3612 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3613 let num_experts = w_shape.dim(0).unwrap_static();
3614 let n = w_shape.dim(2).unwrap_static();
3615 Thunk::GroupedMatMul {
3616 input: node_offset(arena, node.inputs[0]),
3617 weight: node_offset(arena, node.inputs[1]),
3618 expert_idx: node_offset(arena, node.inputs[2]),
3619 dst: node_offset(arena, node.id),
3620 m: m as u32,
3621 k_dim: k_dim as u32,
3622 n: n as u32,
3623 num_experts: num_experts as u32,
3624 }
3625 }
3626
3627 Op::DequantGroupedMatMul { scheme } => {
3628 let in_shape = &graph.node(node.inputs[0]).shape;
3629 let w_shape = &graph.node(node.inputs[1]).shape;
3630 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3631 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3632 let out_shape = &node.shape;
3633 let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3634 let block_elems = scheme.gguf_block_size() as usize;
3635 let block_bytes = scheme.gguf_block_bytes() as usize;
3636 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3637 let total_bytes = w_shape.num_elements().unwrap();
3638 let num_experts = total_bytes / slab_bytes.max(1);
3639 Thunk::DequantGroupedMatMulGguf {
3640 input: node_offset(arena, node.inputs[0]),
3641 w_q: node_offset(arena, node.inputs[1]),
3642 expert_idx: node_offset(arena, node.inputs[2]),
3643 dst: node_offset(arena, node.id),
3644 m: m as u32,
3645 k_dim: k_dim as u32,
3646 n: n as u32,
3647 num_experts: num_experts as u32,
3648 scheme: *scheme,
3649 }
3650 }
3651
3652 Op::DequantMoEWeights { scheme } => {
3653 let w_shape = &graph.node(node.inputs[0]).shape;
3654 let out_shape = &node.shape;
3655 let num_experts = out_shape.dim(0).unwrap_static();
3656 let k_dim = out_shape.dim(1).unwrap_static();
3657 let n = out_shape.dim(2).unwrap_static();
3658 let block_elems = scheme.gguf_block_size() as usize;
3659 let block_bytes = scheme.gguf_block_bytes() as usize;
3660 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3661 let total_bytes = w_shape.num_elements().unwrap();
3662 assert_eq!(
3663 total_bytes,
3664 num_experts * slab_bytes,
3665 "DequantMoEWeights packed bytes mismatch"
3666 );
3667 Thunk::DequantMoEWeightsGguf {
3668 w_q: node_offset(arena, node.inputs[0]),
3669 dst: node_offset(arena, node.id),
3670 k_dim: k_dim as u32,
3671 n: n as u32,
3672 num_experts: num_experts as u32,
3673 scheme: *scheme,
3674 }
3675 }
3676
3677 Op::TopK { k } => {
3678 let in_shape = &graph.node(node.inputs[0]).shape;
3679 let rank = in_shape.rank();
3680 let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3681 let outer = in_shape.num_elements().unwrap() / axis_dim;
3682 Thunk::TopK {
3683 src: node_offset(arena, node.inputs[0]),
3684 dst: node_offset(arena, node.id),
3685 outer: outer as u32,
3686 axis_dim: axis_dim as u32,
3687 k: *k as u32,
3688 }
3689 }
3690
3691 Op::Reduce {
3692 op,
3693 axes,
3694 keep_dim: _,
3695 } => {
3696 let in_shape = &graph.node(node.inputs[0]).shape;
3702 let rank = in_shape.rank();
3703 let mut sorted = axes.clone();
3704 sorted.sort();
3705 sorted.dedup();
3706 let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
3707 && !sorted.is_empty()
3708 && *sorted.last().unwrap() < rank;
3709 if !contiguous {
3710 Thunk::Nop
3711 } else {
3712 let first = sorted[0];
3713 let last = *sorted.last().unwrap();
3714 let outer: usize = (0..first)
3715 .map(|i| in_shape.dim(i).unwrap_static())
3716 .product::<usize>()
3717 .max(1);
3718 let reduced: usize = (first..=last)
3719 .map(|i| in_shape.dim(i).unwrap_static())
3720 .product();
3721 let inner: usize = (last + 1..rank)
3722 .map(|i| in_shape.dim(i).unwrap_static())
3723 .product::<usize>()
3724 .max(1);
3725 let src = node_offset(arena, node.inputs[0]);
3726 let dst = node_offset(arena, node.id);
3727 if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
3728 Thunk::ReduceSumF64 {
3729 src,
3730 dst,
3731 outer: outer as u32,
3732 reduced: reduced as u32,
3733 inner: inner as u32,
3734 }
3735 } else {
3736 Thunk::Reduce {
3737 src,
3738 dst,
3739 outer: outer as u32,
3740 reduced: reduced as u32,
3741 inner: inner as u32,
3742 op: *op,
3743 }
3744 }
3745 }
3746 }
3747
3748 Op::Compare(cmp) => {
3749 let len = node.shape.num_elements().unwrap();
3750 Thunk::Compare {
3751 lhs: node_offset(arena, node.inputs[0]),
3752 rhs: node_offset(arena, node.inputs[1]),
3753 dst: node_offset(arena, node.id),
3754 len: len as u32,
3755 op: *cmp,
3756 }
3757 }
3758
3759 Op::Where => {
3760 let len = node.shape.num_elements().unwrap();
3761 Thunk::Where {
3762 cond: node_offset(arena, node.inputs[0]),
3763 on_true: node_offset(arena, node.inputs[1]),
3764 on_false: node_offset(arena, node.inputs[2]),
3765 dst: node_offset(arena, node.id),
3766 len: len as u32,
3767 }
3768 }
3769
3770 Op::ReluBackward => {
3771 let len: usize = (0..node.shape.rank())
3772 .map(|i| node.shape.dim(i).unwrap_static())
3773 .product();
3774 let x = node_offset(arena, node.inputs[0]);
3775 let dy = node_offset(arena, node.inputs[1]);
3776 let dx = node_offset(arena, node.id);
3777 match node.shape.dtype() {
3778 rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
3779 x,
3780 dy,
3781 dx,
3782 len: len as u32,
3783 },
3784 _ => Thunk::ReluBackward {
3785 x,
3786 dy,
3787 dx,
3788 len: len as u32,
3789 },
3790 }
3791 }
3792
3793 Op::ComplexNormSq => {
3794 let len: usize = (0..node.shape.rank())
3795 .map(|i| node.shape.dim(i).unwrap_static())
3796 .product();
3797 let src = node_offset(arena, node.inputs[0]);
3798 let dst = node_offset(arena, node.id);
3799 Thunk::ComplexNormSqF32 {
3800 src,
3801 dst,
3802 len: len as u32,
3803 }
3804 }
3805
3806 Op::ComplexNormSqBackward => {
3807 let len: usize = (0..node.shape.rank())
3808 .map(|i| node.shape.dim(i).unwrap_static())
3809 .product();
3810 let z = node_offset(arena, node.inputs[0]);
3811 let g = node_offset(arena, node.inputs[1]);
3812 let dz = node_offset(arena, node.id);
3813 Thunk::ComplexNormSqBackwardF32 {
3814 z,
3815 g,
3816 dz,
3817 len: len as u32,
3818 }
3819 }
3820
3821 Op::Conjugate => {
3822 let len: usize = (0..node.shape.rank())
3823 .map(|i| node.shape.dim(i).unwrap_static())
3824 .product();
3825 Thunk::ConjugateC64 {
3826 src: node_offset(arena, node.inputs[0]),
3827 dst: node_offset(arena, node.id),
3828 len: len as u32,
3829 }
3830 }
3831
3832 Op::ActivationBackward { kind } => {
3833 let len: usize = (0..node.shape.rank())
3834 .map(|i| node.shape.dim(i).unwrap_static())
3835 .product();
3836 let x = node_offset(arena, node.inputs[0]);
3837 let dy = node_offset(arena, node.inputs[1]);
3838 let dx = node_offset(arena, node.id);
3839 match node.shape.dtype() {
3840 rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
3841 x,
3842 dy,
3843 dx,
3844 len: len as u32,
3845 kind: *kind,
3846 },
3847 _ => Thunk::ActivationBackward {
3848 x,
3849 dy,
3850 dx,
3851 len: len as u32,
3852 kind: *kind,
3853 },
3854 }
3855 }
3856
3857 Op::LayerNormBackwardInput { eps, .. } => {
3858 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3860 let total = node.shape.num_elements().unwrap();
3861 Thunk::LayerNormBackwardInput {
3862 x: node_offset(arena, node.inputs[0]),
3863 gamma: node_offset(arena, node.inputs[1]),
3864 dy: node_offset(arena, node.inputs[2]),
3865 dx: node_offset(arena, node.id),
3866 rows: (total / h) as u32,
3867 h: h as u32,
3868 eps: *eps,
3869 }
3870 }
3871
3872 Op::LayerNormBackwardGamma { eps, .. } => {
3873 let x_shape = &graph.node(node.inputs[0]).shape;
3874 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3875 let x_total = x_shape.num_elements().unwrap();
3876 Thunk::LayerNormBackwardGamma {
3877 x: node_offset(arena, node.inputs[0]),
3878 dy: node_offset(arena, node.inputs[1]),
3879 dgamma: node_offset(arena, node.id),
3880 rows: (x_total / h) as u32,
3881 h: h as u32,
3882 eps: *eps,
3883 }
3884 }
3885
3886 Op::RmsNormBackwardInput { eps, .. }
3887 | Op::RmsNormBackwardGamma { eps, .. }
3888 | Op::RmsNormBackwardBeta { eps, .. } => {
3889 let x_shape = &graph.node(node.inputs[0]).shape;
3890 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3891 let rows = (x_shape.num_elements().unwrap() / h) as u32;
3892 let off = |i: usize| node_offset(arena, node.inputs[i]);
3893 let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
3894 match &node.op {
3895 Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
3896 x: common.0,
3897 gamma: common.1,
3898 beta: common.2,
3899 dy: common.3,
3900 dx: node_offset(arena, node.id),
3901 rows: common.4,
3902 h: common.5,
3903 eps: common.6,
3904 },
3905 Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
3906 x: common.0,
3907 gamma: common.1,
3908 beta: common.2,
3909 dy: common.3,
3910 dgamma: node_offset(arena, node.id),
3911 rows: common.4,
3912 h: common.5,
3913 eps: common.6,
3914 },
3915 Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
3916 x: common.0,
3917 gamma: common.1,
3918 beta: common.2,
3919 dy: common.3,
3920 dbeta: node_offset(arena, node.id),
3921 rows: common.4,
3922 h: common.5,
3923 eps: common.6,
3924 },
3925 _ => unreachable!(),
3926 }
3927 }
3928
3929 Op::RopeBackward { head_dim, n_rot } => {
3930 let dy_shape = &graph.node(node.inputs[0]).shape;
3931 let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
3932 (
3933 dy_shape.dim(0).unwrap_static(),
3934 dy_shape.dim(1).unwrap_static(),
3935 dy_shape.dim(2).unwrap_static(),
3936 )
3937 } else {
3938 (
3939 1,
3940 dy_shape.dim(0).unwrap_static(),
3941 dy_shape.dim(1).unwrap_static(),
3942 )
3943 };
3944 let cos_shape = &graph.node(node.inputs[1]).shape;
3945 let cos_len = cos_shape.num_elements().unwrap();
3946 Thunk::RopeBackward {
3947 dy: node_offset(arena, node.inputs[0]),
3948 cos: node_offset(arena, node.inputs[1]),
3949 sin: node_offset(arena, node.inputs[2]),
3950 dx: node_offset(arena, node.id),
3951 batch: batch as u32,
3952 seq: seq as u32,
3953 hidden: hidden as u32,
3954 head_dim: *head_dim as u32,
3955 n_rot: *n_rot as u32,
3956 cos_len: cos_len as u32,
3957 }
3958 }
3959
3960 Op::CumsumBackward { exclusive, .. } => {
3961 let dy_shape = &graph.node(node.inputs[0]).shape;
3962 let rank = dy_shape.rank();
3963 let cols = dy_shape.dim(rank - 1).unwrap_static();
3964 let rows = dy_shape.num_elements().unwrap() / cols;
3965 Thunk::CumsumBackward {
3966 dy: node_offset(arena, node.inputs[0]),
3967 dx: node_offset(arena, node.id),
3968 rows: rows as u32,
3969 cols: cols as u32,
3970 exclusive: *exclusive,
3971 }
3972 }
3973
3974 Op::GatherBackward { .. } => {
3975 let dy_shape = &graph.node(node.inputs[0]).shape;
3976 let idx_shape = &graph.node(node.inputs[1]).shape;
3977 let out_shape = &node.shape;
3978 let rank = out_shape.rank();
3979 let axis = match &node.op {
3980 Op::GatherBackward { axis } => *axis,
3981 _ => 0,
3982 };
3983 let axis_u = if axis < 0 {
3984 (rank as i32 + axis) as usize
3985 } else {
3986 axis as usize
3987 };
3988 let outer: usize = (0..axis_u)
3989 .map(|i| dy_shape.dim(i).unwrap_static())
3990 .product::<usize>()
3991 .max(1);
3992 let num_idx = idx_shape.dim(axis_u).unwrap_static();
3993 let trailing: usize = (axis_u + 1..dy_shape.rank())
3994 .map(|i| dy_shape.dim(i).unwrap_static())
3995 .product::<usize>()
3996 .max(1);
3997 let axis_dim = out_shape.dim(axis_u).unwrap_static();
3998 Thunk::GatherBackward {
3999 dy: node_offset(arena, node.inputs[0]),
4000 indices: node_offset(arena, node.inputs[1]),
4001 dst: node_offset(arena, node.id),
4002 outer: outer as u32,
4003 axis_dim: axis_dim as u32,
4004 num_idx: num_idx as u32,
4005 trailing: trailing as u32,
4006 }
4007 }
4008
4009 Op::GroupNormBackwardInput { num_groups, eps }
4010 | Op::GroupNormBackwardGamma { num_groups, eps }
4011 | Op::GroupNormBackwardBeta { num_groups, eps } => {
4012 let x_shape = &graph.node(node.inputs[0]).shape;
4013 let n = x_shape.dim(0).unwrap_static() as u32;
4014 let c = x_shape.dim(1).unwrap_static() as u32;
4015 let h = x_shape.dim(2).unwrap_static() as u32;
4016 let w = x_shape.dim(3).unwrap_static() as u32;
4017 match &node.op {
4018 Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4019 x: node_offset(arena, node.inputs[0]),
4020 gamma: node_offset(arena, node.inputs[1]),
4021 beta: node_offset(arena, node.inputs[2]),
4022 dy: node_offset(arena, node.inputs[3]),
4023 dx: node_offset(arena, node.id),
4024 n,
4025 c,
4026 h,
4027 w,
4028 num_groups: *num_groups as u32,
4029 eps: *eps,
4030 },
4031 Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4032 x: node_offset(arena, node.inputs[0]),
4033 dy: node_offset(arena, node.inputs[1]),
4034 dgamma: node_offset(arena, node.id),
4035 n,
4036 c,
4037 h,
4038 w,
4039 num_groups: *num_groups as u32,
4040 eps: *eps,
4041 },
4042 Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4043 dy: node_offset(arena, node.inputs[1]),
4044 dbeta: node_offset(arena, node.id),
4045 n,
4046 c,
4047 h,
4048 w,
4049 },
4050 _ => unreachable!(),
4051 }
4052 }
4053
4054 Op::MaxPool2dBackward {
4055 kernel_size,
4056 stride,
4057 padding,
4058 } => {
4059 let x_shape = &graph.node(node.inputs[0]).shape;
4060 let dy_shape = &graph.node(node.inputs[1]).shape;
4061 if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4062 Thunk::MaxPool2dBackward {
4063 x: node_offset(arena, node.inputs[0]),
4064 dy: node_offset(arena, node.inputs[1]),
4065 dx: node_offset(arena, node.id),
4066 n: x_shape.dim(0).unwrap_static() as u32,
4067 c: x_shape.dim(1).unwrap_static() as u32,
4068 h: x_shape.dim(2).unwrap_static() as u32,
4069 w: x_shape.dim(3).unwrap_static() as u32,
4070 h_out: dy_shape.dim(2).unwrap_static() as u32,
4071 w_out: dy_shape.dim(3).unwrap_static() as u32,
4072 kh: kernel_size[0] as u32,
4073 kw: kernel_size[1] as u32,
4074 sh: stride.first().copied().unwrap_or(1) as u32,
4075 sw: stride.get(1).copied().unwrap_or(1) as u32,
4076 ph: padding.first().copied().unwrap_or(0) as u32,
4077 pw: padding.get(1).copied().unwrap_or(0) as u32,
4078 }
4079 } else {
4080 Thunk::Nop
4081 }
4082 }
4083
4084 Op::Conv2dBackwardInput {
4085 kernel_size,
4086 stride,
4087 padding,
4088 dilation,
4089 groups,
4090 } => {
4091 let dy_shape = &graph.node(node.inputs[0]).shape;
4092 let w_shape = &graph.node(node.inputs[1]).shape;
4093 let out_shape = &node.shape;
4094 if kernel_size.len() == 2
4095 && dy_shape.rank() == 4
4096 && w_shape.rank() == 4
4097 && out_shape.rank() == 4
4098 {
4099 Thunk::Conv2dBackwardInput {
4100 dy: node_offset(arena, node.inputs[0]),
4101 w: node_offset(arena, node.inputs[1]),
4102 dx: node_offset(arena, node.id),
4103 n: out_shape.dim(0).unwrap_static() as u32,
4104 c_in: out_shape.dim(1).unwrap_static() as u32,
4105 h: out_shape.dim(2).unwrap_static() as u32,
4106 w_in: out_shape.dim(3).unwrap_static() as u32,
4107 c_out: dy_shape.dim(1).unwrap_static() as u32,
4108 h_out: dy_shape.dim(2).unwrap_static() as u32,
4109 w_out: dy_shape.dim(3).unwrap_static() as u32,
4110 kh: kernel_size[0] as u32,
4111 kw: kernel_size[1] as u32,
4112 sh: stride.first().copied().unwrap_or(1) as u32,
4113 sw: stride.get(1).copied().unwrap_or(1) as u32,
4114 ph: padding.first().copied().unwrap_or(0) as u32,
4115 pw: padding.get(1).copied().unwrap_or(0) as u32,
4116 dh: dilation.first().copied().unwrap_or(1) as u32,
4117 dw: dilation.get(1).copied().unwrap_or(1) as u32,
4118 groups: *groups as u32,
4119 }
4120 } else {
4121 Thunk::Nop
4122 }
4123 }
4124
4125 Op::Conv2dBackwardWeight {
4126 kernel_size,
4127 stride,
4128 padding,
4129 dilation,
4130 groups,
4131 } => {
4132 let x_shape = &graph.node(node.inputs[0]).shape;
4133 let dy_shape = &graph.node(node.inputs[1]).shape;
4134 let dw_shape = &node.shape;
4135 if kernel_size.len() == 2
4136 && x_shape.rank() == 4
4137 && dy_shape.rank() == 4
4138 && dw_shape.rank() == 4
4139 {
4140 Thunk::Conv2dBackwardWeight {
4141 x: node_offset(arena, node.inputs[0]),
4142 dy: node_offset(arena, node.inputs[1]),
4143 dw: node_offset(arena, node.id),
4144 n: x_shape.dim(0).unwrap_static() as u32,
4145 c_in: x_shape.dim(1).unwrap_static() as u32,
4146 h: x_shape.dim(2).unwrap_static() as u32,
4147 w: x_shape.dim(3).unwrap_static() as u32,
4148 c_out: dy_shape.dim(1).unwrap_static() as u32,
4149 h_out: dy_shape.dim(2).unwrap_static() as u32,
4150 w_out: dy_shape.dim(3).unwrap_static() as u32,
4151 kh: kernel_size[0] as u32,
4152 kw: kernel_size[1] as u32,
4153 sh: stride.first().copied().unwrap_or(1) as u32,
4154 sw: stride.get(1).copied().unwrap_or(1) as u32,
4155 ph: padding.first().copied().unwrap_or(0) as u32,
4156 pw: padding.get(1).copied().unwrap_or(0) as u32,
4157 dh: dilation.first().copied().unwrap_or(1) as u32,
4158 dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4159 groups: *groups as u32,
4160 }
4161 } else {
4162 Thunk::Nop
4163 }
4164 }
4165
4166 Op::SoftmaxCrossEntropyWithLogits => {
4167 let logits_shape = &graph.node(node.inputs[0]).shape;
4168 if logits_shape.rank() == 2 {
4169 Thunk::SoftmaxCrossEntropy {
4170 logits: node_offset(arena, node.inputs[0]),
4171 labels: node_offset(arena, node.inputs[1]),
4172 dst: node_offset(arena, node.id),
4173 n: logits_shape.dim(0).unwrap_static() as u32,
4174 c: logits_shape.dim(1).unwrap_static() as u32,
4175 }
4176 } else {
4177 Thunk::Nop
4178 }
4179 }
4180
4181 Op::SoftmaxCrossEntropyBackward => {
4182 let logits_shape = &graph.node(node.inputs[0]).shape;
4183 if logits_shape.rank() == 2 {
4184 Thunk::SoftmaxCrossEntropyBackward {
4185 logits: node_offset(arena, node.inputs[0]),
4186 labels: node_offset(arena, node.inputs[1]),
4187 d_loss: node_offset(arena, node.inputs[2]),
4188 dlogits: node_offset(arena, node.id),
4189 n: logits_shape.dim(0).unwrap_static() as u32,
4190 c: logits_shape.dim(1).unwrap_static() as u32,
4191 }
4192 } else {
4193 Thunk::Nop
4194 }
4195 }
4196
4197 Op::DenseSolve => {
4198 let a_shape = &graph.node(node.inputs[0]).shape;
4200 let n = a_shape.dim(0).unwrap_static();
4201 debug_assert_eq!(
4202 n,
4203 a_shape.dim(1).unwrap_static(),
4204 "DenseSolve: A must be square"
4205 );
4206 let b_elems = node.shape.num_elements().unwrap();
4207 let nrhs = b_elems / n;
4208 match node.shape.dtype() {
4209 rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4210 a: node_offset(arena, node.inputs[0]),
4211 b: node_offset(arena, node.inputs[1]),
4212 x: node_offset(arena, node.id),
4213 n: n as u32,
4214 nrhs: nrhs as u32,
4215 },
4216 rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4217 a: node_offset(arena, node.inputs[0]),
4218 b: node_offset(arena, node.inputs[1]),
4219 x: node_offset(arena, node.id),
4220 n: n as u32,
4221 nrhs: nrhs as u32,
4222 },
4223 other => panic!(
4224 "DenseSolve: F32 + F64 lowered; got {other:?}. \
4225 Add another variant when needed."
4226 ),
4227 }
4228 }
4229
4230 Op::BatchedDenseSolve => {
4231 let a_shape = &graph.node(node.inputs[0]).shape;
4233 assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4234 let batch = a_shape.dim(0).unwrap_static();
4235 let n = a_shape.dim(1).unwrap_static();
4236 debug_assert_eq!(
4237 n,
4238 a_shape.dim(2).unwrap_static(),
4239 "BatchedDenseSolve: A's last two dims must match"
4240 );
4241 let total = node.shape.num_elements().unwrap();
4242 let nrhs = total / (batch * n);
4243 match node.shape.dtype() {
4244 rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4245 a: node_offset(arena, node.inputs[0]),
4246 b: node_offset(arena, node.inputs[1]),
4247 x: node_offset(arena, node.id),
4248 batch: batch as u32,
4249 n: n as u32,
4250 nrhs: nrhs as u32,
4251 },
4252 rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4253 a: node_offset(arena, node.inputs[0]),
4254 b: node_offset(arena, node.inputs[1]),
4255 x: node_offset(arena, node.id),
4256 batch: batch as u32,
4257 n: n as u32,
4258 nrhs: nrhs as u32,
4259 },
4260 other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4261 }
4262 }
4263
4264 Op::Scan {
4265 body,
4266 length,
4267 save_trajectory,
4268 num_bcast,
4269 num_xs,
4270 num_checkpoints,
4271 } => {
4272 assert!(
4273 *num_checkpoints == 0 || *num_checkpoints <= *length,
4274 "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4275 *num_checkpoints,
4276 *length
4277 );
4278 if *num_checkpoints != 0 && *num_checkpoints != *length {
4279 assert!(
4280 *save_trajectory,
4281 "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4282 );
4283 }
4284 let body_plan = rlx_opt::memory::plan_memory(body);
4295 let _body_arena_size = body_plan.arena_size;
4296 let body_offsets: HashMap<NodeId, usize> = body_plan
4299 .assignments
4300 .iter()
4301 .map(|(id, slot)| (*id, slot.offset))
4302 .collect();
4303
4304 let mut body_inputs: Vec<NodeId> = body
4307 .nodes()
4308 .iter()
4309 .filter(|n| matches!(n.op, Op::Input { .. }))
4310 .map(|n| n.id)
4311 .collect();
4312 body_inputs.sort();
4313 let n_body_inputs = body_inputs.len();
4314 let expected = 1 + *num_bcast as usize + *num_xs as usize;
4315 if n_body_inputs != expected {
4316 let names: Vec<String> = body
4317 .nodes()
4318 .iter()
4319 .filter_map(|n| match &n.op {
4320 Op::Input { name } => Some(format!("{}={}", n.id, name)),
4321 _ => None,
4322 })
4323 .collect();
4324 panic!(
4325 "Op::Scan body has {} Op::Input nodes; expected {} \
4326 (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4327 n_body_inputs,
4328 expected,
4329 *num_bcast,
4330 *num_xs,
4331 names.join(", ")
4332 );
4333 }
4334
4335 let body_input_id = body_inputs[0];
4336 let body_input_off = body_offsets[&body_input_id];
4337 let body_output_id = body
4338 .outputs
4339 .first()
4340 .copied()
4341 .expect("Op::Scan body must declare one output");
4342 let body_output_off = body_offsets[&body_output_id];
4343
4344 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4345 for n in body.nodes() {
4348 if let Op::Constant { data } = &n.op
4349 && body_arena.has_buffer(n.id)
4350 && !data.is_empty()
4351 {
4352 match n.shape.dtype() {
4353 rlx_ir::DType::F64 => {
4354 let off = body_arena.byte_offset(n.id);
4355 let buf = body_arena.raw_buf_mut();
4356 let nbytes = (buf.len() - off).min(data.len());
4357 buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4358 }
4359 _ => {
4360 let buf = body_arena.slice_mut(n.id);
4361 let n_floats = data.len() / 4;
4362 let n_lim = buf.len().min(n_floats);
4363 for i in 0..n_lim {
4364 let bytes = [
4365 data[i * 4],
4366 data[i * 4 + 1],
4367 data[i * 4 + 2],
4368 data[i * 4 + 3],
4369 ];
4370 buf[i] = f32::from_le_bytes(bytes);
4371 }
4372 }
4373 }
4374 }
4375 }
4376 let body_init = body_arena.raw_buf().to_vec();
4377 let body_schedule = compile_thunks(body, &body_arena);
4378
4379 let carry_bytes = if *save_trajectory {
4384 let total = node
4385 .shape
4386 .size_bytes()
4387 .expect("Op::Scan trajectory output must have static shape");
4388 total / *length as usize
4389 } else {
4390 node.shape
4391 .size_bytes()
4392 .expect("Op::Scan carry must have static shape")
4393 };
4394
4395 let mut bcast_inputs: Vec<(usize, usize, u32)> =
4400 Vec::with_capacity(*num_bcast as usize);
4401 for i in 0..*num_bcast as usize {
4402 let body_b_id = body_inputs[1 + i];
4403 let body_b_off = body_offsets[&body_b_id];
4404 let outer_b_id = node.inputs[1 + i];
4405 let outer_b_off = node_offset(arena, outer_b_id);
4406 let outer_b_shape = &graph.node(outer_b_id).shape;
4407 let total = outer_b_shape
4408 .size_bytes()
4409 .expect("Op::Scan bcast must have static shape");
4410 bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4411 }
4412
4413 let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4417 let xs_base = 1 + *num_bcast as usize;
4418 for i in 0..*num_xs as usize {
4419 let body_x_id = body_inputs[xs_base + i];
4420 let body_x_off = body_offsets[&body_x_id];
4421 let outer_xs_id = node.inputs[xs_base + i];
4422 let outer_xs_off = node_offset(arena, outer_xs_id);
4423 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4424 let total = outer_xs_shape
4425 .size_bytes()
4426 .expect("Op::Scan xs must have static shape");
4427 let per_step = total / *length as usize;
4428 xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4429 }
4430
4431 Thunk::Scan {
4432 body: Arc::new(body_schedule),
4433 body_init: Arc::new(body_init),
4434 body_input_off,
4435 body_output_off,
4436 outer_init_off: node_offset(arena, node.inputs[0]),
4437 outer_final_off: node_offset(arena, node.id),
4438 length: *length,
4439 carry_bytes: carry_bytes as u32,
4440 save_trajectory: *save_trajectory,
4441 xs_inputs: Arc::new(xs_inputs),
4442 bcast_inputs: Arc::new(bcast_inputs),
4443 num_checkpoints: *num_checkpoints,
4444 }
4445 }
4446
4447 Op::ScanBackward {
4448 body_vjp,
4449 length,
4450 save_trajectory,
4451 num_xs,
4452 num_checkpoints,
4453 forward_body,
4454 } => {
4455 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4456 if is_recursive {
4457 assert!(
4458 forward_body.is_some(),
4459 "Op::ScanBackward with num_checkpoints<length requires forward_body"
4460 );
4461 }
4462 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4470 let body_offsets: HashMap<NodeId, usize> = body_plan
4471 .assignments
4472 .iter()
4473 .map(|(id, slot)| (*id, slot.offset))
4474 .collect();
4475 let mut body_d_output_off: Option<usize> = None;
4476 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4477 for n in body_vjp.nodes() {
4478 if let Op::Input { name } = &n.op {
4479 let off = body_offsets[&n.id];
4480 if name == "d_output" {
4481 body_d_output_off = Some(off);
4482 } else {
4483 body_other_inputs.push((n.id, off));
4484 }
4485 }
4486 }
4487 body_other_inputs.sort_by_key(|(id, _)| *id);
4488 let body_d_output_off =
4489 body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4490 let expected_others = 1 + *num_xs as usize;
4491 assert_eq!(
4492 body_other_inputs.len(),
4493 expected_others,
4494 "ScanBackward body_vjp has {} non-d_output Inputs; \
4495 expected {} (1 carry + {} xs)",
4496 body_other_inputs.len(),
4497 expected_others,
4498 num_xs
4499 );
4500 let body_carry_in_off = body_other_inputs[0].1;
4501 let body_x_offs: Vec<usize> = body_other_inputs
4502 .iter()
4503 .skip(1)
4504 .map(|(_, off)| *off)
4505 .collect();
4506 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4507
4508 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4509 for n in body_vjp.nodes() {
4511 if let Op::Constant { data } = &n.op
4512 && body_arena.has_buffer(n.id)
4513 && !data.is_empty()
4514 {
4515 match n.shape.dtype() {
4516 rlx_ir::DType::F64 => {
4517 let off = body_arena.byte_offset(n.id);
4518 let buf = body_arena.raw_buf_mut();
4519 let nb = (buf.len() - off).min(data.len());
4520 buf[off..off + nb].copy_from_slice(&data[..nb]);
4521 }
4522 _ => {
4523 let buf = body_arena.slice_mut(n.id);
4524 let nf = data.len() / 4;
4525 let nl = buf.len().min(nf);
4526 for i in 0..nl {
4527 let bytes = [
4528 data[i * 4],
4529 data[i * 4 + 1],
4530 data[i * 4 + 2],
4531 data[i * 4 + 3],
4532 ];
4533 buf[i] = f32::from_le_bytes(bytes);
4534 }
4535 }
4536 }
4537 }
4538 }
4539 let body_init = body_arena.raw_buf().to_vec();
4540 let body_schedule = compile_thunks(body_vjp, &body_arena);
4541
4542 let carry_bytes = body_vjp
4544 .node(body_vjp.outputs[0])
4545 .shape
4546 .size_bytes()
4547 .expect("ScanBackward dcarry must be statically shaped");
4548 let carry_elem_size = body_vjp
4549 .node(body_vjp.outputs[0])
4550 .shape
4551 .dtype()
4552 .size_bytes() as u32;
4553
4554 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4557 for i in 0..*num_xs as usize {
4558 let outer_xs_id = node.inputs[3 + i];
4559 let outer_xs_off = node_offset(arena, outer_xs_id);
4560 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4561 let total = outer_xs_shape
4562 .size_bytes()
4563 .expect("ScanBackward xs must have static shape");
4564 let per_step = total / *length as usize;
4565 outer_xs_offs.push((outer_xs_off, per_step as u32));
4566 }
4567
4568 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4573 if is_recursive {
4574 let fb = forward_body.as_ref().unwrap();
4575 let fb_plan = rlx_opt::memory::plan_memory(fb);
4576 let fb_offsets: HashMap<NodeId, usize> = fb_plan
4577 .assignments
4578 .iter()
4579 .map(|(id, slot)| (*id, slot.offset))
4580 .collect();
4581 let mut fb_inputs: Vec<NodeId> = fb
4582 .nodes()
4583 .iter()
4584 .filter(|n| matches!(n.op, Op::Input { .. }))
4585 .map(|n| n.id)
4586 .collect();
4587 fb_inputs.sort();
4588 let fb_carry = fb_offsets[&fb_inputs[0]];
4589 let fb_xs: Vec<usize> = (1..fb_inputs.len())
4590 .map(|i| fb_offsets[&fb_inputs[i]])
4591 .collect();
4592 let fb_out = fb_offsets[&fb.outputs[0]];
4593 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4594 for n in fb.nodes() {
4595 if let Op::Constant { data } = &n.op
4596 && fb_arena.has_buffer(n.id)
4597 && !data.is_empty()
4598 {
4599 let off = fb_arena.byte_offset(n.id);
4606 let buf = fb_arena.raw_buf_mut();
4607 let nb = (buf.len() - off).min(data.len());
4608 buf[off..off + nb].copy_from_slice(&data[..nb]);
4609 }
4610 }
4611 let fb_init_bytes = fb_arena.raw_buf().to_vec();
4612 let fb_sched = compile_thunks(fb, &fb_arena);
4613 (
4614 Some(Arc::new(fb_sched)),
4615 Some(Arc::new(fb_init_bytes)),
4616 fb_carry,
4617 fb_out,
4618 fb_xs,
4619 )
4620 } else {
4621 (None, None, 0, 0, Vec::new())
4622 };
4623
4624 Thunk::ScanBackward {
4625 body_vjp: Arc::new(body_schedule),
4626 body_init: Arc::new(body_init),
4627 body_carry_in_off,
4628 body_x_offs: Arc::new(body_x_offs),
4629 body_d_output_off,
4630 body_dcarry_out_off,
4631 outer_init_off: node_offset(arena, node.inputs[0]),
4632 outer_traj_off: node_offset(arena, node.inputs[1]),
4633 outer_upstream_off: node_offset(arena, node.inputs[2]),
4634 outer_xs_offs: Arc::new(outer_xs_offs),
4635 outer_dinit_off: node_offset(arena, node.id),
4636 length: *length,
4637 carry_bytes: carry_bytes as u32,
4638 carry_elem_size,
4639 save_trajectory: *save_trajectory,
4640 num_checkpoints: *num_checkpoints,
4641 forward_body: fb_schedule,
4642 forward_body_init: fb_init,
4643 forward_body_carry_in_off: fb_carry_in_off,
4644 forward_body_output_off: fb_output_off,
4645 forward_body_x_offs: Arc::new(fb_x_offs),
4646 }
4647 }
4648
4649 Op::ScanBackwardXs {
4650 body_vjp,
4651 length,
4652 save_trajectory,
4653 num_xs,
4654 xs_idx,
4655 num_checkpoints,
4656 forward_body,
4657 } => {
4658 assert!(
4659 *num_checkpoints == 0 || *num_checkpoints <= *length,
4660 "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
4661 *num_checkpoints,
4662 *length
4663 );
4664 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4665 if is_recursive {
4666 assert!(
4667 forward_body.is_some(),
4668 "Op::ScanBackwardXs with num_checkpoints<length \
4669 requires forward_body"
4670 );
4671 }
4672 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4680 let body_offsets: HashMap<NodeId, usize> = body_plan
4681 .assignments
4682 .iter()
4683 .map(|(id, slot)| (*id, slot.offset))
4684 .collect();
4685 let mut body_d_output_off: Option<usize> = None;
4686 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4687 for n in body_vjp.nodes() {
4688 if let Op::Input { name } = &n.op {
4689 let off = body_offsets[&n.id];
4690 if name == "d_output" {
4691 body_d_output_off = Some(off);
4692 } else {
4693 body_other_inputs.push((n.id, off));
4694 }
4695 }
4696 }
4697 body_other_inputs.sort_by_key(|(id, _)| *id);
4698 let body_d_output_off =
4699 body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
4700 let expected_others = 1 + *num_xs as usize;
4701 assert_eq!(
4702 body_other_inputs.len(),
4703 expected_others,
4704 "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
4705 body_other_inputs.len(),
4706 expected_others
4707 );
4708 let body_carry_in_off = body_other_inputs[0].1;
4709 let body_x_offs: Vec<usize> = body_other_inputs
4710 .iter()
4711 .skip(1)
4712 .map(|(_, off)| *off)
4713 .collect();
4714 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4715 let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
4716 let body_dxs_out_off = body_offsets[&dxs_out_node];
4717
4718 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4719 for n in body_vjp.nodes() {
4720 if let Op::Constant { data } = &n.op
4721 && body_arena.has_buffer(n.id)
4722 && !data.is_empty()
4723 {
4724 match n.shape.dtype() {
4725 rlx_ir::DType::F64 => {
4726 let off = body_arena.byte_offset(n.id);
4727 let buf = body_arena.raw_buf_mut();
4728 let nb = (buf.len() - off).min(data.len());
4729 buf[off..off + nb].copy_from_slice(&data[..nb]);
4730 }
4731 _ => {
4732 let buf = body_arena.slice_mut(n.id);
4733 let nf = data.len() / 4;
4734 let nl = buf.len().min(nf);
4735 for i in 0..nl {
4736 let bytes = [
4737 data[i * 4],
4738 data[i * 4 + 1],
4739 data[i * 4 + 2],
4740 data[i * 4 + 3],
4741 ];
4742 buf[i] = f32::from_le_bytes(bytes);
4743 }
4744 }
4745 }
4746 }
4747 }
4748 let body_init = body_arena.raw_buf().to_vec();
4749 let body_schedule = compile_thunks(body_vjp, &body_arena);
4750
4751 let carry_bytes = body_vjp
4752 .node(body_vjp.outputs[0])
4753 .shape
4754 .size_bytes()
4755 .expect("ScanBackwardXs dcarry must be statically shaped");
4756 let carry_elem_size = body_vjp
4757 .node(body_vjp.outputs[0])
4758 .shape
4759 .dtype()
4760 .size_bytes() as u32;
4761 let per_step_bytes = body_vjp
4762 .node(dxs_out_node)
4763 .shape
4764 .size_bytes()
4765 .expect("ScanBackwardXs dxs body output must be statically shaped");
4766
4767 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4768 for i in 0..*num_xs as usize {
4769 let outer_xs_id = node.inputs[3 + i];
4770 let outer_xs_off = node_offset(arena, outer_xs_id);
4771 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4772 let total = outer_xs_shape
4773 .size_bytes()
4774 .expect("ScanBackwardXs xs must have static shape");
4775 let per_step = total / *length as usize;
4776 outer_xs_offs.push((outer_xs_off, per_step as u32));
4777 }
4778
4779 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4782 if is_recursive {
4783 let fb = forward_body.as_ref().unwrap();
4784 let fb_plan = rlx_opt::memory::plan_memory(fb);
4785 let fb_offsets: HashMap<NodeId, usize> = fb_plan
4786 .assignments
4787 .iter()
4788 .map(|(id, slot)| (*id, slot.offset))
4789 .collect();
4790 let mut fb_inputs: Vec<NodeId> = fb
4791 .nodes()
4792 .iter()
4793 .filter(|n| matches!(n.op, Op::Input { .. }))
4794 .map(|n| n.id)
4795 .collect();
4796 fb_inputs.sort();
4797 let fb_carry = fb_offsets[&fb_inputs[0]];
4798 let fb_xs: Vec<usize> = (1..fb_inputs.len())
4799 .map(|i| fb_offsets[&fb_inputs[i]])
4800 .collect();
4801 let fb_out = fb_offsets[&fb.outputs[0]];
4802 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4803 for n in fb.nodes() {
4804 if let Op::Constant { data } = &n.op
4805 && fb_arena.has_buffer(n.id)
4806 && !data.is_empty()
4807 {
4808 let off = fb_arena.byte_offset(n.id);
4815 let buf = fb_arena.raw_buf_mut();
4816 let nb = (buf.len() - off).min(data.len());
4817 buf[off..off + nb].copy_from_slice(&data[..nb]);
4818 }
4819 }
4820 let fb_init_bytes = fb_arena.raw_buf().to_vec();
4821 let fb_sched = compile_thunks(fb, &fb_arena);
4822 (
4823 Some(Arc::new(fb_sched)),
4824 Some(Arc::new(fb_init_bytes)),
4825 fb_carry,
4826 fb_out,
4827 fb_xs,
4828 )
4829 } else {
4830 (None, None, 0, 0, Vec::new())
4831 };
4832
4833 Thunk::ScanBackwardXs {
4834 body_vjp: Arc::new(body_schedule),
4835 body_init: Arc::new(body_init),
4836 body_carry_in_off,
4837 body_x_offs: Arc::new(body_x_offs),
4838 body_d_output_off,
4839 body_dcarry_out_off,
4840 body_dxs_out_off,
4841 outer_init_off: node_offset(arena, node.inputs[0]),
4842 outer_traj_off: node_offset(arena, node.inputs[1]),
4843 outer_upstream_off: node_offset(arena, node.inputs[2]),
4844 outer_xs_offs: Arc::new(outer_xs_offs),
4845 outer_dxs_off: node_offset(arena, node.id),
4846 length: *length,
4847 carry_bytes: carry_bytes as u32,
4848 carry_elem_size,
4849 per_step_bytes: per_step_bytes as u32,
4850 save_trajectory: *save_trajectory,
4851 num_checkpoints: *num_checkpoints,
4852 forward_body: fb_schedule,
4853 forward_body_init: fb_init,
4854 forward_body_carry_in_off: fb_carry_in_off,
4855 forward_body_output_off: fb_output_off,
4856 forward_body_x_offs: Arc::new(fb_x_offs),
4857 }
4858 }
4859
4860 Op::Concat { axis } => {
4861 let out_shape = &node.shape;
4865 let rank = out_shape.rank();
4866 let outer: usize = (0..*axis)
4867 .map(|i| out_shape.dim(i).unwrap_static())
4868 .product::<usize>()
4869 .max(1);
4870 let inner: usize = (*axis + 1..rank)
4871 .map(|i| out_shape.dim(i).unwrap_static())
4872 .product::<usize>()
4873 .max(1);
4874 let total_axis = out_shape.dim(*axis).unwrap_static();
4875 let inputs: Vec<(usize, u32)> = node
4876 .inputs
4877 .iter()
4878 .map(|&in_id| {
4879 let in_shape = &graph.node(in_id).shape;
4880 let in_axis = in_shape.dim(*axis).unwrap_static();
4881 (node_offset(arena, in_id), in_axis as u32)
4882 })
4883 .collect();
4884 let dst = node_offset(arena, node.id);
4885 match out_shape.dtype() {
4886 rlx_ir::DType::F64 => Thunk::ConcatF64 {
4887 dst,
4888 outer: outer as u32,
4889 inner: inner as u32,
4890 total_axis: total_axis as u32,
4891 inputs,
4892 },
4893 _ => Thunk::Concat {
4894 dst,
4895 outer: outer as u32,
4896 inner: inner as u32,
4897 total_axis: total_axis as u32,
4898 inputs,
4899 },
4900 }
4901 }
4902
4903 Op::GaussianSplatRender {
4904 width,
4905 height,
4906 tile_size,
4907 radius_scale,
4908 alpha_cutoff,
4909 max_splat_steps,
4910 transmittance_threshold,
4911 max_list_entries,
4912 } => {
4913 let elem_len =
4914 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4915 Thunk::GaussianSplatRender {
4916 positions_off: node_offset(arena, node.inputs[0]),
4917 positions_len: elem_len(node.inputs[0]),
4918 scales_off: node_offset(arena, node.inputs[1]),
4919 scales_len: elem_len(node.inputs[1]),
4920 rotations_off: node_offset(arena, node.inputs[2]),
4921 rotations_len: elem_len(node.inputs[2]),
4922 opacities_off: node_offset(arena, node.inputs[3]),
4923 opacities_len: elem_len(node.inputs[3]),
4924 colors_off: node_offset(arena, node.inputs[4]),
4925 colors_len: elem_len(node.inputs[4]),
4926 sh_coeffs_off: node_offset(arena, node.inputs[5]),
4927 sh_coeffs_len: elem_len(node.inputs[5]),
4928 meta_off: node_offset(arena, node.inputs[6]),
4929 dst_off: node_offset(arena, node.id),
4930 dst_len: node.shape.num_elements().unwrap_or(0),
4931 width: *width,
4932 height: *height,
4933 tile_size: *tile_size,
4934 radius_scale: *radius_scale,
4935 alpha_cutoff: *alpha_cutoff,
4936 max_splat_steps: *max_splat_steps,
4937 transmittance_threshold: *transmittance_threshold,
4938 max_list_entries: *max_list_entries,
4939 }
4940 }
4941
4942 Op::GaussianSplatRenderBackward {
4943 width,
4944 height,
4945 tile_size,
4946 radius_scale,
4947 alpha_cutoff,
4948 max_splat_steps,
4949 transmittance_threshold,
4950 max_list_entries,
4951 loss_grad_clip,
4952 sh_band,
4953 max_anisotropy,
4954 } => {
4955 let elem_len =
4956 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4957 Thunk::GaussianSplatRenderBackward {
4958 positions_off: node_offset(arena, node.inputs[0]),
4959 positions_len: elem_len(node.inputs[0]),
4960 scales_off: node_offset(arena, node.inputs[1]),
4961 scales_len: elem_len(node.inputs[1]),
4962 rotations_off: node_offset(arena, node.inputs[2]),
4963 rotations_len: elem_len(node.inputs[2]),
4964 opacities_off: node_offset(arena, node.inputs[3]),
4965 opacities_len: elem_len(node.inputs[3]),
4966 colors_off: node_offset(arena, node.inputs[4]),
4967 colors_len: elem_len(node.inputs[4]),
4968 sh_coeffs_off: node_offset(arena, node.inputs[5]),
4969 sh_coeffs_len: elem_len(node.inputs[5]),
4970 meta_off: node_offset(arena, node.inputs[6]),
4971 d_loss_off: node_offset(arena, node.inputs[7]),
4972 d_loss_len: elem_len(node.inputs[7]),
4973 packed_off: node_offset(arena, node.id),
4974 packed_len: node.shape.num_elements().unwrap_or(0),
4975 width: *width,
4976 height: *height,
4977 tile_size: *tile_size,
4978 radius_scale: *radius_scale,
4979 alpha_cutoff: *alpha_cutoff,
4980 max_splat_steps: *max_splat_steps,
4981 transmittance_threshold: *transmittance_threshold,
4982 max_list_entries: *max_list_entries,
4983 loss_grad_clip: *loss_grad_clip,
4984 sh_band: *sh_band,
4985 max_anisotropy: *max_anisotropy,
4986 }
4987 }
4988
4989 Op::GaussianSplatPrepare {
4990 width,
4991 height,
4992 tile_size,
4993 radius_scale,
4994 alpha_cutoff,
4995 max_splat_steps,
4996 transmittance_threshold,
4997 max_list_entries,
4998 } => {
4999 let elem_len =
5000 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5001 Thunk::GaussianSplatPrepare {
5002 positions_off: node_offset(arena, node.inputs[0]),
5003 positions_len: elem_len(node.inputs[0]),
5004 scales_off: node_offset(arena, node.inputs[1]),
5005 scales_len: elem_len(node.inputs[1]),
5006 rotations_off: node_offset(arena, node.inputs[2]),
5007 rotations_len: elem_len(node.inputs[2]),
5008 opacities_off: node_offset(arena, node.inputs[3]),
5009 opacities_len: elem_len(node.inputs[3]),
5010 colors_off: node_offset(arena, node.inputs[4]),
5011 colors_len: elem_len(node.inputs[4]),
5012 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5013 sh_coeffs_len: elem_len(node.inputs[5]),
5014 meta_off: node_offset(arena, node.inputs[6]),
5015 meta_len: elem_len(node.inputs[6]),
5016 prep_off: node_offset(arena, node.id),
5017 prep_len: node.shape.num_elements().unwrap_or(0),
5018 width: *width,
5019 height: *height,
5020 tile_size: *tile_size,
5021 radius_scale: *radius_scale,
5022 alpha_cutoff: *alpha_cutoff,
5023 max_splat_steps: *max_splat_steps,
5024 transmittance_threshold: *transmittance_threshold,
5025 max_list_entries: *max_list_entries,
5026 }
5027 }
5028
5029 Op::GaussianSplatRasterize {
5030 width,
5031 height,
5032 tile_size,
5033 alpha_cutoff,
5034 max_splat_steps,
5035 transmittance_threshold,
5036 max_list_entries,
5037 } => {
5038 let elem_len =
5039 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5040 let prep_id = node.inputs[0];
5041 let count = match &graph.node(prep_id).op {
5042 rlx_ir::Op::GaussianSplatPrepare { .. } => {
5043 elem_len(graph.node(prep_id).inputs[0]) / 3
5044 }
5045 _ => 1,
5046 };
5047 Thunk::GaussianSplatRasterize {
5048 prep_off: node_offset(arena, prep_id),
5049 prep_len: elem_len(prep_id),
5050 meta_off: node_offset(arena, node.inputs[1]),
5051 meta_len: elem_len(node.inputs[1]),
5052 dst_off: node_offset(arena, node.id),
5053 dst_len: node.shape.num_elements().unwrap_or(0),
5054 count,
5055 width: *width,
5056 height: *height,
5057 tile_size: *tile_size,
5058 alpha_cutoff: *alpha_cutoff,
5059 max_splat_steps: *max_splat_steps,
5060 transmittance_threshold: *transmittance_threshold,
5061 max_list_entries: *max_list_entries,
5062 }
5063 }
5064
5065 Op::Custom { name, attrs, .. } => {
5066 let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5067 panic!(
5068 "compile_thunks: no CPU kernel registered for \
5069 Op::Custom('{name}'). Register one via \
5070 rlx_cpu::op_registry::register_cpu_kernel \
5071 before compiling on the CPU backend."
5072 )
5073 });
5074 let inputs_v: Vec<(usize, u32, Shape)> = node
5075 .inputs
5076 .iter()
5077 .map(|&in_id| {
5078 let s = graph.node(in_id).shape.clone();
5079 let len = s.num_elements().unwrap_or(0) as u32;
5080 (node_offset(arena, in_id), len, s)
5081 })
5082 .collect();
5083 let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5084 Thunk::CustomOp {
5085 kernel,
5086 inputs: inputs_v,
5087 output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5088 attrs: attrs.clone(),
5089 }
5090 }
5091
5092 Op::Fft { inverse } => {
5093 let shape = &node.shape;
5101 let last = shape.dim(shape.rank() - 1).unwrap_static();
5102 let n_complex = (last / 2) as u32;
5103 let total = shape.num_elements().unwrap_or(0);
5104 let outer = (total / last) as u32;
5105 let dtype = shape.dtype();
5106 assert!(
5107 matches!(dtype, rlx_ir::DType::F32 | rlx_ir::DType::F64),
5108 "Op::Fft on CPU requires F32 or F64, got {dtype:?}"
5109 );
5110 Thunk::Fft1d {
5111 src: node_offset(arena, node.inputs[0]),
5112 dst: node_offset(arena, node.id),
5113 outer,
5114 n_complex,
5115 inverse: *inverse,
5116 dtype,
5117 }
5118 }
5119
5120 Op::CustomFn {
5121 fwd_body,
5122 num_inputs,
5123 ..
5124 } => {
5125 let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5131 let body_offsets: HashMap<NodeId, usize> = body_plan
5132 .assignments
5133 .iter()
5134 .map(|(id, slot)| (*id, slot.offset))
5135 .collect();
5136
5137 let mut body_input_ids: Vec<NodeId> = fwd_body
5138 .nodes()
5139 .iter()
5140 .filter(|n| matches!(n.op, Op::Input { .. }))
5141 .map(|n| n.id)
5142 .collect();
5143 body_input_ids.sort();
5144 assert_eq!(
5145 body_input_ids.len(),
5146 *num_inputs as usize,
5147 "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5148 body_input_ids.len(),
5149 *num_inputs,
5150 );
5151
5152 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5153 for n in fwd_body.nodes() {
5154 if let Op::Constant { data } = &n.op
5155 && body_arena.has_buffer(n.id)
5156 && !data.is_empty()
5157 {
5158 match n.shape.dtype() {
5159 rlx_ir::DType::F64 => {
5160 let off = body_arena.byte_offset(n.id);
5161 let buf = body_arena.raw_buf_mut();
5162 let nb = (buf.len() - off).min(data.len());
5163 buf[off..off + nb].copy_from_slice(&data[..nb]);
5164 }
5165 _ => {
5166 let buf = body_arena.slice_mut(n.id);
5167 let nf = data.len() / 4;
5168 let nl = buf.len().min(nf);
5169 for i in 0..nl {
5170 let bytes = [
5171 data[i * 4],
5172 data[i * 4 + 1],
5173 data[i * 4 + 2],
5174 data[i * 4 + 3],
5175 ];
5176 buf[i] = f32::from_le_bytes(bytes);
5177 }
5178 }
5179 }
5180 }
5181 }
5182 let body_init = body_arena.raw_buf().to_vec();
5183 let body_schedule = compile_thunks(fwd_body, &body_arena);
5184
5185 let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5187 .map(|i| {
5188 let body_in = body_input_ids[i];
5189 let body_off = body_offsets[&body_in];
5190 let outer_in = node.inputs[i];
5191 let outer_off = node_offset(arena, outer_in);
5192 let bytes = graph
5193 .node(outer_in)
5194 .shape
5195 .size_bytes()
5196 .expect("Op::CustomFn primal input must have static shape");
5197 (body_off, outer_off, bytes as u32)
5198 })
5199 .collect();
5200
5201 let body_output_id = fwd_body
5202 .outputs
5203 .first()
5204 .copied()
5205 .expect("Op::CustomFn fwd_body must declare exactly one output");
5206 let body_output_off = body_offsets[&body_output_id];
5207 let out_bytes = node
5208 .shape
5209 .size_bytes()
5210 .expect("Op::CustomFn output must have static shape");
5211
5212 Thunk::CustomFn {
5213 body: Arc::new(body_schedule),
5214 body_init: Arc::new(body_init),
5215 inputs: Arc::new(inputs_v),
5216 body_output_off,
5217 outer_output_off: node_offset(arena, node.id),
5218 out_bytes: out_bytes as u32,
5219 }
5220 }
5221
5222 _ => Thunk::Nop,
5223 };
5224 thunks.push(t);
5225 }
5226
5227 let cfg = crate::config::RuntimeConfig::global();
5228 let mask_thr = cfg.mask_binary_threshold;
5229 let mask_neg = cfg.attn_mask_neg_inf;
5230 let score_skip = cfg.score_skip_threshold;
5231
5232 let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5234 .iter()
5235 .filter(|t| !matches!(t, Thunk::Nop))
5236 .map(|thunk| {
5237 match thunk.clone() {
5238 Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5239
5240 Thunk::Sgemm { a, b, c, m, k, n } => {
5241 let (m, k, n) = (m as usize, k as usize, n as usize);
5242 Arc::new(move |base: *mut u8| unsafe {
5243 crate::blas::sgemm(
5244 sl(a, base, m * k),
5245 sl(b, base, k * n),
5246 sl_mut(c, base, m * n),
5247 m,
5248 k,
5249 n,
5250 );
5251 })
5252 }
5253
5254 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5255 let (n_, nrhs_) = (n as usize, nrhs as usize);
5256 Arc::new(move |base: *mut u8| unsafe {
5257 let a_src = sl_f64(a, base, n_ * n_);
5258 let b_src = sl_f64(b, base, n_ * nrhs_);
5259 let mut a_scratch: Vec<f64> = a_src.to_vec();
5260 let mut x_buf: Vec<f64> = b_src.to_vec();
5261 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5262 if info != 0 {
5263 panic!("DenseSolveF64: singular (info={info})");
5264 }
5265 sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5266 })
5267 }
5268
5269 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5270 let (n_, nrhs_) = (n as usize, nrhs as usize);
5271 Arc::new(move |base: *mut u8| unsafe {
5272 let a_src = sl(a, base, n_ * n_);
5273 let b_src = sl(b, base, n_ * nrhs_);
5274 let mut a_scratch: Vec<f32> = a_src.to_vec();
5275 let mut x_buf: Vec<f32> = b_src.to_vec();
5276 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5277 if info != 0 {
5278 panic!("DenseSolveF32: singular (info={info})");
5279 }
5280 sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5281 })
5282 }
5283
5284 Thunk::FusedMmBiasAct {
5285 a,
5286 w,
5287 bias,
5288 c,
5289 m,
5290 k,
5291 n,
5292 act,
5293 } => {
5294 let (m, k, n) = (m as usize, k as usize, n as usize);
5295 Arc::new(move |base: *mut u8| unsafe {
5296 let out = sl_mut(c, base, m * n);
5297 crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5298 match act {
5306 Some(Activation::Gelu) => {
5307 crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5308 }
5309 Some(other) => {
5310 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5311 apply_activation_inplace(out, other);
5312 }
5313 None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5314 }
5315 })
5316 }
5317
5318 Thunk::FusedResidualLN {
5319 x,
5320 res,
5321 bias,
5322 g,
5323 b,
5324 out,
5325 rows,
5326 h,
5327 eps,
5328 has_bias,
5329 } => {
5330 let (rows, h) = (rows as usize, h as usize);
5331 Arc::new(move |base: *mut u8| unsafe {
5332 let zero = vec![0f32; h]; let bi = if has_bias { sl(bias, base, h) } else { &zero };
5334 let xp = sl(x, base, rows * h).as_ptr() as usize;
5335 let rp = sl(res, base, rows * h).as_ptr() as usize;
5336 let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5337 let bp = bi.as_ptr() as usize;
5338 let gp = sl(g, base, h).as_ptr() as usize;
5339 let bbp = sl(b, base, h).as_ptr() as usize;
5340 crate::pool::par_for(rows, 4, &|off, cnt| {
5341 let xs = std::slice::from_raw_parts(
5342 (xp as *const f32).add(off * h),
5343 cnt * h,
5344 );
5345 let rs = std::slice::from_raw_parts(
5346 (rp as *const f32).add(off * h),
5347 cnt * h,
5348 );
5349 let os = std::slice::from_raw_parts_mut(
5350 (op as *mut f32).add(off * h),
5351 cnt * h,
5352 );
5353 let bi = std::slice::from_raw_parts(bp as *const f32, h);
5354 let g = std::slice::from_raw_parts(gp as *const f32, h);
5355 let b = std::slice::from_raw_parts(bbp as *const f32, h);
5356 crate::kernels::residual_bias_layer_norm(
5357 xs, rs, bi, g, b, os, cnt, h, eps,
5358 );
5359 });
5360 })
5361 }
5362
5363 Thunk::BiasAdd {
5364 src,
5365 bias,
5366 dst,
5367 m,
5368 n,
5369 } => {
5370 let (m, n) = (m as usize, n as usize);
5371 Arc::new(move |base: *mut u8| unsafe {
5372 let out = sl_mut(dst, base, m * n);
5373 out.copy_from_slice(sl(src, base, m * n));
5374 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5375 })
5376 }
5377
5378 Thunk::Gather {
5379 table,
5380 table_len,
5381 idx,
5382 dst,
5383 num_idx,
5384 trailing,
5385 } => {
5386 let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5387 Arc::new(move |base: *mut u8| unsafe {
5388 let tab = sl(table, base, tl);
5389 let ids = sl(idx, base, ni);
5390 let out = sl_mut(dst, base, ni * tr);
5391 for i in 0..ni {
5392 let row = ids[i] as usize;
5393 out[i * tr..(i + 1) * tr]
5394 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5395 }
5396 })
5397 }
5398
5399 Thunk::Narrow {
5400 src,
5401 dst,
5402 outer,
5403 src_stride,
5404 dst_stride,
5405 inner,
5406 elem_bytes,
5407 } => {
5408 narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5409 }
5410
5411 Thunk::Copy { src, dst, len } => {
5412 let len = len as usize;
5413 Arc::new(move |base: *mut u8| unsafe {
5414 sl_mut(dst, base, len).copy_from_slice(sl(src, base, len));
5415 })
5416 }
5417
5418 Thunk::Softmax { data, rows, cols } => {
5419 let (rows, cols) = (rows as usize, cols as usize);
5420 Arc::new(move |base: *mut u8| unsafe {
5421 crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5422 })
5423 }
5424
5425 Thunk::Cumsum {
5426 src,
5427 dst,
5428 rows,
5429 cols,
5430 exclusive,
5431 } => {
5432 let (rows, cols) = (rows as usize, cols as usize);
5433 Arc::new(move |base: *mut u8| unsafe {
5434 let s = sl(src, base, rows * cols);
5435 let d = sl_mut(dst, base, rows * cols);
5436 if exclusive {
5437 for r in 0..rows {
5438 let mut acc = 0.0f32;
5439 for c in 0..cols {
5440 d[r * cols + c] = acc;
5441 acc += s[r * cols + c];
5442 }
5443 }
5444 } else {
5445 for r in 0..rows {
5446 let mut acc = 0.0f32;
5447 for c in 0..cols {
5448 acc += s[r * cols + c];
5449 d[r * cols + c] = acc;
5450 }
5451 }
5452 }
5453 })
5454 }
5455
5456 Thunk::Sample {
5457 logits,
5458 dst,
5459 batch,
5460 vocab,
5461 top_k,
5462 top_p,
5463 temperature,
5464 seed,
5465 } => {
5466 let (b, v) = (batch as usize, vocab as usize);
5467 let k = (top_k as usize).min(v);
5468 Arc::new(move |base: *mut u8| unsafe {
5469 let lg = sl(logits, base, b * v);
5470 let out = sl_mut(dst, base, b);
5471 let mut rng =
5472 rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5473 for bi in 0..b {
5474 let row = &lg[bi * v..(bi + 1) * v];
5475 out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5476 }
5477 })
5478 }
5479
5480 Thunk::DequantMatMul {
5481 x,
5482 w_q,
5483 scale,
5484 zp,
5485 dst,
5486 m,
5487 k,
5488 n,
5489 block_size,
5490 is_asymmetric,
5491 } => {
5492 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5493 let n_blocks_per_col = k.div_ceil(bs);
5494 Arc::new(move |base: *mut u8| unsafe {
5495 let xs = sl(x, base, m * k);
5496 let raw = base.add(w_q);
5498 let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5499 let scales = sl(scale, base, n_blocks_per_col * n);
5500 let zps = if is_asymmetric {
5501 sl(zp, base, n_blocks_per_col * n)
5502 } else {
5503 &[][..]
5504 };
5505 let out = sl_mut(dst, base, m * n);
5506 dequant_matmul_int8(
5507 xs,
5508 w_bytes,
5509 scales,
5510 zps,
5511 out,
5512 m,
5513 k,
5514 n,
5515 bs,
5516 is_asymmetric,
5517 );
5518 })
5519 }
5520
5521 Thunk::DequantMatMulGguf {
5522 x,
5523 w_q,
5524 dst,
5525 m,
5526 k,
5527 n,
5528 scheme,
5529 } => {
5530 let (m, k, n) = (m as usize, k as usize, n as usize);
5531 let block_bytes = scheme.gguf_block_bytes() as usize;
5532 let block_elems = scheme.gguf_block_size() as usize;
5533 let total_bytes = (k * n) / block_elems * block_bytes;
5534 Arc::new(move |base: *mut u8| unsafe {
5535 let xs = sl(x, base, m * k);
5536 let w_bytes =
5537 std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
5538 let out = sl_mut(dst, base, m * n);
5539 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
5540 })
5541 }
5542
5543 Thunk::DequantMatMulInt4 {
5544 x,
5545 w_q,
5546 scale,
5547 zp,
5548 dst,
5549 m,
5550 k,
5551 n,
5552 block_size,
5553 is_asymmetric,
5554 } => {
5555 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5556 let n_blocks = k.div_ceil(bs);
5557 Arc::new(move |base: *mut u8| unsafe {
5558 let xs = sl(x, base, m * k);
5559 let w_bytes = std::slice::from_raw_parts(
5560 base.add(w_q) as *const u8,
5561 (k * n).div_ceil(2),
5562 );
5563 let scales = sl(scale, base, n_blocks * n);
5564 let zps = if is_asymmetric {
5565 sl(zp, base, n_blocks * n)
5566 } else {
5567 &[][..]
5568 };
5569 let out = sl_mut(dst, base, m * n);
5570 dequant_matmul_int4(
5571 xs,
5572 w_bytes,
5573 scales,
5574 zps,
5575 out,
5576 m,
5577 k,
5578 n,
5579 bs,
5580 is_asymmetric,
5581 );
5582 })
5583 }
5584
5585 Thunk::DequantMatMulFp8 {
5586 x,
5587 w_q,
5588 scale,
5589 dst,
5590 m,
5591 k,
5592 n,
5593 e5m2,
5594 } => {
5595 let (m, k, n) = (m as usize, k as usize, n as usize);
5596 Arc::new(move |base: *mut u8| unsafe {
5597 let xs = sl(x, base, m * k);
5598 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
5599 let scales = sl(scale, base, n);
5600 let out = sl_mut(dst, base, m * n);
5601 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
5602 })
5603 }
5604
5605 Thunk::DequantMatMulNvfp4 {
5606 x,
5607 w_q,
5608 scale,
5609 global_scale,
5610 dst,
5611 m,
5612 k,
5613 n,
5614 } => {
5615 let (m, k, n) = (m as usize, k as usize, n as usize);
5616 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
5617 Arc::new(move |base: *mut u8| unsafe {
5618 let xs = sl(x, base, m * k);
5619 let w_bytes = std::slice::from_raw_parts(
5620 base.add(w_q) as *const u8,
5621 (k * n).div_ceil(2),
5622 );
5623 let scale_bytes =
5624 std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
5625 let gs = sl(global_scale, base, 1)[0];
5626 let out = sl_mut(dst, base, m * n);
5627 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
5628 })
5629 }
5630
5631 Thunk::LoraMatMul {
5632 x,
5633 w,
5634 a,
5635 b,
5636 dst,
5637 m,
5638 k,
5639 n,
5640 r,
5641 scale,
5642 } => {
5643 let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
5644 Arc::new(move |base: *mut u8| unsafe {
5645 let xs = sl(x, base, m * k);
5646 let ws = sl(w, base, k * n);
5647 let a_s = sl(a, base, k * r);
5648 let bs = sl(b, base, r * n);
5649 let out = sl_mut(dst, base, m * n);
5650 crate::blas::sgemm(xs, ws, out, m, k, n);
5652 let mut tmp = vec![0f32; m * r];
5654 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
5655 if scale != 1.0 {
5659 for v in tmp.iter_mut() {
5660 *v *= scale;
5661 }
5662 }
5663 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
5664 })
5665 }
5666
5667 Thunk::LayerNorm {
5668 src,
5669 g,
5670 b,
5671 dst,
5672 rows,
5673 h,
5674 eps,
5675 } => {
5676 let (rows, h) = (rows as usize, h as usize);
5677 Arc::new(move |base: *mut u8| unsafe {
5678 let inp = sl(src, base, rows * h);
5679 let gamma = sl(g, base, h);
5680 let beta = sl(b, base, h);
5681 let out = sl_mut(dst, base, rows * h);
5682 for row in 0..rows {
5683 crate::kernels::layer_norm_row(
5684 &inp[row * h..(row + 1) * h],
5685 gamma,
5686 beta,
5687 &mut out[row * h..(row + 1) * h],
5688 h,
5689 eps,
5690 );
5691 }
5692 })
5693 }
5694
5695 Thunk::Attention {
5696 q,
5697 k,
5698 v,
5699 mask,
5700 out,
5701 batch,
5702 seq,
5703 kv_seq: _,
5704 heads,
5705 head_dim,
5706 mask_kind,
5707 q_row_stride,
5708 k_row_stride,
5709 v_row_stride,
5710 bhsd,
5711 } => {
5712 let (b, s, nh, dh) = (
5713 batch as usize,
5714 seq as usize,
5715 heads as usize,
5716 head_dim as usize,
5717 );
5718 let hs = nh * dh;
5719 let qrs = q_row_stride as usize;
5720 let krs = k_row_stride as usize;
5721 let vrs = v_row_stride as usize;
5722 let scale = (dh as f32).powf(-0.5);
5723 Arc::new(move |base: *mut u8| unsafe {
5724 let (q_len, k_len, v_len, o_len) = if bhsd {
5729 let n = b * nh * s * dh;
5730 (n, n, n, n)
5731 } else {
5732 (b * s * qrs, b * s * krs, b * s * vrs, b * s * hs)
5733 };
5734 let q_d = sl(q, base, q_len);
5735 let k_d = sl(k, base, k_len);
5736 let v_d = sl(v, base, v_len);
5737 let m_d: &[f32] = match mask_kind {
5738 rlx_ir::op::MaskKind::Custom => sl(mask, base, b * s),
5739 rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * s * s),
5740 _ => &[],
5741 };
5742 let o_d = sl_mut(out, base, o_len);
5743 let sdh = s * dh;
5744 let mut qh = vec![0f32; sdh];
5745 let mut kh = vec![0f32; sdh];
5746 let mut vh = vec![0f32; sdh];
5747 let mut sc = vec![0f32; s * s];
5748 let mut oh = vec![0f32; sdh];
5749 for bi in 0..b {
5750 for hi in 0..nh {
5751 for si in 0..s {
5752 let (q_off, k_off, v_off) = if bhsd {
5764 (
5765 bi * nh * s * dh + hi * s * dh + si * dh,
5766 bi * nh * s * dh + hi * s * dh + si * dh,
5767 bi * nh * s * dh + hi * s * dh + si * dh,
5768 )
5769 } else {
5770 (
5771 bi * s * qrs + si * qrs + hi * dh,
5772 bi * s * krs + si * krs + hi * dh,
5773 bi * s * vrs + si * vrs + hi * dh,
5774 )
5775 };
5776 qh[si * dh..(si + 1) * dh]
5777 .copy_from_slice(&q_d[q_off..q_off + dh]);
5778 kh[si * dh..(si + 1) * dh]
5779 .copy_from_slice(&k_d[k_off..k_off + dh]);
5780 vh[si * dh..(si + 1) * dh]
5781 .copy_from_slice(&v_d[v_off..v_off + dh]);
5782 }
5783 for qi in 0..s {
5784 for ki in 0..s {
5785 let mut dot = 0f32;
5786 for d in 0..dh {
5787 dot += qh[qi * dh + d] * kh[ki * dh + d];
5788 }
5789 sc[qi * s + ki] = dot * scale;
5790 }
5791 }
5792 match mask_kind {
5795 rlx_ir::op::MaskKind::None => {}
5796 rlx_ir::op::MaskKind::Causal => {
5797 for qi in 0..s {
5798 for ki in (qi + 1)..s {
5799 sc[qi * s + ki] = mask_neg;
5800 }
5801 }
5802 }
5803 rlx_ir::op::MaskKind::SlidingWindow(w) => {
5804 for qi in 0..s {
5805 let lo = qi.saturating_sub(w);
5806 for ki in 0..s {
5807 if ki < lo || ki > qi {
5808 sc[qi * s + ki] = mask_neg;
5809 }
5810 }
5811 }
5812 }
5813 rlx_ir::op::MaskKind::Custom => {
5814 for qi in 0..s {
5815 for ki in 0..s {
5816 if m_d[bi * s + ki] < mask_thr {
5817 sc[qi * s + ki] = mask_neg;
5818 }
5819 }
5820 }
5821 }
5822 rlx_ir::op::MaskKind::Bias => {
5823 let per_bh = s * s;
5824 let off = (bi * nh + hi) * per_bh;
5825 for i in 0..per_bh {
5826 sc[i] += m_d[off + i];
5827 }
5828 }
5829 }
5830 crate::naive::softmax(&mut sc, s, s);
5831 oh.fill(0.0);
5832 for qi in 0..s {
5833 for ki in 0..s {
5834 let w = sc[qi * s + ki];
5835 if w > score_skip {
5836 for d in 0..dh {
5837 oh[qi * dh + d] += w * vh[ki * dh + d];
5838 }
5839 }
5840 }
5841 }
5842 for si in 0..s {
5843 let off = if bhsd {
5844 bi * nh * s * dh + hi * s * dh + si * dh
5845 } else {
5846 bi * s * hs + si * hs + hi * dh
5847 };
5848 o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
5849 }
5850 }
5851 }
5852 })
5853 }
5854
5855 Thunk::FusedSwiGLU {
5856 src,
5857 dst,
5858 n_half,
5859 total,
5860 gate_first,
5861 } => {
5862 let n = n_half as usize;
5863 let t = total as usize;
5864 let outer = t / n;
5865 let in_total = outer * 2 * n;
5866 Arc::new(move |base: *mut u8| unsafe {
5867 let inp = sl(src, base, in_total);
5868 let out = sl_mut(dst, base, t);
5869 for o in 0..outer {
5870 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
5871 let out_row = &mut out[o * n..(o + 1) * n];
5872 for i in 0..n {
5873 let (up, gate) = if gate_first {
5874 (in_row[n + i], in_row[i])
5875 } else {
5876 (in_row[i], in_row[n + i])
5877 };
5878 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
5879 }
5880 }
5881 })
5882 }
5883
5884 Thunk::Concat {
5885 dst,
5886 outer,
5887 inner,
5888 total_axis,
5889 inputs,
5890 } => {
5891 let outer = outer as usize;
5892 let inner = inner as usize;
5893 let total_axis = total_axis as usize;
5894 let out_total = outer * total_axis * inner;
5895 let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
5898 let mut cum: usize = 0;
5899 for (src_off, in_axis) in &inputs {
5900 let in_axis = *in_axis as usize;
5901 layout.push((*src_off, cum * inner, in_axis * inner));
5902 cum += in_axis;
5903 }
5904 Arc::new(move |base: *mut u8| unsafe {
5905 let out = sl_mut(dst, base, out_total);
5906 let row_stride = total_axis * inner;
5907 for (src_off, dst_col_off, copy_per_row) in &layout {
5908 let in_total = outer * *copy_per_row;
5909 let inp = sl(*src_off, base, in_total);
5910 for o in 0..outer {
5911 let dst_row_start = o * row_stride + *dst_col_off;
5912 let src_row_start = o * *copy_per_row;
5913 out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
5914 &inp[src_row_start..src_row_start + *copy_per_row],
5915 );
5916 }
5917 }
5918 })
5919 }
5920
5921 Thunk::CustomOp {
5922 kernel,
5923 inputs,
5924 output,
5925 attrs,
5926 } => {
5927 let kernel = kernel.clone();
5933 let attrs = attrs.clone();
5934 let inputs = inputs.clone();
5935 let (out_off, out_len, out_shape) = output.clone();
5936 Arc::new(move |base: *mut u8| unsafe {
5937 dispatch_custom_op(
5938 &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
5939 );
5940 })
5941 }
5942
5943 Thunk::GaussianSplatRender {
5944 positions_off,
5945 positions_len,
5946 scales_off,
5947 scales_len,
5948 rotations_off,
5949 rotations_len,
5950 opacities_off,
5951 opacities_len,
5952 colors_off,
5953 colors_len,
5954 sh_coeffs_off,
5955 sh_coeffs_len,
5956 meta_off,
5957 dst_off,
5958 dst_len,
5959 width,
5960 height,
5961 tile_size,
5962 radius_scale,
5963 alpha_cutoff,
5964 max_splat_steps,
5965 transmittance_threshold,
5966 max_list_entries,
5967 } => Arc::new(move |base: *mut u8| unsafe {
5968 crate::splat::execute_gaussian_splat_render(
5969 positions_off,
5970 positions_len,
5971 scales_off,
5972 scales_len,
5973 rotations_off,
5974 rotations_len,
5975 opacities_off,
5976 opacities_len,
5977 colors_off,
5978 colors_len,
5979 sh_coeffs_off,
5980 sh_coeffs_len,
5981 meta_off,
5982 dst_off,
5983 dst_len,
5984 width,
5985 height,
5986 tile_size,
5987 radius_scale,
5988 alpha_cutoff,
5989 max_splat_steps,
5990 transmittance_threshold,
5991 max_list_entries,
5992 base,
5993 );
5994 }),
5995
5996 Thunk::GaussianSplatRenderBackward {
5997 positions_off,
5998 positions_len,
5999 scales_off,
6000 scales_len,
6001 rotations_off,
6002 rotations_len,
6003 opacities_off,
6004 opacities_len,
6005 colors_off,
6006 colors_len,
6007 sh_coeffs_off,
6008 sh_coeffs_len,
6009 meta_off,
6010 d_loss_off,
6011 d_loss_len,
6012 packed_off,
6013 packed_len,
6014 width,
6015 height,
6016 tile_size,
6017 radius_scale,
6018 alpha_cutoff,
6019 max_splat_steps,
6020 transmittance_threshold,
6021 max_list_entries,
6022 loss_grad_clip,
6023 sh_band,
6024 max_anisotropy,
6025 } => Arc::new(move |base: *mut u8| unsafe {
6026 crate::splat::execute_gaussian_splat_render_backward(
6027 positions_off,
6028 positions_len,
6029 scales_off,
6030 scales_len,
6031 rotations_off,
6032 rotations_len,
6033 opacities_off,
6034 opacities_len,
6035 colors_off,
6036 colors_len,
6037 sh_coeffs_off,
6038 sh_coeffs_len,
6039 meta_off,
6040 d_loss_off,
6041 d_loss_len,
6042 packed_off,
6043 packed_len,
6044 width,
6045 height,
6046 tile_size,
6047 radius_scale,
6048 alpha_cutoff,
6049 max_splat_steps,
6050 transmittance_threshold,
6051 max_list_entries,
6052 loss_grad_clip,
6053 sh_band,
6054 max_anisotropy,
6055 base,
6056 );
6057 }),
6058
6059 Thunk::GaussianSplatPrepare {
6060 positions_off,
6061 positions_len,
6062 scales_off,
6063 scales_len,
6064 rotations_off,
6065 rotations_len,
6066 opacities_off,
6067 opacities_len,
6068 colors_off,
6069 colors_len,
6070 sh_coeffs_off,
6071 sh_coeffs_len,
6072 meta_off,
6073 meta_len,
6074 prep_off,
6075 prep_len,
6076 width,
6077 height,
6078 tile_size,
6079 radius_scale,
6080 alpha_cutoff,
6081 max_splat_steps,
6082 transmittance_threshold,
6083 max_list_entries,
6084 } => Arc::new(move |base: *mut u8| unsafe {
6085 crate::splat::execute_gaussian_splat_prepare(
6086 positions_off,
6087 positions_len,
6088 scales_off,
6089 scales_len,
6090 rotations_off,
6091 rotations_len,
6092 opacities_off,
6093 opacities_len,
6094 colors_off,
6095 colors_len,
6096 sh_coeffs_off,
6097 sh_coeffs_len,
6098 meta_off,
6099 meta_len,
6100 prep_off,
6101 prep_len,
6102 width,
6103 height,
6104 tile_size,
6105 radius_scale,
6106 alpha_cutoff,
6107 max_splat_steps,
6108 transmittance_threshold,
6109 max_list_entries,
6110 base,
6111 );
6112 }),
6113
6114 Thunk::GaussianSplatRasterize {
6115 prep_off,
6116 prep_len,
6117 meta_off,
6118 meta_len,
6119 dst_off,
6120 dst_len,
6121 count,
6122 width,
6123 height,
6124 tile_size,
6125 alpha_cutoff,
6126 max_splat_steps,
6127 transmittance_threshold,
6128 max_list_entries,
6129 } => Arc::new(move |base: *mut u8| unsafe {
6130 crate::splat::execute_gaussian_splat_rasterize(
6131 prep_off,
6132 prep_len,
6133 meta_off,
6134 meta_len,
6135 dst_off,
6136 dst_len,
6137 count,
6138 width,
6139 height,
6140 tile_size,
6141 alpha_cutoff,
6142 max_splat_steps,
6143 transmittance_threshold,
6144 max_list_entries,
6145 base,
6146 );
6147 }),
6148
6149 Thunk::Fft1d {
6150 src,
6151 dst,
6152 outer,
6153 n_complex,
6154 inverse,
6155 dtype,
6156 } => {
6157 let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6158 rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6159 execute_fft1d_f64(
6160 src,
6161 dst,
6162 outer as usize,
6163 n_complex as usize,
6164 inverse,
6165 base,
6166 );
6167 }),
6168 rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6169 execute_fft1d_f32(
6170 src,
6171 dst,
6172 outer as usize,
6173 n_complex as usize,
6174 inverse,
6175 base,
6176 );
6177 }),
6178 other => panic!("Op::Fft on CPU requires F32/F64, got {other:?}"),
6179 };
6180 f
6181 }
6182
6183 _ => Arc::new(|_: *mut u8| {}),
6184 }
6185 })
6186 .collect();
6187
6188 let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6192 .and_then(|v| v.parse().ok())
6193 .unwrap_or(64);
6194 let should_fuse = thunks.iter().any(|t| match t {
6195 Thunk::Attention { batch, seq, .. } => {
6196 (*batch as usize) * (*seq as usize) <= fuse_threshold
6197 }
6198 _ => false,
6199 });
6200
6201 if should_fuse {
6202 let active: Vec<usize> = thunks
6204 .iter()
6205 .enumerate()
6206 .filter(|(_, t)| !matches!(t, Thunk::Nop))
6207 .map(|(i, _)| i)
6208 .collect();
6209
6210 let mut kill = vec![false; thunks.len()]; let mut insertions: Vec<(usize, Thunk)> = Vec::new(); let mut ai = 0;
6214 while ai < active.len() {
6215 let a = |off: usize| -> Option<(usize, &Thunk)> {
6217 active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6218 };
6219
6220 let matched = (|| {
6222 let (_i0, t0) = a(0)?;
6223 let (_, t1) = a(1)?;
6224 let (_, t2) = a(2)?;
6225 let (_, t3) = a(3)?;
6226
6227 let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6229 Thunk::FusedMmBiasAct {
6230 a,
6231 w,
6232 bias,
6233 n: _,
6234 act: None,
6235 ..
6236 } => (*a, *w, *bias, true),
6237 Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6238 _ => return None,
6239 };
6240
6241 if !matches!(t1, Thunk::Narrow { .. }) {
6243 return None;
6244 }
6245 if !matches!(t2, Thunk::Narrow { .. }) {
6246 return None;
6247 }
6248 if !matches!(t3, Thunk::Narrow { .. }) {
6249 return None;
6250 }
6251
6252 let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6254 _,
6255 Thunk::Rope {
6256 cos, sin, cos_len, ..
6257 },
6258 )) = a(4)
6259 {
6260 if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6261 if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6262 (true, 6, *cos, *sin, *cos_len)
6263 } else {
6264 return None;
6265 }
6266 } else {
6267 return None;
6268 }
6269 } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6270 (false, 4, 0, 0, 0)
6271 } else {
6272 return None;
6273 };
6274
6275 let (_attn_real_idx, attn_t) = a(attn_ai)?;
6276 let (batch, seq, heads, head_dim, mask) = match attn_t {
6277 Thunk::Attention {
6278 batch,
6279 seq,
6280 heads,
6281 head_dim,
6282 mask,
6283 ..
6284 } => (*batch, *seq, *heads, *head_dim, *mask),
6285 _ => return None,
6286 };
6287
6288 let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6290 let (out_w, out_b, out_dst) = match out_t {
6291 Thunk::FusedMmBiasAct {
6292 w,
6293 bias,
6294 c,
6295 act: None,
6296 ..
6297 } => (*w, *bias, *c),
6298 Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6299 _ => return None,
6300 };
6301
6302 let hs = heads * head_dim;
6303 let total_active = attn_ai + 2; Some((
6306 total_active,
6307 Thunk::FusedAttnBlock {
6308 hidden,
6309 qkv_w,
6310 out_w,
6311 mask,
6312 out: out_dst,
6313 qkv_b: if has_b { qkv_b } else { 0 },
6314 out_b: if has_b { out_b } else { 0 },
6315 cos: cos_off,
6316 sin: sin_off,
6317 cos_len: cl,
6318 batch,
6319 seq,
6320 hs,
6321 nh: heads,
6322 dh: head_dim,
6323 has_bias: has_b,
6324 has_rope,
6325 },
6326 ))
6327 })();
6328
6329 if let Some((count, fused_thunk)) = matched {
6330 for off in 0..count {
6332 if let Some(&idx) = active.get(ai + off) {
6333 kill[idx] = true;
6334 }
6335 }
6336 insertions.push((active[ai], fused_thunk));
6338 ai += count;
6339 } else {
6340 ai += 1;
6341 }
6342 }
6343
6344 if !insertions.is_empty() {
6346 let mut new_thunks = Vec::with_capacity(thunks.len());
6347 let mut insert_idx = 0;
6348 for (i, t) in thunks.into_iter().enumerate() {
6349 if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6350 new_thunks.push(insertions[insert_idx].1.clone());
6351 insert_idx += 1;
6352 }
6353 if !kill[i] {
6354 new_thunks.push(t);
6355 }
6356 }
6357 if cfg.verbose >= 1 {
6358 eprintln!(
6359 "[rlx] fused_attention: {} attention blocks fused",
6360 insertions.len()
6361 );
6362 }
6363 thunks = new_thunks;
6364 }
6365 }
6366
6367 if should_fuse {
6372 let active: Vec<usize> = thunks
6373 .iter()
6374 .enumerate()
6375 .filter(|(_, t)| !matches!(t, Thunk::Nop))
6376 .map(|(i, _)| i)
6377 .collect();
6378
6379 let mut kill = vec![false; thunks.len()];
6380 let mut insertions: Vec<(usize, Thunk)> = Vec::new();
6381
6382 let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
6383
6384 let mut ai = 0;
6385 while ai < active.len() {
6386 let bert_match = (|| -> Option<usize> {
6388 let fab = a(ai)?;
6389 let rln1 = a(ai + 1)?;
6390 let ffn1 = a(ai + 2)?;
6391 let ffn2 = a(ai + 3)?;
6392 let rln2 = a(ai + 4)?;
6393
6394 let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
6395 Thunk::FusedAttnBlock {
6396 hidden,
6397 qkv_w,
6398 qkv_b,
6399 out_w,
6400 out_b,
6401 mask,
6402 batch,
6403 seq,
6404 hs,
6405 nh,
6406 dh,
6407 has_bias: true,
6408 has_rope: false,
6409 ..
6410 } => (
6411 *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
6412 ),
6413 _ => return None,
6414 };
6415 let (ln1_g, ln1_b, eps1) = match rln1 {
6416 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6417 _ => return None,
6418 };
6419 let (fc1_w, fc1_b, int_dim) = match ffn1 {
6420 Thunk::FusedMmBiasAct {
6421 w,
6422 bias,
6423 n,
6424 act: Some(Activation::Gelu),
6425 ..
6426 } => (*w, *bias, *n),
6427 _ => return None,
6428 };
6429 let (fc2_w, fc2_b) = match ffn2 {
6430 Thunk::FusedMmBiasAct {
6431 w, bias, act: None, ..
6432 } => (*w, *bias),
6433 _ => return None,
6434 };
6435 let (ln2_g, ln2_b, eps2, out) = match rln2 {
6436 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6437 _ => return None,
6438 };
6439
6440 for off in 0..5 {
6441 kill[active[ai + off]] = true;
6442 }
6443 insertions.push((
6444 active[ai],
6445 Thunk::FusedBertLayer {
6446 hidden,
6447 qkv_w,
6448 qkv_b,
6449 out_w,
6450 out_b,
6451 mask,
6452 ln1_g,
6453 ln1_b,
6454 eps1,
6455 fc1_w,
6456 fc1_b,
6457 fc2_w,
6458 fc2_b,
6459 ln2_g,
6460 ln2_b,
6461 eps2,
6462 out,
6463 batch,
6464 seq,
6465 hs,
6466 nh,
6467 dh,
6468 int_dim,
6469 },
6470 ));
6471 Some(5)
6472 })();
6473 if let Some(n) = bert_match {
6474 ai += n;
6475 continue;
6476 }
6477
6478 #[allow(unreachable_code)]
6482 let nomic_match = (|| -> Option<usize> {
6483 return None; let fab = a(ai)?;
6485 let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
6486 match fab {
6487 Thunk::FusedAttnBlock {
6488 hidden,
6489 qkv_w,
6490 out_w,
6491 mask,
6492 cos,
6493 sin,
6494 cos_len,
6495 batch,
6496 seq,
6497 hs,
6498 nh,
6499 dh,
6500 has_bias: false,
6501 has_rope: true,
6502 ..
6503 } => (
6504 *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
6505 *hs, *nh, *dh,
6506 ),
6507 _ => return None,
6508 };
6509 let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
6511 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6512 _ => return None,
6513 };
6514 let fused_fc_w = match a(ai + 2)? {
6516 Thunk::Sgemm { b: w, .. } => *w,
6517 _ => return None,
6518 };
6519 if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
6521 return None;
6522 }
6523 if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
6524 return None;
6525 }
6526 if !matches!(
6528 a(ai + 5)?,
6529 Thunk::ActivationInPlace {
6530 act: Activation::Silu,
6531 ..
6532 }
6533 ) {
6534 return None;
6535 }
6536 if !matches!(
6538 a(ai + 6)?,
6539 Thunk::BinaryFull {
6540 op: BinaryOp::Mul,
6541 ..
6542 }
6543 ) {
6544 return None;
6545 }
6546 let fc2_w = match a(ai + 7)? {
6548 Thunk::Sgemm { b: w, .. } => *w,
6549 _ => return None,
6550 };
6551 let int_dim = match a(ai + 3)? {
6553 Thunk::Narrow { inner, .. } => *inner,
6554 _ => return None,
6555 };
6556 let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
6558 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6559 _ => return None,
6560 };
6561
6562 for off in 0..9 {
6563 kill[active[ai + off]] = true;
6564 }
6565 insertions.push((
6566 active[ai],
6567 Thunk::FusedNomicLayer {
6568 hidden,
6569 qkv_w,
6570 out_w,
6571 mask,
6572 cos,
6573 sin,
6574 cos_len,
6575 ln1_g,
6576 ln1_b,
6577 eps1,
6578 fc11_w: fused_fc_w,
6579 fc12_w: 0,
6580 fc2_w,
6581 ln2_g,
6582 ln2_b,
6583 eps2,
6584 out,
6585 batch,
6586 seq,
6587 hs,
6588 nh,
6589 dh,
6590 int_dim,
6591 },
6592 ));
6593 Some(9)
6594 })();
6595 if let Some(n) = nomic_match {
6596 ai += n;
6597 continue;
6598 }
6599
6600 ai += 1;
6601 }
6602
6603 if !insertions.is_empty() {
6604 let mut new_thunks = Vec::with_capacity(thunks.len());
6605 let mut ins_idx = 0;
6606 for (i, t) in thunks.into_iter().enumerate() {
6607 if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
6608 new_thunks.push(insertions[ins_idx].1.clone());
6609 ins_idx += 1;
6610 }
6611 if !kill[i] {
6612 new_thunks.push(t);
6613 }
6614 }
6615 if cfg.verbose >= 1 {
6616 eprintln!(
6617 "[rlx] fused_layer: {} full transformer layers fused",
6618 insertions.len()
6619 );
6620 }
6621 thunks = new_thunks;
6622 }
6623 }
6624
6625 {
6637 let mut read_offsets: HashMap<usize, usize> = HashMap::new();
6640 for t in &thunks {
6641 for off in thunk_read_offsets(t) {
6642 *read_offsets.entry(off).or_insert(0) += 1;
6643 }
6644 }
6645
6646 let mut fused_count = 0usize;
6647 for i in 0..thunks.len().saturating_sub(1) {
6648 let narrow = match &thunks[i] {
6651 Thunk::Narrow { .. } => i,
6652 _ => continue,
6653 };
6654 let mut j = narrow + 1;
6656 while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
6657 j += 1;
6658 }
6659 if j >= thunks.len() {
6660 continue;
6661 }
6662 let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
6664 Thunk::Narrow {
6665 src,
6666 dst,
6667 src_stride,
6668 ..
6669 } => (*src, *dst, *src_stride),
6670 _ => continue,
6671 };
6672 let rope_reads_narrow = matches!(&thunks[j],
6673 Thunk::Rope { src, .. } if *src == n_dst);
6674 if !rope_reads_narrow {
6675 continue;
6676 }
6677 if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
6681 continue;
6682 }
6683
6684 if let Thunk::Rope {
6687 src,
6688 src_row_stride,
6689 ..
6690 } = &mut thunks[j]
6691 {
6692 *src = n_src;
6693 *src_row_stride = n_src_stride;
6694 }
6695 thunks[narrow] = Thunk::Nop;
6696 fused_count += 1;
6697 }
6698
6699 if fused_count > 0 && cfg.verbose >= 1 {
6700 eprintln!(
6701 "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
6702 fused_count
6703 );
6704 }
6705 }
6706
6707 {
6719 let mut read_counts: HashMap<usize, usize> = HashMap::new();
6720 for t in &thunks {
6721 for off in thunk_read_offsets(t) {
6722 *read_counts.entry(off).or_insert(0) += 1;
6723 }
6724 }
6725 let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
6727 for (i, t) in thunks.iter().enumerate() {
6728 if let Thunk::Narrow { dst, .. } = t {
6729 dst_to_idx.insert(*dst, i);
6730 }
6731 }
6732
6733 let mut fused_count = 0usize;
6734 for i in 0..thunks.len() {
6735 let (q_off, k_off, v_off) = match &thunks[i] {
6736 Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
6737 _ => continue,
6738 };
6739 let q_n = match dst_to_idx.get(&q_off).copied() {
6741 Some(x) => x,
6742 None => continue,
6743 };
6744 let k_n = match dst_to_idx.get(&k_off).copied() {
6745 Some(x) => x,
6746 None => continue,
6747 };
6748 let v_n = match dst_to_idx.get(&v_off).copied() {
6749 Some(x) => x,
6750 None => continue,
6751 };
6752 if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
6754 continue;
6755 }
6756 if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
6757 continue;
6758 }
6759 if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
6760 continue;
6761 }
6762
6763 let (q_src, q_stride) = match &thunks[q_n] {
6764 Thunk::Narrow {
6765 src, src_stride, ..
6766 } => (*src, *src_stride),
6767 _ => continue,
6768 };
6769 let (k_src, k_stride) = match &thunks[k_n] {
6770 Thunk::Narrow {
6771 src, src_stride, ..
6772 } => (*src, *src_stride),
6773 _ => continue,
6774 };
6775 let (v_src, v_stride) = match &thunks[v_n] {
6776 Thunk::Narrow {
6777 src, src_stride, ..
6778 } => (*src, *src_stride),
6779 _ => continue,
6780 };
6781
6782 if let Thunk::Attention {
6783 q,
6784 k,
6785 v,
6786 q_row_stride,
6787 k_row_stride,
6788 v_row_stride,
6789 ..
6790 } = &mut thunks[i]
6791 {
6792 *q = q_src;
6793 *k = k_src;
6794 *v = v_src;
6795 *q_row_stride = q_stride;
6796 *k_row_stride = k_stride;
6797 *v_row_stride = v_stride;
6798 }
6799 thunks[q_n] = Thunk::Nop;
6800 thunks[k_n] = Thunk::Nop;
6801 thunks[v_n] = Thunk::Nop;
6802 fused_count += 1;
6803 }
6804
6805 if fused_count > 0 && cfg.verbose >= 1 {
6806 eprintln!(
6807 "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
6808 fused_count
6809 );
6810 }
6811 }
6812
6813 ThunkSchedule {
6814 thunks,
6815 moe_resident: None,
6816 moe_resident_layers: None,
6817 moe_topk_capture: None,
6818 mask_threshold: cfg.mask_binary_threshold,
6819 mask_neg_inf: cfg.attn_mask_neg_inf,
6820 score_skip: cfg.score_skip_threshold,
6821 compiled_fns,
6822 }
6823}
6824
6825fn get_len(graph: &Graph, id: NodeId) -> usize {
6826 graph.node(id).shape.num_elements().unwrap_or(0)
6827}
6828
6829fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
6831 let dims = graph.node(id).shape.dims();
6832 let mut out = Vec::with_capacity(dims.len());
6833 for d in dims {
6834 if let Some(s) = match d {
6835 rlx_ir::Dim::Static(s) => Some(*s),
6836 _ => None,
6837 } {
6838 out.push(s);
6839 } else {
6840 return Vec::new();
6841 }
6842 }
6843 out
6844}
6845
6846fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
6864 if rhs_dims.len() > out_dims.len() {
6865 return false;
6866 }
6867 let off = out_dims.len() - rhs_dims.len();
6868 for i in 0..rhs_dims.len() {
6869 let r = match rhs_dims[i] {
6870 rlx_ir::Dim::Static(n) => n,
6871 _ => return false,
6872 };
6873 let o = match out_dims[off + i] {
6874 rlx_ir::Dim::Static(n) => n,
6875 _ => return false,
6876 };
6877 if r != o {
6878 return false;
6879 }
6880 }
6881 true
6882}
6883
6884fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
6885 let r_out = out_dims.len();
6886 let r_in = in_dims.len();
6887 assert!(
6888 r_in <= r_out,
6889 "broadcast: input rank {r_in} > output rank {r_out}"
6890 );
6891 let pad = r_out - r_in;
6892 let mut strides = vec![0u32; r_out];
6893 let mut acc: usize = 1;
6894 for d in (0..r_out).rev() {
6895 let in_size = if d < pad { 1 } else { in_dims[d - pad] };
6896 if in_size == 1 {
6897 strides[d] = 0;
6898 } else {
6899 assert_eq!(
6900 in_size, out_dims[d],
6901 "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
6902 out_dims[d]
6903 );
6904 strides[d] = acc as u32;
6905 acc *= in_size;
6906 }
6907 }
6908 strides
6909}
6910
6911pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6915 let base = arena_buf.as_mut_ptr();
6916 for f in &schedule.compiled_fns {
6917 f(base);
6918 }
6919}
6920
6921pub fn execute_thunks_active(
6926 schedule: &ThunkSchedule,
6927 _arena_buf: &mut [u8],
6928 _actual: usize,
6929 _upper: usize,
6930) -> bool {
6931 let _ = schedule;
6932 false
6933}
6934
6935struct MoeResidencyGuard;
6937impl Drop for MoeResidencyGuard {
6938 fn drop(&mut self) {
6939 if let Some(stats) = crate::moe_residency::take_stats() {
6940 crate::moe_residency::stash_last_forward_stats(stats);
6941 } else {
6942 crate::moe_residency::clear_mask();
6943 }
6944 }
6945}
6946
6947pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6948 crate::moe_residency::reset_gmm_counters();
6949 if let Some(layers) = schedule.moe_resident_layers.clone() {
6950 crate::moe_residency::set_per_layer_masks(Some(layers));
6951 } else {
6952 crate::moe_residency::set_mask(schedule.moe_resident.clone());
6953 }
6954 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
6955 cap.clear();
6956 }
6957 let _moe_guard = MoeResidencyGuard;
6958 let base = arena_buf.as_mut_ptr();
6959 let mask_thr = schedule.mask_threshold;
6960 let mask_neg = schedule.mask_neg_inf;
6961 let score_thr = schedule.score_skip;
6962 let thunks = &schedule.thunks;
6963 let len = thunks.len();
6964
6965 let max_h = thunks
6967 .iter()
6968 .filter_map(|t| match t {
6969 Thunk::FusedResidualLN { h, .. }
6970 | Thunk::FusedResidualRmsNorm { h, .. }
6971 | Thunk::LayerNorm { h, .. } => Some(*h as usize),
6972 _ => None,
6973 })
6974 .max()
6975 .unwrap_or(0);
6976 let zero_bias = vec![0f32; max_h];
6977
6978 let max_sdpa = thunks
6981 .iter()
6982 .filter_map(|t| match t {
6983 Thunk::Attention {
6984 batch,
6985 seq,
6986 kv_seq,
6987 heads,
6988 head_dim,
6989 ..
6990 } => Some((
6991 *batch as usize,
6992 (*seq as usize).max(*kv_seq as usize),
6993 *heads as usize,
6994 *head_dim as usize,
6995 )),
6996 _ => None,
6997 })
6998 .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
6999 (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7000 });
7001 let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7002 let max_units = max_batch * max_heads;
7003 let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7004
7005 let fl = thunks
7007 .iter()
7008 .filter_map(|t| match t {
7009 Thunk::FusedBertLayer {
7010 batch,
7011 seq,
7012 hs,
7013 int_dim,
7014 ..
7015 } => {
7016 let m = (*batch as usize) * (*seq as usize);
7017 let h = *hs as usize;
7018 let id = *int_dim as usize;
7019 Some((m, h, id, m * (*seq as usize)))
7020 }
7021 Thunk::FusedNomicLayer {
7022 batch,
7023 seq,
7024 hs,
7025 int_dim,
7026 ..
7027 } => {
7028 let m = (*batch as usize) * (*seq as usize);
7029 let h = *hs as usize;
7030 let id = *int_dim as usize;
7031 Some((m, h, id, m * (*seq as usize)))
7032 }
7033 _ => None,
7034 })
7035 .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7036 (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7037 });
7038 let (fl_m, fl_h, fl_int, fl_ss) = fl;
7039 let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7040 let mut fl_attn = vec![0f32; fl_m * fl_h];
7041 let mut fl_res = vec![0f32; fl_m * fl_h];
7042 let mut fl_normed = vec![0f32; fl_m * fl_h];
7043 let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; let mut fl_sc = vec![0f32; fl_ss.max(1)];
7045
7046 for i in 0..len {
7047 let thunk = unsafe { thunks.get_unchecked(i) };
7048 match thunk {
7049 Thunk::Nop => {}
7050
7051 Thunk::GaussianSplatRender {
7052 positions_off,
7053 positions_len,
7054 scales_off,
7055 scales_len,
7056 rotations_off,
7057 rotations_len,
7058 opacities_off,
7059 opacities_len,
7060 colors_off,
7061 colors_len,
7062 sh_coeffs_off,
7063 sh_coeffs_len,
7064 meta_off,
7065 dst_off,
7066 dst_len,
7067 width,
7068 height,
7069 tile_size,
7070 radius_scale,
7071 alpha_cutoff,
7072 max_splat_steps,
7073 transmittance_threshold,
7074 max_list_entries,
7075 } => unsafe {
7076 crate::splat::execute_gaussian_splat_render(
7077 *positions_off,
7078 *positions_len,
7079 *scales_off,
7080 *scales_len,
7081 *rotations_off,
7082 *rotations_len,
7083 *opacities_off,
7084 *opacities_len,
7085 *colors_off,
7086 *colors_len,
7087 *sh_coeffs_off,
7088 *sh_coeffs_len,
7089 *meta_off,
7090 *dst_off,
7091 *dst_len,
7092 *width,
7093 *height,
7094 *tile_size,
7095 *radius_scale,
7096 *alpha_cutoff,
7097 *max_splat_steps,
7098 *transmittance_threshold,
7099 *max_list_entries,
7100 base,
7101 );
7102 },
7103
7104 Thunk::GaussianSplatRenderBackward {
7105 positions_off,
7106 positions_len,
7107 scales_off,
7108 scales_len,
7109 rotations_off,
7110 rotations_len,
7111 opacities_off,
7112 opacities_len,
7113 colors_off,
7114 colors_len,
7115 sh_coeffs_off,
7116 sh_coeffs_len,
7117 meta_off,
7118 d_loss_off,
7119 d_loss_len,
7120 packed_off,
7121 packed_len,
7122 width,
7123 height,
7124 tile_size,
7125 radius_scale,
7126 alpha_cutoff,
7127 max_splat_steps,
7128 transmittance_threshold,
7129 max_list_entries,
7130 loss_grad_clip,
7131 sh_band,
7132 max_anisotropy,
7133 } => unsafe {
7134 crate::splat::execute_gaussian_splat_render_backward(
7135 *positions_off,
7136 *positions_len,
7137 *scales_off,
7138 *scales_len,
7139 *rotations_off,
7140 *rotations_len,
7141 *opacities_off,
7142 *opacities_len,
7143 *colors_off,
7144 *colors_len,
7145 *sh_coeffs_off,
7146 *sh_coeffs_len,
7147 *meta_off,
7148 *d_loss_off,
7149 *d_loss_len,
7150 *packed_off,
7151 *packed_len,
7152 *width,
7153 *height,
7154 *tile_size,
7155 *radius_scale,
7156 *alpha_cutoff,
7157 *max_splat_steps,
7158 *transmittance_threshold,
7159 *max_list_entries,
7160 *loss_grad_clip,
7161 *sh_band,
7162 *max_anisotropy,
7163 base,
7164 );
7165 },
7166
7167 Thunk::GaussianSplatPrepare {
7168 positions_off,
7169 positions_len,
7170 scales_off,
7171 scales_len,
7172 rotations_off,
7173 rotations_len,
7174 opacities_off,
7175 opacities_len,
7176 colors_off,
7177 colors_len,
7178 sh_coeffs_off,
7179 sh_coeffs_len,
7180 meta_off,
7181 meta_len,
7182 prep_off,
7183 prep_len,
7184 width,
7185 height,
7186 tile_size,
7187 radius_scale,
7188 alpha_cutoff,
7189 max_splat_steps,
7190 transmittance_threshold,
7191 max_list_entries,
7192 } => unsafe {
7193 crate::splat::execute_gaussian_splat_prepare(
7194 *positions_off,
7195 *positions_len,
7196 *scales_off,
7197 *scales_len,
7198 *rotations_off,
7199 *rotations_len,
7200 *opacities_off,
7201 *opacities_len,
7202 *colors_off,
7203 *colors_len,
7204 *sh_coeffs_off,
7205 *sh_coeffs_len,
7206 *meta_off,
7207 *meta_len,
7208 *prep_off,
7209 *prep_len,
7210 *width,
7211 *height,
7212 *tile_size,
7213 *radius_scale,
7214 *alpha_cutoff,
7215 *max_splat_steps,
7216 *transmittance_threshold,
7217 *max_list_entries,
7218 base,
7219 );
7220 },
7221
7222 Thunk::GaussianSplatRasterize {
7223 prep_off,
7224 prep_len,
7225 meta_off,
7226 meta_len,
7227 dst_off,
7228 dst_len,
7229 count,
7230 width,
7231 height,
7232 tile_size,
7233 alpha_cutoff,
7234 max_splat_steps,
7235 transmittance_threshold,
7236 max_list_entries,
7237 } => unsafe {
7238 crate::splat::execute_gaussian_splat_rasterize(
7239 *prep_off,
7240 *prep_len,
7241 *meta_off,
7242 *meta_len,
7243 *dst_off,
7244 *dst_len,
7245 *count,
7246 *width,
7247 *height,
7248 *tile_size,
7249 *alpha_cutoff,
7250 *max_splat_steps,
7251 *transmittance_threshold,
7252 *max_list_entries,
7253 base,
7254 );
7255 },
7256
7257 Thunk::Fft1d {
7258 src,
7259 dst,
7260 outer,
7261 n_complex,
7262 inverse,
7263 dtype,
7264 } => unsafe {
7265 match dtype {
7266 rlx_ir::DType::F64 => execute_fft1d_f64(
7267 *src,
7268 *dst,
7269 *outer as usize,
7270 *n_complex as usize,
7271 *inverse,
7272 base,
7273 ),
7274 rlx_ir::DType::F32 => execute_fft1d_f32(
7275 *src,
7276 *dst,
7277 *outer as usize,
7278 *n_complex as usize,
7279 *inverse,
7280 base,
7281 ),
7282 other => panic!("Op::Fft on CPU requires F32/F64, got {other:?}"),
7283 }
7284 },
7285
7286 Thunk::CustomFn {
7290 body,
7291 body_init,
7292 inputs,
7293 body_output_off,
7294 outer_output_off,
7295 out_bytes,
7296 } => {
7297 let mut body_buf: Vec<u8> = (**body_init).clone();
7298 unsafe {
7299 for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
7300 let src = (base as *const u8).add(*outer_in_off);
7301 let dst = body_buf.as_mut_ptr().add(*body_in_off);
7302 std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
7303 }
7304 }
7305 execute_thunks(body, &mut body_buf);
7306 unsafe {
7307 let src = body_buf.as_ptr().add(*body_output_off);
7308 let dst = base.add(*outer_output_off);
7309 std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
7310 }
7311 }
7312
7313 Thunk::Sgemm { a, b, c, m, k, n } => {
7314 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7315 unsafe {
7316 crate::blas::sgemm_auto(
7317 sl(*a, base, m * k),
7318 sl(*b, base, k * n),
7319 sl_mut(*c, base, m * n),
7320 m,
7321 k,
7322 n,
7323 );
7324 }
7325 }
7326
7327 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
7328 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7329 unsafe {
7335 let a_src = sl_f64(*a, base, n_ * n_);
7336 let b_src = sl_f64(*b, base, n_ * nrhs_);
7337 let mut a_scratch: Vec<f64> = a_src.to_vec();
7338 let mut x_buf: Vec<f64> = b_src.to_vec();
7339 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7340 if info != 0 {
7341 panic!(
7342 "DenseSolveF64: dgesv reported singular matrix \
7343 (info={info}, n={n_}, nrhs={nrhs_})"
7344 );
7345 }
7346 let dst = sl_mut_f64(*x, base, n_ * nrhs_);
7347 dst.copy_from_slice(&x_buf);
7348 }
7349 }
7350
7351 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
7352 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7353 unsafe {
7354 let a_src = sl(*a, base, n_ * n_);
7355 let b_src = sl(*b, base, n_ * nrhs_);
7356 let mut a_scratch: Vec<f32> = a_src.to_vec();
7357 let mut x_buf: Vec<f32> = b_src.to_vec();
7358 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7359 if info != 0 {
7360 panic!(
7361 "DenseSolveF32: sgesv reported singular matrix \
7362 (info={info}, n={n_}, nrhs={nrhs_})"
7363 );
7364 }
7365 let dst = sl_mut(*x, base, n_ * nrhs_);
7366 dst.copy_from_slice(&x_buf);
7367 }
7368 }
7369
7370 Thunk::BatchedDenseSolveF64 {
7371 a,
7372 b,
7373 x,
7374 batch,
7375 n,
7376 nrhs,
7377 } => {
7378 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7385 let a_stride = n_ * n_;
7386 let b_stride = n_ * nrhs_;
7387 unsafe {
7388 let a_full = sl_f64(*a, base, b_ * a_stride);
7389 let b_full = sl_f64(*b, base, b_ * b_stride);
7390 let x_full = sl_mut_f64(*x, base, b_ * b_stride);
7391 for bi in 0..b_ {
7392 let mut a_scratch: Vec<f64> =
7393 a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7394 let mut x_buf: Vec<f64> =
7395 b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7396 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7397 if info != 0 {
7398 panic!(
7399 "BatchedDenseSolveF64: slice {bi} \
7400 singular (info={info}, n={n_}, nrhs={nrhs_})"
7401 );
7402 }
7403 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7404 }
7405 }
7406 }
7407
7408 Thunk::BatchedDenseSolveF32 {
7409 a,
7410 b,
7411 x,
7412 batch,
7413 n,
7414 nrhs,
7415 } => {
7416 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7417 let a_stride = n_ * n_;
7418 let b_stride = n_ * nrhs_;
7419 unsafe {
7420 let a_full = sl(*a, base, b_ * a_stride);
7421 let b_full = sl(*b, base, b_ * b_stride);
7422 let x_full = sl_mut(*x, base, b_ * b_stride);
7423 for bi in 0..b_ {
7424 let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7425 let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7426 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7427 if info != 0 {
7428 panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
7429 }
7430 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7431 }
7432 }
7433 }
7434
7435 Thunk::BatchedDgemmF64 {
7436 a,
7437 b,
7438 c,
7439 batch,
7440 m,
7441 k,
7442 n,
7443 } => {
7444 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7445 let a_stride = m_ * k_;
7446 let b_stride = k_ * n_;
7447 let c_stride = m_ * n_;
7448 unsafe {
7449 let a_full = sl_f64(*a, base, b_ * a_stride);
7450 let b_full = sl_f64(*b, base, b_ * b_stride);
7451 let c_full = sl_mut_f64(*c, base, b_ * c_stride);
7452 for bi in 0..b_ {
7453 let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7454 let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7455 let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7456 crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
7457 }
7458 }
7459 }
7460
7461 Thunk::BatchedSgemm {
7462 a,
7463 b,
7464 c,
7465 batch,
7466 m,
7467 k,
7468 n,
7469 } => {
7470 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7471 let a_stride = m_ * k_;
7472 let b_stride = k_ * n_;
7473 let c_stride = m_ * n_;
7474 unsafe {
7475 let a_full = sl(*a, base, b_ * a_stride);
7476 let b_full = sl(*b, base, b_ * b_stride);
7477 let c_full = sl_mut(*c, base, b_ * c_stride);
7478 for bi in 0..b_ {
7479 let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7480 let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7481 let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7482 crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
7483 }
7484 }
7485 }
7486
7487 Thunk::Dgemm { a, b, c, m, k, n } => {
7488 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7489 unsafe {
7490 crate::blas::dgemm(
7491 sl_f64(*a, base, m * k),
7492 sl_f64(*b, base, k * n),
7493 sl_mut_f64(*c, base, m * n),
7494 m,
7495 k,
7496 n,
7497 );
7498 }
7499 }
7500
7501 Thunk::TransposeF64 {
7502 src,
7503 dst,
7504 in_total,
7505 out_dims,
7506 in_strides,
7507 } => unsafe {
7508 let inp = sl_f64(*src, base, *in_total as usize);
7509 let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
7510 let out = sl_mut_f64(*dst, base, out_total);
7511 transpose_walk_f64(inp, out, out_dims, in_strides);
7512 },
7513
7514 Thunk::ActivationF64 {
7515 src,
7516 dst,
7517 len,
7518 kind,
7519 } => {
7520 let len = *len as usize;
7521 unsafe {
7522 let inp = sl_f64(*src, base, len);
7523 let out = sl_mut_f64(*dst, base, len);
7524 apply_activation_f64(inp, out, *kind);
7525 }
7526 }
7527
7528 Thunk::ReduceSumF64 {
7529 src,
7530 dst,
7531 outer,
7532 reduced,
7533 inner,
7534 } => {
7535 let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
7536 unsafe {
7537 let inp = sl_f64(*src, base, o * r * n);
7538 let out = sl_mut_f64(*dst, base, o * n);
7539 reduce_sum_f64(inp, out, o, r, n);
7540 }
7541 }
7542
7543 Thunk::CopyF64 { src, dst, len } => {
7544 let len = *len as usize;
7545 if *src == *dst { } else {
7547 unsafe {
7548 let s = sl_f64(*src, base, len);
7549 let d = sl_mut_f64(*dst, base, len);
7550 d.copy_from_slice(s);
7551 }
7552 }
7553 }
7554
7555 Thunk::BinaryFullF64 {
7556 lhs,
7557 rhs,
7558 dst,
7559 len,
7560 lhs_len,
7561 rhs_len,
7562 op,
7563 out_dims_bcast,
7564 bcast_lhs_strides,
7565 bcast_rhs_strides,
7566 } => {
7567 let len = *len as usize;
7568 let lhs_len = *lhs_len as usize;
7569 let rhs_len = *rhs_len as usize;
7570 unsafe {
7571 let l = sl_f64(*lhs, base, lhs_len);
7572 let r = sl_f64(*rhs, base, rhs_len);
7573 let d = sl_mut_f64(*dst, base, len);
7574 if lhs_len == len && rhs_len == len {
7575 for i in 0..len {
7576 d[i] = binary_op_f64(*op, l[i], r[i]);
7577 }
7578 } else if !out_dims_bcast.is_empty() {
7579 let rank = out_dims_bcast.len();
7583 let mut coords = vec![0u32; rank];
7584 for i in 0..len {
7585 let mut rem = i;
7586 for ax in (0..rank).rev() {
7587 let sz = out_dims_bcast[ax] as usize;
7588 coords[ax] = (rem % sz) as u32;
7589 rem /= sz;
7590 }
7591 let mut li: usize = 0;
7592 let mut ri: usize = 0;
7593 for ax in 0..rank {
7594 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7595 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7596 }
7597 d[i] = binary_op_f64(*op, l[li], r[ri]);
7598 }
7599 } else {
7600 for i in 0..len {
7605 d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
7606 }
7607 }
7608 }
7609 }
7610
7611 Thunk::BinaryFullC64 {
7612 lhs,
7613 rhs,
7614 dst,
7615 len,
7616 lhs_len,
7617 rhs_len,
7618 op,
7619 out_dims_bcast,
7620 bcast_lhs_strides,
7621 bcast_rhs_strides,
7622 } => {
7623 let n_out = *len as usize;
7629 let n_l = *lhs_len as usize;
7630 let n_r = *rhs_len as usize;
7631 unsafe {
7632 let l = sl(*lhs, base, 2 * n_l);
7633 let r = sl(*rhs, base, 2 * n_r);
7634 let d = sl_mut(*dst, base, 2 * n_out);
7635 let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
7636 match op {
7637 BinaryOp::Add => (a_re + b_re, a_im + b_im),
7638 BinaryOp::Sub => (a_re - b_re, a_im - b_im),
7639 BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
7640 BinaryOp::Div => {
7641 let denom = b_re * b_re + b_im * b_im;
7642 (
7643 (a_re * b_re + a_im * b_im) / denom,
7644 (a_im * b_re - a_re * b_im) / denom,
7645 )
7646 }
7647 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
7648 unreachable!("C64 max/min/pow rejected at lowering")
7649 }
7650 }
7651 };
7652 if n_l == n_out && n_r == n_out {
7653 for i in 0..n_out {
7654 let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
7655 d[2 * i] = re;
7656 d[2 * i + 1] = im;
7657 }
7658 } else if !out_dims_bcast.is_empty() {
7659 let rank = out_dims_bcast.len();
7663 let mut coords = vec![0u32; rank];
7664 for i in 0..n_out {
7665 let mut rem = i;
7666 for ax in (0..rank).rev() {
7667 let sz = out_dims_bcast[ax] as usize;
7668 coords[ax] = (rem % sz) as u32;
7669 rem /= sz;
7670 }
7671 let mut li: usize = 0;
7672 let mut ri: usize = 0;
7673 for ax in 0..rank {
7674 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7675 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7676 }
7677 let (re, im) =
7678 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7679 d[2 * i] = re;
7680 d[2 * i + 1] = im;
7681 }
7682 } else {
7683 for i in 0..n_out {
7685 let li = if n_l == 1 { 0 } else { i % n_l };
7686 let ri = if n_r == 1 { 0 } else { i % n_r };
7687 let (re, im) =
7688 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7689 d[2 * i] = re;
7690 d[2 * i + 1] = im;
7691 }
7692 }
7693 }
7694 }
7695
7696 Thunk::ComplexNormSqF32 { src, dst, len } => {
7697 let n = *len as usize;
7698 unsafe {
7699 let s = sl(*src, base, 2 * n);
7700 let d = sl_mut(*dst, base, n);
7701 for i in 0..n {
7702 let re = s[2 * i];
7703 let im = s[2 * i + 1];
7704 d[i] = re * re + im * im;
7705 }
7706 }
7707 }
7708
7709 Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
7710 let n = *len as usize;
7713 unsafe {
7714 let zb = sl(*z, base, 2 * n);
7715 let gb = sl(*g, base, n);
7716 let db = sl_mut(*dz, base, 2 * n);
7717 for i in 0..n {
7718 let re = zb[2 * i];
7719 let im = zb[2 * i + 1];
7720 let gv = gb[i];
7721 db[2 * i] = gv * re;
7722 db[2 * i + 1] = gv * im;
7723 }
7724 }
7725 }
7726
7727 Thunk::ConjugateC64 { src, dst, len } => {
7728 let n = *len as usize;
7729 unsafe {
7730 let s = sl(*src, base, 2 * n);
7731 let d = sl_mut(*dst, base, 2 * n);
7732 for i in 0..n {
7733 d[2 * i] = s[2 * i];
7734 d[2 * i + 1] = -s[2 * i + 1];
7735 }
7736 }
7737 }
7738
7739 Thunk::ActivationC64 {
7740 src,
7741 dst,
7742 len,
7743 kind,
7744 } => {
7745 let n = *len as usize;
7746 unsafe {
7747 let s = sl(*src, base, 2 * n);
7748 let d = sl_mut(*dst, base, 2 * n);
7749 for i in 0..n {
7750 let a = s[2 * i];
7751 let b = s[2 * i + 1];
7752 let (re, im) = match kind {
7753 Activation::Neg => (-a, -b),
7754 Activation::Exp => {
7755 let ea = a.exp();
7757 (ea * b.cos(), ea * b.sin())
7758 }
7759 Activation::Log => {
7760 let r = (a * a + b * b).sqrt();
7762 (r.ln(), b.atan2(a))
7763 }
7764 Activation::Sqrt => {
7765 let r = (a * a + b * b).sqrt();
7768 let re = ((r + a) * 0.5).max(0.0).sqrt();
7769 let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
7770 let im = if b >= 0.0 { im_mag } else { -im_mag };
7771 (re, im)
7772 }
7773 _ => unreachable!("non-C64 activation kind survived lowering"),
7774 };
7775 d[2 * i] = re;
7776 d[2 * i + 1] = im;
7777 }
7778 }
7779 }
7780
7781 Thunk::Scan {
7782 body,
7783 body_init,
7784 body_input_off,
7785 body_output_off,
7786 outer_init_off,
7787 outer_final_off,
7788 length,
7789 carry_bytes,
7790 save_trajectory,
7791 xs_inputs,
7792 bcast_inputs,
7793 num_checkpoints,
7794 } => {
7795 let cb = *carry_bytes as usize;
7796 let n_steps = *length as usize;
7797 let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
7801 n_steps } else {
7803 *num_checkpoints as usize
7804 };
7805 let checkpoint_t_for_k = |k: usize| -> usize {
7806 if k_total == n_steps {
7807 k
7808 } else {
7809 ((k + 1) * n_steps)
7810 .div_ceil(k_total)
7811 .saturating_sub(1)
7812 .min(n_steps - 1)
7813 }
7814 };
7815 let mut next_k = 0usize;
7816
7817 let mut body_buf: Vec<u8> = (**body_init).clone();
7818 unsafe {
7819 std::ptr::copy_nonoverlapping(
7820 base.add(*outer_init_off),
7821 body_buf.as_mut_ptr().add(*body_input_off),
7822 cb,
7823 );
7824 for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
7828 std::ptr::copy_nonoverlapping(
7829 base.add(*outer_b_off),
7830 body_buf.as_mut_ptr().add(*body_b_off),
7831 *total_bytes as usize,
7832 );
7833 }
7834 }
7835 for t in 0..n_steps {
7836 for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
7837 let psb = *per_step_bytes as usize;
7838 unsafe {
7839 std::ptr::copy_nonoverlapping(
7840 base.add(*outer_xs_off + t * psb),
7841 body_buf.as_mut_ptr().add(*body_x_off),
7842 psb,
7843 );
7844 }
7845 }
7846
7847 execute_thunks(body, &mut body_buf);
7848
7849 if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
7850 unsafe {
7851 std::ptr::copy_nonoverlapping(
7852 body_buf.as_ptr().add(*body_output_off),
7853 base.add(*outer_final_off + next_k * cb),
7854 cb,
7855 );
7856 }
7857 next_k += 1;
7858 }
7859
7860 if *body_output_off != *body_input_off {
7861 body_buf
7862 .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
7863 }
7864 }
7865
7866 if !*save_trajectory {
7867 unsafe {
7869 std::ptr::copy_nonoverlapping(
7870 body_buf.as_ptr().add(*body_output_off),
7871 base.add(*outer_final_off),
7872 cb,
7873 );
7874 }
7875 }
7876 }
7877
7878 Thunk::ScanBackward {
7879 body_vjp,
7880 body_init,
7881 body_carry_in_off,
7882 body_x_offs,
7883 body_d_output_off,
7884 body_dcarry_out_off,
7885 outer_init_off,
7886 outer_traj_off,
7887 outer_upstream_off,
7888 outer_xs_offs,
7889 outer_dinit_off,
7890 length,
7891 carry_bytes,
7892 save_trajectory,
7893 num_checkpoints,
7894 forward_body,
7895 forward_body_init,
7896 forward_body_carry_in_off,
7897 forward_body_output_off,
7898 forward_body_x_offs,
7899 carry_elem_size,
7900 } => {
7901 let cb = *carry_bytes as usize;
7914 let n_steps = *length as usize;
7915 let k_total = *num_checkpoints as usize;
7916 let is_recursive = k_total != 0 && k_total != n_steps;
7917 let checkpoint_t_for_k = |k: usize| -> usize {
7918 ((k + 1) * n_steps)
7919 .div_ceil(k_total)
7920 .saturating_sub(1)
7921 .min(n_steps - 1)
7922 };
7923
7924 let mut fwd_buf: Vec<u8> = if is_recursive {
7925 (**forward_body_init.as_ref().unwrap()).clone()
7926 } else {
7927 Vec::new()
7928 };
7929
7930 let mut dcarry: Vec<u8> = vec![0u8; cb];
7931 if !*save_trajectory {
7932 unsafe {
7933 std::ptr::copy_nonoverlapping(
7934 base.add(*outer_upstream_off),
7935 dcarry.as_mut_ptr(),
7936 cb,
7937 );
7938 }
7939 }
7940
7941 let mut body_buf: Vec<u8> = (**body_init).clone();
7942
7943 let process_iter =
7948 |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
7949 if *save_trajectory {
7950 unsafe {
7951 let up_off = *outer_upstream_off + t * cb;
7952 match *carry_elem_size {
7953 4 => {
7954 let up_ptr = base.add(up_off) as *const f32;
7955 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
7956 let n_elems = cb / 4;
7957 for i in 0..n_elems {
7958 *dc_ptr.add(i) += *up_ptr.add(i);
7959 }
7960 }
7961 8 => {
7962 let up_ptr = base.add(up_off) as *const f64;
7963 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
7964 let n_elems = cb / 8;
7965 for i in 0..n_elems {
7966 *dc_ptr.add(i) += *up_ptr.add(i);
7967 }
7968 }
7969 other => panic!(
7970 "ScanBackward: unsupported carry elem size {other} \
7971 (only f32/f64 carries are supported today)"
7972 ),
7973 }
7974 }
7975 }
7976 body_buf[*body_carry_in_off..*body_carry_in_off + cb]
7977 .copy_from_slice(carry_in);
7978 unsafe {
7979 for (i, body_x_off) in body_x_offs.iter().enumerate() {
7980 let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
7981 let psb = per_step_bytes as usize;
7982 std::ptr::copy_nonoverlapping(
7983 base.add(outer_xs_off + t * psb),
7984 body_buf.as_mut_ptr().add(*body_x_off),
7985 psb,
7986 );
7987 }
7988 std::ptr::copy_nonoverlapping(
7989 dcarry.as_ptr(),
7990 body_buf.as_mut_ptr().add(*body_d_output_off),
7991 cb,
7992 );
7993 }
7994 execute_thunks(body_vjp, body_buf);
7995 unsafe {
7996 std::ptr::copy_nonoverlapping(
7997 body_buf.as_ptr().add(*body_dcarry_out_off),
7998 dcarry.as_mut_ptr(),
7999 cb,
8000 );
8001 }
8002 };
8003
8004 if is_recursive {
8005 let leaf_threshold = 4usize;
8013 let fb_sched = forward_body.as_ref().unwrap();
8014 let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8015 let mut segment_end = n_steps - 1;
8016 for seg_k in (0..k_total).rev() {
8017 let segment_start = if seg_k == 0 {
8018 0
8019 } else {
8020 checkpoint_t_for_k(seg_k - 1) + 1
8021 };
8022 let mut anchor: Vec<u8> = vec![0u8; cb];
8023 unsafe {
8024 let src = if seg_k == 0 {
8025 base.add(*outer_init_off)
8026 } else {
8027 base.add(*outer_traj_off + (seg_k - 1) * cb)
8028 };
8029 std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8030 }
8031 let mut leaf_action = |t: usize, carry_in: &[u8]| {
8034 process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8035 };
8036 unsafe {
8037 griewank_process_segment(
8038 segment_start,
8039 segment_end,
8040 &anchor,
8041 cb,
8042 fb_sched,
8043 fb_init,
8044 *forward_body_carry_in_off,
8045 *forward_body_output_off,
8046 forward_body_x_offs,
8047 base,
8048 outer_xs_offs,
8049 &mut fwd_buf,
8050 leaf_threshold,
8051 &mut leaf_action,
8052 );
8053 }
8054 if seg_k == 0 {
8055 break;
8056 }
8057 segment_end = segment_start - 1;
8058 }
8059 } else {
8060 let mut carry_buf: Vec<u8> = vec![0u8; cb];
8063 for t in (0..n_steps).rev() {
8064 unsafe {
8065 let src = if t == 0 {
8066 base.add(*outer_init_off)
8067 } else {
8068 base.add(*outer_traj_off + (t - 1) * cb)
8069 };
8070 std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8071 }
8072 process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8073 }
8074 }
8075
8076 unsafe {
8077 std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8078 }
8079 }
8080
8081 Thunk::ScanBackwardXs {
8082 body_vjp,
8083 body_init,
8084 body_carry_in_off,
8085 body_x_offs,
8086 body_d_output_off,
8087 body_dcarry_out_off,
8088 body_dxs_out_off,
8089 outer_init_off,
8090 outer_traj_off,
8091 outer_upstream_off,
8092 outer_xs_offs,
8093 outer_dxs_off,
8094 length,
8095 carry_bytes,
8096 carry_elem_size,
8097 per_step_bytes,
8098 save_trajectory,
8099 num_checkpoints,
8100 forward_body,
8101 forward_body_init,
8102 forward_body_carry_in_off,
8103 forward_body_output_off,
8104 forward_body_x_offs,
8105 } => {
8106 let cb = *carry_bytes as usize;
8107 let psb = *per_step_bytes as usize;
8108 let n_steps = *length as usize;
8109 let k_total = *num_checkpoints as usize;
8110 let is_recursive = k_total != 0 && k_total != n_steps;
8111 let checkpoint_t_for_k = |k: usize| -> usize {
8112 ((k + 1) * n_steps)
8113 .div_ceil(k_total)
8114 .saturating_sub(1)
8115 .min(n_steps - 1)
8116 };
8117
8118 let mut fwd_buf: Vec<u8> = if is_recursive {
8122 (**forward_body_init.as_ref().unwrap()).clone()
8123 } else {
8124 Vec::new()
8125 };
8126 let mut seg_cache: Vec<u8> = Vec::new();
8127 let mut seg_start_t: usize = usize::MAX;
8128 let mut seg_count: usize = 0;
8129 let recompute_carry_t =
8130 |t: usize,
8131 dst: &mut [u8],
8132 fwd_buf: &mut Vec<u8>,
8133 seg_cache: &mut Vec<u8>,
8134 seg_start_t: &mut usize,
8135 seg_count: &mut usize| {
8136 if !is_recursive {
8137 unsafe {
8138 let src = if t == 0 {
8139 base.add(*outer_init_off)
8140 } else {
8141 base.add(*outer_traj_off + (t - 1) * cb)
8142 };
8143 std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
8144 }
8145 return;
8146 }
8147 if *seg_start_t != usize::MAX
8148 && t >= *seg_start_t
8149 && t < *seg_start_t + *seg_count
8150 {
8151 let off = (t - *seg_start_t) * cb;
8152 dst.copy_from_slice(&seg_cache[off..off + cb]);
8153 return;
8154 }
8155 let seg_k = (0..k_total)
8156 .find(|&k| t <= checkpoint_t_for_k(k))
8157 .unwrap_or(k_total - 1);
8158 let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
8159 (0, unsafe { base.add(*outer_init_off) as *const u8 })
8160 } else {
8161 let prev_ck = checkpoint_t_for_k(seg_k - 1);
8162 (prev_ck + 1, unsafe {
8163 base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
8164 })
8165 };
8166 let seg_end_t = checkpoint_t_for_k(seg_k);
8167 let seg_size = seg_end_t - anchor_t + 1;
8168
8169 fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
8170 unsafe {
8171 std::ptr::copy_nonoverlapping(
8172 anchor_ptr,
8173 fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
8174 cb,
8175 );
8176 }
8177 seg_cache.resize(seg_size * cb, 0u8);
8178 seg_cache[0..cb].copy_from_slice(
8179 &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8180 );
8181 let fb_sched = forward_body.as_ref().unwrap();
8182 for i in 1..seg_size {
8183 let cur_iter = anchor_t + i - 1;
8184 for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
8185 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
8186 let xb = x_psb as usize;
8187 unsafe {
8188 std::ptr::copy_nonoverlapping(
8189 base.add(outer_xs_off + cur_iter * xb),
8190 fwd_buf.as_mut_ptr().add(*fb_x_off),
8191 xb,
8192 );
8193 }
8194 }
8195 execute_thunks(fb_sched, fwd_buf);
8196 if *forward_body_output_off != *forward_body_carry_in_off {
8197 fwd_buf.copy_within(
8198 *forward_body_output_off..*forward_body_output_off + cb,
8199 *forward_body_carry_in_off,
8200 );
8201 }
8202 let cache_off = i * cb;
8203 seg_cache[cache_off..cache_off + cb].copy_from_slice(
8204 &fwd_buf
8205 [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8206 );
8207 }
8208 *seg_start_t = anchor_t;
8209 *seg_count = seg_size;
8210
8211 let off = (t - anchor_t) * cb;
8212 dst.copy_from_slice(&seg_cache[off..off + cb]);
8213 };
8214
8215 let mut dcarry: Vec<u8> = vec![0u8; cb];
8216 if !*save_trajectory {
8217 unsafe {
8218 std::ptr::copy_nonoverlapping(
8219 base.add(*outer_upstream_off),
8220 dcarry.as_mut_ptr(),
8221 cb,
8222 );
8223 }
8224 }
8225
8226 let mut body_buf: Vec<u8> = (**body_init).clone();
8227
8228 for t in (0..n_steps).rev() {
8229 if *save_trajectory {
8230 unsafe {
8231 let up_off = *outer_upstream_off + t * cb;
8232 match *carry_elem_size {
8233 4 => {
8234 let up_ptr = base.add(up_off) as *const f32;
8235 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8236 let n_elems = cb / 4;
8237 for i in 0..n_elems {
8238 *dc_ptr.add(i) += *up_ptr.add(i);
8239 }
8240 }
8241 8 => {
8242 let up_ptr = base.add(up_off) as *const f64;
8243 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8244 let n_elems = cb / 8;
8245 for i in 0..n_elems {
8246 *dc_ptr.add(i) += *up_ptr.add(i);
8247 }
8248 }
8249 other => panic!(
8250 "ScanBackwardXs: unsupported carry elem size {other} \
8251 (only f32/f64 carries are supported today)"
8252 ),
8253 }
8254 }
8255 }
8256
8257 let carry_dst_start = *body_carry_in_off;
8261 {
8262 let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
8263 recompute_carry_t(
8264 t,
8265 carry_slice,
8266 &mut fwd_buf,
8267 &mut seg_cache,
8268 &mut seg_start_t,
8269 &mut seg_count,
8270 );
8271 }
8272 unsafe {
8273 for (i, body_x_off) in body_x_offs.iter().enumerate() {
8274 let (outer_xs_off, x_psb) = outer_xs_offs[i];
8275 let xb = x_psb as usize;
8276 std::ptr::copy_nonoverlapping(
8277 base.add(outer_xs_off + t * xb),
8278 body_buf.as_mut_ptr().add(*body_x_off),
8279 xb,
8280 );
8281 }
8282 std::ptr::copy_nonoverlapping(
8283 dcarry.as_ptr(),
8284 body_buf.as_mut_ptr().add(*body_d_output_off),
8285 cb,
8286 );
8287 }
8288
8289 execute_thunks(body_vjp, &mut body_buf);
8290
8291 unsafe {
8294 std::ptr::copy_nonoverlapping(
8295 body_buf.as_ptr().add(*body_dxs_out_off),
8296 base.add(*outer_dxs_off + t * psb),
8297 psb,
8298 );
8299 }
8300
8301 unsafe {
8303 std::ptr::copy_nonoverlapping(
8304 body_buf.as_ptr().add(*body_dcarry_out_off),
8305 dcarry.as_mut_ptr(),
8306 cb,
8307 );
8308 }
8309 }
8310 }
8311
8312 Thunk::FusedMmBiasAct {
8313 a,
8314 w,
8315 bias,
8316 c,
8317 m,
8318 k,
8319 n,
8320 act,
8321 } => {
8322 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8323 unsafe {
8324 let out = sl_mut(*c, base, m * n);
8325 crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
8326 match act {
8327 Some(Activation::Gelu) => {
8328 crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
8329 }
8330 Some(other) => {
8331 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8332 apply_activation_inplace(out, *other);
8333 }
8334 None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
8335 }
8336 }
8337 }
8338
8339 Thunk::FusedResidualLN {
8340 x,
8341 res,
8342 bias,
8343 g,
8344 b,
8345 out,
8346 rows,
8347 h,
8348 eps,
8349 has_bias,
8350 } => {
8351 let (rows, h) = (*rows as usize, *h as usize);
8352 unsafe {
8353 let zero = &zero_bias[..h];
8354 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8355 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8356 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8357 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8358 let bi_ptr = bi.as_ptr() as usize;
8359 let g_ptr = sl(*g, base, h).as_ptr() as usize;
8360 let b_ptr = sl(*b, base, h).as_ptr() as usize;
8361 let e = *eps;
8362 crate::pool::par_for(rows, 4, &|off, cnt| {
8363 let xs =
8364 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8365 let rs =
8366 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8367 let os = std::slice::from_raw_parts_mut(
8368 (o_ptr as *mut f32).add(off * h),
8369 cnt * h,
8370 );
8371 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8372 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8373 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8374 crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
8375 });
8376 }
8377 }
8378
8379 Thunk::FusedResidualRmsNorm {
8380 x,
8381 res,
8382 bias,
8383 g,
8384 b,
8385 out,
8386 rows,
8387 h,
8388 eps,
8389 has_bias,
8390 } => {
8391 let (rows, h) = (*rows as usize, *h as usize);
8392 unsafe {
8393 let zero = &zero_bias[..h];
8394 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8395 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8396 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8397 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8398 let bi_ptr = bi.as_ptr() as usize;
8399 let g_ptr = sl(*g, base, h).as_ptr() as usize;
8400 let b_ptr = sl(*b, base, h).as_ptr() as usize;
8401 let e = *eps;
8402 crate::pool::par_for(rows, 4, &|off, cnt| {
8403 let xs =
8404 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8405 let rs =
8406 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8407 let os = std::slice::from_raw_parts_mut(
8408 (o_ptr as *mut f32).add(off * h),
8409 cnt * h,
8410 );
8411 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8412 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8413 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8414 crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
8415 });
8416 }
8417 }
8418
8419 Thunk::BiasAdd {
8420 src,
8421 bias,
8422 dst,
8423 m,
8424 n,
8425 } => {
8426 let (m, n) = (*m as usize, *n as usize);
8427 unsafe {
8428 let out = sl_mut(*dst, base, m * n);
8429 out.copy_from_slice(sl(*src, base, m * n));
8430 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8431 }
8432 }
8433
8434 Thunk::BinaryFull {
8435 lhs,
8436 rhs,
8437 dst,
8438 len,
8439 lhs_len,
8440 rhs_len,
8441 op,
8442 out_dims_bcast,
8443 bcast_lhs_strides,
8444 bcast_rhs_strides,
8445 } => {
8446 let len = *len as usize;
8447 let ll = (*lhs_len as usize).max(1);
8448 let rl = (*rhs_len as usize).max(1);
8449 unsafe {
8450 let l = sl(*lhs, base, ll);
8451 let r = sl(*rhs, base, rl);
8452 let o = sl_mut(*dst, base, len);
8453 if ll == len && rl == len {
8455 #[cfg(target_arch = "aarch64")]
8456 if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
8457 use std::arch::aarch64::*;
8458 let chunks = len / 4;
8459 for c in 0..chunks {
8460 let off = c * 4;
8461 let vl = vld1q_f32(l.as_ptr().add(off));
8462 let vr = vld1q_f32(r.as_ptr().add(off));
8463 let res = match op {
8464 BinaryOp::Add => vaddq_f32(vl, vr),
8465 BinaryOp::Mul => vmulq_f32(vl, vr),
8466 _ => unreachable!(),
8467 };
8468 vst1q_f32(o.as_mut_ptr().add(off), res);
8469 }
8470 for i in (chunks * 4)..len {
8471 o[i] = match op {
8472 BinaryOp::Add => l[i] + r[i],
8473 BinaryOp::Mul => l[i] * r[i],
8474 _ => unreachable!(),
8475 };
8476 }
8477 continue;
8483 }
8484 }
8485 if !out_dims_bcast.is_empty() {
8486 let rank = out_dims_bcast.len();
8489 let mut coords = vec![0u32; rank];
8490 for i in 0..len {
8491 let mut rem = i;
8492 for ax in (0..rank).rev() {
8493 let sz = out_dims_bcast[ax] as usize;
8494 coords[ax] = (rem % sz) as u32;
8495 rem /= sz;
8496 }
8497 let mut li: usize = 0;
8498 let mut ri: usize = 0;
8499 for ax in 0..rank {
8500 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8501 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8502 }
8503 o[i] = match op {
8504 BinaryOp::Add => l[li] + r[ri],
8505 BinaryOp::Sub => l[li] - r[ri],
8506 BinaryOp::Mul => l[li] * r[ri],
8507 BinaryOp::Div => l[li] / r[ri],
8508 BinaryOp::Max => l[li].max(r[ri]),
8509 BinaryOp::Min => l[li].min(r[ri]),
8510 BinaryOp::Pow => l[li].powf(r[ri]),
8511 };
8512 }
8513 } else {
8514 for i in 0..len {
8516 let li = if ll == 1 { 0 } else { i % ll };
8517 let ri = if rl == 1 { 0 } else { i % rl };
8518 o[i] = match op {
8519 BinaryOp::Add => l[li] + r[ri],
8520 BinaryOp::Sub => l[li] - r[ri],
8521 BinaryOp::Mul => l[li] * r[ri],
8522 BinaryOp::Div => l[li] / r[ri],
8523 BinaryOp::Max => l[li].max(r[ri]),
8524 BinaryOp::Min => l[li].min(r[ri]),
8525 BinaryOp::Pow => l[li].powf(r[ri]),
8526 };
8527 }
8528 }
8529 }
8530 }
8531
8532 Thunk::Gather {
8533 table,
8534 table_len,
8535 idx,
8536 dst,
8537 num_idx,
8538 trailing,
8539 } => {
8540 let (ni, tr) = (*num_idx as usize, *trailing as usize);
8541 unsafe {
8542 let tab = sl(*table, base, *table_len as usize);
8543 let ids = sl(*idx, base, ni);
8544 let out = sl_mut(*dst, base, ni * tr);
8545 for i in 0..ni {
8546 let row = ids[i] as usize;
8547 out[i * tr..(i + 1) * tr].copy_from_slice(&tab[row * tr..(row + 1) * tr]);
8548 }
8549 }
8550 }
8551
8552 Thunk::Narrow {
8553 src,
8554 dst,
8555 outer,
8556 src_stride,
8557 dst_stride,
8558 inner,
8559 elem_bytes,
8560 } => {
8561 let f = narrow_thunk_closure(
8562 *src,
8563 *dst,
8564 *outer,
8565 *src_stride,
8566 *dst_stride,
8567 *inner,
8568 *elem_bytes,
8569 );
8570 f(base);
8571 }
8572
8573 Thunk::Copy { src, dst, len } => {
8574 let len = *len as usize;
8575 unsafe {
8576 let s = sl(*src, base, len);
8577 let d = sl_mut(*dst, base, len);
8578 d.copy_from_slice(s);
8579 }
8580 }
8581
8582 Thunk::LayerNorm {
8583 src,
8584 g,
8585 b,
8586 dst,
8587 rows,
8588 h,
8589 eps,
8590 } => {
8591 let (rows, h) = (*rows as usize, *h as usize);
8592 unsafe {
8593 let input = sl(*src, base, rows * h);
8594 let gamma = sl(*g, base, h);
8595 let beta = sl(*b, base, h);
8596 let output = sl_mut(*dst, base, rows * h);
8597 if rows >= 4 && rows * h >= 30_000 {
8599 let i_ptr = input.as_ptr() as usize;
8600 let o_ptr = output.as_mut_ptr() as usize;
8601 let g_ptr = gamma.as_ptr() as usize;
8602 let b_ptr = beta.as_ptr() as usize;
8603 let e = *eps;
8604 crate::pool::par_for(rows, 4, &|off, cnt| {
8605 let inp = std::slice::from_raw_parts(
8606 (i_ptr as *const f32).add(off * h),
8607 cnt * h,
8608 );
8609 let out = std::slice::from_raw_parts_mut(
8610 (o_ptr as *mut f32).add(off * h),
8611 cnt * h,
8612 );
8613 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8614 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8615 for row in 0..cnt {
8616 crate::kernels::layer_norm_row(
8617 &inp[row * h..(row + 1) * h],
8618 g,
8619 b,
8620 &mut out[row * h..(row + 1) * h],
8621 h,
8622 e,
8623 );
8624 }
8625 });
8626 } else {
8627 for row in 0..rows {
8628 crate::kernels::layer_norm_row(
8629 &input[row * h..(row + 1) * h],
8630 gamma,
8631 beta,
8632 &mut output[row * h..(row + 1) * h],
8633 h,
8634 *eps,
8635 );
8636 }
8637 }
8638 }
8639 }
8640
8641 Thunk::GroupNorm {
8642 src,
8643 g,
8644 b,
8645 dst,
8646 n,
8647 c,
8648 h,
8649 w,
8650 num_groups,
8651 eps,
8652 } => {
8653 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8654 let plane = c * h * w;
8655 unsafe {
8656 for ni in 0..n {
8657 let input = sl(*src, base.add(ni * plane), plane);
8658 let gamma = sl(*g, base, c);
8659 let beta = sl(*b, base, c);
8660 let output = sl_mut(*dst, base.add(ni * plane), plane);
8661 crate::kernels::group_norm_nchw(
8662 input,
8663 gamma,
8664 beta,
8665 output,
8666 1,
8667 c,
8668 h,
8669 w,
8670 *num_groups as usize,
8671 *eps,
8672 );
8673 }
8674 }
8675 }
8676
8677 Thunk::LayerNorm2d {
8678 src,
8679 g,
8680 b,
8681 dst,
8682 n,
8683 c,
8684 h,
8685 w,
8686 eps,
8687 } => {
8688 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8689 let plane = c * h * w;
8690 unsafe {
8691 let input = sl(*src, base, n * plane);
8692 let gamma = sl(*g, base, c);
8693 let beta = sl(*b, base, c);
8694 let output = sl_mut(*dst, base, n * plane);
8695 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
8696 }
8697 }
8698
8699 Thunk::ConvTranspose2d {
8700 src,
8701 weight,
8702 dst,
8703 n,
8704 c_in,
8705 h,
8706 w_in,
8707 c_out,
8708 h_out,
8709 w_out,
8710 kh,
8711 kw,
8712 sh,
8713 sw,
8714 ph,
8715 pw,
8716 dh,
8717 dw,
8718 groups,
8719 } => {
8720 let n = *n as usize;
8721 let c_in = *c_in as usize;
8722 let h = *h as usize;
8723 let w_in = *w_in as usize;
8724 let c_out = *c_out as usize;
8725 let h_out = *h_out as usize;
8726 let w_out = *w_out as usize;
8727 unsafe {
8728 let inp = sl(*src, base, n * c_in * h * w_in);
8729 let wt = sl(
8730 *weight,
8731 base,
8732 c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
8733 );
8734 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
8735 crate::kernels::conv_transpose2d_nchw(
8736 inp,
8737 wt,
8738 out,
8739 n,
8740 c_in,
8741 h,
8742 w_in,
8743 c_out,
8744 h_out,
8745 w_out,
8746 *kh as usize,
8747 *kw as usize,
8748 *sh as usize,
8749 *sw as usize,
8750 *ph as usize,
8751 *pw as usize,
8752 *dh as usize,
8753 *dw as usize,
8754 *groups as usize,
8755 );
8756 }
8757 }
8758
8759 Thunk::ResizeNearest2x {
8760 src,
8761 dst,
8762 n,
8763 c,
8764 h,
8765 w,
8766 } => {
8767 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8768 let in_plane = c * h * w;
8769 let out_plane = c * h * 2 * w * 2;
8770 unsafe {
8771 for ni in 0..n {
8772 let input = sl(*src, base.add(ni * in_plane), in_plane);
8773 let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
8774 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
8775 }
8776 }
8777 }
8778
8779 Thunk::AxialRope2d {
8780 src,
8781 dst,
8782 batch,
8783 seq,
8784 hidden,
8785 end_x,
8786 end_y,
8787 head_dim,
8788 num_heads,
8789 theta,
8790 repeat_factor,
8791 } => {
8792 let b = *batch as usize;
8793 let s = *seq as usize;
8794 let hdim = *head_dim as usize;
8795 let nh = *num_heads as usize;
8796 let plane = s * (*hidden as usize);
8797 unsafe {
8798 for bi in 0..b {
8799 let input = sl(*src, base.add(bi * plane), plane);
8800 let output = sl_mut(*dst, base.add(bi * plane), plane);
8801 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
8802 input,
8803 nh,
8804 s,
8805 hdim,
8806 *end_x as usize,
8807 *end_y as usize,
8808 *theta,
8809 *repeat_factor as usize,
8810 );
8811 output.copy_from_slice(&rotated);
8812 }
8813 }
8814 }
8815
8816 Thunk::RmsNorm {
8817 src,
8818 g,
8819 b,
8820 dst,
8821 rows,
8822 h,
8823 eps,
8824 } => {
8825 let (rows, h) = (*rows as usize, *h as usize);
8826 unsafe {
8827 let input = sl(*src, base, rows * h);
8828 let gamma = sl(*g, base, h);
8829 let beta = sl(*b, base, h);
8830 let output = sl_mut(*dst, base, rows * h);
8831 let inv_h = 1.0 / h as f32;
8832 for row in 0..rows {
8833 let in_row = &input[row * h..(row + 1) * h];
8834 let out_row = &mut output[row * h..(row + 1) * h];
8835 let mut sumsq = 0f32;
8837 for &v in in_row {
8838 sumsq += v * v;
8839 }
8840 let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
8841 for i in 0..h {
8842 out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
8843 }
8844 }
8845 }
8846 }
8847
8848 Thunk::Softmax { data, rows, cols } => {
8849 let (rows, cols) = (*rows as usize, *cols as usize);
8850 unsafe {
8851 crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
8852 }
8853 }
8854
8855 Thunk::Cumsum {
8856 src,
8857 dst,
8858 rows,
8859 cols,
8860 exclusive,
8861 } => {
8862 let (rows, cols) = (*rows as usize, *cols as usize);
8863 unsafe {
8864 let s = sl(*src, base, rows * cols);
8865 let d = sl_mut(*dst, base, rows * cols);
8866 if *exclusive {
8867 for r in 0..rows {
8868 let mut acc = 0.0f32;
8869 for c in 0..cols {
8870 d[r * cols + c] = acc;
8871 acc += s[r * cols + c];
8872 }
8873 }
8874 } else {
8875 for r in 0..rows {
8876 let mut acc = 0.0f32;
8877 for c in 0..cols {
8878 acc += s[r * cols + c];
8879 d[r * cols + c] = acc;
8880 }
8881 }
8882 }
8883 }
8884 }
8885
8886 Thunk::Sample {
8887 logits,
8888 dst,
8889 batch,
8890 vocab,
8891 top_k,
8892 top_p,
8893 temperature,
8894 seed,
8895 } => {
8896 let (b, v) = (*batch as usize, *vocab as usize);
8897 let k = (*top_k as usize).min(v);
8898 unsafe {
8899 let lg = sl(*logits, base, b * v);
8900 let out = sl_mut(*dst, base, b);
8901 let mut rng =
8902 rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
8903 for bi in 0..b {
8904 let row = &lg[bi * v..(bi + 1) * v];
8905 out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
8906 }
8907 }
8908 }
8909
8910 Thunk::GatedDeltaNet {
8911 q,
8912 k,
8913 v,
8914 g,
8915 beta,
8916 state,
8917 dst,
8918 batch,
8919 seq,
8920 heads,
8921 state_size,
8922 } => unsafe {
8923 execute_gated_delta_net_f32(
8924 *q,
8925 *k,
8926 *v,
8927 *g,
8928 *beta,
8929 *state,
8930 *dst,
8931 *batch as usize,
8932 *seq as usize,
8933 *heads as usize,
8934 *state_size as usize,
8935 base,
8936 );
8937 },
8938
8939 Thunk::SelectiveScan {
8940 x,
8941 delta,
8942 a,
8943 b: bp,
8944 c: cp,
8945 dst,
8946 batch,
8947 seq,
8948 hidden,
8949 state_size,
8950 } => {
8951 let (b, s, h, n) = (
8952 *batch as usize,
8953 *seq as usize,
8954 *hidden as usize,
8955 *state_size as usize,
8956 );
8957 unsafe {
8958 let xs = sl(*x, base, b * s * h);
8959 let dt = sl(*delta, base, b * s * h);
8960 let am = sl(*a, base, h * n);
8961 let bm = sl(*bp, base, b * s * n);
8962 let cm = sl(*cp, base, b * s * n);
8963 let out = sl_mut(*dst, base, b * s * h);
8964
8965 let mut state = vec![0f32; h * n];
8969 for bi in 0..b {
8970 for v in state.iter_mut() {
8972 *v = 0.0;
8973 }
8974 for si in 0..s {
8975 let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8976 let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8977 let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
8978 let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
8979 let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8980
8981 for ci in 0..h {
8982 let d = dt_row[ci];
8983 let xv = x_row[ci];
8984 let mut acc = 0f32;
8985 for ni in 0..n {
8986 let da = (d * am[ci * n + ni]).exp();
8988 state[ci * n + ni] =
8989 da * state[ci * n + ni] + d * b_row[ni] * xv;
8990 acc += c_row[ni] * state[ci * n + ni];
8991 }
8992 out_row[ci] = acc;
8993 }
8994 }
8995 }
8996 }
8997 }
8998
8999 Thunk::DequantMatMul {
9000 x,
9001 w_q,
9002 scale,
9003 zp,
9004 dst,
9005 m,
9006 k,
9007 n,
9008 block_size,
9009 is_asymmetric,
9010 } => {
9011 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9012 let n_blocks = k.div_ceil(bs);
9013 unsafe {
9014 let xs = sl(*x, base, m * k);
9015 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
9016 let scales = sl(*scale, base, n_blocks * n);
9017 let zps = if *is_asymmetric {
9018 sl(*zp, base, n_blocks * n)
9019 } else {
9020 &[][..]
9021 };
9022 let out = sl_mut(*dst, base, m * n);
9023 dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9024 }
9025 }
9026
9027 Thunk::DequantMatMulGguf {
9028 x,
9029 w_q,
9030 dst,
9031 m,
9032 k,
9033 n,
9034 scheme,
9035 } => {
9036 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9037 let block_bytes = scheme.gguf_block_bytes() as usize;
9038 let block_elems = scheme.gguf_block_size() as usize;
9039 debug_assert!(
9040 block_bytes > 0 && block_elems > 0,
9041 "non-GGUF scheme in GGUF arm"
9042 );
9043 debug_assert!(
9044 (k * n).is_multiple_of(block_elems),
9045 "k*n={} not aligned to GGUF block size {}",
9046 k * n,
9047 block_elems
9048 );
9049 let total_bytes = (k * n) / block_elems * block_bytes;
9050 unsafe {
9051 let xs = sl(*x, base, m * k);
9052 let w_bytes_ptr = base.add(*w_q) as *const u8;
9053 let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
9054 let out = sl_mut(*dst, base, m * n);
9055 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
9056 }
9057 }
9058
9059 Thunk::DequantMatMulInt4 {
9060 x,
9061 w_q,
9062 scale,
9063 zp,
9064 dst,
9065 m,
9066 k,
9067 n,
9068 block_size,
9069 is_asymmetric,
9070 } => {
9071 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9072 let n_blocks = k.div_ceil(bs);
9073 unsafe {
9074 let xs = sl(*x, base, m * k);
9075 let w_bytes = std::slice::from_raw_parts(
9076 base.add(*w_q) as *const u8,
9077 (k * n).div_ceil(2),
9078 );
9079 let scales = sl(*scale, base, n_blocks * n);
9080 let zps = if *is_asymmetric {
9081 sl(*zp, base, n_blocks * n)
9082 } else {
9083 &[][..]
9084 };
9085 let out = sl_mut(*dst, base, m * n);
9086 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9087 }
9088 }
9089
9090 Thunk::DequantMatMulFp8 {
9091 x,
9092 w_q,
9093 scale,
9094 dst,
9095 m,
9096 k,
9097 n,
9098 e5m2,
9099 } => {
9100 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9101 unsafe {
9102 let xs = sl(*x, base, m * k);
9103 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
9104 let scales = sl(*scale, base, n);
9105 let out = sl_mut(*dst, base, m * n);
9106 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
9107 }
9108 }
9109
9110 Thunk::DequantMatMulNvfp4 {
9111 x,
9112 w_q,
9113 scale,
9114 global_scale,
9115 dst,
9116 m,
9117 k,
9118 n,
9119 } => {
9120 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9121 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
9122 unsafe {
9123 let xs = sl(*x, base, m * k);
9124 let w_bytes = std::slice::from_raw_parts(
9125 base.add(*w_q) as *const u8,
9126 (k * n).div_ceil(2),
9127 );
9128 let scale_bytes =
9129 std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
9130 let gs = sl(*global_scale, base, 1)[0];
9131 let out = sl_mut(*dst, base, m * n);
9132 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
9133 }
9134 }
9135
9136 Thunk::LoraMatMul {
9137 x,
9138 w,
9139 a,
9140 b,
9141 dst,
9142 m,
9143 k,
9144 n,
9145 r,
9146 scale,
9147 } => {
9148 let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
9149 unsafe {
9150 let xs = sl(*x, base, m * k);
9151 let ws = sl(*w, base, k * n);
9152 let a_s = sl(*a, base, k * r);
9153 let bs = sl(*b, base, r * n);
9154 let out = sl_mut(*dst, base, m * n);
9155 crate::blas::sgemm(xs, ws, out, m, k, n);
9156 let mut tmp = vec![0f32; m * r];
9157 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
9158 if *scale != 1.0 {
9159 for v in tmp.iter_mut() {
9160 *v *= *scale;
9161 }
9162 }
9163 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
9164 }
9165 }
9166
9167 Thunk::Attention {
9168 q,
9169 k,
9170 v,
9171 mask,
9172 out,
9173 batch,
9174 seq,
9175 kv_seq,
9176 heads,
9177 head_dim,
9178 mask_kind,
9179 q_row_stride,
9180 k_row_stride,
9181 v_row_stride,
9182 bhsd,
9183 } => {
9184 let (b, q_s, k_s, nh, dh) = (
9185 *batch as usize,
9186 *seq as usize,
9187 *kv_seq as usize,
9188 *heads as usize,
9189 *head_dim as usize,
9190 );
9191 let hs = nh * dh;
9192 let (qrs, krs, vrs) = if *bhsd {
9195 (dh, dh, dh)
9196 } else {
9197 (
9198 *q_row_stride as usize,
9199 *k_row_stride as usize,
9200 *v_row_stride as usize,
9201 )
9202 };
9203 let bhsd = *bhsd;
9204 let _ = (q_row_stride, k_row_stride, v_row_stride);
9205 let scale = (dh as f32).powf(-0.5);
9206 let ss = q_s * k_s;
9207 let cfg = crate::config::RuntimeConfig::global();
9208 unsafe {
9209 let q_len = if bhsd {
9216 b * nh * q_s * dh
9217 } else {
9218 b * q_s * qrs
9219 };
9220 let k_len = if bhsd {
9221 b * nh * k_s * dh
9222 } else {
9223 b * k_s * krs
9224 };
9225 let v_len = if bhsd {
9226 b * nh * k_s * dh
9227 } else {
9228 b * k_s * vrs
9229 };
9230 let q_data = sl(*q, base, q_len);
9231 let k_data = sl(*k, base, k_len);
9232 let v_data = sl(*v, base, v_len);
9233 let mask_data: &[f32] = match mask_kind {
9234 rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
9235 rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
9236 _ => &[],
9237 };
9238 let out_len = if bhsd {
9239 b * nh * q_s * dh
9240 } else {
9241 b * q_s * hs
9242 };
9243 let out_data = sl_mut(*out, base, out_len);
9244
9245 if bhsd {
9256 let scores = &mut sdpa_scores[..ss];
9257 for bi in 0..b {
9258 for hi in 0..nh {
9259 let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
9260 let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
9261 for qi in 0..q_s {
9263 let q_base = q_head_base + qi * dh;
9264 for ki in 0..k_s {
9265 let k_base = k_head_base + ki * dh;
9266 let mut dot = 0f32;
9267 for d in 0..dh {
9268 dot += q_data[q_base + d] * k_data[k_base + d];
9269 }
9270 scores[qi * k_s + ki] = dot * scale;
9271 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9272 && !mask_data.is_empty()
9273 && mask_data[bi * k_s + ki] < mask_thr
9274 {
9275 scores[qi * k_s + ki] = mask_neg;
9276 }
9277 }
9278 }
9279 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9280 let off = (bi * nh + hi) * q_s * k_s;
9281 for i in 0..q_s * k_s {
9282 scores[i] += mask_data[off + i];
9283 }
9284 }
9285 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9286 crate::kernels::neon_softmax(scores, q_s, k_s);
9287 for qi in 0..q_s {
9289 let o_base = q_head_base + qi * dh;
9290 for d in 0..dh {
9291 out_data[o_base + d] = 0.0;
9292 }
9293 for ki in 0..k_s {
9294 let sc = scores[qi * k_s + ki];
9295 if sc > score_thr {
9296 let v_base = k_head_base + ki * dh;
9297 for d in 0..dh {
9298 out_data[o_base + d] += sc * v_data[v_base + d];
9299 }
9300 }
9301 }
9302 }
9303 }
9304 }
9305 continue;
9306 }
9307
9308 if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
9315 let scores = &mut sdpa_scores[..ss];
9317 #[cfg(target_arch = "aarch64")]
9318 let neon_chunks = dh / 4;
9319
9320 for bi in 0..b {
9321 for hi in 0..nh {
9322 for qi in 0..q_s {
9324 let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
9325 for ki in 0..k_s {
9326 let k_off = bi * k_s * krs + ki * krs + hi * dh;
9327 #[cfg(target_arch = "aarch64")]
9328 let mut dot;
9329 #[cfg(not(target_arch = "aarch64"))]
9330 let mut dot = 0f32;
9331 #[cfg(target_arch = "aarch64")]
9332 {
9333 use std::arch::aarch64::*;
9334 let mut acc = vdupq_n_f32(0.0);
9335 for c in 0..neon_chunks {
9336 let vq =
9337 vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
9338 let vk =
9339 vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
9340 acc = vfmaq_f32(acc, vq, vk);
9341 }
9342 dot = vaddvq_f32(acc);
9343 for d in (neon_chunks * 4)..dh {
9344 dot += q_data[q_off + d] * k_data[k_off + d];
9345 }
9346 }
9347 #[cfg(not(target_arch = "aarch64"))]
9348 for d in 0..dh {
9349 dot += q_data[q_off + d] * k_data[k_off + d];
9350 }
9351 scores[qi * k_s + ki] = dot * scale;
9352 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9359 && !mask_data.is_empty()
9360 && mask_data[bi * k_s + ki] < mask_thr
9361 {
9362 scores[qi * k_s + ki] = mask_neg;
9363 }
9364 }
9365 }
9366
9367 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9368 let off = (bi * nh + hi) * q_s * k_s;
9369 for i in 0..q_s * k_s {
9370 scores[i] += mask_data[off + i];
9371 }
9372 }
9373 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9374 crate::kernels::neon_softmax(scores, q_s, k_s);
9375
9376 for qi in 0..q_s {
9378 let o_off = bi * q_s * hs + qi * hs + hi * dh;
9379 for d in 0..dh {
9381 out_data[o_off + d] = 0.0;
9382 }
9383 for ki in 0..k_s {
9384 let sc = scores[qi * k_s + ki];
9385 if sc > score_thr {
9386 let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
9387 #[cfg(target_arch = "aarch64")]
9388 {
9389 use std::arch::aarch64::*;
9390 let vsc = vdupq_n_f32(sc);
9391 for c in 0..neon_chunks {
9392 let off = c * 4;
9393 let vo = vld1q_f32(
9394 out_data.as_ptr().add(o_off + off),
9395 );
9396 let vv =
9397 vld1q_f32(v_data.as_ptr().add(v_off + off));
9398 vst1q_f32(
9399 out_data.as_mut_ptr().add(o_off + off),
9400 vfmaq_f32(vo, vsc, vv),
9401 );
9402 }
9403 }
9404 #[cfg(not(target_arch = "aarch64"))]
9405 for d in 0..dh {
9406 out_data[o_off + d] += sc * v_data[v_off + d];
9407 }
9408 }
9409 }
9410 }
9411 }
9412 }
9413 } else {
9414 let total_work = b * nh;
9416 let q_addr = q_data.as_ptr() as usize;
9417 let k_addr = k_data.as_ptr() as usize;
9418 let v_addr = v_data.as_ptr() as usize;
9419 let m_addr = mask_data.as_ptr() as usize;
9420 let o_addr = out_data.as_mut_ptr() as usize;
9421 let sc_addr = sdpa_scores.as_mut_ptr() as usize;
9422
9423 crate::pool::par_for(total_work, 1, &|off, cnt| {
9424 for idx in off..off + cnt {
9425 let bi = idx / nh;
9426 let hi = idx % nh;
9427
9428 let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
9429 let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
9430 let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
9431 let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
9432 let sc = std::slice::from_raw_parts_mut(
9433 (sc_addr as *mut f32).add(idx * ss),
9434 ss,
9435 );
9436
9437 crate::blas::sgemm_general(
9440 q_start,
9441 k_start,
9442 sc.as_mut_ptr(),
9443 q_s,
9444 k_s,
9445 dh,
9446 scale,
9447 0.0,
9448 qrs,
9449 krs,
9450 k_s,
9451 false,
9452 true,
9453 );
9454
9455 match mask_kind {
9456 rlx_ir::op::MaskKind::Custom => {
9457 let mask_bi = std::slice::from_raw_parts(
9458 (m_addr as *const f32).add(bi * k_s),
9459 k_s,
9460 );
9461 for ki in 0..k_s {
9462 if mask_bi[ki] < mask_thr {
9463 for qi in 0..q_s {
9464 sc[qi * k_s + ki] = mask_neg;
9465 }
9466 }
9467 }
9468 }
9469 rlx_ir::op::MaskKind::Bias => {
9470 let bias = std::slice::from_raw_parts(
9472 (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
9473 q_s * k_s,
9474 );
9475 for i in 0..q_s * k_s {
9476 sc[i] += bias[i];
9477 }
9478 }
9479 _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
9480 }
9481
9482 crate::kernels::neon_softmax(sc, q_s, k_s);
9483
9484 crate::blas::sgemm_general(
9488 sc.as_ptr(),
9489 v_start,
9490 o_start,
9491 q_s,
9492 dh,
9493 k_s,
9494 1.0,
9495 0.0,
9496 k_s,
9497 vrs,
9498 hs,
9499 false,
9500 false,
9501 );
9502 }
9503 });
9504 }
9505 }
9506 }
9507
9508 Thunk::AttentionBackward {
9509 q,
9510 k,
9511 v,
9512 dy,
9513 mask,
9514 out,
9515 batch,
9516 seq,
9517 kv_seq,
9518 heads,
9519 head_dim,
9520 mask_kind,
9521 wrt,
9522 bhsd,
9523 } => {
9524 let (b, q_s, k_s, nh, dh) = (
9525 *batch as usize,
9526 *seq as usize,
9527 *kv_seq as usize,
9528 *heads as usize,
9529 *head_dim as usize,
9530 );
9531 unsafe {
9532 let q_len = if *bhsd {
9533 b * nh * q_s * dh
9534 } else {
9535 b * q_s * nh * dh
9536 };
9537 let k_len = if *bhsd {
9538 b * nh * k_s * dh
9539 } else {
9540 b * k_s * nh * dh
9541 };
9542 let out_len = match wrt {
9543 rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
9544 k_len
9545 }
9546 rlx_ir::op::AttentionBwdWrt::Query => q_len,
9547 };
9548 let q_data = sl(*q, base, q_len);
9549 let k_data = sl(*k, base, k_len);
9550 let v_data = sl(*v, base, k_len);
9551 let dy_data = sl(*dy, base, q_len);
9552 let out_data = sl_mut(*out, base, out_len);
9553 let mask_data: &[f32] = if *mask != 0 {
9554 let ml = match mask_kind {
9555 rlx_ir::op::MaskKind::Custom => b * k_s,
9556 rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
9557 _ => 0,
9558 };
9559 sl(*mask, base, ml)
9560 } else {
9561 &[]
9562 };
9563 crate::attention_bwd::attention_backward(
9564 *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
9565 *mask_kind, mask_data, *bhsd,
9566 );
9567 }
9568 }
9569
9570 Thunk::ActivationInPlace { data, len, act } => {
9571 let len = *len as usize;
9572 unsafe {
9573 let d = sl_mut(*data, base, len);
9574 match act {
9575 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
9576 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
9577 Activation::Silu => crate::kernels::par_silu_inplace(d),
9578 Activation::Relu => {
9579 for v in d.iter_mut() {
9580 *v = v.max(0.0);
9581 }
9582 }
9583 Activation::Sigmoid => {
9584 for v in d.iter_mut() {
9585 *v = 1.0 / (1.0 + (-*v).exp());
9586 }
9587 }
9588 Activation::Tanh => {
9589 for v in d.iter_mut() {
9590 *v = v.tanh();
9591 }
9592 }
9593 Activation::Exp => {
9594 for v in d.iter_mut() {
9595 *v = v.exp();
9596 }
9597 }
9598 Activation::Log => {
9599 for v in d.iter_mut() {
9600 *v = v.ln();
9601 }
9602 }
9603 Activation::Sqrt => {
9604 for v in d.iter_mut() {
9605 *v = v.sqrt();
9606 }
9607 }
9608 Activation::Rsqrt => {
9609 for v in d.iter_mut() {
9610 *v = 1.0 / v.sqrt();
9611 }
9612 }
9613 Activation::Neg => {
9614 for v in d.iter_mut() {
9615 *v = -*v;
9616 }
9617 }
9618 Activation::Abs => {
9619 for v in d.iter_mut() {
9620 *v = v.abs();
9621 }
9622 }
9623 Activation::Round => {
9624 for v in d.iter_mut() {
9625 *v = v.round();
9626 }
9627 }
9628 Activation::Sin => {
9629 for v in d.iter_mut() {
9630 *v = v.sin();
9631 }
9632 }
9633 Activation::Cos => {
9634 for v in d.iter_mut() {
9635 *v = v.cos();
9636 }
9637 }
9638 Activation::Tan => {
9639 for v in d.iter_mut() {
9640 *v = v.tan();
9641 }
9642 }
9643 Activation::Atan => {
9644 for v in d.iter_mut() {
9645 *v = v.atan();
9646 }
9647 }
9648 }
9649 }
9650 }
9651
9652 Thunk::FusedAttnBlock {
9653 hidden,
9654 qkv_w,
9655 out_w,
9656 mask,
9657 out,
9658 qkv_b,
9659 out_b,
9660 cos,
9661 sin,
9662 cos_len,
9663 batch,
9664 seq,
9665 hs,
9666 nh,
9667 dh,
9668 has_bias,
9669 has_rope,
9670 } => {
9671 let (b, s) = (*batch as usize, *seq as usize);
9672 let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
9673 let m = b * s;
9674 let scale = (d_h as f32).powf(-0.5);
9675 let half = d_h / 2;
9676 unsafe {
9677 let inp = sl(*hidden, base, m * h);
9678 let wq = sl(*qkv_w, base, h * 3 * h);
9679 let wo = sl(*out_w, base, h * h);
9680 let mk = sl(*mask, base, b * s);
9681 let dst = sl_mut(*out, base, m * h);
9682
9683 let mut qkv = vec![0f32; m * 3 * h];
9685 let mut attn_out = vec![0f32; m * h];
9686 let mut scores_buf = vec![0f32; s * s]; crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
9690 if *has_bias {
9691 let bias = sl(*qkv_b, base, 3 * h);
9692 crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
9693 }
9694
9695 #[cfg(target_arch = "aarch64")]
9698 let neon_chunks = d_h / 4;
9699 #[cfg(target_arch = "aarch64")]
9700 let _rope_chunks = half / 4;
9701
9702 for bi in 0..b {
9703 for hi in 0..n_h {
9704 for qi in 0..s {
9706 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9707 for ki in 0..s {
9708 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9709 let mut dot = 0f32;
9710
9711 if *has_rope {
9712 let q_cos = qi * half;
9714 let k_cos = ki * half;
9715 let cos_tab = sl(*cos, base, *cos_len as usize);
9716 let sin_tab = sl(*sin, base, *cos_len as usize);
9717 for i in 0..half {
9720 let q1 = qkv[q_base + i];
9721 let q2 = qkv[q_base + half + i];
9722 let k1 = qkv[k_base + i];
9723 let k2 = qkv[k_base + half + i];
9724 let c_q = cos_tab[q_cos + i];
9725 let s_q = sin_tab[q_cos + i];
9726 let c_k = cos_tab[k_cos + i];
9727 let s_k = sin_tab[k_cos + i];
9728 let qr1 = q1 * c_q - q2 * s_q;
9729 let kr1 = k1 * c_k - k2 * s_k;
9730 let qr2 = q2 * c_q + q1 * s_q;
9731 let kr2 = k2 * c_k + k1 * s_k;
9732 dot += qr1 * kr1 + qr2 * kr2;
9733 }
9734 } else {
9735 #[cfg(target_arch = "aarch64")]
9737 {
9738 use std::arch::aarch64::*;
9739 let mut acc = vdupq_n_f32(0.0);
9740 for c in 0..neon_chunks {
9741 let vq =
9742 vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
9743 let vk =
9744 vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
9745 acc = vfmaq_f32(acc, vq, vk);
9746 }
9747 dot = vaddvq_f32(acc);
9748 for d in (neon_chunks * 4)..d_h {
9749 dot += qkv[q_base + d] * qkv[k_base + d];
9750 }
9751 }
9752 #[cfg(not(target_arch = "aarch64"))]
9753 for d in 0..d_h {
9754 dot += qkv[q_base + d] * qkv[k_base + d];
9755 }
9756 }
9757
9758 scores_buf[qi * s + ki] = dot * scale;
9759 if mk[bi * s + ki] < mask_thr {
9760 scores_buf[qi * s + ki] = mask_neg;
9761 }
9762 }
9763 }
9764
9765 crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
9767
9768 for qi in 0..s {
9770 let o_base = bi * s * h + qi * h + hi * d_h;
9771 for d in 0..d_h {
9772 attn_out[o_base + d] = 0.0;
9773 }
9774 for ki in 0..s {
9775 let sc = scores_buf[qi * s + ki];
9776 if sc > score_thr {
9777 let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9778 #[cfg(target_arch = "aarch64")]
9779 {
9780 use std::arch::aarch64::*;
9781 let vsc = vdupq_n_f32(sc);
9782 for c in 0..neon_chunks {
9783 let off = c * 4;
9784 let vo =
9785 vld1q_f32(attn_out.as_ptr().add(o_base + off));
9786 let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
9787 vst1q_f32(
9788 attn_out.as_mut_ptr().add(o_base + off),
9789 vfmaq_f32(vo, vsc, vv),
9790 );
9791 }
9792 }
9793 #[cfg(not(target_arch = "aarch64"))]
9794 for d in 0..d_h {
9795 attn_out[o_base + d] += sc * qkv[v_base + d];
9796 }
9797 }
9798 }
9799 }
9800 }
9801 }
9802
9803 crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
9805 if *has_bias {
9806 let bias = sl(*out_b, base, h);
9807 crate::blas::bias_add(dst, bias, m, h);
9808 }
9809 }
9810 }
9811
9812 Thunk::Rope {
9813 src,
9814 cos,
9815 sin,
9816 dst,
9817 batch,
9818 seq,
9819 hidden,
9820 head_dim,
9821 n_rot,
9822 cos_len,
9823 src_row_stride,
9824 } => {
9825 let (b, s, hs, dh, nr) = (
9826 *batch as usize,
9827 *seq as usize,
9828 *hidden as usize,
9829 *head_dim as usize,
9830 *n_rot as usize,
9831 );
9832 let tab_half = dh / 2;
9833 let rot_half = nr / 2;
9834 let nh = hs / dh;
9835 let cl = *cos_len as usize;
9836 let src_rs = *src_row_stride as usize;
9837 unsafe {
9838 let x = sl(*src, base, b * s * src_rs);
9839 let cos_tab = sl(*cos, base, cl);
9840 let sin_tab = sl(*sin, base, cl);
9841 let out = sl_mut(*dst, base, b * s * hs);
9842
9843 let total = b * s;
9844 let x_ptr = x.as_ptr() as usize;
9845 let o_ptr = out.as_mut_ptr() as usize;
9846 let c_ptr = cos_tab.as_ptr() as usize;
9847 let s_ptr = sin_tab.as_ptr() as usize;
9848
9849 crate::pool::par_for(total, 4, &|off, cnt| {
9850 for idx in off..off + cnt {
9851 let bi = idx / s;
9852 let si = idx % s;
9853 let tab_off = si * tab_half;
9854
9855 for hi in 0..nh {
9856 let src_base = bi * s * src_rs + si * src_rs + hi * dh;
9857 let dst_base = bi * s * hs + si * hs + hi * dh;
9858 let xp = (x_ptr as *const f32).add(src_base);
9859 let op = (o_ptr as *mut f32).add(dst_base);
9860 let cp = (c_ptr as *const f32).add(tab_off);
9861 let sp = (s_ptr as *const f32).add(tab_off);
9862
9863 for i in 0..rot_half {
9864 let x1 = *xp.add(i);
9865 let x2 = *xp.add(rot_half + i);
9866 let cv = *cp.add(i);
9867 let sv = *sp.add(i);
9868 *op.add(i) = x1 * cv - x2 * sv;
9869 *op.add(rot_half + i) = x2 * cv + x1 * sv;
9870 }
9871 for j in nr..dh {
9872 *op.add(j) = *xp.add(j);
9873 }
9874 }
9875 }
9876 });
9877 }
9878 }
9879 Thunk::FusedBertLayer {
9880 hidden,
9881 qkv_w,
9882 qkv_b,
9883 out_w,
9884 out_b,
9885 mask,
9886 ln1_g,
9887 ln1_b,
9888 eps1,
9889 fc1_w,
9890 fc1_b,
9891 fc2_w,
9892 fc2_b,
9893 ln2_g,
9894 ln2_b,
9895 eps2,
9896 out,
9897 batch,
9898 seq,
9899 hs,
9900 nh,
9901 dh,
9902 int_dim,
9903 } => {
9904 let (b, s, h, n_h, d_h) = (
9905 *batch as usize,
9906 *seq as usize,
9907 *hs as usize,
9908 *nh as usize,
9909 *dh as usize,
9910 );
9911 let m = b * s;
9912 let id = *int_dim as usize;
9913 let scale = (d_h as f32).powf(-0.5);
9914 let _half = d_h / 2;
9915 #[cfg(target_arch = "aarch64")]
9916 let neon_chunks = d_h / 4;
9917 unsafe {
9918 let inp = sl(*hidden, base, m * h);
9919 let dst = sl_mut(*out, base, m * h);
9920 let mk = sl(*mask, base, b * s);
9921
9922 let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
9924 let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
9925 let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
9926 let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
9927 let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
9928 let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
9929
9930 crate::blas::par_sgemm_bias(
9932 inp,
9933 sl(*qkv_w, base, h * 3 * h),
9934 sl(*qkv_b, base, 3 * h),
9935 qkv,
9936 m,
9937 h,
9938 3 * h,
9939 );
9940
9941 for bi in 0..b {
9943 for hi in 0..n_h {
9944 for qi in 0..s {
9945 for ki in 0..s {
9946 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9947 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9948 #[cfg(target_arch = "aarch64")]
9949 let dot;
9950 #[cfg(not(target_arch = "aarch64"))]
9951 let mut dot = 0f32;
9952 #[cfg(target_arch = "aarch64")]
9953 {
9954 use std::arch::aarch64::*;
9955 let mut acc = vdupq_n_f32(0.0);
9956 for c in 0..neon_chunks {
9957 acc = vfmaq_f32(
9958 acc,
9959 vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
9960 vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
9961 );
9962 }
9963 dot = vaddvq_f32(acc);
9964 }
9965 #[cfg(not(target_arch = "aarch64"))]
9966 for d in 0..d_h {
9967 dot += qkv[q_base + d] * qkv[k_base + d];
9968 }
9969 sc[qi * s + ki] = dot * scale;
9970 if mk[bi * s + ki] < mask_thr {
9971 sc[qi * s + ki] = mask_neg;
9972 }
9973 }
9974 }
9975 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
9976 for qi in 0..s {
9977 let o = bi * s * h + qi * h + hi * d_h;
9978 for d in 0..d_h {
9979 attn[o + d] = 0.0;
9980 }
9981 for ki in 0..s {
9982 let w = sc[qi * s + ki];
9983 if w > score_thr {
9984 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9985 #[cfg(target_arch = "aarch64")]
9986 {
9987 use std::arch::aarch64::*;
9988 let vw = vdupq_n_f32(w);
9989 for c in 0..neon_chunks {
9990 let off = c * 4;
9991 vst1q_f32(
9992 attn.as_mut_ptr().add(o + off),
9993 vfmaq_f32(
9994 vld1q_f32(attn.as_ptr().add(o + off)),
9995 vw,
9996 vld1q_f32(qkv.as_ptr().add(v + off)),
9997 ),
9998 );
9999 }
10000 }
10001 #[cfg(not(target_arch = "aarch64"))]
10002 for d in 0..d_h {
10003 attn[o + d] += w * qkv[v + d];
10004 }
10005 }
10006 }
10007 }
10008 }
10009 }
10010
10011 crate::blas::sgemm_bias(
10013 attn,
10014 sl(*out_w, base, h * h),
10015 sl(*out_b, base, h),
10016 res,
10017 m,
10018 h,
10019 h,
10020 );
10021 #[cfg(target_arch = "aarch64")]
10022 {
10023 use std::arch::aarch64::*;
10024 let chunks_h = (m * h) / 4;
10025 for c in 0..chunks_h {
10026 let off = c * 4;
10027 vst1q_f32(
10028 res.as_mut_ptr().add(off),
10029 vaddq_f32(
10030 vld1q_f32(res.as_ptr().add(off)),
10031 vld1q_f32(inp.as_ptr().add(off)),
10032 ),
10033 );
10034 }
10035 for i in (chunks_h * 4)..(m * h) {
10036 res[i] += inp[i];
10037 }
10038 }
10039 #[cfg(not(target_arch = "aarch64"))]
10040 for i in 0..m * h {
10041 res[i] += inp[i];
10042 }
10043
10044 let g1 = sl(*ln1_g, base, h);
10046 let b1 = sl(*ln1_b, base, h);
10047 for r in 0..m {
10048 crate::kernels::layer_norm_row(
10049 &res[r * h..(r + 1) * h],
10050 g1,
10051 b1,
10052 &mut normed[r * h..(r + 1) * h],
10053 h,
10054 *eps1,
10055 );
10056 }
10057
10058 crate::blas::par_sgemm_bias(
10060 normed,
10061 sl(*fc1_w, base, h * id),
10062 sl(*fc1_b, base, id),
10063 ffn,
10064 m,
10065 h,
10066 id,
10067 );
10068 crate::kernels::par_gelu_inplace(ffn);
10069
10070 crate::blas::par_sgemm_bias(
10072 ffn,
10073 sl(*fc2_w, base, id * h),
10074 sl(*fc2_b, base, h),
10075 res,
10076 m,
10077 id,
10078 h,
10079 );
10080 #[cfg(target_arch = "aarch64")]
10081 {
10082 use std::arch::aarch64::*;
10083 let chunks_h = (m * h) / 4;
10084 for c in 0..chunks_h {
10085 let off = c * 4;
10086 vst1q_f32(
10087 res.as_mut_ptr().add(off),
10088 vaddq_f32(
10089 vld1q_f32(res.as_ptr().add(off)),
10090 vld1q_f32(normed.as_ptr().add(off)),
10091 ),
10092 );
10093 }
10094 for i in (chunks_h * 4)..(m * h) {
10095 res[i] += normed[i];
10096 }
10097 }
10098 #[cfg(not(target_arch = "aarch64"))]
10099 for i in 0..m * h {
10100 res[i] += normed[i];
10101 }
10102
10103 let g2 = sl(*ln2_g, base, h);
10105 let b2 = sl(*ln2_b, base, h);
10106 for r in 0..m {
10107 crate::kernels::layer_norm_row(
10108 &res[r * h..(r + 1) * h],
10109 g2,
10110 b2,
10111 &mut dst[r * h..(r + 1) * h],
10112 h,
10113 *eps2,
10114 );
10115 }
10116 }
10117 }
10118
10119 Thunk::FusedNomicLayer {
10120 hidden,
10121 qkv_w,
10122 out_w,
10123 mask,
10124 cos,
10125 sin,
10126 cos_len,
10127 ln1_g,
10128 ln1_b,
10129 eps1,
10130 fc11_w,
10131 fc12_w: _,
10132 fc2_w,
10133 ln2_g,
10134 ln2_b,
10135 eps2,
10136 out,
10137 batch,
10138 seq,
10139 hs,
10140 nh,
10141 dh,
10142 int_dim,
10143 } => {
10144 let (b, s, h, n_h, d_h) = (
10145 *batch as usize,
10146 *seq as usize,
10147 *hs as usize,
10148 *nh as usize,
10149 *dh as usize,
10150 );
10151 let m = b * s;
10152 let id = *int_dim as usize;
10153 let scale = (d_h as f32).powf(-0.5);
10154 let half_dh = d_h / 2;
10155 #[cfg(target_arch = "aarch64")]
10156 let neon_chunks = d_h / 4;
10157 unsafe {
10158 let inp = sl(*hidden, base, m * h);
10159 let dst = sl_mut(*out, base, m * h);
10160 let mk = sl(*mask, base, b * s);
10161 let cos_tab = sl(*cos, base, *cos_len as usize);
10162 let sin_tab = sl(*sin, base, *cos_len as usize);
10163 let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
10165
10166 let mut qkv = vec![0f32; m * 3 * h];
10167 let mut attn = vec![0f32; m * h];
10168 let mut res = vec![0f32; m * h];
10169 let mut normed = vec![0f32; m * h];
10170 let mut ffn_concat = vec![0f32; m * 2 * id]; let mut sc = vec![0f32; s * s];
10172
10173 crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
10175
10176 for bi in 0..b {
10178 for hi in 0..n_h {
10179 for qi in 0..s {
10180 for ki in 0..s {
10181 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10182 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10183 let mut dot = 0f32;
10184 for i in 0..half_dh {
10185 let q1 = qkv[q_base + i];
10186 let q2 = qkv[q_base + half_dh + i];
10187 let k1 = qkv[k_base + i];
10188 let k2 = qkv[k_base + half_dh + i];
10189 let cq = cos_tab[qi * half_dh + i];
10190 let sq = sin_tab[qi * half_dh + i];
10191 let ck = cos_tab[ki * half_dh + i];
10192 let sk = sin_tab[ki * half_dh + i];
10193 dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
10194 + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
10195 }
10196 sc[qi * s + ki] = dot * scale;
10197 if mk[bi * s + ki] < mask_thr {
10198 sc[qi * s + ki] = mask_neg;
10199 }
10200 }
10201 }
10202 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
10203 for qi in 0..s {
10204 let o = bi * s * h + qi * h + hi * d_h;
10205 for d in 0..d_h {
10206 attn[o + d] = 0.0;
10207 }
10208 for ki in 0..s {
10209 let w = sc[qi * s + ki];
10210 if w > score_thr {
10211 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10212 #[cfg(target_arch = "aarch64")]
10213 {
10214 use std::arch::aarch64::*;
10215 let vw = vdupq_n_f32(w);
10216 for c in 0..neon_chunks {
10217 let off = c * 4;
10218 vst1q_f32(
10219 attn.as_mut_ptr().add(o + off),
10220 vfmaq_f32(
10221 vld1q_f32(attn.as_ptr().add(o + off)),
10222 vw,
10223 vld1q_f32(qkv.as_ptr().add(v + off)),
10224 ),
10225 );
10226 }
10227 }
10228 #[cfg(not(target_arch = "aarch64"))]
10229 for d in 0..d_h {
10230 attn[o + d] += w * qkv[v + d];
10231 }
10232 }
10233 }
10234 }
10235 }
10236 }
10237
10238 crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
10240 for i in 0..m * h {
10241 res[i] += inp[i];
10242 }
10243
10244 let g1 = sl(*ln1_g, base, h);
10246 let b1 = sl(*ln1_b, base, h);
10247 for r in 0..m {
10248 crate::kernels::layer_norm_row(
10249 &res[r * h..(r + 1) * h],
10250 g1,
10251 b1,
10252 &mut normed[r * h..(r + 1) * h],
10253 h,
10254 *eps1,
10255 );
10256 }
10257
10258 crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
10260 for row in 0..m {
10263 let bo = row * 2 * id;
10264 for j in 0..id {
10266 let x = ffn_concat[bo + id + j];
10267 ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
10268 }
10269 for j in 0..id {
10271 ffn_concat[bo + j] *= ffn_concat[bo + id + j];
10272 }
10273 }
10274
10275 crate::blas::sgemm_general(
10280 ffn_concat.as_ptr(),
10281 sl(*fc2_w, base, id * h).as_ptr(),
10282 res.as_mut_ptr(),
10283 m,
10284 h,
10285 id,
10286 1.0,
10287 0.0,
10288 2 * id,
10289 h,
10290 h,
10291 false,
10292 false,
10293 );
10294 for i in 0..m * h {
10295 res[i] += normed[i];
10296 }
10297
10298 let g2 = sl(*ln2_g, base, h);
10300 let b2 = sl(*ln2_b, base, h);
10301 for r in 0..m {
10302 crate::kernels::layer_norm_row(
10303 &res[r * h..(r + 1) * h],
10304 g2,
10305 b2,
10306 &mut dst[r * h..(r + 1) * h],
10307 h,
10308 *eps2,
10309 );
10310 }
10311 }
10312 }
10313
10314 Thunk::FusedSwiGLU {
10315 src,
10316 dst,
10317 n_half,
10318 total,
10319 gate_first,
10320 } => {
10321 let n = *n_half as usize;
10322 let t = *total as usize;
10323 let outer = t / n;
10324 let in_total = outer * 2 * n;
10325 let gate_first = *gate_first;
10326 unsafe {
10327 let inp = sl(*src, base, in_total);
10328 let out = sl_mut(*dst, base, t);
10329 for o in 0..outer {
10330 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
10331 let out_row = &mut out[o * n..(o + 1) * n];
10332 for i in 0..n {
10333 let (up, gate) = if gate_first {
10334 (in_row[n + i], in_row[i])
10335 } else {
10336 (in_row[i], in_row[n + i])
10337 };
10338 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
10339 }
10340 }
10341 }
10342 }
10343
10344 Thunk::Concat {
10345 dst,
10346 outer,
10347 inner,
10348 total_axis,
10349 inputs,
10350 } => {
10351 let outer = *outer as usize;
10352 let inner = *inner as usize;
10353 let total_axis = *total_axis as usize;
10354 let row_stride = total_axis * inner;
10355 let out_total = outer * row_stride;
10356 unsafe {
10357 let out = sl_mut(*dst, base, out_total);
10358 let mut cum: usize = 0;
10359 for (src_off, in_axis) in inputs {
10360 let in_axis = *in_axis as usize;
10361 let copy_per_row = in_axis * inner;
10362 let dst_col_off = cum * inner;
10363 let in_total = outer * copy_per_row;
10364 let inp = sl(*src_off, base, in_total);
10365 for o in 0..outer {
10366 let dst_row_start = o * row_stride + dst_col_off;
10367 let src_row_start = o * copy_per_row;
10368 out[dst_row_start..dst_row_start + copy_per_row]
10369 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10370 }
10371 cum += in_axis;
10372 }
10373 }
10374 }
10375
10376 Thunk::ConcatF64 {
10377 dst,
10378 outer,
10379 inner,
10380 total_axis,
10381 inputs,
10382 } => {
10383 let outer = *outer as usize;
10384 let inner = *inner as usize;
10385 let total_axis = *total_axis as usize;
10386 let row_stride = total_axis * inner;
10387 let out_total = outer * row_stride;
10388 unsafe {
10389 let out = sl_mut_f64(*dst, base, out_total);
10390 let mut cum: usize = 0;
10391 for (src_off, in_axis) in inputs {
10392 let in_axis = *in_axis as usize;
10393 let copy_per_row = in_axis * inner;
10394 let dst_col_off = cum * inner;
10395 let in_total = outer * copy_per_row;
10396 let inp = sl_f64(*src_off, base, in_total);
10397 for o in 0..outer {
10398 let dst_row_start = o * row_stride + dst_col_off;
10399 let src_row_start = o * copy_per_row;
10400 out[dst_row_start..dst_row_start + copy_per_row]
10401 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10402 }
10403 cum += in_axis;
10404 }
10405 }
10406 }
10407
10408 Thunk::Compare {
10409 lhs,
10410 rhs,
10411 dst,
10412 len,
10413 op,
10414 } => {
10415 let len = *len as usize;
10416 unsafe {
10417 let l = sl(*lhs, base, len);
10418 let r = sl(*rhs, base, len);
10419 let o = sl_mut(*dst, base, len);
10420 for i in 0..len {
10421 o[i] = match op {
10422 CmpOp::Eq => (l[i] == r[i]) as u32 as f32,
10423 CmpOp::Ne => (l[i] != r[i]) as u32 as f32,
10424 CmpOp::Lt => (l[i] < r[i]) as u32 as f32,
10425 CmpOp::Le => (l[i] <= r[i]) as u32 as f32,
10426 CmpOp::Gt => (l[i] > r[i]) as u32 as f32,
10427 CmpOp::Ge => (l[i] >= r[i]) as u32 as f32,
10428 };
10429 }
10430 }
10431 }
10432
10433 Thunk::Where {
10434 cond,
10435 on_true,
10436 on_false,
10437 dst,
10438 len,
10439 } => {
10440 let len = *len as usize;
10441 unsafe {
10442 let c = sl(*cond, base, len);
10443 let t = sl(*on_true, base, len);
10444 let e = sl(*on_false, base, len);
10445 let o = sl_mut(*dst, base, len);
10446 for i in 0..len {
10447 o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
10449 }
10450 }
10451 }
10452
10453 Thunk::ScatterAdd {
10454 updates,
10455 indices,
10456 dst,
10457 num_updates,
10458 out_dim,
10459 trailing,
10460 } => {
10461 let num_updates = *num_updates as usize;
10462 let out_dim = *out_dim as usize;
10463 let trailing = *trailing as usize;
10464 unsafe {
10465 let upd = sl(*updates, base, num_updates * trailing);
10466 let ids = sl(*indices, base, num_updates);
10467 let out = sl_mut(*dst, base, out_dim * trailing);
10468 for v in out.iter_mut() {
10470 *v = 0.0;
10471 }
10472 for i in 0..num_updates {
10473 let row = ids[i] as usize;
10474 debug_assert!(row < out_dim, "ScatterAdd index out of range");
10475 let src_off = i * trailing;
10476 let dst_off = row * trailing;
10477 for j in 0..trailing {
10478 out[dst_off + j] += upd[src_off + j];
10479 }
10480 }
10481 }
10482 }
10483
10484 Thunk::GroupedMatMul {
10485 input,
10486 weight,
10487 expert_idx,
10488 dst,
10489 m,
10490 k_dim,
10491 n,
10492 num_experts,
10493 } => {
10494 let m = *m as usize;
10495 let k_dim = *k_dim as usize;
10496 let n = *n as usize;
10497 let num_experts = *num_experts as usize;
10498 unsafe {
10499 let inp = sl(*input, base, m * k_dim);
10500 let wt = sl(*weight, base, num_experts * k_dim * n);
10501 let ids = sl(*expert_idx, base, m);
10502 let out = sl_mut(*dst, base, m * n);
10503
10504 let mut counts = vec![0usize; num_experts];
10507 for i in 0..m {
10508 let e = ids[i] as usize;
10509 debug_assert!(
10510 e < num_experts,
10511 "expert_idx out of range: {e} >= {num_experts}"
10512 );
10513 counts[e] += 1;
10514 }
10515 let mut offsets = vec![0usize; num_experts + 1];
10517 for e in 0..num_experts {
10518 offsets[e + 1] = offsets[e] + counts[e];
10519 }
10520 let mut packed_in = vec![0f32; m * k_dim];
10524 let mut original_pos = vec![0usize; m];
10525 let mut write_idx = vec![0usize; num_experts];
10526 for i in 0..m {
10527 let e = ids[i] as usize;
10528 let dst_row = offsets[e] + write_idx[e];
10529 packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
10530 .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
10531 original_pos[dst_row] = i;
10532 write_idx[e] += 1;
10533 }
10534
10535 let mut packed_out = vec![0f32; m * n];
10539 let expert_stride = k_dim * n;
10540 let gmm_ord = crate::moe_residency::next_gmm_ord();
10541 let moe_layer = gmm_ord / 3;
10542 for e in 0..num_experts {
10543 let count = counts[e];
10544 if count == 0 {
10545 continue;
10546 }
10547 crate::moe_residency::record_expert_tokens(moe_layer, e, count);
10548 let in_start = offsets[e];
10549 let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
10550 let w_slab: &[f32] =
10551 if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
10552 if let Some(ptr) =
10553 crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
10554 {
10555 std::slice::from_raw_parts(ptr, expert_stride)
10556 } else {
10557 &wt[e * expert_stride..(e + 1) * expert_stride]
10558 }
10559 } else {
10560 &wt[e * expert_stride..(e + 1) * expert_stride]
10561 };
10562 let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
10563 crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
10564 }
10565
10566 for packed_idx in 0..m {
10568 let i = original_pos[packed_idx];
10569 out[i * n..(i + 1) * n]
10570 .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
10571 }
10572 }
10573 }
10574
10575 Thunk::DequantGroupedMatMulGguf {
10576 input,
10577 w_q,
10578 expert_idx,
10579 dst,
10580 m,
10581 k_dim,
10582 n,
10583 num_experts,
10584 scheme,
10585 } => {
10586 let m = *m as usize;
10587 let k_dim = *k_dim as usize;
10588 let n = *n as usize;
10589 let num_experts = *num_experts as usize;
10590 let block_elems = scheme.gguf_block_size() as usize;
10591 let block_bytes = scheme.gguf_block_bytes() as usize;
10592 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10593 unsafe {
10594 let inp = sl(*input, base, m * k_dim);
10595 let wt = std::slice::from_raw_parts(
10596 base.add(*w_q) as *const u8,
10597 num_experts * slab_bytes,
10598 );
10599 let ids = sl(*expert_idx, base, m);
10600 let out = sl_mut(*dst, base, m * n);
10601 crate::gguf_matmul::gguf_grouped_matmul_bt(
10602 inp,
10603 wt,
10604 ids,
10605 out,
10606 m,
10607 k_dim,
10608 n,
10609 num_experts,
10610 *scheme,
10611 );
10612 }
10613 }
10614
10615 Thunk::DequantMoEWeightsGguf {
10616 w_q,
10617 dst,
10618 k_dim,
10619 n,
10620 num_experts,
10621 scheme,
10622 } => {
10623 let k_dim = *k_dim as usize;
10624 let n = *n as usize;
10625 let num_experts = *num_experts as usize;
10626 let block_elems = scheme.gguf_block_size() as usize;
10627 let block_bytes = scheme.gguf_block_bytes() as usize;
10628 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10629 unsafe {
10630 let wt = std::slice::from_raw_parts(
10631 base.add(*w_q) as *const u8,
10632 num_experts * slab_bytes,
10633 );
10634 let out = sl_mut(*dst, base, num_experts * k_dim * n);
10635 crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
10636 wt,
10637 out,
10638 num_experts,
10639 k_dim,
10640 n,
10641 *scheme,
10642 );
10643 }
10644 }
10645
10646 Thunk::TopK {
10647 src,
10648 dst,
10649 outer,
10650 axis_dim,
10651 k,
10652 } => {
10653 let outer = *outer as usize;
10654 let axis_dim = *axis_dim as usize;
10655 let k = *k as usize;
10656 unsafe {
10657 let inp = sl(*src, base, outer * axis_dim);
10658 let out = sl_mut(*dst, base, outer * k);
10659 let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
10663 for o in 0..outer {
10664 row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
10665 for ki in 0..k {
10666 let mut best_i = 0usize;
10668 let mut best_v = row_buf[0];
10669 for i in 1..axis_dim {
10670 let v = row_buf[i];
10671 if v > best_v {
10672 best_v = v;
10673 best_i = i;
10674 }
10675 }
10676 out[o * k + ki] = best_i as f32;
10677 row_buf[best_i] = f32::NEG_INFINITY;
10680 }
10681 }
10682 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
10683 cap.push_topk_f32(&out[..outer * k], axis_dim);
10684 }
10685 }
10686 }
10687
10688 Thunk::Reduce {
10689 src,
10690 dst,
10691 outer,
10692 reduced,
10693 inner,
10694 op,
10695 } => {
10696 let outer = *outer as usize;
10697 let reduced = *reduced as usize;
10698 let inner = *inner as usize;
10699 let in_total = outer * reduced * inner;
10700 let out_total = outer * inner;
10701 unsafe {
10702 let inp = sl(*src, base, in_total);
10703 let out = sl_mut(*dst, base, out_total);
10704 for o in 0..outer {
10705 for i in 0..inner {
10706 let mut acc = match op {
10707 ReduceOp::Max => f32::NEG_INFINITY,
10708 ReduceOp::Min => f32::INFINITY,
10709 ReduceOp::Prod => 1.0f32,
10710 _ => 0.0f32, };
10712 for r in 0..reduced {
10714 let v = inp[o * reduced * inner + r * inner + i];
10715 acc = match op {
10716 ReduceOp::Sum | ReduceOp::Mean => acc + v,
10717 ReduceOp::Max => acc.max(v),
10718 ReduceOp::Min => acc.min(v),
10719 ReduceOp::Prod => acc * v,
10720 };
10721 }
10722 if matches!(op, ReduceOp::Mean) {
10723 acc /= reduced as f32;
10724 }
10725 out[o * inner + i] = acc;
10726 }
10727 }
10728 }
10729 }
10730
10731 Thunk::Conv2D1x1 {
10732 src,
10733 weight,
10734 dst,
10735 n,
10736 c_in,
10737 c_out,
10738 hw,
10739 } => {
10740 let n = *n as usize;
10741 let c_in = *c_in as usize;
10742 let c_out = *c_out as usize;
10743 let hw = *hw as usize;
10744 unsafe {
10745 let inp = sl(*src, base, n * c_in * hw);
10746 let wt = sl(*weight, base, c_out * c_in);
10747 let out = sl_mut(*dst, base, n * c_out * hw);
10748 for ni in 0..n {
10753 let in_off = ni * c_in * hw;
10754 let out_off = ni * c_out * hw;
10755 crate::blas::sgemm(
10756 wt,
10757 &inp[in_off..in_off + c_in * hw],
10758 &mut out[out_off..out_off + c_out * hw],
10759 c_out,
10760 c_in,
10761 hw,
10762 );
10763 }
10764 }
10765 }
10766
10767 Thunk::Conv2D {
10768 src,
10769 weight,
10770 dst,
10771 n,
10772 c_in,
10773 h,
10774 w,
10775 c_out,
10776 h_out,
10777 w_out,
10778 kh,
10779 kw,
10780 sh,
10781 sw,
10782 ph,
10783 pw,
10784 dh,
10785 dw,
10786 groups,
10787 } => {
10788 let n = *n as usize;
10789 let c_in = *c_in as usize;
10790 let h = *h as usize;
10791 let w = *w as usize;
10792 let c_out = *c_out as usize;
10793 let h_out = *h_out as usize;
10794 let w_out = *w_out as usize;
10795 let kh = *kh as usize;
10796 let kw = *kw as usize;
10797 let sh = *sh as usize;
10798 let sw = *sw as usize;
10799 let ph = *ph as usize;
10800 let pw = *pw as usize;
10801 let dh = *dh as usize;
10802 let dw = *dw as usize;
10803 let groups = *groups as usize;
10804 let c_in_per_g = c_in / groups;
10805 let c_out_per_g = c_out / groups;
10806 unsafe {
10807 let inp = sl(*src, base, n * c_in * h * w);
10808 let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
10809 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
10810 for ni in 0..n {
10811 for co in 0..c_out {
10812 let g = co / c_out_per_g;
10813 let ci_start = g * c_in_per_g;
10814 for ho in 0..h_out {
10815 for wo in 0..w_out {
10816 let mut acc = 0f32;
10817 for ci_off in 0..c_in_per_g {
10818 let ci = ci_start + ci_off;
10819 let in_chan = ((ni * c_in) + ci) * h * w;
10820 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
10821 for ki in 0..kh {
10822 for kj in 0..kw {
10823 let hi = ho * sh + ki * dh;
10824 let wi = wo * sw + kj * dw;
10825 if hi < ph || wi < pw {
10826 continue;
10827 }
10828 let hi = hi - ph;
10829 let wi = wi - pw;
10830 if hi >= h || wi >= w {
10831 continue;
10832 }
10833 acc += inp[in_chan + hi * w + wi]
10834 * wt[wt_chan + ki * kw + kj];
10835 }
10836 }
10837 }
10838 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
10839 acc;
10840 }
10841 }
10842 }
10843 }
10844 }
10845 }
10846
10847 Thunk::Pool2D {
10848 src,
10849 dst,
10850 n,
10851 c,
10852 h,
10853 w,
10854 h_out,
10855 w_out,
10856 kh,
10857 kw,
10858 sh,
10859 sw,
10860 ph,
10861 pw,
10862 kind,
10863 } => {
10864 let n = *n as usize;
10865 let c = *c as usize;
10866 let h = *h as usize;
10867 let w = *w as usize;
10868 let h_out = *h_out as usize;
10869 let w_out = *w_out as usize;
10870 let kh = *kh as usize;
10871 let kw = *kw as usize;
10872 let sh = *sh as usize;
10873 let sw = *sw as usize;
10874 let ph = *ph as usize;
10875 let pw = *pw as usize;
10876 let kernel_area = (kh * kw) as f32;
10877 unsafe {
10878 let inp = sl(*src, base, n * c * h * w);
10879 let out = sl_mut(*dst, base, n * c * h_out * w_out);
10880 for ni in 0..n {
10881 for ci in 0..c {
10882 let in_chan = ni * c * h * w + ci * h * w;
10883 let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
10884 for ho in 0..h_out {
10885 for wo in 0..w_out {
10886 let mut acc = match kind {
10887 ReduceOp::Max => f32::NEG_INFINITY,
10888 _ => 0f32, };
10890 for ki in 0..kh {
10891 for kj in 0..kw {
10892 let hi = ho * sh + ki;
10893 let wi = wo * sw + kj;
10894 if hi < ph || wi < pw {
10896 continue;
10897 }
10898 let hi = hi - ph;
10899 let wi = wi - pw;
10900 if hi >= h || wi >= w {
10901 continue;
10902 }
10903 let v = inp[in_chan + hi * w + wi];
10904 match kind {
10905 ReduceOp::Max => acc = acc.max(v),
10906 _ => acc += v,
10907 }
10908 }
10909 }
10910 if matches!(kind, ReduceOp::Mean) {
10911 acc /= kernel_area;
10912 }
10913 out[out_chan + ho * w_out + wo] = acc;
10914 }
10915 }
10916 }
10917 }
10918 }
10919 }
10920
10921 Thunk::ReluBackward { x, dy, dx, len } => {
10922 let len = *len as usize;
10923 unsafe {
10924 let xs = sl(*x, base, len);
10925 let dys = sl(*dy, base, len);
10926 let out = sl_mut(*dx, base, len);
10927 for i in 0..len {
10928 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10929 }
10930 }
10931 }
10932
10933 Thunk::ReluBackwardF64 { x, dy, dx, len } => {
10934 let len = *len as usize;
10935 unsafe {
10936 let xs = sl_f64(*x, base, len);
10937 let dys = sl_f64(*dy, base, len);
10938 let out = sl_mut_f64(*dx, base, len);
10939 for i in 0..len {
10940 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10941 }
10942 }
10943 }
10944
10945 Thunk::QMatMul {
10946 x,
10947 w,
10948 bias,
10949 out,
10950 m,
10951 k,
10952 n,
10953 x_zp,
10954 w_zp,
10955 out_zp,
10956 mult,
10957 } => {
10958 let m = *m as usize;
10959 let k = *k as usize;
10960 let n = *n as usize;
10961 unsafe {
10962 let x_ptr = base.add(*x) as *const i8;
10963 let w_ptr = base.add(*w) as *const i8;
10964 let bias_ptr = base.add(*bias) as *const i32;
10965 let out_ptr = base.add(*out) as *mut i8;
10966 for mi in 0..m {
10967 for ni in 0..n {
10968 let mut acc: i32 = *bias_ptr.add(ni);
10969 for ki in 0..k {
10970 let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
10971 let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
10972 acc += xv * wv;
10973 }
10974 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
10977 let r = r.clamp(-128, 127) as i8;
10978 *out_ptr.add(mi * n + ni) = r;
10979 }
10980 }
10981 }
10982 }
10983
10984 Thunk::QConv2d {
10985 x,
10986 w,
10987 bias,
10988 out,
10989 n,
10990 c_in,
10991 h,
10992 w_in,
10993 c_out,
10994 h_out,
10995 w_out,
10996 kh,
10997 kw,
10998 sh,
10999 sw,
11000 ph,
11001 pw,
11002 dh,
11003 dw,
11004 groups,
11005 x_zp,
11006 w_zp,
11007 out_zp,
11008 mult,
11009 } => {
11010 let n = *n as usize;
11011 let c_in = *c_in as usize;
11012 let h = *h as usize;
11013 let w_in = *w_in as usize;
11014 let c_out = *c_out as usize;
11015 let h_out = *h_out as usize;
11016 let w_out = *w_out as usize;
11017 let kh = *kh as usize;
11018 let kw = *kw as usize;
11019 let sh = *sh as usize;
11020 let sw = *sw as usize;
11021 let ph = *ph as usize;
11022 let pw = *pw as usize;
11023 let dh = *dh as usize;
11024 let dw = *dw as usize;
11025 let groups = *groups as usize;
11026 let c_in_per_g = c_in / groups;
11027 let c_out_per_g = c_out / groups;
11028 unsafe {
11029 let x_ptr = base.add(*x) as *const i8;
11030 let w_ptr = base.add(*w) as *const i8;
11031 let bias_ptr = base.add(*bias) as *const i32;
11032 let out_ptr = base.add(*out) as *mut i8;
11033 for ni in 0..n {
11034 for co in 0..c_out {
11035 let g = co / c_out_per_g;
11036 let ci_start = g * c_in_per_g;
11037 for ho in 0..h_out {
11038 for wo in 0..w_out {
11039 let mut acc: i32 = *bias_ptr.add(co);
11040 for ci_off in 0..c_in_per_g {
11041 let ci = ci_start + ci_off;
11042 let in_chan = ((ni * c_in) + ci) * h * w_in;
11043 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
11044 for ki in 0..kh {
11045 for kj in 0..kw {
11046 let hi = ho * sh + ki * dh;
11047 let wi = wo * sw + kj * dw;
11048 if hi < ph || wi < pw {
11049 continue;
11050 }
11051 let hi = hi - ph;
11052 let wi = wi - pw;
11053 if hi >= h || wi >= w_in {
11054 continue;
11055 }
11056 let xv = *x_ptr.add(in_chan + hi * w_in + wi)
11057 as i32
11058 - *x_zp;
11059 let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
11060 - *w_zp;
11061 acc += xv * wv;
11062 }
11063 }
11064 }
11065 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
11066 let r = r.clamp(-128, 127) as i8;
11067 let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
11068 *out_ptr.add(dst) = r;
11069 }
11070 }
11071 }
11072 }
11073 }
11074 }
11075
11076 Thunk::Quantize {
11077 x,
11078 q,
11079 len,
11080 chan_axis: _,
11081 chan_dim,
11082 inner,
11083 scales,
11084 zero_points,
11085 } => {
11086 let len = *len as usize;
11087 let chan_dim = *chan_dim as usize;
11088 let inner = *inner as usize;
11089 unsafe {
11090 let xs = sl(*x, base, len);
11091 let q_ptr = base.add(*q) as *mut i8;
11092 for i in 0..len {
11093 let c = if chan_dim == 1 {
11094 0
11095 } else {
11096 (i / inner) % chan_dim
11097 };
11098 let inv_scale = 1.0 / scales[c];
11099 let zp = zero_points[c];
11100 let v = (xs[i] * inv_scale).round() as i32 + zp;
11101 *q_ptr.add(i) = v.clamp(-128, 127) as i8;
11102 }
11103 }
11104 }
11105
11106 Thunk::Dequantize {
11107 q,
11108 x,
11109 len,
11110 chan_axis: _,
11111 chan_dim,
11112 inner,
11113 scales,
11114 zero_points,
11115 } => {
11116 let len = *len as usize;
11117 let chan_dim = *chan_dim as usize;
11118 let inner = *inner as usize;
11119 unsafe {
11120 let q_ptr = base.add(*q) as *const i8;
11121 let out = sl_mut(*x, base, len);
11122 for i in 0..len {
11123 let c = if chan_dim == 1 {
11124 0
11125 } else {
11126 (i / inner) % chan_dim
11127 };
11128 let scale = scales[c];
11129 let zp = zero_points[c];
11130 let qv = *q_ptr.add(i) as i32;
11131 out[i] = (qv - zp) as f32 * scale;
11132 }
11133 }
11134 }
11135
11136 Thunk::FakeQuantize {
11137 x,
11138 out,
11139 len,
11140 chan_axis: _,
11141 chan_dim,
11142 inner,
11143 bits,
11144 ste: _,
11145 scale_mode,
11146 state_off,
11147 } => {
11148 use rlx_ir::op::ScaleMode;
11149 let len = *len as usize;
11150 let chan_dim = *chan_dim as usize;
11151 let inner = *inner as usize;
11152 let q_max: f32 = match *bits {
11153 8 => 127.0,
11154 4 => 7.0,
11155 2 => 1.0,
11156 n => panic!("FakeQuantize: unsupported bits {n}"),
11157 };
11158 unsafe {
11159 let xs = sl(*x, base, len);
11160 let outs = sl_mut(*out, base, len);
11161
11162 let mut scale = vec![0f32; chan_dim];
11163 match scale_mode {
11164 ScaleMode::PerBatch => {
11165 let mut max_abs = vec![0f32; chan_dim];
11166 for i in 0..len {
11167 let c = if chan_dim == 1 {
11168 0
11169 } else {
11170 (i / inner) % chan_dim
11171 };
11172 let a = xs[i].abs();
11173 if a > max_abs[c] {
11174 max_abs[c] = a;
11175 }
11176 }
11177 for c in 0..chan_dim {
11178 scale[c] = (max_abs[c] / q_max).max(1e-12);
11179 }
11180 }
11181 ScaleMode::EMA { decay } => {
11182 let mut max_abs = vec![0f32; chan_dim];
11185 for i in 0..len {
11186 let c = if chan_dim == 1 {
11187 0
11188 } else {
11189 (i / inner) % chan_dim
11190 };
11191 let a = xs[i].abs();
11192 if a > max_abs[c] {
11193 max_abs[c] = a;
11194 }
11195 }
11196 let state =
11197 sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
11198 for c in 0..chan_dim {
11199 let cur = (max_abs[c] / q_max).max(1e-12);
11200 let blended = if state[c] <= 0.0 {
11202 cur
11203 } else {
11204 *decay * state[c] + (1.0 - *decay) * cur
11205 };
11206 state[c] = blended;
11207 scale[c] = blended;
11208 }
11209 }
11210 ScaleMode::Fixed => {
11211 let state =
11212 sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
11213 for c in 0..chan_dim {
11214 scale[c] = state[c].max(1e-12);
11215 }
11216 }
11217 }
11218
11219 for i in 0..len {
11220 let c = if chan_dim == 1 {
11221 0
11222 } else {
11223 (i / inner) % chan_dim
11224 };
11225 let s = scale[c];
11226 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11227 outs[i] = qv * s;
11228 }
11229 }
11230 }
11231
11232 Thunk::ActivationBackward {
11233 x,
11234 dy,
11235 dx,
11236 len,
11237 kind,
11238 } => {
11239 let len = *len as usize;
11240 unsafe {
11241 let xs = sl(*x, base, len);
11242 let dys = sl(*dy, base, len);
11243 let out = sl_mut(*dx, base, len);
11244 activation_backward_kernel(*kind, xs, dys, out);
11245 }
11246 }
11247
11248 Thunk::ActivationBackwardF64 {
11249 x,
11250 dy,
11251 dx,
11252 len,
11253 kind,
11254 } => {
11255 let len = *len as usize;
11256 unsafe {
11257 let xs = sl_f64(*x, base, len);
11258 let dys = sl_f64(*dy, base, len);
11259 let out = sl_mut_f64(*dx, base, len);
11260 activation_backward_kernel_f64(*kind, xs, dys, out);
11261 }
11262 }
11263
11264 Thunk::FakeQuantizeLSQ {
11265 x,
11266 scale_off,
11267 out,
11268 len,
11269 chan_axis: _,
11270 chan_dim,
11271 inner,
11272 bits,
11273 } => {
11274 let len = *len as usize;
11275 let chan_dim = *chan_dim as usize;
11276 let inner = *inner as usize;
11277 let q_max: f32 = match *bits {
11278 8 => 127.0,
11279 4 => 7.0,
11280 2 => 1.0,
11281 n => panic!("FakeQuantizeLSQ: bad bits {n}"),
11282 };
11283 unsafe {
11284 let xs = sl(*x, base, len);
11285 let scale = sl(*scale_off, base, chan_dim);
11286 let outs = sl_mut(*out, base, len);
11287 for i in 0..len {
11288 let c = if chan_dim == 1 {
11289 0
11290 } else {
11291 (i / inner) % chan_dim
11292 };
11293 let s = scale[c].max(1e-12);
11294 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11295 outs[i] = qv * s;
11296 }
11297 }
11298 }
11299
11300 Thunk::FakeQuantizeLSQBackwardX {
11301 x,
11302 scale_off,
11303 dy,
11304 dx,
11305 len,
11306 chan_axis: _,
11307 chan_dim,
11308 inner,
11309 bits,
11310 } => {
11311 let len = *len as usize;
11312 let chan_dim = *chan_dim as usize;
11313 let inner = *inner as usize;
11314 let q_max: f32 = match *bits {
11315 8 => 127.0,
11316 4 => 7.0,
11317 2 => 1.0,
11318 n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
11319 };
11320 unsafe {
11321 let xs = sl(*x, base, len);
11322 let scale = sl(*scale_off, base, chan_dim);
11323 let dys = sl(*dy, base, len);
11324 let outs = sl_mut(*dx, base, len);
11325 for i in 0..len {
11327 let c = if chan_dim == 1 {
11328 0
11329 } else {
11330 (i / inner) % chan_dim
11331 };
11332 let z = xs[i] / scale[c].max(1e-12);
11333 outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
11334 }
11335 }
11336 }
11337
11338 Thunk::FakeQuantizeLSQBackwardScale {
11339 x,
11340 scale_off,
11341 dy,
11342 dscale,
11343 len,
11344 chan_axis: _,
11345 chan_dim,
11346 inner,
11347 bits,
11348 } => {
11349 let len = *len as usize;
11350 let chan_dim = *chan_dim as usize;
11351 let inner = *inner as usize;
11352 let q_max: f32 = match *bits {
11353 8 => 127.0,
11354 4 => 7.0,
11355 2 => 1.0,
11356 n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
11357 };
11358 unsafe {
11359 let xs = sl(*x, base, len);
11360 let scale = sl(*scale_off, base, chan_dim);
11361 let dys = sl(*dy, base, len);
11362 let outs = sl_mut(*dscale, base, chan_dim);
11363 for v in outs.iter_mut() {
11364 *v = 0.0;
11365 }
11366 for i in 0..len {
11369 let c = if chan_dim == 1 {
11370 0
11371 } else {
11372 (i / inner) % chan_dim
11373 };
11374 let s = scale[c].max(1e-12);
11375 let z = xs[i] / s;
11376 let psi = if z.abs() <= q_max {
11377 -z + z.round()
11378 } else if z > 0.0 {
11379 q_max
11380 } else {
11381 -q_max
11382 };
11383 outs[c] += psi * dys[i];
11384 }
11385 }
11386 }
11387
11388 Thunk::FakeQuantizeBackward {
11389 x,
11390 dy,
11391 dx,
11392 len,
11393 chan_axis: _,
11394 chan_dim,
11395 inner,
11396 bits,
11397 ste,
11398 } => {
11399 use rlx_ir::op::SteKind;
11400 let len = *len as usize;
11401 let chan_dim = *chan_dim as usize;
11402 let inner = *inner as usize;
11403 let q_max: f32 = match *bits {
11404 8 => 127.0,
11405 4 => 7.0,
11406 2 => 1.0,
11407 n => panic!("FakeQuantizeBackward: bad bits {n}"),
11408 };
11409 unsafe {
11410 let xs = sl(*x, base, len);
11411 let dys = sl(*dy, base, len);
11412 let outs = sl_mut(*dx, base, len);
11413
11414 let mut max_abs = vec![0f32; chan_dim];
11416 for i in 0..len {
11417 let c = if chan_dim == 1 {
11418 0
11419 } else {
11420 (i / inner) % chan_dim
11421 };
11422 let a = xs[i].abs();
11423 if a > max_abs[c] {
11424 max_abs[c] = a;
11425 }
11426 }
11427 let mut scale = vec![0f32; chan_dim];
11428 for c in 0..chan_dim {
11429 scale[c] = (max_abs[c] / q_max).max(1e-12);
11430 }
11431
11432 match *ste {
11433 SteKind::Identity => {
11434 outs.copy_from_slice(dys);
11436 }
11437 SteKind::ClippedIdentity => {
11438 for i in 0..len {
11441 let c = if chan_dim == 1 {
11442 0
11443 } else {
11444 (i / inner) % chan_dim
11445 };
11446 let bound = q_max * scale[c];
11447 outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
11448 }
11449 }
11450 SteKind::Tanh => {
11451 for i in 0..len {
11453 let c = if chan_dim == 1 {
11454 0
11455 } else {
11456 (i / inner) % chan_dim
11457 };
11458 let t = (xs[i] / scale[c]).tanh();
11459 outs[i] = dys[i] * (1.0 - t * t);
11460 }
11461 }
11462 SteKind::HardTanh => {
11463 for i in 0..len {
11465 let c = if chan_dim == 1 {
11466 0
11467 } else {
11468 (i / inner) % chan_dim
11469 };
11470 let bound = q_max * scale[c];
11471 let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
11472 outs[i] = dys[i] * attenuation;
11473 }
11474 }
11475 }
11476 }
11477 }
11478
11479 Thunk::LayerNormBackwardInput {
11480 x,
11481 gamma,
11482 dy,
11483 dx,
11484 rows,
11485 h,
11486 eps,
11487 } => {
11488 let rows = *rows as usize;
11489 let h = *h as usize;
11490 let eps = *eps;
11491 unsafe {
11492 let xs = sl(*x, base, rows * h);
11493 let g = sl(*gamma, base, h);
11494 let dys = sl(*dy, base, rows * h);
11495 let out = sl_mut(*dx, base, rows * h);
11496 let n_inv = 1.0 / h as f32;
11497 for r in 0..rows {
11498 let xr = &xs[r * h..(r + 1) * h];
11499 let dyr = &dys[r * h..(r + 1) * h];
11500 let mut sum = 0f32;
11503 for &v in xr {
11504 sum += v;
11505 }
11506 let mean = sum * n_inv;
11507 let mut var = 0f32;
11508 for &v in xr {
11509 let d = v - mean;
11510 var += d * d;
11511 }
11512 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11513
11514 let mut s_sy = 0f32;
11517 let mut s_sxh = 0f32;
11518 for d in 0..h {
11519 let xh = (xr[d] - mean) * inv_std;
11520 let sy = dyr[d] * g[d];
11521 s_sy += sy;
11522 s_sxh += sy * xh;
11523 }
11524 let m_sy = s_sy * n_inv;
11525 let m_sxh = s_sxh * n_inv;
11526
11527 for d in 0..h {
11528 let xh = (xr[d] - mean) * inv_std;
11529 let sy = dyr[d] * g[d];
11530 out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
11531 }
11532 }
11533 }
11534 }
11535
11536 Thunk::LayerNormBackwardGamma {
11537 x,
11538 dy,
11539 dgamma,
11540 rows,
11541 h,
11542 eps,
11543 } => {
11544 let rows = *rows as usize;
11545 let h = *h as usize;
11546 let eps = *eps;
11547 unsafe {
11548 let xs = sl(*x, base, rows * h);
11549 let dys = sl(*dy, base, rows * h);
11550 let out = sl_mut(*dgamma, base, h);
11551 for v in out.iter_mut() {
11552 *v = 0.0;
11553 }
11554 let n_inv = 1.0 / h as f32;
11555 for r in 0..rows {
11556 let xr = &xs[r * h..(r + 1) * h];
11557 let dyr = &dys[r * h..(r + 1) * h];
11558 let mut sum = 0f32;
11559 for &v in xr {
11560 sum += v;
11561 }
11562 let mean = sum * n_inv;
11563 let mut var = 0f32;
11564 for &v in xr {
11565 let d = v - mean;
11566 var += d * d;
11567 }
11568 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11569 for d in 0..h {
11570 let xh = (xr[d] - mean) * inv_std;
11571 out[d] += dyr[d] * xh;
11572 }
11573 }
11574 }
11575 }
11576
11577 Thunk::RmsNormBackwardInput {
11578 x,
11579 gamma,
11580 beta,
11581 dy,
11582 dx,
11583 rows,
11584 h,
11585 eps,
11586 } => {
11587 let (rows, h) = (*rows as usize, *h as usize);
11588 unsafe {
11589 let xs = sl(*x, base, rows * h);
11590 let g = sl(*gamma, base, h);
11591 let b = sl(*beta, base, h);
11592 let dys = sl(*dy, base, rows * h);
11593 let out = sl_mut(*dx, base, rows * h);
11594 let mut dg = vec![0f32; h];
11595 let mut db = vec![0f32; h];
11596 for r in 0..rows {
11597 crate::training_bwd::rms_norm_backward_row(
11598 &xs[r * h..(r + 1) * h],
11599 g,
11600 b,
11601 &dys[r * h..(r + 1) * h],
11602 &mut out[r * h..(r + 1) * h],
11603 &mut dg,
11604 &mut db,
11605 *eps,
11606 );
11607 }
11608 }
11609 }
11610
11611 Thunk::RmsNormBackwardGamma {
11612 x,
11613 gamma,
11614 beta,
11615 dy,
11616 dgamma,
11617 rows,
11618 h,
11619 eps,
11620 } => {
11621 let (rows, h) = (*rows as usize, *h as usize);
11622 unsafe {
11623 let xs = sl(*x, base, rows * h);
11624 let g = sl(*gamma, base, h);
11625 let b = sl(*beta, base, h);
11626 let dys = sl(*dy, base, rows * h);
11627 let out = sl_mut(*dgamma, base, h);
11628 for v in out.iter_mut() {
11629 *v = 0.0;
11630 }
11631 let mut dx = vec![0f32; h];
11632 let mut db = vec![0f32; h];
11633 for r in 0..rows {
11634 crate::training_bwd::rms_norm_backward_row(
11635 &xs[r * h..(r + 1) * h],
11636 g,
11637 b,
11638 &dys[r * h..(r + 1) * h],
11639 &mut dx,
11640 &mut *out,
11641 &mut db,
11642 *eps,
11643 );
11644 }
11645 }
11646 }
11647
11648 Thunk::RmsNormBackwardBeta {
11649 x,
11650 gamma,
11651 beta,
11652 dy,
11653 dbeta,
11654 rows,
11655 h,
11656 eps,
11657 } => {
11658 let (rows, h) = (*rows as usize, *h as usize);
11659 unsafe {
11660 let xs = sl(*x, base, rows * h);
11661 let g = sl(*gamma, base, h);
11662 let b = sl(*beta, base, h);
11663 let dys = sl(*dy, base, rows * h);
11664 let out = sl_mut(*dbeta, base, h);
11665 for v in out.iter_mut() {
11666 *v = 0.0;
11667 }
11668 let mut dx = vec![0f32; h];
11669 let mut dg = vec![0f32; h];
11670 for r in 0..rows {
11671 crate::training_bwd::rms_norm_backward_row(
11672 &xs[r * h..(r + 1) * h],
11673 g,
11674 b,
11675 &dys[r * h..(r + 1) * h],
11676 &mut dx,
11677 &mut dg,
11678 &mut *out,
11679 *eps,
11680 );
11681 }
11682 }
11683 }
11684
11685 Thunk::RopeBackward {
11686 dy,
11687 cos,
11688 sin,
11689 dx,
11690 batch,
11691 seq,
11692 hidden,
11693 head_dim,
11694 n_rot,
11695 cos_len,
11696 } => {
11697 let (b, s, hs, dh, nr, cl) = (
11698 *batch as usize,
11699 *seq as usize,
11700 *hidden as usize,
11701 *head_dim as usize,
11702 *n_rot as usize,
11703 *cos_len as usize,
11704 );
11705 let nh = hs / dh;
11706 let tab_half = dh / 2;
11707 unsafe {
11708 let dys = sl(*dy, base, b * s * hs);
11709 let cos_tab = sl(*cos, base, cl);
11710 let sin_tab = sl(*sin, base, cl);
11711 let out = sl_mut(*dx, base, b * s * hs);
11712 for bi in 0..b {
11713 for si in 0..s {
11714 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
11715 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
11716 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
11717 for hi in 0..nh {
11718 let base_idx = bi * s * hs + si * hs + hi * dh;
11719 crate::training_bwd::rope_backward_row(
11720 &dys[base_idx..base_idx + dh],
11721 cp,
11722 sp,
11723 &mut out[base_idx..base_idx + dh],
11724 dh,
11725 nr,
11726 );
11727 }
11728 }
11729 }
11730 }
11731 }
11732
11733 Thunk::CumsumBackward {
11734 dy,
11735 dx,
11736 rows,
11737 cols,
11738 exclusive,
11739 } => {
11740 let (rows, cols) = (*rows as usize, *cols as usize);
11741 unsafe {
11742 let dys = sl(*dy, base, rows * cols);
11743 let out = sl_mut(*dx, base, rows * cols);
11744 for r in 0..rows {
11745 crate::training_bwd::cumsum_backward_row(
11746 &dys[r * cols..(r + 1) * cols],
11747 &mut out[r * cols..(r + 1) * cols],
11748 *exclusive,
11749 );
11750 }
11751 }
11752 }
11753
11754 Thunk::GroupNormBackwardInput {
11755 x,
11756 gamma,
11757 beta: _beta,
11758 dy,
11759 dx,
11760 n,
11761 c,
11762 h,
11763 w,
11764 num_groups,
11765 eps,
11766 } => {
11767 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11768 let plane = c * h * w;
11769 unsafe {
11770 let xs = sl(*x, base, n * plane);
11771 let g = sl(*gamma, base, c);
11772 let dys = sl(*dy, base, n * plane);
11773 let out = sl_mut(*dx, base, n * plane);
11774 crate::training_bwd::group_norm_backward_input_nchw(
11775 xs,
11776 g,
11777 dys,
11778 out,
11779 n,
11780 c,
11781 h,
11782 w,
11783 *num_groups as usize,
11784 *eps,
11785 );
11786 }
11787 }
11788
11789 Thunk::GroupNormBackwardGamma {
11790 x,
11791 dy,
11792 dgamma,
11793 n,
11794 c,
11795 h,
11796 w,
11797 num_groups,
11798 eps,
11799 } => {
11800 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11801 let plane = c * h * w;
11802 unsafe {
11803 let xs = sl(*x, base, n * plane);
11804 let dys = sl(*dy, base, n * plane);
11805 let out = sl_mut(*dgamma, base, c);
11806 crate::training_bwd::group_norm_backward_gamma_nchw(
11807 xs,
11808 dys,
11809 out,
11810 n,
11811 c,
11812 h,
11813 w,
11814 *num_groups as usize,
11815 *eps,
11816 );
11817 }
11818 }
11819
11820 Thunk::GroupNormBackwardBeta {
11821 dy,
11822 dbeta,
11823 n,
11824 c,
11825 h,
11826 w,
11827 } => {
11828 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11829 let plane = c * h * w;
11830 unsafe {
11831 let dys = sl(*dy, base, n * plane);
11832 let out = sl_mut(*dbeta, base, c);
11833 crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
11834 }
11835 }
11836
11837 Thunk::GatherBackward {
11838 dy,
11839 indices,
11840 dst,
11841 outer,
11842 axis_dim,
11843 num_idx,
11844 trailing,
11845 } => {
11846 let (outer, axis_dim, num_idx, trailing) = (
11847 *outer as usize,
11848 *axis_dim as usize,
11849 *num_idx as usize,
11850 *trailing as usize,
11851 );
11852 unsafe {
11853 let dys = sl(*dy, base, outer * num_idx * trailing);
11854 let ids = sl(*indices, base, num_idx);
11855 let out = sl_mut(*dst, base, outer * axis_dim * trailing);
11856 for v in out.iter_mut() {
11857 *v = 0.0;
11858 }
11859 crate::training_bwd::gather_axis_backward(
11860 dys, ids, out, outer, axis_dim, num_idx, trailing,
11861 );
11862 }
11863 }
11864
11865 Thunk::MaxPool2dBackward {
11866 x,
11867 dy,
11868 dx,
11869 n,
11870 c,
11871 h,
11872 w,
11873 h_out,
11874 w_out,
11875 kh,
11876 kw,
11877 sh,
11878 sw,
11879 ph,
11880 pw,
11881 } => {
11882 let n = *n as usize;
11883 let c = *c as usize;
11884 let h = *h as usize;
11885 let w = *w as usize;
11886 let h_out = *h_out as usize;
11887 let w_out = *w_out as usize;
11888 let kh = *kh as usize;
11889 let kw = *kw as usize;
11890 let sh = *sh as usize;
11891 let sw = *sw as usize;
11892 let ph = *ph as usize;
11893 let pw = *pw as usize;
11894 unsafe {
11895 let xs = sl(*x, base, n * c * h * w);
11896 let dys = sl(*dy, base, n * c * h_out * w_out);
11897 let dxs = sl_mut(*dx, base, n * c * h * w);
11898 for v in dxs.iter_mut() {
11901 *v = 0.0;
11902 }
11903 for ni in 0..n {
11904 for ci in 0..c {
11905 let in_chan = (ni * c + ci) * h * w;
11906 let out_chan = (ni * c + ci) * h_out * w_out;
11907 for ho in 0..h_out {
11908 for wo in 0..w_out {
11909 let mut best_v = f32::NEG_INFINITY;
11911 let mut best_idx: Option<usize> = None;
11912 for ki in 0..kh {
11913 for kj in 0..kw {
11914 let hi = ho * sh + ki;
11915 let wi = wo * sw + kj;
11916 if hi < ph || wi < pw {
11917 continue;
11918 }
11919 let hi = hi - ph;
11920 let wi = wi - pw;
11921 if hi >= h || wi >= w {
11922 continue;
11923 }
11924 let idx = in_chan + hi * w + wi;
11925 let v = xs[idx];
11926 if v > best_v {
11930 best_v = v;
11931 best_idx = Some(idx);
11932 }
11933 }
11934 }
11935 if let Some(idx) = best_idx {
11936 dxs[idx] += dys[out_chan + ho * w_out + wo];
11937 }
11938 }
11939 }
11940 }
11941 }
11942 }
11943 }
11944
11945 Thunk::Conv2dBackwardInput {
11946 dy,
11947 w,
11948 dx,
11949 n,
11950 c_in,
11951 h,
11952 w_in,
11953 c_out,
11954 h_out,
11955 w_out,
11956 kh,
11957 kw,
11958 sh,
11959 sw,
11960 ph,
11961 pw,
11962 dh,
11963 dw,
11964 groups,
11965 } => {
11966 let n = *n as usize;
11978 let c_in = *c_in as usize;
11979 let h = *h as usize;
11980 let w_in = *w_in as usize;
11981 let c_out = *c_out as usize;
11982 let h_out = *h_out as usize;
11983 let w_out = *w_out as usize;
11984 let kh = *kh as usize;
11985 let kw = *kw as usize;
11986 let sh = *sh as usize;
11987 let sw = *sw as usize;
11988 let ph = *ph as usize;
11989 let pw = *pw as usize;
11990 let dh = *dh as usize;
11991 let dw = *dw as usize;
11992 let groups = *groups as usize;
11993 let c_in_per_g = c_in / groups;
11994 let c_out_per_g = c_out / groups;
11995
11996 let m_dim = c_in_per_g * kh * kw;
11997 let n_dim = h_out * w_out;
11998 let k_dim = c_out_per_g;
11999
12000 let dy_stride_n = c_out * h_out * w_out;
12001 let dy_stride_g = c_out_per_g * h_out * w_out;
12002 let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12003 let dx_stride_n = c_in * h * w_in;
12004 let dx_stride_g = c_in_per_g * h * w_in;
12005
12006 unsafe {
12007 let dys = sl(*dy, base, n * c_out * h_out * w_out);
12008 let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
12009 let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
12010 for v in dxs.iter_mut() {
12011 *v = 0.0;
12012 }
12013
12014 let mut dcol = vec![0f32; m_dim * n_dim];
12016
12017 for ni in 0..n {
12018 for g in 0..groups {
12019 let w_g_off = g * w_stride_g;
12020 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12021 let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
12022
12023 crate::blas::sgemm_general(
12028 ws.as_ptr().add(w_g_off),
12029 dys.as_ptr().add(dy_n_g_off),
12030 dcol.as_mut_ptr(),
12031 m_dim,
12032 n_dim,
12033 k_dim,
12034 1.0,
12035 0.0,
12036 m_dim,
12037 n_dim,
12038 n_dim,
12039 true,
12040 false,
12041 );
12042
12043 col2im(
12045 &dcol,
12046 &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
12047 c_in_per_g,
12048 h,
12049 w_in,
12050 h_out,
12051 w_out,
12052 kh,
12053 kw,
12054 sh,
12055 sw,
12056 ph,
12057 pw,
12058 dh,
12059 dw,
12060 );
12061 }
12062 }
12063 }
12064 }
12065
12066 Thunk::Conv2dBackwardWeight {
12067 x,
12068 dy,
12069 dw,
12070 n,
12071 c_in,
12072 h,
12073 w,
12074 c_out,
12075 h_out,
12076 w_out,
12077 kh,
12078 kw,
12079 sh,
12080 sw,
12081 ph,
12082 pw,
12083 dh,
12084 dw_dil,
12085 groups,
12086 } => {
12087 let n = *n as usize;
12088 let c_in = *c_in as usize;
12089 let h = *h as usize;
12090 let w = *w as usize;
12091 let c_out = *c_out as usize;
12102 let h_out = *h_out as usize;
12103 let w_out = *w_out as usize;
12104 let kh = *kh as usize;
12105 let kw = *kw as usize;
12106 let sh = *sh as usize;
12107 let sw = *sw as usize;
12108 let ph = *ph as usize;
12109 let pw = *pw as usize;
12110 let dh = *dh as usize;
12111 let dw_dil = *dw_dil as usize;
12112 let groups = *groups as usize;
12113 let c_in_per_g = c_in / groups;
12114 let c_out_per_g = c_out / groups;
12115
12116 let m_dim = c_out_per_g;
12117 let n_dim = c_in_per_g * kh * kw;
12118 let k_dim = h_out * w_out;
12119
12120 let x_stride_n = c_in * h * w;
12121 let x_stride_g = c_in_per_g * h * w;
12122 let dy_stride_n = c_out * h_out * w_out;
12123 let dy_stride_g = c_out_per_g * h_out * w_out;
12124 let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12125
12126 unsafe {
12127 let xs = sl(*x, base, n * c_in * h * w);
12128 let dys = sl(*dy, base, n * c_out * h_out * w_out);
12129 let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
12130 for v in dws.iter_mut() {
12131 *v = 0.0;
12132 }
12133
12134 let mut col = vec![0f32; n_dim * k_dim];
12135
12136 for ni in 0..n {
12137 for g in 0..groups {
12138 let x_n_g_off = ni * x_stride_n + g * x_stride_g;
12139 im2col(
12140 &xs[x_n_g_off..x_n_g_off + x_stride_g],
12141 &mut col,
12142 c_in_per_g,
12143 h,
12144 w,
12145 h_out,
12146 w_out,
12147 kh,
12148 kw,
12149 sh,
12150 sw,
12151 ph,
12152 pw,
12153 dh,
12154 dw_dil,
12155 );
12156
12157 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12158 let dw_g_off = g * dw_stride_g;
12159
12160 crate::blas::sgemm_general(
12168 dys.as_ptr().add(dy_n_g_off),
12169 col.as_ptr(),
12170 dws.as_mut_ptr().add(dw_g_off),
12171 m_dim,
12172 n_dim,
12173 k_dim,
12174 1.0,
12175 1.0,
12176 k_dim,
12177 k_dim,
12178 n_dim,
12179 false,
12180 true,
12181 );
12182 }
12183 }
12184 }
12185 }
12186
12187 Thunk::SoftmaxCrossEntropy {
12188 logits,
12189 labels,
12190 dst,
12191 n,
12192 c,
12193 } => {
12194 let n = *n as usize;
12195 let c = *c as usize;
12196 unsafe {
12197 let lg = sl(*logits, base, n * c);
12198 let lb = sl(*labels, base, n);
12199 let out = sl_mut(*dst, base, n);
12200 for ni in 0..n {
12201 let row = &lg[ni * c..(ni + 1) * c];
12202 let mut m = f32::NEG_INFINITY;
12204 for &v in row {
12205 if v > m {
12206 m = v;
12207 }
12208 }
12209 let mut sum = 0f32;
12210 for &v in row {
12211 sum += (v - m).exp();
12212 }
12213 let lse = m + sum.ln();
12214 let label_idx = lb[ni] as usize;
12215 out[ni] = lse - row[label_idx];
12217 }
12218 }
12219 }
12220
12221 Thunk::SoftmaxCrossEntropyBackward {
12222 logits,
12223 labels,
12224 d_loss,
12225 dlogits,
12226 n,
12227 c,
12228 } => {
12229 let n = *n as usize;
12230 let c = *c as usize;
12231 unsafe {
12232 let lg = sl(*logits, base, n * c);
12233 let lb = sl(*labels, base, n);
12234 let dl = sl(*d_loss, base, n);
12235 let out = sl_mut(*dlogits, base, n * c);
12236 for ni in 0..n {
12237 let row = &lg[ni * c..(ni + 1) * c];
12238 let label_idx = lb[ni] as usize;
12239 let scale = dl[ni];
12240 let mut m = f32::NEG_INFINITY;
12241 for &v in row {
12242 if v > m {
12243 m = v;
12244 }
12245 }
12246 let mut sum = 0f32;
12247 for &v in row {
12248 sum += (v - m).exp();
12249 }
12250 let inv_sum = 1.0 / sum;
12251 let dst_row = &mut out[ni * c..(ni + 1) * c];
12252 for k in 0..c {
12253 let p = (row[k] - m).exp() * inv_sum;
12254 let one_hot = if k == label_idx { 1.0 } else { 0.0 };
12255 dst_row[k] = (p - one_hot) * scale;
12256 }
12257 }
12258 }
12259 }
12260
12261 Thunk::GatherAxis {
12262 table,
12263 idx,
12264 dst,
12265 outer,
12266 axis_dim,
12267 num_idx,
12268 trailing,
12269 } => {
12270 let outer = *outer as usize;
12271 let axis_dim = *axis_dim as usize;
12272 let num_idx = *num_idx as usize;
12273 let trailing = *trailing as usize;
12274 unsafe {
12275 let tab = sl(*table, base, outer * axis_dim * trailing);
12276 let ids = sl(*idx, base, num_idx);
12277 let out = sl_mut(*dst, base, outer * num_idx * trailing);
12278 for o in 0..outer {
12279 let tab_outer = o * axis_dim * trailing;
12280 let out_outer = o * num_idx * trailing;
12281 for k in 0..num_idx {
12282 let row = ids[k] as usize;
12283 let tab_row = tab_outer + row * trailing;
12284 let out_row = out_outer + k * trailing;
12285 out[out_row..out_row + trailing]
12286 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
12287 }
12288 }
12289 }
12290 }
12291
12292 Thunk::Transpose {
12293 src,
12294 dst,
12295 in_total,
12296 out_dims,
12297 in_strides,
12298 } => {
12299 let rank = out_dims.len();
12304 let total: usize = out_dims.iter().map(|&d| d as usize).product();
12305 let in_total = *in_total as usize;
12306 unsafe {
12307 let inp = sl(*src, base, in_total);
12308 let out = sl_mut(*dst, base, total);
12309 let mut idx = vec![0usize; rank];
12310 for o in 0..total {
12311 let mut src_idx = 0usize;
12312 for d in 0..rank {
12313 src_idx += idx[d] * in_strides[d] as usize;
12314 }
12315 out[o] = inp[src_idx];
12316 for d in (0..rank).rev() {
12318 idx[d] += 1;
12319 if idx[d] < out_dims[d] as usize {
12320 break;
12321 }
12322 idx[d] = 0;
12323 }
12324 }
12325 }
12326 }
12327
12328 Thunk::CustomOp {
12334 kernel,
12335 inputs,
12336 output,
12337 attrs,
12338 } => {
12339 let (out_off, out_len, out_shape) = output;
12340 unsafe {
12341 dispatch_custom_op(
12342 &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
12343 );
12344 }
12345 }
12346 }
12347 }
12348}
12349
12350#[allow(clippy::too_many_arguments)]
12365unsafe fn griewank_process_segment(
12366 t_lo: usize,
12367 t_hi: usize,
12368 anchor_carry: &[u8],
12369 cb: usize,
12370 fwd_sched: &ThunkSchedule,
12371 fwd_init: &[u8],
12372 fwd_carry_in_off: usize,
12373 fwd_output_off: usize,
12374 fwd_x_offs: &[usize],
12375 base: *mut u8,
12376 outer_xs_offs: &[(usize, u32)],
12377 fwd_buf: &mut Vec<u8>,
12378 leaf_threshold: usize,
12379 process_iter: &mut dyn FnMut(usize, &[u8]),
12380) {
12381 unsafe {
12382 let size = t_hi - t_lo + 1;
12383 if size == 1 {
12384 process_iter(t_lo, anchor_carry);
12385 return;
12386 }
12387 if size <= leaf_threshold {
12388 let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
12390 cache.extend_from_slice(anchor_carry);
12391 fwd_buf.copy_from_slice(fwd_init);
12392 std::ptr::copy_nonoverlapping(
12393 anchor_carry.as_ptr(),
12394 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12395 cb,
12396 );
12397 for i in 1..size {
12398 let cur_iter = t_lo + i - 1;
12399 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12400 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12401 let xb = x_psb as usize;
12402 std::ptr::copy_nonoverlapping(
12403 base.add(outer_xs_off + cur_iter * xb),
12404 fwd_buf.as_mut_ptr().add(*fb_x_off),
12405 xb,
12406 );
12407 }
12408 execute_thunks(fwd_sched, fwd_buf);
12409 if fwd_output_off != fwd_carry_in_off {
12410 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12411 }
12412 cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
12413 }
12414 for t in (t_lo..=t_hi).rev() {
12416 let idx = t - t_lo;
12417 let carry = &cache[idx * cb..(idx + 1) * cb];
12418 process_iter(t, carry);
12419 }
12420 return;
12421 }
12422
12423 let mid = t_lo + size / 2;
12427 fwd_buf.copy_from_slice(fwd_init);
12428 std::ptr::copy_nonoverlapping(
12429 anchor_carry.as_ptr(),
12430 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12431 cb,
12432 );
12433 for cur_iter in t_lo..mid {
12434 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12435 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12436 let xb = x_psb as usize;
12437 std::ptr::copy_nonoverlapping(
12438 base.add(outer_xs_off + cur_iter * xb),
12439 fwd_buf.as_mut_ptr().add(*fb_x_off),
12440 xb,
12441 );
12442 }
12443 execute_thunks(fwd_sched, fwd_buf);
12444 if fwd_output_off != fwd_carry_in_off {
12445 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12446 }
12447 }
12448 let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
12449
12450 griewank_process_segment(
12454 mid,
12455 t_hi,
12456 &mid_carry,
12457 cb,
12458 fwd_sched,
12459 fwd_init,
12460 fwd_carry_in_off,
12461 fwd_output_off,
12462 fwd_x_offs,
12463 base,
12464 outer_xs_offs,
12465 fwd_buf,
12466 leaf_threshold,
12467 process_iter,
12468 );
12469 griewank_process_segment(
12471 t_lo,
12472 mid - 1,
12473 anchor_carry,
12474 cb,
12475 fwd_sched,
12476 fwd_init,
12477 fwd_carry_in_off,
12478 fwd_output_off,
12479 fwd_x_offs,
12480 base,
12481 outer_xs_offs,
12482 fwd_buf,
12483 leaf_threshold,
12484 process_iter,
12485 );
12486 }
12487}
12488
12489pub unsafe fn execute_fft1d_f64(
12506 src: usize,
12507 dst: usize,
12508 outer: usize,
12509 n_complex: usize,
12510 inverse: bool,
12511 base: *mut u8,
12512) {
12513 let row_elems = 2 * n_complex;
12514 let mut re = vec![0f64; n_complex];
12515 let mut im = vec![0f64; n_complex];
12516 let mut scratch = if n_complex.is_power_of_two() {
12519 BluesteinScratchF64::empty()
12520 } else {
12521 BluesteinScratchF64::build(n_complex, inverse)
12522 };
12523 for o in 0..outer {
12524 let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
12525 let s = unsafe { sl_f64(row_offset, base, row_elems) };
12526 re.copy_from_slice(&s[..n_complex]);
12527 im.copy_from_slice(&s[n_complex..]);
12528 if n_complex.is_power_of_two() {
12529 fft_radix2_inplace_f64(&mut re, &mut im, inverse);
12530 } else {
12531 fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
12532 }
12533 let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
12534 let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
12535 d[..n_complex].copy_from_slice(&re);
12536 d[n_complex..].copy_from_slice(&im);
12537 }
12538}
12539
12540pub unsafe fn execute_gated_delta_net_f32(
12549 q: usize,
12550 k: usize,
12551 v: usize,
12552 g: usize,
12553 beta: usize,
12554 state: usize,
12555 dst: usize,
12556 batch: usize,
12557 seq: usize,
12558 heads: usize,
12559 state_size: usize,
12560 base: *mut u8,
12561) {
12562 use rayon::prelude::*;
12563
12564 #[derive(Copy, Clone)]
12565 struct ArenaPtr(usize);
12566 unsafe impl Send for ArenaPtr {}
12567 unsafe impl Sync for ArenaPtr {}
12568 impl ArenaPtr {
12569 #[inline]
12570 fn get(self) -> *mut u8 {
12571 self.0 as *mut u8
12572 }
12573 }
12574
12575 unsafe {
12576 let arena = ArenaPtr(base as usize);
12577 let (b, s, h, n) = (batch, seq, heads, state_size);
12578 let scale = 1.0f32 / (n as f32).sqrt();
12579 let use_external = state != 0;
12580 let mut owned_state = vec![0f32; h * n * n];
12581
12582 crate::pool::num_threads();
12583
12584 assert!(
12585 n <= crate::gdn::GDN_MAX_STATE,
12586 "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
12587 crate::gdn::GDN_MAX_STATE
12588 );
12589
12590 let qs = sl(q, arena.get(), b * s * h * n);
12591 let ks = sl(k, arena.get(), b * s * h * n);
12592 let vs = sl(v, arena.get(), b * s * h * n);
12593 let gs = sl(g, arena.get(), b * s * h);
12594 let betas = sl(beta, arena.get(), b * s * h);
12595 let _out = sl_mut(dst, arena.get(), b * s * h * n);
12596 let hs_n = h * n;
12597
12598 let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
12599 for ti in 0..s {
12600 let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
12601 let gb_step = bi * s * h + ti * h + hi;
12602 let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
12603 crate::gdn::gdn_step_blas(
12604 s_mat,
12605 &qs[qkv_step..qkv_step + n],
12606 &ks[qkv_step..qkv_step + n],
12607 &vs[qkv_step..qkv_step + n],
12608 gs[gb_step],
12609 betas[gb_step],
12610 out_row,
12611 sk,
12612 n,
12613 scale,
12614 );
12615 }
12616 };
12617
12618 if !use_external && s > 1 {
12621 for bi in 0..b {
12622 (0..h).into_par_iter().for_each(|hi| {
12623 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12624 let sk = &mut sk_buf[..n];
12625 let mut local_state =
12626 [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
12627 let s_mat = &mut local_state[..n * n];
12628 s_mat.fill(0.0);
12629 run_head(bi, hi, s_mat, sk);
12630 });
12631 }
12632 return;
12633 }
12634
12635 if use_external {
12636 let state_bytes = state;
12637 (0..b * h).into_par_iter().for_each(|bhi| {
12638 let bi = bhi / h;
12639 let hi = bhi % h;
12640 let elem_off = bi * h * n * n + hi * n * n;
12641 let s_mat = sl_mut(
12642 state_bytes + elem_off * std::mem::size_of::<f32>(),
12643 arena.get(),
12644 n * n,
12645 );
12646 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12647 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12648 });
12649 } else {
12650 for bi in 0..b {
12651 owned_state.fill(0.0);
12652 owned_state
12653 .par_chunks_mut(n * n)
12654 .enumerate()
12655 .for_each(|(hi, s_mat)| {
12656 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12657 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12658 });
12659 }
12660 }
12661 }
12662}
12663
12664pub unsafe fn execute_rms_norm_backward_input_f32(
12666 x: usize,
12667 gamma: usize,
12668 beta: usize,
12669 dy: usize,
12670 dx: usize,
12671 rows: u32,
12672 h: u32,
12673 eps: f32,
12674 base: *mut u8,
12675) {
12676 let (rows, h) = (rows as usize, h as usize);
12677 let mut dg = vec![0f32; h];
12678 let mut db = vec![0f32; h];
12679 let xs = sl(x, base, rows * h);
12680 let dys = sl(dy, base, rows * h);
12681 let g = sl(gamma, base, h);
12682 let b = sl(beta, base, h);
12683 let out = sl_mut(dx, base, rows * h);
12684 for r in 0..rows {
12685 crate::training_bwd::rms_norm_backward_row(
12686 &xs[r * h..(r + 1) * h],
12687 g,
12688 b,
12689 &dys[r * h..(r + 1) * h],
12690 &mut out[r * h..(r + 1) * h],
12691 &mut dg,
12692 &mut db,
12693 eps,
12694 );
12695 }
12696}
12697
12698pub unsafe fn execute_rms_norm_backward_gamma_f32(
12699 x: usize,
12700 gamma: usize,
12701 beta: usize,
12702 dy: usize,
12703 dgamma: usize,
12704 rows: u32,
12705 h: u32,
12706 eps: f32,
12707 base: *mut u8,
12708) {
12709 let (rows, h) = (rows as usize, h as usize);
12710 let out = sl_mut(dgamma, base, h);
12711 out.fill(0.0);
12712 let mut dx = vec![0f32; h];
12713 let mut db = vec![0f32; h];
12714 let xs = sl(x, base, rows * h);
12715 let dys = sl(dy, base, rows * h);
12716 let g = sl(gamma, base, h);
12717 let b = sl(beta, base, h);
12718 for r in 0..rows {
12719 crate::training_bwd::rms_norm_backward_row(
12720 &xs[r * h..(r + 1) * h],
12721 g,
12722 b,
12723 &dys[r * h..(r + 1) * h],
12724 &mut dx,
12725 out,
12726 &mut db,
12727 eps,
12728 );
12729 }
12730}
12731
12732pub unsafe fn execute_rms_norm_backward_beta_f32(
12733 x: usize,
12734 gamma: usize,
12735 beta: usize,
12736 dy: usize,
12737 dbeta: usize,
12738 rows: u32,
12739 h: u32,
12740 eps: f32,
12741 base: *mut u8,
12742) {
12743 let (rows, h) = (rows as usize, h as usize);
12744 let out = sl_mut(dbeta, base, h);
12745 out.fill(0.0);
12746 let mut dx = vec![0f32; h];
12747 let mut dg = vec![0f32; h];
12748 let xs = sl(x, base, rows * h);
12749 let dys = sl(dy, base, rows * h);
12750 let g = sl(gamma, base, h);
12751 let b = sl(beta, base, h);
12752 for r in 0..rows {
12753 crate::training_bwd::rms_norm_backward_row(
12754 &xs[r * h..(r + 1) * h],
12755 g,
12756 b,
12757 &dys[r * h..(r + 1) * h],
12758 &mut dx,
12759 &mut dg,
12760 out,
12761 eps,
12762 );
12763 }
12764}
12765
12766pub unsafe fn execute_rope_backward_f32(
12767 dy: usize,
12768 cos: usize,
12769 sin: usize,
12770 dx: usize,
12771 batch: u32,
12772 seq: u32,
12773 hidden: u32,
12774 head_dim: u32,
12775 n_rot: u32,
12776 cos_len: u32,
12777 base: *mut u8,
12778) {
12779 let (b, s, hs, dh, nr, cl) = (
12780 batch as usize,
12781 seq as usize,
12782 hidden as usize,
12783 head_dim as usize,
12784 n_rot as usize,
12785 cos_len as usize,
12786 );
12787 let nh = hs / dh;
12788 let tab_half = dh / 2;
12789 let dys = sl(dy, base, b * s * hs);
12790 let cos_tab = sl(cos, base, cl);
12791 let sin_tab = sl(sin, base, cl);
12792 let out = sl_mut(dx, base, b * s * hs);
12793 for bi in 0..b {
12794 for si in 0..s {
12795 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12796 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12797 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12798 for hi in 0..nh {
12799 let base_idx = bi * s * hs + si * hs + hi * dh;
12800 crate::training_bwd::rope_backward_row(
12801 &dys[base_idx..base_idx + dh],
12802 cp,
12803 sp,
12804 &mut out[base_idx..base_idx + dh],
12805 dh,
12806 nr,
12807 );
12808 }
12809 }
12810 }
12811}
12812
12813pub unsafe fn execute_cumsum_backward_f32(
12814 dy: usize,
12815 dx: usize,
12816 rows: u32,
12817 cols: u32,
12818 exclusive: bool,
12819 base: *mut u8,
12820) {
12821 let (rows, cols) = (rows as usize, cols as usize);
12822 let dys = sl(dy, base, rows * cols);
12823 let out = sl_mut(dx, base, rows * cols);
12824 for r in 0..rows {
12825 crate::training_bwd::cumsum_backward_row(
12826 &dys[r * cols..(r + 1) * cols],
12827 &mut out[r * cols..(r + 1) * cols],
12828 exclusive,
12829 );
12830 }
12831}
12832
12833pub unsafe fn execute_gather_backward_f32(
12834 dy: usize,
12835 indices: usize,
12836 dst: usize,
12837 outer: u32,
12838 axis_dim: u32,
12839 num_idx: u32,
12840 trailing: u32,
12841 base: *mut u8,
12842) {
12843 let (outer, axis_dim, num_idx, trailing) = (
12844 outer as usize,
12845 axis_dim as usize,
12846 num_idx as usize,
12847 trailing as usize,
12848 );
12849 let out = sl_mut(dst, base, outer * axis_dim * trailing);
12850 out.fill(0.0);
12851 crate::training_bwd::gather_axis_backward(
12852 sl(dy, base, outer * num_idx * trailing),
12853 sl(indices, base, num_idx),
12854 out,
12855 outer,
12856 axis_dim,
12857 num_idx,
12858 trailing,
12859 );
12860}
12861
12862pub unsafe fn execute_dequant_matmul_gguf_f32(
12864 x: usize,
12865 w_q: usize,
12866 dst: usize,
12867 m: usize,
12868 k: usize,
12869 n: usize,
12870 scheme: rlx_ir::quant::QuantScheme,
12871 base: *mut u8,
12872) {
12873 unsafe {
12874 let block_bytes = scheme.gguf_block_bytes() as usize;
12875 let block_elems = scheme.gguf_block_size() as usize;
12876 let total_bytes = (k * n) / block_elems * block_bytes;
12877 let xs = sl(x, base, m * k);
12878 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
12879 let out = sl_mut(dst, base, m * n);
12880 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
12881 }
12882}
12883
12884pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
12886 input: usize,
12887 w_q: usize,
12888 expert_idx: usize,
12889 dst: usize,
12890 m: usize,
12891 k: usize,
12892 n: usize,
12893 num_experts: usize,
12894 scheme: rlx_ir::quant::QuantScheme,
12895 base: *mut u8,
12896) {
12897 unsafe {
12898 let block_bytes = scheme.gguf_block_bytes() as usize;
12899 let block_elems = scheme.gguf_block_size() as usize;
12900 let slab_bytes = (k * n) / block_elems * block_bytes;
12901 let xs = sl(input, base, m * k);
12902 let w_bytes =
12903 std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
12904 let ids = sl(expert_idx, base, m);
12905 let out = sl_mut(dst, base, m * n);
12906 crate::gguf_matmul::gguf_grouped_matmul_bt(
12907 xs,
12908 w_bytes,
12909 ids,
12910 out,
12911 m,
12912 k,
12913 n,
12914 num_experts,
12915 scheme,
12916 );
12917 }
12918}
12919
12920pub unsafe fn execute_dequant_matmul_int4_f32(
12922 x: usize,
12923 w_q: usize,
12924 scale: usize,
12925 zp: usize,
12926 dst: usize,
12927 m: usize,
12928 k: usize,
12929 n: usize,
12930 block_size: u32,
12931 is_asymmetric: bool,
12932 base: *mut u8,
12933) {
12934 let bs = block_size as usize;
12935 let n_blocks = k.div_ceil(bs);
12936 unsafe {
12937 let xs = sl(x, base, m * k);
12938 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12939 let scales = sl(scale, base, n_blocks * n);
12940 let zps = if is_asymmetric {
12941 sl(zp, base, n_blocks * n)
12942 } else {
12943 &[][..]
12944 };
12945 let out = sl_mut(dst, base, m * n);
12946 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
12947 }
12948}
12949
12950pub unsafe fn execute_dequant_matmul_fp8_f32(
12952 x: usize,
12953 w_q: usize,
12954 scale: usize,
12955 dst: usize,
12956 m: usize,
12957 k: usize,
12958 n: usize,
12959 e5m2: bool,
12960 base: *mut u8,
12961) {
12962 unsafe {
12963 let xs = sl(x, base, m * k);
12964 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
12965 let scales = sl(scale, base, n);
12966 let out = sl_mut(dst, base, m * n);
12967 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
12968 }
12969}
12970
12971pub unsafe fn execute_dequant_matmul_nvfp4_f32(
12973 x: usize,
12974 w_q: usize,
12975 scale: usize,
12976 global_scale: usize,
12977 dst: usize,
12978 m: usize,
12979 k: usize,
12980 n: usize,
12981 base: *mut u8,
12982) {
12983 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
12984 unsafe {
12985 let xs = sl(x, base, m * k);
12986 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12987 let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
12988 let gs = sl(global_scale, base, 1)[0];
12989 let out = sl_mut(dst, base, m * n);
12990 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
12991 }
12992}
12993
12994pub unsafe fn execute_gated_delta_net_f16(
12996 q: usize,
12997 k: usize,
12998 v: usize,
12999 g: usize,
13000 beta: usize,
13001 state: usize,
13002 dst: usize,
13003 batch: usize,
13004 seq: usize,
13005 heads: usize,
13006 state_size: usize,
13007 base: *mut u8,
13008) {
13009 use half::f16;
13010 unsafe {
13011 let read_f16 = |off: usize, len: usize| -> Vec<f32> {
13012 let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
13013 raw.chunks_exact(2)
13014 .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
13015 .collect()
13016 };
13017 let write_f16 = |off: usize, data: &[f32]| {
13018 let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
13019 for (i, &v) in data.iter().enumerate() {
13020 let le = f16::from_f32(v).to_le_bytes();
13021 out[i * 2] = le[0];
13022 out[i * 2 + 1] = le[1];
13023 }
13024 };
13025
13026 let (b, s, h, n) = (batch, seq, heads, state_size);
13027 let q_f = read_f16(q, b * s * h * n);
13028 let k_f = read_f16(k, b * s * h * n);
13029 let v_f = read_f16(v, b * s * h * n);
13030 let g_f = read_f16(g, b * s * h);
13031 let b_f = read_f16(beta, b * s * h);
13032 let mut state_f = if state != 0 {
13033 read_f16(state, b * h * n * n)
13034 } else {
13035 vec![0f32; b * h * n * n]
13036 };
13037 let mut out_f = vec![0f32; b * s * h * n];
13038 let scale = 1.0f32 / (n as f32).sqrt();
13039 let mut sk_buf = vec![0f32; n];
13040 let mut owned_state = vec![0f32; h * n * n];
13041
13042 for bi in 0..b {
13043 let state_slice: &mut [f32] = if state != 0 {
13044 let start = bi * h * n * n;
13045 &mut state_f[start..start + h * n * n]
13046 } else {
13047 owned_state.fill(0.0);
13048 &mut owned_state
13049 };
13050
13051 for ti in 0..s {
13052 let qkv_step_base = bi * s * h * n + ti * h * n;
13053 let gb_step_base = bi * s * h + ti * h;
13054
13055 for hi in 0..h {
13056 let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13057 let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13058 let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13059 let g_t = g_f[gb_step_base + hi];
13060 let beta_t = b_f[gb_step_base + hi];
13061
13062 let s_base = hi * n * n;
13063 let s_mat = &mut state_slice[s_base..s_base + n * n];
13064
13065 let g_exp = g_t.exp();
13066 for st in s_mat.iter_mut() {
13067 *st *= g_exp;
13068 }
13069
13070 for j in 0..n {
13071 let mut acc = 0f32;
13072 for i in 0..n {
13073 acc += s_mat[i * n + j] * k_row[i];
13074 }
13075 sk_buf[j] = acc;
13076 }
13077
13078 for j in 0..n {
13079 sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
13080 }
13081
13082 for i in 0..n {
13083 let ki = k_row[i];
13084 for j in 0..n {
13085 s_mat[i * n + j] += ki * sk_buf[j];
13086 }
13087 }
13088
13089 let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13090 for j in 0..n {
13091 let mut acc = 0f32;
13092 for i in 0..n {
13093 acc += s_mat[i * n + j] * q_row[i];
13094 }
13095 out_row[j] = acc * scale;
13096 }
13097 }
13098 }
13099 }
13100
13101 write_f16(dst, &out_f);
13102 if state != 0 {
13103 write_f16(state, &state_f);
13104 }
13105 }
13106}
13107
13108pub unsafe fn execute_group_norm_nchw_f32(
13110 src: usize,
13111 g: usize,
13112 b: usize,
13113 dst: usize,
13114 n: usize,
13115 c: usize,
13116 h: usize,
13117 w: usize,
13118 num_groups: usize,
13119 eps: f32,
13120 base: *mut u8,
13121) {
13122 let plane = c * h * w;
13123 for ni in 0..n {
13124 let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13125 let gamma = unsafe { sl(g, base, c) };
13126 let beta = unsafe { sl(b, base, c) };
13127 let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13128 crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
13129 }
13130}
13131
13132pub unsafe fn execute_layer_norm2d_nchw_f32(
13134 src: usize,
13135 g: usize,
13136 b: usize,
13137 dst: usize,
13138 n: usize,
13139 c: usize,
13140 h: usize,
13141 w: usize,
13142 eps: f32,
13143 base: *mut u8,
13144) {
13145 let plane = c * h * w;
13146 unsafe {
13147 let input = sl(src, base, n * plane);
13148 let gamma = sl(g, base, c);
13149 let beta = sl(b, base, c);
13150 let output = sl_mut(dst, base, n * plane);
13151 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
13152 }
13153}
13154
13155pub unsafe fn execute_conv_transpose2d_nchw_f32(
13157 src: usize,
13158 weight: usize,
13159 dst: usize,
13160 n: usize,
13161 c_in: usize,
13162 h: usize,
13163 w_in: usize,
13164 c_out: usize,
13165 h_out: usize,
13166 w_out: usize,
13167 kh: usize,
13168 kw: usize,
13169 sh: usize,
13170 sw: usize,
13171 ph: usize,
13172 pw: usize,
13173 dh: usize,
13174 dw: usize,
13175 groups: usize,
13176 base: *mut u8,
13177) {
13178 let in_elems = n * c_in * h * w_in;
13179 let w_elems = c_in * (c_out / groups) * kh * kw;
13180 let out_elems = n * c_out * h_out * w_out;
13181 unsafe {
13182 let input = sl(src, base, in_elems);
13183 let wt = sl(weight, base, w_elems);
13184 let output = sl_mut(dst, base, out_elems);
13185 crate::kernels::conv_transpose2d_nchw(
13186 input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
13187 dw, groups,
13188 );
13189 }
13190}
13191
13192pub unsafe fn execute_resize_nearest_2x_f32(
13194 src: usize,
13195 dst: usize,
13196 n: usize,
13197 c: usize,
13198 h: usize,
13199 w: usize,
13200 base: *mut u8,
13201) {
13202 let in_plane = c * h * w;
13203 let out_plane = c * h * 2 * w * 2;
13204 for ni in 0..n {
13205 let input = unsafe {
13206 sl(
13207 src + ni * in_plane * std::mem::size_of::<f32>(),
13208 base,
13209 in_plane,
13210 )
13211 };
13212 let output = unsafe {
13213 sl_mut(
13214 dst + ni * out_plane * std::mem::size_of::<f32>(),
13215 base,
13216 out_plane,
13217 )
13218 };
13219 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
13220 }
13221}
13222
13223pub unsafe fn execute_axial_rope2d_f32(
13225 src: usize,
13226 dst: usize,
13227 batch: usize,
13228 seq: usize,
13229 hidden: usize,
13230 end_x: usize,
13231 end_y: usize,
13232 head_dim: usize,
13233 num_heads: usize,
13234 theta: f32,
13235 repeat_factor: usize,
13236 base: *mut u8,
13237) {
13238 let plane = seq * hidden;
13239 let plane_bytes = plane * std::mem::size_of::<f32>();
13240 for bi in 0..batch {
13241 let in_off = src + bi * plane_bytes;
13242 let input = unsafe { sl(in_off, base, plane) };
13243 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
13244 input,
13245 num_heads,
13246 seq,
13247 head_dim,
13248 end_x,
13249 end_y,
13250 theta,
13251 repeat_factor,
13252 );
13253 let out_off = dst + bi * plane_bytes;
13254 let output = unsafe { sl_mut(out_off, base, plane) };
13255 output.copy_from_slice(&rotated);
13256 }
13257}
13258
13259pub unsafe fn execute_fft1d_f32(
13261 src: usize,
13262 dst: usize,
13263 outer: usize,
13264 n_complex: usize,
13265 inverse: bool,
13266 base: *mut u8,
13267) {
13268 let row_elems = 2 * n_complex;
13269 let mut re = vec![0f32; n_complex];
13270 let mut im = vec![0f32; n_complex];
13271 let mut scratch = if n_complex.is_power_of_two() {
13272 BluesteinScratchF32::empty()
13273 } else {
13274 BluesteinScratchF32::build(n_complex, inverse)
13275 };
13276 for o in 0..outer {
13277 let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
13278 let s = unsafe { sl(row_offset, base, row_elems) };
13279 re.copy_from_slice(&s[..n_complex]);
13280 im.copy_from_slice(&s[n_complex..]);
13281 if n_complex.is_power_of_two() {
13282 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
13283 } else {
13284 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
13285 }
13286 let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
13287 let d = unsafe { sl_mut(dst_offset, base, row_elems) };
13288 d[..n_complex].copy_from_slice(&re);
13289 d[n_complex..].copy_from_slice(&im);
13290 }
13291}
13292
13293fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
13298 let n = re.len();
13299 debug_assert_eq!(im.len(), n);
13300 debug_assert!(
13301 n.is_power_of_two(),
13302 "fft_radix2_f32: n={n} must be a power of two"
13303 );
13304 if n <= 1 {
13305 return;
13306 }
13307
13308 let mut j = 0usize;
13309 for i in 1..n {
13310 let mut bit = n >> 1;
13311 while j & bit != 0 {
13312 j ^= bit;
13313 bit >>= 1;
13314 }
13315 j ^= bit;
13316 if i < j {
13317 re.swap(i, j);
13318 im.swap(i, j);
13319 }
13320 }
13321
13322 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13323 let mut len = 2usize;
13324 while len <= n {
13325 let half = len / 2;
13326 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13327 let w_re_step = theta.cos();
13328 let w_im_step = theta.sin();
13329 let mut i = 0usize;
13330 while i < n {
13331 let mut wre = 1.0_f64;
13332 let mut wim = 0.0_f64;
13333 for k in 0..half {
13334 let wre_f = wre as f32;
13335 let wim_f = wim as f32;
13336 let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
13337 let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
13338 let u_re = re[i + k];
13339 let u_im = im[i + k];
13340 re[i + k] = u_re + t_re;
13341 im[i + k] = u_im + t_im;
13342 re[i + k + half] = u_re - t_re;
13343 im[i + k + half] = u_im - t_im;
13344 let new_wre = wre * w_re_step - wim * w_im_step;
13345 let new_wim = wre * w_im_step + wim * w_re_step;
13346 wre = new_wre;
13347 wim = new_wim;
13348 }
13349 i += len;
13350 }
13351 len <<= 1;
13352 }
13353}
13354
13355fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
13359 let n = re.len();
13360 debug_assert_eq!(im.len(), n);
13361 debug_assert!(
13362 n.is_power_of_two(),
13363 "fft_radix2: n={n} must be a power of two"
13364 );
13365 if n <= 1 {
13366 return;
13367 }
13368
13369 let mut j = 0usize;
13371 for i in 1..n {
13372 let mut bit = n >> 1;
13373 while j & bit != 0 {
13374 j ^= bit;
13375 bit >>= 1;
13376 }
13377 j ^= bit;
13378 if i < j {
13379 re.swap(i, j);
13380 im.swap(i, j);
13381 }
13382 }
13383
13384 let sign = if inverse { 1.0 } else { -1.0 };
13386 let mut len = 2usize;
13387 while len <= n {
13388 let half = len / 2;
13389 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13390 let w_re_step = theta.cos();
13391 let w_im_step = theta.sin();
13392 let mut i = 0usize;
13393 while i < n {
13394 let mut wre = 1.0_f64;
13396 let mut wim = 0.0_f64;
13397 for k in 0..half {
13398 let t_re = wre * re[i + k + half] - wim * im[i + k + half];
13399 let t_im = wre * im[i + k + half] + wim * re[i + k + half];
13400 let u_re = re[i + k];
13401 let u_im = im[i + k];
13402 re[i + k] = u_re + t_re;
13403 im[i + k] = u_im + t_im;
13404 re[i + k + half] = u_re - t_re;
13405 im[i + k + half] = u_im - t_im;
13406 let new_wre = wre * w_re_step - wim * w_im_step;
13407 let new_wim = wre * w_im_step + wim * w_re_step;
13408 wre = new_wre;
13409 wim = new_wim;
13410 }
13411 i += len;
13412 }
13413 len <<= 1;
13414 }
13415}
13416
13417struct BluesteinScratchF64 {
13421 m: usize,
13423 w_re: Vec<f64>,
13427 w_im: Vec<f64>,
13428 bf_re: Vec<f64>,
13431 bf_im: Vec<f64>,
13432 ar: Vec<f64>,
13434 ai: Vec<f64>,
13435}
13436
13437impl BluesteinScratchF64 {
13438 fn empty() -> Self {
13439 Self {
13440 m: 0,
13441 w_re: Vec::new(),
13442 w_im: Vec::new(),
13443 bf_re: Vec::new(),
13444 bf_im: Vec::new(),
13445 ar: Vec::new(),
13446 ai: Vec::new(),
13447 }
13448 }
13449
13450 fn build(n: usize, inverse: bool) -> Self {
13451 let m = if n <= 1 {
13454 1
13455 } else {
13456 (2 * n - 1).next_power_of_two()
13457 };
13458
13459 let mod_2n = (2 * n) as u64;
13462 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13463 let mut w_re = vec![0.0_f64; n];
13464 let mut w_im = vec![0.0_f64; n];
13465 for k in 0..n {
13466 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13467 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13468 w_re[k] = theta.cos();
13469 w_im[k] = theta.sin();
13470 }
13471
13472 let mut bf_re = vec![0.0_f64; m];
13475 let mut bf_im = vec![0.0_f64; m];
13476 if n > 0 {
13477 bf_re[0] = w_re[0];
13478 bf_im[0] = -w_im[0];
13479 for k in 1..n {
13480 bf_re[k] = w_re[k];
13481 bf_im[k] = -w_im[k];
13482 bf_re[m - k] = w_re[k];
13483 bf_im[m - k] = -w_im[k];
13484 }
13485 }
13486 if m > 1 {
13487 fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
13488 }
13489
13490 Self {
13491 m,
13492 w_re,
13493 w_im,
13494 bf_re,
13495 bf_im,
13496 ar: vec![0.0_f64; m],
13497 ai: vec![0.0_f64; m],
13498 }
13499 }
13500}
13501
13502fn fft_bluestein_inplace_f64(
13511 re: &mut [f64],
13512 im: &mut [f64],
13513 _inverse: bool,
13514 s: &mut BluesteinScratchF64,
13515) {
13516 let n = re.len();
13517 debug_assert_eq!(im.len(), n);
13518 debug_assert_eq!(s.w_re.len(), n);
13519 if n <= 1 {
13520 return;
13521 }
13522 let m = s.m;
13523
13524 for k in 0..m {
13526 s.ar[k] = 0.0;
13527 s.ai[k] = 0.0;
13528 }
13529 for k in 0..n {
13530 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13531 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13532 }
13533
13534 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
13536
13537 for k in 0..m {
13539 let ar = s.ar[k];
13540 let ai = s.ai[k];
13541 let br = s.bf_re[k];
13542 let bi = s.bf_im[k];
13543 s.ar[k] = ar * br - ai * bi;
13544 s.ai[k] = ar * bi + ai * br;
13545 }
13546
13547 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
13550 let inv_m = 1.0 / (m as f64);
13551
13552 for k in 0..n {
13554 let yr = s.ar[k] * inv_m;
13555 let yi = s.ai[k] * inv_m;
13556 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13557 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13558 }
13559}
13560
13561struct BluesteinScratchF32 {
13565 m: usize,
13566 w_re: Vec<f32>,
13567 w_im: Vec<f32>,
13568 bf_re: Vec<f32>,
13569 bf_im: Vec<f32>,
13570 ar: Vec<f32>,
13571 ai: Vec<f32>,
13572}
13573
13574impl BluesteinScratchF32 {
13575 fn empty() -> Self {
13576 Self {
13577 m: 0,
13578 w_re: Vec::new(),
13579 w_im: Vec::new(),
13580 bf_re: Vec::new(),
13581 bf_im: Vec::new(),
13582 ar: Vec::new(),
13583 ai: Vec::new(),
13584 }
13585 }
13586
13587 fn build(n: usize, inverse: bool) -> Self {
13588 let m = if n <= 1 {
13589 1
13590 } else {
13591 (2 * n - 1).next_power_of_two()
13592 };
13593
13594 let mod_2n = (2 * n) as u64;
13595 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13596 let mut w_re = vec![0.0_f32; n];
13597 let mut w_im = vec![0.0_f32; n];
13598 for k in 0..n {
13599 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13600 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13601 w_re[k] = theta.cos() as f32;
13602 w_im[k] = theta.sin() as f32;
13603 }
13604
13605 let mut bf_re = vec![0.0_f32; m];
13606 let mut bf_im = vec![0.0_f32; m];
13607 if n > 0 {
13608 bf_re[0] = w_re[0];
13609 bf_im[0] = -w_im[0];
13610 for k in 1..n {
13611 bf_re[k] = w_re[k];
13612 bf_im[k] = -w_im[k];
13613 bf_re[m - k] = w_re[k];
13614 bf_im[m - k] = -w_im[k];
13615 }
13616 }
13617 if m > 1 {
13618 fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
13619 }
13620
13621 Self {
13622 m,
13623 w_re,
13624 w_im,
13625 bf_re,
13626 bf_im,
13627 ar: vec![0.0_f32; m],
13628 ai: vec![0.0_f32; m],
13629 }
13630 }
13631}
13632
13633fn fft_bluestein_inplace_f32(
13634 re: &mut [f32],
13635 im: &mut [f32],
13636 _inverse: bool,
13637 s: &mut BluesteinScratchF32,
13638) {
13639 let n = re.len();
13640 debug_assert_eq!(im.len(), n);
13641 debug_assert_eq!(s.w_re.len(), n);
13642 if n <= 1 {
13643 return;
13644 }
13645 let m = s.m;
13646
13647 for k in 0..m {
13648 s.ar[k] = 0.0;
13649 s.ai[k] = 0.0;
13650 }
13651 for k in 0..n {
13652 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13653 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13654 }
13655
13656 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
13657
13658 for k in 0..m {
13659 let ar = s.ar[k];
13660 let ai = s.ai[k];
13661 let br = s.bf_re[k];
13662 let bi = s.bf_im[k];
13663 s.ar[k] = ar * br - ai * bi;
13664 s.ai[k] = ar * bi + ai * br;
13665 }
13666
13667 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
13668 let inv_m = 1.0_f32 / (m as f32);
13669
13670 for k in 0..n {
13671 let yr = s.ar[k] * inv_m;
13672 let yi = s.ai[k] * inv_m;
13673 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13674 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13675 }
13676}
13677
13678unsafe fn dispatch_custom_op(
13684 kernel: &dyn crate::op_registry::CpuKernel,
13685 inputs: &[(usize, u32, Shape)],
13686 out_off: usize,
13687 out_len: u32,
13688 out_shape: &Shape,
13689 attrs: &[u8],
13690 base: *mut u8,
13691) {
13692 use crate::op_registry::{CpuTensorMut, CpuTensorRef};
13693 use rlx_ir::DType;
13694
13695 macro_rules! build_in_view {
13700 ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
13701 CpuTensorRef::$variant {
13702 data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
13703 shape: $shape,
13704 }
13705 };
13706 }
13707 macro_rules! build_out_view {
13708 ($variant:ident, $rust_ty:ty) => {
13709 CpuTensorMut::$variant {
13710 data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
13711 shape: out_shape,
13712 }
13713 };
13714 }
13715
13716 let in_views: Vec<CpuTensorRef<'_>> = inputs
13717 .iter()
13718 .map(|(off, len, shape)| {
13719 let n = *len as usize;
13720 let off = *off;
13721 match shape.dtype() {
13722 DType::F32 => build_in_view!(shape, off, n, F32, f32),
13723 DType::F64 => build_in_view!(shape, off, n, F64, f64),
13724 DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
13725 DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
13726 DType::I8 => build_in_view!(shape, off, n, I8, i8),
13727 DType::I16 => build_in_view!(shape, off, n, I16, i16),
13728 DType::I32 => build_in_view!(shape, off, n, I32, i32),
13729 DType::I64 => build_in_view!(shape, off, n, I64, i64),
13730 DType::U8 => build_in_view!(shape, off, n, U8, u8),
13731 DType::U32 => build_in_view!(shape, off, n, U32, u32),
13732 DType::Bool => build_in_view!(shape, off, n, Bool, u8),
13733 DType::C64 => panic!(
13737 "Op::Custom kernel input has DType::C64 — built-in \
13738 complex ops handle their own kernels; user-registered \
13739 ops don't yet see complex tensors"
13740 ),
13741 }
13742 })
13743 .collect();
13744
13745 let result = match out_shape.dtype() {
13746 DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
13747 DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
13748 DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
13749 DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
13750 DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
13751 DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
13752 DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
13753 DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
13754 DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
13755 DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
13756 DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
13757 DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
13758 };
13759 if let Err(e) = result {
13760 panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
13761 }
13762}
13763
13764#[inline(always)]
13770unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
13771 if offset == usize::MAX {
13772 return &[];
13773 }
13774 unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
13775}
13776
13777#[inline(always)]
13778unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
13779 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
13780}
13781
13782#[inline(always)]
13784fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
13788 use rlx_ir::op::Activation;
13789 match act {
13790 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
13791 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
13792 Activation::Silu => crate::kernels::par_silu_inplace(d),
13793 Activation::Relu => {
13794 for v in d.iter_mut() {
13795 *v = v.max(0.0);
13796 }
13797 }
13798 Activation::Sigmoid => {
13799 for v in d.iter_mut() {
13800 *v = 1.0 / (1.0 + (-*v).exp());
13801 }
13802 }
13803 Activation::Tanh => {
13804 for v in d.iter_mut() {
13805 *v = v.tanh();
13806 }
13807 }
13808 Activation::Exp => {
13809 for v in d.iter_mut() {
13810 *v = v.exp();
13811 }
13812 }
13813 Activation::Log => {
13814 for v in d.iter_mut() {
13815 *v = v.ln();
13816 }
13817 }
13818 Activation::Sqrt => {
13819 for v in d.iter_mut() {
13820 *v = v.sqrt();
13821 }
13822 }
13823 Activation::Rsqrt => {
13824 for v in d.iter_mut() {
13825 *v = 1.0 / v.sqrt();
13826 }
13827 }
13828 Activation::Neg => {
13829 for v in d.iter_mut() {
13830 *v = -*v;
13831 }
13832 }
13833 Activation::Abs => {
13834 for v in d.iter_mut() {
13835 *v = v.abs();
13836 }
13837 }
13838 Activation::Round => {
13839 for v in d.iter_mut() {
13840 *v = v.round();
13841 }
13842 }
13843 Activation::Sin => {
13844 for v in d.iter_mut() {
13845 *v = v.sin();
13846 }
13847 }
13848 Activation::Cos => {
13849 for v in d.iter_mut() {
13850 *v = v.cos();
13851 }
13852 }
13853 Activation::Tan => {
13854 for v in d.iter_mut() {
13855 *v = v.tan();
13856 }
13857 }
13858 Activation::Atan => {
13859 for v in d.iter_mut() {
13860 *v = v.atan();
13861 }
13862 }
13863 }
13864}
13865
13866#[allow(clippy::too_many_arguments)]
13875fn im2col(
13876 x: &[f32],
13877 col: &mut [f32],
13878 c_in: usize,
13879 h: usize,
13880 w: usize,
13881 h_out: usize,
13882 w_out: usize,
13883 kh: usize,
13884 kw: usize,
13885 sh: usize,
13886 sw: usize,
13887 ph: usize,
13888 pw: usize,
13889 dh: usize,
13890 dw_dil: usize,
13891) {
13892 let n_dim = h_out * w_out;
13893 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
13894 debug_assert_eq!(x.len(), c_in * h * w);
13895 let h_isz = h as isize;
13896 let w_isz = w as isize;
13897 let ph_isz = ph as isize;
13898 let pw_isz = pw as isize;
13899 for ci in 0..c_in {
13900 for ki in 0..kh {
13901 for kj in 0..kw {
13902 let row = ((ci * kh) + ki) * kw + kj;
13903 let row_off = row * n_dim;
13904 for ho in 0..h_out {
13905 let hi = (ho * sh + ki * dh) as isize - ph_isz;
13906 if hi < 0 || hi >= h_isz {
13907 for wo in 0..w_out {
13908 col[row_off + ho * w_out + wo] = 0.0;
13909 }
13910 continue;
13911 }
13912 let hi = hi as usize;
13913 let in_row_off = (ci * h + hi) * w;
13914 for wo in 0..w_out {
13915 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
13916 col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
13917 0.0
13918 } else {
13919 x[in_row_off + wi as usize]
13920 };
13921 }
13922 }
13923 }
13924 }
13925 }
13926}
13927
13928#[allow(clippy::too_many_arguments)]
13935fn col2im(
13936 col: &[f32],
13937 x: &mut [f32],
13938 c_in: usize,
13939 h: usize,
13940 w: usize,
13941 h_out: usize,
13942 w_out: usize,
13943 kh: usize,
13944 kw: usize,
13945 sh: usize,
13946 sw: usize,
13947 ph: usize,
13948 pw: usize,
13949 dh: usize,
13950 dw_dil: usize,
13951) {
13952 let n_dim = h_out * w_out;
13953 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
13954 debug_assert_eq!(x.len(), c_in * h * w);
13955 let h_isz = h as isize;
13956 let w_isz = w as isize;
13957 let ph_isz = ph as isize;
13958 let pw_isz = pw as isize;
13959 for ci in 0..c_in {
13960 for ki in 0..kh {
13961 for kj in 0..kw {
13962 let row = ((ci * kh) + ki) * kw + kj;
13963 let row_off = row * n_dim;
13964 for ho in 0..h_out {
13965 let hi = (ho * sh + ki * dh) as isize - ph_isz;
13966 if hi < 0 || hi >= h_isz {
13967 continue;
13968 }
13969 let hi = hi as usize;
13970 let in_row_off = (ci * h + hi) * w;
13971 for wo in 0..w_out {
13972 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
13973 if wi < 0 || wi >= w_isz {
13974 continue;
13975 }
13976 x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
13977 }
13978 }
13979 }
13980 }
13981 }
13982}
13983
13984fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
13994 match axis {
13995 None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
13996 Some(d) => {
13997 let chan_dim = shape.dim(d).unwrap_static();
13998 let inner: usize = (d + 1..shape.rank())
13999 .map(|i| shape.dim(i).unwrap_static())
14000 .product::<usize>()
14001 .max(1);
14002 (d, chan_dim, inner)
14003 }
14004 }
14005}
14006
14007fn activation_backward_kernel(
14008 act: rlx_ir::op::Activation,
14009 xs: &[f32],
14010 dys: &[f32],
14011 out: &mut [f32],
14012) {
14013 use rlx_ir::op::Activation;
14014 let n = xs.len();
14015 debug_assert_eq!(dys.len(), n);
14016 debug_assert_eq!(out.len(), n);
14017 match act {
14018 Activation::Relu => {
14019 for i in 0..n {
14020 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14021 }
14022 }
14023 Activation::Sigmoid => {
14024 for i in 0..n {
14025 let s = 1.0 / (1.0 + (-xs[i]).exp());
14026 out[i] = s * (1.0 - s) * dys[i];
14027 }
14028 }
14029 Activation::Tanh => {
14030 for i in 0..n {
14031 let t = xs[i].tanh();
14032 out[i] = (1.0 - t * t) * dys[i];
14033 }
14034 }
14035 Activation::Silu => {
14036 for i in 0..n {
14038 let s = 1.0 / (1.0 + (-xs[i]).exp());
14039 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14040 }
14041 }
14042 Activation::Gelu => {
14043 const INV_SQRT2: f32 = 0.707_106_77;
14046 const INV_SQRT_2PI: f32 = 0.398_942_3;
14047 for i in 0..n {
14048 let x = xs[i];
14049 let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
14050 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14051 out[i] = (phi + x * pdf) * dys[i];
14052 }
14053 }
14054 Activation::GeluApprox => {
14055 const C: f32 = 0.797_884_6; const A: f32 = 0.044_715;
14059 for i in 0..n {
14060 let x = xs[i];
14061 let inner = C * (x + A * x * x * x);
14062 let t = inner.tanh();
14063 let dinner = C * (1.0 + 3.0 * A * x * x);
14064 let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
14065 out[i] = d * dys[i];
14066 }
14067 }
14068 Activation::Exp => {
14069 for i in 0..n {
14070 out[i] = xs[i].exp() * dys[i];
14071 }
14072 }
14073 Activation::Log => {
14074 for i in 0..n {
14075 out[i] = dys[i] / xs[i];
14076 }
14077 }
14078 Activation::Sqrt => {
14079 for i in 0..n {
14081 let s = xs[i].sqrt();
14082 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14083 }
14084 }
14085 Activation::Rsqrt => {
14086 for i in 0..n {
14088 let s = xs[i].sqrt();
14089 out[i] = if s > 0.0 {
14090 -0.5 * dys[i] / (xs[i] * s)
14091 } else {
14092 0.0
14093 };
14094 }
14095 }
14096 Activation::Neg => {
14097 for i in 0..n {
14098 out[i] = -dys[i];
14099 }
14100 }
14101 Activation::Abs => {
14102 for i in 0..n {
14104 let x = xs[i];
14105 let s = if x > 0.0 {
14106 1.0
14107 } else if x < 0.0 {
14108 -1.0
14109 } else {
14110 0.0
14111 };
14112 out[i] = s * dys[i];
14113 }
14114 }
14115 Activation::Round => {
14116 out.copy_from_slice(dys);
14121 }
14122 Activation::Sin => {
14123 for i in 0..n {
14125 out[i] = xs[i].cos() * dys[i];
14126 }
14127 }
14128 Activation::Cos => {
14129 for i in 0..n {
14130 out[i] = -xs[i].sin() * dys[i];
14131 }
14132 }
14133 Activation::Tan => {
14134 for i in 0..n {
14136 let t = xs[i].tan();
14137 out[i] = (1.0 + t * t) * dys[i];
14138 }
14139 }
14140 Activation::Atan => {
14141 for i in 0..n {
14143 let x = xs[i];
14144 out[i] = dys[i] / (1.0 + x * x);
14145 }
14146 }
14147 }
14148}
14149
14150fn activation_backward_kernel_f64(
14154 act: rlx_ir::op::Activation,
14155 xs: &[f64],
14156 dys: &[f64],
14157 out: &mut [f64],
14158) {
14159 use rlx_ir::op::Activation;
14160 let n = xs.len();
14161 debug_assert_eq!(dys.len(), n);
14162 debug_assert_eq!(out.len(), n);
14163 match act {
14164 Activation::Relu => {
14165 for i in 0..n {
14166 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14167 }
14168 }
14169 Activation::Sigmoid => {
14170 for i in 0..n {
14171 let s = 1.0 / (1.0 + (-xs[i]).exp());
14172 out[i] = s * (1.0 - s) * dys[i];
14173 }
14174 }
14175 Activation::Tanh => {
14176 for i in 0..n {
14177 let t = xs[i].tanh();
14178 out[i] = (1.0 - t * t) * dys[i];
14179 }
14180 }
14181 Activation::Silu => {
14182 for i in 0..n {
14183 let s = 1.0 / (1.0 + (-xs[i]).exp());
14184 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14185 }
14186 }
14187 Activation::Gelu | Activation::GeluApprox => {
14188 const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
14190 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
14191 for i in 0..n {
14192 let x = xs[i];
14193 let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
14194 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14195 out[i] = (phi + x * pdf) * dys[i];
14196 }
14197 }
14198 Activation::Exp => {
14199 for i in 0..n {
14200 out[i] = xs[i].exp() * dys[i];
14201 }
14202 }
14203 Activation::Log => {
14204 for i in 0..n {
14205 out[i] = dys[i] / xs[i];
14206 }
14207 }
14208 Activation::Sqrt => {
14209 for i in 0..n {
14210 let s = xs[i].sqrt();
14211 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14212 }
14213 }
14214 Activation::Rsqrt => {
14215 for i in 0..n {
14216 let s = xs[i].sqrt();
14217 out[i] = if s > 0.0 {
14218 -0.5 * dys[i] / (xs[i] * s)
14219 } else {
14220 0.0
14221 };
14222 }
14223 }
14224 Activation::Neg => {
14225 for i in 0..n {
14226 out[i] = -dys[i];
14227 }
14228 }
14229 Activation::Abs => {
14230 for i in 0..n {
14231 let x = xs[i];
14232 let s = if x > 0.0 {
14233 1.0
14234 } else if x < 0.0 {
14235 -1.0
14236 } else {
14237 0.0
14238 };
14239 out[i] = s * dys[i];
14240 }
14241 }
14242 Activation::Round => {
14243 out.copy_from_slice(dys);
14244 }
14245 Activation::Sin => {
14246 for i in 0..n {
14247 out[i] = xs[i].cos() * dys[i];
14248 }
14249 }
14250 Activation::Cos => {
14251 for i in 0..n {
14252 out[i] = -xs[i].sin() * dys[i];
14253 }
14254 }
14255 Activation::Tan => {
14256 for i in 0..n {
14257 let t = xs[i].tan();
14258 out[i] = (1.0 + t * t) * dys[i];
14259 }
14260 }
14261 Activation::Atan => {
14262 for i in 0..n {
14263 let x = xs[i];
14264 out[i] = dys[i] / (1.0 + x * x);
14265 }
14266 }
14267 }
14268}
14269
14270#[inline(always)]
14275fn erf_f64(x: f64) -> f64 {
14276 let s = x.signum();
14277 let x = x.abs();
14278 let t = 1.0 / (1.0 + 0.327_591_1 * x);
14279 let y = 1.0
14280 - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
14281 + 0.254_829_59)
14282 * t
14283 * (-x * x).exp();
14284 s * y
14285}
14286
14287#[inline(always)]
14290fn erf_f32(x: f32) -> f32 {
14291 let s = x.signum();
14292 let x = x.abs();
14293 let t = 1.0 / (1.0 + 0.327_591_1 * x);
14294 let y = 1.0
14295 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
14296 + 0.254_829_6)
14297 * t
14298 * (-x * x).exp();
14299 s * y
14300}
14301
14302fn narrow_thunk_closure(
14303 src: usize,
14304 dst: usize,
14305 outer: u32,
14306 src_stride: u32,
14307 dst_stride: u32,
14308 inner: u32,
14309 elem_bytes: u8,
14310) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
14311 let (outer, ss, ds, inner) = (
14312 outer as usize,
14313 src_stride as usize,
14314 dst_stride as usize,
14315 inner as usize,
14316 );
14317 if elem_bytes == 8 {
14318 Arc::new(move |base: *mut u8| unsafe {
14319 let s = sl_f64(src, base, outer * ss);
14320 let d = sl_mut_f64(dst, base, outer * ds);
14321 for o in 0..outer {
14322 d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14323 }
14324 })
14325 } else {
14326 Arc::new(move |base: *mut u8| unsafe {
14327 let s = sl(src, base, outer * ss);
14328 let d = sl_mut(dst, base, outer * ds);
14329 for o in 0..outer {
14330 d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14331 }
14332 })
14333 }
14334}
14335
14336unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
14337 if offset == usize::MAX {
14338 return &[];
14339 }
14340 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
14341}
14342
14343#[inline(always)]
14344unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
14345 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
14346}
14347
14348#[inline(always)]
14349unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
14350 if offset == usize::MAX {
14351 return &[];
14352 }
14353 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
14354}
14355
14356#[inline(always)]
14357unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
14358 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
14359}
14360
14361#[allow(dead_code)]
14366#[inline(always)]
14367unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
14368 if offset == usize::MAX {
14369 return &[];
14370 }
14371 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
14372}
14373
14374#[allow(dead_code)]
14375#[inline(always)]
14376unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
14377 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
14378}
14379
14380#[allow(dead_code)]
14381#[inline(always)]
14382unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
14383 if offset == usize::MAX {
14384 return &[];
14385 }
14386 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
14387}
14388
14389#[allow(dead_code)]
14390#[inline(always)]
14391unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
14392 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
14393}
14394
14395fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
14399 let rank = out_dims.len();
14400 let mut idx = vec![0u32; rank];
14401 for o in 0..out.len() {
14402 let mut src_off = 0usize;
14403 for d in 0..rank {
14404 src_off += idx[d] as usize * in_strides[d] as usize;
14405 }
14406 out[o] = inp[src_off];
14407 for d in (0..rank).rev() {
14409 idx[d] += 1;
14410 if idx[d] < out_dims[d] {
14411 break;
14412 }
14413 idx[d] = 0;
14414 }
14415 }
14416}
14417
14418fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
14424 match kind {
14425 Activation::Neg => {
14426 for (o, &v) in out.iter_mut().zip(inp) {
14427 *o = -v;
14428 }
14429 }
14430 Activation::Exp => {
14431 for (o, &v) in out.iter_mut().zip(inp) {
14432 *o = v.exp();
14433 }
14434 }
14435 Activation::Log => {
14436 for (o, &v) in out.iter_mut().zip(inp) {
14437 *o = v.ln();
14438 }
14439 }
14440 Activation::Sqrt => {
14441 for (o, &v) in out.iter_mut().zip(inp) {
14442 *o = v.sqrt();
14443 }
14444 }
14445 Activation::Rsqrt => {
14446 for (o, &v) in out.iter_mut().zip(inp) {
14447 *o = 1.0 / v.sqrt();
14448 }
14449 }
14450 Activation::Abs => {
14451 for (o, &v) in out.iter_mut().zip(inp) {
14452 *o = v.abs();
14453 }
14454 }
14455 Activation::Tanh => {
14456 for (o, &v) in out.iter_mut().zip(inp) {
14457 *o = v.tanh();
14458 }
14459 }
14460 Activation::Sigmoid => {
14461 for (o, &v) in out.iter_mut().zip(inp) {
14462 *o = 1.0 / (1.0 + (-v).exp());
14463 }
14464 }
14465 Activation::Relu => {
14466 for (o, &v) in out.iter_mut().zip(inp) {
14467 *o = v.max(0.0);
14468 }
14469 }
14470 Activation::Round => {
14471 for (o, &v) in out.iter_mut().zip(inp) {
14472 *o = v.round_ties_even();
14473 }
14474 }
14475 Activation::Sin => {
14476 for (o, &v) in out.iter_mut().zip(inp) {
14477 *o = v.sin();
14478 }
14479 }
14480 Activation::Cos => {
14481 for (o, &v) in out.iter_mut().zip(inp) {
14482 *o = v.cos();
14483 }
14484 }
14485 Activation::Tan => {
14486 for (o, &v) in out.iter_mut().zip(inp) {
14487 *o = v.tan();
14488 }
14489 }
14490 Activation::Atan => {
14491 for (o, &v) in out.iter_mut().zip(inp) {
14492 *o = v.atan();
14493 }
14494 }
14495 Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
14496 panic!(
14497 "apply_activation_f64: {kind:?} not yet implemented at f64. \
14498 Add when a workload needs it."
14499 );
14500 }
14501 }
14502}
14503
14504#[inline]
14505fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
14506 match op {
14507 BinaryOp::Add => a + b,
14508 BinaryOp::Sub => a - b,
14509 BinaryOp::Mul => a * b,
14510 BinaryOp::Div => a / b,
14511 BinaryOp::Max => a.max(b),
14512 BinaryOp::Min => a.min(b),
14513 BinaryOp::Pow => a.powf(b),
14514 }
14515}
14516
14517fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
14520 for o in 0..outer {
14521 for n in 0..inner {
14522 let mut acc = 0.0_f64;
14523 for r in 0..reduced {
14524 acc += inp[o * reduced * inner + r * inner + n];
14525 }
14526 out[o * inner + n] = acc;
14527 }
14528 }
14529}
14530
14531#[cfg(test)]
14532mod tests {
14533 use super::*;
14534 use rlx_ir::*;
14535
14536 #[test]
14542 fn narrow_rope_fuses_in_unfused_path() {
14543 let f = DType::F32;
14544 let mut g = Graph::new("nr_fuse");
14545 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); let cos = g.input("cos", Shape::new(&[16], f));
14548 let sin = g.input("sin", Shape::new(&[16], f));
14549 let q = g.narrow_(qkv, 2, 0, 64);
14551 let q_rope = g.rope(q, cos, sin, 16);
14552 g.set_outputs(vec![q_rope]);
14553
14554 let plan = rlx_opt::memory::plan_memory(&g);
14555 let arena = crate::arena::Arena::from_plan(plan);
14556 let sched = compile_thunks(&g, &arena);
14557
14558 let mut narrow_count = 0;
14559 let mut rope_with_stride: Option<u32> = None;
14560 for t in &sched.thunks {
14561 match t {
14562 Thunk::Narrow { .. } => narrow_count += 1,
14563 Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
14564 _ => {}
14565 }
14566 }
14567 assert_eq!(
14570 narrow_count, 0,
14571 "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
14572 );
14573 assert_eq!(
14574 rope_with_stride,
14575 Some(192),
14576 "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
14577 );
14578 }
14579
14580 #[test]
14583 fn ssm_selective_scan_matches_reference() {
14584 use rlx_ir::Philox4x32;
14585 let bch = 1usize;
14586 let s = 4usize;
14587 let h = 3usize;
14588 let n = 2usize;
14589
14590 let mut rng = Philox4x32::new(13);
14591 let mut x = vec![0f32; bch * s * h];
14592 rng.fill_normal(&mut x);
14593 let mut delta = vec![0f32; bch * s * h];
14594 for v in delta.iter_mut() {
14596 *v = (rng.next_f32() - 0.5) * 0.1;
14597 }
14598 let mut a = vec![0f32; h * n];
14599 for v in a.iter_mut() {
14600 *v = -(rng.next_f32() * 0.5 + 0.1);
14601 } let mut b = vec![0f32; bch * s * n];
14603 rng.fill_normal(&mut b);
14604 let mut c = vec![0f32; bch * s * n];
14605 rng.fill_normal(&mut c);
14606
14607 let mut expected = vec![0f32; bch * s * h];
14609 for bi in 0..bch {
14610 let mut state = vec![0f32; h * n];
14611 for si in 0..s {
14612 for ci in 0..h {
14613 let d = delta[bi * s * h + si * h + ci];
14614 let xv = x[bi * s * h + si * h + ci];
14615 let mut acc = 0f32;
14616 for ni in 0..n {
14617 let da = (d * a[ci * n + ni]).exp();
14618 state[ci * n + ni] =
14619 da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
14620 acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
14621 }
14622 expected[bi * s * h + si * h + ci] = acc;
14623 }
14624 }
14625 }
14626
14627 let f = DType::F32;
14629 let mut g = Graph::new("ssm");
14630 let xn = g.input("x", Shape::new(&[bch, s, h], f));
14631 let dn = g.input("delta", Shape::new(&[bch, s, h], f));
14632 let an = g.param("a", Shape::new(&[h, n], f));
14633 let bn = g.param("b", Shape::new(&[bch, s, n], f));
14634 let cn = g.param("c", Shape::new(&[bch, s, n], f));
14635 let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
14636 g.set_outputs(vec![yn]);
14637
14638 let plan = rlx_opt::memory::plan_memory(&g);
14639 let mut arena = crate::arena::Arena::from_plan(plan);
14640 let sched = compile_thunks(&g, &arena);
14641
14642 let xn_off = arena.byte_offset(xn);
14643 let dn_off = arena.byte_offset(dn);
14644 let an_off = arena.byte_offset(an);
14645 let bn_off = arena.byte_offset(bn);
14646 let cn_off = arena.byte_offset(cn);
14647 let yn_off = arena.byte_offset(yn);
14648 let buf = arena.raw_buf_mut();
14649 unsafe {
14650 let copy = |dst: *mut f32, data: &[f32]| {
14651 for (i, &v) in data.iter().enumerate() {
14652 *dst.add(i) = v;
14653 }
14654 };
14655 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14656 copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
14657 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14658 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14659 copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
14660 }
14661 execute_thunks(&sched, arena.raw_buf_mut());
14662
14663 let actual: Vec<f32> = unsafe {
14664 let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
14665 (0..bch * s * h).map(|i| *p.add(i)).collect()
14666 };
14667
14668 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14669 assert!(
14670 (e - a).abs() < 1e-3,
14671 "mismatch at {i}: expected {e}, got {a}"
14672 );
14673 }
14674 }
14675
14676 #[test]
14679 fn conv_1x1_fast_path_matches_scalar() {
14680 use rlx_ir::Philox4x32;
14681 let n = 2usize;
14683 let c_in = 4usize;
14684 let h = 3usize;
14685 let w = 3usize;
14686 let c_out = 5usize;
14687 let mut rng = Philox4x32::new(31);
14688 let mut x = vec![0f32; n * c_in * h * w];
14689 rng.fill_normal(&mut x);
14690 let mut weight = vec![0f32; c_out * c_in];
14691 rng.fill_normal(&mut weight);
14692
14693 let mut expected = vec![0f32; n * c_out * h * w];
14696 for ni in 0..n {
14697 for co in 0..c_out {
14698 for hi in 0..h {
14699 for wi in 0..w {
14700 let mut acc = 0f32;
14701 for ci in 0..c_in {
14702 acc += weight[co * c_in + ci]
14703 * x[((ni * c_in) + ci) * h * w + hi * w + wi];
14704 }
14705 expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
14706 }
14707 }
14708 }
14709 }
14710
14711 let f = DType::F32;
14713 let mut g = Graph::new("conv1x1");
14714 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
14715 let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
14716 let cn = g.add_node(
14718 rlx_ir::Op::Conv {
14719 kernel_size: vec![1, 1],
14720 stride: vec![1, 1],
14721 padding: vec![0, 0],
14722 dilation: vec![1, 1],
14723 groups: 1,
14724 },
14725 vec![xn, wn],
14726 Shape::new(&[n, c_out, h, w], f),
14727 );
14728 g.set_outputs(vec![cn]);
14729
14730 let plan = rlx_opt::memory::plan_memory(&g);
14731 let mut arena = crate::arena::Arena::from_plan(plan);
14732 let sched = compile_thunks(&g, &arena);
14733
14734 let saw_fast = sched
14736 .thunks
14737 .iter()
14738 .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
14739 let saw_slow = sched
14740 .thunks
14741 .iter()
14742 .any(|t| matches!(t, Thunk::Conv2D { .. }));
14743 assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
14744 assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
14745
14746 let xn_off = arena.byte_offset(xn);
14747 let wn_off = arena.byte_offset(wn);
14748 let cn_off = arena.byte_offset(cn);
14749 let buf = arena.raw_buf_mut();
14750 unsafe {
14751 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14752 for (i, &v) in x.iter().enumerate() {
14753 *xp.add(i) = v;
14754 }
14755 let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
14756 for (i, &v) in weight.iter().enumerate() {
14757 *wp.add(i) = v;
14758 }
14759 }
14760 execute_thunks(&sched, arena.raw_buf_mut());
14761
14762 let actual: Vec<f32> = unsafe {
14763 let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
14764 (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
14765 };
14766
14767 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14768 assert!(
14769 (e - a).abs() < 1e-3,
14770 "mismatch at {i}: expected {e}, got {a}"
14771 );
14772 }
14773 }
14774
14775 #[test]
14778 fn dequant_matmul_int8_sym_matches_reference() {
14779 use rlx_ir::Philox4x32;
14780 use rlx_ir::quant::QuantScheme;
14781
14782 let m = 3usize;
14783 let k = 8usize;
14784 let n = 4usize;
14785 let block_size = 4usize; let blocks_per_col = k / block_size;
14787
14788 let mut rng = Philox4x32::new(99);
14790 let mut x = vec![0f32; m * k];
14791 rng.fill_normal(&mut x);
14792 let w_q: Vec<i8> = (0..(k * n))
14793 .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
14794 .collect();
14795 let scales: Vec<f32> = (0..(blocks_per_col * n))
14796 .map(|i| 0.01 + 0.001 * i as f32)
14797 .collect();
14798
14799 let mut w_f32 = vec![0f32; k * n];
14801 for p in 0..k {
14802 let block = p / block_size;
14803 for j in 0..n {
14804 let s = scales[block * n + j];
14805 w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
14806 }
14807 }
14808 let mut expected = vec![0f32; m * n];
14809 for i in 0..m {
14810 for j in 0..n {
14811 let mut acc = 0f32;
14812 for p in 0..k {
14813 acc += x[i * k + p] * w_f32[p * n + j];
14814 }
14815 expected[i * n + j] = acc;
14816 }
14817 }
14818
14819 let f = DType::F32;
14821 let mut g = Graph::new("dq");
14822 let xn = g.input("x", Shape::new(&[m, k], f));
14823 let wn = g.param("w", Shape::new(&[k, n], DType::I8));
14824 let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
14825 let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); let dq = g.dequant_matmul(
14827 xn,
14828 wn,
14829 sn,
14830 zn,
14831 QuantScheme::Int8Block {
14832 block_size: block_size as u32,
14833 },
14834 Shape::new(&[m, n], f),
14835 );
14836 g.set_outputs(vec![dq]);
14837
14838 let plan = rlx_opt::memory::plan_memory(&g);
14839 let mut arena = crate::arena::Arena::from_plan(plan);
14840 let sched = compile_thunks(&g, &arena);
14841
14842 let xn_off = arena.byte_offset(xn);
14843 let wn_off = arena.byte_offset(wn);
14844 let sn_off = arena.byte_offset(sn);
14845 let zn_off = arena.byte_offset(zn);
14846 let dq_off = arena.byte_offset(dq);
14847 let buf = arena.raw_buf_mut();
14848 unsafe {
14849 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14851 for (i, &v) in x.iter().enumerate() {
14852 *xp.add(i) = v;
14853 }
14854 let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
14855 for (i, &v) in scales.iter().enumerate() {
14856 *sp.add(i) = v;
14857 }
14858 let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
14859 for i in 0..(blocks_per_col * n) {
14860 *zp.add(i) = 0.0;
14861 }
14862 let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
14864 for (i, &v) in w_q.iter().enumerate() {
14865 *wp.add(i) = v;
14866 }
14867 }
14868 execute_thunks(&sched, arena.raw_buf_mut());
14869
14870 let actual: Vec<f32> = unsafe {
14871 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
14872 (0..m * n).map(|i| *p.add(i)).collect()
14873 };
14874
14875 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14876 assert!(
14877 (e - a).abs() < 1e-3,
14878 "mismatch at {i}: expected {e}, got {a}"
14879 );
14880 }
14881 }
14882
14883 #[test]
14885 fn lora_matmul_matches_unfused_reference() {
14886 use rlx_ir::Philox4x32;
14887
14888 let m = 4usize;
14889 let k = 8usize;
14890 let n = 6usize;
14891 let r = 2usize;
14892 let scale = 0.5f32;
14893
14894 let mut rng = Philox4x32::new(42);
14896 let mut x = vec![0f32; m * k];
14897 rng.fill_normal(&mut x);
14898 let mut w = vec![0f32; k * n];
14899 rng.fill_normal(&mut w);
14900 let mut a = vec![0f32; k * r];
14901 rng.fill_normal(&mut a);
14902 let mut b = vec![0f32; r * n];
14903 rng.fill_normal(&mut b);
14904
14905 let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
14907 let mut o = vec![0f32; rows * cols];
14908 for i in 0..rows {
14909 for j in 0..cols {
14910 let mut acc = 0f32;
14911 for p in 0..inner {
14912 acc += a_buf[i * inner + p] * b_buf[p * cols + j];
14913 }
14914 o[i * cols + j] = acc;
14915 }
14916 }
14917 o
14918 };
14919 let xw = naive(&x, &w, m, k, n);
14920 let xa = naive(&x, &a, m, k, r);
14921 let xab = naive(&xa, &b, m, r, n);
14922 let mut expected = xw;
14923 for i in 0..(m * n) {
14924 expected[i] += scale * xab[i];
14925 }
14926
14927 let f = DType::F32;
14929 let mut g = Graph::new("lora");
14930 let xn = g.input("x", Shape::new(&[m, k], f));
14931 let wn = g.param("w", Shape::new(&[k, n], f));
14932 let an = g.param("a", Shape::new(&[k, r], f));
14933 let bn = g.param("b", Shape::new(&[r, n], f));
14934 let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
14935 g.set_outputs(vec![lm]);
14936
14937 let plan = rlx_opt::memory::plan_memory(&g);
14938 let mut arena = crate::arena::Arena::from_plan(plan);
14939 let sched = compile_thunks(&g, &arena);
14940
14941 let xn_off = arena.byte_offset(xn);
14942 let wn_off = arena.byte_offset(wn);
14943 let an_off = arena.byte_offset(an);
14944 let bn_off = arena.byte_offset(bn);
14945 let lm_off = arena.byte_offset(lm);
14946 let buf = arena.raw_buf_mut();
14947 unsafe {
14948 let copy = |dst: *mut f32, data: &[f32]| {
14949 for (i, &v) in data.iter().enumerate() {
14950 *dst.add(i) = v;
14951 }
14952 };
14953 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14954 copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
14955 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14956 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14957 }
14958 execute_thunks(&sched, arena.raw_buf_mut());
14959
14960 let actual: Vec<f32> = unsafe {
14961 let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
14962 (0..m * n).map(|i| *p.add(i)).collect()
14963 };
14964
14965 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14966 assert!(
14967 (e - a).abs() < 1e-3,
14968 "mismatch at {i}: expected {e}, got {a}"
14969 );
14970 }
14971 }
14972
14973 #[test]
14975 fn sample_temperature_zero_is_argmax() {
14976 let f = DType::F32;
14979 let mut g = Graph::new("samp");
14980 let logits = g.input("logits", Shape::new(&[1, 8], f));
14981 let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
14982 g.set_outputs(vec![s]);
14983 let plan = rlx_opt::memory::plan_memory(&g);
14984 let mut arena = crate::arena::Arena::from_plan(plan);
14985 let sched = compile_thunks(&g, &arena);
14986
14987 let logits_off = arena.byte_offset(logits);
14988 let s_off = arena.byte_offset(s);
14989 let buf = arena.raw_buf_mut();
14990 unsafe {
14991 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
14992 let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
14994 for (i, &v) in inputs.iter().enumerate() {
14995 *p.add(i) = v;
14996 }
14997 }
14998 execute_thunks(&sched, arena.raw_buf_mut());
14999
15000 let token = unsafe {
15001 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15002 *p as usize
15003 };
15004 assert_eq!(token, 5, "low-temp sampling should pick the argmax");
15005 }
15006
15007 #[test]
15008 fn sample_top_k_one_is_deterministic() {
15009 let f = DType::F32;
15011 let mut g = Graph::new("samp_k1");
15012 let logits = g.input("logits", Shape::new(&[1, 4], f));
15013 let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
15014 g.set_outputs(vec![s]);
15015 let plan = rlx_opt::memory::plan_memory(&g);
15016 let mut arena = crate::arena::Arena::from_plan(plan);
15017 let sched = compile_thunks(&g, &arena);
15018
15019 let logits_off = arena.byte_offset(logits);
15020 let s_off = arena.byte_offset(s);
15021 let buf = arena.raw_buf_mut();
15022 unsafe {
15023 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
15024 let inputs = [0.1f32, 5.0, 0.3, 0.4]; for (i, &v) in inputs.iter().enumerate() {
15026 *p.add(i) = v;
15027 }
15028 }
15029 execute_thunks(&sched, arena.raw_buf_mut());
15030 let token = unsafe {
15031 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15032 *p as usize
15033 };
15034 assert_eq!(token, 1);
15035 }
15036
15037 #[test]
15039 fn cumsum_inclusive_matches_naive() {
15040 let f = DType::F32;
15041 let mut g = Graph::new("cumsum");
15042 let x = g.input("x", Shape::new(&[2, 4], f));
15043 let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
15044 g.set_outputs(vec![cs]);
15045 let plan = rlx_opt::memory::plan_memory(&g);
15046 let mut arena = crate::arena::Arena::from_plan(plan);
15047 let sched = compile_thunks(&g, &arena);
15048
15049 let x_off = arena.byte_offset(x);
15051 let out_off = arena.byte_offset(cs);
15052 let buf = arena.raw_buf_mut();
15053 unsafe {
15054 let p = buf.as_mut_ptr().add(x_off) as *mut f32;
15055 let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
15056 for (i, &v) in inputs.iter().enumerate() {
15057 *p.add(i) = v;
15058 }
15059 }
15060 execute_thunks(&sched, arena.raw_buf_mut());
15061
15062 let out: Vec<f32> = unsafe {
15063 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
15064 (0..8).map(|i| *p.add(i)).collect()
15065 };
15066 assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
15067 }
15068
15069 #[test]
15073 fn narrow_attention_fuses_in_unfused_path() {
15074 let f = DType::F32;
15075 let mut g = Graph::new("nattn_fuse");
15076 let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); let mask = g.input("mask", Shape::new(&[8, 16], f));
15079 let q = g.narrow_(qkv, 2, 0, 64);
15080 let k = g.narrow_(qkv, 2, 64, 64);
15081 let v = g.narrow_(qkv, 2, 128, 64);
15082 let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
15083 g.set_outputs(vec![attn]);
15084
15085 let plan = rlx_opt::memory::plan_memory(&g);
15086 let arena = crate::arena::Arena::from_plan(plan);
15087 let sched = compile_thunks(&g, &arena);
15088
15089 let mut narrow_count = 0;
15090 let mut attn_strides: Option<(u32, u32, u32)> = None;
15091 for t in &sched.thunks {
15092 match t {
15093 Thunk::Narrow { .. } => narrow_count += 1,
15094 Thunk::Attention {
15095 q_row_stride,
15096 k_row_stride,
15097 v_row_stride,
15098 ..
15099 } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
15100 _ => {}
15101 }
15102 }
15103 assert_eq!(
15106 narrow_count, 0,
15107 "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
15108 );
15109 assert_eq!(
15110 attn_strides,
15111 Some((192, 192, 192)),
15112 "Attention should walk Q/K/V with parent row stride 192"
15113 );
15114 }
15115
15116 fn run_graph(
15127 g: &Graph,
15128 inputs: &[(NodeId, &[f32])],
15129 out_id: NodeId,
15130 out_len: usize,
15131 ) -> Vec<f32> {
15132 let plan = rlx_opt::memory::plan_memory(g);
15133 let mut arena = crate::arena::Arena::from_plan(plan);
15134 let sched = compile_thunks(g, &arena);
15135 for &(id, data) in inputs {
15136 let off = arena.byte_offset(id);
15137 let buf = arena.raw_buf_mut();
15138 unsafe {
15139 let p = buf.as_mut_ptr().add(off) as *mut f32;
15140 for (i, &v) in data.iter().enumerate() {
15141 *p.add(i) = v;
15142 }
15143 }
15144 }
15145 execute_thunks(&sched, arena.raw_buf_mut());
15146 let off = arena.byte_offset(out_id);
15147 unsafe {
15148 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15149 (0..out_len).map(|i| *p.add(i)).collect()
15150 }
15151 }
15152
15153 #[test]
15154 fn relu_backward_matches_mask() {
15155 let f = DType::F32;
15156 let len = 7usize;
15157 let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
15158 let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
15159
15160 let mut g = Graph::new("relu_bw");
15161 let xn = g.input("x", Shape::new(&[len], f));
15162 let dyn_ = g.input("dy", Shape::new(&[len], f));
15163 let dx = g.relu_backward(xn, dyn_);
15164 g.set_outputs(vec![dx]);
15165
15166 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
15167 let expected: Vec<f32> = x
15171 .iter()
15172 .zip(&dy)
15173 .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
15174 .collect();
15175 for (a, e) in actual.iter().zip(&expected) {
15176 assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
15177 }
15178 }
15179
15180 #[test]
15181 fn maxpool2d_backward_routes_to_argmax() {
15182 let f = DType::F32;
15183 let x: Vec<f32> = vec![
15185 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
15186 ];
15187 let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
15191
15192 let mut g = Graph::new("maxpool_bw");
15193 let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
15194 let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
15195 let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
15196 g.set_outputs(vec![dx]);
15197
15198 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
15199 let mut expected = vec![0f32; 16];
15200 expected[5] = 0.5;
15201 expected[7] = 1.0;
15202 expected[13] = 2.0;
15203 expected[15] = 4.0;
15204 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15205 assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
15206 }
15207 }
15208
15209 #[test]
15210 fn conv2d_backward_input_matches_numerical_gradient() {
15211 use rlx_ir::Philox4x32;
15212 let n = 1usize;
15215 let c_in = 2usize;
15216 let h = 4usize;
15217 let w = 4usize;
15218 let c_out = 3usize;
15219 let kh = 3usize;
15220 let kw = 3usize;
15221 let ph = 1usize;
15222 let pw = 1usize;
15223 let sh = 1usize;
15224 let sw = 1usize;
15225 let h_out = (h + 2 * ph - kh) / sh + 1;
15227 let w_out = (w + 2 * pw - kw) / sw + 1;
15228 assert_eq!(h_out, 4);
15229 assert_eq!(w_out, 4);
15230
15231 let mut rng = Philox4x32::new(7);
15232 let mut x = vec![0f32; n * c_in * h * w];
15233 rng.fill_normal(&mut x);
15234 let mut wt = vec![0f32; c_out * c_in * kh * kw];
15235 rng.fill_normal(&mut wt);
15236 let mut dy = vec![0f32; n * c_out * h_out * w_out];
15237 rng.fill_normal(&mut dy);
15238
15239 let f = DType::F32;
15241 let mut g = Graph::new("conv_bwi");
15242 let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15243 let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
15244 let dx = g.conv2d_backward_input(
15245 dy_in,
15246 w_in,
15247 Shape::new(&[n, c_in, h, w], f),
15248 vec![kh, kw],
15249 vec![sh, sw],
15250 vec![ph, pw],
15251 vec![1, 1],
15252 1,
15253 );
15254 g.set_outputs(vec![dx]);
15255 let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
15256
15257 let forward = |x: &[f32]| -> Vec<f32> {
15261 let mut out = vec![0f32; n * c_out * h_out * w_out];
15262 for ni in 0..n {
15263 for co in 0..c_out {
15264 for ho in 0..h_out {
15265 for wo in 0..w_out {
15266 let mut acc = 0f32;
15267 for ci in 0..c_in {
15268 for ki in 0..kh {
15269 for kj in 0..kw {
15270 let hi = ho * sh + ki;
15271 let wi = wo * sw + kj;
15272 if hi < ph || wi < pw {
15273 continue;
15274 }
15275 let hi = hi - ph;
15276 let wi = wi - pw;
15277 if hi >= h || wi >= w {
15278 continue;
15279 }
15280 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15281 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15282 acc += xv * wv;
15283 }
15284 }
15285 }
15286 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15287 }
15288 }
15289 }
15290 }
15291 out
15292 };
15293 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15294 let eps = 1e-3f32;
15295 let mut numerical = vec![0f32; x.len()];
15296 for i in 0..x.len() {
15297 let saved = x[i];
15298 x[i] = saved + eps;
15299 let plus = dot(&forward(&x), &dy);
15300 x[i] = saved - eps;
15301 let minus = dot(&forward(&x), &dy);
15302 x[i] = saved;
15303 numerical[i] = (plus - minus) / (2.0 * eps);
15304 }
15305 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15306 assert!(
15308 (a - n).abs() < 5e-3,
15309 "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
15310 );
15311 }
15312 }
15313
15314 #[test]
15315 fn conv2d_backward_weight_matches_numerical_gradient() {
15316 use rlx_ir::Philox4x32;
15317 let n = 2usize;
15318 let c_in = 2usize;
15319 let h = 4usize;
15320 let w = 4usize;
15321 let c_out = 2usize;
15322 let kh = 3usize;
15323 let kw = 3usize;
15324 let ph = 0usize;
15325 let pw = 0usize;
15326 let sh = 1usize;
15327 let sw = 1usize;
15328 let h_out = (h + 2 * ph - kh) / sh + 1;
15329 let w_out = (w + 2 * pw - kw) / sw + 1;
15330
15331 let mut rng = Philox4x32::new(11);
15332 let mut x = vec![0f32; n * c_in * h * w];
15333 rng.fill_normal(&mut x);
15334 let mut wt = vec![0f32; c_out * c_in * kh * kw];
15335 rng.fill_normal(&mut wt);
15336 let mut dy = vec![0f32; n * c_out * h_out * w_out];
15337 rng.fill_normal(&mut dy);
15338
15339 let f = DType::F32;
15340 let mut g = Graph::new("conv_bww");
15341 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
15342 let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15343 let dwn = g.conv2d_backward_weight(
15344 xn,
15345 dyn_,
15346 Shape::new(&[c_out, c_in, kh, kw], f),
15347 vec![kh, kw],
15348 vec![sh, sw],
15349 vec![ph, pw],
15350 vec![1, 1],
15351 1,
15352 );
15353 g.set_outputs(vec![dwn]);
15354 let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
15355
15356 let forward = |wt: &[f32]| -> Vec<f32> {
15357 let mut out = vec![0f32; n * c_out * h_out * w_out];
15358 for ni in 0..n {
15359 for co in 0..c_out {
15360 for ho in 0..h_out {
15361 for wo in 0..w_out {
15362 let mut acc = 0f32;
15363 for ci in 0..c_in {
15364 for ki in 0..kh {
15365 for kj in 0..kw {
15366 let hi = ho + ki;
15367 let wi = wo + kj;
15368 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15369 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15370 acc += xv * wv;
15371 }
15372 }
15373 }
15374 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15375 }
15376 }
15377 }
15378 }
15379 out
15380 };
15381 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15382 let eps = 1e-3f32;
15383 let mut numerical = vec![0f32; wt.len()];
15384 for i in 0..wt.len() {
15385 let saved = wt[i];
15386 wt[i] = saved + eps;
15387 let plus = dot(&forward(&wt), &dy);
15388 wt[i] = saved - eps;
15389 let minus = dot(&forward(&wt), &dy);
15390 wt[i] = saved;
15391 numerical[i] = (plus - minus) / (2.0 * eps);
15392 }
15393 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15394 assert!(
15395 (a - n).abs() < 5e-3,
15396 "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
15397 );
15398 }
15399 }
15400
15401 #[test]
15402 fn softmax_cross_entropy_matches_reference() {
15403 let f = DType::F32;
15404 let logits: Vec<f32> = vec![
15405 1.0, 2.0, 3.0, -1.0, 0.0, 4.0, 5.0, 5.0, 5.0, ];
15409 let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
15410
15411 let mut g = Graph::new("sce");
15412 let lg = g.input("logits", Shape::new(&[3, 3], f));
15413 let lb = g.input("labels", Shape::new(&[3], f));
15414 let loss = g.softmax_cross_entropy_with_logits(lg, lb);
15415 g.set_outputs(vec![loss]);
15416 let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
15417
15418 let mut expected = vec![0f32; 3];
15420 for ni in 0..3 {
15421 let row = &logits[ni * 3..(ni + 1) * 3];
15422 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15423 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15424 let lse = m + sum.ln();
15425 let label_idx = labels[ni] as usize;
15426 expected[ni] = lse - row[label_idx];
15427 }
15428 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15429 assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
15430 }
15431 }
15432
15433 #[test]
15434 fn softmax_cross_entropy_backward_matches_numerical_gradient() {
15435 use rlx_ir::Philox4x32;
15436 let n = 4usize;
15437 let c = 5usize;
15438 let mut rng = Philox4x32::new(23);
15439 let mut logits = vec![0f32; n * c];
15440 rng.fill_normal(&mut logits);
15441 let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
15442 let mut d_loss = vec![0f32; n];
15443 rng.fill_normal(&mut d_loss);
15444
15445 let f = DType::F32;
15446 let mut g = Graph::new("sce_bw");
15447 let lg = g.input("logits", Shape::new(&[n, c], f));
15448 let lb = g.input("labels", Shape::new(&[n], f));
15449 let dl = g.input("d_loss", Shape::new(&[n], f));
15450 let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
15451 g.set_outputs(vec![dlogits]);
15452 let analytical = run_graph(
15453 &g,
15454 &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
15455 dlogits,
15456 n * c,
15457 );
15458
15459 let sce_loss = |logits: &[f32]| -> Vec<f32> {
15461 let mut out = vec![0f32; n];
15462 for ni in 0..n {
15463 let row = &logits[ni * c..(ni + 1) * c];
15464 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15465 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15466 out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
15467 }
15468 out
15469 };
15470 let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
15471 let eps = 1e-3f32;
15472 let mut numerical = vec![0f32; logits.len()];
15473 for i in 0..logits.len() {
15474 let saved = logits[i];
15475 logits[i] = saved + eps;
15476 let plus = dot(&sce_loss(&logits), &d_loss);
15477 logits[i] = saved - eps;
15478 let minus = dot(&sce_loss(&logits), &d_loss);
15479 logits[i] = saved;
15480 numerical[i] = (plus - minus) / (2.0 * eps);
15481 }
15482 for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
15483 assert!(
15484 (a - num).abs() < 5e-3,
15485 "sce_bw[{i}]: analytical {a} vs numerical {num}"
15486 );
15487 }
15488 }
15489
15490 fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
15503 for node in graph.nodes() {
15504 if let Op::Constant { data } = &node.op
15505 && arena.has_buffer(node.id)
15506 && !data.is_empty()
15507 {
15508 let buf = arena.slice_mut(node.id);
15509 let n_floats = data.len() / 4;
15510 let n = buf.len().min(n_floats);
15511 for i in 0..n {
15512 let bytes = [
15513 data[i * 4],
15514 data[i * 4 + 1],
15515 data[i * 4 + 2],
15516 data[i * 4 + 3],
15517 ];
15518 buf[i] = f32::from_le_bytes(bytes);
15519 }
15520 }
15521 }
15522 }
15523
15524 fn prepare(
15528 graph: &Graph,
15529 seed_inputs: &[(NodeId, &[f32])],
15530 ) -> (ThunkSchedule, crate::arena::Arena) {
15531 let plan = rlx_opt::memory::plan_memory(graph);
15532 let mut arena = crate::arena::Arena::from_plan(plan);
15533 let sched = compile_thunks(graph, &arena);
15534 fill_constants_into_arena(graph, &mut arena);
15535 for &(id, data) in seed_inputs {
15536 let off = arena.byte_offset(id);
15537 let buf = arena.raw_buf_mut();
15538 unsafe {
15539 let p = buf.as_mut_ptr().add(off) as *mut f32;
15540 for (i, &v) in data.iter().enumerate() {
15541 *p.add(i) = v;
15542 }
15543 }
15544 }
15545 (sched, arena)
15546 }
15547
15548 fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
15549 let off = arena.byte_offset(id);
15550 unsafe {
15551 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15552 (0..len).map(|i| *p.add(i)).collect()
15553 }
15554 }
15555
15556 fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
15557 let off = arena.byte_offset(id);
15558 let buf = arena.raw_buf_mut();
15559 unsafe {
15560 let p = buf.as_mut_ptr().add(off) as *mut f32;
15561 for (i, &v) in data.iter().enumerate() {
15562 *p.add(i) = v;
15563 }
15564 }
15565 }
15566
15567 fn prepare_f64(
15569 graph: &Graph,
15570 seed_inputs: &[(NodeId, &[f64])],
15571 ) -> (ThunkSchedule, crate::arena::Arena) {
15572 let plan = rlx_opt::memory::plan_memory(graph);
15573 let mut arena = crate::arena::Arena::from_plan(plan);
15574 let sched = compile_thunks(graph, &arena);
15575 fill_constants_into_arena(graph, &mut arena);
15576 for &(id, data) in seed_inputs {
15577 let off = arena.byte_offset(id);
15578 let buf = arena.raw_buf_mut();
15579 unsafe {
15580 let p = buf.as_mut_ptr().add(off) as *mut f64;
15581 for (i, &v) in data.iter().enumerate() {
15582 *p.add(i) = v;
15583 }
15584 }
15585 }
15586 (sched, arena)
15587 }
15588
15589 fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
15590 let off = arena.byte_offset(id);
15591 unsafe {
15592 let p = arena.raw_buf().as_ptr().add(off) as *const f64;
15593 (0..len).map(|i| *p.add(i)).collect()
15594 }
15595 }
15596
15597 #[test]
15607 fn dense_solve_f64_end_to_end() {
15608 let mut g = Graph::new("solve_e2e");
15609 let a = g.input("A", Shape::new(&[2, 2], DType::F64));
15610 let b = g.input("b", Shape::new(&[2], DType::F64));
15611 let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
15612 g.set_outputs(vec![x]);
15613
15614 let a_data = [2.0, 1.0, 1.0, 3.0_f64];
15615 let b_data = [5.0, 10.0_f64];
15616 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15617 execute_thunks(&sched, arena.raw_buf_mut());
15618
15619 let got = read_arena_f64(&arena, x, 2);
15620 let want = [1.0, 3.0_f64];
15621 for i in 0..2 {
15622 assert!(
15623 (got[i] - want[i]).abs() < 1e-12,
15624 "x[{i}] = {} (expected {})",
15625 got[i],
15626 want[i]
15627 );
15628 }
15629 }
15630
15631 #[test]
15637 fn dense_solve_f64_5x5_laplacian() {
15638 let n = 5usize;
15639 let mut g = Graph::new("solve_5x5");
15640 let a = g.input("A", Shape::new(&[n, n], DType::F64));
15641 let b = g.input("b", Shape::new(&[n], DType::F64));
15642 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15643 g.set_outputs(vec![x]);
15644
15645 let mut a_data = vec![0.0_f64; n * n];
15647 for i in 0..n {
15648 a_data[i * n + i] = 2.0;
15649 if i > 0 {
15650 a_data[i * n + (i - 1)] = -1.0;
15651 }
15652 if i + 1 < n {
15653 a_data[i * n + (i + 1)] = -1.0;
15654 }
15655 }
15656 let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
15657 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15658 execute_thunks(&sched, arena.raw_buf_mut());
15659
15660 let got = read_arena_f64(&arena, x, n);
15661 let mut residual = vec![0.0_f64; n];
15663 for i in 0..n {
15664 for j in 0..n {
15665 residual[i] += a_data[i * n + j] * got[j];
15666 }
15667 }
15668 for i in 0..n {
15669 assert!(
15670 (residual[i] - b_data[i]).abs() < 1e-10,
15671 "row {i}: residual {} vs b {}",
15672 residual[i],
15673 b_data[i]
15674 );
15675 }
15676 }
15677
15678 #[test]
15697 fn hello_resistor_gradient_end_to_end() {
15698 use rlx_opt::autodiff::grad_with_loss;
15699 let n = 3usize;
15700
15701 let mut g = Graph::new("hello_resistor");
15703 let a = g.param("A", Shape::new(&[n, n], DType::F64));
15704 let b = g.input("b", Shape::new(&[n], DType::F64));
15705 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15706 let loss = g.reduce(
15707 x,
15708 ReduceOp::Sum,
15709 vec![0],
15710 false,
15711 Shape::new(&[1], DType::F64),
15712 );
15713 g.set_outputs(vec![loss]);
15714
15715 let bwd = grad_with_loss(&g, &[a, b]);
15717 assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
15718
15719 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
15723 for node in graph.nodes() {
15724 let name = match &node.op {
15725 rlx_ir::Op::Input { name } => Some(name.as_str()),
15726 rlx_ir::Op::Param { name } => Some(name.as_str()),
15727 _ => None,
15728 };
15729 if name == Some(want) {
15730 return node.id;
15731 }
15732 }
15733 panic!("no node named {want:?} in bwd graph");
15734 };
15735 let a_bwd = find_by_name(&bwd, "A");
15736 let b_bwd = find_by_name(&bwd, "b");
15737 let d_out_bwd = find_by_name(&bwd, "d_output");
15738
15739 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
15743 let b_data = [1.0, 2.0, 3.0_f64];
15744 let d_output = [1.0_f64]; let (sched, mut arena) = prepare_f64(
15748 &bwd,
15749 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
15750 );
15751 execute_thunks(&sched, arena.raw_buf_mut());
15752
15753 let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
15754 let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
15755 let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
15756
15757 let x_ref = {
15760 let mut a = a_data;
15761 let mut b = b_data;
15762 let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
15763 assert_eq!(info, 0);
15764 b
15765 };
15766 let loss_ref: f64 = x_ref.iter().sum();
15767 let db_ref = {
15769 let mut at = [0.0_f64; 9];
15770 for i in 0..n {
15771 for j in 0..n {
15772 at[i * n + j] = a_data[j * n + i];
15773 }
15774 }
15775 let mut ones = [1.0_f64; 3];
15776 let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
15777 assert_eq!(info, 0);
15778 ones
15779 };
15780 let mut da_ref = [0.0_f64; 9];
15782 for i in 0..n {
15783 for j in 0..n {
15784 da_ref[i * n + j] = -db_ref[i] * x_ref[j];
15785 }
15786 }
15787
15788 assert!(
15790 (loss_out[0] - loss_ref).abs() < 1e-10,
15791 "loss: got {}, want {}",
15792 loss_out[0],
15793 loss_ref
15794 );
15795 for i in 0..n {
15796 assert!(
15797 (db_out[i] - db_ref[i]).abs() < 1e-10,
15798 "db[{i}]: got {}, want {}",
15799 db_out[i],
15800 db_ref[i]
15801 );
15802 }
15803 for i in 0..n * n {
15804 assert!(
15805 (da_out[i] - da_ref[i]).abs() < 1e-10,
15806 "dA[{i}]: got {}, want {}",
15807 da_out[i],
15808 da_ref[i]
15809 );
15810 }
15811
15812 let h = 1e-6_f64;
15815 for k in 0..n {
15816 let mut bp = b_data;
15817 bp[k] += h;
15818 let mut bm = b_data;
15819 bm[k] -= h;
15820 let lp = {
15821 let mut ac = a_data;
15822 let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
15823 assert_eq!(info, 0);
15824 bp.iter().sum::<f64>()
15825 };
15826 let lm = {
15827 let mut ac = a_data;
15828 let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
15829 assert_eq!(info, 0);
15830 bm.iter().sum::<f64>()
15831 };
15832 let fd = (lp - lm) / (2.0 * h);
15833 assert!(
15834 (db_out[k] - fd).abs() < 1e-7,
15835 "FD mismatch on db[{k}]: AD={} FD={}",
15836 db_out[k],
15837 fd
15838 );
15839 }
15840 }
15841
15842 #[test]
15847 fn scan_geometric_growth_f64() {
15848 let n = 3usize;
15849 let length = 10u32;
15850
15851 let mut body = Graph::new("scan_body");
15853 let x = body.input("carry", Shape::new(&[n], DType::F64));
15854 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
15855 let scale = body.add_node(
15856 Op::Constant { data: scale_bytes },
15857 vec![],
15858 Shape::new(&[n], DType::F64),
15859 );
15860 let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
15861 let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
15862 body.set_outputs(vec![next]);
15863
15864 let mut g = Graph::new("scan_outer");
15866 let init = g.input("init", Shape::new(&[n], DType::F64));
15867 let final_carry = g.scan(init, body, length);
15868 g.set_outputs(vec![final_carry]);
15869
15870 let init_data = vec![1.0_f64; n];
15871 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
15872 execute_thunks(&sched, arena.raw_buf_mut());
15873 let got = read_arena_f64(&arena, final_carry, n);
15874 let want: f64 = 1.1_f64.powi(length as i32);
15875 for i in 0..n {
15876 assert!(
15877 (got[i] - want).abs() < 1e-12,
15878 "got[{i}] = {} want {}",
15879 got[i],
15880 want
15881 );
15882 }
15883 }
15884
15885 #[test]
15892 fn scan_with_xs_cumulative_sum() {
15893 let n = 3usize;
15894 let length = 4u32;
15895
15896 let mut body = Graph::new("cumsum_body");
15897 let carry = body.input("carry", Shape::new(&[n], DType::F64));
15899 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
15900 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
15901 body.set_outputs(vec![next]);
15902
15903 let mut g = Graph::new("cumsum_outer");
15904 let init = g.input("init", Shape::new(&[n], DType::F64));
15905 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
15906 let final_carry = g.scan_with_xs(init, &[xs], body, length);
15907 g.set_outputs(vec![final_carry]);
15908
15909 let init_data = vec![0.0_f64; n];
15910 let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
15912 execute_thunks(&sched, arena.raw_buf_mut());
15913 let got = read_arena_f64(&arena, final_carry, n);
15914
15915 let mut want = init_data.clone();
15919 for t in 0..length as usize {
15920 for j in 0..n {
15921 want[j] += xs_data[t * n + j];
15922 }
15923 }
15924 for i in 0..n {
15925 assert!(
15926 (got[i] - want[i]).abs() < 1e-12,
15927 "got[{i}] = {} want {}",
15928 got[i],
15929 want[i]
15930 );
15931 }
15932 }
15933
15934 #[test]
15938 fn scan_with_xs_be_with_drive() {
15939 let n = 3usize;
15940 let length = 4u32;
15941 let dt = 0.1_f64;
15942
15943 let mut m_data = vec![0.0_f64; n * n];
15944 for i in 0..n {
15945 m_data[i * n + i] = 1.0 + dt * 2.0;
15946 if i > 0 {
15947 m_data[i * n + (i - 1)] = -dt;
15948 }
15949 if i + 1 < n {
15950 m_data[i * n + (i + 1)] = -dt;
15951 }
15952 }
15953 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
15954
15955 let mut body = Graph::new("be_drive_body");
15956 let carry = body.input("carry", Shape::new(&[n], DType::F64));
15957 let drive = body.input("drive", Shape::new(&[n], DType::F64));
15958 let m = body.add_node(
15959 Op::Constant { data: m_bytes },
15960 vec![],
15961 Shape::new(&[n, n], DType::F64),
15962 );
15963 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
15964 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
15965 body.set_outputs(vec![next]);
15966
15967 let mut g = Graph::new("be_drive_outer");
15968 let init = g.input("init", Shape::new(&[n], DType::F64));
15969 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
15970 let final_carry = g.scan_with_xs(init, &[xs], body, length);
15971 g.set_outputs(vec![final_carry]);
15972
15973 let init_data = vec![0.0_f64; n];
15974 let mut xs_data = vec![0.0_f64; length as usize * n];
15977 xs_data[0] = 1.0;
15978
15979 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
15980 execute_thunks(&sched, arena.raw_buf_mut());
15981 let got = read_arena_f64(&arena, final_carry, n);
15982
15983 let mut x = init_data.clone();
15985 for t in 0..length as usize {
15986 for j in 0..n {
15987 x[j] += xs_data[t * n + j];
15988 }
15989 let mut a_copy = m_data.clone();
15990 crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
15991 }
15992 for i in 0..n {
15993 assert!(
15994 (got[i] - x[i]).abs() < 1e-12,
15995 "got[{i}] = {} ref {}",
15996 got[i],
15997 x[i]
15998 );
15999 }
16000 }
16001
16002 #[test]
16008 fn batched_dense_solve_gradient_matches_per_batch_analytic() {
16009 use rlx_opt::autodiff::grad_with_loss;
16010 let n = 3usize;
16011 let batch = 4usize;
16012
16013 let mut g = Graph::new("bds_grad");
16014 let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
16015 let b = g.input("b", Shape::new(&[batch, n], DType::F64));
16016 let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
16017 let loss = g.reduce(
16018 x,
16019 ReduceOp::Sum,
16020 vec![0, 1],
16021 false,
16022 Shape::new(&[1], DType::F64),
16023 );
16024 g.set_outputs(vec![loss]);
16025
16026 let bwd = grad_with_loss(&g, &[a, b]);
16027
16028 let find = |graph: &Graph, want: &str| -> NodeId {
16029 for node in graph.nodes() {
16030 let name = match &node.op {
16031 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16032 _ => None,
16033 };
16034 if name == Some(want) {
16035 return node.id;
16036 }
16037 }
16038 panic!("no node named {want}");
16039 };
16040 let a_id = find(&bwd, "A");
16041 let b_id = find(&bwd, "b");
16042 let d_out_id = find(&bwd, "d_output");
16043
16044 let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
16045 let mut a_data = vec![0.0_f64; batch * n * n];
16046 let mut b_data = vec![0.0_f64; batch * n];
16047 for bi in 0..batch {
16048 for i in 0..n {
16049 for j in 0..n {
16050 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16051 }
16052 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16053 }
16054 for i in 0..n {
16055 b_data[bi * n + i] = rng.next_f32() as f64;
16056 }
16057 }
16058 let d_seed = [1.0_f64];
16059
16060 let (sched, mut arena) = prepare_f64(
16061 &bwd,
16062 &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
16063 );
16064 execute_thunks(&sched, arena.raw_buf_mut());
16065 let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
16066 let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
16067
16068 for bi in 0..batch {
16071 let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16072 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16073 let mut a_copy = a_slice.clone();
16074 crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
16075 let x_ref = b_slice.clone();
16076 let mut at = vec![0.0_f64; n * n];
16078 for i in 0..n {
16079 for j in 0..n {
16080 at[i * n + j] = a_slice[j * n + i];
16081 }
16082 }
16083 let mut ones = vec![1.0_f64; n];
16084 crate::blas::dgesv(&mut at, &mut ones, n, 1);
16085 let db_ref = ones;
16086 for i in 0..n {
16087 let got = db_out[bi * n + i];
16088 assert!(
16089 (got - db_ref[i]).abs() < 1e-10,
16090 "batch {bi}, db[{i}]: got {got} ref {}",
16091 db_ref[i]
16092 );
16093 }
16094 for i in 0..n {
16096 for j in 0..n {
16097 let got = da_out[bi * n * n + i * n + j];
16098 let want = -db_ref[i] * x_ref[j];
16099 assert!(
16100 (got - want).abs() < 1e-10,
16101 "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
16102 );
16103 }
16104 }
16105 }
16106 }
16107
16108 #[test]
16113 fn scan_checkpointed_grad_matches_plain_scan_grad() {
16114 use rlx_opt::autodiff::grad_with_loss;
16115 let n = 2usize;
16116 let length = 6u32;
16117
16118 let make_body = || {
16119 let mut body = Graph::new("ck_body");
16120 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16121 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
16122 let scale = body.add_node(
16123 Op::Constant { data: scale_bytes },
16124 vec![],
16125 Shape::new(&[n], DType::F64),
16126 );
16127 let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
16128 body.set_outputs(vec![next]);
16129 body
16130 };
16131
16132 let mut g_plain = Graph::new("ck_plain");
16134 let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
16135 let final_p = g_plain.scan(init_p, make_body(), length);
16136 let loss_p = g_plain.reduce(
16137 final_p,
16138 ReduceOp::Sum,
16139 vec![0],
16140 false,
16141 Shape::new(&[1], DType::F64),
16142 );
16143 g_plain.set_outputs(vec![loss_p]);
16144 let bwd_p = grad_with_loss(&g_plain, &[init_p]);
16145
16146 let mut g_ck = Graph::new("ck_ckpt");
16148 let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
16149 let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
16150 let loss_c = g_ck.reduce(
16151 final_c,
16152 ReduceOp::Sum,
16153 vec![0],
16154 false,
16155 Shape::new(&[1], DType::F64),
16156 );
16157 g_ck.set_outputs(vec![loss_c]);
16158 let bwd_c = grad_with_loss(&g_ck, &[init_c]);
16159
16160 let find = |graph: &Graph, want: &str| -> NodeId {
16161 for node in graph.nodes() {
16162 let name = match &node.op {
16163 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16164 _ => None,
16165 };
16166 if name == Some(want) {
16167 return node.id;
16168 }
16169 }
16170 panic!("no {want}");
16171 };
16172
16173 let init_data = vec![0.5_f64, -0.5];
16174 let d_seed = [1.0_f64];
16175
16176 let (s_p, mut a_p) = prepare_f64(
16177 &bwd_p,
16178 &[
16179 (find(&bwd_p, "init"), &init_data),
16180 (find(&bwd_p, "d_output"), &d_seed),
16181 ],
16182 );
16183 execute_thunks(&s_p, a_p.raw_buf_mut());
16184 let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
16185
16186 let (s_c, mut a_c) = prepare_f64(
16187 &bwd_c,
16188 &[
16189 (find(&bwd_c, "init"), &init_data),
16190 (find(&bwd_c, "d_output"), &d_seed),
16191 ],
16192 );
16193 execute_thunks(&s_c, a_c.raw_buf_mut());
16194 let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
16195
16196 for i in 0..n {
16197 assert!(
16198 (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
16199 "dinit[{i}]: plain={} checkpointed={}",
16200 dinit_p[i],
16201 dinit_c[i]
16202 );
16203 }
16204 }
16205
16206 #[test]
16212 fn recursive_checkpointing_matches_full_trajectory() {
16213 let n = 2usize;
16214 let length = 4u32;
16215
16216 let build_body = || -> Graph {
16218 let mut body = Graph::new("rc_body");
16219 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16220 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16221 let ones = body.add_node(
16222 Op::Constant { data: ones_bytes },
16223 vec![],
16224 Shape::new(&[n], DType::F64),
16225 );
16226 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16227 body.set_outputs(vec![next]);
16228 body
16229 };
16230
16231 let body_vjp_for = || -> Graph {
16234 use rlx_opt::autodiff::grad;
16235 let body = build_body();
16236 let carry_id = body
16238 .nodes()
16239 .iter()
16240 .find(|n| matches!(n.op, Op::Input { .. }))
16241 .map(|n| n.id)
16242 .unwrap();
16243 grad(&body, &[carry_id])
16244 };
16245
16246 let mut g_full = Graph::new("rc_outer_full");
16248 let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
16249 let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
16250 let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16252 let dinit_full_id = g_full.scan_backward(
16253 init_full,
16254 traj_full_id,
16255 upstream_full,
16256 &[],
16257 body_vjp_for(),
16258 length,
16259 true,
16260 Shape::new(&[n], DType::F64),
16261 );
16262 g_full.set_outputs(vec![dinit_full_id]);
16263
16264 let k = 2u32;
16267 let mut g_rec = Graph::new("rc_outer_rec");
16268 let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
16269 let traj_rec_id = g_rec.add_node(
16270 Op::Scan {
16271 body: Box::new(build_body()),
16272 length,
16273 save_trajectory: true,
16274 num_bcast: 0,
16275 num_xs: 0,
16276 num_checkpoints: k,
16277 },
16278 vec![init_rec],
16279 Shape::new(&[k as usize, n], DType::F64),
16280 );
16281 let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16284 let dinit_rec_id = g_rec.add_node(
16285 Op::ScanBackward {
16286 body_vjp: Box::new(body_vjp_for()),
16287 length,
16288 save_trajectory: true,
16289 num_xs: 0,
16290 num_checkpoints: k,
16291 forward_body: Some(Box::new(build_body())),
16292 },
16293 vec![init_rec, traj_rec_id, upstream_rec],
16294 Shape::new(&[n], DType::F64),
16295 );
16296 g_rec.set_outputs(vec![dinit_rec_id]);
16297
16298 let init_data = vec![0.5_f64, -0.5];
16300 let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
16301
16302 let find = |graph: &Graph, want: &str| -> NodeId {
16303 for node in graph.nodes() {
16304 if let Op::Input { name } = &node.op
16305 && name == want
16306 {
16307 return node.id;
16308 }
16309 }
16310 panic!("no input {want}");
16311 };
16312
16313 let (s_full, mut a_full) = prepare_f64(
16314 &g_full,
16315 &[
16316 (find(&g_full, "init"), &init_data),
16317 (find(&g_full, "upstream"), &upstream_data),
16318 ],
16319 );
16320 execute_thunks(&s_full, a_full.raw_buf_mut());
16321 let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
16322
16323 let (s_rec, mut a_rec) = prepare_f64(
16324 &g_rec,
16325 &[
16326 (find(&g_rec, "init"), &init_data),
16327 (find(&g_rec, "upstream"), &upstream_data),
16328 ],
16329 );
16330 execute_thunks(&s_rec, a_rec.raw_buf_mut());
16331 let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
16332
16333 for i in 0..n {
16334 assert!(
16335 (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
16336 "i={i}: full={} rec={}",
16337 dinit_full[i],
16338 dinit_rec[i]
16339 );
16340 }
16341 }
16342
16343 #[test]
16352 fn vmap_of_grad_scan_matches_per_row_runs() {
16353 use rlx_opt::autodiff::grad_with_loss;
16354 use rlx_opt::vmap::vmap;
16355 let n = 2usize;
16356 let length = 3u32;
16357 let batch = 3usize;
16358
16359 let mut body = Graph::new("scan_grad_body");
16360 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16361 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16362 let ones = body.add_node(
16363 Op::Constant { data: ones_bytes },
16364 vec![],
16365 Shape::new(&[n], DType::F64),
16366 );
16367 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16368 body.set_outputs(vec![next]);
16369
16370 let mut g = Graph::new("scan_grad_outer");
16371 let init = g.input("init", Shape::new(&[n], DType::F64));
16372 let final_x = g.scan(init, body, length);
16373 let loss = g.reduce(
16374 final_x,
16375 ReduceOp::Sum,
16376 vec![0],
16377 false,
16378 Shape::new(&[1], DType::F64),
16379 );
16380 g.set_outputs(vec![loss]);
16381
16382 let bwd = grad_with_loss(&g, &[init]);
16383 let bg = vmap(&bwd, &["init"], batch);
16384
16385 let find = |graph: &Graph, want: &str| -> NodeId {
16386 for node in graph.nodes() {
16387 let name = match &node.op {
16388 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16389 _ => None,
16390 };
16391 if name == Some(want) {
16392 return node.id;
16393 }
16394 }
16395 panic!("no node named {want}");
16396 };
16397 let init_b = find(&bg, "init");
16398 let d_out_b = find(&bg, "d_output");
16399
16400 let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
16401 let d_seed = [1.0_f64];
16402
16403 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
16404 execute_thunks(&sched, arena.raw_buf_mut());
16405 let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
16406
16407 for i in 0..batch * n {
16408 assert!(
16409 (dinit_b[i] - 1.0).abs() < 1e-12,
16410 "dinit[{i}] = {} (expected 1.0)",
16411 dinit_b[i]
16412 );
16413 }
16414
16415 for bi in 0..batch {
16417 let row = &init_data[bi * n..(bi + 1) * n];
16418 let mut g2 = Graph::new("per_row_grad");
16419 let init2 = g2.input("init", Shape::new(&[n], DType::F64));
16420 let mut body2 = Graph::new("per_row_body");
16421 let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
16422 let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16423 let ones2 = body2.add_node(
16424 Op::Constant { data: ones2_bytes },
16425 vec![],
16426 Shape::new(&[n], DType::F64),
16427 );
16428 let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
16429 body2.set_outputs(vec![next2]);
16430 let final2 = g2.scan(init2, body2, length);
16431 let loss2 = g2.reduce(
16432 final2,
16433 ReduceOp::Sum,
16434 vec![0],
16435 false,
16436 Shape::new(&[1], DType::F64),
16437 );
16438 g2.set_outputs(vec![loss2]);
16439 let bwd2 = grad_with_loss(&g2, &[init2]);
16440 let init2_id = find(&bwd2, "init");
16441 let d_out2_id = find(&bwd2, "d_output");
16442 let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
16443 execute_thunks(&s2, a2.raw_buf_mut());
16444 let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
16445 for j in 0..n {
16446 let got = dinit_b[bi * n + j];
16447 let want = row_dinit[j];
16448 assert!(
16449 (got - want).abs() < 1e-12,
16450 "row {bi}, j {j}: vmap'd={got} per-row={want}"
16451 );
16452 }
16453 }
16454 }
16455
16456 #[test]
16462 fn vmap_scan_cumulative_sum_matches_scalar_runs() {
16463 use rlx_opt::vmap::vmap;
16464 let n = 2usize;
16465 let length = 4u32;
16466 let batch = 3usize;
16467
16468 let mut body = Graph::new("scan_body_cumsum");
16470 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16471 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
16472 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
16473 body.set_outputs(vec![next]);
16474
16475 let mut g = Graph::new("scan_outer_cumsum");
16476 let init = g.input("init", Shape::new(&[n], DType::F64));
16477 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16478 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16479 g.set_outputs(vec![final_carry]);
16480
16481 let bg = vmap(&g, &["init", "xs"], batch);
16483
16484 let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
16486 let xs_data: Vec<f64> = (0..batch * length as usize * n)
16489 .map(|i| 0.1 * (i as f64))
16490 .collect();
16491
16492 let find = |graph: &Graph, want: &str| -> NodeId {
16493 for node in graph.nodes() {
16494 if let Op::Input { name } = &node.op
16495 && name == want
16496 {
16497 return node.id;
16498 }
16499 }
16500 panic!("no input {want}");
16501 };
16502 let init_b = find(&bg, "init");
16503 let xs_b = find(&bg, "xs");
16504 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
16505 execute_thunks(&sched, arena.raw_buf_mut());
16506 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
16507
16508 for bi in 0..batch {
16510 let init_slice = &init_data[bi * n..(bi + 1) * n];
16511 let mut x = init_slice.to_vec();
16512 for t in 0..length as usize {
16513 for j in 0..n {
16514 x[j] += xs_data[bi * length as usize * n + t * n + j];
16515 }
16516 }
16517
16518 for i in 0..n {
16519 let got = batched_out[bi * n + i];
16520 assert!(
16521 (got - x[i]).abs() < 1e-12,
16522 "row {bi}, i {i}: got {got} ref {}",
16523 x[i]
16524 );
16525 }
16526 }
16527 }
16528
16529 #[test]
16534 fn vmap_dense_solve_matches_scalar_runs() {
16535 use rlx_opt::vmap::vmap;
16536 let n = 3usize;
16537 let batch = 4usize;
16538
16539 let mut g = Graph::new("solve_forward");
16540 let a = g.input("A", Shape::new(&[n, n], DType::F64));
16541 let b = g.input("b", Shape::new(&[n], DType::F64));
16542 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
16543 g.set_outputs(vec![x]);
16544
16545 let bg = vmap(&g, &["A", "b"], batch);
16547
16548 let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
16550 let mut a_data = vec![0.0_f64; batch * n * n];
16551 let mut b_data = vec![0.0_f64; batch * n];
16552 for bi in 0..batch {
16553 for i in 0..n {
16555 for j in 0..n {
16556 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16557 }
16558 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16559 }
16560 for i in 0..n {
16561 b_data[bi * n + i] = rng.next_f32() as f64;
16562 }
16563 }
16564
16565 let find = |graph: &Graph, want: &str| -> NodeId {
16566 for node in graph.nodes() {
16567 if let Op::Input { name } = &node.op
16568 && name == want
16569 {
16570 return node.id;
16571 }
16572 }
16573 panic!("no input named {want}");
16574 };
16575 let ba = find(&bg, "A");
16576 let bb = find(&bg, "b");
16577 let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
16578 execute_thunks(&sched, arena.raw_buf_mut());
16579 let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
16580
16581 for bi in 0..batch {
16583 let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16584 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16585 crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
16586 for i in 0..n {
16587 let got = batched_x[bi * n + i];
16588 let want = b_slice[i];
16589 assert!(
16590 (got - want).abs() < 1e-12,
16591 "row {bi}, i {i}: got {got} want {want}"
16592 );
16593 }
16594 }
16595 }
16596
16597 #[test]
16604 fn vmap_matmul_add_reduce_matches_scalar_runs() {
16605 use rlx_opt::vmap::vmap;
16606 let n = 3usize;
16607 let batch = 4usize;
16608
16609 let mut g = Graph::new("vmap_e2e_forward");
16611 let x = g.input("x", Shape::new(&[n], DType::F64));
16612 let w = g.input("w", Shape::new(&[n, n], DType::F64));
16613 let b = g.input("b", Shape::new(&[n], DType::F64));
16614 let x_row = g.add_node(
16615 Op::Reshape {
16616 new_shape: vec![1, n as i64],
16617 },
16618 vec![x],
16619 Shape::new(&[1, n], DType::F64),
16620 );
16621 let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
16622 let mm_flat = g.add_node(
16623 Op::Reshape {
16624 new_shape: vec![n as i64],
16625 },
16626 vec![mm],
16627 Shape::new(&[n], DType::F64),
16628 );
16629 let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
16630 let loss = g.reduce(
16631 yv,
16632 ReduceOp::Sum,
16633 vec![0],
16634 false,
16635 Shape::new(&[1], DType::F64),
16636 );
16637 g.set_outputs(vec![loss]);
16638
16639 let bg = vmap(&g, &["x"], batch);
16641
16642 let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
16644 let n_w = n * n;
16645 let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
16646 let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
16647 let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
16648 for _ in 0..batch * n {
16649 x_data_batched.push(rng.next_f32() as f64);
16650 }
16651
16652 let find = |graph: &Graph, want: &str| -> NodeId {
16654 for node in graph.nodes() {
16655 if let Op::Input { name } = &node.op
16656 && name == want
16657 {
16658 return node.id;
16659 }
16660 }
16661 panic!("no input named {want}");
16662 };
16663 let bx = find(&bg, "x");
16664 let bw = find(&bg, "w");
16665 let bb = find(&bg, "b");
16666 let (sched, mut arena) =
16667 prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
16668 execute_thunks(&sched, arena.raw_buf_mut());
16669 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
16675
16676 for bi in 0..batch {
16678 let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
16679 let mut g2 = Graph::new("scalar_run");
16680 let x2 = g2.input("x", Shape::new(&[n], DType::F64));
16681 let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
16682 let b2 = g2.input("b", Shape::new(&[n], DType::F64));
16683 let xr = g2.add_node(
16684 Op::Reshape {
16685 new_shape: vec![1, n as i64],
16686 },
16687 vec![x2],
16688 Shape::new(&[1, n], DType::F64),
16689 );
16690 let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
16691 let mf = g2.add_node(
16692 Op::Reshape {
16693 new_shape: vec![n as i64],
16694 },
16695 vec![m],
16696 Shape::new(&[n], DType::F64),
16697 );
16698 let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
16699 let l2 = g2.reduce(
16700 yv2,
16701 ReduceOp::Sum,
16702 vec![0],
16703 false,
16704 Shape::new(&[1], DType::F64),
16705 );
16706 g2.set_outputs(vec![l2]);
16707 let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
16708 execute_thunks(&s2, a2.raw_buf_mut());
16709 let scalar_out = read_arena_f64(&a2, l2, 1);
16710 assert!(
16711 (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
16712 "row {bi}: batched={} scalar={}",
16713 batched_out[bi],
16714 scalar_out[0]
16715 );
16716 }
16717 }
16718
16719 #[test]
16726 fn scan_with_xs_dxs_matches_fd() {
16727 use rlx_opt::autodiff::grad_with_loss;
16728 let n = 3usize;
16729 let length = 3u32;
16730 let dt = 0.1_f64;
16731
16732 let mut m_data = vec![0.0_f64; n * n];
16733 for i in 0..n {
16734 m_data[i * n + i] = 1.0 + dt * 2.0;
16735 if i > 0 {
16736 m_data[i * n + (i - 1)] = -dt;
16737 }
16738 if i + 1 < n {
16739 m_data[i * n + (i + 1)] = -dt;
16740 }
16741 }
16742 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16743
16744 let mut body = Graph::new("be_dxs_body");
16745 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16746 let drive = body.input("drive", Shape::new(&[n], DType::F64));
16747 let m = body.add_node(
16748 Op::Constant { data: m_bytes },
16749 vec![],
16750 Shape::new(&[n, n], DType::F64),
16751 );
16752 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16753 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16754 body.set_outputs(vec![next]);
16755
16756 let mut g = Graph::new("be_dxs_outer");
16757 let init = g.input("init", Shape::new(&[n], DType::F64));
16758 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16759 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16760 let loss = g.reduce(
16761 final_carry,
16762 ReduceOp::Sum,
16763 vec![0],
16764 false,
16765 Shape::new(&[1], DType::F64),
16766 );
16767 g.set_outputs(vec![loss]);
16768
16769 let bwd = grad_with_loss(&g, &[init, xs]);
16771 assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
16772
16773 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16774 for node in graph.nodes() {
16775 let name = match &node.op {
16776 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16777 _ => None,
16778 };
16779 if name == Some(want) {
16780 return node.id;
16781 }
16782 }
16783 panic!("no node named {want:?}");
16784 };
16785 let init_bwd = find_by_name(&bwd, "init");
16786 let xs_bwd = find_by_name(&bwd, "xs");
16787 let d_out_bwd = find_by_name(&bwd, "d_output");
16788
16789 let init_data = vec![0.5_f64, 0.0, -0.5];
16790 let xs_data: Vec<f64> = (0..length as usize * n)
16791 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16792 .collect();
16793 let d_seed = [1.0_f64];
16794
16795 let (sched, mut arena) = prepare_f64(
16796 &bwd,
16797 &[
16798 (init_bwd, &init_data),
16799 (xs_bwd, &xs_data),
16800 (d_out_bwd, &d_seed),
16801 ],
16802 );
16803 execute_thunks(&sched, arena.raw_buf_mut());
16804 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16805 let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
16806
16807 let h = 1e-6;
16808 let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
16809 let mut acc = x0.to_vec();
16810 for t in 0..length as usize {
16811 for j in 0..n {
16812 acc[j] += xs_in[t * n + j];
16813 }
16814 let mut a_copy = m_data.clone();
16815 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16816 }
16817 acc.iter().sum()
16818 };
16819
16820 for i in 0..n {
16822 let mut ip = init_data.to_vec();
16823 ip[i] += h;
16824 let mut im = init_data.to_vec();
16825 im[i] -= h;
16826 let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
16827 assert!(
16828 (dinit[i] - fd).abs() < 1e-7,
16829 "FD dinit[{i}]: AD={} FD={}",
16830 dinit[i],
16831 fd
16832 );
16833 }
16834
16835 for t in 0..length as usize {
16837 for j in 0..n {
16838 let idx = t * n + j;
16839 let mut xp = xs_data.clone();
16840 xp[idx] += h;
16841 let mut xm = xs_data.clone();
16842 xm[idx] -= h;
16843 let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
16844 assert!(
16845 (dxs[idx] - fd).abs() < 1e-7,
16846 "FD dxs[t={t},j={j}]: AD={} FD={}",
16847 dxs[idx],
16848 fd
16849 );
16850 }
16851 }
16852 }
16853
16854 #[test]
16862 fn scan_with_xs_gradient_dinit_matches_fd() {
16863 use rlx_opt::autodiff::grad_with_loss;
16864 let n = 3usize;
16865 let length = 3u32;
16866 let dt = 0.1_f64;
16867
16868 let mut m_data = vec![0.0_f64; n * n];
16869 for i in 0..n {
16870 m_data[i * n + i] = 1.0 + dt * 2.0;
16871 if i > 0 {
16872 m_data[i * n + (i - 1)] = -dt;
16873 }
16874 if i + 1 < n {
16875 m_data[i * n + (i + 1)] = -dt;
16876 }
16877 }
16878 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16879
16880 let mut body = Graph::new("be_xs_grad_body");
16881 let carry = body.input("carry", Shape::new(&[n], DType::F64));
16882 let drive = body.input("drive", Shape::new(&[n], DType::F64));
16883 let m = body.add_node(
16884 Op::Constant { data: m_bytes },
16885 vec![],
16886 Shape::new(&[n, n], DType::F64),
16887 );
16888 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16889 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16890 body.set_outputs(vec![next]);
16891
16892 let mut g = Graph::new("be_xs_grad_outer");
16893 let init = g.input("init", Shape::new(&[n], DType::F64));
16894 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16895 let final_carry = g.scan_with_xs(init, &[xs], body, length);
16896 let loss = g.reduce(
16897 final_carry,
16898 ReduceOp::Sum,
16899 vec![0],
16900 false,
16901 Shape::new(&[1], DType::F64),
16902 );
16903 g.set_outputs(vec![loss]);
16904
16905 let bwd = grad_with_loss(&g, &[init]);
16906
16907 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16908 for node in graph.nodes() {
16909 let name = match &node.op {
16910 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16911 _ => None,
16912 };
16913 if name == Some(want) {
16914 return node.id;
16915 }
16916 }
16917 panic!("no node named {want:?}");
16918 };
16919 let init_bwd = find_by_name(&bwd, "init");
16920 let xs_bwd = find_by_name(&bwd, "xs");
16921 let d_out_bwd = find_by_name(&bwd, "d_output");
16922
16923 let init_data = vec![0.5_f64, 0.0, -0.5];
16924 let xs_data: Vec<f64> = (0..length as usize * n)
16926 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16927 .collect();
16928 let d_seed = [1.0_f64];
16929
16930 let (sched, mut arena) = prepare_f64(
16931 &bwd,
16932 &[
16933 (init_bwd, &init_data),
16934 (xs_bwd, &xs_data),
16935 (d_out_bwd, &d_seed),
16936 ],
16937 );
16938 execute_thunks(&sched, arena.raw_buf_mut());
16939 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16940
16941 let h = 1e-6;
16942 let loss_at = |x0: &[f64]| -> f64 {
16943 let mut acc = x0.to_vec();
16944 for t in 0..length as usize {
16945 for j in 0..n {
16946 acc[j] += xs_data[t * n + j];
16947 }
16948 let mut a_copy = m_data.clone();
16949 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16950 }
16951 acc.iter().sum()
16952 };
16953 for i in 0..n {
16954 let mut ip = init_data.to_vec();
16955 ip[i] += h;
16956 let mut im = init_data.to_vec();
16957 im[i] -= h;
16958 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
16959 assert!(
16960 (dinit[i] - fd).abs() < 1e-7,
16961 "FD dinit[{i}]: AD={} FD={}",
16962 dinit[i],
16963 fd
16964 );
16965 }
16966 }
16967
16968 #[test]
16976 fn scan_gradient_geometric_matches_closed_form() {
16977 use rlx_opt::autodiff::grad_with_loss;
16978 let n = 3usize;
16979 let length = 5u32;
16980
16981 let mut body = Graph::new("scan_grad_body");
16982 let x = body.input("carry", Shape::new(&[n], DType::F64));
16983 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
16984 let scale = body.add_node(
16985 Op::Constant { data: scale_bytes },
16986 vec![],
16987 Shape::new(&[n], DType::F64),
16988 );
16989 let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
16990 body.set_outputs(vec![next]);
16991
16992 let mut g = Graph::new("scan_grad_outer");
16993 let init = g.input("init", Shape::new(&[n], DType::F64));
16994 let final_x = g.scan(init, body, length);
16995 let loss = g.reduce(
16996 final_x,
16997 ReduceOp::Sum,
16998 vec![0],
16999 false,
17000 Shape::new(&[1], DType::F64),
17001 );
17002 g.set_outputs(vec![loss]);
17003
17004 let bwd = grad_with_loss(&g, &[init]);
17005 assert_eq!(bwd.outputs.len(), 2);
17006
17007 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17008 for node in graph.nodes() {
17009 let name = match &node.op {
17010 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17011 _ => None,
17012 };
17013 if name == Some(want) {
17014 return node.id;
17015 }
17016 }
17017 panic!("no node named {want:?}");
17018 };
17019 let init_bwd = find_by_name(&bwd, "init");
17020 let d_out_bwd = find_by_name(&bwd, "d_output");
17021
17022 let init_data = vec![1.0_f64; n];
17023 let d_seed = [1.0_f64];
17024 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17025 execute_thunks(&sched, arena.raw_buf_mut());
17026 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17027
17028 let want = 1.1_f64.powi(length as i32);
17029 for i in 0..n {
17030 assert!(
17031 (dinit[i] - want).abs() < 1e-12,
17032 "dinit[{i}] = {} want {}",
17033 dinit[i],
17034 want
17035 );
17036 }
17037
17038 let h = 1e-6;
17040 let loss_at = |x: &[f64]| -> f64 {
17041 let mut acc = x.to_vec();
17042 for _ in 0..length {
17043 for v in acc.iter_mut() {
17044 *v *= 1.1;
17045 }
17046 }
17047 acc.iter().sum()
17048 };
17049 let mut ip = init_data.clone();
17050 ip[0] += h;
17051 let mut im = init_data.clone();
17052 im[0] -= h;
17053 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17054 assert!(
17055 (dinit[0] - fd).abs() < 1e-7,
17056 "FD dinit[0]: AD={} FD={}",
17057 dinit[0],
17058 fd
17059 );
17060 }
17061
17062 #[test]
17065 fn scan_gradient_backward_euler_matches_fd() {
17066 use rlx_opt::autodiff::grad_with_loss;
17067 let n = 4usize;
17068 let length = 3u32;
17069 let dt = 0.05_f64;
17070
17071 let mut m_data = vec![0.0_f64; n * n];
17072 for i in 0..n {
17073 m_data[i * n + i] = 1.0 + dt * 2.0;
17074 if i > 0 {
17075 m_data[i * n + (i - 1)] = -dt;
17076 }
17077 if i + 1 < n {
17078 m_data[i * n + (i + 1)] = -dt;
17079 }
17080 }
17081 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17082
17083 let mut body = Graph::new("be_grad_body");
17084 let x = body.input("x", Shape::new(&[n], DType::F64));
17085 let m = body.add_node(
17086 Op::Constant { data: m_bytes },
17087 vec![],
17088 Shape::new(&[n, n], DType::F64),
17089 );
17090 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17091 body.set_outputs(vec![next]);
17092
17093 let mut g = Graph::new("be_grad_outer");
17094 let init = g.input("x0", Shape::new(&[n], DType::F64));
17095 let final_x = g.scan(init, body, length);
17096 let loss = g.reduce(
17097 final_x,
17098 ReduceOp::Sum,
17099 vec![0],
17100 false,
17101 Shape::new(&[1], DType::F64),
17102 );
17103 g.set_outputs(vec![loss]);
17104
17105 let bwd = grad_with_loss(&g, &[init]);
17106
17107 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17108 for node in graph.nodes() {
17109 let name = match &node.op {
17110 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17111 _ => None,
17112 };
17113 if name == Some(want) {
17114 return node.id;
17115 }
17116 }
17117 panic!("no node named {want:?}");
17118 };
17119 let init_bwd = find_by_name(&bwd, "x0");
17120 let d_out_bwd = find_by_name(&bwd, "d_output");
17121
17122 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17123 let d_seed = [1.0_f64];
17124 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17125 execute_thunks(&sched, arena.raw_buf_mut());
17126 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17127
17128 let h = 1e-6;
17129 let loss_at = |x0: &[f64]| -> f64 {
17130 let mut acc = x0.to_vec();
17131 for _ in 0..length {
17132 let mut a_copy = m_data.clone();
17133 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
17134 }
17135 acc.iter().sum()
17136 };
17137 for i in 0..n {
17138 let mut ip = init_data.to_vec();
17139 ip[i] += h;
17140 let mut im = init_data.to_vec();
17141 im[i] -= h;
17142 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17143 assert!(
17144 (dinit[i] - fd).abs() < 1e-7,
17145 "FD dinit[{i}]: AD={} FD={}",
17146 dinit[i],
17147 fd
17148 );
17149 }
17150 }
17151
17152 #[test]
17158 fn scan_trajectory_backward_euler_records_waveform() {
17159 let n = 4usize;
17160 let length = 5u32;
17161 let dt = 0.05_f64;
17162
17163 let mut m_data = vec![0.0_f64; n * n];
17164 for i in 0..n {
17165 m_data[i * n + i] = 1.0 + dt * 2.0;
17166 if i > 0 {
17167 m_data[i * n + (i - 1)] = -dt;
17168 }
17169 if i + 1 < n {
17170 m_data[i * n + (i + 1)] = -dt;
17171 }
17172 }
17173 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17174
17175 let mut body = Graph::new("be_traj_body");
17176 let x = body.input("x", Shape::new(&[n], DType::F64));
17177 let m = body.add_node(
17178 Op::Constant { data: m_bytes },
17179 vec![],
17180 Shape::new(&[n, n], DType::F64),
17181 );
17182 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17183 body.set_outputs(vec![next]);
17184
17185 let mut g = Graph::new("be_traj_outer");
17186 let init = g.input("x0", Shape::new(&[n], DType::F64));
17187 let traj = g.scan_trajectory(init, body, length);
17188 g.set_outputs(vec![traj]);
17189
17190 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17191 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17192 execute_thunks(&sched, arena.raw_buf_mut());
17193 let got = read_arena_f64(&arena, traj, length as usize * n);
17194
17195 let mut want = Vec::<f64>::with_capacity(length as usize * n);
17197 let mut x_ref = init_data.to_vec();
17198 for _ in 0..length {
17199 let mut a_copy = m_data.clone();
17200 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
17201 want.extend_from_slice(&x_ref);
17202 }
17203 for i in 0..length as usize * n {
17204 assert!(
17205 (got[i] - want[i]).abs() < 1e-12,
17206 "got[{i}] = {} ref {}",
17207 got[i],
17208 want[i]
17209 );
17210 }
17211
17212 for t in 1..length as usize {
17215 let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
17216 let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
17217 assert!(
17218 curr <= prev + 1e-15,
17219 "mass should decay: row {} sum {prev}, row {t} sum {curr}",
17220 t - 1
17221 );
17222 }
17223
17224 let mut body2 = Graph::new("be_final_body");
17228 let x2 = body2.input("x", Shape::new(&[n], DType::F64));
17229 let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17230 let m2 = body2.add_node(
17231 Op::Constant { data: m_bytes2 },
17232 vec![],
17233 Shape::new(&[n, n], DType::F64),
17234 );
17235 let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
17236 body2.set_outputs(vec![next2]);
17237
17238 let mut g2 = Graph::new("be_final_outer");
17239 let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
17240 let final_x = g2.scan(init2, body2, length);
17241 g2.set_outputs(vec![final_x]);
17242 let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
17243 execute_thunks(&sched2, arena2.raw_buf_mut());
17244 let final_got = read_arena_f64(&arena2, final_x, n);
17245
17246 let last_row = &got[(length as usize - 1) * n..length as usize * n];
17247 for i in 0..n {
17248 assert!(
17249 (last_row[i] - final_got[i]).abs() < 1e-15,
17250 "last trajectory row[{i}] = {} vs final-scan = {}",
17251 last_row[i],
17252 final_got[i]
17253 );
17254 }
17255 }
17256
17257 #[test]
17263 fn scan_backward_euler_heat_f64() {
17264 let n = 4usize;
17265 let length = 5u32;
17266 let dt = 0.05_f64;
17267
17268 let mut m_data = vec![0.0_f64; n * n];
17271 for i in 0..n {
17272 m_data[i * n + i] = 1.0 + dt * 2.0;
17273 if i > 0 {
17274 m_data[i * n + (i - 1)] = -dt;
17275 }
17276 if i + 1 < n {
17277 m_data[i * n + (i + 1)] = -dt;
17278 }
17279 }
17280 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17281
17282 let mut body = Graph::new("be_body");
17283 let x = body.input("x", Shape::new(&[n], DType::F64));
17284 let m = body.add_node(
17285 Op::Constant { data: m_bytes },
17286 vec![],
17287 Shape::new(&[n, n], DType::F64),
17288 );
17289 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17290 body.set_outputs(vec![next]);
17291
17292 let mut g = Graph::new("be_outer");
17293 let init = g.input("x0", Shape::new(&[n], DType::F64));
17294 let final_x = g.scan(init, body, length);
17295 g.set_outputs(vec![final_x]);
17296
17297 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17299 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17300 execute_thunks(&sched, arena.raw_buf_mut());
17301 let got = read_arena_f64(&arena, final_x, n);
17302
17303 let mut ref_x = init_data.to_vec();
17305 for _ in 0..length {
17306 let mut a_copy = m_data.clone();
17307 crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
17308 }
17309 for i in 0..n {
17310 assert!(
17311 (got[i] - ref_x[i]).abs() < 1e-12,
17312 "got[{i}] = {} ref {}",
17313 got[i],
17314 ref_x[i]
17315 );
17316 }
17317 let mass: f64 = got.iter().sum();
17322 assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
17323 }
17324
17325 #[test]
17329 fn dense_solve_f64_multi_rhs_forward() {
17330 let n = 3usize;
17331 let k = 2usize;
17332 let mut g = Graph::new("solve_multi_rhs");
17333 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17334 let b = g.input("B", Shape::new(&[n, k], DType::F64));
17335 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17336 g.set_outputs(vec![x]);
17337
17338 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17339 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17340 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17341 execute_thunks(&sched, arena.raw_buf_mut());
17342 let x_got = read_arena_f64(&arena, x, n * k);
17343 for c in 0..k {
17344 for i in 0..n {
17345 let mut acc = 0.0_f64;
17346 for j in 0..n {
17347 acc += a_data[i * n + j] * x_got[j * k + c];
17348 }
17349 let want = b_data[i * k + c];
17350 assert!(
17351 (acc - want).abs() < 1e-10,
17352 "col {c} row {i}: got {acc} want {want}"
17353 );
17354 }
17355 }
17356 }
17357
17358 #[test]
17361 fn dense_solve_f64_multi_rhs_gradient() {
17362 use rlx_opt::autodiff::grad_with_loss;
17363 let n = 3usize;
17364 let k = 2usize;
17365 let mut g = Graph::new("solve_mrhs_grad");
17366 let a = g.param("A", Shape::new(&[n, n], DType::F64));
17367 let b = g.input("B", Shape::new(&[n, k], DType::F64));
17368 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17369 let loss = g.reduce(
17370 x,
17371 ReduceOp::Sum,
17372 vec![0, 1],
17373 false,
17374 Shape::new(&[1], DType::F64),
17375 );
17376 g.set_outputs(vec![loss]);
17377
17378 let bwd = grad_with_loss(&g, &[a, b]);
17379 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17380 for node in graph.nodes() {
17381 let name = match &node.op {
17382 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17383 _ => None,
17384 };
17385 if name == Some(want) {
17386 return node.id;
17387 }
17388 }
17389 panic!("no node named {want:?}");
17390 };
17391 let a_bwd = find_by_name(&bwd, "A");
17392 let b_bwd = find_by_name(&bwd, "B");
17393 let d_out = find_by_name(&bwd, "d_output");
17394
17395 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17396 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17397 let d_seed = [1.0_f64];
17398
17399 let (sched, mut arena) = prepare_f64(
17400 &bwd,
17401 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
17402 );
17403 execute_thunks(&sched, arena.raw_buf_mut());
17404 let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
17405 let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
17406
17407 let mut x_ref = b_data;
17409 {
17410 let mut a_copy = a_data;
17411 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
17412 }
17413 let mut at = [0.0_f64; 9];
17414 for i in 0..n {
17415 for j in 0..n {
17416 at[i * n + j] = a_data[j * n + i];
17417 }
17418 }
17419 let mut ones_nk = vec![1.0_f64; n * k];
17420 crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
17421 let db_ref = ones_nk;
17422 let mut da_ref = [0.0_f64; 9];
17423 for i in 0..n {
17424 for j in 0..n {
17425 let mut acc = 0.0_f64;
17426 for c in 0..k {
17427 acc += db_ref[i * k + c] * x_ref[j * k + c];
17428 }
17429 da_ref[i * n + j] = -acc;
17430 }
17431 }
17432 for i in 0..n * k {
17433 assert!(
17434 (db_got[i] - db_ref[i]).abs() < 1e-10,
17435 "dB[{i}]: got {} want {}",
17436 db_got[i],
17437 db_ref[i]
17438 );
17439 }
17440 for i in 0..n * n {
17441 assert!(
17442 (da_got[i] - da_ref[i]).abs() < 1e-10,
17443 "dA[{i}]: got {} want {}",
17444 da_got[i],
17445 da_ref[i]
17446 );
17447 }
17448
17449 let h = 1e-6;
17451 let mut bp = b_data;
17452 bp[0] += h;
17453 let mut bm = b_data;
17454 bm[0] -= h;
17455 let xp = {
17456 let mut a_copy = a_data;
17457 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17458 bp
17459 };
17460 let xm = {
17461 let mut a_copy = a_data;
17462 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17463 bm
17464 };
17465 let lp: f64 = xp.iter().sum();
17466 let lm: f64 = xm.iter().sum();
17467 let fd = (lp - lm) / (2.0 * h);
17468 assert!(
17469 (db_got[0] - fd).abs() < 1e-7,
17470 "FD dB[0,0]: AD={} FD={}",
17471 db_got[0],
17472 fd
17473 );
17474 }
17475
17476 #[test]
17478 fn dense_solve_f64_multi_rhs_jvp() {
17479 use rlx_opt::autodiff_fwd::jvp;
17480 let n = 3usize;
17481 let k = 2usize;
17482 let mut g = Graph::new("solve_mrhs_jvp");
17483 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17484 let b = g.input("B", Shape::new(&[n, k], DType::F64));
17485 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17486 g.set_outputs(vec![x]);
17487
17488 let jg = jvp(&g, &[b]);
17489 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17490 for node in graph.nodes() {
17491 let name = match &node.op {
17492 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17493 _ => None,
17494 };
17495 if name == Some(want) {
17496 return node.id;
17497 }
17498 }
17499 panic!("no node named {want:?}");
17500 };
17501 let a_id = find_by_name(&jg, "A");
17502 let b_id = find_by_name(&jg, "B");
17503 let tb_id = find_by_name(&jg, "tangent_B");
17504
17505 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17506 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17507 let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
17508
17509 let (sched, mut arena) =
17510 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17511 execute_thunks(&sched, arena.raw_buf_mut());
17512 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
17513
17514 let mut a_copy = a_data;
17515 let mut tb_copy = tb_data;
17516 crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
17517 for i in 0..n * k {
17518 assert!(
17519 (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
17520 "t_X[{i}]: AD={} ref={}",
17521 tangent_x[i],
17522 tb_copy[i]
17523 );
17524 }
17525
17526 let h = 1e-6;
17527 let mut bp = b_data;
17528 let mut bm = b_data;
17529 for i in 0..n * k {
17530 bp[i] += h * tb_data[i];
17531 bm[i] -= h * tb_data[i];
17532 }
17533 let xp = {
17534 let mut a_copy = a_data;
17535 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17536 bp
17537 };
17538 let xm = {
17539 let mut a_copy = a_data;
17540 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17541 bm
17542 };
17543 for i in 0..n * k {
17544 let fd = (xp[i] - xm[i]) / (2.0 * h);
17545 assert!(
17546 (tangent_x[i] - fd).abs() < 1e-7,
17547 "FD t_X[{i}]: AD={} FD={}",
17548 tangent_x[i],
17549 fd
17550 );
17551 }
17552 }
17553
17554 #[test]
17561 fn jvp_dense_solve_b_runs_and_matches_fd() {
17562 use rlx_opt::autodiff_fwd::jvp;
17563 let n = 3usize;
17564
17565 let mut g = Graph::new("jvp_b_e2e");
17567 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17568 let b = g.input("b", Shape::new(&[n], DType::F64));
17569 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17570 g.set_outputs(vec![x]);
17571
17572 let jg = jvp(&g, &[b]);
17574 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17576 for node in graph.nodes() {
17577 let name = match &node.op {
17578 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17579 _ => None,
17580 };
17581 if name == Some(want) {
17582 return node.id;
17583 }
17584 }
17585 panic!("no node named {want:?}");
17586 };
17587 let a_id = find_by_name(&jg, "A");
17588 let b_id = find_by_name(&jg, "b");
17589 let tb_id = find_by_name(&jg, "tangent_b");
17590
17591 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17592 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17593 let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
17595
17596 let (sched, mut arena) =
17597 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17598 execute_thunks(&sched, arena.raw_buf_mut());
17599
17600 let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
17602 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17603
17604 let t_x_ref = {
17606 let mut a = a_data;
17607 let mut tb = tb_data;
17608 let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
17609 assert_eq!(info, 0);
17610 tb
17611 };
17612 for i in 0..n {
17613 assert!(
17614 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17615 "t_x[{i}]: got {} want {}",
17616 tangent_x[i],
17617 t_x_ref[i]
17618 );
17619 }
17620
17621 let h = 1e-6;
17623 let mut bp = b_data;
17624 let mut bm = b_data;
17625 for i in 0..n {
17626 bp[i] += h * tb_data[i];
17627 bm[i] -= h * tb_data[i];
17628 }
17629 let xp = {
17630 let mut a = a_data;
17631 let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
17632 assert_eq!(info, 0);
17633 bp
17634 };
17635 let xm = {
17636 let mut a = a_data;
17637 let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
17638 assert_eq!(info, 0);
17639 bm
17640 };
17641 let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
17642 for i in 0..n {
17643 assert!(
17644 (tangent_x[i] - fd[i]).abs() < 1e-7,
17645 "FD mismatch t_x[{i}]: AD={} FD={}",
17646 tangent_x[i],
17647 fd[i]
17648 );
17649 }
17650 let primal_ref = {
17652 let mut a = a_data;
17653 let mut b = b_data;
17654 crate::blas::dgesv(&mut a, &mut b, n, 1);
17655 b
17656 };
17657 for i in 0..n {
17658 assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
17659 }
17660 }
17661
17662 #[test]
17668 fn jvp_dense_solve_a_runs_and_matches_fd() {
17669 use rlx_opt::autodiff_fwd::jvp;
17670 let n = 3usize;
17671
17672 let mut g = Graph::new("jvp_a_e2e");
17673 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17674 let b = g.input("b", Shape::new(&[n], DType::F64));
17675 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17676 g.set_outputs(vec![x]);
17677
17678 let jg = jvp(&g, &[a]);
17679 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17680 for node in graph.nodes() {
17681 let name = match &node.op {
17682 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17683 _ => None,
17684 };
17685 if name == Some(want) {
17686 return node.id;
17687 }
17688 }
17689 panic!("no node named {want:?}");
17690 };
17691 let a_id = find_by_name(&jg, "A");
17692 let b_id = find_by_name(&jg, "b");
17693 let ta_id = find_by_name(&jg, "tangent_A");
17694
17695 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17696 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17697 let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
17699
17700 let (sched, mut arena) =
17701 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
17702 execute_thunks(&sched, arena.raw_buf_mut());
17703
17704 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17705
17706 let x_ref = {
17708 let mut a = a_data;
17709 let mut b = b_data;
17710 crate::blas::dgesv(&mut a, &mut b, n, 1);
17711 b
17712 };
17713 let mut prod = [0.0_f64; 3];
17714 for i in 0..n {
17715 for j in 0..n {
17716 prod[i] += ta_data[i * n + j] * x_ref[j];
17717 }
17718 }
17719 let t_x_ref = {
17720 let mut a = a_data;
17721 let mut p = prod;
17722 crate::blas::dgesv(&mut a, &mut p, n, 1);
17723 [-p[0], -p[1], -p[2]]
17724 };
17725 for i in 0..n {
17726 assert!(
17727 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17728 "closed-form t_x[{i}]: AD={} ref={}",
17729 tangent_x[i],
17730 t_x_ref[i]
17731 );
17732 }
17733
17734 let h = 1e-6;
17736 let mut ap = a_data;
17737 let mut am = a_data;
17738 for i in 0..n * n {
17739 ap[i] += h * ta_data[i];
17740 am[i] -= h * ta_data[i];
17741 }
17742 let xp = {
17743 let mut a = ap;
17744 let mut b = b_data;
17745 crate::blas::dgesv(&mut a, &mut b, n, 1);
17746 b
17747 };
17748 let xm = {
17749 let mut a = am;
17750 let mut b = b_data;
17751 crate::blas::dgesv(&mut a, &mut b, n, 1);
17752 b
17753 };
17754 for i in 0..n {
17755 let fd = (xp[i] - xm[i]) / (2.0 * h);
17756 assert!(
17757 (tangent_x[i] - fd).abs() < 1e-7,
17758 "FD t_x[{i}]: AD={} FD={}",
17759 tangent_x[i],
17760 fd
17761 );
17762 }
17763 }
17764
17765 #[test]
17771 fn q_conv2d_matches_reference() {
17772 use rlx_ir::Philox4x32;
17773 let n = 1usize;
17775 let c_in = 2usize;
17776 let h = 5usize;
17777 let w_in = 5usize;
17778 let c_out = 3usize;
17779 let kh = 3usize;
17780 let kw = 3usize;
17781 let ph = 1usize;
17782 let pw = 1usize;
17783 let sh = 1usize;
17784 let sw = 1usize;
17785 let h_out = (h + 2 * ph - kh) / sh + 1;
17786 let w_out = (w_in + 2 * pw - kw) / sw + 1;
17787
17788 let x_scale = 0.04f32;
17789 let w_scale = 0.02f32;
17790 let out_scale = 0.5f32;
17791 let mult = x_scale * w_scale / out_scale;
17792
17793 let mut rng = Philox4x32::new(2099);
17794 let mut xf = vec![0f32; n * c_in * h * w_in];
17795 rng.fill_normal(&mut xf);
17796 let mut wf = vec![0f32; c_out * c_in * kh * kw];
17797 rng.fill_normal(&mut wf);
17798 let xq: Vec<i8> = xf
17799 .iter()
17800 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17801 .collect();
17802 let wq: Vec<i8> = wf
17803 .iter()
17804 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17805 .collect();
17806 let bias: Vec<i32> = vec![0i32; c_out];
17807
17808 let mut g = Graph::new("qconv");
17809 let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
17810 let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
17811 let bn = g.input("b", Shape::new(&[c_out], DType::I32));
17812 let out = g.q_conv2d(
17813 xn,
17814 wn,
17815 bn,
17816 vec![kh, kw],
17817 vec![sh, sw],
17818 vec![ph, pw],
17819 vec![1, 1],
17820 1,
17821 0,
17822 0,
17823 0,
17824 mult,
17825 Shape::new(&[n, c_out, h_out, w_out], DType::I8),
17826 );
17827 g.set_outputs(vec![out]);
17828
17829 let plan = rlx_opt::memory::plan_memory(&g);
17830 let mut arena = crate::arena::Arena::from_plan(plan);
17831 let sched = compile_thunks(&g, &arena);
17832 let xn_off = arena.byte_offset(xn);
17835 let wn_off = arena.byte_offset(wn);
17836 let bn_off = arena.byte_offset(bn);
17837 let out_off = arena.byte_offset(out);
17838 let buf = arena.raw_buf_mut();
17839 unsafe {
17840 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
17841 for (i, &v) in xq.iter().enumerate() {
17842 *p.add(i) = v;
17843 }
17844 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
17845 for (i, &v) in wq.iter().enumerate() {
17846 *p.add(i) = v;
17847 }
17848 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
17849 for (i, &v) in bias.iter().enumerate() {
17850 *p.add(i) = v;
17851 }
17852 }
17853 execute_thunks(&sched, arena.raw_buf_mut());
17854 let out_q: Vec<i8> = unsafe {
17855 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
17856 (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
17857 };
17858
17859 let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
17861 for ni in 0..n {
17862 for co in 0..c_out {
17863 for ho in 0..h_out {
17864 for wo in 0..w_out {
17865 let mut acc: i32 = 0;
17866 for ci in 0..c_in {
17867 for ki in 0..kh {
17868 for kj in 0..kw {
17869 let hi = ho * sh + ki;
17870 let wi = wo * sw + kj;
17871 if hi < ph || wi < pw {
17872 continue;
17873 }
17874 let hi = hi - ph;
17875 let wi = wi - pw;
17876 if hi >= h || wi >= w_in {
17877 continue;
17878 }
17879 let xv =
17880 xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
17881 let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
17882 acc += xv * wv;
17883 }
17884 }
17885 }
17886 let r = (acc as f32 * mult).round() as i32;
17887 let r = r.clamp(-128, 127) as i8;
17888 out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
17889 }
17890 }
17891 }
17892 }
17893
17894 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
17895 assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
17896 }
17897 }
17898
17899 #[test]
17907 fn q_matmul_matches_fake_quant_reference() {
17908 use rlx_ir::Philox4x32;
17909 let m = 3usize;
17910 let k = 8usize;
17911 let n = 5usize;
17912 let mut rng = Philox4x32::new(2031);
17913
17914 let x_scale = 0.05f32;
17916 let w_scale = 0.03f32;
17917 let out_scale = 0.4f32;
17918 let mult = x_scale * w_scale / out_scale;
17919 let mut xf = vec![0f32; m * k];
17920 rng.fill_normal(&mut xf);
17921 let mut wf = vec![0f32; k * n];
17922 rng.fill_normal(&mut wf);
17923 let xq: Vec<i8> = xf
17924 .iter()
17925 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17926 .collect();
17927 let wq: Vec<i8> = wf
17928 .iter()
17929 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17930 .collect();
17931 let bias: Vec<i32> = vec![0i32; n];
17932
17933 let _f = DType::F32;
17935 let mut g_q = Graph::new("qmm_direct");
17936 let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
17937 let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
17938 let bn = g_q.input("b", Shape::new(&[n], DType::I32));
17939 let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
17940 g_q.set_outputs(vec![out]);
17941 let plan = rlx_opt::memory::plan_memory(&g_q);
17942 let mut arena = crate::arena::Arena::from_plan(plan);
17943 let sched = compile_thunks(&g_q, &arena);
17944
17945 let xn_off = arena.byte_offset(xn);
17947 let wn_off = arena.byte_offset(wn);
17948 let bn_off = arena.byte_offset(bn);
17949 let out_off = arena.byte_offset(out);
17950 let buf = arena.raw_buf_mut();
17951 unsafe {
17952 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
17953 for (i, &v) in xq.iter().enumerate() {
17954 *p.add(i) = v;
17955 }
17956 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
17957 for (i, &v) in wq.iter().enumerate() {
17958 *p.add(i) = v;
17959 }
17960 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
17961 for (i, &v) in bias.iter().enumerate() {
17962 *p.add(i) = v;
17963 }
17964 }
17965 execute_thunks(&sched, arena.raw_buf_mut());
17966 let out_q: Vec<i8> = unsafe {
17967 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
17968 (0..m * n).map(|i| *p.add(i)).collect()
17969 };
17970
17971 let mut out_ref = vec![0i8; m * n];
17976 for mi in 0..m {
17977 for ni in 0..n {
17978 let mut acc: i32 = 0;
17979 for ki in 0..k {
17980 acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
17981 }
17982 let r = (acc as f32 * mult).round() as i32;
17983 out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
17984 }
17985 }
17986
17987 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
17988 assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
17989 }
17990 }
17991
17992 #[test]
17997 fn quantize_dequantize_round_trip() {
17998 use rlx_ir::Philox4x32;
17999 let len = 64;
18000 let mut rng = Philox4x32::new(2027);
18001 let mut x = vec![0f32; len];
18002 rng.fill_normal(&mut x);
18003 x[0] = 999.0;
18006 x[1] = -999.0;
18007
18008 let scale = 0.05f32;
18009 let zp = 3i32;
18010
18011 let f = DType::F32;
18012 let mut g = Graph::new("qdq");
18013 let xn = g.input("x", Shape::new(&[len], f));
18014 let q = g.quantize(xn, scale, zp);
18015 let dq = g.dequantize(q, scale, zp);
18016 g.set_outputs(vec![dq]);
18017
18018 let plan = rlx_opt::memory::plan_memory(&g);
18019 let mut arena = crate::arena::Arena::from_plan(plan);
18020 let sched = compile_thunks(&g, &arena);
18021 let xn_off = arena.byte_offset(xn);
18022 let dq_off = arena.byte_offset(dq);
18023 let buf = arena.raw_buf_mut();
18024 unsafe {
18025 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18026 for (i, &v) in x.iter().enumerate() {
18027 *p.add(i) = v;
18028 }
18029 }
18030 execute_thunks(&sched, arena.raw_buf_mut());
18031 let out: Vec<f32> = unsafe {
18032 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18033 (0..len).map(|i| *p.add(i)).collect()
18034 };
18035
18036 let sat_pos = (127 - zp) as f32 * scale;
18039 let sat_neg = (-128 - zp) as f32 * scale;
18040 assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
18041 assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
18042
18043 for i in 2..len {
18046 assert!(
18047 (out[i] - x[i]).abs() <= scale + 1e-5,
18048 "qdq[{i}]: {} → {}, scale={scale}",
18049 x[i],
18050 out[i]
18051 );
18052 }
18053 }
18054
18055 #[test]
18061 fn quantize_per_channel_round_trip() {
18062 let c = 4usize;
18063 let inner = 5usize;
18064 let mags = [0.01f32, 0.5, 5.0, 50.0];
18067 let mut x = vec![0f32; c * inner];
18068 for ci in 0..c {
18069 for ii in 0..inner {
18070 x[ci * inner + ii] = match ii {
18074 0 => -mags[ci],
18075 1 => 0.0,
18076 2 => mags[ci],
18077 3 => mags[ci] * 1000.0, _ => -mags[ci] * 1000.0, };
18080 }
18081 }
18082 let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
18083 let zps: Vec<i32> = vec![0, 0, 0, 0];
18084
18085 let f = DType::F32;
18086 let mut g = Graph::new("qdq_pc");
18087 let xn = g.input("x", Shape::new(&[c, inner], f));
18088 let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
18089 let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
18090 g.set_outputs(vec![dq]);
18091
18092 let plan = rlx_opt::memory::plan_memory(&g);
18093 let mut arena = crate::arena::Arena::from_plan(plan);
18094 let sched = compile_thunks(&g, &arena);
18095 let xn_off = arena.byte_offset(xn);
18096 let dq_off = arena.byte_offset(dq);
18097 let buf = arena.raw_buf_mut();
18098 unsafe {
18099 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18100 for (i, &v) in x.iter().enumerate() {
18101 *p.add(i) = v;
18102 }
18103 }
18104 execute_thunks(&sched, arena.raw_buf_mut());
18105 let out: Vec<f32> = unsafe {
18106 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18107 (0..c * inner).map(|i| *p.add(i)).collect()
18108 };
18109
18110 for ci in 0..c {
18111 for ii in 0..3 {
18114 let idx = ci * inner + ii;
18115 assert!(
18116 (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
18117 "ch {ci} idx {ii}: {} vs {}",
18118 x[idx],
18119 out[idx]
18120 );
18121 }
18122 let sat_pos = 127.0 * scales[ci];
18124 let sat_neg = -128.0 * scales[ci];
18125 assert!(
18126 (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
18127 "ch {ci} +sat: {}",
18128 out[ci * inner + 3]
18129 );
18130 assert!(
18131 (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
18132 "ch {ci} -sat: {}",
18133 out[ci * inner + 4]
18134 );
18135 }
18136 }
18137
18138 #[test]
18144 fn activation_backward_matches_numerical_per_kind() {
18145 use rlx_ir::Philox4x32;
18146 use rlx_ir::op::Activation;
18147 let mut rng = Philox4x32::new(91);
18148 let len = 32;
18149 let mut x_pos = vec![0f32; len];
18154 rng.fill_normal(&mut x_pos);
18155 for v in x_pos.iter_mut() {
18156 *v = v.abs() + 0.5;
18157 }
18158 let mut x_any = vec![0f32; len];
18159 rng.fill_normal(&mut x_any);
18160 let mut dy = vec![0f32; len];
18161 rng.fill_normal(&mut dy);
18162
18163 for &(kind, x_data, eps, tol) in &[
18164 (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
18165 (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
18166 (Activation::Silu, &x_any[..], 1e-3, 5e-3),
18167 (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
18168 (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
18169 (Activation::Exp, &x_any[..], 1e-4, 5e-3),
18170 (Activation::Log, &x_pos[..], 1e-4, 5e-3),
18171 (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
18172 (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
18173 (Activation::Neg, &x_any[..], 1e-3, 5e-4),
18174 ] {
18175 let f = DType::F32;
18176 let mut g = Graph::new("act_bw");
18177 let xn = g.input("x", Shape::new(&[len], f));
18178 let dyn_ = g.input("dy", Shape::new(&[len], f));
18179 let dx = g.activation_backward(kind, xn, dyn_);
18180 g.set_outputs(vec![dx]);
18181
18182 let plan = rlx_opt::memory::plan_memory(&g);
18183 let mut arena = crate::arena::Arena::from_plan(plan);
18184 let sched = compile_thunks(&g, &arena);
18185
18186 let xn_off = arena.byte_offset(xn);
18187 let dyn_off = arena.byte_offset(dyn_);
18188 let dx_off = arena.byte_offset(dx);
18189 let buf = arena.raw_buf_mut();
18190 unsafe {
18191 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18192 for (i, &v) in x_data.iter().enumerate() {
18193 *p.add(i) = v;
18194 }
18195 let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
18196 for (i, &v) in dy.iter().enumerate() {
18197 *p.add(i) = v;
18198 }
18199 }
18200 execute_thunks(&sched, arena.raw_buf_mut());
18201 let analytical: Vec<f32> = unsafe {
18202 let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
18203 (0..len).map(|i| *p.add(i)).collect()
18204 };
18205
18206 let act_apply = |kind: Activation, x: f32| -> f32 {
18209 match kind {
18210 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
18211 Activation::Tanh => x.tanh(),
18212 Activation::Silu => x / (1.0 + (-x).exp()),
18213 Activation::Gelu => {
18214 const INV_SQRT2: f32 = 0.707_106_77;
18216 0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
18217 }
18218 Activation::GeluApprox => {
18219 const C: f32 = 0.797_884_6;
18220 const A: f32 = 0.044_715;
18221 let inner = C * (x + A * x * x * x);
18222 0.5 * x * (1.0 + inner.tanh())
18223 }
18224 Activation::Exp => x.exp(),
18225 Activation::Log => x.ln(),
18226 Activation::Sqrt => x.sqrt(),
18227 Activation::Rsqrt => 1.0 / x.sqrt(),
18228 Activation::Neg => -x,
18229 Activation::Relu => x.max(0.0),
18230 Activation::Abs => x.abs(),
18231 Activation::Round => x.round(),
18232 Activation::Sin => x.sin(),
18233 Activation::Cos => x.cos(),
18234 Activation::Tan => x.tan(),
18235 Activation::Atan => x.atan(),
18236 }
18237 };
18238 for i in 0..len {
18239 let xv = x_data[i];
18240 let plus = act_apply(kind, xv + eps);
18241 let minus = act_apply(kind, xv - eps);
18242 let num = (plus - minus) / (2.0 * eps) * dy[i];
18243 assert!(
18244 (analytical[i] - num).abs() < tol,
18245 "{kind:?}[{i}]: analytical {} vs numerical {num}",
18246 analytical[i]
18247 );
18248 }
18249 }
18250 }
18251
18252 #[test]
18256 fn matmul_3d_gradient_matches_numerical() {
18257 use rlx_ir::Philox4x32;
18258 let batch = 2usize;
18259 let m = 3usize;
18260 let k = 4usize;
18261 let n = 5usize;
18262 let mut rng = Philox4x32::new(101);
18263 let mut a_data = vec![0f32; batch * m * k];
18264 rng.fill_normal(&mut a_data);
18265 let mut b_data = vec![0f32; batch * k * n];
18266 rng.fill_normal(&mut b_data);
18267
18268 let f = DType::F32;
18269 let mut fwd = Graph::new("matmul_3d");
18270 let an = fwd.input("a", Shape::new(&[batch, m, k], f));
18271 let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
18272 let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
18273 let loss = fwd.add_node(
18274 Op::Reduce {
18275 op: ReduceOp::Sum,
18276 axes: vec![0, 1, 2],
18277 keep_dim: false,
18278 },
18279 vec![mm],
18280 Shape::from_dims(&[], f),
18281 );
18282 fwd.set_outputs(vec![loss]);
18283
18284 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
18285 let d_out = bwd_graph
18286 .nodes()
18287 .iter()
18288 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18289 .map(|n| n.id)
18290 .unwrap();
18291
18292 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18293 let mut arena = crate::arena::Arena::from_plan(plan);
18294 let sched = compile_thunks(&bwd_graph, &arena);
18295 for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
18296 let off = arena.byte_offset(id);
18297 let buf = arena.raw_buf_mut();
18298 unsafe {
18299 let p = buf.as_mut_ptr().add(off) as *mut f32;
18300 for (i, &v) in data.iter().enumerate() {
18301 *p.add(i) = v;
18302 }
18303 }
18304 }
18305 execute_thunks(&sched, arena.raw_buf_mut());
18306 let gb_id = bwd_graph.outputs[1];
18307 let g_b: Vec<f32> = unsafe {
18308 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
18309 (0..batch * k * n).map(|i| *p.add(i)).collect()
18310 };
18311
18312 let forward_loss = |b_vals: &[f32]| -> f32 {
18314 let mut out = vec![0f32; batch * m * n];
18315 for bi in 0..batch {
18316 for mi in 0..m {
18317 for ni in 0..n {
18318 let mut acc = 0f32;
18319 for ki in 0..k {
18320 acc +=
18321 a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
18322 }
18323 out[bi * m * n + mi * n + ni] = acc;
18324 }
18325 }
18326 }
18327 out.iter().sum()
18328 };
18329 let eps = 1e-3f32;
18330 let mut bp_p = b_data.clone();
18331 let mut g_b_num = vec![0f32; b_data.len()];
18332 for i in 0..b_data.len() {
18333 let s = bp_p[i];
18334 bp_p[i] = s + eps;
18335 let lp = forward_loss(&bp_p);
18336 bp_p[i] = s - eps;
18337 let lm = forward_loss(&bp_p);
18338 bp_p[i] = s;
18339 g_b_num[i] = (lp - lm) / (2.0 * eps);
18340 }
18341 for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
18342 assert!(
18343 (a - n).abs() < 5e-3,
18344 "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
18345 );
18346 }
18347 }
18348
18349 #[test]
18355 fn softmax_gradient_matches_numerical() {
18356 use rlx_ir::Philox4x32;
18357 let n = 3usize;
18358 let c = 5usize;
18359 let mut rng = Philox4x32::new(57);
18360 let mut x_data = vec![0f32; n * c];
18361 rng.fill_normal(&mut x_data);
18362
18363 let f = DType::F32;
18364 let mut fwd = Graph::new("softmax_only");
18365 let xn = fwd.input("x", Shape::new(&[n, c], f));
18366 let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
18367 let loss = fwd.add_node(
18371 Op::Reduce {
18372 op: ReduceOp::Sum,
18373 axes: vec![0, 1],
18374 keep_dim: false,
18375 },
18376 vec![sm],
18377 Shape::from_dims(&[], f),
18378 );
18379 fwd.set_outputs(vec![loss]);
18380
18381 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
18385 let d_out = bwd_graph
18386 .nodes()
18387 .iter()
18388 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18389 .map(|n| n.id)
18390 .unwrap();
18391
18392 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18393 let mut arena = crate::arena::Arena::from_plan(plan);
18394 let sched = compile_thunks(&bwd_graph, &arena);
18395 for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
18396 let off = arena.byte_offset(id);
18397 let buf = arena.raw_buf_mut();
18398 unsafe {
18399 let p = buf.as_mut_ptr().add(off) as *mut f32;
18400 for (i, &v) in data.iter().enumerate() {
18401 *p.add(i) = v;
18402 }
18403 }
18404 }
18405 execute_thunks(&sched, arena.raw_buf_mut());
18406 let g_x_id = bwd_graph.outputs[1];
18407 let g_x: Vec<f32> = unsafe {
18408 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
18409 (0..n * c).map(|i| *p.add(i)).collect()
18410 };
18411
18412 let forward_loss = |x: &[f32]| -> f32 {
18416 let mut total = 0f32;
18417 for ni in 0..n {
18418 let row = &x[ni * c..(ni + 1) * c];
18419 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
18420 let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
18421 for &v in row {
18422 total += (v - m).exp() / denom;
18423 }
18424 }
18425 total
18426 };
18427 let eps = 1e-3f32;
18428 let mut p = x_data.clone();
18429 for i in 0..x_data.len() {
18430 let s = p[i];
18431 p[i] = s + eps;
18432 let lp = forward_loss(&p);
18433 p[i] = s - eps;
18434 let lm = forward_loss(&p);
18435 p[i] = s;
18436 let num = (lp - lm) / (2.0 * eps);
18437 assert!(
18438 (g_x[i] - num).abs() < 5e-3,
18439 "softmax g_x[{i}]: analytical {} vs numerical {num}",
18440 g_x[i]
18441 );
18442 }
18443 }
18444
18445 #[test]
18450 fn layer_norm_gradient_matches_numerical() {
18451 use rlx_ir::Philox4x32;
18452 let rows = 3usize;
18453 let h = 6usize;
18454 let mut rng = Philox4x32::new(1009);
18455 let mut x_data = vec![0f32; rows * h];
18456 rng.fill_normal(&mut x_data);
18457 let mut g_data = vec![0f32; h];
18458 rng.fill_normal(&mut g_data);
18459 for v in g_data.iter_mut() {
18460 *v = v.abs() + 0.5;
18461 }
18462 let mut b_data = vec![0f32; h];
18463 rng.fill_normal(&mut b_data);
18464 let eps = 1e-5f32;
18465
18466 let f = DType::F32;
18467 let mut fwd = Graph::new("ln_only");
18468 let xn = fwd.input("x", Shape::new(&[rows, h], f));
18469 let gp = fwd.param("gamma", Shape::new(&[h], f));
18470 let bp = fwd.param("beta", Shape::new(&[h], f));
18471 let ln = fwd.add_node(
18472 Op::LayerNorm { axis: -1, eps },
18473 vec![xn, gp, bp],
18474 Shape::new(&[rows, h], f),
18475 );
18476 let loss = fwd.add_node(
18477 Op::Reduce {
18478 op: ReduceOp::Sum,
18479 axes: vec![0, 1],
18480 keep_dim: false,
18481 },
18482 vec![ln],
18483 Shape::from_dims(&[], f),
18484 );
18485 fwd.set_outputs(vec![loss]);
18486
18487 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
18488 let d_out = bwd_graph
18489 .nodes()
18490 .iter()
18491 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18492 .map(|n| n.id)
18493 .unwrap();
18494
18495 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18496 let mut arena = crate::arena::Arena::from_plan(plan);
18497 let sched = compile_thunks(&bwd_graph, &arena);
18498 for &(id, data) in &[
18499 (xn, &x_data),
18500 (gp, &g_data),
18501 (bp, &b_data),
18502 (d_out, &vec![1.0f32]),
18503 ] {
18504 let off = arena.byte_offset(id);
18505 let buf = arena.raw_buf_mut();
18506 unsafe {
18507 let p = buf.as_mut_ptr().add(off) as *mut f32;
18508 for (i, &v) in data.iter().enumerate() {
18509 *p.add(i) = v;
18510 }
18511 }
18512 }
18513 execute_thunks(&sched, arena.raw_buf_mut());
18514 let read = |id: NodeId, n: usize| -> Vec<f32> {
18515 let off = arena.byte_offset(id);
18516 unsafe {
18517 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
18518 (0..n).map(|i| *p.add(i)).collect()
18519 }
18520 };
18521 let dx_a = read(bwd_graph.outputs[1], rows * h);
18522 let dg_a = read(bwd_graph.outputs[2], h);
18523 let db_a = read(bwd_graph.outputs[3], h);
18524
18525 let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
18526 let mut total = 0f32;
18527 for r in 0..rows {
18528 let row = &x[r * h..(r + 1) * h];
18529 let mean = row.iter().sum::<f32>() / h as f32;
18530 let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
18531 let inv_std = 1.0 / (var + eps).sqrt();
18532 for d in 0..h {
18533 total += ((row[d] - mean) * inv_std) * g[d] + b[d];
18534 }
18535 }
18536 total
18537 };
18538 let h_eps = 1e-3f32;
18539
18540 let mut x_p = x_data.clone();
18541 for i in 0..x_p.len() {
18542 let s = x_p[i];
18543 x_p[i] = s + h_eps;
18544 let lp = forward_loss(&x_p, &g_data, &b_data);
18545 x_p[i] = s - h_eps;
18546 let lm = forward_loss(&x_p, &g_data, &b_data);
18547 x_p[i] = s;
18548 let num = (lp - lm) / (2.0 * h_eps);
18549 assert!(
18550 (dx_a[i] - num).abs() < 5e-3,
18551 "ln dx[{i}]: analytical {} vs numerical {num}",
18552 dx_a[i]
18553 );
18554 }
18555 let mut g_p = g_data.clone();
18556 for i in 0..g_p.len() {
18557 let s = g_p[i];
18558 g_p[i] = s + h_eps;
18559 let lp = forward_loss(&x_data, &g_p, &b_data);
18560 g_p[i] = s - h_eps;
18561 let lm = forward_loss(&x_data, &g_p, &b_data);
18562 g_p[i] = s;
18563 let num = (lp - lm) / (2.0 * h_eps);
18564 assert!(
18565 (dg_a[i] - num).abs() < 5e-3,
18566 "ln dg[{i}]: analytical {} vs numerical {num}",
18567 dg_a[i]
18568 );
18569 }
18570 let mut b_p = b_data.clone();
18571 for i in 0..b_p.len() {
18572 let s = b_p[i];
18573 b_p[i] = s + h_eps;
18574 let lp = forward_loss(&x_data, &g_data, &b_p);
18575 b_p[i] = s - h_eps;
18576 let lm = forward_loss(&x_data, &g_data, &b_p);
18577 b_p[i] = s;
18578 let num = (lp - lm) / (2.0 * h_eps);
18579 assert!(
18580 (db_a[i] - num).abs() < 5e-3,
18581 "ln db[{i}]: analytical {} vs numerical {num}",
18582 db_a[i]
18583 );
18584 }
18585 }
18586
18587 #[test]
18592 fn dense_sce_mean_gradient_matches_numerical() {
18593 use rlx_ir::Philox4x32;
18594 let bs = 4usize;
18595 let k_in = 3usize;
18596 let c = 5usize;
18597 let mut rng = Philox4x32::new(7);
18598 let mut x = vec![0f32; bs * k_in];
18599 rng.fill_normal(&mut x);
18600 let mut w_init = vec![0f32; k_in * c];
18601 rng.fill_normal(&mut w_init);
18602 let mut b_init = vec![0f32; c];
18603 rng.fill_normal(&mut b_init);
18604 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18605
18606 let f = DType::F32;
18608 let mut fwd = Graph::new("dense_sce");
18609 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18610 let lb = fwd.input("labels", Shape::new(&[bs], f));
18611 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18612 let bp = fwd.param("b", Shape::new(&[c], f));
18613 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18614 let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
18615 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18616 let loss = fwd.add_node(
18617 Op::Reduce {
18618 op: ReduceOp::Sum,
18619 axes: vec![0],
18620 keep_dim: false,
18621 },
18622 vec![loss_per],
18623 Shape::from_dims(&[], f),
18625 );
18626 fwd.set_outputs(vec![loss]);
18634
18635 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
18637 let d_out = bwd_graph
18640 .nodes()
18641 .iter()
18642 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18643 .map(|n| n.id)
18644 .expect("d_output input");
18645
18646 let (sched, mut arena) = prepare(
18647 &bwd_graph,
18648 &[
18649 (xn, &x),
18650 (lb, &labels),
18651 (wp, &w_init),
18652 (bp, &b_init),
18653 (d_out, &[1.0]),
18654 ],
18655 );
18656 execute_thunks(&sched, arena.raw_buf_mut());
18657
18658 let outs = &bwd_graph.outputs;
18659 let loss_id = outs[0];
18660 let gw_id = outs[1];
18661 let gb_id = outs[2];
18662 let loss_actual = read_arena(&arena, loss_id, 1)[0];
18663 let gw_actual = read_arena(&arena, gw_id, k_in * c);
18664 let gb_actual = read_arena(&arena, gb_id, c);
18665
18666 let plan = rlx_opt::memory::plan_memory(&fwd);
18670 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18671 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18672 write_arena(&mut fwd_arena, xn, &x);
18673 write_arena(&mut fwd_arena, lb, &labels);
18674
18675 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
18676 write_arena(arena, wp, w);
18677 write_arena(arena, bp, b);
18678 execute_thunks(&fwd_sched, arena.raw_buf_mut());
18679 read_arena(arena, loss, 1)[0]
18680 };
18681
18682 let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
18685 assert!(
18686 (loss_actual - loss_check).abs() < 1e-4,
18687 "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
18688 );
18689
18690 let eps = 1e-3f32;
18691 let mut w_perturbed = w_init.clone();
18692 let mut gw_numerical = vec![0f32; w_init.len()];
18693 for i in 0..w_init.len() {
18694 let saved = w_perturbed[i];
18695 w_perturbed[i] = saved + eps;
18696 let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18697 w_perturbed[i] = saved - eps;
18698 let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18699 w_perturbed[i] = saved;
18700 gw_numerical[i] = (lp - lm) / (2.0 * eps);
18701 }
18702 for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
18703 assert!(
18704 (a - n).abs() < 5e-3,
18705 "grad_w[{i}]: analytical {a} vs numerical {n}"
18706 );
18707 }
18708
18709 let mut b_perturbed = b_init.clone();
18710 let mut gb_numerical = vec![0f32; b_init.len()];
18711 for i in 0..b_init.len() {
18712 let saved = b_perturbed[i];
18713 b_perturbed[i] = saved + eps;
18714 let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18715 b_perturbed[i] = saved - eps;
18716 let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18717 b_perturbed[i] = saved;
18718 gb_numerical[i] = (lp - lm) / (2.0 * eps);
18719 }
18720 for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
18721 assert!(
18722 (a - n).abs() < 5e-3,
18723 "grad_b[{i}]: analytical {a} vs numerical {n}"
18724 );
18725 }
18726 }
18727
18728 #[test]
18731 fn dense_sce_mean_reduce_gradient_matches_numerical() {
18732 use rlx_ir::Philox4x32;
18733 let bs = 3usize;
18734 let k_in = 2usize;
18735 let c = 4usize;
18736 let mut rng = Philox4x32::new(13);
18737 let mut x = vec![0f32; bs * k_in];
18738 rng.fill_normal(&mut x);
18739 let mut w_init = vec![0f32; k_in * c];
18740 rng.fill_normal(&mut w_init);
18741 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18742
18743 let f = DType::F32;
18744 let mut fwd = Graph::new("dense_sce_mean");
18745 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18746 let lb = fwd.input("labels", Shape::new(&[bs], f));
18747 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18748 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18749 let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
18750 let loss = fwd.add_node(
18751 Op::Reduce {
18752 op: ReduceOp::Mean,
18753 axes: vec![0],
18754 keep_dim: false,
18755 },
18756 vec![loss_per],
18757 Shape::from_dims(&[], f),
18758 );
18759 fwd.set_outputs(vec![loss]);
18760
18761 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
18762 let d_out = bwd_graph
18763 .nodes()
18764 .iter()
18765 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18766 .map(|n| n.id)
18767 .unwrap();
18768
18769 let (sched, mut arena) = prepare(
18770 &bwd_graph,
18771 &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
18772 );
18773 execute_thunks(&sched, arena.raw_buf_mut());
18774
18775 let outs = &bwd_graph.outputs;
18776 let loss_id = outs[0];
18777 let gw_id = outs[1];
18778 let _ = read_arena(&arena, loss_id, 1)[0];
18779 let gw_actual = read_arena(&arena, gw_id, k_in * c);
18780
18781 let plan = rlx_opt::memory::plan_memory(&fwd);
18782 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18783 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18784 write_arena(&mut fwd_arena, xn, &x);
18785 write_arena(&mut fwd_arena, lb, &labels);
18786
18787 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
18788 write_arena(arena, wp, w);
18789 execute_thunks(&fwd_sched, arena.raw_buf_mut());
18790 read_arena(arena, loss, 1)[0]
18791 };
18792
18793 let eps = 1e-3f32;
18794 let mut wp_p = w_init.clone();
18795 let mut gw_num = vec![0f32; w_init.len()];
18796 for i in 0..w_init.len() {
18797 let s = wp_p[i];
18798 wp_p[i] = s + eps;
18799 let lp = run_loss(&mut fwd_arena, &wp_p);
18800 wp_p[i] = s - eps;
18801 let lm = run_loss(&mut fwd_arena, &wp_p);
18802 wp_p[i] = s;
18803 gw_num[i] = (lp - lm) / (2.0 * eps);
18804 }
18805 for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
18806 assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
18807 }
18808 }
18809 #[test]
18814 fn tinyconv_full_gradient_matches_numerical() {
18815 use rlx_ir::Philox4x32;
18816 let n = 1usize;
18818 let c_in = 1usize;
18819 let h = 6usize;
18820 let w_in = 6usize;
18821 let c_mid = 2usize; let kh = 3;
18823 let kw = 3;
18824 let h1 = h - kh + 1; let w1 = w_in - kw + 1; let h2 = h1 / 2;
18827 let w2 = w1 / 2; let flat = c_mid * h2 * w2; let num_classes = 3usize;
18830
18831 let mut rng = Philox4x32::new(31);
18832 let mut x = vec![0f32; n * c_in * h * w_in];
18833 rng.fill_normal(&mut x);
18834 let mut wc = vec![0f32; c_mid * c_in * kh * kw];
18835 rng.fill_normal(&mut wc);
18836 for v in wc.iter_mut() {
18837 *v *= 0.2;
18838 }
18839 let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
18848 let mut wfc = vec![0f32; flat * num_classes];
18849 rng.fill_normal(&mut wfc);
18850 for v in wfc.iter_mut() {
18851 *v *= 0.5;
18852 }
18853 let mut bfc = vec![0f32; num_classes];
18854 rng.fill_normal(&mut bfc);
18855 let labels: Vec<f32> = vec![1.0]; let f = DType::F32;
18858 let mut fwd = Graph::new("tinyconv");
18859 let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
18860 let lb = fwd.input("labels", Shape::new(&[n], f));
18861 let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
18862 let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
18863 let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
18864 let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
18865
18866 let conv = fwd.add_node(
18868 Op::Conv {
18869 kernel_size: vec![kh, kw],
18870 stride: vec![1, 1],
18871 padding: vec![0, 0],
18872 dilation: vec![1, 1],
18873 groups: 1,
18874 },
18875 vec![xn, wcp],
18876 Shape::new(&[n, c_mid, h1, w1], f),
18877 );
18878 let bc_4d = fwd.add_node(
18890 Op::Reshape {
18891 new_shape: vec![1, c_mid as i64, 1, 1],
18892 },
18893 vec![bcp],
18894 Shape::new(&[1, c_mid, 1, 1], f),
18895 );
18896 let bc_expanded = fwd.add_node(
18897 Op::Expand {
18898 target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
18899 },
18900 vec![bc_4d],
18901 Shape::new(&[n, c_mid, h1, w1], f),
18902 );
18903 let conv_b = fwd.binary(
18904 BinaryOp::Add,
18905 conv,
18906 bc_expanded,
18907 Shape::new(&[n, c_mid, h1, w1], f),
18908 );
18909 let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
18910 let pool = fwd.add_node(
18911 Op::Pool {
18912 kind: ReduceOp::Max,
18913 kernel_size: vec![2, 2],
18914 stride: vec![2, 2],
18915 padding: vec![0, 0],
18916 },
18917 vec![relu],
18918 Shape::new(&[n, c_mid, h2, w2], f),
18919 );
18920 let flatn = fwd.add_node(
18921 Op::Reshape {
18922 new_shape: vec![n as i64, flat as i64],
18923 },
18924 vec![pool],
18925 Shape::new(&[n, flat], f),
18926 );
18927 let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
18928 let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
18929 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18930 let loss = fwd.add_node(
18931 Op::Reduce {
18932 op: ReduceOp::Mean,
18933 axes: vec![0],
18934 keep_dim: false,
18935 },
18936 vec![loss_per],
18937 Shape::from_dims(&[], f),
18938 );
18939 fwd.set_outputs(vec![loss]);
18940
18941 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
18942 let d_out = bwd_graph
18943 .nodes()
18944 .iter()
18945 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18946 .map(|n| n.id)
18947 .unwrap();
18948
18949 let (sched, mut arena) = prepare(
18950 &bwd_graph,
18951 &[
18952 (xn, &x),
18953 (lb, &labels),
18954 (wcp, &wc),
18955 (bcp, &bc),
18956 (wfp, &wfc),
18957 (bfp, &bfc),
18958 (d_out, &[1.0]),
18959 ],
18960 );
18961 execute_thunks(&sched, arena.raw_buf_mut());
18962
18963 let outs = bwd_graph.outputs.clone();
18964 let loss_id = outs[0];
18965 let g_wc_id = outs[1];
18966 let g_bc_id = outs[2];
18967 let g_wfc_id = outs[3];
18968 let g_bfc_id = outs[4];
18969 let loss_actual = read_arena(&arena, loss_id, 1)[0];
18970 let g_wc = read_arena(&arena, g_wc_id, wc.len());
18971 let g_bc = read_arena(&arena, g_bc_id, bc.len());
18972 let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
18973 let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
18974
18975 let plan = rlx_opt::memory::plan_memory(&fwd);
18977 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18978 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18979 write_arena(&mut fwd_arena, xn, &x);
18980 write_arena(&mut fwd_arena, lb, &labels);
18981
18982 let run_loss = |arena: &mut crate::arena::Arena,
18985 wc: &[f32],
18986 bc: &[f32],
18987 wfc: &[f32],
18988 bfc: &[f32]|
18989 -> f32 {
18990 write_arena(arena, wcp, wc);
18991 write_arena(arena, bcp, bc);
18992 write_arena(arena, wfp, wfc);
18993 write_arena(arena, bfp, bfc);
18994 execute_thunks(&fwd_sched, arena.raw_buf_mut());
18995 read_arena(arena, loss, 1)[0]
18996 };
18997
18998 let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
18999 assert!(
19000 (loss_actual - loss_check).abs() < 1e-4,
19001 "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
19002 );
19003
19004 let eps = 1e-3f32;
19005 let check_grad = |arena: &mut crate::arena::Arena,
19006 name: &str,
19007 analytical: &[f32],
19008 mut perturb: Box<
19009 dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
19010 >,
19011 n: usize| {
19012 for i in 0..n {
19013 let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
19014 let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
19015 let num = (lp - lm) / (2.0 * eps);
19016 assert!(
19017 (analytical[i] - num).abs() < 5e-3,
19018 "{name}[{i}]: analytical {} vs numerical {num}",
19019 analytical[i]
19020 );
19021 }
19022 };
19023
19024 #[allow(unused_macros)]
19027 macro_rules! sweep {
19028 ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
19029 let n = $base.len();
19030 for i in 0..n {
19031 let mut p = $base.clone();
19032 let s = p[i];
19033 p[i] = s + eps;
19034 let lp = {
19035 let $set_param = &p;
19036 run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
19037 let _ = $set_param;
19040 0.0_f32
19042 };
19043 let _ = lp;
19044 }
19045 }};
19046 }
19047 let _ = check_grad; for i in 0..wc.len() {
19051 let mut p = wc.clone();
19052 let s = p[i];
19053 p[i] = s + eps;
19054 let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19055 p[i] = s - eps;
19056 let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19057 let num = (lp - lm) / (2.0 * eps);
19058 assert!(
19059 (g_wc[i] - num).abs() < 5e-3,
19060 "g_wc[{i}]: {} vs {num}",
19061 g_wc[i]
19062 );
19063 }
19064 for i in 0..bc.len() {
19065 let mut p = bc.clone();
19066 let s = p[i];
19067 p[i] = s + eps;
19068 let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19069 p[i] = s - eps;
19070 let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19071 let num = (lp - lm) / (2.0 * eps);
19072 assert!(
19073 (g_bc[i] - num).abs() < 5e-3,
19074 "g_bc[{i}]: {} vs {num}",
19075 g_bc[i]
19076 );
19077 }
19078 for i in 0..wfc.len() {
19079 let mut p = wfc.clone();
19080 let s = p[i];
19081 p[i] = s + eps;
19082 let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19083 p[i] = s - eps;
19084 let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19085 let num = (lp - lm) / (2.0 * eps);
19086 assert!(
19087 (g_wfc[i] - num).abs() < 5e-3,
19088 "g_wfc[{i}]: {} vs {num}",
19089 g_wfc[i]
19090 );
19091 }
19092 for i in 0..bfc.len() {
19093 let mut p = bfc.clone();
19094 let s = p[i];
19095 p[i] = s + eps;
19096 let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19097 p[i] = s - eps;
19098 let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19099 let num = (lp - lm) / (2.0 * eps);
19100 assert!(
19101 (g_bfc[i] - num).abs() < 5e-3,
19102 "g_bfc[{i}]: {} vs {num}",
19103 g_bfc[i]
19104 );
19105 }
19106 }
19107
19108 #[test]
19112 fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
19113 let f = DType::F32;
19114 let mut g = Graph::new("nr_skip");
19115 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
19116 let cos = g.input("cos", Shape::new(&[16], f));
19117 let sin = g.input("sin", Shape::new(&[16], f));
19118 let q = g.narrow_(qkv, 2, 0, 64);
19119 let q_rope = g.rope(q, cos, sin, 16);
19120 let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
19122 g.set_outputs(vec![q_rope, q_dup]);
19123
19124 let plan = rlx_opt::memory::plan_memory(&g);
19125 let arena = crate::arena::Arena::from_plan(plan);
19126 let sched = compile_thunks(&g, &arena);
19127
19128 let narrow_count = sched
19129 .thunks
19130 .iter()
19131 .filter(|t| matches!(t, Thunk::Narrow { .. }))
19132 .count();
19133 assert!(
19134 narrow_count >= 1,
19135 "Narrow with multiple consumers must NOT be fused away"
19136 );
19137 }
19138
19139 #[test]
19152 fn custom_fn_forward_inlines_body() {
19153 let s = Shape::new(&[3], DType::F32);
19154
19155 let mut body = Graph::new("addone_body");
19157 let x = body.input("x", s.clone());
19158 let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
19159 let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
19160 let y = body.binary(BinaryOp::Add, x, one, s.clone());
19161 body.set_outputs(vec![y]);
19162
19163 let mut g = Graph::new("custom_fn_outer");
19164 let xin = g.input("x_in", s.clone());
19165 let cf = g.custom_fn(vec![xin], body, None, None);
19166 g.set_outputs(vec![cf]);
19167
19168 let xs = vec![10.0_f32, 20.0, 30.0];
19169 let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
19170 execute_thunks(&sched, arena.raw_buf_mut());
19171 let got = read_arena(&arena, cf, 3);
19172 assert_eq!(got, vec![11.0, 21.0, 31.0]);
19173 }
19174
19175 fn find_named(graph: &Graph, want: &str) -> NodeId {
19177 for n in graph.nodes() {
19178 let name = match &n.op {
19179 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19180 _ => None,
19181 };
19182 if name == Some(want) {
19183 return n.id;
19184 }
19185 }
19186 panic!("no node named {want:?} in graph");
19187 }
19188
19189 #[test]
19193 fn custom_fn_vjp_overrides_natural_gradient() {
19194 use rlx_opt::autodiff::grad_with_loss;
19195 let s = Shape::new(&[1], DType::F32);
19196
19197 let mut fwd = Graph::new("id_fwd");
19198 let x = fwd.input("x", s.clone());
19199 fwd.set_outputs(vec![x]);
19200
19201 let mut vjp_g = Graph::new("id_vjp");
19202 let _x_p = vjp_g.input("x", s.clone());
19203 let _y_p = vjp_g.input("primal_output", s.clone());
19204 let dy = vjp_g.input("d_output", s.clone());
19205 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19206 let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19207 let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
19208 vjp_g.set_outputs(vec![dx]);
19209
19210 let mut g = Graph::new("outer");
19211 let xp = g.param("x", s.clone());
19212 let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
19213 g.set_outputs(vec![cf]);
19214
19215 let bwd = grad_with_loss(&g, &[xp]);
19216 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
19217
19218 let xb = find_named(&bwd, "x");
19219 let dout = find_named(&bwd, "d_output");
19220 let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
19221 execute_thunks(&sched, arena.raw_buf_mut());
19222 let loss = read_arena(&arena, bwd.outputs[0], 1);
19223 let dx_v = read_arena(&arena, bwd.outputs[1], 1);
19224 assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
19225 assert!(
19226 (dx_v[0] - 2.0).abs() < 1e-6,
19227 "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
19228 dx_v[0]
19229 );
19230 }
19231
19232 #[test]
19237 fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
19238 use rlx_opt::autodiff::grad_with_loss;
19239 let s = Shape::new(&[1], DType::F32);
19240
19241 let mut fwd = Graph::new("mul_fwd");
19242 let a_f = fwd.input("a", s.clone());
19243 let b_f = fwd.input("b", s.clone());
19244 let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
19245 fwd.set_outputs(vec![y_f]);
19246
19247 let mut vjp_g = Graph::new("mul_vjp");
19248 let a_v = vjp_g.input("a", s.clone());
19249 let b_v = vjp_g.input("b", s.clone());
19250 let _y_v = vjp_g.input("primal_output", s.clone());
19251 let dy_v = vjp_g.input("d_output", s.clone());
19252 let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
19253 let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
19254 vjp_g.set_outputs(vec![da, db]);
19255
19256 let mut g = Graph::new("outer");
19257 let ap = g.param("a", s.clone());
19258 let bp = g.param("b", s.clone());
19259 let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
19260 g.set_outputs(vec![cf]);
19261
19262 let bwd = grad_with_loss(&g, &[ap, bp]);
19263 assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
19264
19265 let ab = find_named(&bwd, "a");
19266 let bb = find_named(&bwd, "b");
19267 let dout = find_named(&bwd, "d_output");
19268 let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
19269 execute_thunks(&sched, arena.raw_buf_mut());
19270 let loss = read_arena(&arena, bwd.outputs[0], 1);
19271 let da_v = read_arena(&arena, bwd.outputs[1], 1);
19272 let db_v = read_arena(&arena, bwd.outputs[2], 1);
19273 assert!((loss[0] - 15.0).abs() < 1e-5);
19274 assert!(
19275 (da_v[0] - 5.0).abs() < 1e-5,
19276 "da should be b=5.0, got {}",
19277 da_v[0]
19278 );
19279 assert!(
19280 (db_v[0] - 3.0).abs() < 1e-5,
19281 "db should be a=3.0, got {}",
19282 db_v[0]
19283 );
19284 }
19285
19286 #[test]
19289 fn custom_fn_jvp_overrides_natural_tangent() {
19290 use rlx_opt::autodiff_fwd::jvp;
19291 let s = Shape::new(&[1], DType::F32);
19292
19293 let mut fwd = Graph::new("id_fwd");
19294 let x = fwd.input("x", s.clone());
19295 fwd.set_outputs(vec![x]);
19296
19297 let mut jvp_g = Graph::new("id_jvp");
19298 let _x_p = jvp_g.input("x", s.clone());
19299 let tx = jvp_g.input("tangent_0", s.clone());
19300 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19301 let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19302 let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
19303 jvp_g.set_outputs(vec![ty]);
19304
19305 let mut g = Graph::new("outer");
19306 let xin = g.input("x_in", s.clone());
19307 let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
19308 g.set_outputs(vec![cf]);
19309
19310 let fwd_g = jvp(&g, &[xin]);
19311 assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
19312
19313 let xb = find_named(&fwd_g, "x_in");
19314 let tan = find_named(&fwd_g, "tangent_x_in");
19315 let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
19316 execute_thunks(&sched, arena.raw_buf_mut());
19317 let y = read_arena(&arena, fwd_g.outputs[0], 1);
19318 let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
19319 assert!((y[0] - 7.0).abs() < 1e-6);
19320 assert!(
19321 (ty_v[0] - 2.0).abs() < 1e-6,
19322 "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
19323 ty_v[0]
19324 );
19325 }
19326
19327 #[test]
19332 fn c64_dtype_storage_layout() {
19333 assert_eq!(
19334 DType::C64.size_bytes(),
19335 8,
19336 "C64 should be 8 bytes (f32 real + f32 imag)"
19337 );
19338 assert!(DType::C64.is_complex());
19339 assert!(!DType::C64.is_float());
19340
19341 let s = Shape::new(&[2], DType::C64);
19343 assert_eq!(s.size_bytes().unwrap(), 16);
19344 }
19345
19346 fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
19353 let n = a.len();
19354 let s = Shape::new(&[n], DType::C64);
19355 let mut g = Graph::new("c64_bin");
19356 let in_a = g.input("a", s.clone());
19357 let in_b = g.input("b", s.clone());
19358 let out = g.binary(op, in_a, in_b, s.clone());
19359 g.set_outputs(vec![out]);
19360
19361 let plan = rlx_opt::memory::plan_memory(&g);
19362 let mut arena = crate::arena::Arena::from_plan(plan);
19363 let sched = compile_thunks(&g, &arena);
19364
19365 let a_off = arena.byte_offset(in_a);
19366 let b_off = arena.byte_offset(in_b);
19367 let out_off = arena.byte_offset(out);
19368 let buf = arena.raw_buf_mut();
19370 unsafe {
19371 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19372 let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
19373 for (i, &(re, im)) in a.iter().enumerate() {
19374 *pa.add(2 * i) = re;
19375 *pa.add(2 * i + 1) = im;
19376 }
19377 for (i, &(re, im)) in b.iter().enumerate() {
19378 *pb.add(2 * i) = re;
19379 *pb.add(2 * i + 1) = im;
19380 }
19381 }
19382 execute_thunks(&sched, arena.raw_buf_mut());
19383 let raw_out: Vec<f32> = unsafe {
19384 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19385 (0..(2 * n)).map(|i| *p.add(i)).collect()
19386 };
19387 (0..n)
19388 .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
19389 .collect()
19390 }
19391
19392 #[track_caller]
19393 fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
19394 let dr = (got.0 - expected.0).abs();
19395 let di = (got.1 - expected.1).abs();
19396 assert!(
19397 dr < tol && di < tol,
19398 "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
19399 got.0,
19400 got.1,
19401 expected.0,
19402 expected.1
19403 );
19404 }
19405
19406 #[test]
19407 fn c64_binary_add_matches_complex_arithmetic() {
19408 let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
19409 let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
19410 let out = run_c64_binary(BinaryOp::Add, &a, &b);
19411 assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
19412 assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
19413 }
19414
19415 #[test]
19416 fn c64_binary_sub_matches_complex_arithmetic() {
19417 let a = [(5.0_f32, 1.0_f32)];
19418 let b = [(2.0_f32, 3.0_f32)];
19419 let out = run_c64_binary(BinaryOp::Sub, &a, &b);
19420 assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
19421 }
19422
19423 #[test]
19424 fn c64_binary_mul_matches_complex_arithmetic() {
19425 let a = [(1.0_f32, 2.0_f32)];
19427 let b = [(3.0_f32, 4.0_f32)];
19428 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19429 assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
19430 }
19431
19432 #[test]
19433 fn c64_binary_div_matches_complex_arithmetic() {
19434 let a = [(1.0_f32, 2.0_f32)];
19438 let b = [(3.0_f32, 4.0_f32)];
19439 let out = run_c64_binary(BinaryOp::Div, &a, &b);
19440 assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
19441 }
19442
19443 #[test]
19444 fn c64_binary_mul_identity_one_is_no_op() {
19445 let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
19447 let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
19448 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19449 assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
19450 assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
19451 }
19452
19453 #[test]
19454 fn c64_binary_mul_by_i_rotates_90_degrees() {
19455 let a = [(1.0_f32, 0.0_f32)];
19457 let b = [(0.0_f32, 1.0_f32)];
19458 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19459 assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
19460 }
19461
19462 #[test]
19463 fn c64_binary_div_by_self_gives_unity() {
19464 let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
19465 let out = run_c64_binary(BinaryOp::Div, &a, &a);
19466 assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
19467 assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
19468 }
19469
19470 #[test]
19471 #[should_panic(expected = "C64: complex max/min/pow")]
19472 fn c64_binary_max_is_rejected_at_lowering() {
19473 run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
19474 }
19475
19476 fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
19477 let n = a.len();
19478 let s = Shape::new(&[n], DType::C64);
19479 let mut g = Graph::new("c64_act");
19480 let in_a = g.input("a", s.clone());
19481 let out = g.activation(act, in_a, s.clone());
19482 g.set_outputs(vec![out]);
19483 let plan = rlx_opt::memory::plan_memory(&g);
19484 let mut arena = crate::arena::Arena::from_plan(plan);
19485 let sched = compile_thunks(&g, &arena);
19486 let a_off = arena.byte_offset(in_a);
19487 let out_off = arena.byte_offset(out);
19488 let buf = arena.raw_buf_mut();
19489 unsafe {
19490 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19491 for (i, &(re, im)) in a.iter().enumerate() {
19492 *pa.add(2 * i) = re;
19493 *pa.add(2 * i + 1) = im;
19494 }
19495 }
19496 execute_thunks(&sched, arena.raw_buf_mut());
19497 let raw: Vec<f32> = unsafe {
19498 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19499 (0..(2 * n)).map(|i| *p.add(i)).collect()
19500 };
19501 (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
19502 }
19503
19504 #[test]
19505 fn c64_activation_neg_negates_both_components() {
19506 let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
19507 let out = run_c64_activation(Activation::Neg, &inp);
19508 assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
19509 assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
19510 }
19511
19512 #[test]
19513 fn c64_activation_exp_matches_euler() {
19514 let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
19517 let out = run_c64_activation(Activation::Exp, &inp);
19518 assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
19519 assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
19520 }
19521
19522 #[test]
19523 fn c64_activation_log_matches_principal_branch() {
19524 let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
19528 let out = run_c64_activation(Activation::Log, &inp);
19529 assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
19530 assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
19531 assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
19532 }
19533
19534 #[test]
19535 fn c64_activation_sqrt_squared_recovers_input() {
19536 let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
19539 let roots = run_c64_activation(Activation::Sqrt, &inp);
19540 assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
19542 assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
19543 }
19544
19545 #[test]
19546 #[should_panic(expected = "no natural complex extension")]
19547 fn c64_activation_relu_is_rejected_at_lowering() {
19548 run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
19549 }
19550
19551 fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
19555 let n = z.len();
19556 let mut g = Graph::new("cns_fwd");
19557 let in_z = g.input("z", Shape::new(&[n], DType::C64));
19558 let out = g.complex_norm_sq(in_z);
19559 g.set_outputs(vec![out]);
19560 let plan = rlx_opt::memory::plan_memory(&g);
19561 let mut arena = crate::arena::Arena::from_plan(plan);
19562 let sched = compile_thunks(&g, &arena);
19563 let z_off = arena.byte_offset(in_z);
19564 let out_off = arena.byte_offset(out);
19565 let buf = arena.raw_buf_mut();
19566 unsafe {
19567 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19568 for (i, &(re, im)) in z.iter().enumerate() {
19569 *pz.add(2 * i) = re;
19570 *pz.add(2 * i + 1) = im;
19571 }
19572 }
19573 execute_thunks(&sched, arena.raw_buf_mut());
19574 unsafe {
19575 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19576 (0..n).map(|i| *p.add(i)).collect()
19577 }
19578 }
19579
19580 fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
19582 let n = z.len();
19583 let mut gr = Graph::new("cns_bwd");
19584 let in_z = gr.input("z", Shape::new(&[n], DType::C64));
19585 let in_g = gr.input("g", Shape::new(&[n], DType::F32));
19586 let out = gr.complex_norm_sq_backward(in_z, in_g);
19587 gr.set_outputs(vec![out]);
19588 let plan = rlx_opt::memory::plan_memory(&gr);
19589 let mut arena = crate::arena::Arena::from_plan(plan);
19590 let sched = compile_thunks(&gr, &arena);
19591 let z_off = arena.byte_offset(in_z);
19592 let g_off = arena.byte_offset(in_g);
19593 let out_off = arena.byte_offset(out);
19594 let buf = arena.raw_buf_mut();
19595 unsafe {
19596 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19597 let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
19598 for (i, &(re, im)) in z.iter().enumerate() {
19599 *pz.add(2 * i) = re;
19600 *pz.add(2 * i + 1) = im;
19601 }
19602 for (i, &v) in g.iter().enumerate() {
19603 *pg.add(i) = v;
19604 }
19605 }
19606 execute_thunks(&sched, arena.raw_buf_mut());
19607 unsafe {
19608 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19609 (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
19610 }
19611 }
19612
19613 #[test]
19614 fn complex_norm_sq_matches_textbook() {
19615 let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
19619 let out = run_complex_norm_sq(&z);
19620 assert!((out[0] - 25.0).abs() < 1e-5);
19621 assert!((out[1] - 1.0).abs() < 1e-6);
19622 assert!(out[2].abs() < 1e-6);
19623 }
19624
19625 #[test]
19626 fn complex_norm_sq_backward_matches_wirtinger_formula() {
19627 let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
19629 let g = [1.0_f32, 1.0_f32];
19630 let dz = run_complex_norm_sq_bwd(&z, &g);
19631 assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
19632 assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
19633 }
19634
19635 #[test]
19636 fn complex_norm_sq_backward_scales_with_upstream() {
19637 let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
19639 let g = [0.5_f32, -2.0_f32];
19640 let dz = run_complex_norm_sq_bwd(&z, &g);
19641 assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
19642 assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
19643 }
19644
19645 #[test]
19650 fn custom_fn_multi_extracts_each_subgraph_output() {
19651 use rlx_ir::ops::special::MultiOutputHandle;
19652
19653 let _ = MultiOutputHandle {
19654 source: NodeId(0),
19655 sub_shapes: vec![],
19656 offsets: vec![],
19657 }; let mut body = Graph::new("multi_body");
19661 let s3 = Shape::new(&[3], DType::F32);
19662 let x = body.input("x", s3.clone());
19663 let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
19664 let two = body.add_node(
19665 Op::Constant {
19666 data: vec![
19667 2.0_f32.to_le_bytes(),
19668 2.0_f32.to_le_bytes(),
19669 2.0_f32.to_le_bytes(),
19670 ]
19671 .into_iter()
19672 .flatten()
19673 .collect(),
19674 },
19675 vec![],
19676 s3.clone(),
19677 );
19678 let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
19679 body.set_outputs(vec![x_sq, two_x]);
19680
19681 let mut outer = Graph::new("multi_outer");
19683 let in_x = outer.input("xin", s3.clone());
19684 let handle = outer.custom_fn_multi(vec![in_x], body);
19685 assert_eq!(handle.n_outputs(), 2);
19686 let out0 = handle.output(&mut outer, 0); let out1 = handle.output(&mut outer, 1); outer.set_outputs(vec![out0, out1]);
19689
19690 let plan = rlx_opt::memory::plan_memory(&outer);
19691 let mut arena = crate::arena::Arena::from_plan(plan);
19692 let sched = compile_thunks(&outer, &arena);
19693 let xin_off = arena.byte_offset(in_x);
19694 let out0_off = arena.byte_offset(out0);
19695 let out1_off = arena.byte_offset(out1);
19696 let xs = [1.0_f32, 2.0, 3.0];
19697 unsafe {
19698 let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
19699 for (i, &v) in xs.iter().enumerate() {
19700 *p.add(i) = v;
19701 }
19702 }
19703 execute_thunks(&sched, arena.raw_buf_mut());
19704 let out0_v: Vec<f32> = unsafe {
19705 let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
19706 (0..3).map(|i| *p.add(i)).collect()
19707 };
19708 let out1_v: Vec<f32> = unsafe {
19709 let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
19710 (0..3).map(|i| *p.add(i)).collect()
19711 };
19712 for i in 0..3 {
19714 assert!(
19715 (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
19716 "out0[{i}] = {} != x² = {}",
19717 out0_v[i],
19718 xs[i] * xs[i]
19719 );
19720 assert!(
19721 (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
19722 "out1[{i}] = {} != 2x = {}",
19723 out1_v[i],
19724 2.0 * xs[i]
19725 );
19726 }
19727 }
19728
19729 #[test]
19730 fn complex_norm_sq_gradient_matches_finite_difference() {
19731 let z = [(3.0_f32, 4.0_f32)];
19733 let eps = 1e-3_f32;
19734 let v0 = run_complex_norm_sq(&z)[0];
19735 let z_pert = [(3.0_f32 + eps, 4.0_f32)];
19736 let v1 = run_complex_norm_sq(&z_pert)[0];
19737 let fd_re = (v1 - v0) / eps;
19738 let analytic_re = 2.0 * z[0].0;
19739 assert!((fd_re - analytic_re).abs() < 1e-2);
19740
19741 let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
19743 let v2 = run_complex_norm_sq(&z_pert_im)[0];
19744 let fd_im = (v2 - v0) / eps;
19745 let analytic_im = 2.0 * z[0].1;
19746 assert!((fd_im - analytic_im).abs() < 1e-2);
19747
19748 let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
19754 assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
19755 assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
19756 }
19757
19758 #[test]
19763 fn binary_full_5d_mid_singleton_broadcast() {
19764 let bh = 2usize;
19765 let h = 3;
19766 let w = 4;
19767 let f = DType::F32;
19768
19769 let mut g = Graph::new("bcast_5d");
19770 let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
19771 let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
19773 let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
19774 g.set_outputs(vec![out]);
19775
19776 let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
19778 let rhs_data: Vec<f32> = (0..bh * h * w * w)
19779 .map(|i| (i as f32 + 100.0) * 0.01)
19780 .collect();
19781
19782 let mut expected = vec![0f32; bh * h * w * h * w];
19784 for b_ in 0..bh {
19785 for hq in 0..h {
19786 for wq in 0..w {
19787 for hk in 0..h {
19788 for wk in 0..w {
19789 let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
19790 let ri = ((b_ * h + hq) * w + wq) * w + wk;
19792 expected[li] = lhs_data[li] + rhs_data[ri];
19793 }
19794 }
19795 }
19796 }
19797 }
19798
19799 let plan = rlx_opt::memory::plan_memory(&g);
19800 let mut arena = crate::arena::Arena::from_plan(plan);
19801 let sched = compile_thunks(&g, &arena);
19802 let lhs_off = arena.byte_offset(lhs);
19803 let rhs_off = arena.byte_offset(rhs);
19804 let out_off = arena.byte_offset(out);
19805 let buf = arena.raw_buf_mut();
19806 unsafe {
19807 let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
19808 for (i, &v) in lhs_data.iter().enumerate() {
19809 *p.add(i) = v;
19810 }
19811 let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
19812 for (i, &v) in rhs_data.iter().enumerate() {
19813 *p.add(i) = v;
19814 }
19815 }
19816 execute_thunks(&sched, arena.raw_buf_mut());
19817 let actual: Vec<f32> = unsafe {
19818 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19819 (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
19820 };
19821
19822 let mut max_diff = 0f32;
19824 let mut max_idx = 0;
19825 for i in 0..actual.len() {
19826 let d = (actual[i] - expected[i]).abs();
19827 if d > max_diff {
19828 max_diff = d;
19829 max_idx = i;
19830 }
19831 }
19832 assert!(
19833 max_diff < 1e-6,
19834 "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
19835 (actual={}, expected={})",
19836 actual[max_idx],
19837 expected[max_idx]
19838 );
19839 }
19840
19841 #[test]
19842 fn layer_norm2d_and_conv_transpose2d_kernels() {
19843 let mut out = vec![0f32; 8];
19844 crate::kernels::layer_norm2d_nchw(
19845 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
19846 &[1.0, 1.0],
19847 &[0.0, 0.0],
19848 &mut out,
19849 1,
19850 2,
19851 2,
19852 2,
19853 1e-5,
19854 );
19855 let mean0: f32 = (1.0 + 3.0) / 2.0;
19856 assert!((out[0] - mean0).abs() > 0.1);
19857
19858 let mut up = vec![0f32; 4];
19859 crate::kernels::conv_transpose2d_nchw(
19860 &[2.0],
19861 &[1.0, 0.0, 0.0, 1.0],
19862 &mut up,
19863 1,
19864 1,
19865 1,
19866 1,
19867 1,
19868 2,
19869 2,
19870 2,
19871 2,
19872 2,
19873 2,
19874 0,
19875 0,
19876 1,
19877 1,
19878 1,
19879 );
19880 assert!((up[0] - 2.0).abs() < 1e-5);
19881 assert!((up[3] - 2.0).abs() < 1e-5);
19882 }
19883}