1#![allow(unsafe_op_in_unsafe_fn)]
24use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33#[derive(Clone)]
35pub enum Thunk {
36 Nop,
38 Sgemm {
40 a: usize,
41 b: usize,
42 c: usize,
43 m: u32,
44 k: u32,
45 n: u32,
46 },
47 CgemmC64 {
51 a: usize,
52 b: usize,
53 c: usize,
54 m: u32,
55 k: u32,
56 n: u32,
57 },
58 DenseSolveF64 {
64 a: usize,
65 b: usize,
66 x: usize,
67 n: u32,
68 nrhs: u32,
69 },
70 DenseSolveF32 {
73 a: usize,
74 b: usize,
75 x: usize,
76 n: u32,
77 nrhs: u32,
78 },
79 BatchedDenseSolveF64 {
84 a: usize,
85 b: usize,
86 x: usize,
87 batch: u32,
88 n: u32,
89 nrhs: u32,
90 },
91 BatchedDenseSolveF32 {
93 a: usize,
94 b: usize,
95 x: usize,
96 batch: u32,
97 n: u32,
98 nrhs: u32,
99 },
100 BatchedDgemmF64 {
106 a: usize,
107 b: usize,
108 c: usize,
109 batch: u32,
110 m: u32,
111 k: u32,
112 n: u32,
113 },
114 BatchedSgemm {
121 a: usize,
122 b: usize,
123 c: usize,
124 batch: u32,
125 m: u32,
126 k: u32,
127 n: u32,
128 },
129 Dgemm {
131 a: usize,
132 b: usize,
133 c: usize,
134 m: u32,
135 k: u32,
136 n: u32,
137 },
138 TransposeF64 {
142 src: usize,
143 dst: usize,
144 in_total: u32,
145 out_dims: Vec<u32>,
146 in_strides: Vec<u32>,
147 },
148 ActivationF64 {
152 src: usize,
153 dst: usize,
154 len: u32,
155 kind: Activation,
156 },
157 ComplexNormSqF32 {
161 src: usize,
162 dst: usize,
163 len: u32,
165 },
166 ComplexNormSqBackwardF32 {
170 z: usize,
171 g: usize,
172 dz: usize,
173 len: u32,
174 },
175 ConjugateC64 {
178 src: usize,
179 dst: usize,
180 len: u32,
181 },
182 ActivationC64 {
189 src: usize,
190 dst: usize,
191 len: u32,
192 kind: Activation,
193 },
194 ReduceSumF64 {
198 src: usize,
199 dst: usize,
200 outer: u32,
201 reduced: u32,
202 inner: u32,
203 },
204 CopyF64 {
207 src: usize,
208 dst: usize,
209 len: u32,
210 },
211 CopyI64 {
213 src: usize,
214 dst: usize,
215 len: u32,
216 },
217 CastF32ToI64 {
219 src: usize,
220 dst: usize,
221 len: u32,
222 },
223 CastF32ToF64 {
224 src: usize,
225 dst: usize,
226 len: u32,
227 },
228 CastF32ToI32 {
229 src: usize,
230 dst: usize,
231 len: u32,
232 },
233 CastI64ToF32 {
235 src: usize,
236 dst: usize,
237 len: u32,
238 },
239 CastBoolToI32 {
241 src: usize,
242 dst: usize,
243 len: u32,
244 },
245 CastBoolToF32 {
246 src: usize,
247 dst: usize,
248 len: u32,
249 },
250 CastI32ToF32 {
252 src: usize,
253 dst: usize,
254 len: u32,
255 },
256 BinaryFullF64 {
260 lhs: usize,
261 rhs: usize,
262 dst: usize,
263 len: u32,
264 lhs_len: u32,
265 rhs_len: u32,
266 op: BinaryOp,
267 out_dims_bcast: Vec<u32>,
270 bcast_lhs_strides: Vec<u32>,
271 bcast_rhs_strides: Vec<u32>,
272 },
273 ConcatF64 {
277 dst: usize,
278 outer: u32,
279 inner: u32,
280 total_axis: u32,
281 inputs: Vec<(usize, u32, u32)>,
282 },
283 BinaryFullC64 {
291 lhs: usize,
292 rhs: usize,
293 dst: usize,
294 len: u32,
297 lhs_len: u32,
298 rhs_len: u32,
299 op: BinaryOp,
300 out_dims_bcast: Vec<u32>,
301 bcast_lhs_strides: Vec<u32>,
302 bcast_rhs_strides: Vec<u32>,
303 },
304 Scan {
313 body: Arc<ThunkSchedule>,
314 body_init: Arc<Vec<u8>>, body_input_off: usize, body_output_off: usize, outer_init_off: usize, outer_final_off: usize, length: u32,
320 carry_bytes: u32, save_trajectory: bool,
326 xs_inputs: Arc<Vec<(usize, usize, u32)>>,
331 bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
337 num_checkpoints: u32,
343 },
344
345 ScanBackward {
353 body_vjp: Arc<ThunkSchedule>,
354 body_init: Arc<Vec<u8>>,
355 body_carry_in_off: usize, body_x_offs: Arc<Vec<usize>>, body_d_output_off: usize, body_dcarry_out_off: usize, outer_init_off: usize, outer_traj_off: usize, outer_upstream_off: usize, outer_xs_offs: Arc<Vec<(usize, u32)>>,
365 outer_dinit_off: usize, length: u32,
367 carry_bytes: u32,
368 carry_elem_size: u32,
374 save_trajectory: bool, num_checkpoints: u32,
381 forward_body: Option<Arc<ThunkSchedule>>,
385 forward_body_init: Option<Arc<Vec<u8>>>,
387 forward_body_carry_in_off: usize,
390 forward_body_output_off: usize,
391 forward_body_x_offs: Arc<Vec<usize>>,
394 },
395
396 ScanBackwardXs {
403 body_vjp: Arc<ThunkSchedule>,
404 body_init: Arc<Vec<u8>>,
405 body_carry_in_off: usize,
406 body_x_offs: Arc<Vec<usize>>,
407 body_d_output_off: usize,
408 body_dcarry_out_off: usize,
409 body_dxs_out_off: usize, outer_init_off: usize,
411 outer_traj_off: usize,
412 outer_upstream_off: usize,
413 outer_xs_offs: Arc<Vec<(usize, u32)>>,
414 outer_dxs_off: usize, length: u32,
416 carry_bytes: u32,
417 carry_elem_size: u32,
419 per_step_bytes: u32, save_trajectory: bool,
421 num_checkpoints: u32,
429 forward_body: Option<Arc<ThunkSchedule>>,
430 forward_body_init: Option<Arc<Vec<u8>>>,
431 forward_body_carry_in_off: usize,
432 forward_body_output_off: usize,
433 forward_body_x_offs: Arc<Vec<usize>>,
434 },
435 CustomFn {
440 body: Arc<ThunkSchedule>,
441 body_init: Arc<Vec<u8>>,
442 inputs: Arc<Vec<(usize, usize, u32)>>,
444 body_output_off: usize,
445 outer_output_off: usize,
446 out_bytes: u32,
447 },
448 FusedMmBiasAct {
450 a: usize,
451 w: usize,
452 bias: usize,
453 c: usize,
454 m: u32,
455 k: u32,
456 n: u32,
457 act: Option<Activation>,
458 },
459 FusedResidualLN {
461 x: usize,
462 res: usize,
463 bias: usize,
464 g: usize,
465 b: usize,
466 out: usize,
467 rows: u32,
468 h: u32,
469 eps: f32,
470 has_bias: bool,
471 },
472 FusedResidualRmsNorm {
474 x: usize,
475 res: usize,
476 bias: usize,
477 g: usize,
478 b: usize,
479 out: usize,
480 rows: u32,
481 h: u32,
482 eps: f32,
483 has_bias: bool,
484 },
485 BiasAdd {
487 src: usize,
488 bias: usize,
489 dst: usize,
490 m: u32,
491 n: u32,
492 },
493 BinaryFull {
508 lhs: usize,
509 rhs: usize,
510 dst: usize,
511 len: u32,
512 lhs_len: u32,
513 rhs_len: u32,
514 op: BinaryOp,
515 out_dims_bcast: Vec<u32>,
517 bcast_lhs_strides: Vec<u32>,
519 bcast_rhs_strides: Vec<u32>,
521 elem_bytes: u8,
523 },
524 ActivationInPlace {
526 data: usize,
527 len: u32,
528 act: Activation,
529 },
530 Gather {
532 table: usize,
533 table_len: u32,
534 idx: usize,
535 dst: usize,
536 num_idx: u32,
537 trailing: u32,
538 idx_i64: u8,
540 table_bytes: u8,
542 },
543 Narrow {
545 src: usize,
546 dst: usize,
547 outer: u32,
548 src_stride: u32,
549 dst_stride: u32,
550 inner: u32,
551 elem_bytes: u8,
552 },
553 Copy {
555 src: usize,
556 dst: usize,
557 len: u32,
558 },
559 LayerNorm {
561 src: usize,
562 g: usize,
563 b: usize,
564 dst: usize,
565 rows: u32,
566 h: u32,
567 eps: f32,
568 },
569 GroupNorm {
571 src: usize,
572 g: usize,
573 b: usize,
574 dst: usize,
575 n: u32,
576 c: u32,
577 h: u32,
578 w: u32,
579 num_groups: u32,
580 eps: f32,
581 },
582 BatchNormInference {
584 src: usize,
585 g: usize,
586 b: usize,
587 mean: usize,
588 var: usize,
589 dst: usize,
590 count: u32,
591 channels: u32,
592 eps: f32,
593 },
594 BatchNormInferenceBackwardInput {
595 x: usize,
596 gamma: usize,
597 mean: usize,
598 var: usize,
599 dy: usize,
600 dx: usize,
601 count: u32,
602 channels: u32,
603 eps: f32,
604 },
605 BatchNormInferenceBackwardGamma {
606 x: usize,
607 mean: usize,
608 var: usize,
609 dy: usize,
610 dgamma: usize,
611 count: u32,
612 channels: u32,
613 eps: f32,
614 },
615 BatchNormInferenceBackwardBeta {
616 dy: usize,
617 dbeta: usize,
618 count: u32,
619 channels: u32,
620 },
621 LayerNorm2d {
623 src: usize,
624 g: usize,
625 b: usize,
626 dst: usize,
627 n: u32,
628 c: u32,
629 h: u32,
630 w: u32,
631 eps: f32,
632 },
633 ConvTranspose2d {
635 src: usize,
636 weight: usize,
637 dst: usize,
638 n: u32,
639 c_in: u32,
640 h: u32,
641 w_in: u32,
642 c_out: u32,
643 h_out: u32,
644 w_out: u32,
645 kh: u32,
646 kw: u32,
647 sh: u32,
648 sw: u32,
649 ph: u32,
650 pw: u32,
651 dh: u32,
652 dw: u32,
653 groups: u32,
654 },
655 ResizeNearest2x {
657 src: usize,
658 dst: usize,
659 n: u32,
660 c: u32,
661 h: u32,
662 w: u32,
663 },
664 AxialRope2d {
666 src: usize,
667 dst: usize,
668 batch: u32,
669 seq: u32,
670 hidden: u32,
671 end_x: u32,
672 end_y: u32,
673 head_dim: u32,
674 num_heads: u32,
675 theta: f32,
676 repeat_factor: u32,
677 },
678 RmsNorm {
681 src: usize,
682 g: usize,
683 b: usize,
684 dst: usize,
685 rows: u32,
686 h: u32,
687 eps: f32,
688 },
689 Softmax {
691 data: usize,
692 rows: u32,
693 cols: u32,
694 },
695 Cumsum {
698 src: usize,
699 dst: usize,
700 rows: u32,
701 cols: u32,
702 exclusive: bool,
703 },
704 SelectiveScan {
708 x: usize,
709 delta: usize,
710 a: usize,
711 b: usize,
712 c: usize,
713 dst: usize,
714 batch: u32,
715 seq: u32,
716 hidden: u32,
717 state_size: u32,
718 },
719
720 GatedDeltaNet {
724 q: usize,
725 k: usize,
726 v: usize,
727 g: usize,
728 beta: usize,
729 state: usize,
732 dst: usize,
733 batch: u32,
734 seq: u32,
735 heads: u32,
736 state_size: u32,
737 },
738
739 Lstm {
743 x: usize,
744 w_ih: usize,
745 w_hh: usize,
746 bias: usize,
747 h0: usize,
748 c0: usize,
749 dst: usize,
750 batch: u32,
751 seq: u32,
752 input_size: u32,
753 hidden: u32,
754 num_layers: u32,
755 bidirectional: bool,
756 carry: bool,
757 },
758
759 Conv2D1x1 {
769 src: usize,
770 weight: usize,
771 dst: usize,
772 n: u32,
773 c_in: u32,
774 c_out: u32,
775 hw: u32,
776 },
777
778 DequantMatMul {
782 x: usize,
783 w_q: usize, scale: usize, zp: usize, dst: usize,
787 m: u32,
788 k: u32,
789 n: u32,
790 block_size: u32,
791 is_asymmetric: bool,
792 },
793
794 DequantMatMulGguf {
804 x: usize, w_q: usize, dst: usize, m: u32,
808 k: u32,
809 n: u32,
810 scheme: rlx_ir::quant::QuantScheme,
811 },
812
813 DequantMatMulInt4 {
815 x: usize,
816 w_q: usize,
817 scale: usize,
818 zp: usize,
819 dst: usize,
820 m: u32,
821 k: u32,
822 n: u32,
823 block_size: u32,
824 is_asymmetric: bool,
825 },
826
827 DequantMatMulFp8 {
829 x: usize,
830 w_q: usize,
831 scale: usize,
832 dst: usize,
833 m: u32,
834 k: u32,
835 n: u32,
836 e5m2: bool,
837 },
838
839 DequantMatMulNvfp4 {
841 x: usize,
842 w_q: usize,
843 scale: usize,
844 global_scale: usize,
845 dst: usize,
846 m: u32,
847 k: u32,
848 n: u32,
849 },
850
851 LoraMatMul {
855 x: usize,
856 w: usize,
857 a: usize,
858 b: usize,
859 dst: usize,
860 m: u32,
861 k: u32,
862 n: u32,
863 r: u32,
864 scale: f32,
865 },
866 Sample {
870 logits: usize,
871 dst: usize,
872 batch: u32,
873 vocab: u32,
874 top_k: u32, top_p: f32, temperature: f32, seed: u64,
878 },
879 RngNormal {
881 dst: usize,
882 len: u32,
883 mean: f32,
884 scale: f32,
885 key: u64,
886 op_seed: Option<f32>,
887 },
888 RngUniform {
890 dst: usize,
891 len: u32,
892 low: f32,
893 high: f32,
894 key: u64,
895 op_seed: Option<f32>,
896 },
897 Attention {
908 q: usize,
909 k: usize,
910 v: usize,
911 mask: usize,
912 out: usize,
913 batch: u32,
914 seq: u32,
916 kv_seq: u32,
918 heads: u32,
919 head_dim: u32,
920 mask_kind: rlx_ir::op::MaskKind,
921 scale: f32,
925 q_row_stride: u32,
926 k_row_stride: u32,
927 v_row_stride: u32,
928 bhsd: bool,
936 },
937 AttentionBackward {
939 q: usize,
940 k: usize,
941 v: usize,
942 dy: usize,
943 mask: usize,
944 out: usize,
945 batch: u32,
946 seq: u32,
947 kv_seq: u32,
948 heads: u32,
949 head_dim: u32,
950 mask_kind: rlx_ir::op::MaskKind,
951 wrt: rlx_ir::op::AttentionBwdWrt,
952 bhsd: bool,
953 },
954 Rope {
960 src: usize,
961 cos: usize,
962 sin: usize,
963 dst: usize,
964 batch: u32,
965 seq: u32,
966 hidden: u32,
967 head_dim: u32,
968 n_rot: u32,
969 cos_len: u32,
970 src_row_stride: u32,
971 },
972 FusedAttnBlock {
975 hidden: usize,
976 qkv_w: usize,
977 out_w: usize,
978 mask: usize,
979 out: usize,
980 qkv_b: usize,
981 out_b: usize, cos: usize,
983 sin: usize,
984 cos_len: u32, batch: u32,
986 seq: u32,
987 hs: u32,
988 nh: u32,
989 dh: u32,
990 has_bias: bool,
991 has_rope: bool,
992 },
993 FusedBertLayer {
996 hidden: usize,
998 qkv_w: usize,
999 qkv_b: usize,
1000 out_w: usize,
1001 out_b: usize,
1002 mask: usize,
1003 ln1_g: usize,
1005 ln1_b: usize,
1006 eps1: f32,
1007 fc1_w: usize,
1009 fc1_b: usize,
1010 fc2_w: usize,
1011 fc2_b: usize,
1012 ln2_g: usize,
1014 ln2_b: usize,
1015 eps2: f32,
1016 out: usize,
1018 batch: u32,
1020 seq: u32,
1021 hs: u32,
1022 nh: u32,
1023 dh: u32,
1024 int_dim: u32,
1025 },
1026 FusedNomicLayer {
1028 hidden: usize,
1029 qkv_w: usize,
1030 out_w: usize,
1031 mask: usize,
1032 cos: usize,
1033 sin: usize,
1034 cos_len: u32,
1035 ln1_g: usize,
1036 ln1_b: usize,
1037 eps1: f32,
1038 fc11_w: usize,
1039 fc12_w: usize,
1040 fc2_w: usize,
1041 ln2_g: usize,
1042 ln2_b: usize,
1043 eps2: f32,
1044 out: usize,
1045 batch: u32,
1046 seq: u32,
1047 hs: u32,
1048 nh: u32,
1049 dh: u32,
1050 int_dim: u32,
1051 },
1052 FusedSwiGLU {
1056 src: usize,
1057 dst: usize,
1058 n_half: u32,
1059 total: u32,
1060 gate_first: bool,
1061 },
1062 Concat {
1067 dst: usize,
1068 outer: u32,
1069 inner: u32,
1070 total_axis: u32,
1071 inputs: Vec<(usize, u32, u32)>,
1074 },
1075 Compare {
1077 lhs: usize,
1078 rhs: usize,
1079 dst: usize,
1080 len: u32,
1081 op: CmpOp,
1082 inputs_i64: u8,
1084 inputs_elem_bytes: u8,
1086 dst_elem_bytes: u8,
1088 },
1089 Reduce {
1097 src: usize,
1098 dst: usize,
1099 outer: u32,
1100 reduced: u32,
1101 inner: u32,
1102 op: ReduceOp,
1103 },
1104 ArgReduce {
1107 src: usize,
1108 dst: usize,
1109 outer: u32,
1110 reduced: u32,
1111 inner: u32,
1112 is_max: bool,
1113 },
1114 TopK {
1118 src: usize,
1119 dst: usize,
1120 outer: u32,
1121 axis_dim: u32,
1122 k: u32,
1123 indices_i64: u8,
1124 },
1125 GroupedMatMul {
1129 input: usize,
1130 weight: usize,
1131 expert_idx: usize,
1132 dst: usize,
1133 m: u32,
1134 k_dim: u32,
1135 n: u32,
1136 num_experts: u32,
1137 },
1138 DequantGroupedMatMulGguf {
1140 input: usize,
1141 w_q: usize,
1142 expert_idx: usize,
1143 dst: usize,
1144 m: u32,
1145 k_dim: u32,
1146 n: u32,
1147 num_experts: u32,
1148 scheme: rlx_ir::quant::QuantScheme,
1149 },
1150 DequantMoEWeightsGguf {
1152 w_q: usize,
1153 dst: usize,
1154 k_dim: u32,
1155 n: u32,
1156 num_experts: u32,
1157 scheme: rlx_ir::quant::QuantScheme,
1158 },
1159 ScatterAdd {
1162 updates: usize,
1163 indices: usize,
1164 dst: usize,
1165 num_updates: u32,
1166 out_dim: u32,
1167 trailing: u32,
1168 },
1169 Where {
1171 cond: usize,
1172 on_true: usize,
1173 on_false: usize,
1174 dst: usize,
1175 len: u32,
1176 elem_bytes: u8,
1177 cond_elem_bytes: u8,
1179 },
1180 Transpose {
1186 src: usize,
1187 dst: usize,
1188 in_total: u32,
1189 out_dims: Vec<u32>,
1190 in_strides: Vec<u32>,
1191 elem_bytes: u8,
1192 },
1193 GatherAxis {
1198 table: usize,
1199 idx: usize,
1200 dst: usize,
1201 outer: u32,
1202 axis_dim: u32,
1203 num_idx: u32,
1204 trailing: u32,
1205 idx_i64: u8,
1206 table_bytes: u8,
1207 },
1208 Pool2D {
1212 src: usize,
1213 dst: usize,
1214 n: u32,
1215 c: u32,
1216 h: u32,
1217 w: u32,
1218 h_out: u32,
1219 w_out: u32,
1220 kh: u32,
1221 kw: u32,
1222 sh: u32,
1223 sw: u32,
1224 ph: u32,
1225 pw: u32,
1226 kind: ReduceOp,
1227 },
1228 Conv2D {
1233 src: usize,
1234 weight: usize,
1235 dst: usize,
1236 n: u32,
1237 c_in: u32,
1238 h: u32,
1239 w: u32,
1240 c_out: u32,
1241 h_out: u32,
1242 w_out: u32,
1243 kh: u32,
1244 kw: u32,
1245 sh: u32,
1246 sw: u32,
1247 ph: u32,
1248 pw: u32,
1249 dh: u32,
1250 dw: u32,
1251 groups: u32,
1252 },
1253
1254 QMatMul {
1262 x: usize,
1263 w: usize,
1264 bias: usize,
1265 out: usize,
1266 m: u32,
1267 k: u32,
1268 n: u32,
1269 x_zp: i32,
1270 w_zp: i32,
1271 out_zp: i32,
1272 mult: f32,
1273 },
1274
1275 QConv2d {
1279 x: usize,
1280 w: usize,
1281 bias: usize,
1282 out: usize,
1283 n: u32,
1284 c_in: u32,
1285 h: u32,
1286 w_in: u32,
1287 c_out: u32,
1288 h_out: u32,
1289 w_out: u32,
1290 kh: u32,
1291 kw: u32,
1292 sh: u32,
1293 sw: u32,
1294 ph: u32,
1295 pw: u32,
1296 dh: u32,
1297 dw: u32,
1298 groups: u32,
1299 x_zp: i32,
1300 w_zp: i32,
1301 out_zp: i32,
1302 mult: f32,
1303 },
1304
1305 Quantize {
1312 x: usize,
1313 q: usize,
1314 len: u32,
1315 chan_axis: u32,
1316 chan_dim: u32,
1317 inner: u32,
1318 scales: Vec<f32>,
1319 zero_points: Vec<i32>,
1320 },
1321
1322 Dequantize {
1324 q: usize,
1325 x: usize,
1326 len: u32,
1327 chan_axis: u32,
1328 chan_dim: u32,
1329 inner: u32,
1330 scales: Vec<f32>,
1331 zero_points: Vec<i32>,
1332 },
1333
1334 FakeQuantize {
1345 x: usize,
1346 out: usize,
1347 len: u32,
1348 chan_axis: u32,
1349 chan_dim: u32,
1350 inner: u32,
1351 bits: u8,
1352 ste: rlx_ir::op::SteKind,
1356 scale_mode: rlx_ir::op::ScaleMode,
1361 state_off: Option<usize>,
1365 },
1366
1367 FakeQuantizeBackward {
1372 x: usize,
1373 dy: usize,
1374 dx: usize,
1375 len: u32,
1376 chan_axis: u32,
1377 chan_dim: u32,
1378 inner: u32,
1379 bits: u8,
1380 ste: rlx_ir::op::SteKind,
1381 },
1382
1383 FakeQuantizeLSQ {
1386 x: usize,
1387 scale_off: usize,
1388 out: usize,
1389 len: u32,
1390 chan_axis: u32,
1391 chan_dim: u32,
1392 inner: u32,
1393 bits: u8,
1394 },
1395
1396 FakeQuantizeLSQBackwardX {
1399 x: usize,
1400 scale_off: usize,
1401 dy: usize,
1402 dx: usize,
1403 len: u32,
1404 chan_axis: u32,
1405 chan_dim: u32,
1406 inner: u32,
1407 bits: u8,
1408 },
1409
1410 FakeQuantizeLSQBackwardScale {
1415 x: usize,
1416 scale_off: usize,
1417 dy: usize,
1418 dscale: usize,
1419 len: u32,
1420 chan_axis: u32,
1421 chan_dim: u32,
1422 inner: u32,
1423 bits: u8,
1424 },
1425
1426 ReluBackward {
1428 x: usize,
1429 dy: usize,
1430 dx: usize,
1431 len: u32,
1432 },
1433 ReluBackwardF64 {
1439 x: usize,
1440 dy: usize,
1441 dx: usize,
1442 len: u32,
1443 },
1444
1445 ActivationBackward {
1450 x: usize,
1451 dy: usize,
1452 dx: usize,
1453 len: u32,
1454 kind: Activation,
1455 },
1456 ActivationBackwardF64 {
1462 x: usize,
1463 dy: usize,
1464 dx: usize,
1465 len: u32,
1466 kind: Activation,
1467 },
1468
1469 LayerNormBackwardInput {
1472 x: usize,
1473 gamma: usize,
1474 dy: usize,
1475 dx: usize,
1476 rows: u32,
1477 h: u32,
1478 eps: f32,
1479 },
1480
1481 LayerNormBackwardGamma {
1483 x: usize,
1484 dy: usize,
1485 dgamma: usize,
1486 rows: u32,
1487 h: u32,
1488 eps: f32,
1489 },
1490
1491 RmsNormBackwardInput {
1492 x: usize,
1493 gamma: usize,
1494 beta: usize,
1495 dy: usize,
1496 dx: usize,
1497 rows: u32,
1498 h: u32,
1499 eps: f32,
1500 },
1501 RmsNormBackwardGamma {
1502 x: usize,
1503 gamma: usize,
1504 beta: usize,
1505 dy: usize,
1506 dgamma: usize,
1507 rows: u32,
1508 h: u32,
1509 eps: f32,
1510 },
1511 RmsNormBackwardBeta {
1512 x: usize,
1513 gamma: usize,
1514 beta: usize,
1515 dy: usize,
1516 dbeta: usize,
1517 rows: u32,
1518 h: u32,
1519 eps: f32,
1520 },
1521 RopeBackward {
1522 dy: usize,
1523 cos: usize,
1524 sin: usize,
1525 dx: usize,
1526 batch: u32,
1527 seq: u32,
1528 hidden: u32,
1529 head_dim: u32,
1530 n_rot: u32,
1531 cos_len: u32,
1532 },
1533 CumsumBackward {
1534 dy: usize,
1535 dx: usize,
1536 rows: u32,
1537 cols: u32,
1538 exclusive: bool,
1539 },
1540 GatherBackward {
1541 dy: usize,
1542 indices: usize,
1543 dst: usize,
1544 outer: u32,
1545 axis_dim: u32,
1546 num_idx: u32,
1547 trailing: u32,
1548 },
1549
1550 GroupNormBackwardInput {
1551 x: usize,
1552 gamma: usize,
1553 beta: usize,
1554 dy: usize,
1555 dx: usize,
1556 n: u32,
1557 c: u32,
1558 h: u32,
1559 w: u32,
1560 num_groups: u32,
1561 eps: f32,
1562 },
1563 GroupNormBackwardGamma {
1564 x: usize,
1565 dy: usize,
1566 dgamma: usize,
1567 n: u32,
1568 c: u32,
1569 h: u32,
1570 w: u32,
1571 num_groups: u32,
1572 eps: f32,
1573 },
1574 GroupNormBackwardBeta {
1575 dy: usize,
1576 dbeta: usize,
1577 n: u32,
1578 c: u32,
1579 h: u32,
1580 w: u32,
1581 },
1582
1583 MaxPool2dBackward {
1589 x: usize,
1590 dy: usize,
1591 dx: usize,
1592 n: u32,
1593 c: u32,
1594 h: u32,
1595 w: u32,
1596 h_out: u32,
1597 w_out: u32,
1598 kh: u32,
1599 kw: u32,
1600 sh: u32,
1601 sw: u32,
1602 ph: u32,
1603 pw: u32,
1604 },
1605
1606 Conv2dBackwardInput {
1610 dy: usize,
1611 w: usize,
1612 dx: usize,
1613 n: u32,
1614 c_in: u32,
1615 h: u32,
1616 w_in: u32,
1617 c_out: u32,
1618 h_out: u32,
1619 w_out: u32,
1620 kh: u32,
1621 kw: u32,
1622 sh: u32,
1623 sw: u32,
1624 ph: u32,
1625 pw: u32,
1626 dh: u32,
1627 dw: u32,
1628 groups: u32,
1629 },
1630
1631 Conv2dBackwardWeight {
1635 x: usize,
1636 dy: usize,
1637 dw: usize,
1638 n: u32,
1639 c_in: u32,
1640 h: u32,
1641 w: u32,
1642 c_out: u32,
1643 h_out: u32,
1644 w_out: u32,
1645 kh: u32,
1646 kw: u32,
1647 sh: u32,
1648 sw: u32,
1649 ph: u32,
1650 pw: u32,
1651 dh: u32,
1652 dw_dil: u32,
1653 groups: u32,
1654 },
1655
1656 Im2Col {
1659 x: usize,
1660 col: usize,
1661 n: u32,
1662 c_in: u32,
1663 h: u32,
1664 w: u32,
1665 h_out: u32,
1666 w_out: u32,
1667 kh: u32,
1668 kw: u32,
1669 sh: u32,
1670 sw: u32,
1671 ph: u32,
1672 pw: u32,
1673 dh: u32,
1674 dw_dil: u32,
1675 },
1676
1677 SoftmaxCrossEntropy {
1681 logits: usize,
1682 labels: usize,
1683 dst: usize,
1684 n: u32,
1685 c: u32,
1686 },
1687
1688 SoftmaxCrossEntropyBackward {
1691 logits: usize,
1692 labels: usize,
1693 d_loss: usize,
1694 dlogits: usize,
1695 n: u32,
1696 c: u32,
1697 },
1698
1699 CustomOp {
1705 kernel: Arc<dyn CpuKernel>,
1706 inputs: Vec<(usize, u32, Shape)>, output: (usize, u32, Shape), attrs: Vec<u8>,
1709 },
1710
1711 GaussianSplatRender {
1721 positions_off: usize,
1722 positions_len: usize,
1723 scales_off: usize,
1724 scales_len: usize,
1725 rotations_off: usize,
1726 rotations_len: usize,
1727 opacities_off: usize,
1728 opacities_len: usize,
1729 colors_off: usize,
1730 colors_len: usize,
1731 sh_coeffs_off: usize,
1732 sh_coeffs_len: usize,
1733 meta_off: usize,
1734 dst_off: usize,
1735 dst_len: usize,
1736 width: u32,
1737 height: u32,
1738 tile_size: u32,
1739 radius_scale: f32,
1740 alpha_cutoff: f32,
1741 max_splat_steps: u32,
1742 transmittance_threshold: f32,
1743 max_list_entries: u32,
1744 },
1745 GaussianSplatRenderBackward {
1746 positions_off: usize,
1747 positions_len: usize,
1748 scales_off: usize,
1749 scales_len: usize,
1750 rotations_off: usize,
1751 rotations_len: usize,
1752 opacities_off: usize,
1753 opacities_len: usize,
1754 colors_off: usize,
1755 colors_len: usize,
1756 sh_coeffs_off: usize,
1757 sh_coeffs_len: usize,
1758 meta_off: usize,
1759 d_loss_off: usize,
1760 d_loss_len: usize,
1761 packed_off: usize,
1762 packed_len: usize,
1763 width: u32,
1764 height: u32,
1765 tile_size: u32,
1766 radius_scale: f32,
1767 alpha_cutoff: f32,
1768 max_splat_steps: u32,
1769 transmittance_threshold: f32,
1770 max_list_entries: u32,
1771 loss_grad_clip: f32,
1772 sh_band: u32,
1773 max_anisotropy: f32,
1774 },
1775 GaussianSplatPrepare {
1777 positions_off: usize,
1778 positions_len: usize,
1779 scales_off: usize,
1780 scales_len: usize,
1781 rotations_off: usize,
1782 rotations_len: usize,
1783 opacities_off: usize,
1784 opacities_len: usize,
1785 colors_off: usize,
1786 colors_len: usize,
1787 sh_coeffs_off: usize,
1788 sh_coeffs_len: usize,
1789 meta_off: usize,
1790 meta_len: usize,
1791 prep_off: usize,
1792 prep_len: usize,
1793 width: u32,
1794 height: u32,
1795 tile_size: u32,
1796 radius_scale: f32,
1797 alpha_cutoff: f32,
1798 max_splat_steps: u32,
1799 transmittance_threshold: f32,
1800 max_list_entries: u32,
1801 },
1802 GaussianSplatRasterize {
1804 prep_off: usize,
1805 prep_len: usize,
1806 meta_off: usize,
1807 meta_len: usize,
1808 dst_off: usize,
1809 dst_len: usize,
1810 count: usize,
1811 width: u32,
1812 height: u32,
1813 tile_size: u32,
1814 alpha_cutoff: f32,
1815 max_splat_steps: u32,
1816 transmittance_threshold: f32,
1817 max_list_entries: u32,
1818 },
1819 Fft1d {
1820 src: usize,
1821 dst: usize,
1822 outer: u32,
1823 n_complex: u32,
1824 inverse: bool,
1825 norm_tag: u32,
1826 dtype: rlx_ir::DType,
1827 },
1828 FftButterflyStage {
1829 state_src: usize,
1830 state_dst: usize,
1831 gate_src: usize,
1832 rev_src: usize,
1833 tw_re_src: usize,
1834 tw_im_src: usize,
1835 batch: u32,
1836 n_fft: u32,
1837 stage: u32,
1838 },
1839 LogMel {
1840 spec: usize,
1841 filters: usize,
1842 dst: usize,
1843 outer: u32,
1844 n_fft: u32,
1845 n_bins: u32,
1846 n_mels: u32,
1847 },
1848 LogMelBackward {
1849 spec: usize,
1850 filters: usize,
1851 dy: usize,
1852 dst: usize,
1853 outer: u32,
1854 n_fft: u32,
1855 n_bins: u32,
1856 n_mels: u32,
1857 },
1858 WelchPeaks {
1859 spec: usize,
1860 dst: usize,
1861 welch_batch: u32,
1862 n_fft: u32,
1863 n_segments: u32,
1864 k: u32,
1865 },
1866}
1867
1868#[derive(Clone)]
1871pub struct ThunkSchedule {
1872 pub thunks: Vec<Thunk>,
1873 pub moe_resident: Option<std::sync::Arc<[bool]>>,
1875 pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1877 pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1879 pub mask_threshold: f32,
1881 pub mask_neg_inf: f32,
1882 pub score_skip: f32,
1883 pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1889 pub rng: Arc<std::sync::RwLock<rlx_ir::RngOptions>>,
1891}
1892
1893impl ThunkSchedule {
1894 pub fn strip_nops(&mut self) {
1895 self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1896 self.compiled_fns.clear();
1899 }
1900}
1901
1902fn node_offset(arena: &Arena, id: NodeId) -> usize {
1904 if arena.has_buffer(id) {
1905 arena.byte_offset(id)
1906 } else {
1907 usize::MAX
1908 }
1909}
1910
1911fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1917 match t {
1918 Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1919 Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1920 Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1921 Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1922 Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1923 Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1924 Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1925 Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1926 Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1927 Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1928 Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1929 Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1930 Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1931 Thunk::ConjugateC64 { src, .. } => vec![*src],
1932 Thunk::Scan {
1933 outer_init_off,
1934 xs_inputs,
1935 ..
1936 } => {
1937 let mut v = vec![*outer_init_off];
1938 for (_, outer_xs_off, _) in xs_inputs.iter() {
1939 v.push(*outer_xs_off);
1940 }
1941 v
1942 }
1943 Thunk::ScanBackward {
1944 outer_init_off,
1945 outer_traj_off,
1946 outer_upstream_off,
1947 outer_xs_offs,
1948 ..
1949 } => {
1950 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1951 for (off, _) in outer_xs_offs.iter() {
1952 v.push(*off);
1953 }
1954 v
1955 }
1956 Thunk::ScanBackwardXs {
1957 outer_init_off,
1958 outer_traj_off,
1959 outer_upstream_off,
1960 outer_xs_offs,
1961 ..
1962 } => {
1963 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1964 for (off, _) in outer_xs_offs.iter() {
1965 v.push(*off);
1966 }
1967 v
1968 }
1969 Thunk::CustomFn { inputs, .. } => {
1970 inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1971 }
1972 Thunk::ActivationInPlace { data, .. } => vec![*data],
1973 Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1974 vec![*src, *g, *b]
1975 }
1976 Thunk::BatchNormInference {
1977 src,
1978 g,
1979 b,
1980 mean,
1981 var,
1982 ..
1983 } => vec![*src, *g, *b, *mean, *var],
1984 Thunk::ResizeNearest2x { src, .. } => vec![*src],
1985 Thunk::AxialRope2d { src, .. } => vec![*src],
1986 Thunk::FusedResidualLN {
1987 x, res, bias, g, b, ..
1988 } => vec![*x, *res, *bias, *g, *b],
1989 Thunk::FusedResidualRmsNorm {
1990 x, res, bias, g, b, ..
1991 } => vec![*x, *res, *bias, *g, *b],
1992 Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1993 Thunk::Softmax { data, .. } => vec![*data],
1994 Thunk::Cumsum { src, .. } => vec![*src],
1995 Thunk::Sample { logits, .. } => vec![*logits],
1996 Thunk::RngNormal { .. } | Thunk::RngUniform { .. } => vec![],
1997 Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1998 Thunk::DequantMatMul {
1999 x, w_q, scale, zp, ..
2000 } => vec![*x, *w_q, *scale, *zp],
2001 Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
2002 Thunk::DequantMatMulInt4 {
2003 x, w_q, scale, zp, ..
2004 } => vec![*x, *w_q, *scale, *zp],
2005 Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
2006 Thunk::DequantMatMulNvfp4 {
2007 x,
2008 w_q,
2009 scale,
2010 global_scale,
2011 ..
2012 } => vec![*x, *w_q, *scale, *global_scale],
2013 Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
2014 Thunk::SelectiveScan {
2015 x, delta, a, b, c, ..
2016 } => vec![*x, *delta, *a, *b, *c],
2017 Thunk::GatedDeltaNet {
2018 q,
2019 k,
2020 v,
2021 g,
2022 beta,
2023 state,
2024 ..
2025 } => {
2026 let mut v = vec![*q, *k, *v, *g, *beta];
2027 if *state != 0 {
2028 v.push(*state);
2029 }
2030 v
2031 }
2032 Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
2033 Thunk::AttentionBackward {
2034 q, k, v, dy, mask, ..
2035 } => {
2036 let mut v = vec![*q, *k, *v, *dy];
2037 if *mask != 0 {
2038 v.push(*mask);
2039 }
2040 v
2041 }
2042 Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
2043 Thunk::FusedAttnBlock {
2044 hidden,
2045 qkv_w,
2046 out_w,
2047 mask,
2048 qkv_b,
2049 out_b,
2050 cos,
2051 sin,
2052 ..
2053 } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
2054 Thunk::FusedSwiGLU { src, .. } => vec![*src],
2055 Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _, _)| *off).collect(),
2056 Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _, _)| *off).collect(),
2057 Thunk::Narrow { src, .. } => vec![*src],
2058 Thunk::Copy { src, .. } => vec![*src],
2059 Thunk::Gather { table, idx, .. } => vec![*table, *idx],
2060 _ => vec![],
2064 }
2065}
2066
2067#[allow(clippy::too_many_arguments)]
2081pub fn dequant_matmul_int8(
2082 x: &[f32], w_bytes: &[i8], scales: &[f32], zps: &[f32], out: &mut [f32], m: usize,
2088 k: usize,
2089 n: usize,
2090 block_size: usize,
2091 asym: bool,
2092) {
2093 let blocks_per_col = k.div_ceil(block_size);
2094 for i in 0..m {
2095 for j in 0..n {
2096 let mut acc = 0f32;
2097 for p in 0..k {
2098 let block = p / block_size;
2099 let s = scales[block * n + j];
2100 let z = if asym { zps[block * n + j] } else { 0.0 };
2101 let q = w_bytes[p * n + j] as f32;
2102 let dequantized = (q - z) * s;
2103 acc += x[i * k + p] * dequantized;
2104 }
2105 out[i * n + j] = acc;
2106 }
2107 }
2108 let _ = blocks_per_col;
2109}
2110
2111#[allow(clippy::too_many_arguments)]
2112fn dequant_matmul_int4(
2113 x: &[f32],
2114 w_bytes: &[u8],
2115 scales: &[f32],
2116 zps: &[f32],
2117 out: &mut [f32],
2118 m: usize,
2119 k: usize,
2120 n: usize,
2121 block_size: usize,
2122 asym: bool,
2123) {
2124 for i in 0..m {
2125 for j in 0..n {
2126 let mut acc = 0f32;
2127 for p in 0..k {
2128 let block = p / block_size;
2129 let s = scales[block * n + j];
2130 let z = if asym { zps[block * n + j] } else { 0.0 };
2131 let byte_idx = (p * n + j) / 2;
2132 let nibble = if (p * n + j) & 1 == 0 {
2133 w_bytes[byte_idx] & 0x0F
2134 } else {
2135 w_bytes[byte_idx] >> 4
2136 };
2137 let dequantized = (nibble as f32 - z) * s;
2138 acc += x[i * k + p] * dequantized;
2139 }
2140 out[i * n + j] = acc;
2141 }
2142 }
2143}
2144
2145fn fp8_e4m3_to_f32(b: u8) -> f32 {
2146 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2147 let exp = (b >> 3) & 0x0F;
2148 let mant = b & 0x07;
2149 if exp == 0 {
2150 if mant == 0 {
2151 return 0.0;
2152 }
2153 return sign * (mant as f32) * 2f32.powi(-9);
2154 }
2155 if exp == 0x0F {
2156 return if mant == 0 {
2157 sign * f32::INFINITY
2158 } else {
2159 f32::NAN
2160 };
2161 }
2162 sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
2163}
2164
2165fn fp8_e5m2_to_f32(b: u8) -> f32 {
2166 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2167 let exp = (b >> 2) & 0x1F;
2168 let mant = b & 0x03;
2169 if exp == 0 {
2170 if mant == 0 {
2171 return 0.0;
2172 }
2173 return sign * (mant as f32) * 2f32.powi(-16);
2174 }
2175 if exp == 0x1F {
2176 return if mant == 0 {
2177 sign * f32::INFINITY
2178 } else {
2179 f32::NAN
2180 };
2181 }
2182 sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
2183}
2184
2185#[allow(clippy::too_many_arguments)]
2186fn dequant_matmul_fp8(
2187 x: &[f32],
2188 w_bytes: &[u8],
2189 scales: &[f32],
2190 out: &mut [f32],
2191 m: usize,
2192 k: usize,
2193 n: usize,
2194 e5m2: bool,
2195) {
2196 let dequant = if e5m2 {
2197 fp8_e5m2_to_f32
2198 } else {
2199 fp8_e4m3_to_f32
2200 };
2201 for i in 0..m {
2202 for j in 0..n {
2203 let mut acc = 0f32;
2204 for p in 0..k {
2205 let w = dequant(w_bytes[p * n + j]);
2206 let s = scales.get(j).copied().unwrap_or(1.0);
2207 acc += x[i * k + p] * w * s;
2208 }
2209 out[i * n + j] = acc;
2210 }
2211 }
2212}
2213
2214#[allow(clippy::too_many_arguments)]
2215pub fn dequant_matmul_nvfp4(
2216 x: &[f32],
2217 w_bytes: &[u8],
2218 scale_bytes: &[u8],
2219 global_scale: f32,
2220 out: &mut [f32],
2221 m: usize,
2222 k: usize,
2223 n: usize,
2224) {
2225 use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
2226 let gs = NVFP4_GROUP_SIZE;
2227 for i in 0..m {
2228 for j in 0..n {
2229 let mut acc = 0f32;
2230 for p in 0..k {
2231 let byte_idx = (p * n + j) / 2;
2232 let nibble = if (p * n + j) & 1 == 0 {
2233 w_bytes[byte_idx] & 0x0F
2234 } else {
2235 w_bytes[byte_idx] >> 4
2236 };
2237 let block = p / gs;
2238 let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
2239 let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
2240 acc += x[i * k + p] * w;
2241 }
2242 out[i * n + j] = acc;
2243 }
2244 }
2245}
2246
2247fn sample_row(
2256 logits: &[f32],
2257 top_k: usize,
2258 top_p: f32,
2259 temperature: f32,
2260 rng: &mut rlx_ir::Philox4x32,
2261) -> usize {
2262 let v = logits.len();
2263 if v == 0 {
2264 return 0;
2265 }
2266 let temp = temperature.max(1e-6);
2267 let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2269
2270 if top_k > 0 && top_k < v {
2272 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2274 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2277 let cutoff = indexed[top_k - 1].1;
2278 for x in scaled.iter_mut() {
2279 if *x < cutoff {
2280 *x = f32::NEG_INFINITY;
2281 }
2282 }
2283 }
2284
2285 let mut max_l = f32::NEG_INFINITY;
2287 for &x in &scaled {
2288 if x > max_l {
2289 max_l = x;
2290 }
2291 }
2292 let mut sum = 0.0f32;
2293 for x in scaled.iter_mut() {
2294 *x = (*x - max_l).exp();
2295 sum += *x;
2296 }
2297 let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2298 for x in scaled.iter_mut() {
2299 *x *= inv;
2300 }
2301
2302 if top_p < 1.0 {
2305 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2306 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2307 let mut cum = 0.0f32;
2308 let mut keep = vec![false; v];
2309 for (idx, p) in indexed.iter() {
2310 keep[*idx] = true;
2311 cum += *p;
2312 if cum >= top_p {
2313 break;
2314 }
2315 }
2316 let mut new_sum = 0.0f32;
2317 for (i, x) in scaled.iter_mut().enumerate() {
2318 if !keep[i] {
2319 *x = 0.0;
2320 }
2321 new_sum += *x;
2322 }
2323 let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2324 for x in scaled.iter_mut() {
2325 *x *= inv;
2326 }
2327 }
2328
2329 let r = rng.next_f32();
2331 let mut acc = 0.0f32;
2332 for (i, &p) in scaled.iter().enumerate() {
2333 acc += p;
2334 if r <= acc {
2335 return i;
2336 }
2337 }
2338 v - 1 }
2340
2341#[inline]
2345fn apply_synthetic_mask(
2346 scores: &mut [f32],
2347 q_seq: usize,
2348 k_seq: usize,
2349 kind: rlx_ir::op::MaskKind,
2350) {
2351 let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2352 let q_offset = k_seq.saturating_sub(q_seq);
2353 match kind {
2354 rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2355 rlx_ir::op::MaskKind::Causal => {
2356 for qi in 0..q_seq {
2357 let abs_q = q_offset + qi;
2358 for ki in (abs_q + 1)..k_seq {
2359 scores[qi * k_seq + ki] = neg;
2360 }
2361 }
2362 }
2363 rlx_ir::op::MaskKind::SlidingWindow(w) => {
2364 for qi in 0..q_seq {
2365 let abs_q = q_offset + qi;
2366 let lo = abs_q.saturating_sub(w);
2367 for ki in 0..k_seq {
2368 if ki < lo || ki > abs_q {
2369 scores[qi * k_seq + ki] = neg;
2370 }
2371 }
2372 }
2373 }
2374 }
2375}
2376
2377fn conv_nchw_dims(shape: &Shape) -> (u32, u32, u32, u32) {
2379 match shape.rank() {
2380 3 => (
2381 shape.dim(0).unwrap_static() as u32,
2382 shape.dim(1).unwrap_static() as u32,
2383 1,
2384 shape.dim(2).unwrap_static() as u32,
2385 ),
2386 4 => (
2387 shape.dim(0).unwrap_static() as u32,
2388 shape.dim(1).unwrap_static() as u32,
2389 shape.dim(2).unwrap_static() as u32,
2390 shape.dim(3).unwrap_static() as u32,
2391 ),
2392 r => panic!("conv_nchw_dims: expected rank 3 or 4, got {r}"),
2393 }
2394}
2395
2396pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2398 compile_thunks_with_rng(graph, arena, rlx_ir::RngOptions::default())
2399}
2400
2401pub fn compile_thunks_with_rng(
2403 graph: &Graph,
2404 arena: &Arena,
2405 rng: rlx_ir::RngOptions,
2406) -> ThunkSchedule {
2407 let rng_shared = Arc::new(std::sync::RwLock::new(rng));
2408 let mut thunks = Vec::with_capacity(graph.len());
2409
2410 for node in graph.nodes() {
2411 if rlx_opt::is_pure_view(graph, node) {
2415 thunks.push(Thunk::Nop);
2416 continue;
2417 }
2418 let t = match &node.op {
2419 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2420
2421 Op::FusedMatMulBiasAct { activation } => {
2422 let shape = &node.shape;
2423 let n = shape.dim(shape.rank() - 1).unwrap_static();
2424 let total = shape.num_elements().unwrap();
2425 let m = total / n;
2426 let a_len = get_len(graph, node.inputs[0]);
2427 let k = a_len / m;
2428 Thunk::FusedMmBiasAct {
2429 a: node_offset(arena, node.inputs[0]),
2430 w: node_offset(arena, node.inputs[1]),
2431 bias: node_offset(arena, node.inputs[2]),
2432 c: node_offset(arena, node.id),
2433 m: m as u32,
2434 k: k as u32,
2435 n: n as u32,
2436 act: *activation,
2437 }
2438 }
2439
2440 Op::FusedResidualLN { has_bias, eps } => {
2441 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2442 let total = node.shape.num_elements().unwrap();
2443 let rows = total / h;
2444 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2445 Thunk::FusedResidualLN {
2446 x: node_offset(arena, node.inputs[0]),
2447 res: node_offset(arena, node.inputs[1]),
2448 bias: if *has_bias {
2449 node_offset(arena, node.inputs[2])
2450 } else {
2451 0
2452 },
2453 g: node_offset(arena, node.inputs[g_idx]),
2454 b: node_offset(arena, node.inputs[b_idx]),
2455 out: node_offset(arena, node.id),
2456 rows: rows as u32,
2457 h: h as u32,
2458 eps: *eps,
2459 has_bias: *has_bias,
2460 }
2461 }
2462
2463 Op::FusedResidualRmsNorm { has_bias, eps } => {
2464 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2465 let total = node.shape.num_elements().unwrap();
2466 let rows = total / h;
2467 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2468 Thunk::FusedResidualRmsNorm {
2469 x: node_offset(arena, node.inputs[0]),
2470 res: node_offset(arena, node.inputs[1]),
2471 bias: if *has_bias {
2472 node_offset(arena, node.inputs[2])
2473 } else {
2474 0
2475 },
2476 g: node_offset(arena, node.inputs[g_idx]),
2477 b: node_offset(arena, node.inputs[b_idx]),
2478 out: node_offset(arena, node.id),
2479 rows: rows as u32,
2480 h: h as u32,
2481 eps: *eps,
2482 has_bias: *has_bias,
2483 }
2484 }
2485
2486 Op::MatMul => {
2487 let shape = &node.shape;
2488 let a_shape = &graph.node(node.inputs[0]).shape;
2489 let b_shape = &graph.node(node.inputs[1]).shape;
2490 let eff =
2493 rlx_ir::shape::matmul_shape(a_shape, b_shape).unwrap_or_else(|_| shape.clone());
2494 let rank = eff.rank().max(2);
2495 let n = eff.dim(rank - 1).unwrap_static();
2496 let k_dim = a_shape.dim(a_shape.rank().max(2) - 1).unwrap_static();
2497 if shape.dtype() == rlx_ir::DType::C64 {
2498 let both = a_shape.rank() >= 3 && b_shape.rank() >= 3;
2502 assert!(!both, "batched (both-operand) C64 matmul not yet supported");
2503 let m: usize = if a_shape.rank() >= 3 {
2504 (0..a_shape.rank() - 1)
2505 .map(|d| a_shape.dim(d).unwrap_static())
2506 .product()
2507 } else {
2508 a_shape.dim(a_shape.rank() - 2).unwrap_static()
2509 };
2510 Thunk::CgemmC64 {
2511 a: node_offset(arena, node.inputs[0]),
2512 b: node_offset(arena, node.inputs[1]),
2513 c: node_offset(arena, node.id),
2514 m: m as u32,
2515 k: k_dim as u32,
2516 n: n as u32,
2517 }
2518 } else {
2519 let both_batched = a_shape.rank() >= 3 && b_shape.rank() >= 3;
2522 let batched_3d =
2523 rank >= 3 && both_batched && a_shape.rank() + b_shape.rank() > 4;
2524 if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2525 let mut batch_prod = 1usize;
2526 for d in 0..rank - 2 {
2527 batch_prod *= eff.dim(d).unwrap_static();
2528 }
2529 let m_dim = eff.dim(rank - 2).unwrap_static();
2530 Thunk::BatchedDgemmF64 {
2531 a: node_offset(arena, node.inputs[0]),
2532 b: node_offset(arena, node.inputs[1]),
2533 c: node_offset(arena, node.id),
2534 batch: batch_prod as u32,
2535 m: m_dim as u32,
2536 k: k_dim as u32,
2537 n: n as u32,
2538 }
2539 } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2540 let mut batch_prod = 1usize;
2541 for d in 0..rank - 2 {
2542 batch_prod *= eff.dim(d).unwrap_static();
2543 }
2544 let m_dim = eff.dim(rank - 2).unwrap_static();
2545 Thunk::BatchedSgemm {
2546 a: node_offset(arena, node.inputs[0]),
2547 b: node_offset(arena, node.inputs[1]),
2548 c: node_offset(arena, node.id),
2549 batch: batch_prod as u32,
2550 m: m_dim as u32,
2551 k: k_dim as u32,
2552 n: n as u32,
2553 }
2554 } else {
2555 let m = if a_shape.rank() >= 3 && b_shape.rank() <= 2 {
2556 let mut m_prod = 1usize;
2557 for d in 0..a_shape.rank() - 1 {
2558 m_prod *= a_shape.dim(d).unwrap_static();
2559 }
2560 m_prod
2561 } else if a_shape.rank() >= 2 {
2562 a_shape.dim(a_shape.rank() - 2).unwrap_static()
2563 } else {
2564 eff.num_elements().unwrap_or(1) / n.max(1)
2565 };
2566 match shape.dtype() {
2567 rlx_ir::DType::F64 => Thunk::Dgemm {
2568 a: node_offset(arena, node.inputs[0]),
2569 b: node_offset(arena, node.inputs[1]),
2570 c: node_offset(arena, node.id),
2571 m: m as u32,
2572 k: k_dim as u32,
2573 n: n as u32,
2574 },
2575 _ => Thunk::Sgemm {
2576 a: node_offset(arena, node.inputs[0]),
2577 b: node_offset(arena, node.inputs[1]),
2578 c: node_offset(arena, node.id),
2579 m: m as u32,
2580 k: k_dim as u32,
2581 n: n as u32,
2582 },
2583 }
2584 }
2585 }
2586 }
2587
2588 Op::Binary(op) => {
2589 let lhs_len = get_len(graph, node.inputs[0]);
2590 let rhs_len = get_len(graph, node.inputs[1]);
2591 let out_len = node.shape.num_elements().unwrap();
2592 if node.shape.dtype() == rlx_ir::DType::C64 {
2593 match op {
2597 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2598 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2599 "Op::Binary({op:?}) on DType::C64: complex \
2600 max/min/pow have no single natural definition \
2601 — caller should drop to 2N-real-block (see \
2602 spike-ac) and pick a convention there"
2603 ),
2604 }
2605 }
2606 let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2610 if lhs_len == out_len && rhs_len == out_len {
2611 (Vec::new(), Vec::new(), Vec::new())
2612 } else {
2613 let lhs_dims = get_static_dims(graph, node.inputs[0]);
2614 let rhs_dims = get_static_dims(graph, node.inputs[1]);
2615 let out_dims_v = get_static_dims(graph, node.id);
2616 if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2617 (Vec::new(), Vec::new(), Vec::new())
2622 } else {
2623 let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2624 let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2625 let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2626 (od, ls, rs)
2627 }
2628 };
2629 if node.shape.dtype() == rlx_ir::DType::C64 {
2630 Thunk::BinaryFullC64 {
2631 lhs: node_offset(arena, node.inputs[0]),
2632 rhs: node_offset(arena, node.inputs[1]),
2633 dst: node_offset(arena, node.id),
2634 len: out_len as u32,
2635 lhs_len: lhs_len as u32,
2636 rhs_len: rhs_len as u32,
2637 op: *op,
2638 out_dims_bcast,
2639 bcast_lhs_strides,
2640 bcast_rhs_strides,
2641 }
2642 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2643 Thunk::BinaryFullF64 {
2646 lhs: node_offset(arena, node.inputs[0]),
2647 rhs: node_offset(arena, node.inputs[1]),
2648 dst: node_offset(arena, node.id),
2649 len: out_len as u32,
2650 lhs_len: lhs_len as u32,
2651 rhs_len: rhs_len as u32,
2652 op: *op,
2653 out_dims_bcast,
2654 bcast_lhs_strides,
2655 bcast_rhs_strides,
2656 }
2657 } else if matches!(op, BinaryOp::Add)
2658 && rhs_len < out_len
2659 && out_len % rhs_len == 0
2660 && is_trailing_bias_broadcast(
2661 graph.node(node.inputs[1]).shape.dims(),
2662 graph.node(node.id).shape.dims(),
2663 )
2664 {
2665 Thunk::BiasAdd {
2675 src: node_offset(arena, node.inputs[0]),
2676 bias: node_offset(arena, node.inputs[1]),
2677 dst: node_offset(arena, node.id),
2678 m: (out_len / rhs_len) as u32,
2679 n: rhs_len as u32,
2680 }
2681 } else {
2682 let lhs_len = get_len(graph, node.inputs[0]);
2683 Thunk::BinaryFull {
2684 lhs: node_offset(arena, node.inputs[0]),
2685 rhs: node_offset(arena, node.inputs[1]),
2686 dst: node_offset(arena, node.id),
2687 len: out_len as u32,
2688 lhs_len: lhs_len as u32,
2689 rhs_len: rhs_len as u32,
2690 op: *op,
2691 out_dims_bcast,
2692 bcast_lhs_strides,
2693 bcast_rhs_strides,
2694 elem_bytes: node.shape.dtype().size_bytes() as u8,
2695 }
2696 }
2697 }
2698
2699 Op::Activation(act) => {
2700 let len = node.shape.num_elements().unwrap();
2701 let in_off = node_offset(arena, node.inputs[0]);
2702 let out_off = node_offset(arena, node.id);
2703 if node.shape.dtype() == rlx_ir::DType::C64 {
2704 match act {
2709 Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2710 other => panic!(
2711 "Op::Activation({other:?}) on DType::C64: no \
2712 natural complex extension — supported on C64: \
2713 Neg, Exp, Log, Sqrt"
2714 ),
2715 }
2716 Thunk::ActivationC64 {
2717 src: in_off,
2718 dst: out_off,
2719 len: len as u32,
2720 kind: *act,
2721 }
2722 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2723 Thunk::ActivationF64 {
2724 src: in_off,
2725 dst: out_off,
2726 len: len as u32,
2727 kind: *act,
2728 }
2729 } else if in_off == out_off {
2730 Thunk::ActivationInPlace {
2734 data: out_off,
2735 len: len as u32,
2736 act: *act,
2737 }
2738 } else {
2739 thunks.push(Thunk::Copy {
2743 src: in_off,
2744 dst: out_off,
2745 len: len as u32,
2746 });
2747 Thunk::ActivationInPlace {
2748 data: out_off,
2749 len: len as u32,
2750 act: *act,
2751 }
2752 }
2753 }
2754
2755 Op::Gather { axis } if *axis == 0 => {
2756 let table_shape = &graph.node(node.inputs[0]).shape;
2757 let table_total = table_shape.num_elements().unwrap();
2758 let trailing: usize = (1..table_shape.rank())
2759 .map(|i| table_shape.dim(i).unwrap_static())
2760 .product();
2761 let idx_len = get_len(graph, node.inputs[1]);
2762 let idx_i64 =
2763 u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2764 let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2765 Thunk::Gather {
2766 table: node_offset(arena, node.inputs[0]),
2767 table_len: table_total as u32,
2768 idx: node_offset(arena, node.inputs[1]),
2769 dst: node_offset(arena, node.id),
2770 num_idx: idx_len as u32,
2771 trailing: trailing as u32,
2772 idx_i64,
2773 table_bytes,
2774 }
2775 }
2776
2777 Op::Gather { axis } => {
2778 let table_shape = &graph.node(node.inputs[0]).shape;
2780 let rank = table_shape.rank();
2781 let outer: usize = (0..*axis)
2782 .map(|i| table_shape.dim(i).unwrap_static())
2783 .product::<usize>()
2784 .max(1);
2785 let trailing: usize = (*axis + 1..rank)
2786 .map(|i| table_shape.dim(i).unwrap_static())
2787 .product::<usize>()
2788 .max(1);
2789 let axis_dim = table_shape.dim(*axis).unwrap_static();
2790 let idx_len = get_len(graph, node.inputs[1]);
2791 let idx_i64 =
2792 u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2793 let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2794 Thunk::GatherAxis {
2795 table: node_offset(arena, node.inputs[0]),
2796 idx: node_offset(arena, node.inputs[1]),
2797 dst: node_offset(arena, node.id),
2798 outer: outer as u32,
2799 axis_dim: axis_dim as u32,
2800 num_idx: idx_len as u32,
2801 trailing: trailing as u32,
2802 idx_i64,
2803 table_bytes,
2804 }
2805 }
2806
2807 Op::Narrow { axis, start, len } => {
2808 let in_shape = &graph.node(node.inputs[0]).shape;
2809 let elem_bytes = in_shape.dtype().size_bytes() as u8;
2810 let rank = in_shape.rank();
2811 let outer: usize = (0..*axis)
2812 .map(|i| in_shape.dim(i).unwrap_static())
2813 .product::<usize>()
2814 .max(1);
2815 let inner: usize = (*axis + 1..rank)
2816 .map(|i| in_shape.dim(i).unwrap_static())
2817 .product::<usize>()
2818 .max(1);
2819 let in_axis = in_shape.dim(*axis).unwrap_static();
2820 let src_byte_offset =
2821 node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2822 Thunk::Narrow {
2823 src: src_byte_offset,
2824 dst: node_offset(arena, node.id),
2825 outer: outer as u32,
2826 src_stride: (in_axis * inner) as u32, dst_stride: (*len * inner) as u32, inner: (*len * inner) as u32, elem_bytes,
2830 }
2831 }
2832
2833 Op::Reshape { .. } | Op::StopGradient => {
2834 let len = node.shape.num_elements().unwrap();
2836 let src = node_offset(arena, node.inputs[0]);
2837 let dst = node_offset(arena, node.id);
2838 match node.shape.dtype() {
2839 rlx_ir::DType::F64 => Thunk::CopyF64 {
2840 src,
2841 dst,
2842 len: len as u32,
2843 },
2844 rlx_ir::DType::I64 => Thunk::CopyI64 {
2845 src,
2846 dst,
2847 len: len as u32,
2848 },
2849 _ => Thunk::Copy {
2850 src,
2851 dst,
2852 len: len as u32,
2853 },
2854 }
2855 }
2856
2857 Op::Cast { to } => {
2858 let in_node = graph.node(node.inputs[0]);
2859 let in_dtype = in_node.shape.dtype();
2860 let out_dtype = *to;
2861 let len = node.shape.num_elements().unwrap();
2862 let src = node_offset(arena, node.inputs[0]);
2863 let dst = node_offset(arena, node.id);
2864 if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::I64 {
2865 Thunk::CastF32ToI64 {
2866 src,
2867 dst,
2868 len: len as u32,
2869 }
2870 } else if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::F64 {
2871 Thunk::CastF32ToF64 {
2872 src,
2873 dst,
2874 len: len as u32,
2875 }
2876 } else if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::I32 {
2877 Thunk::CastF32ToI32 {
2878 src,
2879 dst,
2880 len: len as u32,
2881 }
2882 } else if in_dtype == rlx_ir::DType::I64 && out_dtype == rlx_ir::DType::F32 {
2883 Thunk::CastI64ToF32 {
2884 src,
2885 dst,
2886 len: len as u32,
2887 }
2888 } else if in_dtype == rlx_ir::DType::Bool && out_dtype == rlx_ir::DType::I32 {
2889 Thunk::CastBoolToI32 {
2890 src,
2891 dst,
2892 len: len as u32,
2893 }
2894 } else if in_dtype == rlx_ir::DType::Bool && out_dtype == rlx_ir::DType::F32 {
2895 Thunk::CastBoolToF32 {
2898 src,
2899 dst,
2900 len: len as u32,
2901 }
2902 } else if in_dtype == rlx_ir::DType::I32 && out_dtype == rlx_ir::DType::F32 {
2903 Thunk::CastI32ToF32 {
2904 src,
2905 dst,
2906 len: len as u32,
2907 }
2908 } else if in_dtype == out_dtype {
2909 match out_dtype {
2910 rlx_ir::DType::F64 => Thunk::CopyF64 {
2911 src,
2912 dst,
2913 len: len as u32,
2914 },
2915 rlx_ir::DType::I64 => Thunk::CopyI64 {
2916 src,
2917 dst,
2918 len: len as u32,
2919 },
2920 _ => Thunk::Copy {
2921 src,
2922 dst,
2923 len: len as u32,
2924 },
2925 }
2926 } else {
2927 Thunk::Copy {
2928 src,
2929 dst,
2930 len: len as u32,
2931 }
2932 }
2933 }
2934
2935 Op::Quantize {
2936 axis,
2937 scales,
2938 zero_points,
2939 } => {
2940 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2941 Thunk::Quantize {
2942 x: node_offset(arena, node.inputs[0]),
2943 q: node_offset(arena, node.id),
2944 len: node.shape.num_elements().unwrap() as u32,
2945 chan_axis: chan_axis as u32,
2946 chan_dim: chan_dim as u32,
2947 inner: inner as u32,
2948 scales: scales.clone(),
2949 zero_points: zero_points.clone(),
2950 }
2951 }
2952
2953 Op::FakeQuantize {
2954 bits,
2955 axis,
2956 ste,
2957 scale_mode,
2958 } => {
2959 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2960 let state_off = match scale_mode {
2961 rlx_ir::op::ScaleMode::PerBatch => None,
2962 rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2963 debug_assert_eq!(
2965 node.inputs.len(),
2966 2,
2967 "EMA/Fixed FakeQuantize needs a state input"
2968 );
2969 Some(node_offset(arena, node.inputs[1]))
2970 }
2971 };
2972 Thunk::FakeQuantize {
2973 x: node_offset(arena, node.inputs[0]),
2974 out: node_offset(arena, node.id),
2975 len: node.shape.num_elements().unwrap() as u32,
2976 chan_axis: chan_axis as u32,
2977 chan_dim: chan_dim as u32,
2978 inner: inner as u32,
2979 bits: *bits,
2980 ste: *ste,
2981 scale_mode: *scale_mode,
2982 state_off,
2983 }
2984 }
2985
2986 Op::FakeQuantizeLSQ { bits, axis } => {
2987 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2988 Thunk::FakeQuantizeLSQ {
2989 x: node_offset(arena, node.inputs[0]),
2990 scale_off: node_offset(arena, node.inputs[1]),
2991 out: node_offset(arena, node.id),
2992 len: node.shape.num_elements().unwrap() as u32,
2993 chan_axis: chan_axis as u32,
2994 chan_dim: chan_dim as u32,
2995 inner: inner as u32,
2996 bits: *bits,
2997 }
2998 }
2999
3000 Op::FakeQuantizeLSQBackwardX { bits, axis } => {
3001 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
3002 Thunk::FakeQuantizeLSQBackwardX {
3003 x: node_offset(arena, node.inputs[0]),
3004 scale_off: node_offset(arena, node.inputs[1]),
3005 dy: node_offset(arena, node.inputs[2]),
3006 dx: node_offset(arena, node.id),
3007 len: node.shape.num_elements().unwrap() as u32,
3008 chan_axis: chan_axis as u32,
3009 chan_dim: chan_dim as u32,
3010 inner: inner as u32,
3011 bits: *bits,
3012 }
3013 }
3014
3015 Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
3016 let in_shape = &graph.node(node.inputs[0]).shape;
3019 let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
3020 Thunk::FakeQuantizeLSQBackwardScale {
3021 x: node_offset(arena, node.inputs[0]),
3022 scale_off: node_offset(arena, node.inputs[1]),
3023 dy: node_offset(arena, node.inputs[2]),
3024 dscale: node_offset(arena, node.id),
3025 len: in_shape.num_elements().unwrap() as u32,
3026 chan_axis: chan_axis as u32,
3027 chan_dim: chan_dim as u32,
3028 inner: inner as u32,
3029 bits: *bits,
3030 }
3031 }
3032
3033 Op::FakeQuantizeBackward { bits, axis, ste } => {
3034 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
3035 Thunk::FakeQuantizeBackward {
3036 x: node_offset(arena, node.inputs[0]),
3037 dy: node_offset(arena, node.inputs[1]),
3038 dx: node_offset(arena, node.id),
3039 len: node.shape.num_elements().unwrap() as u32,
3040 chan_axis: chan_axis as u32,
3041 chan_dim: chan_dim as u32,
3042 inner: inner as u32,
3043 bits: *bits,
3044 ste: *ste,
3045 }
3046 }
3047
3048 Op::Dequantize {
3049 axis,
3050 scales,
3051 zero_points,
3052 } => {
3053 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
3054 Thunk::Dequantize {
3055 q: node_offset(arena, node.inputs[0]),
3056 x: node_offset(arena, node.id),
3057 len: node.shape.num_elements().unwrap() as u32,
3058 chan_axis: chan_axis as u32,
3059 chan_dim: chan_dim as u32,
3060 inner: inner as u32,
3061 scales: scales.clone(),
3062 zero_points: zero_points.clone(),
3063 }
3064 }
3065
3066 Op::Expand { .. } => {
3067 let in_shape = &graph.node(node.inputs[0]).shape;
3072 let out_shape = &node.shape;
3073 let in_rank = in_shape.rank();
3074 let out_rank = out_shape.rank();
3075 let pad = out_rank.saturating_sub(in_rank);
3077 let in_dims: Vec<usize> = (0..out_rank)
3078 .map(|i| {
3079 if i < pad {
3080 1
3081 } else {
3082 in_shape.dim(i - pad).unwrap_static()
3083 }
3084 })
3085 .collect();
3086 let mut in_strides_full = vec![1usize; out_rank];
3088 for d in (0..out_rank.saturating_sub(1)).rev() {
3089 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3090 }
3091 let out_dims: Vec<u32> = (0..out_rank)
3092 .map(|i| out_shape.dim(i).unwrap_static() as u32)
3093 .collect();
3094 let in_strides: Vec<u32> = (0..out_rank)
3096 .map(|i| {
3097 if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
3098 0
3099 } else {
3100 in_strides_full[i] as u32
3101 }
3102 })
3103 .collect();
3104 let in_total = in_dims.iter().product::<usize>() as u32;
3105 let src = node_offset(arena, node.inputs[0]);
3106 let dst = node_offset(arena, node.id);
3107 let elem_bytes = node.shape.dtype().size_bytes() as u8;
3108 match node.shape.dtype() {
3109 rlx_ir::DType::F64 => Thunk::TransposeF64 {
3110 src,
3111 dst,
3112 in_total,
3113 out_dims,
3114 in_strides,
3115 },
3116 _ => Thunk::Transpose {
3117 src,
3118 dst,
3119 in_total,
3120 out_dims,
3121 in_strides,
3122 elem_bytes,
3123 },
3124 }
3125 }
3126
3127 Op::RmsNorm { eps, .. } => {
3128 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3129 let total = node.shape.num_elements().unwrap();
3130 Thunk::RmsNorm {
3131 src: node_offset(arena, node.inputs[0]),
3132 g: node_offset(arena, node.inputs[1]),
3133 b: node_offset(arena, node.inputs[2]),
3134 dst: node_offset(arena, node.id),
3135 rows: (total / h) as u32,
3136 h: h as u32,
3137 eps: *eps,
3138 }
3139 }
3140
3141 Op::LayerNorm { eps, .. } => {
3142 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3143 let total = node.shape.num_elements().unwrap();
3144 Thunk::LayerNorm {
3145 src: node_offset(arena, node.inputs[0]),
3146 g: node_offset(arena, node.inputs[1]),
3147 b: node_offset(arena, node.inputs[2]),
3148 dst: node_offset(arena, node.id),
3149 rows: (total / h) as u32,
3150 h: h as u32,
3151 eps: *eps,
3152 }
3153 }
3154
3155 Op::GroupNorm { num_groups, eps } => {
3156 let in_shape = &graph.node(node.inputs[0]).shape;
3157 let (n, c, h, w) = conv_nchw_dims(in_shape);
3158 Thunk::GroupNorm {
3159 src: node_offset(arena, node.inputs[0]),
3160 g: node_offset(arena, node.inputs[1]),
3161 b: node_offset(arena, node.inputs[2]),
3162 dst: node_offset(arena, node.id),
3163 n,
3164 c,
3165 h,
3166 w,
3167 num_groups: *num_groups as u32,
3168 eps: *eps,
3169 }
3170 }
3171
3172 Op::BatchNormInference { eps } => {
3173 let in_shape = &graph.node(node.inputs[0]).shape;
3174 let rank = in_shape.rank();
3175 let channels = in_shape.dim(rank - 1).unwrap_static();
3176 let total = in_shape.num_elements().unwrap_or(0);
3177 let count = (total / channels.max(1)) as u32;
3178 Thunk::BatchNormInference {
3179 src: node_offset(arena, node.inputs[0]),
3180 g: node_offset(arena, node.inputs[1]),
3181 b: node_offset(arena, node.inputs[2]),
3182 mean: node_offset(arena, node.inputs[3]),
3183 var: node_offset(arena, node.inputs[4]),
3184 dst: node_offset(arena, node.id),
3185 count,
3186 channels: channels as u32,
3187 eps: *eps,
3188 }
3189 }
3190
3191 Op::BatchNormInferenceBackwardInput { eps } => {
3192 let x_shape = &graph.node(node.inputs[0]).shape;
3193 let rank = x_shape.rank();
3194 let channels = x_shape.dim(rank - 1).unwrap_static();
3195 let total = x_shape.num_elements().unwrap_or(0);
3196 Thunk::BatchNormInferenceBackwardInput {
3197 x: node_offset(arena, node.inputs[0]),
3198 gamma: node_offset(arena, node.inputs[1]),
3199 mean: node_offset(arena, node.inputs[2]),
3200 var: node_offset(arena, node.inputs[3]),
3201 dy: node_offset(arena, node.inputs[4]),
3202 dx: node_offset(arena, node.id),
3203 count: (total / channels.max(1)) as u32,
3204 channels: channels as u32,
3205 eps: *eps,
3206 }
3207 }
3208
3209 Op::BatchNormInferenceBackwardGamma { eps } => {
3210 let x_shape = &graph.node(node.inputs[0]).shape;
3211 let rank = x_shape.rank();
3212 let channels = x_shape.dim(rank - 1).unwrap_static();
3213 let total = x_shape.num_elements().unwrap_or(0);
3214 let _gamma_shape = &graph.node(node.id).shape;
3215 Thunk::BatchNormInferenceBackwardGamma {
3216 x: node_offset(arena, node.inputs[0]),
3217 mean: node_offset(arena, node.inputs[1]),
3218 var: node_offset(arena, node.inputs[2]),
3219 dy: node_offset(arena, node.inputs[3]),
3220 dgamma: node_offset(arena, node.id),
3221 count: (total / channels.max(1)) as u32,
3222 channels: channels as u32,
3223 eps: *eps,
3224 }
3225 }
3226
3227 Op::BatchNormInferenceBackwardBeta => {
3228 let dy_shape = &graph.node(node.inputs[0]).shape;
3229 let rank = dy_shape.rank();
3230 let channels = dy_shape.dim(rank - 1).unwrap_static();
3231 let total = dy_shape.num_elements().unwrap_or(0);
3232 Thunk::BatchNormInferenceBackwardBeta {
3233 dy: node_offset(arena, node.inputs[0]),
3234 dbeta: node_offset(arena, node.id),
3235 count: (total / channels.max(1)) as u32,
3236 channels: channels as u32,
3237 }
3238 }
3239
3240 Op::LayerNorm2d { eps } => {
3241 let in_shape = &graph.node(node.inputs[0]).shape;
3242 let (n, c, h, w) = conv_nchw_dims(in_shape);
3243 Thunk::LayerNorm2d {
3244 src: node_offset(arena, node.inputs[0]),
3245 g: node_offset(arena, node.inputs[1]),
3246 b: node_offset(arena, node.inputs[2]),
3247 dst: node_offset(arena, node.id),
3248 n,
3249 c,
3250 h,
3251 w,
3252 eps: *eps,
3253 }
3254 }
3255
3256 Op::ConvTranspose2d {
3257 kernel_size,
3258 stride,
3259 padding,
3260 dilation,
3261 output_padding: _,
3262 groups,
3263 } => {
3264 let in_shape = &graph.node(node.inputs[0]).shape;
3265 let out_shape = &node.shape;
3266 let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3267 let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3268 Thunk::ConvTranspose2d {
3269 src: node_offset(arena, node.inputs[0]),
3270 weight: node_offset(arena, node.inputs[1]),
3271 dst: node_offset(arena, node.id),
3272 n,
3273 c_in,
3274 h,
3275 w_in,
3276 c_out,
3277 h_out,
3278 w_out,
3279 kh: kernel_size[0] as u32,
3280 kw: kernel_size[1] as u32,
3281 sh: stride.first().copied().unwrap_or(1) as u32,
3282 sw: stride.get(1).copied().unwrap_or(1) as u32,
3283 ph: padding.first().copied().unwrap_or(0) as u32,
3284 pw: padding.get(1).copied().unwrap_or(0) as u32,
3285 dh: dilation.first().copied().unwrap_or(1) as u32,
3286 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3287 groups: *groups as u32,
3288 }
3289 }
3290
3291 Op::ResizeNearest2x => {
3292 let in_shape = &graph.node(node.inputs[0]).shape;
3293 let (n, c, h, w) = conv_nchw_dims(in_shape);
3294 Thunk::ResizeNearest2x {
3295 src: node_offset(arena, node.inputs[0]),
3296 dst: node_offset(arena, node.id),
3297 n,
3298 c,
3299 h,
3300 w,
3301 }
3302 }
3303
3304 Op::AxialRope2d {
3305 end_x,
3306 end_y,
3307 head_dim,
3308 num_heads,
3309 theta,
3310 repeat_factor,
3311 } => {
3312 let in_shape = &graph.node(node.inputs[0]).shape;
3313 let batch = in_shape.dim(0).unwrap_static() as u32;
3314 let seq = in_shape.dim(1).unwrap_static() as u32;
3315 let hidden = in_shape.dim(2).unwrap_static() as u32;
3316 Thunk::AxialRope2d {
3317 src: node_offset(arena, node.inputs[0]),
3318 dst: node_offset(arena, node.id),
3319 batch,
3320 seq,
3321 hidden,
3322 end_x: *end_x as u32,
3323 end_y: *end_y as u32,
3324 head_dim: *head_dim as u32,
3325 num_heads: *num_heads as u32,
3326 theta: *theta,
3327 repeat_factor: *repeat_factor as u32,
3328 }
3329 }
3330
3331 Op::Softmax { axis } => {
3332 let rank = node.shape.rank();
3333 let ax = if *axis < 0 {
3334 (rank as i32 + axis) as usize
3335 } else {
3336 *axis as usize
3337 };
3338 let cols = node.shape.dim(ax).unwrap_static();
3339 let total = node.shape.num_elements().unwrap();
3340 let in_off = node_offset(arena, node.inputs[0]);
3341 let out_off = node_offset(arena, node.id);
3342 if in_off != out_off {
3348 thunks.push(Thunk::Copy {
3349 src: in_off,
3350 dst: out_off,
3351 len: total as u32,
3352 });
3353 }
3354 Thunk::Softmax {
3355 data: out_off,
3356 rows: (total / cols) as u32,
3357 cols: cols as u32,
3358 }
3359 }
3360
3361 Op::SelectiveScan { state_size } => {
3362 let in_shape = &graph.node(node.inputs[0]).shape;
3363 let (batch, seq, hidden) = (
3364 in_shape.dim(0).unwrap_static(),
3365 in_shape.dim(1).unwrap_static(),
3366 in_shape.dim(2).unwrap_static(),
3367 );
3368 Thunk::SelectiveScan {
3369 x: node_offset(arena, node.inputs[0]),
3370 delta: node_offset(arena, node.inputs[1]),
3371 a: node_offset(arena, node.inputs[2]),
3372 b: node_offset(arena, node.inputs[3]),
3373 c: node_offset(arena, node.inputs[4]),
3374 dst: node_offset(arena, node.id),
3375 batch: batch as u32,
3376 seq: seq as u32,
3377 hidden: hidden as u32,
3378 state_size: *state_size as u32,
3379 }
3380 }
3381
3382 Op::GatedDeltaNet {
3383 state_size,
3384 carry_state,
3385 } => {
3386 let q_shape = &graph.node(node.inputs[0]).shape;
3387 let (batch, seq, heads) = (
3388 q_shape.dim(0).unwrap_static(),
3389 q_shape.dim(1).unwrap_static(),
3390 q_shape.dim(2).unwrap_static(),
3391 );
3392 let state_off = if *carry_state {
3393 node_offset(arena, node.inputs[5])
3394 } else {
3395 0
3396 };
3397 Thunk::GatedDeltaNet {
3398 q: node_offset(arena, node.inputs[0]),
3399 k: node_offset(arena, node.inputs[1]),
3400 v: node_offset(arena, node.inputs[2]),
3401 g: node_offset(arena, node.inputs[3]),
3402 beta: node_offset(arena, node.inputs[4]),
3403 state: state_off,
3404 dst: node_offset(arena, node.id),
3405 batch: batch as u32,
3406 seq: seq as u32,
3407 heads: heads as u32,
3408 state_size: *state_size as u32,
3409 }
3410 }
3411
3412 Op::Lstm {
3413 hidden_size,
3414 num_layers,
3415 bidirectional,
3416 carry,
3417 } => {
3418 let x_shape = &graph.node(node.inputs[0]).shape;
3419 let (batch, seq, input_size) = (
3420 x_shape.dim(0).unwrap_static(),
3421 x_shape.dim(1).unwrap_static(),
3422 x_shape.dim(2).unwrap_static(),
3423 );
3424 let (h0, c0) = if *carry {
3425 (
3426 node_offset(arena, node.inputs[4]),
3427 node_offset(arena, node.inputs[5]),
3428 )
3429 } else {
3430 (0, 0)
3431 };
3432 Thunk::Lstm {
3433 x: node_offset(arena, node.inputs[0]),
3434 w_ih: node_offset(arena, node.inputs[1]),
3435 w_hh: node_offset(arena, node.inputs[2]),
3436 bias: node_offset(arena, node.inputs[3]),
3437 h0,
3438 c0,
3439 dst: node_offset(arena, node.id),
3440 batch: batch as u32,
3441 seq: seq as u32,
3442 input_size: input_size as u32,
3443 hidden: *hidden_size as u32,
3444 num_layers: *num_layers as u32,
3445 bidirectional: *bidirectional,
3446 carry: *carry,
3447 }
3448 }
3449
3450 Op::QMatMul {
3451 x_zp,
3452 w_zp,
3453 out_zp,
3454 mult,
3455 } => {
3456 let x_shape = &graph.node(node.inputs[0]).shape;
3457 let w_shape = &graph.node(node.inputs[1]).shape;
3458 let m = x_shape.dim(0).unwrap_static();
3459 let k = x_shape.dim(1).unwrap_static();
3460 let n = w_shape.dim(1).unwrap_static();
3461 Thunk::QMatMul {
3462 x: node_offset(arena, node.inputs[0]),
3463 w: node_offset(arena, node.inputs[1]),
3464 bias: node_offset(arena, node.inputs[2]),
3465 out: node_offset(arena, node.id),
3466 m: m as u32,
3467 k: k as u32,
3468 n: n as u32,
3469 x_zp: *x_zp,
3470 w_zp: *w_zp,
3471 out_zp: *out_zp,
3472 mult: *mult,
3473 }
3474 }
3475
3476 Op::QConv2d {
3477 kernel_size,
3478 stride,
3479 padding,
3480 dilation,
3481 groups,
3482 x_zp,
3483 w_zp,
3484 out_zp,
3485 mult,
3486 } => {
3487 let in_shape = &graph.node(node.inputs[0]).shape;
3488 let w_shape = &graph.node(node.inputs[1]).shape;
3489 let out_shape = &node.shape;
3490 if kernel_size.len() == 2
3491 && in_shape.rank() == 4
3492 && w_shape.rank() == 4
3493 && out_shape.rank() == 4
3494 {
3495 Thunk::QConv2d {
3496 x: node_offset(arena, node.inputs[0]),
3497 w: node_offset(arena, node.inputs[1]),
3498 bias: node_offset(arena, node.inputs[2]),
3499 out: node_offset(arena, node.id),
3500 n: in_shape.dim(0).unwrap_static() as u32,
3501 c_in: in_shape.dim(1).unwrap_static() as u32,
3502 h: in_shape.dim(2).unwrap_static() as u32,
3503 w_in: in_shape.dim(3).unwrap_static() as u32,
3504 c_out: out_shape.dim(1).unwrap_static() as u32,
3505 h_out: out_shape.dim(2).unwrap_static() as u32,
3506 w_out: out_shape.dim(3).unwrap_static() as u32,
3507 kh: kernel_size[0] as u32,
3508 kw: kernel_size[1] as u32,
3509 sh: stride.first().copied().unwrap_or(1) as u32,
3510 sw: stride.get(1).copied().unwrap_or(1) as u32,
3511 ph: padding.first().copied().unwrap_or(0) as u32,
3512 pw: padding.get(1).copied().unwrap_or(0) as u32,
3513 dh: dilation.first().copied().unwrap_or(1) as u32,
3514 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3515 groups: *groups as u32,
3516 x_zp: *x_zp,
3517 w_zp: *w_zp,
3518 out_zp: *out_zp,
3519 mult: *mult,
3520 }
3521 } else {
3522 Thunk::Nop
3523 }
3524 }
3525
3526 Op::DequantMatMul { scheme } => {
3527 use rlx_ir::quant::QuantScheme;
3528 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3529 let total = node.shape.num_elements().unwrap();
3530 let m = total / n.max(1);
3531 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3532 let k = x_total / m.max(1);
3533 if scheme.is_gguf() {
3534 Thunk::DequantMatMulGguf {
3535 x: node_offset(arena, node.inputs[0]),
3536 w_q: node_offset(arena, node.inputs[1]),
3537 dst: node_offset(arena, node.id),
3538 m: m as u32,
3539 k: k as u32,
3540 n: n as u32,
3541 scheme: *scheme,
3542 }
3543 } else {
3544 match scheme {
3545 QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3546 x: node_offset(arena, node.inputs[0]),
3547 w_q: node_offset(arena, node.inputs[1]),
3548 scale: node_offset(arena, node.inputs[2]),
3549 global_scale: node_offset(arena, node.inputs[3]),
3550 dst: node_offset(arena, node.id),
3551 m: m as u32,
3552 k: k as u32,
3553 n: n as u32,
3554 },
3555 QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3556 x: node_offset(arena, node.inputs[0]),
3557 w_q: node_offset(arena, node.inputs[1]),
3558 scale: node_offset(arena, node.inputs[2]),
3559 zp: node_offset(arena, node.inputs[3]),
3560 dst: node_offset(arena, node.id),
3561 m: m as u32,
3562 k: k as u32,
3563 n: n as u32,
3564 block_size: *block_size,
3565 is_asymmetric: false,
3566 },
3567 QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3568 x: node_offset(arena, node.inputs[0]),
3569 w_q: node_offset(arena, node.inputs[1]),
3570 scale: node_offset(arena, node.inputs[2]),
3571 dst: node_offset(arena, node.id),
3572 m: m as u32,
3573 k: k as u32,
3574 n: n as u32,
3575 e5m2: false,
3576 },
3577 QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3578 x: node_offset(arena, node.inputs[0]),
3579 w_q: node_offset(arena, node.inputs[1]),
3580 scale: node_offset(arena, node.inputs[2]),
3581 dst: node_offset(arena, node.id),
3582 m: m as u32,
3583 k: k as u32,
3584 n: n as u32,
3585 e5m2: true,
3586 },
3587 QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3588 x: node_offset(arena, node.inputs[0]),
3589 w_q: node_offset(arena, node.inputs[1]),
3590 scale: node_offset(arena, node.inputs[2]),
3591 zp: node_offset(arena, node.inputs[3]),
3592 dst: node_offset(arena, node.id),
3593 m: m as u32,
3594 k: k as u32,
3595 n: n as u32,
3596 block_size: *block_size,
3597 is_asymmetric: false,
3598 },
3599 QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3600 x: node_offset(arena, node.inputs[0]),
3601 w_q: node_offset(arena, node.inputs[1]),
3602 scale: node_offset(arena, node.inputs[2]),
3603 zp: node_offset(arena, node.inputs[3]),
3604 dst: node_offset(arena, node.id),
3605 m: m as u32,
3606 k: k as u32,
3607 n: n as u32,
3608 block_size: *block_size,
3609 is_asymmetric: true,
3610 },
3611 other => panic!(
3612 "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3613 ),
3614 }
3615 }
3616 }
3617
3618 Op::LoraMatMul { scale } => {
3619 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3621 let total = node.shape.num_elements().unwrap();
3622 let m = total / n.max(1);
3623 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3624 let k = x_total / m.max(1);
3625 let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3626 let r = a_total / k.max(1);
3627 Thunk::LoraMatMul {
3628 x: node_offset(arena, node.inputs[0]),
3629 w: node_offset(arena, node.inputs[1]),
3630 a: node_offset(arena, node.inputs[2]),
3631 b: node_offset(arena, node.inputs[3]),
3632 dst: node_offset(arena, node.id),
3633 m: m as u32,
3634 k: k as u32,
3635 n: n as u32,
3636 r: r as u32,
3637 scale: *scale,
3638 }
3639 }
3640
3641 Op::Sample {
3642 top_k,
3643 top_p,
3644 temperature,
3645 seed,
3646 } => {
3647 let in_shape = &graph.node(node.inputs[0]).shape;
3648 let (batch, vocab) = if in_shape.rank() >= 2 {
3650 (
3651 in_shape.dim(0).unwrap_static(),
3652 in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3653 )
3654 } else {
3655 (1, in_shape.num_elements().unwrap_or(0))
3656 };
3657 Thunk::Sample {
3658 logits: node_offset(arena, node.inputs[0]),
3659 dst: node_offset(arena, node.id),
3660 batch: batch as u32,
3661 vocab: vocab as u32,
3662 top_k: *top_k as u32,
3663 top_p: *top_p,
3664 temperature: *temperature,
3665 seed: *seed,
3666 }
3667 }
3668
3669 Op::RngNormal {
3670 mean,
3671 scale,
3672 key,
3673 op_seed,
3674 } => Thunk::RngNormal {
3675 dst: node_offset(arena, node.id),
3676 len: node.shape.num_elements().unwrap_or(0) as u32,
3677 mean: *mean,
3678 scale: *scale,
3679 key: *key,
3680 op_seed: *op_seed,
3681 },
3682
3683 Op::RngUniform {
3684 low,
3685 high,
3686 key,
3687 op_seed,
3688 } => Thunk::RngUniform {
3689 dst: node_offset(arena, node.id),
3690 len: node.shape.num_elements().unwrap_or(0) as u32,
3691 low: *low,
3692 high: *high,
3693 key: *key,
3694 op_seed: *op_seed,
3695 },
3696
3697 Op::Cumsum { axis, exclusive } => {
3698 let rank = node.shape.rank();
3703 let ax = if *axis < 0 {
3704 (rank as i32 + axis) as usize
3705 } else {
3706 *axis as usize
3707 };
3708 assert_eq!(
3709 ax,
3710 rank - 1,
3711 "Cumsum only supports the last axis on CPU today"
3712 );
3713 let cols = node.shape.dim(ax).unwrap_static();
3714 let total = node.shape.num_elements().unwrap();
3715 Thunk::Cumsum {
3716 src: node_offset(arena, node.inputs[0]),
3717 dst: node_offset(arena, node.id),
3718 rows: (total / cols) as u32,
3719 cols: cols as u32,
3720 exclusive: *exclusive,
3721 }
3722 }
3723
3724 Op::Attention {
3725 num_heads,
3726 head_dim,
3727 mask_kind,
3728 score_scale,
3729 attn_logit_softcap: _,
3730 } => {
3731 let q_shape = &graph.node(node.inputs[0]).shape;
3737 let k_shape = &graph.node(node.inputs[1]).shape;
3738 let rank = q_shape.rank();
3739 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3740 let d1 = q_shape.dim(1).unwrap_static();
3741 let d2 = q_shape.dim(2).unwrap_static();
3742 if d1 == *num_heads {
3743 (
3745 q_shape.dim(0).unwrap_static(),
3746 d2,
3747 k_shape.dim(2).unwrap_static(),
3748 true,
3749 )
3750 } else {
3751 (
3753 q_shape.dim(0).unwrap_static(),
3754 d1,
3755 k_shape.dim(1).unwrap_static(),
3756 false,
3757 )
3758 }
3759 } else if rank >= 3 {
3760 (
3761 q_shape.dim(0).unwrap_static(),
3762 q_shape.dim(1).unwrap_static(),
3763 k_shape.dim(1).unwrap_static(),
3764 false,
3765 )
3766 } else {
3767 (
3768 1,
3769 q_shape.dim(0).unwrap_static(),
3770 k_shape.dim(0).unwrap_static(),
3771 false,
3772 )
3773 };
3774 let mask_off = if matches!(
3775 mask_kind,
3776 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3777 ) {
3778 node_offset(arena, node.inputs[3])
3779 } else {
3780 0
3781 };
3782 let hs = (*num_heads * *head_dim) as u32;
3783 Thunk::Attention {
3784 q: node_offset(arena, node.inputs[0]),
3785 k: node_offset(arena, node.inputs[1]),
3786 v: node_offset(arena, node.inputs[2]),
3787 mask: mask_off,
3788 out: node_offset(arena, node.id),
3789 batch: batch as u32,
3790 seq: seq as u32,
3791 kv_seq: kv_seq as u32,
3792 heads: *num_heads as u32,
3793 head_dim: *head_dim as u32,
3794 mask_kind: *mask_kind,
3795 scale: score_scale.unwrap_or((*head_dim as f32).powf(-0.5)),
3796 q_row_stride: hs,
3800 k_row_stride: hs,
3801 v_row_stride: hs,
3802 bhsd,
3803 }
3804 }
3805
3806 Op::AttentionBackward {
3807 num_heads,
3808 head_dim,
3809 mask_kind,
3810 wrt,
3811 } => {
3812 let q_shape = &graph.node(node.inputs[0]).shape;
3813 let k_shape = &graph.node(node.inputs[1]).shape;
3814 let rank = q_shape.rank();
3815 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3816 let d1 = q_shape.dim(1).unwrap_static();
3817 let d2 = q_shape.dim(2).unwrap_static();
3818 if d1 == *num_heads {
3819 (
3820 q_shape.dim(0).unwrap_static(),
3821 d2,
3822 k_shape.dim(2).unwrap_static(),
3823 true,
3824 )
3825 } else {
3826 (
3827 q_shape.dim(0).unwrap_static(),
3828 d1,
3829 k_shape.dim(1).unwrap_static(),
3830 false,
3831 )
3832 }
3833 } else if rank >= 3 {
3834 (
3835 q_shape.dim(0).unwrap_static(),
3836 q_shape.dim(1).unwrap_static(),
3837 k_shape.dim(1).unwrap_static(),
3838 false,
3839 )
3840 } else {
3841 (
3842 1,
3843 q_shape.dim(0).unwrap_static(),
3844 k_shape.dim(0).unwrap_static(),
3845 false,
3846 )
3847 };
3848 let mask_off = if matches!(
3849 mask_kind,
3850 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3851 ) {
3852 node_offset(arena, node.inputs[4])
3853 } else {
3854 0
3855 };
3856 Thunk::AttentionBackward {
3857 q: node_offset(arena, node.inputs[0]),
3858 k: node_offset(arena, node.inputs[1]),
3859 v: node_offset(arena, node.inputs[2]),
3860 dy: node_offset(arena, node.inputs[3]),
3861 mask: mask_off,
3862 out: node_offset(arena, node.id),
3863 batch: batch as u32,
3864 seq: seq as u32,
3865 kv_seq: kv_seq as u32,
3866 heads: *num_heads as u32,
3867 head_dim: *head_dim as u32,
3868 mask_kind: *mask_kind,
3869 wrt: *wrt,
3870 bhsd,
3871 }
3872 }
3873
3874 Op::FusedAttentionBlock {
3875 num_heads,
3876 head_dim,
3877 has_bias,
3878 has_rope,
3879 } => {
3880 let x_shape = &graph.node(node.inputs[0]).shape;
3881 let (batch, seq) = if x_shape.rank() >= 3 {
3882 (
3883 x_shape.dim(0).unwrap_static(),
3884 x_shape.dim(1).unwrap_static(),
3885 )
3886 } else {
3887 let total = x_shape.num_elements().unwrap();
3888 let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3889 (total / (s * num_heads * head_dim), s)
3890 };
3891 let hs = (*num_heads * *head_dim) as u32;
3892 let mut idx = 4;
3894 let (qkv_b_off, out_b_off) = if *has_bias {
3895 let qb = node_offset(arena, node.inputs[idx]);
3896 let ob = node_offset(arena, node.inputs[idx + 1]);
3897 idx += 2;
3898 (qb, ob)
3899 } else {
3900 (0, 0)
3901 };
3902 let (cos_off, sin_off, cl) = if *has_rope {
3903 let c = node_offset(arena, node.inputs[idx]);
3904 let s = node_offset(arena, node.inputs[idx + 1]);
3905 let clen = get_len(graph, node.inputs[idx]);
3906 (c, s, clen as u32)
3907 } else {
3908 (0, 0, 0)
3909 };
3910
3911 Thunk::FusedAttnBlock {
3912 hidden: node_offset(arena, node.inputs[0]),
3913 qkv_w: node_offset(arena, node.inputs[1]),
3914 out_w: node_offset(arena, node.inputs[2]),
3915 mask: node_offset(arena, node.inputs[3]),
3916 out: node_offset(arena, node.id),
3917 qkv_b: qkv_b_off,
3918 out_b: out_b_off,
3919 cos: cos_off,
3920 sin: sin_off,
3921 cos_len: cl,
3922 batch: batch as u32,
3923 seq: seq as u32,
3924 hs,
3925 nh: *num_heads as u32,
3926 dh: *head_dim as u32,
3927 has_bias: *has_bias,
3928 has_rope: *has_rope,
3929 }
3930 }
3931
3932 Op::Rope { head_dim, n_rot } => {
3933 let x_shape = &graph.node(node.inputs[0]).shape;
3934 let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3935 (
3936 x_shape.dim(0).unwrap_static(),
3937 x_shape.dim(1).unwrap_static(),
3938 x_shape.dim(2).unwrap_static(),
3939 )
3940 } else {
3941 let total = x_shape.num_elements().unwrap();
3942 (
3943 1,
3944 x_shape.dim(0).unwrap_static(),
3945 total / x_shape.dim(0).unwrap_static(),
3946 )
3947 };
3948 let cos_len = get_len(graph, node.inputs[1]);
3949 Thunk::Rope {
3950 src: node_offset(arena, node.inputs[0]),
3951 cos: node_offset(arena, node.inputs[1]),
3952 sin: node_offset(arena, node.inputs[2]),
3953 dst: node_offset(arena, node.id),
3954 batch: batch as u32,
3955 seq: seq as u32,
3956 hidden: hidden as u32,
3957 head_dim: *head_dim as u32,
3958 n_rot: *n_rot as u32,
3959 cos_len: cos_len as u32,
3960 src_row_stride: hidden as u32,
3964 }
3965 }
3966
3967 Op::FusedSwiGLU {
3968 cast_to: _,
3969 gate_first,
3970 } => {
3971 let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3972 let total = node.shape.num_elements().unwrap();
3973 Thunk::FusedSwiGLU {
3974 src: node_offset(arena, node.inputs[0]),
3975 dst: node_offset(arena, node.id),
3976 n_half: n_half as u32,
3977 total: total as u32,
3978 gate_first: *gate_first,
3979 }
3980 }
3981
3982 Op::Conv {
3983 kernel_size,
3984 stride,
3985 padding,
3986 dilation,
3987 groups,
3988 } => {
3989 let in_shape = &graph.node(node.inputs[0]).shape;
3990 let w_shape = &graph.node(node.inputs[1]).shape;
3991 let out_shape = &node.shape;
3992 let is_1x1_simple = kernel_size.len() == 2
3996 && kernel_size[0] == 1
3997 && kernel_size[1] == 1
3998 && stride.iter().all(|&s| s == 1)
3999 && padding.iter().all(|&p| p == 0)
4000 && dilation.iter().all(|&d| d == 1)
4001 && *groups == 1;
4002 if is_1x1_simple
4003 && in_shape.rank() >= 3
4004 && out_shape.rank() >= 3
4005 && w_shape.rank() >= 2
4006 {
4007 let (n, c_in, h, w) = conv_nchw_dims(in_shape);
4008 let (_, c_out, _, _) = conv_nchw_dims(out_shape);
4009 Thunk::Conv2D1x1 {
4010 src: node_offset(arena, node.inputs[0]),
4011 weight: node_offset(arena, node.inputs[1]),
4012 dst: node_offset(arena, node.id),
4013 n,
4014 c_in,
4015 c_out,
4016 hw: h.saturating_mul(w),
4017 }
4018 } else if kernel_size.len() == 2
4019 && in_shape.rank() >= 3
4020 && w_shape.rank() >= 2
4021 && out_shape.rank() >= 3
4022 {
4023 let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
4024 let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
4025 let one_d_w = h == 1
4033 && w_in > 1
4034 && kernel_size[0] > 1
4035 && kernel_size.get(1).copied().unwrap_or(1) == 1;
4036 let (h, w_in, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw) = if one_d_w {
4037 (
4038 w_in,
4039 1,
4040 w_out,
4041 1,
4042 kernel_size[0] as u32,
4043 1,
4044 stride.first().copied().unwrap_or(1) as u32,
4045 1,
4046 padding.first().copied().unwrap_or(0) as u32,
4047 0,
4048 dilation.first().copied().unwrap_or(1) as u32,
4049 1,
4050 )
4051 } else {
4052 (
4053 h,
4054 w_in,
4055 h_out,
4056 w_out,
4057 kernel_size[0] as u32,
4058 kernel_size[1] as u32,
4059 stride.first().copied().unwrap_or(1) as u32,
4060 stride.get(1).copied().unwrap_or(1) as u32,
4061 padding.first().copied().unwrap_or(0) as u32,
4062 padding.get(1).copied().unwrap_or(0) as u32,
4063 dilation.first().copied().unwrap_or(1) as u32,
4064 dilation.get(1).copied().unwrap_or(1) as u32,
4065 )
4066 };
4067 Thunk::Conv2D {
4068 src: node_offset(arena, node.inputs[0]),
4069 weight: node_offset(arena, node.inputs[1]),
4070 dst: node_offset(arena, node.id),
4071 n,
4072 c_in,
4073 h,
4074 w: w_in,
4075 c_out,
4076 h_out,
4077 w_out,
4078 kh,
4079 kw,
4080 sh,
4081 sw,
4082 ph,
4083 pw,
4084 dh,
4085 dw,
4086 groups: *groups as u32,
4087 }
4088 } else {
4089 Thunk::Nop
4090 }
4091 }
4092
4093 Op::Pool {
4094 kind,
4095 kernel_size,
4096 stride,
4097 padding,
4098 } => {
4099 let in_shape = &graph.node(node.inputs[0]).shape;
4101 let out_shape = &node.shape;
4102 if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
4103 Thunk::Pool2D {
4104 src: node_offset(arena, node.inputs[0]),
4105 dst: node_offset(arena, node.id),
4106 n: in_shape.dim(0).unwrap_static() as u32,
4107 c: in_shape.dim(1).unwrap_static() as u32,
4108 h: in_shape.dim(2).unwrap_static() as u32,
4109 w: in_shape.dim(3).unwrap_static() as u32,
4110 h_out: out_shape.dim(2).unwrap_static() as u32,
4111 w_out: out_shape.dim(3).unwrap_static() as u32,
4112 kh: kernel_size[0] as u32,
4113 kw: kernel_size[1] as u32,
4114 sh: stride.first().copied().unwrap_or(1) as u32,
4115 sw: stride.get(1).copied().unwrap_or(1) as u32,
4116 ph: padding.first().copied().unwrap_or(0) as u32,
4117 pw: padding.get(1).copied().unwrap_or(0) as u32,
4118 kind: *kind,
4119 }
4120 } else {
4121 Thunk::Nop
4122 }
4123 }
4124
4125 Op::Transpose { perm } => {
4126 let in_shape = &graph.node(node.inputs[0]).shape;
4129 let in_rank = in_shape.rank();
4130 if perm.iter().any(|&p| p >= in_rank) {
4131 Thunk::Nop
4132 } else {
4133 let in_dims: Vec<usize> = (0..in_rank)
4134 .map(|i| in_shape.dim(i).unwrap_static())
4135 .collect();
4136 let mut in_strides_full = vec![1usize; in_rank];
4138 for d in (0..in_rank.saturating_sub(1)).rev() {
4139 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
4140 }
4141 let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
4142 let in_strides: Vec<u32> =
4143 perm.iter().map(|&p| in_strides_full[p] as u32).collect();
4144 let in_total = in_dims.iter().product::<usize>() as u32;
4145 let src = node_offset(arena, node.inputs[0]);
4146 let dst = node_offset(arena, node.id);
4147 let elem_bytes = node.shape.dtype().size_bytes() as u8;
4148 match node.shape.dtype() {
4149 rlx_ir::DType::F64 => Thunk::TransposeF64 {
4150 src,
4151 dst,
4152 in_total,
4153 out_dims,
4154 in_strides,
4155 },
4156 _ => Thunk::Transpose {
4157 src,
4158 dst,
4159 in_total,
4160 out_dims,
4161 in_strides,
4162 elem_bytes,
4163 },
4164 }
4165 }
4166 }
4167
4168 Op::ScatterAdd => {
4169 let upd_shape = &graph.node(node.inputs[0]).shape;
4172 let out_shape = &node.shape;
4173 let num_updates = upd_shape.dim(0).unwrap_static();
4174 let out_dim = out_shape.dim(0).unwrap_static();
4175 let trailing: usize = (1..out_shape.rank())
4176 .map(|i| out_shape.dim(i).unwrap_static())
4177 .product::<usize>()
4178 .max(1);
4179 Thunk::ScatterAdd {
4180 updates: node_offset(arena, node.inputs[0]),
4181 indices: node_offset(arena, node.inputs[1]),
4182 dst: node_offset(arena, node.id),
4183 num_updates: num_updates as u32,
4184 out_dim: out_dim as u32,
4185 trailing: trailing as u32,
4186 }
4187 }
4188
4189 Op::GroupedMatMul => {
4190 let in_shape = &graph.node(node.inputs[0]).shape;
4192 let w_shape = &graph.node(node.inputs[1]).shape;
4193 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
4194 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
4195 let num_experts = w_shape.dim(0).unwrap_static();
4196 let n = w_shape.dim(2).unwrap_static();
4197 Thunk::GroupedMatMul {
4198 input: node_offset(arena, node.inputs[0]),
4199 weight: node_offset(arena, node.inputs[1]),
4200 expert_idx: node_offset(arena, node.inputs[2]),
4201 dst: node_offset(arena, node.id),
4202 m: m as u32,
4203 k_dim: k_dim as u32,
4204 n: n as u32,
4205 num_experts: num_experts as u32,
4206 }
4207 }
4208
4209 Op::DequantGroupedMatMul { scheme } => {
4210 let in_shape = &graph.node(node.inputs[0]).shape;
4211 let w_shape = &graph.node(node.inputs[1]).shape;
4212 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
4213 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
4214 let out_shape = &node.shape;
4215 let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
4216 let block_elems = scheme.gguf_block_size() as usize;
4217 let block_bytes = scheme.gguf_block_bytes() as usize;
4218 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
4219 let total_bytes = w_shape.num_elements().unwrap();
4220 let num_experts = total_bytes / slab_bytes.max(1);
4221 Thunk::DequantGroupedMatMulGguf {
4222 input: node_offset(arena, node.inputs[0]),
4223 w_q: node_offset(arena, node.inputs[1]),
4224 expert_idx: node_offset(arena, node.inputs[2]),
4225 dst: node_offset(arena, node.id),
4226 m: m as u32,
4227 k_dim: k_dim as u32,
4228 n: n as u32,
4229 num_experts: num_experts as u32,
4230 scheme: *scheme,
4231 }
4232 }
4233
4234 Op::DequantMoEWeights { scheme } => {
4235 let w_shape = &graph.node(node.inputs[0]).shape;
4236 let out_shape = &node.shape;
4237 let num_experts = out_shape.dim(0).unwrap_static();
4238 let k_dim = out_shape.dim(1).unwrap_static();
4239 let n = out_shape.dim(2).unwrap_static();
4240 let block_elems = scheme.gguf_block_size() as usize;
4241 let block_bytes = scheme.gguf_block_bytes() as usize;
4242 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
4243 let total_bytes = w_shape.num_elements().unwrap();
4244 assert_eq!(
4245 total_bytes,
4246 num_experts * slab_bytes,
4247 "DequantMoEWeights packed bytes mismatch"
4248 );
4249 Thunk::DequantMoEWeightsGguf {
4250 w_q: node_offset(arena, node.inputs[0]),
4251 dst: node_offset(arena, node.id),
4252 k_dim: k_dim as u32,
4253 n: n as u32,
4254 num_experts: num_experts as u32,
4255 scheme: *scheme,
4256 }
4257 }
4258
4259 Op::TopK { k } => {
4260 let in_shape = &graph.node(node.inputs[0]).shape;
4261 let rank = in_shape.rank();
4262 let axis_dim = in_shape.dim(rank - 1).unwrap_static();
4263 let outer = in_shape.num_elements().unwrap() / axis_dim;
4264 let indices_i64 = u8::from(graph.node(node.id).shape.dtype() == rlx_ir::DType::I64);
4265 Thunk::TopK {
4266 src: node_offset(arena, node.inputs[0]),
4267 dst: node_offset(arena, node.id),
4268 outer: outer as u32,
4269 axis_dim: axis_dim as u32,
4270 k: *k as u32,
4271 indices_i64,
4272 }
4273 }
4274
4275 Op::Reduce {
4276 op,
4277 axes,
4278 keep_dim: _,
4279 } => {
4280 let in_shape = &graph.node(node.inputs[0]).shape;
4286 let rank = in_shape.rank();
4287 let mut sorted = axes.clone();
4288 sorted.sort();
4289 sorted.dedup();
4290 let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
4291 && !sorted.is_empty()
4292 && *sorted.last().unwrap() < rank;
4293 if !contiguous {
4294 Thunk::Nop
4295 } else {
4296 let first = sorted[0];
4297 let last = *sorted.last().unwrap();
4298 let outer: usize = (0..first)
4299 .map(|i| in_shape.dim(i).unwrap_static())
4300 .product::<usize>()
4301 .max(1);
4302 let reduced: usize = (first..=last)
4303 .map(|i| in_shape.dim(i).unwrap_static())
4304 .product();
4305 let inner: usize = (last + 1..rank)
4306 .map(|i| in_shape.dim(i).unwrap_static())
4307 .product::<usize>()
4308 .max(1);
4309 let src = node_offset(arena, node.inputs[0]);
4310 let dst = node_offset(arena, node.id);
4311 if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
4312 Thunk::ReduceSumF64 {
4313 src,
4314 dst,
4315 outer: outer as u32,
4316 reduced: reduced as u32,
4317 inner: inner as u32,
4318 }
4319 } else {
4320 Thunk::Reduce {
4321 src,
4322 dst,
4323 outer: outer as u32,
4324 reduced: reduced as u32,
4325 inner: inner as u32,
4326 op: *op,
4327 }
4328 }
4329 }
4330 }
4331
4332 Op::ArgMax { axis, keep_dim: _ } | Op::ArgMin { axis, keep_dim: _ } => {
4333 let in_shape = &graph.node(node.inputs[0]).shape;
4334 let rank = in_shape.rank();
4335 let outer: usize = (0..*axis)
4336 .map(|i| in_shape.dim(i).unwrap_static())
4337 .product::<usize>()
4338 .max(1);
4339 let reduced = in_shape.dim(*axis).unwrap_static();
4340 let inner: usize = (*axis + 1..rank)
4341 .map(|i| in_shape.dim(i).unwrap_static())
4342 .product::<usize>()
4343 .max(1);
4344 Thunk::ArgReduce {
4345 src: node_offset(arena, node.inputs[0]),
4346 dst: node_offset(arena, node.id),
4347 outer: outer as u32,
4348 reduced: reduced as u32,
4349 inner: inner as u32,
4350 is_max: matches!(node.op, Op::ArgMax { .. }),
4351 }
4352 }
4353
4354 Op::Compare(cmp) => {
4355 let len = node.shape.num_elements().unwrap();
4356 let in_dtype = graph.node(node.inputs[0]).shape.dtype();
4357 let inputs_i64 = u8::from(in_dtype == rlx_ir::DType::I64);
4358 Thunk::Compare {
4359 lhs: node_offset(arena, node.inputs[0]),
4360 rhs: node_offset(arena, node.inputs[1]),
4361 dst: node_offset(arena, node.id),
4362 len: len as u32,
4363 op: *cmp,
4364 inputs_i64,
4365 inputs_elem_bytes: in_dtype.size_bytes() as u8,
4366 dst_elem_bytes: node.shape.dtype().size_bytes() as u8,
4367 }
4368 }
4369
4370 Op::Where => {
4371 let len = node.shape.num_elements().unwrap();
4372 let elem_bytes = node.shape.dtype().size_bytes() as u8;
4373 let cond_elem_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
4374 Thunk::Where {
4375 cond: node_offset(arena, node.inputs[0]),
4376 on_true: node_offset(arena, node.inputs[1]),
4377 on_false: node_offset(arena, node.inputs[2]),
4378 dst: node_offset(arena, node.id),
4379 len: len as u32,
4380 elem_bytes,
4381 cond_elem_bytes,
4382 }
4383 }
4384
4385 Op::ReluBackward => {
4386 let len: usize = (0..node.shape.rank())
4387 .map(|i| node.shape.dim(i).unwrap_static())
4388 .product();
4389 let x = node_offset(arena, node.inputs[0]);
4390 let dy = node_offset(arena, node.inputs[1]);
4391 let dx = node_offset(arena, node.id);
4392 match node.shape.dtype() {
4393 rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
4394 x,
4395 dy,
4396 dx,
4397 len: len as u32,
4398 },
4399 _ => Thunk::ReluBackward {
4400 x,
4401 dy,
4402 dx,
4403 len: len as u32,
4404 },
4405 }
4406 }
4407
4408 Op::ComplexNormSq => {
4409 let len: usize = (0..node.shape.rank())
4410 .map(|i| node.shape.dim(i).unwrap_static())
4411 .product();
4412 let src = node_offset(arena, node.inputs[0]);
4413 let dst = node_offset(arena, node.id);
4414 Thunk::ComplexNormSqF32 {
4415 src,
4416 dst,
4417 len: len as u32,
4418 }
4419 }
4420
4421 Op::ComplexNormSqBackward => {
4422 let len: usize = (0..node.shape.rank())
4423 .map(|i| node.shape.dim(i).unwrap_static())
4424 .product();
4425 let z = node_offset(arena, node.inputs[0]);
4426 let g = node_offset(arena, node.inputs[1]);
4427 let dz = node_offset(arena, node.id);
4428 Thunk::ComplexNormSqBackwardF32 {
4429 z,
4430 g,
4431 dz,
4432 len: len as u32,
4433 }
4434 }
4435
4436 Op::Conjugate => {
4437 let len: usize = (0..node.shape.rank())
4438 .map(|i| node.shape.dim(i).unwrap_static())
4439 .product();
4440 Thunk::ConjugateC64 {
4441 src: node_offset(arena, node.inputs[0]),
4442 dst: node_offset(arena, node.id),
4443 len: len as u32,
4444 }
4445 }
4446
4447 Op::ActivationBackward { kind } => {
4448 let len: usize = (0..node.shape.rank())
4449 .map(|i| node.shape.dim(i).unwrap_static())
4450 .product();
4451 let x = node_offset(arena, node.inputs[0]);
4452 let dy = node_offset(arena, node.inputs[1]);
4453 let dx = node_offset(arena, node.id);
4454 match node.shape.dtype() {
4455 rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
4456 x,
4457 dy,
4458 dx,
4459 len: len as u32,
4460 kind: *kind,
4461 },
4462 _ => Thunk::ActivationBackward {
4463 x,
4464 dy,
4465 dx,
4466 len: len as u32,
4467 kind: *kind,
4468 },
4469 }
4470 }
4471
4472 Op::LayerNormBackwardInput { eps, .. } => {
4473 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
4475 let total = node.shape.num_elements().unwrap();
4476 Thunk::LayerNormBackwardInput {
4477 x: node_offset(arena, node.inputs[0]),
4478 gamma: node_offset(arena, node.inputs[1]),
4479 dy: node_offset(arena, node.inputs[2]),
4480 dx: node_offset(arena, node.id),
4481 rows: (total / h) as u32,
4482 h: h as u32,
4483 eps: *eps,
4484 }
4485 }
4486
4487 Op::LayerNormBackwardGamma { eps, .. } => {
4488 let x_shape = &graph.node(node.inputs[0]).shape;
4489 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4490 let x_total = x_shape.num_elements().unwrap();
4491 Thunk::LayerNormBackwardGamma {
4492 x: node_offset(arena, node.inputs[0]),
4493 dy: node_offset(arena, node.inputs[1]),
4494 dgamma: node_offset(arena, node.id),
4495 rows: (x_total / h) as u32,
4496 h: h as u32,
4497 eps: *eps,
4498 }
4499 }
4500
4501 Op::RmsNormBackwardInput { eps, .. }
4502 | Op::RmsNormBackwardGamma { eps, .. }
4503 | Op::RmsNormBackwardBeta { eps, .. } => {
4504 let x_shape = &graph.node(node.inputs[0]).shape;
4505 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4506 let rows = (x_shape.num_elements().unwrap() / h) as u32;
4507 let off = |i: usize| node_offset(arena, node.inputs[i]);
4508 let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
4509 match &node.op {
4510 Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
4511 x: common.0,
4512 gamma: common.1,
4513 beta: common.2,
4514 dy: common.3,
4515 dx: node_offset(arena, node.id),
4516 rows: common.4,
4517 h: common.5,
4518 eps: common.6,
4519 },
4520 Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
4521 x: common.0,
4522 gamma: common.1,
4523 beta: common.2,
4524 dy: common.3,
4525 dgamma: node_offset(arena, node.id),
4526 rows: common.4,
4527 h: common.5,
4528 eps: common.6,
4529 },
4530 Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
4531 x: common.0,
4532 gamma: common.1,
4533 beta: common.2,
4534 dy: common.3,
4535 dbeta: node_offset(arena, node.id),
4536 rows: common.4,
4537 h: common.5,
4538 eps: common.6,
4539 },
4540 _ => unreachable!(),
4541 }
4542 }
4543
4544 Op::RopeBackward { head_dim, n_rot } => {
4545 let dy_shape = &graph.node(node.inputs[0]).shape;
4546 let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
4547 (
4548 dy_shape.dim(0).unwrap_static(),
4549 dy_shape.dim(1).unwrap_static(),
4550 dy_shape.dim(2).unwrap_static(),
4551 )
4552 } else {
4553 (
4554 1,
4555 dy_shape.dim(0).unwrap_static(),
4556 dy_shape.dim(1).unwrap_static(),
4557 )
4558 };
4559 let cos_shape = &graph.node(node.inputs[1]).shape;
4560 let cos_len = cos_shape.num_elements().unwrap();
4561 Thunk::RopeBackward {
4562 dy: node_offset(arena, node.inputs[0]),
4563 cos: node_offset(arena, node.inputs[1]),
4564 sin: node_offset(arena, node.inputs[2]),
4565 dx: node_offset(arena, node.id),
4566 batch: batch as u32,
4567 seq: seq as u32,
4568 hidden: hidden as u32,
4569 head_dim: *head_dim as u32,
4570 n_rot: *n_rot as u32,
4571 cos_len: cos_len as u32,
4572 }
4573 }
4574
4575 Op::CumsumBackward { exclusive, .. } => {
4576 let dy_shape = &graph.node(node.inputs[0]).shape;
4577 let rank = dy_shape.rank();
4578 let cols = dy_shape.dim(rank - 1).unwrap_static();
4579 let rows = dy_shape.num_elements().unwrap() / cols;
4580 Thunk::CumsumBackward {
4581 dy: node_offset(arena, node.inputs[0]),
4582 dx: node_offset(arena, node.id),
4583 rows: rows as u32,
4584 cols: cols as u32,
4585 exclusive: *exclusive,
4586 }
4587 }
4588
4589 Op::GatherBackward { .. } => {
4590 let dy_shape = &graph.node(node.inputs[0]).shape;
4591 let idx_shape = &graph.node(node.inputs[1]).shape;
4592 let out_shape = &node.shape;
4593 let rank = out_shape.rank();
4594 let axis = match &node.op {
4595 Op::GatherBackward { axis } => *axis,
4596 _ => 0,
4597 };
4598 let axis_u = if axis < 0 {
4599 (rank as i32 + axis) as usize
4600 } else {
4601 axis as usize
4602 };
4603 let outer: usize = (0..axis_u)
4604 .map(|i| dy_shape.dim(i).unwrap_static())
4605 .product::<usize>()
4606 .max(1);
4607 let num_idx = idx_shape.dim(axis_u).unwrap_static();
4608 let trailing: usize = (axis_u + 1..dy_shape.rank())
4609 .map(|i| dy_shape.dim(i).unwrap_static())
4610 .product::<usize>()
4611 .max(1);
4612 let axis_dim = out_shape.dim(axis_u).unwrap_static();
4613 Thunk::GatherBackward {
4614 dy: node_offset(arena, node.inputs[0]),
4615 indices: node_offset(arena, node.inputs[1]),
4616 dst: node_offset(arena, node.id),
4617 outer: outer as u32,
4618 axis_dim: axis_dim as u32,
4619 num_idx: num_idx as u32,
4620 trailing: trailing as u32,
4621 }
4622 }
4623
4624 Op::GroupNormBackwardInput { num_groups, eps }
4625 | Op::GroupNormBackwardGamma { num_groups, eps }
4626 | Op::GroupNormBackwardBeta { num_groups, eps } => {
4627 let x_shape = &graph.node(node.inputs[0]).shape;
4628 let n = x_shape.dim(0).unwrap_static() as u32;
4629 let c = x_shape.dim(1).unwrap_static() as u32;
4630 let h = x_shape.dim(2).unwrap_static() as u32;
4631 let w = x_shape.dim(3).unwrap_static() as u32;
4632 match &node.op {
4633 Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4634 x: node_offset(arena, node.inputs[0]),
4635 gamma: node_offset(arena, node.inputs[1]),
4636 beta: node_offset(arena, node.inputs[2]),
4637 dy: node_offset(arena, node.inputs[3]),
4638 dx: node_offset(arena, node.id),
4639 n,
4640 c,
4641 h,
4642 w,
4643 num_groups: *num_groups as u32,
4644 eps: *eps,
4645 },
4646 Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4647 x: node_offset(arena, node.inputs[0]),
4648 dy: node_offset(arena, node.inputs[1]),
4649 dgamma: node_offset(arena, node.id),
4650 n,
4651 c,
4652 h,
4653 w,
4654 num_groups: *num_groups as u32,
4655 eps: *eps,
4656 },
4657 Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4658 dy: node_offset(arena, node.inputs[1]),
4659 dbeta: node_offset(arena, node.id),
4660 n,
4661 c,
4662 h,
4663 w,
4664 },
4665 _ => unreachable!(),
4666 }
4667 }
4668
4669 Op::MaxPool2dBackward {
4670 kernel_size,
4671 stride,
4672 padding,
4673 } => {
4674 let x_shape = &graph.node(node.inputs[0]).shape;
4675 let dy_shape = &graph.node(node.inputs[1]).shape;
4676 if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4677 Thunk::MaxPool2dBackward {
4678 x: node_offset(arena, node.inputs[0]),
4679 dy: node_offset(arena, node.inputs[1]),
4680 dx: node_offset(arena, node.id),
4681 n: x_shape.dim(0).unwrap_static() as u32,
4682 c: x_shape.dim(1).unwrap_static() as u32,
4683 h: x_shape.dim(2).unwrap_static() as u32,
4684 w: x_shape.dim(3).unwrap_static() as u32,
4685 h_out: dy_shape.dim(2).unwrap_static() as u32,
4686 w_out: dy_shape.dim(3).unwrap_static() as u32,
4687 kh: kernel_size[0] as u32,
4688 kw: kernel_size[1] as u32,
4689 sh: stride.first().copied().unwrap_or(1) as u32,
4690 sw: stride.get(1).copied().unwrap_or(1) as u32,
4691 ph: padding.first().copied().unwrap_or(0) as u32,
4692 pw: padding.get(1).copied().unwrap_or(0) as u32,
4693 }
4694 } else {
4695 Thunk::Nop
4696 }
4697 }
4698
4699 Op::Conv2dBackwardInput {
4700 kernel_size,
4701 stride,
4702 padding,
4703 dilation,
4704 groups,
4705 } => {
4706 let dy_shape = &graph.node(node.inputs[0]).shape;
4707 let w_shape = &graph.node(node.inputs[1]).shape;
4708 let out_shape = &node.shape;
4709 if kernel_size.len() == 2
4710 && dy_shape.rank() == 4
4711 && w_shape.rank() == 4
4712 && out_shape.rank() == 4
4713 {
4714 Thunk::Conv2dBackwardInput {
4715 dy: node_offset(arena, node.inputs[0]),
4716 w: node_offset(arena, node.inputs[1]),
4717 dx: node_offset(arena, node.id),
4718 n: out_shape.dim(0).unwrap_static() as u32,
4719 c_in: out_shape.dim(1).unwrap_static() as u32,
4720 h: out_shape.dim(2).unwrap_static() as u32,
4721 w_in: out_shape.dim(3).unwrap_static() as u32,
4722 c_out: dy_shape.dim(1).unwrap_static() as u32,
4723 h_out: dy_shape.dim(2).unwrap_static() as u32,
4724 w_out: dy_shape.dim(3).unwrap_static() as u32,
4725 kh: kernel_size[0] as u32,
4726 kw: kernel_size[1] as u32,
4727 sh: stride.first().copied().unwrap_or(1) as u32,
4728 sw: stride.get(1).copied().unwrap_or(1) as u32,
4729 ph: padding.first().copied().unwrap_or(0) as u32,
4730 pw: padding.get(1).copied().unwrap_or(0) as u32,
4731 dh: dilation.first().copied().unwrap_or(1) as u32,
4732 dw: dilation.get(1).copied().unwrap_or(1) as u32,
4733 groups: *groups as u32,
4734 }
4735 } else {
4736 Thunk::Nop
4737 }
4738 }
4739
4740 Op::Conv2dBackwardWeight {
4741 kernel_size,
4742 stride,
4743 padding,
4744 dilation,
4745 groups,
4746 } => {
4747 let x_shape = &graph.node(node.inputs[0]).shape;
4748 let dy_shape = &graph.node(node.inputs[1]).shape;
4749 let dw_shape = &node.shape;
4750 if kernel_size.len() == 2
4751 && x_shape.rank() == 4
4752 && dy_shape.rank() == 4
4753 && dw_shape.rank() == 4
4754 {
4755 Thunk::Conv2dBackwardWeight {
4756 x: node_offset(arena, node.inputs[0]),
4757 dy: node_offset(arena, node.inputs[1]),
4758 dw: node_offset(arena, node.id),
4759 n: x_shape.dim(0).unwrap_static() as u32,
4760 c_in: x_shape.dim(1).unwrap_static() as u32,
4761 h: x_shape.dim(2).unwrap_static() as u32,
4762 w: x_shape.dim(3).unwrap_static() as u32,
4763 c_out: dy_shape.dim(1).unwrap_static() as u32,
4764 h_out: dy_shape.dim(2).unwrap_static() as u32,
4765 w_out: dy_shape.dim(3).unwrap_static() as u32,
4766 kh: kernel_size[0] as u32,
4767 kw: kernel_size[1] as u32,
4768 sh: stride.first().copied().unwrap_or(1) as u32,
4769 sw: stride.get(1).copied().unwrap_or(1) as u32,
4770 ph: padding.first().copied().unwrap_or(0) as u32,
4771 pw: padding.get(1).copied().unwrap_or(0) as u32,
4772 dh: dilation.first().copied().unwrap_or(1) as u32,
4773 dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4774 groups: *groups as u32,
4775 }
4776 } else {
4777 Thunk::Nop
4778 }
4779 }
4780
4781 Op::Im2Col {
4782 kernel_size,
4783 stride,
4784 padding,
4785 dilation,
4786 } => {
4787 let x_shape = &graph.node(node.inputs[0]).shape;
4788 let out_shape = &node.shape;
4789 if kernel_size.len() == 2 && x_shape.rank() == 4 && out_shape.rank() == 2 {
4790 let n = match x_shape.dim(0) {
4791 rlx_ir::shape::Dim::Static(v) => v as u32,
4792 _ => 0,
4793 };
4794 let c_in = x_shape.dim(1).unwrap_static() as u32;
4795 let h = x_shape.dim(2).unwrap_static() as u32;
4796 let w = x_shape.dim(3).unwrap_static() as u32;
4797 let kh = kernel_size[0] as u32;
4798 let kw = kernel_size[1] as u32;
4799 let sh = stride.first().copied().unwrap_or(1) as u32;
4800 let sw = stride.get(1).copied().unwrap_or(1) as u32;
4801 let ph = padding.first().copied().unwrap_or(0) as u32;
4802 let pw = padding.get(1).copied().unwrap_or(0) as u32;
4803 let dh = dilation.first().copied().unwrap_or(1) as u32;
4804 let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
4805 let h_out = rlx_ir::shape::conv2d_spatial_output(
4806 h as usize,
4807 kh as usize,
4808 sh as usize,
4809 ph as usize,
4810 dh as usize,
4811 ) as u32;
4812 let w_out = rlx_ir::shape::conv2d_spatial_output(
4813 w as usize,
4814 kw as usize,
4815 sw as usize,
4816 pw as usize,
4817 dw_dil as usize,
4818 ) as u32;
4819 Thunk::Im2Col {
4820 x: node_offset(arena, node.inputs[0]),
4821 col: node_offset(arena, node.id),
4822 n,
4823 c_in,
4824 h,
4825 w,
4826 h_out,
4827 w_out,
4828 kh,
4829 kw,
4830 sh,
4831 sw,
4832 ph,
4833 pw,
4834 dh,
4835 dw_dil,
4836 }
4837 } else {
4838 Thunk::Nop
4839 }
4840 }
4841
4842 Op::SoftmaxCrossEntropyWithLogits => {
4843 let logits_shape = &graph.node(node.inputs[0]).shape;
4844 if logits_shape.rank() == 2 {
4845 Thunk::SoftmaxCrossEntropy {
4846 logits: node_offset(arena, node.inputs[0]),
4847 labels: node_offset(arena, node.inputs[1]),
4848 dst: node_offset(arena, node.id),
4849 n: logits_shape.dim(0).unwrap_static() as u32,
4850 c: logits_shape.dim(1).unwrap_static() as u32,
4851 }
4852 } else {
4853 Thunk::Nop
4854 }
4855 }
4856
4857 Op::SoftmaxCrossEntropyBackward => {
4858 let logits_shape = &graph.node(node.inputs[0]).shape;
4859 if logits_shape.rank() == 2 {
4860 Thunk::SoftmaxCrossEntropyBackward {
4861 logits: node_offset(arena, node.inputs[0]),
4862 labels: node_offset(arena, node.inputs[1]),
4863 d_loss: node_offset(arena, node.inputs[2]),
4864 dlogits: node_offset(arena, node.id),
4865 n: logits_shape.dim(0).unwrap_static() as u32,
4866 c: logits_shape.dim(1).unwrap_static() as u32,
4867 }
4868 } else {
4869 Thunk::Nop
4870 }
4871 }
4872
4873 Op::DenseSolve => {
4874 let a_shape = &graph.node(node.inputs[0]).shape;
4876 let n = a_shape.dim(0).unwrap_static();
4877 debug_assert_eq!(
4878 n,
4879 a_shape.dim(1).unwrap_static(),
4880 "DenseSolve: A must be square"
4881 );
4882 let b_elems = node.shape.num_elements().unwrap();
4883 let nrhs = b_elems / n;
4884 match node.shape.dtype() {
4885 rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4886 a: node_offset(arena, node.inputs[0]),
4887 b: node_offset(arena, node.inputs[1]),
4888 x: node_offset(arena, node.id),
4889 n: n as u32,
4890 nrhs: nrhs as u32,
4891 },
4892 rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4893 a: node_offset(arena, node.inputs[0]),
4894 b: node_offset(arena, node.inputs[1]),
4895 x: node_offset(arena, node.id),
4896 n: n as u32,
4897 nrhs: nrhs as u32,
4898 },
4899 other => panic!(
4900 "DenseSolve: F32 + F64 lowered; got {other:?}. \
4901 Add another variant when needed."
4902 ),
4903 }
4904 }
4905
4906 Op::BatchedDenseSolve => {
4907 let a_shape = &graph.node(node.inputs[0]).shape;
4909 assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4910 let batch = a_shape.dim(0).unwrap_static();
4911 let n = a_shape.dim(1).unwrap_static();
4912 debug_assert_eq!(
4913 n,
4914 a_shape.dim(2).unwrap_static(),
4915 "BatchedDenseSolve: A's last two dims must match"
4916 );
4917 let total = node.shape.num_elements().unwrap();
4918 let nrhs = total / (batch * n);
4919 match node.shape.dtype() {
4920 rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4921 a: node_offset(arena, node.inputs[0]),
4922 b: node_offset(arena, node.inputs[1]),
4923 x: node_offset(arena, node.id),
4924 batch: batch as u32,
4925 n: n as u32,
4926 nrhs: nrhs as u32,
4927 },
4928 rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4929 a: node_offset(arena, node.inputs[0]),
4930 b: node_offset(arena, node.inputs[1]),
4931 x: node_offset(arena, node.id),
4932 batch: batch as u32,
4933 n: n as u32,
4934 nrhs: nrhs as u32,
4935 },
4936 other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4937 }
4938 }
4939
4940 Op::Scan {
4941 body,
4942 length,
4943 save_trajectory,
4944 num_bcast,
4945 num_xs,
4946 num_checkpoints,
4947 } => {
4948 assert!(
4949 *num_checkpoints == 0 || *num_checkpoints <= *length,
4950 "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4951 *num_checkpoints,
4952 *length
4953 );
4954 if *num_checkpoints != 0 && *num_checkpoints != *length {
4955 assert!(
4956 *save_trajectory,
4957 "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4958 );
4959 }
4960 let body_plan = rlx_opt::memory::plan_memory(body);
4971 let _body_arena_size = body_plan.arena_size;
4972 let body_offsets: HashMap<NodeId, usize> = body_plan
4975 .assignments
4976 .iter()
4977 .map(|(id, slot)| (*id, slot.offset))
4978 .collect();
4979
4980 let mut body_inputs: Vec<NodeId> = body
4983 .nodes()
4984 .iter()
4985 .filter(|n| matches!(n.op, Op::Input { .. }))
4986 .map(|n| n.id)
4987 .collect();
4988 body_inputs.sort();
4989 let n_body_inputs = body_inputs.len();
4990 let expected = 1 + *num_bcast as usize + *num_xs as usize;
4991 if n_body_inputs != expected {
4992 let names: Vec<String> = body
4993 .nodes()
4994 .iter()
4995 .filter_map(|n| match &n.op {
4996 Op::Input { name } => Some(format!("{}={}", n.id, name)),
4997 _ => None,
4998 })
4999 .collect();
5000 panic!(
5001 "Op::Scan body has {} Op::Input nodes; expected {} \
5002 (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
5003 n_body_inputs,
5004 expected,
5005 *num_bcast,
5006 *num_xs,
5007 names.join(", ")
5008 );
5009 }
5010
5011 let body_input_id = body_inputs[0];
5012 let body_input_off = body_offsets[&body_input_id];
5013 let body_output_id = body
5014 .outputs
5015 .first()
5016 .copied()
5017 .expect("Op::Scan body must declare one output");
5018 let body_output_off = body_offsets[&body_output_id];
5019
5020 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5021 for n in body.nodes() {
5024 if let Op::Constant { data } = &n.op
5025 && body_arena.has_buffer(n.id)
5026 && !data.is_empty()
5027 {
5028 match n.shape.dtype() {
5029 rlx_ir::DType::F64 => {
5030 let off = body_arena.byte_offset(n.id);
5031 let buf = body_arena.raw_buf_mut();
5032 let nbytes = (buf.len() - off).min(data.len());
5033 buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
5034 }
5035 _ => {
5036 let buf = body_arena.slice_mut(n.id);
5037 let n_floats = data.len() / 4;
5038 let n_lim = buf.len().min(n_floats);
5039 for i in 0..n_lim {
5040 let bytes = [
5041 data[i * 4],
5042 data[i * 4 + 1],
5043 data[i * 4 + 2],
5044 data[i * 4 + 3],
5045 ];
5046 buf[i] = f32::from_le_bytes(bytes);
5047 }
5048 }
5049 }
5050 }
5051 }
5052 let body_init = body_arena.raw_buf().to_vec();
5053 let body_schedule = compile_thunks_with_rng(body, &body_arena, rng);
5054
5055 let carry_bytes = if *save_trajectory {
5060 let total = node
5061 .shape
5062 .size_bytes()
5063 .expect("Op::Scan trajectory output must have static shape");
5064 total / *length as usize
5065 } else {
5066 node.shape
5067 .size_bytes()
5068 .expect("Op::Scan carry must have static shape")
5069 };
5070
5071 let mut bcast_inputs: Vec<(usize, usize, u32)> =
5076 Vec::with_capacity(*num_bcast as usize);
5077 for i in 0..*num_bcast as usize {
5078 let body_b_id = body_inputs[1 + i];
5079 let body_b_off = body_offsets[&body_b_id];
5080 let outer_b_id = node.inputs[1 + i];
5081 let outer_b_off = node_offset(arena, outer_b_id);
5082 let outer_b_shape = &graph.node(outer_b_id).shape;
5083 let total = outer_b_shape
5084 .size_bytes()
5085 .expect("Op::Scan bcast must have static shape");
5086 bcast_inputs.push((body_b_off, outer_b_off, total as u32));
5087 }
5088
5089 let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
5093 let xs_base = 1 + *num_bcast as usize;
5094 for i in 0..*num_xs as usize {
5095 let body_x_id = body_inputs[xs_base + i];
5096 let body_x_off = body_offsets[&body_x_id];
5097 let outer_xs_id = node.inputs[xs_base + i];
5098 let outer_xs_off = node_offset(arena, outer_xs_id);
5099 let outer_xs_shape = &graph.node(outer_xs_id).shape;
5100 let total = outer_xs_shape
5101 .size_bytes()
5102 .expect("Op::Scan xs must have static shape");
5103 let per_step = total / *length as usize;
5104 xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
5105 }
5106
5107 Thunk::Scan {
5108 body: Arc::new(body_schedule),
5109 body_init: Arc::new(body_init),
5110 body_input_off,
5111 body_output_off,
5112 outer_init_off: node_offset(arena, node.inputs[0]),
5113 outer_final_off: node_offset(arena, node.id),
5114 length: *length,
5115 carry_bytes: carry_bytes as u32,
5116 save_trajectory: *save_trajectory,
5117 xs_inputs: Arc::new(xs_inputs),
5118 bcast_inputs: Arc::new(bcast_inputs),
5119 num_checkpoints: *num_checkpoints,
5120 }
5121 }
5122
5123 Op::ScanBackward {
5124 body_vjp,
5125 length,
5126 save_trajectory,
5127 num_xs,
5128 num_checkpoints,
5129 forward_body,
5130 } => {
5131 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
5132 if is_recursive {
5133 assert!(
5134 forward_body.is_some(),
5135 "Op::ScanBackward with num_checkpoints<length requires forward_body"
5136 );
5137 }
5138 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
5146 let body_offsets: HashMap<NodeId, usize> = body_plan
5147 .assignments
5148 .iter()
5149 .map(|(id, slot)| (*id, slot.offset))
5150 .collect();
5151 let mut body_d_output_off: Option<usize> = None;
5152 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
5153 for n in body_vjp.nodes() {
5154 if let Op::Input { name } = &n.op {
5155 let off = body_offsets[&n.id];
5156 if name == "d_output" {
5157 body_d_output_off = Some(off);
5158 } else {
5159 body_other_inputs.push((n.id, off));
5160 }
5161 }
5162 }
5163 body_other_inputs.sort_by_key(|(id, _)| *id);
5164 let body_d_output_off =
5165 body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
5166 let expected_others = 1 + *num_xs as usize;
5167 assert_eq!(
5168 body_other_inputs.len(),
5169 expected_others,
5170 "ScanBackward body_vjp has {} non-d_output Inputs; \
5171 expected {} (1 carry + {} xs)",
5172 body_other_inputs.len(),
5173 expected_others,
5174 num_xs
5175 );
5176 let body_carry_in_off = body_other_inputs[0].1;
5177 let body_x_offs: Vec<usize> = body_other_inputs
5178 .iter()
5179 .skip(1)
5180 .map(|(_, off)| *off)
5181 .collect();
5182 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
5183
5184 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5185 for n in body_vjp.nodes() {
5187 if let Op::Constant { data } = &n.op
5188 && body_arena.has_buffer(n.id)
5189 && !data.is_empty()
5190 {
5191 match n.shape.dtype() {
5192 rlx_ir::DType::F64 => {
5193 let off = body_arena.byte_offset(n.id);
5194 let buf = body_arena.raw_buf_mut();
5195 let nb = (buf.len() - off).min(data.len());
5196 buf[off..off + nb].copy_from_slice(&data[..nb]);
5197 }
5198 _ => {
5199 let buf = body_arena.slice_mut(n.id);
5200 let nf = data.len() / 4;
5201 let nl = buf.len().min(nf);
5202 for i in 0..nl {
5203 let bytes = [
5204 data[i * 4],
5205 data[i * 4 + 1],
5206 data[i * 4 + 2],
5207 data[i * 4 + 3],
5208 ];
5209 buf[i] = f32::from_le_bytes(bytes);
5210 }
5211 }
5212 }
5213 }
5214 }
5215 let body_init = body_arena.raw_buf().to_vec();
5216 let body_schedule = compile_thunks_with_rng(body_vjp, &body_arena, rng);
5217
5218 let carry_bytes = body_vjp
5220 .node(body_vjp.outputs[0])
5221 .shape
5222 .size_bytes()
5223 .expect("ScanBackward dcarry must be statically shaped");
5224 let carry_elem_size = body_vjp
5225 .node(body_vjp.outputs[0])
5226 .shape
5227 .dtype()
5228 .size_bytes() as u32;
5229
5230 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
5233 for i in 0..*num_xs as usize {
5234 let outer_xs_id = node.inputs[3 + i];
5235 let outer_xs_off = node_offset(arena, outer_xs_id);
5236 let outer_xs_shape = &graph.node(outer_xs_id).shape;
5237 let total = outer_xs_shape
5238 .size_bytes()
5239 .expect("ScanBackward xs must have static shape");
5240 let per_step = total / *length as usize;
5241 outer_xs_offs.push((outer_xs_off, per_step as u32));
5242 }
5243
5244 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
5249 if is_recursive {
5250 let fb = forward_body.as_ref().unwrap();
5251 let fb_plan = rlx_opt::memory::plan_memory(fb);
5252 let fb_offsets: HashMap<NodeId, usize> = fb_plan
5253 .assignments
5254 .iter()
5255 .map(|(id, slot)| (*id, slot.offset))
5256 .collect();
5257 let mut fb_inputs: Vec<NodeId> = fb
5258 .nodes()
5259 .iter()
5260 .filter(|n| matches!(n.op, Op::Input { .. }))
5261 .map(|n| n.id)
5262 .collect();
5263 fb_inputs.sort();
5264 let fb_carry = fb_offsets[&fb_inputs[0]];
5265 let fb_xs: Vec<usize> = (1..fb_inputs.len())
5266 .map(|i| fb_offsets[&fb_inputs[i]])
5267 .collect();
5268 let fb_out = fb_offsets[&fb.outputs[0]];
5269 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
5270 for n in fb.nodes() {
5271 if let Op::Constant { data } = &n.op
5272 && fb_arena.has_buffer(n.id)
5273 && !data.is_empty()
5274 {
5275 let off = fb_arena.byte_offset(n.id);
5282 let buf = fb_arena.raw_buf_mut();
5283 let nb = (buf.len() - off).min(data.len());
5284 buf[off..off + nb].copy_from_slice(&data[..nb]);
5285 }
5286 }
5287 let fb_init_bytes = fb_arena.raw_buf().to_vec();
5288 let fb_sched = compile_thunks_with_rng(fb, &fb_arena, rng);
5289 (
5290 Some(Arc::new(fb_sched)),
5291 Some(Arc::new(fb_init_bytes)),
5292 fb_carry,
5293 fb_out,
5294 fb_xs,
5295 )
5296 } else {
5297 (None, None, 0, 0, Vec::new())
5298 };
5299
5300 Thunk::ScanBackward {
5301 body_vjp: Arc::new(body_schedule),
5302 body_init: Arc::new(body_init),
5303 body_carry_in_off,
5304 body_x_offs: Arc::new(body_x_offs),
5305 body_d_output_off,
5306 body_dcarry_out_off,
5307 outer_init_off: node_offset(arena, node.inputs[0]),
5308 outer_traj_off: node_offset(arena, node.inputs[1]),
5309 outer_upstream_off: node_offset(arena, node.inputs[2]),
5310 outer_xs_offs: Arc::new(outer_xs_offs),
5311 outer_dinit_off: node_offset(arena, node.id),
5312 length: *length,
5313 carry_bytes: carry_bytes as u32,
5314 carry_elem_size,
5315 save_trajectory: *save_trajectory,
5316 num_checkpoints: *num_checkpoints,
5317 forward_body: fb_schedule,
5318 forward_body_init: fb_init,
5319 forward_body_carry_in_off: fb_carry_in_off,
5320 forward_body_output_off: fb_output_off,
5321 forward_body_x_offs: Arc::new(fb_x_offs),
5322 }
5323 }
5324
5325 Op::ScanBackwardXs {
5326 body_vjp,
5327 length,
5328 save_trajectory,
5329 num_xs,
5330 xs_idx,
5331 num_checkpoints,
5332 forward_body,
5333 } => {
5334 assert!(
5335 *num_checkpoints == 0 || *num_checkpoints <= *length,
5336 "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
5337 *num_checkpoints,
5338 *length
5339 );
5340 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
5341 if is_recursive {
5342 assert!(
5343 forward_body.is_some(),
5344 "Op::ScanBackwardXs with num_checkpoints<length \
5345 requires forward_body"
5346 );
5347 }
5348 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
5356 let body_offsets: HashMap<NodeId, usize> = body_plan
5357 .assignments
5358 .iter()
5359 .map(|(id, slot)| (*id, slot.offset))
5360 .collect();
5361 let mut body_d_output_off: Option<usize> = None;
5362 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
5363 for n in body_vjp.nodes() {
5364 if let Op::Input { name } = &n.op {
5365 let off = body_offsets[&n.id];
5366 if name == "d_output" {
5367 body_d_output_off = Some(off);
5368 } else {
5369 body_other_inputs.push((n.id, off));
5370 }
5371 }
5372 }
5373 body_other_inputs.sort_by_key(|(id, _)| *id);
5374 let body_d_output_off =
5375 body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
5376 let expected_others = 1 + *num_xs as usize;
5377 assert_eq!(
5378 body_other_inputs.len(),
5379 expected_others,
5380 "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
5381 body_other_inputs.len(),
5382 expected_others
5383 );
5384 let body_carry_in_off = body_other_inputs[0].1;
5385 let body_x_offs: Vec<usize> = body_other_inputs
5386 .iter()
5387 .skip(1)
5388 .map(|(_, off)| *off)
5389 .collect();
5390 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
5391 let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
5392 let body_dxs_out_off = body_offsets[&dxs_out_node];
5393
5394 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5395 for n in body_vjp.nodes() {
5396 if let Op::Constant { data } = &n.op
5397 && body_arena.has_buffer(n.id)
5398 && !data.is_empty()
5399 {
5400 match n.shape.dtype() {
5401 rlx_ir::DType::F64 => {
5402 let off = body_arena.byte_offset(n.id);
5403 let buf = body_arena.raw_buf_mut();
5404 let nb = (buf.len() - off).min(data.len());
5405 buf[off..off + nb].copy_from_slice(&data[..nb]);
5406 }
5407 _ => {
5408 let buf = body_arena.slice_mut(n.id);
5409 let nf = data.len() / 4;
5410 let nl = buf.len().min(nf);
5411 for i in 0..nl {
5412 let bytes = [
5413 data[i * 4],
5414 data[i * 4 + 1],
5415 data[i * 4 + 2],
5416 data[i * 4 + 3],
5417 ];
5418 buf[i] = f32::from_le_bytes(bytes);
5419 }
5420 }
5421 }
5422 }
5423 }
5424 let body_init = body_arena.raw_buf().to_vec();
5425 let body_schedule = compile_thunks_with_rng(body_vjp, &body_arena, rng);
5426
5427 let carry_bytes = body_vjp
5428 .node(body_vjp.outputs[0])
5429 .shape
5430 .size_bytes()
5431 .expect("ScanBackwardXs dcarry must be statically shaped");
5432 let carry_elem_size = body_vjp
5433 .node(body_vjp.outputs[0])
5434 .shape
5435 .dtype()
5436 .size_bytes() as u32;
5437 let per_step_bytes = body_vjp
5438 .node(dxs_out_node)
5439 .shape
5440 .size_bytes()
5441 .expect("ScanBackwardXs dxs body output must be statically shaped");
5442
5443 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
5444 for i in 0..*num_xs as usize {
5445 let outer_xs_id = node.inputs[3 + i];
5446 let outer_xs_off = node_offset(arena, outer_xs_id);
5447 let outer_xs_shape = &graph.node(outer_xs_id).shape;
5448 let total = outer_xs_shape
5449 .size_bytes()
5450 .expect("ScanBackwardXs xs must have static shape");
5451 let per_step = total / *length as usize;
5452 outer_xs_offs.push((outer_xs_off, per_step as u32));
5453 }
5454
5455 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
5458 if is_recursive {
5459 let fb = forward_body.as_ref().unwrap();
5460 let fb_plan = rlx_opt::memory::plan_memory(fb);
5461 let fb_offsets: HashMap<NodeId, usize> = fb_plan
5462 .assignments
5463 .iter()
5464 .map(|(id, slot)| (*id, slot.offset))
5465 .collect();
5466 let mut fb_inputs: Vec<NodeId> = fb
5467 .nodes()
5468 .iter()
5469 .filter(|n| matches!(n.op, Op::Input { .. }))
5470 .map(|n| n.id)
5471 .collect();
5472 fb_inputs.sort();
5473 let fb_carry = fb_offsets[&fb_inputs[0]];
5474 let fb_xs: Vec<usize> = (1..fb_inputs.len())
5475 .map(|i| fb_offsets[&fb_inputs[i]])
5476 .collect();
5477 let fb_out = fb_offsets[&fb.outputs[0]];
5478 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
5479 for n in fb.nodes() {
5480 if let Op::Constant { data } = &n.op
5481 && fb_arena.has_buffer(n.id)
5482 && !data.is_empty()
5483 {
5484 let off = fb_arena.byte_offset(n.id);
5491 let buf = fb_arena.raw_buf_mut();
5492 let nb = (buf.len() - off).min(data.len());
5493 buf[off..off + nb].copy_from_slice(&data[..nb]);
5494 }
5495 }
5496 let fb_init_bytes = fb_arena.raw_buf().to_vec();
5497 let fb_sched = compile_thunks_with_rng(fb, &fb_arena, rng);
5498 (
5499 Some(Arc::new(fb_sched)),
5500 Some(Arc::new(fb_init_bytes)),
5501 fb_carry,
5502 fb_out,
5503 fb_xs,
5504 )
5505 } else {
5506 (None, None, 0, 0, Vec::new())
5507 };
5508
5509 Thunk::ScanBackwardXs {
5510 body_vjp: Arc::new(body_schedule),
5511 body_init: Arc::new(body_init),
5512 body_carry_in_off,
5513 body_x_offs: Arc::new(body_x_offs),
5514 body_d_output_off,
5515 body_dcarry_out_off,
5516 body_dxs_out_off,
5517 outer_init_off: node_offset(arena, node.inputs[0]),
5518 outer_traj_off: node_offset(arena, node.inputs[1]),
5519 outer_upstream_off: node_offset(arena, node.inputs[2]),
5520 outer_xs_offs: Arc::new(outer_xs_offs),
5521 outer_dxs_off: node_offset(arena, node.id),
5522 length: *length,
5523 carry_bytes: carry_bytes as u32,
5524 carry_elem_size,
5525 per_step_bytes: per_step_bytes as u32,
5526 save_trajectory: *save_trajectory,
5527 num_checkpoints: *num_checkpoints,
5528 forward_body: fb_schedule,
5529 forward_body_init: fb_init,
5530 forward_body_carry_in_off: fb_carry_in_off,
5531 forward_body_output_off: fb_output_off,
5532 forward_body_x_offs: Arc::new(fb_x_offs),
5533 }
5534 }
5535
5536 Op::Concat { axis } => {
5537 let out_shape = &node.shape;
5541 let rank = out_shape.rank();
5542 let outer: usize = (0..*axis)
5543 .map(|i| out_shape.dim(i).unwrap_static())
5544 .product::<usize>()
5545 .max(1);
5546 let inner: usize = (*axis + 1..rank)
5547 .map(|i| out_shape.dim(i).unwrap_static())
5548 .product::<usize>()
5549 .max(1);
5550 let total_axis = out_shape.dim(*axis).unwrap_static();
5551 let inputs: Vec<(usize, u32, u32)> = node
5552 .inputs
5553 .iter()
5554 .map(|&in_id| {
5555 let in_shape = &graph.node(in_id).shape;
5556 let in_axis = concat_axis_extent(in_shape, *axis, rank);
5557 let in_numel = in_shape.num_elements().unwrap_or(0) as u32;
5558 (node_offset(arena, in_id), in_axis as u32, in_numel)
5559 })
5560 .collect();
5561 let dst = node_offset(arena, node.id);
5562 match out_shape.dtype() {
5563 rlx_ir::DType::F64 => Thunk::ConcatF64 {
5564 dst,
5565 outer: outer as u32,
5566 inner: inner as u32,
5567 total_axis: total_axis as u32,
5568 inputs,
5569 },
5570 _ => Thunk::Concat {
5571 dst,
5572 outer: outer as u32,
5573 inner: inner as u32,
5574 total_axis: total_axis as u32,
5575 inputs,
5576 },
5577 }
5578 }
5579
5580 Op::GaussianSplatRender {
5581 width,
5582 height,
5583 tile_size,
5584 radius_scale,
5585 alpha_cutoff,
5586 max_splat_steps,
5587 transmittance_threshold,
5588 max_list_entries,
5589 } => {
5590 let elem_len =
5591 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5592 Thunk::GaussianSplatRender {
5593 positions_off: node_offset(arena, node.inputs[0]),
5594 positions_len: elem_len(node.inputs[0]),
5595 scales_off: node_offset(arena, node.inputs[1]),
5596 scales_len: elem_len(node.inputs[1]),
5597 rotations_off: node_offset(arena, node.inputs[2]),
5598 rotations_len: elem_len(node.inputs[2]),
5599 opacities_off: node_offset(arena, node.inputs[3]),
5600 opacities_len: elem_len(node.inputs[3]),
5601 colors_off: node_offset(arena, node.inputs[4]),
5602 colors_len: elem_len(node.inputs[4]),
5603 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5604 sh_coeffs_len: elem_len(node.inputs[5]),
5605 meta_off: node_offset(arena, node.inputs[6]),
5606 dst_off: node_offset(arena, node.id),
5607 dst_len: node.shape.num_elements().unwrap_or(0),
5608 width: *width,
5609 height: *height,
5610 tile_size: *tile_size,
5611 radius_scale: *radius_scale,
5612 alpha_cutoff: *alpha_cutoff,
5613 max_splat_steps: *max_splat_steps,
5614 transmittance_threshold: *transmittance_threshold,
5615 max_list_entries: *max_list_entries,
5616 }
5617 }
5618
5619 Op::GaussianSplatRenderBackward {
5620 width,
5621 height,
5622 tile_size,
5623 radius_scale,
5624 alpha_cutoff,
5625 max_splat_steps,
5626 transmittance_threshold,
5627 max_list_entries,
5628 loss_grad_clip,
5629 sh_band,
5630 max_anisotropy,
5631 } => {
5632 let elem_len =
5633 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5634 Thunk::GaussianSplatRenderBackward {
5635 positions_off: node_offset(arena, node.inputs[0]),
5636 positions_len: elem_len(node.inputs[0]),
5637 scales_off: node_offset(arena, node.inputs[1]),
5638 scales_len: elem_len(node.inputs[1]),
5639 rotations_off: node_offset(arena, node.inputs[2]),
5640 rotations_len: elem_len(node.inputs[2]),
5641 opacities_off: node_offset(arena, node.inputs[3]),
5642 opacities_len: elem_len(node.inputs[3]),
5643 colors_off: node_offset(arena, node.inputs[4]),
5644 colors_len: elem_len(node.inputs[4]),
5645 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5646 sh_coeffs_len: elem_len(node.inputs[5]),
5647 meta_off: node_offset(arena, node.inputs[6]),
5648 d_loss_off: node_offset(arena, node.inputs[7]),
5649 d_loss_len: elem_len(node.inputs[7]),
5650 packed_off: node_offset(arena, node.id),
5651 packed_len: node.shape.num_elements().unwrap_or(0),
5652 width: *width,
5653 height: *height,
5654 tile_size: *tile_size,
5655 radius_scale: *radius_scale,
5656 alpha_cutoff: *alpha_cutoff,
5657 max_splat_steps: *max_splat_steps,
5658 transmittance_threshold: *transmittance_threshold,
5659 max_list_entries: *max_list_entries,
5660 loss_grad_clip: *loss_grad_clip,
5661 sh_band: *sh_band,
5662 max_anisotropy: *max_anisotropy,
5663 }
5664 }
5665
5666 Op::GaussianSplatPrepare {
5667 width,
5668 height,
5669 tile_size,
5670 radius_scale,
5671 alpha_cutoff,
5672 max_splat_steps,
5673 transmittance_threshold,
5674 max_list_entries,
5675 } => {
5676 let elem_len =
5677 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5678 Thunk::GaussianSplatPrepare {
5679 positions_off: node_offset(arena, node.inputs[0]),
5680 positions_len: elem_len(node.inputs[0]),
5681 scales_off: node_offset(arena, node.inputs[1]),
5682 scales_len: elem_len(node.inputs[1]),
5683 rotations_off: node_offset(arena, node.inputs[2]),
5684 rotations_len: elem_len(node.inputs[2]),
5685 opacities_off: node_offset(arena, node.inputs[3]),
5686 opacities_len: elem_len(node.inputs[3]),
5687 colors_off: node_offset(arena, node.inputs[4]),
5688 colors_len: elem_len(node.inputs[4]),
5689 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5690 sh_coeffs_len: elem_len(node.inputs[5]),
5691 meta_off: node_offset(arena, node.inputs[6]),
5692 meta_len: elem_len(node.inputs[6]),
5693 prep_off: node_offset(arena, node.id),
5694 prep_len: node.shape.num_elements().unwrap_or(0),
5695 width: *width,
5696 height: *height,
5697 tile_size: *tile_size,
5698 radius_scale: *radius_scale,
5699 alpha_cutoff: *alpha_cutoff,
5700 max_splat_steps: *max_splat_steps,
5701 transmittance_threshold: *transmittance_threshold,
5702 max_list_entries: *max_list_entries,
5703 }
5704 }
5705
5706 Op::GaussianSplatRasterize {
5707 width,
5708 height,
5709 tile_size,
5710 alpha_cutoff,
5711 max_splat_steps,
5712 transmittance_threshold,
5713 max_list_entries,
5714 } => {
5715 let elem_len =
5716 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5717 let prep_id = node.inputs[0];
5718 let count = match &graph.node(prep_id).op {
5719 rlx_ir::Op::GaussianSplatPrepare { .. } => {
5720 elem_len(graph.node(prep_id).inputs[0]) / 3
5721 }
5722 _ => 1,
5723 };
5724 Thunk::GaussianSplatRasterize {
5725 prep_off: node_offset(arena, prep_id),
5726 prep_len: elem_len(prep_id),
5727 meta_off: node_offset(arena, node.inputs[1]),
5728 meta_len: elem_len(node.inputs[1]),
5729 dst_off: node_offset(arena, node.id),
5730 dst_len: node.shape.num_elements().unwrap_or(0),
5731 count,
5732 width: *width,
5733 height: *height,
5734 tile_size: *tile_size,
5735 alpha_cutoff: *alpha_cutoff,
5736 max_splat_steps: *max_splat_steps,
5737 transmittance_threshold: *transmittance_threshold,
5738 max_list_entries: *max_list_entries,
5739 }
5740 }
5741
5742 Op::Custom { name, attrs, .. } => {
5743 let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5744 panic!(
5745 "compile_thunks: no CPU kernel registered for \
5746 Op::Custom('{name}'). Register one via \
5747 rlx_cpu::op_registry::register_cpu_kernel \
5748 before compiling on the CPU backend."
5749 )
5750 });
5751 let inputs_v: Vec<(usize, u32, Shape)> = node
5752 .inputs
5753 .iter()
5754 .map(|&in_id| {
5755 let s = graph.node(in_id).shape.clone();
5756 let len = s.num_elements().unwrap_or(0) as u32;
5757 (node_offset(arena, in_id), len, s)
5758 })
5759 .collect();
5760 let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5761 Thunk::CustomOp {
5762 kernel,
5763 inputs: inputs_v,
5764 output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5765 attrs: attrs.clone(),
5766 }
5767 }
5768
5769 Op::Fft { inverse, norm } => {
5770 let shape = &node.shape;
5771 let meta = rlx_ir::fft::fft_meta(shape);
5772 let dtype = shape.dtype();
5773 assert!(
5774 matches!(
5775 dtype,
5776 rlx_ir::DType::F32 | rlx_ir::DType::F64 | rlx_ir::DType::C64
5777 ),
5778 "Op::Fft on CPU requires F32, F64, or C64, got {dtype:?}"
5779 );
5780 Thunk::Fft1d {
5781 src: node_offset(arena, node.inputs[0]),
5782 dst: node_offset(arena, node.id),
5783 outer: meta.outer as u32,
5784 n_complex: meta.n_complex as u32,
5785 inverse: *inverse,
5786 norm_tag: norm.tag(),
5787 dtype,
5788 }
5789 }
5790
5791 Op::FftButterflyStage { stage, n_fft } => {
5792 let state_shape = graph.node(node.inputs[0]).shape.clone();
5793 assert_eq!(
5794 state_shape.dtype(),
5795 rlx_ir::DType::F32,
5796 "Op::FftButterflyStage requires F32 state"
5797 );
5798 let batch = state_shape.dim(0).unwrap_static() as u32;
5799 Thunk::FftButterflyStage {
5800 state_src: node_offset(arena, node.inputs[0]),
5801 state_dst: node_offset(arena, node.id),
5802 gate_src: node_offset(arena, node.inputs[1]),
5803 rev_src: node_offset(arena, node.inputs[2]),
5804 tw_re_src: node_offset(arena, node.inputs[3]),
5805 tw_im_src: node_offset(arena, node.inputs[4]),
5806 batch,
5807 n_fft: *n_fft,
5808 stage: *stage,
5809 }
5810 }
5811
5812 Op::LogMel => {
5813 let spec_shape = graph.node(node.inputs[0]).shape.clone();
5814 let filt_shape = graph.node(node.inputs[1]).shape.clone();
5815 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5816 .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
5817 Thunk::LogMel {
5818 spec: node_offset(arena, node.inputs[0]),
5819 filters: node_offset(arena, node.inputs[1]),
5820 dst: node_offset(arena, node.id),
5821 outer: meta.outer as u32,
5822 n_fft: meta.n_fft as u32,
5823 n_bins: meta.n_bins as u32,
5824 n_mels: meta.n_mels as u32,
5825 }
5826 }
5827
5828 Op::LogMelBackward => {
5829 let spec_shape = graph.node(node.inputs[0]).shape.clone();
5830 let filt_shape = graph.node(node.inputs[1]).shape.clone();
5831 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5832 .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
5833 Thunk::LogMelBackward {
5834 spec: node_offset(arena, node.inputs[0]),
5835 filters: node_offset(arena, node.inputs[1]),
5836 dy: node_offset(arena, node.inputs[2]),
5837 dst: node_offset(arena, node.id),
5838 outer: meta.outer as u32,
5839 n_fft: meta.n_fft as u32,
5840 n_bins: meta.n_bins as u32,
5841 n_mels: meta.n_mels as u32,
5842 }
5843 }
5844
5845 Op::WelchPeaks { k, n_segments } => {
5846 let spec_shape = graph.node(node.inputs[0]).shape.clone();
5847 let meta = rlx_ir::audio::welch_peaks_meta(&spec_shape, *k, *n_segments)
5848 .unwrap_or_else(|e| panic!("Op::WelchPeaks: {e}"));
5849 Thunk::WelchPeaks {
5850 spec: node_offset(arena, node.inputs[0]),
5851 dst: node_offset(arena, node.id),
5852 welch_batch: meta.welch_batch as u32,
5853 n_fft: meta.n_fft as u32,
5854 n_segments: meta.n_segments as u32,
5855 k: meta.k as u32,
5856 }
5857 }
5858
5859 Op::CustomFn {
5860 fwd_body,
5861 num_inputs,
5862 ..
5863 } => {
5864 let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5870 let body_offsets: HashMap<NodeId, usize> = body_plan
5871 .assignments
5872 .iter()
5873 .map(|(id, slot)| (*id, slot.offset))
5874 .collect();
5875
5876 let mut body_input_ids: Vec<NodeId> = fwd_body
5877 .nodes()
5878 .iter()
5879 .filter(|n| matches!(n.op, Op::Input { .. }))
5880 .map(|n| n.id)
5881 .collect();
5882 body_input_ids.sort();
5883 assert_eq!(
5884 body_input_ids.len(),
5885 *num_inputs as usize,
5886 "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5887 body_input_ids.len(),
5888 *num_inputs,
5889 );
5890
5891 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5892 for n in fwd_body.nodes() {
5893 if let Op::Constant { data } = &n.op
5894 && body_arena.has_buffer(n.id)
5895 && !data.is_empty()
5896 {
5897 match n.shape.dtype() {
5898 rlx_ir::DType::F64 => {
5899 let off = body_arena.byte_offset(n.id);
5900 let buf = body_arena.raw_buf_mut();
5901 let nb = (buf.len() - off).min(data.len());
5902 buf[off..off + nb].copy_from_slice(&data[..nb]);
5903 }
5904 _ => {
5905 let buf = body_arena.slice_mut(n.id);
5906 let nf = data.len() / 4;
5907 let nl = buf.len().min(nf);
5908 for i in 0..nl {
5909 let bytes = [
5910 data[i * 4],
5911 data[i * 4 + 1],
5912 data[i * 4 + 2],
5913 data[i * 4 + 3],
5914 ];
5915 buf[i] = f32::from_le_bytes(bytes);
5916 }
5917 }
5918 }
5919 }
5920 }
5921 let body_init = body_arena.raw_buf().to_vec();
5922 let body_schedule = compile_thunks_with_rng(fwd_body, &body_arena, rng);
5923
5924 let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5926 .map(|i| {
5927 let body_in = body_input_ids[i];
5928 let body_off = body_offsets[&body_in];
5929 let outer_in = node.inputs[i];
5930 let outer_off = node_offset(arena, outer_in);
5931 let bytes = graph
5932 .node(outer_in)
5933 .shape
5934 .size_bytes()
5935 .expect("Op::CustomFn primal input must have static shape");
5936 (body_off, outer_off, bytes as u32)
5937 })
5938 .collect();
5939
5940 let body_output_id = fwd_body
5941 .outputs
5942 .first()
5943 .copied()
5944 .expect("Op::CustomFn fwd_body must declare exactly one output");
5945 let body_output_off = body_offsets[&body_output_id];
5946 let out_bytes = node
5947 .shape
5948 .size_bytes()
5949 .expect("Op::CustomFn output must have static shape");
5950
5951 Thunk::CustomFn {
5952 body: Arc::new(body_schedule),
5953 body_init: Arc::new(body_init),
5954 inputs: Arc::new(inputs_v),
5955 body_output_off,
5956 outer_output_off: node_offset(arena, node.id),
5957 out_bytes: out_bytes as u32,
5958 }
5959 }
5960
5961 _ => Thunk::Nop,
5962 };
5963 thunks.push(t);
5964 }
5965
5966 let cfg = crate::config::RuntimeConfig::global();
5967 let mask_thr = cfg.mask_binary_threshold;
5968 let mask_neg = cfg.attn_mask_neg_inf;
5969 let score_skip = cfg.score_skip_threshold;
5970
5971 let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5973 .iter()
5974 .filter(|t| !matches!(t, Thunk::Nop))
5975 .map(|thunk| {
5976 match thunk.clone() {
5977 Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5978
5979 Thunk::Sgemm { a, b, c, m, k, n } => {
5980 let (m, k, n) = (m as usize, k as usize, n as usize);
5981 Arc::new(move |base: *mut u8| unsafe {
5982 crate::blas::sgemm(
5983 sl(a, base, m * k),
5984 sl(b, base, k * n),
5985 sl_mut(c, base, m * n),
5986 m,
5987 k,
5988 n,
5989 );
5990 })
5991 }
5992
5993 Thunk::CgemmC64 { a, b, c, m, k, n } => {
5994 let (m, k, n) = (m as usize, k as usize, n as usize);
5995 Arc::new(move |base: *mut u8| unsafe {
5996 cgemm_c64(a, b, c, m, k, n, base);
5997 })
5998 }
5999
6000 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
6001 let (n_, nrhs_) = (n as usize, nrhs as usize);
6002 Arc::new(move |base: *mut u8| unsafe {
6003 let a_src = sl_f64(a, base, n_ * n_);
6004 let b_src = sl_f64(b, base, n_ * nrhs_);
6005 let mut a_scratch: Vec<f64> = a_src.to_vec();
6006 let mut x_buf: Vec<f64> = b_src.to_vec();
6007 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
6008 if info != 0 {
6009 panic!("DenseSolveF64: singular (info={info})");
6010 }
6011 sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
6012 })
6013 }
6014
6015 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
6016 let (n_, nrhs_) = (n as usize, nrhs as usize);
6017 Arc::new(move |base: *mut u8| unsafe {
6018 let a_src = sl(a, base, n_ * n_);
6019 let b_src = sl(b, base, n_ * nrhs_);
6020 let mut a_scratch: Vec<f32> = a_src.to_vec();
6021 let mut x_buf: Vec<f32> = b_src.to_vec();
6022 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
6023 if info != 0 {
6024 panic!("DenseSolveF32: singular (info={info})");
6025 }
6026 sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
6027 })
6028 }
6029
6030 Thunk::FusedMmBiasAct {
6031 a,
6032 w,
6033 bias,
6034 c,
6035 m,
6036 k,
6037 n,
6038 act,
6039 } => {
6040 let (m, k, n) = (m as usize, k as usize, n as usize);
6041 Arc::new(move |base: *mut u8| unsafe {
6042 let out = sl_mut(c, base, m * n);
6043 crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
6044 match act {
6052 Some(Activation::Gelu) => {
6053 crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
6054 }
6055 Some(other) => {
6056 crate::blas::bias_add(out, sl(bias, base, n), m, n);
6057 apply_activation_inplace(out, other);
6058 }
6059 None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
6060 }
6061 })
6062 }
6063
6064 Thunk::FusedResidualLN {
6065 x,
6066 res,
6067 bias,
6068 g,
6069 b,
6070 out,
6071 rows,
6072 h,
6073 eps,
6074 has_bias,
6075 } => {
6076 let (rows, h) = (rows as usize, h as usize);
6077 Arc::new(move |base: *mut u8| unsafe {
6078 let zero = vec![0f32; h]; let bi = if has_bias { sl(bias, base, h) } else { &zero };
6080 let xp = sl(x, base, rows * h).as_ptr() as usize;
6081 let rp = sl(res, base, rows * h).as_ptr() as usize;
6082 let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
6083 let bp = bi.as_ptr() as usize;
6084 let gp = sl(g, base, h).as_ptr() as usize;
6085 let bbp = sl(b, base, h).as_ptr() as usize;
6086 crate::pool::par_for(rows, 4, &|off, cnt| {
6087 let xs = std::slice::from_raw_parts(
6088 (xp as *const f32).add(off * h),
6089 cnt * h,
6090 );
6091 let rs = std::slice::from_raw_parts(
6092 (rp as *const f32).add(off * h),
6093 cnt * h,
6094 );
6095 let os = std::slice::from_raw_parts_mut(
6096 (op as *mut f32).add(off * h),
6097 cnt * h,
6098 );
6099 let bi = std::slice::from_raw_parts(bp as *const f32, h);
6100 let g = std::slice::from_raw_parts(gp as *const f32, h);
6101 let b = std::slice::from_raw_parts(bbp as *const f32, h);
6102 crate::kernels::residual_bias_layer_norm(
6103 xs, rs, bi, g, b, os, cnt, h, eps,
6104 );
6105 });
6106 })
6107 }
6108
6109 Thunk::BiasAdd {
6110 src,
6111 bias,
6112 dst,
6113 m,
6114 n,
6115 } => {
6116 let (m, n) = (m as usize, n as usize);
6117 let len = m * n;
6118 Arc::new(move |base: *mut u8| unsafe {
6119 let out = sl_mut(dst, base, len);
6120 if src != dst {
6121 let src_ptr = base.add(src) as *const f32;
6122 let dst_ptr = base.add(dst) as *mut f32;
6123 if src_ptr != dst_ptr {
6124 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
6125 }
6126 }
6127 crate::blas::bias_add(out, sl(bias, base, n), m, n);
6128 })
6129 }
6130
6131 Thunk::Gather {
6132 table,
6133 table_len,
6134 idx,
6135 dst,
6136 num_idx,
6137 trailing,
6138 idx_i64,
6139 table_bytes,
6140 } => {
6141 let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
6142 let rows = tl / tr.max(1);
6143 let (idx_i64, table_bytes) = (idx_i64, table_bytes);
6144 Arc::new(move |base: *mut u8| unsafe {
6145 if table_bytes == 8 {
6146 let tab = sl_i64(table, base, tl);
6147 let out = sl_mut_i64(dst, base, ni * tr);
6148 if idx_i64 != 0 {
6149 let ids = sl_i64(idx, base, ni);
6150 for i in 0..ni {
6151 let row = ids[i].max(0) as usize;
6152 if row < rows {
6153 out[i * tr..(i + 1) * tr]
6154 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
6155 }
6156 }
6157 } else {
6158 let ids = sl(idx, base, ni);
6159 for i in 0..ni {
6160 let row = ids[i] as usize;
6161 if row < rows {
6162 out[i * tr..(i + 1) * tr]
6163 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
6164 }
6165 }
6166 }
6167 } else {
6168 let tab = sl(table, base, tl);
6169 let out = sl_mut(dst, base, ni * tr);
6170 if idx_i64 != 0 {
6171 let ids = sl_i64(idx, base, ni);
6172 for i in 0..ni {
6173 let row = ids[i].max(0) as usize;
6174 if row < rows {
6175 out[i * tr..(i + 1) * tr]
6176 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
6177 }
6178 }
6179 } else {
6180 let ids = sl(idx, base, ni);
6181 for i in 0..ni {
6182 let row = ids[i] as usize;
6183 if row < rows {
6184 out[i * tr..(i + 1) * tr]
6185 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
6186 }
6187 }
6188 }
6189 }
6190 })
6191 }
6192
6193 Thunk::Narrow {
6194 src,
6195 dst,
6196 outer,
6197 src_stride,
6198 dst_stride,
6199 inner,
6200 elem_bytes,
6201 } => {
6202 narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
6203 }
6204
6205 Thunk::Copy { src, dst, len } => {
6206 let len = len as usize;
6207 Arc::new(move |base: *mut u8| unsafe {
6208 if src == dst || len == 0 {
6209 return;
6210 }
6211 let src_ptr = base.add(src) as *const f32;
6212 let dst_ptr = base.add(dst) as *mut f32;
6213 if src_ptr == dst_ptr {
6214 return;
6215 }
6216 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
6217 })
6218 }
6219
6220 Thunk::Softmax { data, rows, cols } => {
6221 let (rows, cols) = (rows as usize, cols as usize);
6222 Arc::new(move |base: *mut u8| unsafe {
6223 crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
6224 })
6225 }
6226
6227 Thunk::Cumsum {
6228 src,
6229 dst,
6230 rows,
6231 cols,
6232 exclusive,
6233 } => {
6234 let (rows, cols) = (rows as usize, cols as usize);
6235 Arc::new(move |base: *mut u8| unsafe {
6236 let s = sl(src, base, rows * cols);
6237 let d = sl_mut(dst, base, rows * cols);
6238 if exclusive {
6239 for r in 0..rows {
6240 let mut acc = 0.0f32;
6241 for c in 0..cols {
6242 d[r * cols + c] = acc;
6243 acc += s[r * cols + c];
6244 }
6245 }
6246 } else {
6247 for r in 0..rows {
6248 let mut acc = 0.0f32;
6249 for c in 0..cols {
6250 acc += s[r * cols + c];
6251 d[r * cols + c] = acc;
6252 }
6253 }
6254 }
6255 })
6256 }
6257
6258 Thunk::Sample {
6259 logits,
6260 dst,
6261 batch,
6262 vocab,
6263 top_k,
6264 top_p,
6265 temperature,
6266 seed,
6267 } => {
6268 let (b, v) = (batch as usize, vocab as usize);
6269 let k = (top_k as usize).min(v);
6270 Arc::new(move |base: *mut u8| unsafe {
6271 let lg = sl(logits, base, b * v);
6272 let out = sl_mut(dst, base, b);
6273 let mut rng =
6274 rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
6275 for bi in 0..b {
6276 let row = &lg[bi * v..(bi + 1) * v];
6277 out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
6278 }
6279 })
6280 }
6281
6282 Thunk::RngNormal {
6283 dst,
6284 len,
6285 mean,
6286 scale,
6287 key,
6288 op_seed,
6289 } => {
6290 let n = len as usize;
6291 let rng = rng_shared.clone();
6292 Arc::new(move |base: *mut u8| unsafe {
6293 let out = sl_mut(dst, base, n);
6294 let opts = *rng.read().unwrap();
6295 rlx_ir::fill_normal_like(out, mean, scale, opts, key, op_seed);
6296 })
6297 }
6298
6299 Thunk::RngUniform {
6300 dst,
6301 len,
6302 low,
6303 high,
6304 key,
6305 op_seed,
6306 } => {
6307 let n = len as usize;
6308 let rng = rng_shared.clone();
6309 Arc::new(move |base: *mut u8| unsafe {
6310 let out = sl_mut(dst, base, n);
6311 let opts = *rng.read().unwrap();
6312 rlx_ir::fill_uniform_like(out, low, high, opts, key, op_seed);
6313 })
6314 }
6315
6316 Thunk::DequantMatMul {
6317 x,
6318 w_q,
6319 scale,
6320 zp,
6321 dst,
6322 m,
6323 k,
6324 n,
6325 block_size,
6326 is_asymmetric,
6327 } => {
6328 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
6329 let n_blocks_per_col = k.div_ceil(bs);
6330 Arc::new(move |base: *mut u8| unsafe {
6331 let xs = sl(x, base, m * k);
6332 let raw = base.add(w_q);
6334 let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
6335 let scales = sl(scale, base, n_blocks_per_col * n);
6336 let zps = if is_asymmetric {
6337 sl(zp, base, n_blocks_per_col * n)
6338 } else {
6339 &[][..]
6340 };
6341 let out = sl_mut(dst, base, m * n);
6342 dequant_matmul_int8(
6343 xs,
6344 w_bytes,
6345 scales,
6346 zps,
6347 out,
6348 m,
6349 k,
6350 n,
6351 bs,
6352 is_asymmetric,
6353 );
6354 })
6355 }
6356
6357 Thunk::DequantMatMulGguf {
6358 x,
6359 w_q,
6360 dst,
6361 m,
6362 k,
6363 n,
6364 scheme,
6365 } => {
6366 let (m, k, n) = (m as usize, k as usize, n as usize);
6367 let block_bytes = scheme.gguf_block_bytes() as usize;
6368 let block_elems = scheme.gguf_block_size() as usize;
6369 let total_bytes = (k * n) / block_elems * block_bytes;
6370 Arc::new(move |base: *mut u8| unsafe {
6371 let xs = sl(x, base, m * k);
6372 let w_bytes =
6373 std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
6374 let out = sl_mut(dst, base, m * n);
6375 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
6376 })
6377 }
6378
6379 Thunk::DequantMatMulInt4 {
6380 x,
6381 w_q,
6382 scale,
6383 zp,
6384 dst,
6385 m,
6386 k,
6387 n,
6388 block_size,
6389 is_asymmetric,
6390 } => {
6391 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
6392 let n_blocks = k.div_ceil(bs);
6393 Arc::new(move |base: *mut u8| unsafe {
6394 let xs = sl(x, base, m * k);
6395 let w_bytes = std::slice::from_raw_parts(
6396 base.add(w_q) as *const u8,
6397 (k * n).div_ceil(2),
6398 );
6399 let scales = sl(scale, base, n_blocks * n);
6400 let zps = if is_asymmetric {
6401 sl(zp, base, n_blocks * n)
6402 } else {
6403 &[][..]
6404 };
6405 let out = sl_mut(dst, base, m * n);
6406 dequant_matmul_int4(
6407 xs,
6408 w_bytes,
6409 scales,
6410 zps,
6411 out,
6412 m,
6413 k,
6414 n,
6415 bs,
6416 is_asymmetric,
6417 );
6418 })
6419 }
6420
6421 Thunk::DequantMatMulFp8 {
6422 x,
6423 w_q,
6424 scale,
6425 dst,
6426 m,
6427 k,
6428 n,
6429 e5m2,
6430 } => {
6431 let (m, k, n) = (m as usize, k as usize, n as usize);
6432 Arc::new(move |base: *mut u8| unsafe {
6433 let xs = sl(x, base, m * k);
6434 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
6435 let scales = sl(scale, base, n);
6436 let out = sl_mut(dst, base, m * n);
6437 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
6438 })
6439 }
6440
6441 Thunk::DequantMatMulNvfp4 {
6442 x,
6443 w_q,
6444 scale,
6445 global_scale,
6446 dst,
6447 m,
6448 k,
6449 n,
6450 } => {
6451 let (m, k, n) = (m as usize, k as usize, n as usize);
6452 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
6453 Arc::new(move |base: *mut u8| unsafe {
6454 let xs = sl(x, base, m * k);
6455 let w_bytes = std::slice::from_raw_parts(
6456 base.add(w_q) as *const u8,
6457 (k * n).div_ceil(2),
6458 );
6459 let scale_bytes =
6460 std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
6461 let gs = sl(global_scale, base, 1)[0];
6462 let out = sl_mut(dst, base, m * n);
6463 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
6464 })
6465 }
6466
6467 Thunk::LoraMatMul {
6468 x,
6469 w,
6470 a,
6471 b,
6472 dst,
6473 m,
6474 k,
6475 n,
6476 r,
6477 scale,
6478 } => {
6479 let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
6480 Arc::new(move |base: *mut u8| unsafe {
6481 let xs = sl(x, base, m * k);
6482 let ws = sl(w, base, k * n);
6483 let a_s = sl(a, base, k * r);
6484 let bs = sl(b, base, r * n);
6485 let out = sl_mut(dst, base, m * n);
6486 crate::blas::sgemm(xs, ws, out, m, k, n);
6488 let mut tmp = vec![0f32; m * r];
6490 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
6491 if scale != 1.0 {
6495 for v in tmp.iter_mut() {
6496 *v *= scale;
6497 }
6498 }
6499 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
6500 })
6501 }
6502
6503 Thunk::LayerNorm {
6504 src,
6505 g,
6506 b,
6507 dst,
6508 rows,
6509 h,
6510 eps,
6511 } => {
6512 let (rows, h) = (rows as usize, h as usize);
6513 Arc::new(move |base: *mut u8| unsafe {
6514 let inp = sl(src, base, rows * h);
6515 let gamma = sl(g, base, h);
6516 let beta = sl(b, base, h);
6517 let out = sl_mut(dst, base, rows * h);
6518 for row in 0..rows {
6519 crate::kernels::layer_norm_row(
6520 &inp[row * h..(row + 1) * h],
6521 gamma,
6522 beta,
6523 &mut out[row * h..(row + 1) * h],
6524 h,
6525 eps,
6526 );
6527 }
6528 })
6529 }
6530
6531 Thunk::BatchNormInference {
6532 src,
6533 g,
6534 b,
6535 mean,
6536 var,
6537 dst,
6538 count,
6539 channels,
6540 eps,
6541 } => {
6542 let count = count as usize;
6543 let c = channels as usize;
6544 let n = count * c;
6545 let (src, g, b, mean, var, dst) = (src, g, b, mean, var, dst);
6546 Arc::new(move |base: *mut u8| unsafe {
6547 crate::kernels::batch_norm_inference(
6548 sl(src, base, n),
6549 sl(g, base, c),
6550 sl(b, base, c),
6551 sl(mean, base, c),
6552 sl(var, base, c),
6553 sl_mut(dst, base, n),
6554 c,
6555 eps,
6556 );
6557 })
6558 }
6559
6560 Thunk::Attention {
6561 q,
6562 k,
6563 v,
6564 mask,
6565 out,
6566 batch,
6567 seq,
6568 kv_seq,
6569 heads,
6570 head_dim,
6571 mask_kind,
6572 scale,
6573 q_row_stride,
6574 k_row_stride,
6575 v_row_stride,
6576 bhsd,
6577 } => {
6578 if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6579 eprintln!("[attn-compile] batch={batch} seq={seq} kv_seq={kv_seq} heads={heads} bhsd={bhsd}");
6580 }
6581 let (b, q_s, k_s, nh, dh) = (
6590 batch as usize,
6591 seq as usize,
6592 kv_seq as usize,
6593 heads as usize,
6594 head_dim as usize,
6595 );
6596 let hs = nh * dh;
6597 let qrs = q_row_stride as usize;
6598 let krs = k_row_stride as usize;
6599 let vrs = v_row_stride as usize;
6600 Arc::new(move |base: *mut u8| unsafe {
6602 if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6603 eprintln!("[attn] b={b} q_s={q_s} k_s={k_s} nh={nh} dh={dh} bhsd={bhsd} mask_kind={:?}", mask_kind);
6604 }
6605 let (q_len, k_len, v_len, o_len) = if bhsd {
6610 let qn = b * nh * q_s * dh;
6611 let kn = b * nh * k_s * dh;
6612 (qn, kn, kn, qn)
6613 } else {
6614 (b * q_s * qrs, b * k_s * krs, b * k_s * vrs, b * q_s * hs)
6615 };
6616 let q_d = sl(q, base, q_len);
6617 let k_d = sl(k, base, k_len);
6618 let v_d = sl(v, base, v_len);
6619 let m_d: &[f32] = match mask_kind {
6620 rlx_ir::op::MaskKind::Custom => sl(mask, base, b * k_s),
6621 rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * q_s * k_s),
6622 _ => &[],
6623 };
6624 let o_d = sl_mut(out, base, o_len);
6625 let mut qh = vec![0f32; q_s * dh];
6626 let mut kh = vec![0f32; k_s * dh];
6627 let mut vh = vec![0f32; k_s * dh];
6628 let mut sc = vec![0f32; q_s * k_s];
6629 let mut oh = vec![0f32; q_s * dh];
6630 for bi in 0..b {
6631 for hi in 0..nh {
6632 for si in 0..q_s {
6634 let q_off = if bhsd {
6635 bi * nh * q_s * dh + hi * q_s * dh + si * dh
6636 } else {
6637 bi * q_s * qrs + si * qrs + hi * dh
6638 };
6639 qh[si * dh..(si + 1) * dh]
6640 .copy_from_slice(&q_d[q_off..q_off + dh]);
6641 }
6642 for si in 0..k_s {
6644 let (k_off, v_off) = if bhsd {
6645 (
6646 bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6647 bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6648 )
6649 } else {
6650 (
6651 bi * k_s * krs + si * krs + hi * dh,
6652 bi * k_s * vrs + si * vrs + hi * dh,
6653 )
6654 };
6655 kh[si * dh..(si + 1) * dh]
6656 .copy_from_slice(&k_d[k_off..k_off + dh]);
6657 vh[si * dh..(si + 1) * dh]
6658 .copy_from_slice(&v_d[v_off..v_off + dh]);
6659 }
6660 for qi in 0..q_s {
6661 for ki in 0..k_s {
6662 let mut dot = 0f32;
6663 for d in 0..dh {
6664 dot += qh[qi * dh + d] * kh[ki * dh + d];
6665 }
6666 sc[qi * k_s + ki] = dot * scale;
6667 }
6668 }
6669 let q_offset = k_s.saturating_sub(q_s);
6673 match mask_kind {
6674 rlx_ir::op::MaskKind::None => {}
6675 rlx_ir::op::MaskKind::Causal => {
6676 for qi in 0..q_s {
6677 let abs_q = q_offset + qi;
6678 for ki in (abs_q + 1)..k_s {
6679 sc[qi * k_s + ki] = mask_neg;
6680 }
6681 }
6682 }
6683 rlx_ir::op::MaskKind::SlidingWindow(w) => {
6684 for qi in 0..q_s {
6685 let abs_q = q_offset + qi;
6686 let lo = abs_q.saturating_sub(w);
6687 for ki in 0..k_s {
6688 if ki < lo || ki > abs_q {
6689 sc[qi * k_s + ki] = mask_neg;
6690 }
6691 }
6692 }
6693 }
6694 rlx_ir::op::MaskKind::Custom => {
6695 for qi in 0..q_s {
6696 for ki in 0..k_s {
6697 if m_d[bi * k_s + ki] < mask_thr {
6698 sc[qi * k_s + ki] = mask_neg;
6699 }
6700 }
6701 }
6702 }
6703 rlx_ir::op::MaskKind::Bias => {
6704 let per_bh = q_s * k_s;
6705 let off = (bi * nh + hi) * per_bh;
6706 for i in 0..per_bh {
6707 sc[i] += m_d[off + i];
6708 }
6709 }
6710 }
6711 crate::naive::softmax(&mut sc, q_s, k_s);
6712 oh.fill(0.0);
6713 for qi in 0..q_s {
6714 for ki in 0..k_s {
6715 let w = sc[qi * k_s + ki];
6716 if w > score_skip {
6717 for d in 0..dh {
6718 oh[qi * dh + d] += w * vh[ki * dh + d];
6719 }
6720 }
6721 }
6722 }
6723 for si in 0..q_s {
6724 let off = if bhsd {
6725 bi * nh * q_s * dh + hi * q_s * dh + si * dh
6726 } else {
6727 bi * q_s * hs + si * hs + hi * dh
6728 };
6729 o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
6730 }
6731 }
6732 }
6733 })
6734 }
6735
6736 Thunk::FusedSwiGLU {
6737 src,
6738 dst,
6739 n_half,
6740 total,
6741 gate_first,
6742 } => {
6743 let n = n_half as usize;
6744 let t = total as usize;
6745 let outer = t / n;
6746 let in_total = outer * 2 * n;
6747 Arc::new(move |base: *mut u8| unsafe {
6748 let inp = sl(src, base, in_total);
6749 let out = sl_mut(dst, base, t);
6750 for o in 0..outer {
6751 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
6752 let out_row = &mut out[o * n..(o + 1) * n];
6753 for i in 0..n {
6754 let (up, gate) = if gate_first {
6755 (in_row[n + i], in_row[i])
6756 } else {
6757 (in_row[i], in_row[n + i])
6758 };
6759 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
6760 }
6761 }
6762 })
6763 }
6764
6765 Thunk::Concat {
6766 dst,
6767 outer,
6768 inner,
6769 total_axis,
6770 inputs,
6771 } => {
6772 let outer = outer as usize;
6773 let inner = inner as usize;
6774 let total_axis = total_axis as usize;
6775 let out_total = outer * total_axis * inner;
6776 let mut layout: Vec<(usize, usize, usize, usize)> =
6777 Vec::with_capacity(inputs.len());
6778 let mut cum: usize = 0;
6779 for (src_off, in_axis, in_numel) in &inputs {
6780 let in_axis = *in_axis as usize;
6781 layout.push((*src_off, cum * inner, in_axis * inner, *in_numel as usize));
6782 cum += in_axis;
6783 }
6784 Arc::new(move |base: *mut u8| unsafe {
6785 let out = sl_mut(dst, base, out_total);
6786 let row_stride = total_axis * inner;
6787 for (src_off, dst_col_off, copy_per_row, in_numel) in &layout {
6788 let inp = sl(*src_off, base, (*in_numel).max(1));
6789 concat_copy_rows_f32(
6790 out,
6791 inp,
6792 outer,
6793 *copy_per_row,
6794 row_stride,
6795 *dst_col_off,
6796 *in_numel,
6797 );
6798 }
6799 })
6800 }
6801
6802 Thunk::CustomOp {
6803 kernel,
6804 inputs,
6805 output,
6806 attrs,
6807 } => {
6808 let kernel = kernel.clone();
6814 let attrs = attrs.clone();
6815 let inputs = inputs.clone();
6816 let (out_off, out_len, out_shape) = output.clone();
6817 Arc::new(move |base: *mut u8| unsafe {
6818 dispatch_custom_op(
6819 &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
6820 );
6821 })
6822 }
6823
6824 Thunk::GaussianSplatRender {
6825 positions_off,
6826 positions_len,
6827 scales_off,
6828 scales_len,
6829 rotations_off,
6830 rotations_len,
6831 opacities_off,
6832 opacities_len,
6833 colors_off,
6834 colors_len,
6835 sh_coeffs_off,
6836 sh_coeffs_len,
6837 meta_off,
6838 dst_off,
6839 dst_len,
6840 width,
6841 height,
6842 tile_size,
6843 radius_scale,
6844 alpha_cutoff,
6845 max_splat_steps,
6846 transmittance_threshold,
6847 max_list_entries,
6848 } => Arc::new(move |base: *mut u8| unsafe {
6849 crate::splat::execute_gaussian_splat_render(
6850 positions_off,
6851 positions_len,
6852 scales_off,
6853 scales_len,
6854 rotations_off,
6855 rotations_len,
6856 opacities_off,
6857 opacities_len,
6858 colors_off,
6859 colors_len,
6860 sh_coeffs_off,
6861 sh_coeffs_len,
6862 meta_off,
6863 dst_off,
6864 dst_len,
6865 width,
6866 height,
6867 tile_size,
6868 radius_scale,
6869 alpha_cutoff,
6870 max_splat_steps,
6871 transmittance_threshold,
6872 max_list_entries,
6873 base,
6874 );
6875 }),
6876
6877 Thunk::GaussianSplatRenderBackward {
6878 positions_off,
6879 positions_len,
6880 scales_off,
6881 scales_len,
6882 rotations_off,
6883 rotations_len,
6884 opacities_off,
6885 opacities_len,
6886 colors_off,
6887 colors_len,
6888 sh_coeffs_off,
6889 sh_coeffs_len,
6890 meta_off,
6891 d_loss_off,
6892 d_loss_len,
6893 packed_off,
6894 packed_len,
6895 width,
6896 height,
6897 tile_size,
6898 radius_scale,
6899 alpha_cutoff,
6900 max_splat_steps,
6901 transmittance_threshold,
6902 max_list_entries,
6903 loss_grad_clip,
6904 sh_band,
6905 max_anisotropy,
6906 } => Arc::new(move |base: *mut u8| unsafe {
6907 crate::splat::execute_gaussian_splat_render_backward(
6908 positions_off,
6909 positions_len,
6910 scales_off,
6911 scales_len,
6912 rotations_off,
6913 rotations_len,
6914 opacities_off,
6915 opacities_len,
6916 colors_off,
6917 colors_len,
6918 sh_coeffs_off,
6919 sh_coeffs_len,
6920 meta_off,
6921 d_loss_off,
6922 d_loss_len,
6923 packed_off,
6924 packed_len,
6925 width,
6926 height,
6927 tile_size,
6928 radius_scale,
6929 alpha_cutoff,
6930 max_splat_steps,
6931 transmittance_threshold,
6932 max_list_entries,
6933 loss_grad_clip,
6934 sh_band,
6935 max_anisotropy,
6936 base,
6937 );
6938 }),
6939
6940 Thunk::GaussianSplatPrepare {
6941 positions_off,
6942 positions_len,
6943 scales_off,
6944 scales_len,
6945 rotations_off,
6946 rotations_len,
6947 opacities_off,
6948 opacities_len,
6949 colors_off,
6950 colors_len,
6951 sh_coeffs_off,
6952 sh_coeffs_len,
6953 meta_off,
6954 meta_len,
6955 prep_off,
6956 prep_len,
6957 width,
6958 height,
6959 tile_size,
6960 radius_scale,
6961 alpha_cutoff,
6962 max_splat_steps,
6963 transmittance_threshold,
6964 max_list_entries,
6965 } => Arc::new(move |base: *mut u8| unsafe {
6966 crate::splat::execute_gaussian_splat_prepare(
6967 positions_off,
6968 positions_len,
6969 scales_off,
6970 scales_len,
6971 rotations_off,
6972 rotations_len,
6973 opacities_off,
6974 opacities_len,
6975 colors_off,
6976 colors_len,
6977 sh_coeffs_off,
6978 sh_coeffs_len,
6979 meta_off,
6980 meta_len,
6981 prep_off,
6982 prep_len,
6983 width,
6984 height,
6985 tile_size,
6986 radius_scale,
6987 alpha_cutoff,
6988 max_splat_steps,
6989 transmittance_threshold,
6990 max_list_entries,
6991 base,
6992 );
6993 }),
6994
6995 Thunk::GaussianSplatRasterize {
6996 prep_off,
6997 prep_len,
6998 meta_off,
6999 meta_len,
7000 dst_off,
7001 dst_len,
7002 count,
7003 width,
7004 height,
7005 tile_size,
7006 alpha_cutoff,
7007 max_splat_steps,
7008 transmittance_threshold,
7009 max_list_entries,
7010 } => Arc::new(move |base: *mut u8| unsafe {
7011 crate::splat::execute_gaussian_splat_rasterize(
7012 prep_off,
7013 prep_len,
7014 meta_off,
7015 meta_len,
7016 dst_off,
7017 dst_len,
7018 count,
7019 width,
7020 height,
7021 tile_size,
7022 alpha_cutoff,
7023 max_splat_steps,
7024 transmittance_threshold,
7025 max_list_entries,
7026 base,
7027 );
7028 }),
7029
7030 Thunk::Fft1d {
7031 src,
7032 dst,
7033 outer,
7034 n_complex,
7035 inverse,
7036 norm_tag,
7037 dtype,
7038 } => {
7039 let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
7040 rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
7041 execute_fft1d_f64(
7042 src,
7043 dst,
7044 outer as usize,
7045 n_complex as usize,
7046 inverse,
7047 norm_tag,
7048 base,
7049 );
7050 }),
7051 rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
7052 execute_fft1d_f32(
7053 src,
7054 dst,
7055 outer as usize,
7056 n_complex as usize,
7057 inverse,
7058 norm_tag,
7059 base,
7060 );
7061 }),
7062 rlx_ir::DType::C64 => Arc::new(move |base: *mut u8| unsafe {
7063 execute_fft1d_c64(
7064 src,
7065 dst,
7066 outer as usize,
7067 n_complex as usize,
7068 inverse,
7069 norm_tag,
7070 base,
7071 );
7072 }),
7073 other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
7074 };
7075 f
7076 }
7077
7078 Thunk::FftButterflyStage {
7079 state_src,
7080 state_dst,
7081 gate_src,
7082 rev_src,
7083 tw_re_src,
7084 tw_im_src,
7085 batch,
7086 n_fft,
7087 stage,
7088 } => Arc::new(move |base: *mut u8| unsafe {
7089 execute_fft_butterfly_stage_f32(
7090 state_src,
7091 state_dst,
7092 gate_src,
7093 rev_src,
7094 tw_re_src,
7095 tw_im_src,
7096 batch as usize,
7097 n_fft as usize,
7098 stage as usize,
7099 base,
7100 );
7101 }),
7102
7103 Thunk::LogMel {
7104 spec,
7105 filters,
7106 dst,
7107 outer,
7108 n_fft,
7109 n_bins,
7110 n_mels,
7111 } => Arc::new(move |base: *mut u8| unsafe {
7112 execute_log_mel_f32(
7113 spec,
7114 filters,
7115 dst,
7116 outer as usize,
7117 n_fft as usize,
7118 n_bins as usize,
7119 n_mels as usize,
7120 base,
7121 );
7122 }),
7123
7124 Thunk::LogMelBackward {
7125 spec,
7126 filters,
7127 dy,
7128 dst,
7129 outer,
7130 n_fft,
7131 n_bins,
7132 n_mels,
7133 } => Arc::new(move |base: *mut u8| unsafe {
7134 execute_log_mel_backward_f32(
7135 spec,
7136 filters,
7137 dy,
7138 dst,
7139 outer as usize,
7140 n_fft as usize,
7141 n_bins as usize,
7142 n_mels as usize,
7143 base,
7144 );
7145 }),
7146
7147 Thunk::WelchPeaks {
7148 spec,
7149 dst,
7150 welch_batch,
7151 n_fft,
7152 n_segments,
7153 k,
7154 } => Arc::new(move |base: *mut u8| unsafe {
7155 execute_welch_peaks_f32(
7156 spec,
7157 dst,
7158 welch_batch as usize,
7159 n_fft as usize,
7160 n_segments as usize,
7161 k as usize,
7162 base,
7163 );
7164 }),
7165
7166 _ => Arc::new(|_: *mut u8| {}),
7167 }
7168 })
7169 .collect();
7170
7171 let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
7175 .and_then(|v| v.parse().ok())
7176 .unwrap_or(64);
7177 let should_fuse = thunks.iter().any(|t| match t {
7178 Thunk::Attention { batch, seq, .. } => {
7179 (*batch as usize) * (*seq as usize) <= fuse_threshold
7180 }
7181 _ => false,
7182 });
7183
7184 if should_fuse {
7185 let active: Vec<usize> = thunks
7187 .iter()
7188 .enumerate()
7189 .filter(|(_, t)| !matches!(t, Thunk::Nop))
7190 .map(|(i, _)| i)
7191 .collect();
7192
7193 let mut kill = vec![false; thunks.len()]; let mut insertions: Vec<(usize, Thunk)> = Vec::new(); let mut ai = 0;
7197 while ai < active.len() {
7198 let a = |off: usize| -> Option<(usize, &Thunk)> {
7200 active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
7201 };
7202
7203 let matched = (|| {
7205 let (_i0, t0) = a(0)?;
7206 let (_, t1) = a(1)?;
7207 let (_, t2) = a(2)?;
7208 let (_, t3) = a(3)?;
7209
7210 let (hidden, qkv_w, qkv_b, has_b) = match t0 {
7212 Thunk::FusedMmBiasAct {
7213 a,
7214 w,
7215 bias,
7216 n: _,
7217 act: None,
7218 ..
7219 } => (*a, *w, *bias, true),
7220 Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
7221 _ => return None,
7222 };
7223
7224 if !matches!(t1, Thunk::Narrow { .. }) {
7226 return None;
7227 }
7228 if !matches!(t2, Thunk::Narrow { .. }) {
7229 return None;
7230 }
7231 if !matches!(t3, Thunk::Narrow { .. }) {
7232 return None;
7233 }
7234
7235 let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
7237 _,
7238 Thunk::Rope {
7239 cos, sin, cos_len, ..
7240 },
7241 )) = a(4)
7242 {
7243 if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
7244 if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
7245 (true, 6, *cos, *sin, *cos_len)
7246 } else {
7247 return None;
7248 }
7249 } else {
7250 return None;
7251 }
7252 } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
7253 (false, 4, 0, 0, 0)
7254 } else {
7255 return None;
7256 };
7257
7258 let (_attn_real_idx, attn_t) = a(attn_ai)?;
7259 let (batch, seq, heads, head_dim, mask) = match attn_t {
7260 Thunk::Attention {
7261 batch,
7262 seq,
7263 heads,
7264 head_dim,
7265 mask,
7266 ..
7267 } => (*batch, *seq, *heads, *head_dim, *mask),
7268 _ => return None,
7269 };
7270
7271 let (_out_real_idx, out_t) = a(attn_ai + 1)?;
7273 let (out_w, out_b, out_dst) = match out_t {
7274 Thunk::FusedMmBiasAct {
7275 w,
7276 bias,
7277 c,
7278 act: None,
7279 ..
7280 } => (*w, *bias, *c),
7281 Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
7282 _ => return None,
7283 };
7284
7285 let hs = heads * head_dim;
7286 let total_active = attn_ai + 2; Some((
7289 total_active,
7290 Thunk::FusedAttnBlock {
7291 hidden,
7292 qkv_w,
7293 out_w,
7294 mask,
7295 out: out_dst,
7296 qkv_b: if has_b { qkv_b } else { 0 },
7297 out_b: if has_b { out_b } else { 0 },
7298 cos: cos_off,
7299 sin: sin_off,
7300 cos_len: cl,
7301 batch,
7302 seq,
7303 hs,
7304 nh: heads,
7305 dh: head_dim,
7306 has_bias: has_b,
7307 has_rope,
7308 },
7309 ))
7310 })();
7311
7312 if let Some((count, fused_thunk)) = matched {
7313 for off in 0..count {
7315 if let Some(&idx) = active.get(ai + off) {
7316 kill[idx] = true;
7317 }
7318 }
7319 insertions.push((active[ai], fused_thunk));
7321 ai += count;
7322 } else {
7323 ai += 1;
7324 }
7325 }
7326
7327 if !insertions.is_empty() {
7329 let mut new_thunks = Vec::with_capacity(thunks.len());
7330 let mut insert_idx = 0;
7331 for (i, t) in thunks.into_iter().enumerate() {
7332 if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
7333 new_thunks.push(insertions[insert_idx].1.clone());
7334 insert_idx += 1;
7335 }
7336 if !kill[i] {
7337 new_thunks.push(t);
7338 }
7339 }
7340 if cfg.verbose >= 1 {
7341 eprintln!(
7342 "[rlx] fused_attention: {} attention blocks fused",
7343 insertions.len()
7344 );
7345 }
7346 thunks = new_thunks;
7347 }
7348 }
7349
7350 if should_fuse {
7355 let active: Vec<usize> = thunks
7356 .iter()
7357 .enumerate()
7358 .filter(|(_, t)| !matches!(t, Thunk::Nop))
7359 .map(|(i, _)| i)
7360 .collect();
7361
7362 let mut kill = vec![false; thunks.len()];
7363 let mut insertions: Vec<(usize, Thunk)> = Vec::new();
7364
7365 let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
7366
7367 let mut ai = 0;
7368 while ai < active.len() {
7369 let bert_match = (|| -> Option<usize> {
7371 let fab = a(ai)?;
7372 let rln1 = a(ai + 1)?;
7373 let ffn1 = a(ai + 2)?;
7374 let ffn2 = a(ai + 3)?;
7375 let rln2 = a(ai + 4)?;
7376
7377 let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
7378 Thunk::FusedAttnBlock {
7379 hidden,
7380 qkv_w,
7381 qkv_b,
7382 out_w,
7383 out_b,
7384 mask,
7385 batch,
7386 seq,
7387 hs,
7388 nh,
7389 dh,
7390 has_bias: true,
7391 has_rope: false,
7392 ..
7393 } => (
7394 *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
7395 ),
7396 _ => return None,
7397 };
7398 let (ln1_g, ln1_b, eps1) = match rln1 {
7399 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7400 _ => return None,
7401 };
7402 let (fc1_w, fc1_b, int_dim) = match ffn1 {
7403 Thunk::FusedMmBiasAct {
7404 w,
7405 bias,
7406 n,
7407 act: Some(Activation::Gelu),
7408 ..
7409 } => (*w, *bias, *n),
7410 _ => return None,
7411 };
7412 let (fc2_w, fc2_b) = match ffn2 {
7413 Thunk::FusedMmBiasAct {
7414 w, bias, act: None, ..
7415 } => (*w, *bias),
7416 _ => return None,
7417 };
7418 let (ln2_g, ln2_b, eps2, out) = match rln2 {
7419 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7420 _ => return None,
7421 };
7422
7423 for off in 0..5 {
7424 kill[active[ai + off]] = true;
7425 }
7426 insertions.push((
7427 active[ai],
7428 Thunk::FusedBertLayer {
7429 hidden,
7430 qkv_w,
7431 qkv_b,
7432 out_w,
7433 out_b,
7434 mask,
7435 ln1_g,
7436 ln1_b,
7437 eps1,
7438 fc1_w,
7439 fc1_b,
7440 fc2_w,
7441 fc2_b,
7442 ln2_g,
7443 ln2_b,
7444 eps2,
7445 out,
7446 batch,
7447 seq,
7448 hs,
7449 nh,
7450 dh,
7451 int_dim,
7452 },
7453 ));
7454 Some(5)
7455 })();
7456 if let Some(n) = bert_match {
7457 ai += n;
7458 continue;
7459 }
7460
7461 #[allow(unreachable_code)]
7465 let nomic_match = (|| -> Option<usize> {
7466 return None; let fab = a(ai)?;
7468 let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
7469 match fab {
7470 Thunk::FusedAttnBlock {
7471 hidden,
7472 qkv_w,
7473 out_w,
7474 mask,
7475 cos,
7476 sin,
7477 cos_len,
7478 batch,
7479 seq,
7480 hs,
7481 nh,
7482 dh,
7483 has_bias: false,
7484 has_rope: true,
7485 ..
7486 } => (
7487 *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
7488 *hs, *nh, *dh,
7489 ),
7490 _ => return None,
7491 };
7492 let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
7494 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7495 _ => return None,
7496 };
7497 let fused_fc_w = match a(ai + 2)? {
7499 Thunk::Sgemm { b: w, .. } => *w,
7500 _ => return None,
7501 };
7502 if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
7504 return None;
7505 }
7506 if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
7507 return None;
7508 }
7509 if !matches!(
7511 a(ai + 5)?,
7512 Thunk::ActivationInPlace {
7513 act: Activation::Silu,
7514 ..
7515 }
7516 ) {
7517 return None;
7518 }
7519 if !matches!(
7521 a(ai + 6)?,
7522 Thunk::BinaryFull {
7523 op: BinaryOp::Mul,
7524 ..
7525 }
7526 ) {
7527 return None;
7528 }
7529 let fc2_w = match a(ai + 7)? {
7531 Thunk::Sgemm { b: w, .. } => *w,
7532 _ => return None,
7533 };
7534 let int_dim = match a(ai + 3)? {
7536 Thunk::Narrow { inner, .. } => *inner,
7537 _ => return None,
7538 };
7539 let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
7541 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7542 _ => return None,
7543 };
7544
7545 for off in 0..9 {
7546 kill[active[ai + off]] = true;
7547 }
7548 insertions.push((
7549 active[ai],
7550 Thunk::FusedNomicLayer {
7551 hidden,
7552 qkv_w,
7553 out_w,
7554 mask,
7555 cos,
7556 sin,
7557 cos_len,
7558 ln1_g,
7559 ln1_b,
7560 eps1,
7561 fc11_w: fused_fc_w,
7562 fc12_w: 0,
7563 fc2_w,
7564 ln2_g,
7565 ln2_b,
7566 eps2,
7567 out,
7568 batch,
7569 seq,
7570 hs,
7571 nh,
7572 dh,
7573 int_dim,
7574 },
7575 ));
7576 Some(9)
7577 })();
7578 if let Some(n) = nomic_match {
7579 ai += n;
7580 continue;
7581 }
7582
7583 ai += 1;
7584 }
7585
7586 if !insertions.is_empty() {
7587 let mut new_thunks = Vec::with_capacity(thunks.len());
7588 let mut ins_idx = 0;
7589 for (i, t) in thunks.into_iter().enumerate() {
7590 if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
7591 new_thunks.push(insertions[ins_idx].1.clone());
7592 ins_idx += 1;
7593 }
7594 if !kill[i] {
7595 new_thunks.push(t);
7596 }
7597 }
7598 if cfg.verbose >= 1 {
7599 eprintln!(
7600 "[rlx] fused_layer: {} full transformer layers fused",
7601 insertions.len()
7602 );
7603 }
7604 thunks = new_thunks;
7605 }
7606 }
7607
7608 {
7620 let mut read_offsets: HashMap<usize, usize> = HashMap::new();
7623 for t in &thunks {
7624 for off in thunk_read_offsets(t) {
7625 *read_offsets.entry(off).or_insert(0) += 1;
7626 }
7627 }
7628
7629 let mut fused_count = 0usize;
7630 for i in 0..thunks.len().saturating_sub(1) {
7631 let narrow = match &thunks[i] {
7634 Thunk::Narrow { .. } => i,
7635 _ => continue,
7636 };
7637 let mut j = narrow + 1;
7639 while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
7640 j += 1;
7641 }
7642 if j >= thunks.len() {
7643 continue;
7644 }
7645 let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
7647 Thunk::Narrow {
7648 src,
7649 dst,
7650 src_stride,
7651 ..
7652 } => (*src, *dst, *src_stride),
7653 _ => continue,
7654 };
7655 let rope_reads_narrow = matches!(&thunks[j],
7656 Thunk::Rope { src, .. } if *src == n_dst);
7657 if !rope_reads_narrow {
7658 continue;
7659 }
7660 if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
7664 continue;
7665 }
7666
7667 if let Thunk::Rope {
7670 src,
7671 src_row_stride,
7672 ..
7673 } = &mut thunks[j]
7674 {
7675 *src = n_src;
7676 *src_row_stride = n_src_stride;
7677 }
7678 thunks[narrow] = Thunk::Nop;
7679 fused_count += 1;
7680 }
7681
7682 if fused_count > 0 && cfg.verbose >= 1 {
7683 eprintln!(
7684 "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
7685 fused_count
7686 );
7687 }
7688 }
7689
7690 {
7702 let mut read_counts: HashMap<usize, usize> = HashMap::new();
7703 for t in &thunks {
7704 for off in thunk_read_offsets(t) {
7705 *read_counts.entry(off).or_insert(0) += 1;
7706 }
7707 }
7708 let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
7710 for (i, t) in thunks.iter().enumerate() {
7711 if let Thunk::Narrow { dst, .. } = t {
7712 dst_to_idx.insert(*dst, i);
7713 }
7714 }
7715
7716 let mut fused_count = 0usize;
7717 for i in 0..thunks.len() {
7718 let (q_off, k_off, v_off) = match &thunks[i] {
7719 Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
7720 _ => continue,
7721 };
7722 let q_n = match dst_to_idx.get(&q_off).copied() {
7724 Some(x) => x,
7725 None => continue,
7726 };
7727 let k_n = match dst_to_idx.get(&k_off).copied() {
7728 Some(x) => x,
7729 None => continue,
7730 };
7731 let v_n = match dst_to_idx.get(&v_off).copied() {
7732 Some(x) => x,
7733 None => continue,
7734 };
7735 if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
7737 continue;
7738 }
7739 if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
7740 continue;
7741 }
7742 if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
7743 continue;
7744 }
7745
7746 let (q_src, q_stride) = match &thunks[q_n] {
7747 Thunk::Narrow {
7748 src, src_stride, ..
7749 } => (*src, *src_stride),
7750 _ => continue,
7751 };
7752 let (k_src, k_stride) = match &thunks[k_n] {
7753 Thunk::Narrow {
7754 src, src_stride, ..
7755 } => (*src, *src_stride),
7756 _ => continue,
7757 };
7758 let (v_src, v_stride) = match &thunks[v_n] {
7759 Thunk::Narrow {
7760 src, src_stride, ..
7761 } => (*src, *src_stride),
7762 _ => continue,
7763 };
7764
7765 if let Thunk::Attention {
7766 q,
7767 k,
7768 v,
7769 q_row_stride,
7770 k_row_stride,
7771 v_row_stride,
7772 ..
7773 } = &mut thunks[i]
7774 {
7775 *q = q_src;
7776 *k = k_src;
7777 *v = v_src;
7778 *q_row_stride = q_stride;
7779 *k_row_stride = k_stride;
7780 *v_row_stride = v_stride;
7781 }
7782 thunks[q_n] = Thunk::Nop;
7783 thunks[k_n] = Thunk::Nop;
7784 thunks[v_n] = Thunk::Nop;
7785 fused_count += 1;
7786 }
7787
7788 if fused_count > 0 && cfg.verbose >= 1 {
7789 eprintln!(
7790 "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
7791 fused_count
7792 );
7793 }
7794 }
7795
7796 ThunkSchedule {
7797 thunks,
7798 moe_resident: None,
7799 moe_resident_layers: None,
7800 moe_topk_capture: None,
7801 mask_threshold: cfg.mask_binary_threshold,
7802 mask_neg_inf: cfg.attn_mask_neg_inf,
7803 score_skip: cfg.score_skip_threshold,
7804 compiled_fns,
7805 rng: rng_shared,
7806 }
7807}
7808
7809fn get_len(graph: &Graph, id: NodeId) -> usize {
7810 graph.node(id).shape.num_elements().unwrap_or(0)
7811}
7812
7813fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
7815 let dims = graph.node(id).shape.dims();
7816 let mut out = Vec::with_capacity(dims.len());
7817 for d in dims {
7818 if let Some(s) = match d {
7819 rlx_ir::Dim::Static(s) => Some(*s),
7820 _ => None,
7821 } {
7822 out.push(s);
7823 } else {
7824 return Vec::new();
7825 }
7826 }
7827 out
7828}
7829
7830fn concat_axis_extent(input: &rlx_ir::Shape, axis: usize, out_rank: usize) -> usize {
7833 let in_rank = input.rank();
7834 if axis >= out_rank {
7835 return 1;
7836 }
7837 if axis < in_rank {
7838 input.dim(axis).unwrap_static()
7839 } else {
7840 1
7841 }
7842}
7843
7844fn broadcast_src_index(src_idx: usize, in_len: usize) -> usize {
7845 if in_len == 0 { 0 } else { src_idx % in_len }
7846}
7847
7848fn concat_copy_rows_f32(
7849 out: &mut [f32],
7850 inp: &[f32],
7851 outer: usize,
7852 copy_per_row: usize,
7853 row_stride: usize,
7854 dst_col_off: usize,
7855 in_numel: usize,
7856) {
7857 let need = outer.saturating_mul(copy_per_row.max(1));
7858 let broadcast_outer = in_numel < need;
7859 for o in 0..outer {
7860 let dst_row_start = o * row_stride + dst_col_off;
7861 if broadcast_outer {
7862 if in_numel == 1 {
7863 if copy_per_row == 1 {
7864 out[dst_row_start] = inp[0];
7865 } else {
7866 out[dst_row_start..dst_row_start + copy_per_row].fill(inp[0]);
7867 }
7868 } else if copy_per_row <= inp.len() {
7869 out[dst_row_start..dst_row_start + copy_per_row]
7870 .copy_from_slice(&inp[..copy_per_row]);
7871 } else if !inp.is_empty() {
7872 out[dst_row_start..dst_row_start + copy_per_row].fill(inp[0]);
7873 }
7874 } else {
7875 let src_row_start = o * copy_per_row;
7876 out[dst_row_start..dst_row_start + copy_per_row]
7877 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
7878 }
7879 }
7880}
7881
7882fn concat_copy_rows_f64(
7883 out: &mut [f64],
7884 inp: &[f64],
7885 outer: usize,
7886 copy_per_row: usize,
7887 row_stride: usize,
7888 dst_col_off: usize,
7889 in_numel: usize,
7890) {
7891 let need = outer.saturating_mul(copy_per_row.max(1));
7892 let broadcast_outer = in_numel < need;
7893 for o in 0..outer {
7894 let dst_row_start = o * row_stride + dst_col_off;
7895 if broadcast_outer {
7896 if in_numel == 1 {
7897 if copy_per_row == 1 {
7898 out[dst_row_start] = inp[0];
7899 } else {
7900 out[dst_row_start..dst_row_start + copy_per_row].fill(inp[0]);
7901 }
7902 } else if copy_per_row <= inp.len() {
7903 out[dst_row_start..dst_row_start + copy_per_row]
7904 .copy_from_slice(&inp[..copy_per_row]);
7905 } else if !inp.is_empty() {
7906 out[dst_row_start..dst_row_start + copy_per_row].fill(inp[0]);
7907 }
7908 } else {
7909 let src_row_start = o * copy_per_row;
7910 out[dst_row_start..dst_row_start + copy_per_row]
7911 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
7912 }
7913 }
7914}
7915
7916fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
7934 if rhs_dims.len() > out_dims.len() {
7935 return false;
7936 }
7937 let off = out_dims.len() - rhs_dims.len();
7938 for i in 0..rhs_dims.len() {
7939 let r = match rhs_dims[i] {
7940 rlx_ir::Dim::Static(n) => n,
7941 _ => return false,
7942 };
7943 let o = match out_dims[off + i] {
7944 rlx_ir::Dim::Static(n) => n,
7945 _ => return false,
7946 };
7947 if r != o {
7948 return false;
7949 }
7950 }
7951 true
7952}
7953
7954fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
7955 let r_out = out_dims.len();
7956 let r_in = in_dims.len();
7957 assert!(
7958 r_in <= r_out,
7959 "broadcast: input rank {r_in} > output rank {r_out}"
7960 );
7961 let pad = r_out - r_in;
7962 let mut strides = vec![0u32; r_out];
7963 let mut acc: usize = 1;
7964 for d in (0..r_out).rev() {
7965 let in_size = if d < pad { 1 } else { in_dims[d - pad] };
7966 if in_size == 1 {
7967 strides[d] = 0;
7968 } else {
7969 assert_eq!(
7970 in_size, out_dims[d],
7971 "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
7972 out_dims[d]
7973 );
7974 strides[d] = acc as u32;
7975 acc *= in_size;
7976 }
7977 }
7978 strides
7979}
7980
7981pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7985 let base = arena_buf.as_mut_ptr();
7986 for f in &schedule.compiled_fns {
7987 f(base);
7988 }
7989}
7990
7991pub fn execute_thunks_active(
7996 schedule: &ThunkSchedule,
7997 _arena_buf: &mut [u8],
7998 _actual: usize,
7999 _upper: usize,
8000) -> bool {
8001 let _ = schedule;
8002 false
8003}
8004
8005struct MoeResidencyGuard;
8007impl Drop for MoeResidencyGuard {
8008 fn drop(&mut self) {
8009 if let Some(stats) = crate::moe_residency::take_stats() {
8010 crate::moe_residency::stash_last_forward_stats(stats);
8011 } else {
8012 crate::moe_residency::clear_mask();
8013 }
8014 }
8015}
8016
8017fn thunk_kind_name(t: &Thunk) -> &'static str {
8018 match t {
8019 Thunk::Nop => "Nop",
8020 Thunk::Gather { .. } => "Gather",
8021 Thunk::GatherAxis { .. } => "GatherAxis",
8022 Thunk::TopK { .. } => "TopK",
8023 Thunk::Copy { .. } => "Copy",
8024 Thunk::CopyF64 { .. } => "CopyF64",
8025 Thunk::CopyI64 { .. } => "CopyI64",
8026 Thunk::CastF32ToI64 { .. } => "CastF32ToI64",
8027 Thunk::CastI64ToF32 { .. } => "CastI64ToF32",
8028 Thunk::CastBoolToI32 { .. } => "CastBoolToI32",
8029 Thunk::CastBoolToF32 { .. } => "CastBoolToF32",
8030 Thunk::CastI32ToF32 { .. } => "CastI32ToF32",
8031 Thunk::Transpose { .. } => "Transpose",
8032 Thunk::TransposeF64 { .. } => "TransposeF64",
8033 Thunk::Where { .. } => "Where",
8034 Thunk::Compare { .. } => "Compare",
8035 Thunk::BinaryFull { .. } => "BinaryFull",
8036 Thunk::BinaryFullF64 { .. } => "BinaryFullF64",
8037 Thunk::Sgemm { .. } => "Sgemm",
8038 Thunk::Dgemm { .. } => "Dgemm",
8039 Thunk::FusedMmBiasAct { .. } => "FusedMmBiasAct",
8040 Thunk::BiasAdd { .. } => "BiasAdd",
8041 Thunk::LayerNorm { .. } => "LayerNorm",
8042 Thunk::Softmax { .. } => "Softmax",
8043 Thunk::Conv2D { .. } => "Conv2D",
8044 Thunk::Conv2D1x1 { .. } => "Conv2D1x1",
8045 Thunk::CustomOp { .. } => "CustomOp",
8046 Thunk::ActivationInPlace { .. } => "ActivationInPlace",
8047 Thunk::Narrow { .. } => "Narrow",
8048 Thunk::Cumsum { .. } => "Cumsum",
8049 Thunk::Reduce { .. } => "Reduce",
8050 Thunk::BatchedSgemm { .. } => "BatchedSgemm",
8051 Thunk::DequantMatMul { .. } => "DequantMatMul",
8052 Thunk::Quantize { .. } => "Quantize",
8053 Thunk::Dequantize { .. } => "Dequantize",
8054 Thunk::ConvTranspose2d { .. } => "ConvTranspose2d",
8055 Thunk::ResizeNearest2x { .. } => "ResizeNearest2x",
8056 _ => "Other",
8057 }
8058}
8059
8060pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
8061 crate::moe_residency::reset_gmm_counters();
8062 if let Some(layers) = schedule.moe_resident_layers.clone() {
8063 crate::moe_residency::set_per_layer_masks(Some(layers));
8064 } else {
8065 crate::moe_residency::set_mask(schedule.moe_resident.clone());
8066 }
8067 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
8068 cap.clear();
8069 }
8070 let _moe_guard = MoeResidencyGuard;
8071 let base = arena_buf.as_mut_ptr();
8072 let mask_thr = schedule.mask_threshold;
8073 let mask_neg = schedule.mask_neg_inf;
8074 let score_thr = schedule.score_skip;
8075 let thunks = &schedule.thunks;
8076 let len = thunks.len();
8077
8078 let max_h = thunks
8080 .iter()
8081 .filter_map(|t| match t {
8082 Thunk::FusedResidualLN { h, .. }
8083 | Thunk::FusedResidualRmsNorm { h, .. }
8084 | Thunk::LayerNorm { h, .. } => Some(*h as usize),
8085 _ => None,
8086 })
8087 .max()
8088 .unwrap_or(0);
8089 let zero_bias = vec![0f32; max_h];
8090
8091 let max_sdpa = thunks
8094 .iter()
8095 .filter_map(|t| match t {
8096 Thunk::Attention {
8097 batch,
8098 seq,
8099 kv_seq,
8100 heads,
8101 head_dim,
8102 ..
8103 } => Some((
8104 *batch as usize,
8105 (*seq as usize).max(*kv_seq as usize),
8106 *heads as usize,
8107 *head_dim as usize,
8108 )),
8109 _ => None,
8110 })
8111 .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
8112 (mb.max(b), ms.max(s), mh.max(h), md.max(d))
8113 });
8114 let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
8115 let max_units = max_batch * max_heads;
8116 let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
8117
8118 let fl = thunks
8120 .iter()
8121 .filter_map(|t| match t {
8122 Thunk::FusedBertLayer {
8123 batch,
8124 seq,
8125 hs,
8126 int_dim,
8127 ..
8128 } => {
8129 let m = (*batch as usize) * (*seq as usize);
8130 let h = *hs as usize;
8131 let id = *int_dim as usize;
8132 Some((m, h, id, m * (*seq as usize)))
8133 }
8134 Thunk::FusedNomicLayer {
8135 batch,
8136 seq,
8137 hs,
8138 int_dim,
8139 ..
8140 } => {
8141 let m = (*batch as usize) * (*seq as usize);
8142 let h = *hs as usize;
8143 let id = *int_dim as usize;
8144 Some((m, h, id, m * (*seq as usize)))
8145 }
8146 _ => None,
8147 })
8148 .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
8149 (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
8150 });
8151 let (fl_m, fl_h, fl_int, fl_ss) = fl;
8152 let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
8153 let mut fl_attn = vec![0f32; fl_m * fl_h];
8154 let mut fl_res = vec![0f32; fl_m * fl_h];
8155 let mut fl_normed = vec![0f32; fl_m * fl_h];
8156 let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; let mut fl_sc = vec![0f32; fl_ss.max(1)];
8158
8159 let trace_thunks = std::env::var_os("RLX_TRACE_THUNK").is_some();
8160 if trace_thunks {
8161 eprintln!(
8162 "[thunk] prealloc max_h={max_h} sdpa={} fl_m={fl_m} fl_h={fl_h} fl_int={fl_int}",
8163 max_units * max_seq * max_seq
8164 );
8165 }
8166 for i in 0..len {
8167 let thunk = unsafe { thunks.get_unchecked(i) };
8168 if trace_thunks && (i < 120 || i % 200 == 0 || i + 1 == len) {
8169 eprintln!("[thunk {i}/{len}] {}", thunk_kind_name(thunk));
8170 }
8171 let trace_done = trace_thunks && i < 120;
8172 match thunk {
8173 Thunk::Nop => {}
8174
8175 Thunk::GaussianSplatRender {
8176 positions_off,
8177 positions_len,
8178 scales_off,
8179 scales_len,
8180 rotations_off,
8181 rotations_len,
8182 opacities_off,
8183 opacities_len,
8184 colors_off,
8185 colors_len,
8186 sh_coeffs_off,
8187 sh_coeffs_len,
8188 meta_off,
8189 dst_off,
8190 dst_len,
8191 width,
8192 height,
8193 tile_size,
8194 radius_scale,
8195 alpha_cutoff,
8196 max_splat_steps,
8197 transmittance_threshold,
8198 max_list_entries,
8199 } => unsafe {
8200 crate::splat::execute_gaussian_splat_render(
8201 *positions_off,
8202 *positions_len,
8203 *scales_off,
8204 *scales_len,
8205 *rotations_off,
8206 *rotations_len,
8207 *opacities_off,
8208 *opacities_len,
8209 *colors_off,
8210 *colors_len,
8211 *sh_coeffs_off,
8212 *sh_coeffs_len,
8213 *meta_off,
8214 *dst_off,
8215 *dst_len,
8216 *width,
8217 *height,
8218 *tile_size,
8219 *radius_scale,
8220 *alpha_cutoff,
8221 *max_splat_steps,
8222 *transmittance_threshold,
8223 *max_list_entries,
8224 base,
8225 );
8226 },
8227
8228 Thunk::GaussianSplatRenderBackward {
8229 positions_off,
8230 positions_len,
8231 scales_off,
8232 scales_len,
8233 rotations_off,
8234 rotations_len,
8235 opacities_off,
8236 opacities_len,
8237 colors_off,
8238 colors_len,
8239 sh_coeffs_off,
8240 sh_coeffs_len,
8241 meta_off,
8242 d_loss_off,
8243 d_loss_len,
8244 packed_off,
8245 packed_len,
8246 width,
8247 height,
8248 tile_size,
8249 radius_scale,
8250 alpha_cutoff,
8251 max_splat_steps,
8252 transmittance_threshold,
8253 max_list_entries,
8254 loss_grad_clip,
8255 sh_band,
8256 max_anisotropy,
8257 } => unsafe {
8258 crate::splat::execute_gaussian_splat_render_backward(
8259 *positions_off,
8260 *positions_len,
8261 *scales_off,
8262 *scales_len,
8263 *rotations_off,
8264 *rotations_len,
8265 *opacities_off,
8266 *opacities_len,
8267 *colors_off,
8268 *colors_len,
8269 *sh_coeffs_off,
8270 *sh_coeffs_len,
8271 *meta_off,
8272 *d_loss_off,
8273 *d_loss_len,
8274 *packed_off,
8275 *packed_len,
8276 *width,
8277 *height,
8278 *tile_size,
8279 *radius_scale,
8280 *alpha_cutoff,
8281 *max_splat_steps,
8282 *transmittance_threshold,
8283 *max_list_entries,
8284 *loss_grad_clip,
8285 *sh_band,
8286 *max_anisotropy,
8287 base,
8288 );
8289 },
8290
8291 Thunk::GaussianSplatPrepare {
8292 positions_off,
8293 positions_len,
8294 scales_off,
8295 scales_len,
8296 rotations_off,
8297 rotations_len,
8298 opacities_off,
8299 opacities_len,
8300 colors_off,
8301 colors_len,
8302 sh_coeffs_off,
8303 sh_coeffs_len,
8304 meta_off,
8305 meta_len,
8306 prep_off,
8307 prep_len,
8308 width,
8309 height,
8310 tile_size,
8311 radius_scale,
8312 alpha_cutoff,
8313 max_splat_steps,
8314 transmittance_threshold,
8315 max_list_entries,
8316 } => unsafe {
8317 crate::splat::execute_gaussian_splat_prepare(
8318 *positions_off,
8319 *positions_len,
8320 *scales_off,
8321 *scales_len,
8322 *rotations_off,
8323 *rotations_len,
8324 *opacities_off,
8325 *opacities_len,
8326 *colors_off,
8327 *colors_len,
8328 *sh_coeffs_off,
8329 *sh_coeffs_len,
8330 *meta_off,
8331 *meta_len,
8332 *prep_off,
8333 *prep_len,
8334 *width,
8335 *height,
8336 *tile_size,
8337 *radius_scale,
8338 *alpha_cutoff,
8339 *max_splat_steps,
8340 *transmittance_threshold,
8341 *max_list_entries,
8342 base,
8343 );
8344 },
8345
8346 Thunk::GaussianSplatRasterize {
8347 prep_off,
8348 prep_len,
8349 meta_off,
8350 meta_len,
8351 dst_off,
8352 dst_len,
8353 count,
8354 width,
8355 height,
8356 tile_size,
8357 alpha_cutoff,
8358 max_splat_steps,
8359 transmittance_threshold,
8360 max_list_entries,
8361 } => unsafe {
8362 crate::splat::execute_gaussian_splat_rasterize(
8363 *prep_off,
8364 *prep_len,
8365 *meta_off,
8366 *meta_len,
8367 *dst_off,
8368 *dst_len,
8369 *count,
8370 *width,
8371 *height,
8372 *tile_size,
8373 *alpha_cutoff,
8374 *max_splat_steps,
8375 *transmittance_threshold,
8376 *max_list_entries,
8377 base,
8378 );
8379 },
8380
8381 Thunk::Fft1d {
8382 src,
8383 dst,
8384 outer,
8385 n_complex,
8386 inverse,
8387 norm_tag,
8388 dtype,
8389 } => unsafe {
8390 match dtype {
8391 rlx_ir::DType::F64 => execute_fft1d_f64(
8392 *src,
8393 *dst,
8394 *outer as usize,
8395 *n_complex as usize,
8396 *inverse,
8397 *norm_tag,
8398 base,
8399 ),
8400 rlx_ir::DType::F32 => execute_fft1d_f32(
8401 *src,
8402 *dst,
8403 *outer as usize,
8404 *n_complex as usize,
8405 *inverse,
8406 *norm_tag,
8407 base,
8408 ),
8409 rlx_ir::DType::C64 => execute_fft1d_c64(
8410 *src,
8411 *dst,
8412 *outer as usize,
8413 *n_complex as usize,
8414 *inverse,
8415 *norm_tag,
8416 base,
8417 ),
8418 other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
8419 }
8420 },
8421
8422 Thunk::FftButterflyStage {
8423 state_src,
8424 state_dst,
8425 gate_src,
8426 rev_src,
8427 tw_re_src,
8428 tw_im_src,
8429 batch,
8430 n_fft,
8431 stage,
8432 } => unsafe {
8433 execute_fft_butterfly_stage_f32(
8434 *state_src,
8435 *state_dst,
8436 *gate_src,
8437 *rev_src,
8438 *tw_re_src,
8439 *tw_im_src,
8440 *batch as usize,
8441 *n_fft as usize,
8442 *stage as usize,
8443 base,
8444 );
8445 },
8446
8447 Thunk::LogMel {
8448 spec,
8449 filters,
8450 dst,
8451 outer,
8452 n_fft,
8453 n_bins,
8454 n_mels,
8455 } => unsafe {
8456 execute_log_mel_f32(
8457 *spec,
8458 *filters,
8459 *dst,
8460 *outer as usize,
8461 *n_fft as usize,
8462 *n_bins as usize,
8463 *n_mels as usize,
8464 base,
8465 );
8466 },
8467
8468 Thunk::LogMelBackward {
8469 spec,
8470 filters,
8471 dy,
8472 dst,
8473 outer,
8474 n_fft,
8475 n_bins,
8476 n_mels,
8477 } => unsafe {
8478 execute_log_mel_backward_f32(
8479 *spec,
8480 *filters,
8481 *dy,
8482 *dst,
8483 *outer as usize,
8484 *n_fft as usize,
8485 *n_bins as usize,
8486 *n_mels as usize,
8487 base,
8488 );
8489 },
8490
8491 Thunk::WelchPeaks {
8492 spec,
8493 dst,
8494 welch_batch,
8495 n_fft,
8496 n_segments,
8497 k,
8498 } => unsafe {
8499 execute_welch_peaks_f32(
8500 *spec,
8501 *dst,
8502 *welch_batch as usize,
8503 *n_fft as usize,
8504 *n_segments as usize,
8505 *k as usize,
8506 base,
8507 );
8508 },
8509
8510 Thunk::CustomFn {
8514 body,
8515 body_init,
8516 inputs,
8517 body_output_off,
8518 outer_output_off,
8519 out_bytes,
8520 } => {
8521 let mut body_buf: Vec<u8> = (**body_init).clone();
8522 unsafe {
8523 for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
8524 let src = (base as *const u8).add(*outer_in_off);
8525 let dst = body_buf.as_mut_ptr().add(*body_in_off);
8526 std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
8527 }
8528 }
8529 execute_thunks(body, &mut body_buf);
8530 unsafe {
8531 let src = body_buf.as_ptr().add(*body_output_off);
8532 let dst = base.add(*outer_output_off);
8533 std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
8534 }
8535 }
8536
8537 Thunk::Sgemm { a, b, c, m, k, n } => {
8538 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8539 if trace_thunks {
8540 eprintln!("[sgemm] m={m} k={k} n={n} a={} b={} c={}", *a, *b, *c);
8541 }
8542 let c_len = m.saturating_mul(n);
8543 let a_len = m.saturating_mul(k);
8544 let b_len = k.saturating_mul(n);
8545 let arena_len = arena_buf.len();
8546 let max_a = (arena_len.saturating_sub(*a)) / 4;
8547 let max_b = (arena_len.saturating_sub(*b)) / 4;
8548 let max_c = (arena_len.saturating_sub(*c)) / 4;
8549 let a_len = a_len.min(max_a);
8550 let b_len = b_len.min(max_b);
8551 let c_len = c_len.min(max_c);
8552 unsafe {
8553 let a_sl = sl(*a, base, a_len);
8554 let b_sl = sl(*b, base, b_len);
8555 let c_sl = sl_mut(*c, base, c_len);
8556 if std::ptr::eq(a_sl.as_ptr(), c_sl.as_ptr())
8557 || std::ptr::eq(b_sl.as_ptr(), c_sl.as_ptr())
8558 {
8559 let mut tmp = vec![0.0f32; c_len];
8560 crate::blas::sgemm_auto(a_sl, b_sl, &mut tmp, m, k, n);
8561 c_sl.copy_from_slice(&tmp);
8562 } else {
8563 crate::blas::sgemm_auto(a_sl, b_sl, c_sl, m, k, n);
8564 }
8565 }
8566 }
8567
8568 Thunk::CgemmC64 { a, b, c, m, k, n } => unsafe {
8569 cgemm_c64(*a, *b, *c, *m as usize, *k as usize, *n as usize, base);
8570 },
8571
8572 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
8573 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8574 unsafe {
8580 let a_src = sl_f64(*a, base, n_ * n_);
8581 let b_src = sl_f64(*b, base, n_ * nrhs_);
8582 let mut a_scratch: Vec<f64> = a_src.to_vec();
8583 let mut x_buf: Vec<f64> = b_src.to_vec();
8584 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8585 if info != 0 {
8586 panic!(
8587 "DenseSolveF64: dgesv reported singular matrix \
8588 (info={info}, n={n_}, nrhs={nrhs_})"
8589 );
8590 }
8591 let dst = sl_mut_f64(*x, base, n_ * nrhs_);
8592 dst.copy_from_slice(&x_buf);
8593 }
8594 }
8595
8596 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
8597 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8598 unsafe {
8599 let a_src = sl(*a, base, n_ * n_);
8600 let b_src = sl(*b, base, n_ * nrhs_);
8601 let mut a_scratch: Vec<f32> = a_src.to_vec();
8602 let mut x_buf: Vec<f32> = b_src.to_vec();
8603 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8604 if info != 0 {
8605 panic!(
8606 "DenseSolveF32: sgesv reported singular matrix \
8607 (info={info}, n={n_}, nrhs={nrhs_})"
8608 );
8609 }
8610 let dst = sl_mut(*x, base, n_ * nrhs_);
8611 dst.copy_from_slice(&x_buf);
8612 }
8613 }
8614
8615 Thunk::BatchedDenseSolveF64 {
8616 a,
8617 b,
8618 x,
8619 batch,
8620 n,
8621 nrhs,
8622 } => {
8623 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8630 let a_stride = n_ * n_;
8631 let b_stride = n_ * nrhs_;
8632 unsafe {
8633 let a_full = sl_f64(*a, base, b_ * a_stride);
8634 let b_full = sl_f64(*b, base, b_ * b_stride);
8635 let x_full = sl_mut_f64(*x, base, b_ * b_stride);
8636 for bi in 0..b_ {
8637 let mut a_scratch: Vec<f64> =
8638 a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8639 let mut x_buf: Vec<f64> =
8640 b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8641 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8642 if info != 0 {
8643 panic!(
8644 "BatchedDenseSolveF64: slice {bi} \
8645 singular (info={info}, n={n_}, nrhs={nrhs_})"
8646 );
8647 }
8648 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8649 }
8650 }
8651 }
8652
8653 Thunk::BatchedDenseSolveF32 {
8654 a,
8655 b,
8656 x,
8657 batch,
8658 n,
8659 nrhs,
8660 } => {
8661 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8662 let a_stride = n_ * n_;
8663 let b_stride = n_ * nrhs_;
8664 unsafe {
8665 let a_full = sl(*a, base, b_ * a_stride);
8666 let b_full = sl(*b, base, b_ * b_stride);
8667 let x_full = sl_mut(*x, base, b_ * b_stride);
8668 for bi in 0..b_ {
8669 let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8670 let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8671 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8672 if info != 0 {
8673 panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
8674 }
8675 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8676 }
8677 }
8678 }
8679
8680 Thunk::BatchedDgemmF64 {
8681 a,
8682 b,
8683 c,
8684 batch,
8685 m,
8686 k,
8687 n,
8688 } => {
8689 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8690 let a_stride = m_ * k_;
8691 let b_stride = k_ * n_;
8692 let c_stride = m_ * n_;
8693 unsafe {
8694 let a_full = sl_f64(*a, base, b_ * a_stride);
8695 let b_full = sl_f64(*b, base, b_ * b_stride);
8696 let c_full = sl_mut_f64(*c, base, b_ * c_stride);
8697 for bi in 0..b_ {
8698 let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
8699 let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
8700 let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
8701 crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
8702 }
8703 }
8704 }
8705
8706 Thunk::BatchedSgemm {
8707 a,
8708 b,
8709 c,
8710 batch,
8711 m,
8712 k,
8713 n,
8714 } => {
8715 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8716 if trace_thunks {
8717 eprintln!(
8718 "[batched-sgemm] batch={b_} m={m_} k={k_} n={n_} a={} b={} c={}",
8719 *a, *b, *c
8720 );
8721 }
8722 let a_stride = m_.saturating_mul(k_);
8723 let b_stride = k_.saturating_mul(n_);
8724 let c_stride = m_.saturating_mul(n_);
8725 let arena_len = arena_buf.len();
8726 let a_cap = (arena_len.saturating_sub(*a)) / 4;
8727 let b_cap = (arena_len.saturating_sub(*b)) / 4;
8728 let c_cap = (arena_len.saturating_sub(*c)) / 4;
8729 let a_elems = (b_ * a_stride).min(a_cap);
8730 let b_elems = (b_ * b_stride).min(b_cap);
8731 let c_elems = (b_ * c_stride).min(c_cap);
8732 let b_eff = b_
8733 .min(a_elems.checked_div(a_stride).unwrap_or(0))
8734 .min(b_elems.checked_div(b_stride).unwrap_or(0))
8735 .min(c_elems.checked_div(c_stride).unwrap_or(0));
8736 unsafe {
8737 let a_full = sl(*a, base, a_elems);
8738 let b_full = sl(*b, base, b_elems);
8739 let c_full = sl_mut(*c, base, c_elems);
8740 for bi in 0..b_eff {
8741 let a0 = bi * a_stride;
8742 let b0 = bi * b_stride;
8743 let c0 = bi * c_stride;
8744 if a0 + a_stride > a_full.len()
8745 || b0 + b_stride > b_full.len()
8746 || c0 + c_stride > c_full.len()
8747 {
8748 break;
8749 }
8750 let a_slice = &a_full[a0..a0 + a_stride];
8751 let b_slice = &b_full[b0..b0 + b_stride];
8752 let c_slice = &mut c_full[c0..c0 + c_stride];
8753 if std::ptr::eq(a_slice.as_ptr(), c_slice.as_mut_ptr())
8754 || std::ptr::eq(b_slice.as_ptr(), c_slice.as_mut_ptr())
8755 {
8756 let mut tmp = vec![0.0f32; c_stride];
8757 crate::blas::sgemm_auto(a_slice, b_slice, &mut tmp, m_, k_, n_);
8758 c_slice.copy_from_slice(&tmp);
8759 } else {
8760 crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
8761 }
8762 }
8763 }
8764 }
8765
8766 Thunk::Dgemm { a, b, c, m, k, n } => {
8767 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8768 unsafe {
8769 crate::blas::dgemm(
8770 sl_f64(*a, base, m * k),
8771 sl_f64(*b, base, k * n),
8772 sl_mut_f64(*c, base, m * n),
8773 m,
8774 k,
8775 n,
8776 );
8777 }
8778 }
8779
8780 Thunk::TransposeF64 {
8781 src,
8782 dst,
8783 in_total,
8784 out_dims,
8785 in_strides,
8786 } => unsafe {
8787 let inp = sl_f64(*src, base, *in_total as usize);
8788 let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
8789 let out = sl_mut_f64(*dst, base, out_total);
8790 transpose_walk_f64(inp, out, out_dims, in_strides);
8791 },
8792
8793 Thunk::ActivationF64 {
8794 src,
8795 dst,
8796 len,
8797 kind,
8798 } => {
8799 let len = *len as usize;
8800 unsafe {
8801 let inp = sl_f64(*src, base, len);
8802 let out = sl_mut_f64(*dst, base, len);
8803 apply_activation_f64(inp, out, *kind);
8804 }
8805 }
8806
8807 Thunk::ReduceSumF64 {
8808 src,
8809 dst,
8810 outer,
8811 reduced,
8812 inner,
8813 } => {
8814 let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
8815 unsafe {
8816 let inp = sl_f64(*src, base, o * r * n);
8817 let out = sl_mut_f64(*dst, base, o * n);
8818 reduce_sum_f64(inp, out, o, r, n);
8819 }
8820 }
8821
8822 Thunk::CopyF64 { src, dst, len } => {
8823 let mut len = *len as usize;
8824 if *src == *dst || len == 0 {
8825 continue;
8826 }
8827 let arena_len = arena_buf.len();
8828 let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8829 let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8830 len = len.min(max_from_src).min(max_from_dst);
8831 if len == 0 {
8832 continue;
8833 }
8834 let byte_len = len.saturating_mul(8);
8835 unsafe {
8836 std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8837 }
8838 }
8839
8840 Thunk::CopyI64 { src, dst, len } => {
8841 let mut len = *len as usize;
8842 if *src == *dst || len == 0 {
8843 continue;
8844 }
8845 let arena_len = arena_buf.len();
8846 let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8847 let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8848 len = len.min(max_from_src).min(max_from_dst);
8849 if len == 0 {
8850 continue;
8851 }
8852 let byte_len = len.saturating_mul(8);
8853 unsafe {
8854 std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8855 }
8856 }
8857
8858 Thunk::CastF32ToI64 { src, dst, len } => {
8859 let len = *len as usize;
8860 if len == 0 {
8861 continue;
8862 }
8863 unsafe {
8864 let inp = sl(*src, base, len);
8865 let out = sl_mut_i64(*dst, base, len);
8866 for i in 0..len {
8867 out[i] = inp[i].round() as i64;
8868 }
8869 }
8870 }
8871
8872 Thunk::CastF32ToF64 { src, dst, len } => {
8873 let len = *len as usize;
8874 if len == 0 {
8875 continue;
8876 }
8877 unsafe {
8878 let inp = sl(*src, base, len);
8879 let out = sl_mut_f64(*dst, base, len);
8880 for i in 0..len {
8881 out[i] = inp[i] as f64;
8882 }
8883 }
8884 }
8885
8886 Thunk::CastF32ToI32 { src, dst, len } => {
8887 let len = *len as usize;
8888 if len == 0 {
8889 continue;
8890 }
8891 unsafe {
8892 let inp = sl(*src, base, len);
8893 let out = sl_mut_i32(*dst, base, len);
8894 for i in 0..len {
8895 out[i] = inp[i].round() as i32;
8896 }
8897 }
8898 }
8899
8900 Thunk::CastI64ToF32 { src, dst, len } => {
8901 let len = *len as usize;
8902 if len == 0 {
8903 continue;
8904 }
8905 unsafe {
8906 let inp = sl_i64(*src, base, len);
8907 let out = sl_mut(*dst, base, len);
8908 for i in 0..len {
8909 out[i] = inp[i] as f32;
8910 }
8911 }
8912 }
8913
8914 Thunk::CastBoolToI32 { src, dst, len } => {
8915 let len = *len as usize;
8916 if len == 0 {
8917 continue;
8918 }
8919 unsafe {
8920 let inp = &arena_buf[*src..*src + len];
8921 let out = sl_mut_i32(*dst, base, len);
8922 for i in 0..len {
8923 out[i] = i32::from(inp[i] != 0);
8924 }
8925 }
8926 }
8927
8928 Thunk::CastI32ToF32 { src, dst, len } => {
8929 let len = *len as usize;
8930 if len == 0 {
8931 continue;
8932 }
8933 unsafe {
8934 let inp = sl_i32(*src, base, len);
8935 let out = sl_mut(*dst, base, len);
8936 for i in 0..len {
8937 out[i] = inp[i] as f32;
8938 }
8939 }
8940 }
8941
8942 Thunk::CastBoolToF32 { src, dst, len } => {
8943 let len = *len as usize;
8944 if len == 0 {
8945 continue;
8946 }
8947 unsafe {
8948 let inp = &arena_buf[*src..*src + len];
8949 let out = sl_mut(*dst, base, len);
8950 for i in 0..len {
8951 out[i] = if inp[i] != 0 { 1.0 } else { 0.0 };
8952 }
8953 }
8954 }
8955
8956 Thunk::BinaryFullF64 {
8957 lhs,
8958 rhs,
8959 dst,
8960 len,
8961 lhs_len,
8962 rhs_len,
8963 op,
8964 out_dims_bcast,
8965 bcast_lhs_strides,
8966 bcast_rhs_strides,
8967 } => {
8968 let len = *len as usize;
8969 let lhs_len = *lhs_len as usize;
8970 let rhs_len = *rhs_len as usize;
8971 unsafe {
8972 let l = sl_f64(*lhs, base, lhs_len);
8973 let r = sl_f64(*rhs, base, rhs_len);
8974 let d = sl_mut_f64(*dst, base, len);
8975 if lhs_len == len && rhs_len == len {
8976 for i in 0..len {
8977 d[i] = binary_op_f64(*op, l[i], r[i]);
8978 }
8979 } else if !out_dims_bcast.is_empty() {
8980 let rank = out_dims_bcast.len();
8984 let mut coords = vec![0u32; rank];
8985 for i in 0..len {
8986 let mut rem = i;
8987 for ax in (0..rank).rev() {
8988 let sz = out_dims_bcast[ax] as usize;
8989 coords[ax] = (rem % sz) as u32;
8990 rem /= sz;
8991 }
8992 let mut li: usize = 0;
8993 let mut ri: usize = 0;
8994 for ax in 0..rank {
8995 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8996 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8997 }
8998 d[i] = binary_op_f64(*op, l[li], r[ri]);
8999 }
9000 } else {
9001 for i in 0..len {
9006 d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
9007 }
9008 }
9009 }
9010 }
9011
9012 Thunk::BinaryFullC64 {
9013 lhs,
9014 rhs,
9015 dst,
9016 len,
9017 lhs_len,
9018 rhs_len,
9019 op,
9020 out_dims_bcast,
9021 bcast_lhs_strides,
9022 bcast_rhs_strides,
9023 } => {
9024 let n_out = *len as usize;
9030 let n_l = *lhs_len as usize;
9031 let n_r = *rhs_len as usize;
9032 unsafe {
9033 let l = sl(*lhs, base, 2 * n_l);
9034 let r = sl(*rhs, base, 2 * n_r);
9035 let d = sl_mut(*dst, base, 2 * n_out);
9036 let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
9037 match op {
9038 BinaryOp::Add => (a_re + b_re, a_im + b_im),
9039 BinaryOp::Sub => (a_re - b_re, a_im - b_im),
9040 BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
9041 BinaryOp::Div => {
9042 let denom = b_re * b_re + b_im * b_im;
9043 (
9044 (a_re * b_re + a_im * b_im) / denom,
9045 (a_im * b_re - a_re * b_im) / denom,
9046 )
9047 }
9048 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
9049 unreachable!("C64 max/min/pow rejected at lowering")
9050 }
9051 }
9052 };
9053 if n_l == n_out && n_r == n_out {
9054 for i in 0..n_out {
9055 let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
9056 d[2 * i] = re;
9057 d[2 * i + 1] = im;
9058 }
9059 } else if !out_dims_bcast.is_empty() {
9060 let rank = out_dims_bcast.len();
9064 let mut coords = vec![0u32; rank];
9065 for i in 0..n_out {
9066 let mut rem = i;
9067 for ax in (0..rank).rev() {
9068 let sz = out_dims_bcast[ax] as usize;
9069 coords[ax] = (rem % sz) as u32;
9070 rem /= sz;
9071 }
9072 let mut li: usize = 0;
9073 let mut ri: usize = 0;
9074 for ax in 0..rank {
9075 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9076 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9077 }
9078 let (re, im) =
9079 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
9080 d[2 * i] = re;
9081 d[2 * i + 1] = im;
9082 }
9083 } else {
9084 for i in 0..n_out {
9086 let li = if n_l == 1 { 0 } else { i % n_l };
9087 let ri = if n_r == 1 { 0 } else { i % n_r };
9088 let (re, im) =
9089 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
9090 d[2 * i] = re;
9091 d[2 * i + 1] = im;
9092 }
9093 }
9094 }
9095 }
9096
9097 Thunk::ComplexNormSqF32 { src, dst, len } => {
9098 let n = *len as usize;
9099 unsafe {
9100 let s = sl(*src, base, 2 * n);
9101 let d = sl_mut(*dst, base, n);
9102 for i in 0..n {
9103 let re = s[2 * i];
9104 let im = s[2 * i + 1];
9105 d[i] = re * re + im * im;
9106 }
9107 }
9108 }
9109
9110 Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
9111 let n = *len as usize;
9114 unsafe {
9115 let zb = sl(*z, base, 2 * n);
9116 let gb = sl(*g, base, n);
9117 let db = sl_mut(*dz, base, 2 * n);
9118 for i in 0..n {
9119 let re = zb[2 * i];
9120 let im = zb[2 * i + 1];
9121 let gv = gb[i];
9122 db[2 * i] = gv * re;
9123 db[2 * i + 1] = gv * im;
9124 }
9125 }
9126 }
9127
9128 Thunk::ConjugateC64 { src, dst, len } => {
9129 let n = *len as usize;
9130 unsafe {
9131 let s = sl(*src, base, 2 * n);
9132 let d = sl_mut(*dst, base, 2 * n);
9133 for i in 0..n {
9134 d[2 * i] = s[2 * i];
9135 d[2 * i + 1] = -s[2 * i + 1];
9136 }
9137 }
9138 }
9139
9140 Thunk::ActivationC64 {
9141 src,
9142 dst,
9143 len,
9144 kind,
9145 } => {
9146 let n = *len as usize;
9147 unsafe {
9148 let s = sl(*src, base, 2 * n);
9149 let d = sl_mut(*dst, base, 2 * n);
9150 for i in 0..n {
9151 let a = s[2 * i];
9152 let b = s[2 * i + 1];
9153 let (re, im) = match kind {
9154 Activation::Neg => (-a, -b),
9155 Activation::Exp => {
9156 let ea = a.exp();
9158 (ea * b.cos(), ea * b.sin())
9159 }
9160 Activation::Log => {
9161 let r = (a * a + b * b).sqrt();
9163 (r.ln(), b.atan2(a))
9164 }
9165 Activation::Sqrt => {
9166 let r = (a * a + b * b).sqrt();
9169 let re = ((r + a) * 0.5).max(0.0).sqrt();
9170 let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
9171 let im = if b >= 0.0 { im_mag } else { -im_mag };
9172 (re, im)
9173 }
9174 _ => unreachable!("non-C64 activation kind survived lowering"),
9175 };
9176 d[2 * i] = re;
9177 d[2 * i + 1] = im;
9178 }
9179 }
9180 }
9181
9182 Thunk::Scan {
9183 body,
9184 body_init,
9185 body_input_off,
9186 body_output_off,
9187 outer_init_off,
9188 outer_final_off,
9189 length,
9190 carry_bytes,
9191 save_trajectory,
9192 xs_inputs,
9193 bcast_inputs,
9194 num_checkpoints,
9195 } => {
9196 let cb = *carry_bytes as usize;
9197 let n_steps = *length as usize;
9198 let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
9202 n_steps } else {
9204 *num_checkpoints as usize
9205 };
9206 let checkpoint_t_for_k = |k: usize| -> usize {
9207 if k_total == n_steps {
9208 k
9209 } else {
9210 ((k + 1) * n_steps)
9211 .div_ceil(k_total)
9212 .saturating_sub(1)
9213 .min(n_steps - 1)
9214 }
9215 };
9216 let mut next_k = 0usize;
9217
9218 let mut body_buf: Vec<u8> = (**body_init).clone();
9219 unsafe {
9220 std::ptr::copy_nonoverlapping(
9221 base.add(*outer_init_off),
9222 body_buf.as_mut_ptr().add(*body_input_off),
9223 cb,
9224 );
9225 for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
9229 std::ptr::copy_nonoverlapping(
9230 base.add(*outer_b_off),
9231 body_buf.as_mut_ptr().add(*body_b_off),
9232 *total_bytes as usize,
9233 );
9234 }
9235 }
9236 for t in 0..n_steps {
9237 for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
9238 let psb = *per_step_bytes as usize;
9239 unsafe {
9240 std::ptr::copy_nonoverlapping(
9241 base.add(*outer_xs_off + t * psb),
9242 body_buf.as_mut_ptr().add(*body_x_off),
9243 psb,
9244 );
9245 }
9246 }
9247
9248 execute_thunks(body, &mut body_buf);
9249
9250 if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
9251 unsafe {
9252 std::ptr::copy_nonoverlapping(
9253 body_buf.as_ptr().add(*body_output_off),
9254 base.add(*outer_final_off + next_k * cb),
9255 cb,
9256 );
9257 }
9258 next_k += 1;
9259 }
9260
9261 if *body_output_off != *body_input_off {
9262 body_buf
9263 .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
9264 }
9265 }
9266
9267 if !*save_trajectory {
9268 unsafe {
9270 std::ptr::copy_nonoverlapping(
9271 body_buf.as_ptr().add(*body_output_off),
9272 base.add(*outer_final_off),
9273 cb,
9274 );
9275 }
9276 }
9277 }
9278
9279 Thunk::ScanBackward {
9280 body_vjp,
9281 body_init,
9282 body_carry_in_off,
9283 body_x_offs,
9284 body_d_output_off,
9285 body_dcarry_out_off,
9286 outer_init_off,
9287 outer_traj_off,
9288 outer_upstream_off,
9289 outer_xs_offs,
9290 outer_dinit_off,
9291 length,
9292 carry_bytes,
9293 save_trajectory,
9294 num_checkpoints,
9295 forward_body,
9296 forward_body_init,
9297 forward_body_carry_in_off,
9298 forward_body_output_off,
9299 forward_body_x_offs,
9300 carry_elem_size,
9301 } => {
9302 let cb = *carry_bytes as usize;
9315 let n_steps = *length as usize;
9316 let k_total = *num_checkpoints as usize;
9317 let is_recursive = k_total != 0 && k_total != n_steps;
9318 let checkpoint_t_for_k = |k: usize| -> usize {
9319 ((k + 1) * n_steps)
9320 .div_ceil(k_total)
9321 .saturating_sub(1)
9322 .min(n_steps - 1)
9323 };
9324
9325 let mut fwd_buf: Vec<u8> = if is_recursive {
9326 (**forward_body_init.as_ref().unwrap()).clone()
9327 } else {
9328 Vec::new()
9329 };
9330
9331 let mut dcarry: Vec<u8> = vec![0u8; cb];
9332 if !*save_trajectory {
9333 unsafe {
9334 std::ptr::copy_nonoverlapping(
9335 base.add(*outer_upstream_off),
9336 dcarry.as_mut_ptr(),
9337 cb,
9338 );
9339 }
9340 }
9341
9342 let mut body_buf: Vec<u8> = (**body_init).clone();
9343
9344 let process_iter =
9349 |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
9350 if *save_trajectory {
9351 unsafe {
9352 let up_off = *outer_upstream_off + t * cb;
9353 match *carry_elem_size {
9354 4 => {
9355 let up_ptr = base.add(up_off) as *const f32;
9356 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
9357 let n_elems = cb / 4;
9358 for i in 0..n_elems {
9359 *dc_ptr.add(i) += *up_ptr.add(i);
9360 }
9361 }
9362 8 => {
9363 let up_ptr = base.add(up_off) as *const f64;
9364 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
9365 let n_elems = cb / 8;
9366 for i in 0..n_elems {
9367 *dc_ptr.add(i) += *up_ptr.add(i);
9368 }
9369 }
9370 other => panic!(
9371 "ScanBackward: unsupported carry elem size {other} \
9372 (only f32/f64 carries are supported today)"
9373 ),
9374 }
9375 }
9376 }
9377 body_buf[*body_carry_in_off..*body_carry_in_off + cb]
9378 .copy_from_slice(carry_in);
9379 unsafe {
9380 for (i, body_x_off) in body_x_offs.iter().enumerate() {
9381 let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
9382 let psb = per_step_bytes as usize;
9383 std::ptr::copy_nonoverlapping(
9384 base.add(outer_xs_off + t * psb),
9385 body_buf.as_mut_ptr().add(*body_x_off),
9386 psb,
9387 );
9388 }
9389 std::ptr::copy_nonoverlapping(
9390 dcarry.as_ptr(),
9391 body_buf.as_mut_ptr().add(*body_d_output_off),
9392 cb,
9393 );
9394 }
9395 execute_thunks(body_vjp, body_buf);
9396 unsafe {
9397 std::ptr::copy_nonoverlapping(
9398 body_buf.as_ptr().add(*body_dcarry_out_off),
9399 dcarry.as_mut_ptr(),
9400 cb,
9401 );
9402 }
9403 };
9404
9405 if is_recursive {
9406 let leaf_threshold = 4usize;
9414 let fb_sched = forward_body.as_ref().unwrap();
9415 let fb_init = forward_body_init.as_ref().unwrap().as_slice();
9416 let mut segment_end = n_steps - 1;
9417 for seg_k in (0..k_total).rev() {
9418 let segment_start = if seg_k == 0 {
9419 0
9420 } else {
9421 checkpoint_t_for_k(seg_k - 1) + 1
9422 };
9423 let mut anchor: Vec<u8> = vec![0u8; cb];
9424 unsafe {
9425 let src = if seg_k == 0 {
9426 base.add(*outer_init_off)
9427 } else {
9428 base.add(*outer_traj_off + (seg_k - 1) * cb)
9429 };
9430 std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
9431 }
9432 let mut leaf_action = |t: usize, carry_in: &[u8]| {
9435 process_iter(t, carry_in, &mut dcarry, &mut body_buf);
9436 };
9437 unsafe {
9438 griewank_process_segment(
9439 segment_start,
9440 segment_end,
9441 &anchor,
9442 cb,
9443 fb_sched,
9444 fb_init,
9445 *forward_body_carry_in_off,
9446 *forward_body_output_off,
9447 forward_body_x_offs,
9448 base,
9449 outer_xs_offs,
9450 &mut fwd_buf,
9451 leaf_threshold,
9452 &mut leaf_action,
9453 );
9454 }
9455 if seg_k == 0 {
9456 break;
9457 }
9458 segment_end = segment_start - 1;
9459 }
9460 } else {
9461 let mut carry_buf: Vec<u8> = vec![0u8; cb];
9464 for t in (0..n_steps).rev() {
9465 unsafe {
9466 let src = if t == 0 {
9467 base.add(*outer_init_off)
9468 } else {
9469 base.add(*outer_traj_off + (t - 1) * cb)
9470 };
9471 std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
9472 }
9473 process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
9474 }
9475 }
9476
9477 unsafe {
9478 std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
9479 }
9480 }
9481
9482 Thunk::ScanBackwardXs {
9483 body_vjp,
9484 body_init,
9485 body_carry_in_off,
9486 body_x_offs,
9487 body_d_output_off,
9488 body_dcarry_out_off,
9489 body_dxs_out_off,
9490 outer_init_off,
9491 outer_traj_off,
9492 outer_upstream_off,
9493 outer_xs_offs,
9494 outer_dxs_off,
9495 length,
9496 carry_bytes,
9497 carry_elem_size,
9498 per_step_bytes,
9499 save_trajectory,
9500 num_checkpoints,
9501 forward_body,
9502 forward_body_init,
9503 forward_body_carry_in_off,
9504 forward_body_output_off,
9505 forward_body_x_offs,
9506 } => {
9507 let cb = *carry_bytes as usize;
9508 let psb = *per_step_bytes as usize;
9509 let n_steps = *length as usize;
9510 let k_total = *num_checkpoints as usize;
9511 let is_recursive = k_total != 0 && k_total != n_steps;
9512 let checkpoint_t_for_k = |k: usize| -> usize {
9513 ((k + 1) * n_steps)
9514 .div_ceil(k_total)
9515 .saturating_sub(1)
9516 .min(n_steps - 1)
9517 };
9518
9519 let mut fwd_buf: Vec<u8> = if is_recursive {
9523 (**forward_body_init.as_ref().unwrap()).clone()
9524 } else {
9525 Vec::new()
9526 };
9527 let mut seg_cache: Vec<u8> = Vec::new();
9528 let mut seg_start_t: usize = usize::MAX;
9529 let mut seg_count: usize = 0;
9530 let recompute_carry_t =
9531 |t: usize,
9532 dst: &mut [u8],
9533 fwd_buf: &mut Vec<u8>,
9534 seg_cache: &mut Vec<u8>,
9535 seg_start_t: &mut usize,
9536 seg_count: &mut usize| {
9537 if !is_recursive {
9538 unsafe {
9539 let src = if t == 0 {
9540 base.add(*outer_init_off)
9541 } else {
9542 base.add(*outer_traj_off + (t - 1) * cb)
9543 };
9544 std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
9545 }
9546 return;
9547 }
9548 if *seg_start_t != usize::MAX
9549 && t >= *seg_start_t
9550 && t < *seg_start_t + *seg_count
9551 {
9552 let off = (t - *seg_start_t) * cb;
9553 dst.copy_from_slice(&seg_cache[off..off + cb]);
9554 return;
9555 }
9556 let seg_k = (0..k_total)
9557 .find(|&k| t <= checkpoint_t_for_k(k))
9558 .unwrap_or(k_total - 1);
9559 let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
9560 (0, unsafe { base.add(*outer_init_off) as *const u8 })
9561 } else {
9562 let prev_ck = checkpoint_t_for_k(seg_k - 1);
9563 (prev_ck + 1, unsafe {
9564 base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
9565 })
9566 };
9567 let seg_end_t = checkpoint_t_for_k(seg_k);
9568 let seg_size = seg_end_t - anchor_t + 1;
9569
9570 fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
9571 unsafe {
9572 std::ptr::copy_nonoverlapping(
9573 anchor_ptr,
9574 fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
9575 cb,
9576 );
9577 }
9578 seg_cache.resize(seg_size * cb, 0u8);
9579 seg_cache[0..cb].copy_from_slice(
9580 &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9581 );
9582 let fb_sched = forward_body.as_ref().unwrap();
9583 for i in 1..seg_size {
9584 let cur_iter = anchor_t + i - 1;
9585 for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
9586 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
9587 let xb = x_psb as usize;
9588 unsafe {
9589 std::ptr::copy_nonoverlapping(
9590 base.add(outer_xs_off + cur_iter * xb),
9591 fwd_buf.as_mut_ptr().add(*fb_x_off),
9592 xb,
9593 );
9594 }
9595 }
9596 execute_thunks(fb_sched, fwd_buf);
9597 if *forward_body_output_off != *forward_body_carry_in_off {
9598 fwd_buf.copy_within(
9599 *forward_body_output_off..*forward_body_output_off + cb,
9600 *forward_body_carry_in_off,
9601 );
9602 }
9603 let cache_off = i * cb;
9604 seg_cache[cache_off..cache_off + cb].copy_from_slice(
9605 &fwd_buf
9606 [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9607 );
9608 }
9609 *seg_start_t = anchor_t;
9610 *seg_count = seg_size;
9611
9612 let off = (t - anchor_t) * cb;
9613 dst.copy_from_slice(&seg_cache[off..off + cb]);
9614 };
9615
9616 let mut dcarry: Vec<u8> = vec![0u8; cb];
9617 if !*save_trajectory {
9618 unsafe {
9619 std::ptr::copy_nonoverlapping(
9620 base.add(*outer_upstream_off),
9621 dcarry.as_mut_ptr(),
9622 cb,
9623 );
9624 }
9625 }
9626
9627 let mut body_buf: Vec<u8> = (**body_init).clone();
9628
9629 for t in (0..n_steps).rev() {
9630 if *save_trajectory {
9631 unsafe {
9632 let up_off = *outer_upstream_off + t * cb;
9633 match *carry_elem_size {
9634 4 => {
9635 let up_ptr = base.add(up_off) as *const f32;
9636 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
9637 let n_elems = cb / 4;
9638 for i in 0..n_elems {
9639 *dc_ptr.add(i) += *up_ptr.add(i);
9640 }
9641 }
9642 8 => {
9643 let up_ptr = base.add(up_off) as *const f64;
9644 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
9645 let n_elems = cb / 8;
9646 for i in 0..n_elems {
9647 *dc_ptr.add(i) += *up_ptr.add(i);
9648 }
9649 }
9650 other => panic!(
9651 "ScanBackwardXs: unsupported carry elem size {other} \
9652 (only f32/f64 carries are supported today)"
9653 ),
9654 }
9655 }
9656 }
9657
9658 let carry_dst_start = *body_carry_in_off;
9662 {
9663 let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
9664 recompute_carry_t(
9665 t,
9666 carry_slice,
9667 &mut fwd_buf,
9668 &mut seg_cache,
9669 &mut seg_start_t,
9670 &mut seg_count,
9671 );
9672 }
9673 unsafe {
9674 for (i, body_x_off) in body_x_offs.iter().enumerate() {
9675 let (outer_xs_off, x_psb) = outer_xs_offs[i];
9676 let xb = x_psb as usize;
9677 std::ptr::copy_nonoverlapping(
9678 base.add(outer_xs_off + t * xb),
9679 body_buf.as_mut_ptr().add(*body_x_off),
9680 xb,
9681 );
9682 }
9683 std::ptr::copy_nonoverlapping(
9684 dcarry.as_ptr(),
9685 body_buf.as_mut_ptr().add(*body_d_output_off),
9686 cb,
9687 );
9688 }
9689
9690 execute_thunks(body_vjp, &mut body_buf);
9691
9692 unsafe {
9695 std::ptr::copy_nonoverlapping(
9696 body_buf.as_ptr().add(*body_dxs_out_off),
9697 base.add(*outer_dxs_off + t * psb),
9698 psb,
9699 );
9700 }
9701
9702 unsafe {
9704 std::ptr::copy_nonoverlapping(
9705 body_buf.as_ptr().add(*body_dcarry_out_off),
9706 dcarry.as_mut_ptr(),
9707 cb,
9708 );
9709 }
9710 }
9711 }
9712
9713 Thunk::FusedMmBiasAct {
9714 a,
9715 w,
9716 bias,
9717 c,
9718 m,
9719 k,
9720 n,
9721 act,
9722 } => {
9723 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9724 unsafe {
9725 let out = sl_mut(*c, base, m * n);
9726 crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
9727 match act {
9728 Some(Activation::Gelu) => {
9729 crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
9730 }
9731 Some(other) => {
9732 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9733 apply_activation_inplace(out, *other);
9734 }
9735 None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
9736 }
9737 }
9738 }
9739
9740 Thunk::FusedResidualLN {
9741 x,
9742 res,
9743 bias,
9744 g,
9745 b,
9746 out,
9747 rows,
9748 h,
9749 eps,
9750 has_bias,
9751 } => {
9752 let (rows, h) = (*rows as usize, *h as usize);
9753 unsafe {
9754 let zero = &zero_bias[..h];
9755 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9756 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9757 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9758 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9759 let bi_ptr = bi.as_ptr() as usize;
9760 let g_ptr = sl(*g, base, h).as_ptr() as usize;
9761 let b_ptr = sl(*b, base, h).as_ptr() as usize;
9762 let e = *eps;
9763 crate::pool::par_for(rows, 4, &|off, cnt| {
9764 let xs =
9765 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9766 let rs =
9767 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9768 let os = std::slice::from_raw_parts_mut(
9769 (o_ptr as *mut f32).add(off * h),
9770 cnt * h,
9771 );
9772 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9773 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9774 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9775 crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
9776 });
9777 }
9778 }
9779
9780 Thunk::FusedResidualRmsNorm {
9781 x,
9782 res,
9783 bias,
9784 g,
9785 b,
9786 out,
9787 rows,
9788 h,
9789 eps,
9790 has_bias,
9791 } => {
9792 let (rows, h) = (*rows as usize, *h as usize);
9793 unsafe {
9794 let zero = &zero_bias[..h];
9795 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9796 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9797 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9798 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9799 let bi_ptr = bi.as_ptr() as usize;
9800 let g_ptr = sl(*g, base, h).as_ptr() as usize;
9801 let b_ptr = sl(*b, base, h).as_ptr() as usize;
9802 let e = *eps;
9803 crate::pool::par_for(rows, 4, &|off, cnt| {
9804 let xs =
9805 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9806 let rs =
9807 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9808 let os = std::slice::from_raw_parts_mut(
9809 (o_ptr as *mut f32).add(off * h),
9810 cnt * h,
9811 );
9812 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9813 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9814 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9815 crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
9816 });
9817 }
9818 }
9819
9820 Thunk::BiasAdd {
9821 src,
9822 bias,
9823 dst,
9824 m,
9825 n,
9826 } => {
9827 let (m, n) = (*m as usize, *n as usize);
9828 let len = m * n;
9829 unsafe {
9830 let out = sl_mut(*dst, base, len);
9831 if *src != *dst {
9832 let src_ptr = base.add(*src) as *const f32;
9833 let dst_ptr = base.add(*dst) as *mut f32;
9834 if src_ptr != dst_ptr {
9835 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
9836 }
9837 }
9838 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9839 }
9840 }
9841
9842 Thunk::BinaryFull {
9843 lhs,
9844 rhs,
9845 dst,
9846 len,
9847 lhs_len,
9848 rhs_len,
9849 op,
9850 out_dims_bcast,
9851 bcast_lhs_strides,
9852 bcast_rhs_strides,
9853 elem_bytes,
9854 } => {
9855 let len = *len as usize;
9856 let ll = (*lhs_len as usize).max(1);
9857 let rl = (*rhs_len as usize).max(1);
9858 let eb = (*elem_bytes).max(1) as usize;
9859 let arena_len = arena_buf.len();
9860 let ll = ll.min((arena_len.saturating_sub(*lhs)) / eb);
9861 let rl = rl.min((arena_len.saturating_sub(*rhs)) / eb);
9862 let len = len.min((arena_len.saturating_sub(*dst)) / eb);
9863 unsafe {
9864 if eb == 8 {
9865 let l = sl_i64(*lhs, base, ll);
9866 let r = sl_i64(*rhs, base, rl);
9867 let o = sl_mut_i64(*dst, base, len);
9868 if !out_dims_bcast.is_empty() {
9869 let rank = out_dims_bcast.len();
9870 let mut coords = vec![0u32; rank];
9871 for i in 0..len {
9872 let mut rem = i;
9873 for ax in (0..rank).rev() {
9874 let sz = out_dims_bcast[ax] as usize;
9875 coords[ax] = (rem % sz) as u32;
9876 rem /= sz;
9877 }
9878 let mut li = 0usize;
9879 let mut ri = 0usize;
9880 for ax in 0..rank {
9881 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9882 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9883 }
9884 o[i] = match op {
9885 BinaryOp::Add => l[li].wrapping_add(r[ri]),
9886 BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9887 BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9888 BinaryOp::Div => {
9889 if r[ri] == 0 {
9890 0
9891 } else {
9892 l[li] / r[ri]
9893 }
9894 }
9895 BinaryOp::Max => l[li].max(r[ri]),
9896 BinaryOp::Min => l[li].min(r[ri]),
9897 BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9898 };
9899 }
9900 } else {
9901 for i in 0..len {
9902 let li = if ll == 1 { 0 } else { i % ll };
9903 let ri = if rl == 1 { 0 } else { i % rl };
9904 o[i] = match op {
9905 BinaryOp::Add => l[li].wrapping_add(r[ri]),
9906 BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9907 BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9908 BinaryOp::Div => {
9909 if r[ri] == 0 {
9910 0
9911 } else {
9912 l[li] / r[ri]
9913 }
9914 }
9915 BinaryOp::Max => l[li].max(r[ri]),
9916 BinaryOp::Min => l[li].min(r[ri]),
9917 BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9918 };
9919 }
9920 }
9921 } else {
9922 let l = sl(*lhs, base, ll);
9923 let r = sl(*rhs, base, rl);
9924 let o = sl_mut(*dst, base, len);
9925 if ll == len && rl == len {
9926 #[cfg(target_arch = "aarch64")]
9927 if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
9928 use std::arch::aarch64::*;
9929 let chunks = len / 4;
9930 for c in 0..chunks {
9931 let off = c * 4;
9932 let vl = vld1q_f32(l.as_ptr().add(off));
9933 let vr = vld1q_f32(r.as_ptr().add(off));
9934 let res = match op {
9935 BinaryOp::Add => vaddq_f32(vl, vr),
9936 BinaryOp::Mul => vmulq_f32(vl, vr),
9937 _ => unreachable!(),
9938 };
9939 vst1q_f32(o.as_mut_ptr().add(off), res);
9940 }
9941 for i in (chunks * 4)..len {
9942 o[i] = match op {
9943 BinaryOp::Add => l[i] + r[i],
9944 BinaryOp::Mul => l[i] * r[i],
9945 _ => unreachable!(),
9946 };
9947 }
9948 continue;
9949 }
9950 }
9951 if !out_dims_bcast.is_empty() {
9952 let rank = out_dims_bcast.len();
9953 let mut coords = vec![0u32; rank];
9954 for i in 0..len {
9955 let mut rem = i;
9956 for ax in (0..rank).rev() {
9957 let sz = out_dims_bcast[ax] as usize;
9958 coords[ax] = (rem % sz) as u32;
9959 rem /= sz;
9960 }
9961 let mut li = 0usize;
9962 let mut ri = 0usize;
9963 for ax in 0..rank {
9964 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9965 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9966 }
9967 o[i] = match op {
9968 BinaryOp::Add => l[li] + r[ri],
9969 BinaryOp::Sub => l[li] - r[ri],
9970 BinaryOp::Mul => l[li] * r[ri],
9971 BinaryOp::Div => l[li] / r[ri],
9972 BinaryOp::Max => l[li].max(r[ri]),
9973 BinaryOp::Min => l[li].min(r[ri]),
9974 BinaryOp::Pow => l[li].powf(r[ri]),
9975 };
9976 }
9977 } else {
9978 for i in 0..len {
9979 let li = if ll == 1 { 0 } else { i % ll };
9980 let ri = if rl == 1 { 0 } else { i % rl };
9981 o[i] = match op {
9982 BinaryOp::Add => l[li] + r[ri],
9983 BinaryOp::Sub => l[li] - r[ri],
9984 BinaryOp::Mul => l[li] * r[ri],
9985 BinaryOp::Div => l[li] / r[ri],
9986 BinaryOp::Max => l[li].max(r[ri]),
9987 BinaryOp::Min => l[li].min(r[ri]),
9988 BinaryOp::Pow => l[li].powf(r[ri]),
9989 };
9990 }
9991 }
9992 }
9993 }
9994 }
9995
9996 Thunk::Gather {
9997 table,
9998 table_len,
9999 idx,
10000 dst,
10001 num_idx,
10002 trailing,
10003 idx_i64,
10004 table_bytes,
10005 } => {
10006 let (ni, tr) = (*num_idx as usize, *trailing as usize);
10007 let rows = *table_len as usize / tr.max(1);
10008 unsafe {
10009 if *table_bytes == 8 {
10010 let tab = sl_i64(*table, base, *table_len as usize);
10011 let out = sl_mut_i64(*dst, base, ni * tr);
10012 if *idx_i64 != 0 {
10013 let ids = sl_i64(*idx, base, ni);
10014 for i in 0..ni {
10015 let row = ids[i].max(0) as usize;
10016 if row < rows {
10017 out[i * tr..(i + 1) * tr]
10018 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
10019 }
10020 }
10021 } else {
10022 let ids = sl(*idx, base, ni);
10023 for i in 0..ni {
10024 let row = ids[i] as usize;
10025 if row < rows {
10026 out[i * tr..(i + 1) * tr]
10027 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
10028 }
10029 }
10030 }
10031 } else {
10032 let tab = sl(*table, base, *table_len as usize);
10033 let out = sl_mut(*dst, base, ni * tr);
10034 if *idx_i64 != 0 {
10035 let ids = sl_i64(*idx, base, ni);
10036 for i in 0..ni {
10037 let row = ids[i].max(0) as usize;
10038 if row < rows {
10039 out[i * tr..(i + 1) * tr]
10040 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
10041 }
10042 }
10043 } else {
10044 let ids = sl(*idx, base, ni);
10045 for i in 0..ni {
10046 let row = ids[i] as usize;
10047 if row < rows {
10048 out[i * tr..(i + 1) * tr]
10049 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
10050 }
10051 }
10052 }
10053 }
10054 }
10055 }
10056
10057 Thunk::Narrow {
10058 src,
10059 dst,
10060 outer,
10061 src_stride,
10062 dst_stride,
10063 inner,
10064 elem_bytes,
10065 } => {
10066 let (outer, ss, ds, inner, eb) = (
10067 *outer as usize,
10068 *src_stride as usize,
10069 *dst_stride as usize,
10070 *inner as usize,
10071 *elem_bytes as usize,
10072 );
10073 let row_bytes = inner.saturating_mul(eb);
10074 let src_row_stride = ss.saturating_mul(eb);
10075 let dst_row_stride = ds.saturating_mul(eb);
10076 if trace_thunks {
10077 eprintln!(
10078 "[narrow] src={} dst={} outer={outer} ss={ss} ds={ds} inner={inner} eb={eb} row={row_bytes} arena={}",
10079 *src,
10080 *dst,
10081 arena_buf.len()
10082 );
10083 }
10084 if row_bytes > 0 && *src != *dst {
10085 let arena_len = arena_buf.len();
10086 for o in 0..outer {
10087 let s_off = *src + o * src_row_stride;
10088 let d_off = *dst + o * dst_row_stride;
10089 if s_off == d_off {
10090 continue;
10091 }
10092 if s_off.saturating_add(row_bytes) > arena_len
10093 || d_off.saturating_add(row_bytes) > arena_len
10094 {
10095 break;
10096 }
10097 unsafe {
10098 std::ptr::copy_nonoverlapping(
10099 base.add(s_off),
10100 base.add(d_off),
10101 row_bytes,
10102 );
10103 }
10104 }
10105 }
10106 }
10107
10108 Thunk::Copy { src, dst, len } => {
10109 let mut len = *len as usize;
10110 if *src == *dst || len == 0 {
10111 continue;
10112 }
10113 let arena_len = arena_buf.len();
10114 let max_from_src = (arena_len.saturating_sub(*src)) / 4;
10115 let max_from_dst = (arena_len.saturating_sub(*dst)) / 4;
10116 len = len.min(max_from_src).min(max_from_dst);
10117 if len == 0 {
10118 continue;
10119 }
10120 let byte_len = len.saturating_mul(4);
10121 unsafe {
10122 std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
10123 }
10124 }
10125
10126 Thunk::LayerNorm {
10127 src,
10128 g,
10129 b,
10130 dst,
10131 rows,
10132 h,
10133 eps,
10134 } => {
10135 let (rows, h) = (*rows as usize, *h as usize);
10136 unsafe {
10137 let input = sl(*src, base, rows * h);
10138 let gamma = sl(*g, base, h);
10139 let beta = sl(*b, base, h);
10140 let output = sl_mut(*dst, base, rows * h);
10141 if rows >= 4 && rows * h >= 30_000 {
10143 let i_ptr = input.as_ptr() as usize;
10144 let o_ptr = output.as_mut_ptr() as usize;
10145 let g_ptr = gamma.as_ptr() as usize;
10146 let b_ptr = beta.as_ptr() as usize;
10147 let e = *eps;
10148 crate::pool::par_for(rows, 4, &|off, cnt| {
10149 let inp = std::slice::from_raw_parts(
10150 (i_ptr as *const f32).add(off * h),
10151 cnt * h,
10152 );
10153 let out = std::slice::from_raw_parts_mut(
10154 (o_ptr as *mut f32).add(off * h),
10155 cnt * h,
10156 );
10157 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
10158 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
10159 for row in 0..cnt {
10160 crate::kernels::layer_norm_row(
10161 &inp[row * h..(row + 1) * h],
10162 g,
10163 b,
10164 &mut out[row * h..(row + 1) * h],
10165 h,
10166 e,
10167 );
10168 }
10169 });
10170 } else {
10171 for row in 0..rows {
10172 crate::kernels::layer_norm_row(
10173 &input[row * h..(row + 1) * h],
10174 gamma,
10175 beta,
10176 &mut output[row * h..(row + 1) * h],
10177 h,
10178 *eps,
10179 );
10180 }
10181 }
10182 }
10183 }
10184
10185 Thunk::GroupNorm {
10186 src,
10187 g,
10188 b,
10189 dst,
10190 n,
10191 c,
10192 h,
10193 w,
10194 num_groups,
10195 eps,
10196 } => {
10197 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
10198 let plane = c * h * w;
10199 unsafe {
10200 for ni in 0..n {
10201 let input = sl(*src, base.add(ni * plane), plane);
10202 let gamma = sl(*g, base, c);
10203 let beta = sl(*b, base, c);
10204 let output = sl_mut(*dst, base.add(ni * plane), plane);
10205 crate::kernels::group_norm_nchw(
10206 input,
10207 gamma,
10208 beta,
10209 output,
10210 1,
10211 c,
10212 h,
10213 w,
10214 *num_groups as usize,
10215 *eps,
10216 );
10217 }
10218 }
10219 }
10220
10221 Thunk::BatchNormInference {
10222 src,
10223 g,
10224 b,
10225 mean,
10226 var,
10227 dst,
10228 count,
10229 channels,
10230 eps,
10231 } => {
10232 let count = *count as usize;
10233 let c = *channels as usize;
10234 let n = count * c;
10235 unsafe {
10236 crate::kernels::batch_norm_inference(
10237 sl(*src, base, n),
10238 sl(*g, base, c),
10239 sl(*b, base, c),
10240 sl(*mean, base, c),
10241 sl(*var, base, c),
10242 sl_mut(*dst, base, n),
10243 c,
10244 *eps,
10245 );
10246 }
10247 }
10248
10249 Thunk::LayerNorm2d {
10250 src,
10251 g,
10252 b,
10253 dst,
10254 n,
10255 c,
10256 h,
10257 w,
10258 eps,
10259 } => {
10260 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
10261 let plane = c * h * w;
10262 unsafe {
10263 let input = sl(*src, base, n * plane);
10264 let gamma = sl(*g, base, c);
10265 let beta = sl(*b, base, c);
10266 let output = sl_mut(*dst, base, n * plane);
10267 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
10268 }
10269 }
10270
10271 Thunk::ConvTranspose2d {
10272 src,
10273 weight,
10274 dst,
10275 n,
10276 c_in,
10277 h,
10278 w_in,
10279 c_out,
10280 h_out,
10281 w_out,
10282 kh,
10283 kw,
10284 sh,
10285 sw,
10286 ph,
10287 pw,
10288 dh,
10289 dw,
10290 groups,
10291 } => {
10292 let n = *n as usize;
10293 let c_in = *c_in as usize;
10294 let h = *h as usize;
10295 let w_in = *w_in as usize;
10296 let c_out = *c_out as usize;
10297 let h_out = *h_out as usize;
10298 let w_out = *w_out as usize;
10299 unsafe {
10300 let inp = sl(*src, base, n * c_in * h * w_in);
10301 let wt = sl(
10302 *weight,
10303 base,
10304 c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
10305 );
10306 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
10307 crate::kernels::conv_transpose2d_nchw(
10308 inp,
10309 wt,
10310 out,
10311 n,
10312 c_in,
10313 h,
10314 w_in,
10315 c_out,
10316 h_out,
10317 w_out,
10318 *kh as usize,
10319 *kw as usize,
10320 *sh as usize,
10321 *sw as usize,
10322 *ph as usize,
10323 *pw as usize,
10324 *dh as usize,
10325 *dw as usize,
10326 *groups as usize,
10327 );
10328 }
10329 }
10330
10331 Thunk::ResizeNearest2x {
10332 src,
10333 dst,
10334 n,
10335 c,
10336 h,
10337 w,
10338 } => {
10339 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
10340 let in_plane = c * h * w;
10341 let out_plane = c * h * 2 * w * 2;
10342 unsafe {
10343 for ni in 0..n {
10344 let input = sl(*src, base.add(ni * in_plane), in_plane);
10345 let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
10346 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
10347 }
10348 }
10349 }
10350
10351 Thunk::AxialRope2d {
10352 src,
10353 dst,
10354 batch,
10355 seq,
10356 hidden,
10357 end_x,
10358 end_y,
10359 head_dim,
10360 num_heads,
10361 theta,
10362 repeat_factor,
10363 } => {
10364 let b = *batch as usize;
10365 let s = *seq as usize;
10366 let hdim = *head_dim as usize;
10367 let nh = *num_heads as usize;
10368 let plane = s * (*hidden as usize);
10369 unsafe {
10370 for bi in 0..b {
10371 let input = sl(*src, base.add(bi * plane), plane);
10372 let output = sl_mut(*dst, base.add(bi * plane), plane);
10373 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
10374 input,
10375 nh,
10376 s,
10377 hdim,
10378 *end_x as usize,
10379 *end_y as usize,
10380 *theta,
10381 *repeat_factor as usize,
10382 );
10383 output.copy_from_slice(&rotated);
10384 }
10385 }
10386 }
10387
10388 Thunk::RmsNorm {
10389 src,
10390 g,
10391 b,
10392 dst,
10393 rows,
10394 h,
10395 eps,
10396 } => {
10397 let (rows, h) = (*rows as usize, *h as usize);
10398 unsafe {
10399 let input = sl(*src, base, rows * h);
10400 let gamma = sl(*g, base, h);
10401 let beta = sl(*b, base, h);
10402 let output = sl_mut(*dst, base, rows * h);
10403 let inv_h = 1.0 / h as f32;
10404 for row in 0..rows {
10405 let in_row = &input[row * h..(row + 1) * h];
10406 let out_row = &mut output[row * h..(row + 1) * h];
10407 let mut sumsq = 0f32;
10409 for &v in in_row {
10410 sumsq += v * v;
10411 }
10412 let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
10413 for i in 0..h {
10414 out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
10415 }
10416 }
10417 }
10418 }
10419
10420 Thunk::Softmax { data, rows, cols } => {
10421 let (rows, cols) = (*rows as usize, *cols as usize);
10422 unsafe {
10423 crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
10424 }
10425 }
10426
10427 Thunk::Cumsum {
10428 src,
10429 dst,
10430 rows,
10431 cols,
10432 exclusive,
10433 } => {
10434 let (rows, cols) = (*rows as usize, *cols as usize);
10435 unsafe {
10436 let s = sl(*src, base, rows * cols);
10437 let d = sl_mut(*dst, base, rows * cols);
10438 if *exclusive {
10439 for r in 0..rows {
10440 let mut acc = 0.0f32;
10441 for c in 0..cols {
10442 d[r * cols + c] = acc;
10443 acc += s[r * cols + c];
10444 }
10445 }
10446 } else {
10447 for r in 0..rows {
10448 let mut acc = 0.0f32;
10449 for c in 0..cols {
10450 acc += s[r * cols + c];
10451 d[r * cols + c] = acc;
10452 }
10453 }
10454 }
10455 }
10456 }
10457
10458 Thunk::Sample {
10459 logits,
10460 dst,
10461 batch,
10462 vocab,
10463 top_k,
10464 top_p,
10465 temperature,
10466 seed,
10467 } => {
10468 let (b, v) = (*batch as usize, *vocab as usize);
10469 let k = (*top_k as usize).min(v);
10470 unsafe {
10471 let lg = sl(*logits, base, b * v);
10472 let out = sl_mut(*dst, base, b);
10473 let mut rng =
10474 rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
10475 for bi in 0..b {
10476 let row = &lg[bi * v..(bi + 1) * v];
10477 out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
10478 }
10479 }
10480 }
10481
10482 Thunk::RngNormal {
10483 dst,
10484 len,
10485 mean,
10486 scale,
10487 key,
10488 op_seed,
10489 } => {
10490 let n = *len as usize;
10491 unsafe {
10492 let out = sl_mut(*dst, base, n);
10493 let opts = *schedule.rng.read().unwrap();
10494 rlx_ir::fill_normal_like(out, *mean, *scale, opts, *key, *op_seed);
10495 }
10496 }
10497
10498 Thunk::RngUniform {
10499 dst,
10500 len,
10501 low,
10502 high,
10503 key,
10504 op_seed,
10505 } => {
10506 let n = *len as usize;
10507 unsafe {
10508 let out = sl_mut(*dst, base, n);
10509 let opts = *schedule.rng.read().unwrap();
10510 rlx_ir::fill_uniform_like(out, *low, *high, opts, *key, *op_seed);
10511 }
10512 }
10513
10514 Thunk::GatedDeltaNet {
10515 q,
10516 k,
10517 v,
10518 g,
10519 beta,
10520 state,
10521 dst,
10522 batch,
10523 seq,
10524 heads,
10525 state_size,
10526 } => unsafe {
10527 execute_gated_delta_net_f32(
10528 *q,
10529 *k,
10530 *v,
10531 *g,
10532 *beta,
10533 *state,
10534 *dst,
10535 *batch as usize,
10536 *seq as usize,
10537 *heads as usize,
10538 *state_size as usize,
10539 base,
10540 );
10541 },
10542
10543 Thunk::Lstm {
10544 x,
10545 w_ih,
10546 w_hh,
10547 bias,
10548 h0,
10549 c0,
10550 dst,
10551 batch,
10552 seq,
10553 input_size,
10554 hidden,
10555 num_layers,
10556 bidirectional,
10557 carry,
10558 } => unsafe {
10559 execute_lstm_f32(
10560 *x,
10561 *w_ih,
10562 *w_hh,
10563 *bias,
10564 *h0,
10565 *c0,
10566 *dst,
10567 *batch as usize,
10568 *seq as usize,
10569 *input_size as usize,
10570 *hidden as usize,
10571 *num_layers as usize,
10572 *bidirectional,
10573 *carry,
10574 base,
10575 );
10576 },
10577
10578 Thunk::SelectiveScan {
10579 x,
10580 delta,
10581 a,
10582 b: bp,
10583 c: cp,
10584 dst,
10585 batch,
10586 seq,
10587 hidden,
10588 state_size,
10589 } => {
10590 let (b, s, h, n) = (
10591 *batch as usize,
10592 *seq as usize,
10593 *hidden as usize,
10594 *state_size as usize,
10595 );
10596 unsafe {
10597 let xs = sl(*x, base, b * s * h);
10598 let dt = sl(*delta, base, b * s * h);
10599 let am = sl(*a, base, h * n);
10600 let bm = sl(*bp, base, b * s * n);
10601 let cm = sl(*cp, base, b * s * n);
10602 let out = sl_mut(*dst, base, b * s * h);
10603
10604 let mut state = vec![0f32; h * n];
10608 for bi in 0..b {
10609 for v in state.iter_mut() {
10611 *v = 0.0;
10612 }
10613 for si in 0..s {
10614 let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10615 let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10616 let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10617 let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10618 let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10619
10620 for ci in 0..h {
10621 let d = dt_row[ci];
10622 let xv = x_row[ci];
10623 let mut acc = 0f32;
10624 for ni in 0..n {
10625 let da = (d * am[ci * n + ni]).exp();
10627 state[ci * n + ni] =
10628 da * state[ci * n + ni] + d * b_row[ni] * xv;
10629 acc += c_row[ni] * state[ci * n + ni];
10630 }
10631 out_row[ci] = acc;
10632 }
10633 }
10634 }
10635 }
10636 }
10637
10638 Thunk::DequantMatMul {
10639 x,
10640 w_q,
10641 scale,
10642 zp,
10643 dst,
10644 m,
10645 k,
10646 n,
10647 block_size,
10648 is_asymmetric,
10649 } => {
10650 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10651 let n_blocks = k.div_ceil(bs);
10652 unsafe {
10653 let xs = sl(*x, base, m * k);
10654 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
10655 let scales = sl(*scale, base, n_blocks * n);
10656 let zps = if *is_asymmetric {
10657 sl(*zp, base, n_blocks * n)
10658 } else {
10659 &[][..]
10660 };
10661 let out = sl_mut(*dst, base, m * n);
10662 dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10663 }
10664 }
10665
10666 Thunk::DequantMatMulGguf {
10667 x,
10668 w_q,
10669 dst,
10670 m,
10671 k,
10672 n,
10673 scheme,
10674 } => {
10675 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10676 let block_bytes = scheme.gguf_block_bytes() as usize;
10677 let block_elems = scheme.gguf_block_size() as usize;
10678 debug_assert!(
10679 block_bytes > 0 && block_elems > 0,
10680 "non-GGUF scheme in GGUF arm"
10681 );
10682 debug_assert!(
10683 (k * n).is_multiple_of(block_elems),
10684 "k*n={} not aligned to GGUF block size {}",
10685 k * n,
10686 block_elems
10687 );
10688 let total_bytes = (k * n) / block_elems * block_bytes;
10689 unsafe {
10690 let xs = sl(*x, base, m * k);
10691 let w_bytes_ptr = base.add(*w_q) as *const u8;
10692 let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
10693 let out = sl_mut(*dst, base, m * n);
10694 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
10695 }
10696 }
10697
10698 Thunk::DequantMatMulInt4 {
10699 x,
10700 w_q,
10701 scale,
10702 zp,
10703 dst,
10704 m,
10705 k,
10706 n,
10707 block_size,
10708 is_asymmetric,
10709 } => {
10710 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10711 let n_blocks = k.div_ceil(bs);
10712 unsafe {
10713 let xs = sl(*x, base, m * k);
10714 let w_bytes = std::slice::from_raw_parts(
10715 base.add(*w_q) as *const u8,
10716 (k * n).div_ceil(2),
10717 );
10718 let scales = sl(*scale, base, n_blocks * n);
10719 let zps = if *is_asymmetric {
10720 sl(*zp, base, n_blocks * n)
10721 } else {
10722 &[][..]
10723 };
10724 let out = sl_mut(*dst, base, m * n);
10725 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10726 }
10727 }
10728
10729 Thunk::DequantMatMulFp8 {
10730 x,
10731 w_q,
10732 scale,
10733 dst,
10734 m,
10735 k,
10736 n,
10737 e5m2,
10738 } => {
10739 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10740 unsafe {
10741 let xs = sl(*x, base, m * k);
10742 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
10743 let scales = sl(*scale, base, n);
10744 let out = sl_mut(*dst, base, m * n);
10745 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
10746 }
10747 }
10748
10749 Thunk::DequantMatMulNvfp4 {
10750 x,
10751 w_q,
10752 scale,
10753 global_scale,
10754 dst,
10755 m,
10756 k,
10757 n,
10758 } => {
10759 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10760 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
10761 unsafe {
10762 let xs = sl(*x, base, m * k);
10763 let w_bytes = std::slice::from_raw_parts(
10764 base.add(*w_q) as *const u8,
10765 (k * n).div_ceil(2),
10766 );
10767 let scale_bytes =
10768 std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
10769 let gs = sl(*global_scale, base, 1)[0];
10770 let out = sl_mut(*dst, base, m * n);
10771 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
10772 }
10773 }
10774
10775 Thunk::LoraMatMul {
10776 x,
10777 w,
10778 a,
10779 b,
10780 dst,
10781 m,
10782 k,
10783 n,
10784 r,
10785 scale,
10786 } => {
10787 let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
10788 unsafe {
10789 let xs = sl(*x, base, m * k);
10790 let ws = sl(*w, base, k * n);
10791 let a_s = sl(*a, base, k * r);
10792 let bs = sl(*b, base, r * n);
10793 let out = sl_mut(*dst, base, m * n);
10794 crate::blas::sgemm(xs, ws, out, m, k, n);
10795 let mut tmp = vec![0f32; m * r];
10796 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
10797 if *scale != 1.0 {
10798 for v in tmp.iter_mut() {
10799 *v *= *scale;
10800 }
10801 }
10802 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
10803 }
10804 }
10805
10806 Thunk::Attention {
10807 q,
10808 k,
10809 v,
10810 mask,
10811 out,
10812 batch,
10813 seq,
10814 kv_seq,
10815 heads,
10816 head_dim,
10817 mask_kind,
10818 scale,
10819 q_row_stride,
10820 k_row_stride,
10821 v_row_stride,
10822 bhsd,
10823 } => {
10824 let (b, q_s, k_s, nh, dh) = (
10825 *batch as usize,
10826 *seq as usize,
10827 *kv_seq as usize,
10828 *heads as usize,
10829 *head_dim as usize,
10830 );
10831 let hs = nh * dh;
10832 let (qrs, krs, vrs) = if *bhsd {
10835 (dh, dh, dh)
10836 } else {
10837 (
10838 *q_row_stride as usize,
10839 *k_row_stride as usize,
10840 *v_row_stride as usize,
10841 )
10842 };
10843 let bhsd = *bhsd;
10844 let _ = (q_row_stride, k_row_stride, v_row_stride);
10845 let scale = *scale;
10846 let ss = q_s * k_s;
10847 let cfg = crate::config::RuntimeConfig::global();
10848 unsafe {
10849 let q_len = if bhsd {
10856 b * nh * q_s * dh
10857 } else {
10858 b * q_s * qrs
10859 };
10860 let k_len = if bhsd {
10861 b * nh * k_s * dh
10862 } else {
10863 b * k_s * krs
10864 };
10865 let v_len = if bhsd {
10866 b * nh * k_s * dh
10867 } else {
10868 b * k_s * vrs
10869 };
10870 let q_data = sl(*q, base, q_len);
10871 let k_data = sl(*k, base, k_len);
10872 let v_data = sl(*v, base, v_len);
10873 let mask_data: &[f32] = match mask_kind {
10874 rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
10875 rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
10876 _ => &[],
10877 };
10878 let out_len = if bhsd {
10879 b * nh * q_s * dh
10880 } else {
10881 b * q_s * hs
10882 };
10883 let out_data = sl_mut(*out, base, out_len);
10884
10885 if bhsd {
10896 let scores = &mut sdpa_scores[..ss];
10897 for bi in 0..b {
10898 for hi in 0..nh {
10899 let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
10900 let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
10901 for qi in 0..q_s {
10903 let q_base = q_head_base + qi * dh;
10904 for ki in 0..k_s {
10905 let k_base = k_head_base + ki * dh;
10906 let mut dot = 0f32;
10907 for d in 0..dh {
10908 dot += q_data[q_base + d] * k_data[k_base + d];
10909 }
10910 scores[qi * k_s + ki] = dot * scale;
10911 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10912 && !mask_data.is_empty()
10913 && mask_data[bi * k_s + ki] < mask_thr
10914 {
10915 scores[qi * k_s + ki] = mask_neg;
10916 }
10917 }
10918 }
10919 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10920 let off = (bi * nh + hi) * q_s * k_s;
10921 for i in 0..q_s * k_s {
10922 scores[i] += mask_data[off + i];
10923 }
10924 }
10925 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10926 crate::kernels::neon_softmax(scores, q_s, k_s);
10927 for qi in 0..q_s {
10929 let o_base = q_head_base + qi * dh;
10930 for d in 0..dh {
10931 out_data[o_base + d] = 0.0;
10932 }
10933 for ki in 0..k_s {
10934 let sc = scores[qi * k_s + ki];
10935 if sc > score_thr {
10936 let v_base = k_head_base + ki * dh;
10937 for d in 0..dh {
10938 out_data[o_base + d] += sc * v_data[v_base + d];
10939 }
10940 }
10941 }
10942 }
10943 }
10944 }
10945 continue;
10946 }
10947
10948 if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
10955 let scores = &mut sdpa_scores[..ss];
10957 #[cfg(target_arch = "aarch64")]
10958 let neon_chunks = dh / 4;
10959
10960 for bi in 0..b {
10961 for hi in 0..nh {
10962 for qi in 0..q_s {
10964 let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
10965 for ki in 0..k_s {
10966 let k_off = bi * k_s * krs + ki * krs + hi * dh;
10967 #[cfg(target_arch = "aarch64")]
10968 let mut dot;
10969 #[cfg(not(target_arch = "aarch64"))]
10970 let mut dot = 0f32;
10971 #[cfg(target_arch = "aarch64")]
10972 {
10973 use std::arch::aarch64::*;
10974 let mut acc = vdupq_n_f32(0.0);
10975 for c in 0..neon_chunks {
10976 let vq =
10977 vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
10978 let vk =
10979 vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
10980 acc = vfmaq_f32(acc, vq, vk);
10981 }
10982 dot = vaddvq_f32(acc);
10983 for d in (neon_chunks * 4)..dh {
10984 dot += q_data[q_off + d] * k_data[k_off + d];
10985 }
10986 }
10987 #[cfg(not(target_arch = "aarch64"))]
10988 for d in 0..dh {
10989 dot += q_data[q_off + d] * k_data[k_off + d];
10990 }
10991 scores[qi * k_s + ki] = dot * scale;
10992 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10999 && !mask_data.is_empty()
11000 && mask_data[bi * k_s + ki] < mask_thr
11001 {
11002 scores[qi * k_s + ki] = mask_neg;
11003 }
11004 }
11005 }
11006
11007 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
11008 let off = (bi * nh + hi) * q_s * k_s;
11009 for i in 0..q_s * k_s {
11010 scores[i] += mask_data[off + i];
11011 }
11012 }
11013 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
11014 crate::kernels::neon_softmax(scores, q_s, k_s);
11015
11016 for qi in 0..q_s {
11018 let o_off = bi * q_s * hs + qi * hs + hi * dh;
11019 for d in 0..dh {
11021 out_data[o_off + d] = 0.0;
11022 }
11023 for ki in 0..k_s {
11024 let sc = scores[qi * k_s + ki];
11025 if sc > score_thr {
11026 let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
11027 #[cfg(target_arch = "aarch64")]
11028 {
11029 use std::arch::aarch64::*;
11030 let vsc = vdupq_n_f32(sc);
11031 for c in 0..neon_chunks {
11032 let off = c * 4;
11033 let vo = vld1q_f32(
11034 out_data.as_ptr().add(o_off + off),
11035 );
11036 let vv =
11037 vld1q_f32(v_data.as_ptr().add(v_off + off));
11038 vst1q_f32(
11039 out_data.as_mut_ptr().add(o_off + off),
11040 vfmaq_f32(vo, vsc, vv),
11041 );
11042 }
11043 }
11044 #[cfg(not(target_arch = "aarch64"))]
11045 for d in 0..dh {
11046 out_data[o_off + d] += sc * v_data[v_off + d];
11047 }
11048 }
11049 }
11050 }
11051 }
11052 }
11053 } else {
11054 let total_work = b * nh;
11056 let q_addr = q_data.as_ptr() as usize;
11057 let k_addr = k_data.as_ptr() as usize;
11058 let v_addr = v_data.as_ptr() as usize;
11059 let m_addr = mask_data.as_ptr() as usize;
11060 let o_addr = out_data.as_mut_ptr() as usize;
11061 let sc_addr = sdpa_scores.as_mut_ptr() as usize;
11062
11063 crate::pool::par_for(total_work, 1, &|off, cnt| {
11064 for idx in off..off + cnt {
11065 let bi = idx / nh;
11066 let hi = idx % nh;
11067
11068 let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
11069 let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
11070 let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
11071 let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
11072 let sc = std::slice::from_raw_parts_mut(
11073 (sc_addr as *mut f32).add(idx * ss),
11074 ss,
11075 );
11076
11077 crate::blas::sgemm_general(
11080 q_start,
11081 k_start,
11082 sc.as_mut_ptr(),
11083 q_s,
11084 k_s,
11085 dh,
11086 scale,
11087 0.0,
11088 qrs,
11089 krs,
11090 k_s,
11091 false,
11092 true,
11093 );
11094
11095 match mask_kind {
11096 rlx_ir::op::MaskKind::Custom => {
11097 let mask_bi = std::slice::from_raw_parts(
11098 (m_addr as *const f32).add(bi * k_s),
11099 k_s,
11100 );
11101 for ki in 0..k_s {
11102 if mask_bi[ki] < mask_thr {
11103 for qi in 0..q_s {
11104 sc[qi * k_s + ki] = mask_neg;
11105 }
11106 }
11107 }
11108 }
11109 rlx_ir::op::MaskKind::Bias => {
11110 let bias = std::slice::from_raw_parts(
11112 (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
11113 q_s * k_s,
11114 );
11115 for i in 0..q_s * k_s {
11116 sc[i] += bias[i];
11117 }
11118 }
11119 _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
11120 }
11121
11122 crate::kernels::neon_softmax(sc, q_s, k_s);
11123
11124 crate::blas::sgemm_general(
11128 sc.as_ptr(),
11129 v_start,
11130 o_start,
11131 q_s,
11132 dh,
11133 k_s,
11134 1.0,
11135 0.0,
11136 k_s,
11137 vrs,
11138 hs,
11139 false,
11140 false,
11141 );
11142 }
11143 });
11144 }
11145 }
11146 }
11147
11148 Thunk::AttentionBackward {
11149 q,
11150 k,
11151 v,
11152 dy,
11153 mask,
11154 out,
11155 batch,
11156 seq,
11157 kv_seq,
11158 heads,
11159 head_dim,
11160 mask_kind,
11161 wrt,
11162 bhsd,
11163 } => {
11164 let (b, q_s, k_s, nh, dh) = (
11165 *batch as usize,
11166 *seq as usize,
11167 *kv_seq as usize,
11168 *heads as usize,
11169 *head_dim as usize,
11170 );
11171 unsafe {
11172 let q_len = if *bhsd {
11173 b * nh * q_s * dh
11174 } else {
11175 b * q_s * nh * dh
11176 };
11177 let k_len = if *bhsd {
11178 b * nh * k_s * dh
11179 } else {
11180 b * k_s * nh * dh
11181 };
11182 let out_len = match wrt {
11183 rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
11184 k_len
11185 }
11186 rlx_ir::op::AttentionBwdWrt::Query => q_len,
11187 };
11188 let q_data = sl(*q, base, q_len);
11189 let k_data = sl(*k, base, k_len);
11190 let v_data = sl(*v, base, k_len);
11191 let dy_data = sl(*dy, base, q_len);
11192 let out_data = sl_mut(*out, base, out_len);
11193 let mask_data: &[f32] = if *mask != 0 {
11194 let ml = match mask_kind {
11195 rlx_ir::op::MaskKind::Custom => b * k_s,
11196 rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
11197 _ => 0,
11198 };
11199 sl(*mask, base, ml)
11200 } else {
11201 &[]
11202 };
11203 crate::attention_bwd::attention_backward(
11204 *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
11205 *mask_kind, mask_data, *bhsd,
11206 );
11207 }
11208 }
11209
11210 Thunk::ActivationInPlace { data, len, act } => {
11211 let len = *len as usize;
11212 unsafe {
11213 let d = sl_mut(*data, base, len);
11214 match act {
11215 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
11216 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
11217 Activation::Silu => crate::kernels::par_silu_inplace(d),
11218 Activation::Relu => {
11219 for v in d.iter_mut() {
11220 *v = v.max(0.0);
11221 }
11222 }
11223 Activation::Sigmoid => {
11224 for v in d.iter_mut() {
11225 *v = 1.0 / (1.0 + (-*v).exp());
11226 }
11227 }
11228 Activation::Tanh => {
11229 for v in d.iter_mut() {
11230 *v = v.tanh();
11231 }
11232 }
11233 Activation::Exp => {
11234 for v in d.iter_mut() {
11235 *v = v.exp();
11236 }
11237 }
11238 Activation::Log => {
11239 for v in d.iter_mut() {
11240 *v = v.ln();
11241 }
11242 }
11243 Activation::Sqrt => {
11244 for v in d.iter_mut() {
11245 *v = v.sqrt();
11246 }
11247 }
11248 Activation::Rsqrt => {
11249 for v in d.iter_mut() {
11250 *v = 1.0 / v.sqrt();
11251 }
11252 }
11253 Activation::Neg => {
11254 for v in d.iter_mut() {
11255 *v = -*v;
11256 }
11257 }
11258 Activation::Abs => {
11259 for v in d.iter_mut() {
11260 *v = v.abs();
11261 }
11262 }
11263 Activation::Round => {
11264 for v in d.iter_mut() {
11265 *v = v.round();
11266 }
11267 }
11268 Activation::Sin => {
11269 for v in d.iter_mut() {
11270 *v = v.sin();
11271 }
11272 }
11273 Activation::Cos => {
11274 for v in d.iter_mut() {
11275 *v = v.cos();
11276 }
11277 }
11278 Activation::Tan => {
11279 for v in d.iter_mut() {
11280 *v = v.tan();
11281 }
11282 }
11283 Activation::Atan => {
11284 for v in d.iter_mut() {
11285 *v = v.atan();
11286 }
11287 }
11288 }
11289 }
11290 }
11291
11292 Thunk::FusedAttnBlock {
11293 hidden,
11294 qkv_w,
11295 out_w,
11296 mask,
11297 out,
11298 qkv_b,
11299 out_b,
11300 cos,
11301 sin,
11302 cos_len,
11303 batch,
11304 seq,
11305 hs,
11306 nh,
11307 dh,
11308 has_bias,
11309 has_rope,
11310 } => {
11311 let (b, s) = (*batch as usize, *seq as usize);
11312 let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
11313 let m = b * s;
11314 let scale = (d_h as f32).powf(-0.5);
11315 let half = d_h / 2;
11316 unsafe {
11317 let inp = sl(*hidden, base, m * h);
11318 let wq = sl(*qkv_w, base, h * 3 * h);
11319 let wo = sl(*out_w, base, h * h);
11320 let mk = sl(*mask, base, b * s);
11321 let dst = sl_mut(*out, base, m * h);
11322
11323 let mut qkv = vec![0f32; m * 3 * h];
11325 let mut attn_out = vec![0f32; m * h];
11326 let mut scores_buf = vec![0f32; s * s]; crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
11330 if *has_bias {
11331 let bias = sl(*qkv_b, base, 3 * h);
11332 crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
11333 }
11334
11335 #[cfg(target_arch = "aarch64")]
11338 let neon_chunks = d_h / 4;
11339 #[cfg(target_arch = "aarch64")]
11340 let _rope_chunks = half / 4;
11341
11342 for bi in 0..b {
11343 for hi in 0..n_h {
11344 for qi in 0..s {
11346 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11347 for ki in 0..s {
11348 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11349 let mut dot = 0f32;
11350
11351 if *has_rope {
11352 let q_cos = qi * half;
11354 let k_cos = ki * half;
11355 let cos_tab = sl(*cos, base, *cos_len as usize);
11356 let sin_tab = sl(*sin, base, *cos_len as usize);
11357 for i in 0..half {
11360 let q1 = qkv[q_base + i];
11361 let q2 = qkv[q_base + half + i];
11362 let k1 = qkv[k_base + i];
11363 let k2 = qkv[k_base + half + i];
11364 let c_q = cos_tab[q_cos + i];
11365 let s_q = sin_tab[q_cos + i];
11366 let c_k = cos_tab[k_cos + i];
11367 let s_k = sin_tab[k_cos + i];
11368 let qr1 = q1 * c_q - q2 * s_q;
11369 let kr1 = k1 * c_k - k2 * s_k;
11370 let qr2 = q2 * c_q + q1 * s_q;
11371 let kr2 = k2 * c_k + k1 * s_k;
11372 dot += qr1 * kr1 + qr2 * kr2;
11373 }
11374 } else {
11375 #[cfg(target_arch = "aarch64")]
11377 {
11378 use std::arch::aarch64::*;
11379 let mut acc = vdupq_n_f32(0.0);
11380 for c in 0..neon_chunks {
11381 let vq =
11382 vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
11383 let vk =
11384 vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
11385 acc = vfmaq_f32(acc, vq, vk);
11386 }
11387 dot = vaddvq_f32(acc);
11388 for d in (neon_chunks * 4)..d_h {
11389 dot += qkv[q_base + d] * qkv[k_base + d];
11390 }
11391 }
11392 #[cfg(not(target_arch = "aarch64"))]
11393 for d in 0..d_h {
11394 dot += qkv[q_base + d] * qkv[k_base + d];
11395 }
11396 }
11397
11398 scores_buf[qi * s + ki] = dot * scale;
11399 if mk[bi * s + ki] < mask_thr {
11400 scores_buf[qi * s + ki] = mask_neg;
11401 }
11402 }
11403 }
11404
11405 crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
11407
11408 for qi in 0..s {
11410 let o_base = bi * s * h + qi * h + hi * d_h;
11411 for d in 0..d_h {
11412 attn_out[o_base + d] = 0.0;
11413 }
11414 for ki in 0..s {
11415 let sc = scores_buf[qi * s + ki];
11416 if sc > score_thr {
11417 let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11418 #[cfg(target_arch = "aarch64")]
11419 {
11420 use std::arch::aarch64::*;
11421 let vsc = vdupq_n_f32(sc);
11422 for c in 0..neon_chunks {
11423 let off = c * 4;
11424 let vo =
11425 vld1q_f32(attn_out.as_ptr().add(o_base + off));
11426 let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
11427 vst1q_f32(
11428 attn_out.as_mut_ptr().add(o_base + off),
11429 vfmaq_f32(vo, vsc, vv),
11430 );
11431 }
11432 }
11433 #[cfg(not(target_arch = "aarch64"))]
11434 for d in 0..d_h {
11435 attn_out[o_base + d] += sc * qkv[v_base + d];
11436 }
11437 }
11438 }
11439 }
11440 }
11441 }
11442
11443 crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
11445 if *has_bias {
11446 let bias = sl(*out_b, base, h);
11447 crate::blas::bias_add(dst, bias, m, h);
11448 }
11449 }
11450 }
11451
11452 Thunk::Rope {
11453 src,
11454 cos,
11455 sin,
11456 dst,
11457 batch,
11458 seq,
11459 hidden,
11460 head_dim,
11461 n_rot,
11462 cos_len,
11463 src_row_stride,
11464 } => {
11465 let (b, s, hs, dh, nr) = (
11466 *batch as usize,
11467 *seq as usize,
11468 *hidden as usize,
11469 *head_dim as usize,
11470 *n_rot as usize,
11471 );
11472 let tab_half = dh / 2;
11473 let rot_half = nr / 2;
11474 let nh = hs / dh;
11475 let cl = *cos_len as usize;
11476 let src_rs = *src_row_stride as usize;
11477 unsafe {
11478 let x = sl(*src, base, b * s * src_rs);
11479 let cos_tab = sl(*cos, base, cl);
11480 let sin_tab = sl(*sin, base, cl);
11481 let out = sl_mut(*dst, base, b * s * hs);
11482
11483 let total = b * s;
11484 let x_ptr = x.as_ptr() as usize;
11485 let o_ptr = out.as_mut_ptr() as usize;
11486 let c_ptr = cos_tab.as_ptr() as usize;
11487 let s_ptr = sin_tab.as_ptr() as usize;
11488
11489 crate::pool::par_for(total, 4, &|off, cnt| {
11490 for idx in off..off + cnt {
11491 let bi = idx / s;
11492 let si = idx % s;
11493 let tab_off = si * tab_half;
11494
11495 for hi in 0..nh {
11496 let src_base = bi * s * src_rs + si * src_rs + hi * dh;
11497 let dst_base = bi * s * hs + si * hs + hi * dh;
11498 let xp = (x_ptr as *const f32).add(src_base);
11499 let op = (o_ptr as *mut f32).add(dst_base);
11500 let cp = (c_ptr as *const f32).add(tab_off);
11501 let sp = (s_ptr as *const f32).add(tab_off);
11502
11503 for i in 0..rot_half {
11504 let x1 = *xp.add(i);
11505 let x2 = *xp.add(rot_half + i);
11506 let cv = *cp.add(i);
11507 let sv = *sp.add(i);
11508 *op.add(i) = x1 * cv - x2 * sv;
11509 *op.add(rot_half + i) = x2 * cv + x1 * sv;
11510 }
11511 for j in nr..dh {
11512 *op.add(j) = *xp.add(j);
11513 }
11514 }
11515 }
11516 });
11517 }
11518 }
11519 Thunk::FusedBertLayer {
11520 hidden,
11521 qkv_w,
11522 qkv_b,
11523 out_w,
11524 out_b,
11525 mask,
11526 ln1_g,
11527 ln1_b,
11528 eps1,
11529 fc1_w,
11530 fc1_b,
11531 fc2_w,
11532 fc2_b,
11533 ln2_g,
11534 ln2_b,
11535 eps2,
11536 out,
11537 batch,
11538 seq,
11539 hs,
11540 nh,
11541 dh,
11542 int_dim,
11543 } => {
11544 let (b, s, h, n_h, d_h) = (
11545 *batch as usize,
11546 *seq as usize,
11547 *hs as usize,
11548 *nh as usize,
11549 *dh as usize,
11550 );
11551 let m = b * s;
11552 let id = *int_dim as usize;
11553 let scale = (d_h as f32).powf(-0.5);
11554 let _half = d_h / 2;
11555 #[cfg(target_arch = "aarch64")]
11556 let neon_chunks = d_h / 4;
11557 unsafe {
11558 let inp = sl(*hidden, base, m * h);
11559 let dst = sl_mut(*out, base, m * h);
11560 let mk = sl(*mask, base, b * s);
11561
11562 let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
11564 let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
11565 let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
11566 let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
11567 let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
11568 let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
11569
11570 crate::blas::par_sgemm_bias(
11572 inp,
11573 sl(*qkv_w, base, h * 3 * h),
11574 sl(*qkv_b, base, 3 * h),
11575 qkv,
11576 m,
11577 h,
11578 3 * h,
11579 );
11580
11581 for bi in 0..b {
11583 for hi in 0..n_h {
11584 for qi in 0..s {
11585 for ki in 0..s {
11586 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11587 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11588 #[cfg(target_arch = "aarch64")]
11589 let dot;
11590 #[cfg(not(target_arch = "aarch64"))]
11591 let mut dot = 0f32;
11592 #[cfg(target_arch = "aarch64")]
11593 {
11594 use std::arch::aarch64::*;
11595 let mut acc = vdupq_n_f32(0.0);
11596 for c in 0..neon_chunks {
11597 acc = vfmaq_f32(
11598 acc,
11599 vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
11600 vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
11601 );
11602 }
11603 dot = vaddvq_f32(acc);
11604 }
11605 #[cfg(not(target_arch = "aarch64"))]
11606 for d in 0..d_h {
11607 dot += qkv[q_base + d] * qkv[k_base + d];
11608 }
11609 sc[qi * s + ki] = dot * scale;
11610 if mk[bi * s + ki] < mask_thr {
11611 sc[qi * s + ki] = mask_neg;
11612 }
11613 }
11614 }
11615 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11616 for qi in 0..s {
11617 let o = bi * s * h + qi * h + hi * d_h;
11618 for d in 0..d_h {
11619 attn[o + d] = 0.0;
11620 }
11621 for ki in 0..s {
11622 let w = sc[qi * s + ki];
11623 if w > score_thr {
11624 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11625 #[cfg(target_arch = "aarch64")]
11626 {
11627 use std::arch::aarch64::*;
11628 let vw = vdupq_n_f32(w);
11629 for c in 0..neon_chunks {
11630 let off = c * 4;
11631 vst1q_f32(
11632 attn.as_mut_ptr().add(o + off),
11633 vfmaq_f32(
11634 vld1q_f32(attn.as_ptr().add(o + off)),
11635 vw,
11636 vld1q_f32(qkv.as_ptr().add(v + off)),
11637 ),
11638 );
11639 }
11640 }
11641 #[cfg(not(target_arch = "aarch64"))]
11642 for d in 0..d_h {
11643 attn[o + d] += w * qkv[v + d];
11644 }
11645 }
11646 }
11647 }
11648 }
11649 }
11650
11651 crate::blas::sgemm_bias(
11653 attn,
11654 sl(*out_w, base, h * h),
11655 sl(*out_b, base, h),
11656 res,
11657 m,
11658 h,
11659 h,
11660 );
11661 #[cfg(target_arch = "aarch64")]
11662 {
11663 use std::arch::aarch64::*;
11664 let chunks_h = (m * h) / 4;
11665 for c in 0..chunks_h {
11666 let off = c * 4;
11667 vst1q_f32(
11668 res.as_mut_ptr().add(off),
11669 vaddq_f32(
11670 vld1q_f32(res.as_ptr().add(off)),
11671 vld1q_f32(inp.as_ptr().add(off)),
11672 ),
11673 );
11674 }
11675 for i in (chunks_h * 4)..(m * h) {
11676 res[i] += inp[i];
11677 }
11678 }
11679 #[cfg(not(target_arch = "aarch64"))]
11680 for i in 0..m * h {
11681 res[i] += inp[i];
11682 }
11683
11684 let g1 = sl(*ln1_g, base, h);
11686 let b1 = sl(*ln1_b, base, h);
11687 for r in 0..m {
11688 crate::kernels::layer_norm_row(
11689 &res[r * h..(r + 1) * h],
11690 g1,
11691 b1,
11692 &mut normed[r * h..(r + 1) * h],
11693 h,
11694 *eps1,
11695 );
11696 }
11697
11698 crate::blas::par_sgemm_bias(
11700 normed,
11701 sl(*fc1_w, base, h * id),
11702 sl(*fc1_b, base, id),
11703 ffn,
11704 m,
11705 h,
11706 id,
11707 );
11708 crate::kernels::par_gelu_inplace(ffn);
11709
11710 crate::blas::par_sgemm_bias(
11712 ffn,
11713 sl(*fc2_w, base, id * h),
11714 sl(*fc2_b, base, h),
11715 res,
11716 m,
11717 id,
11718 h,
11719 );
11720 #[cfg(target_arch = "aarch64")]
11721 {
11722 use std::arch::aarch64::*;
11723 let chunks_h = (m * h) / 4;
11724 for c in 0..chunks_h {
11725 let off = c * 4;
11726 vst1q_f32(
11727 res.as_mut_ptr().add(off),
11728 vaddq_f32(
11729 vld1q_f32(res.as_ptr().add(off)),
11730 vld1q_f32(normed.as_ptr().add(off)),
11731 ),
11732 );
11733 }
11734 for i in (chunks_h * 4)..(m * h) {
11735 res[i] += normed[i];
11736 }
11737 }
11738 #[cfg(not(target_arch = "aarch64"))]
11739 for i in 0..m * h {
11740 res[i] += normed[i];
11741 }
11742
11743 let g2 = sl(*ln2_g, base, h);
11745 let b2 = sl(*ln2_b, base, h);
11746 for r in 0..m {
11747 crate::kernels::layer_norm_row(
11748 &res[r * h..(r + 1) * h],
11749 g2,
11750 b2,
11751 &mut dst[r * h..(r + 1) * h],
11752 h,
11753 *eps2,
11754 );
11755 }
11756 }
11757 }
11758
11759 Thunk::FusedNomicLayer {
11760 hidden,
11761 qkv_w,
11762 out_w,
11763 mask,
11764 cos,
11765 sin,
11766 cos_len,
11767 ln1_g,
11768 ln1_b,
11769 eps1,
11770 fc11_w,
11771 fc12_w: _,
11772 fc2_w,
11773 ln2_g,
11774 ln2_b,
11775 eps2,
11776 out,
11777 batch,
11778 seq,
11779 hs,
11780 nh,
11781 dh,
11782 int_dim,
11783 } => {
11784 let (b, s, h, n_h, d_h) = (
11785 *batch as usize,
11786 *seq as usize,
11787 *hs as usize,
11788 *nh as usize,
11789 *dh as usize,
11790 );
11791 let m = b * s;
11792 let id = *int_dim as usize;
11793 let scale = (d_h as f32).powf(-0.5);
11794 let half_dh = d_h / 2;
11795 #[cfg(target_arch = "aarch64")]
11796 let neon_chunks = d_h / 4;
11797 unsafe {
11798 let inp = sl(*hidden, base, m * h);
11799 let dst = sl_mut(*out, base, m * h);
11800 let mk = sl(*mask, base, b * s);
11801 let cos_tab = sl(*cos, base, *cos_len as usize);
11802 let sin_tab = sl(*sin, base, *cos_len as usize);
11803 let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
11805
11806 let mut qkv = vec![0f32; m * 3 * h];
11807 let mut attn = vec![0f32; m * h];
11808 let mut res = vec![0f32; m * h];
11809 let mut normed = vec![0f32; m * h];
11810 let mut ffn_concat = vec![0f32; m * 2 * id]; let mut sc = vec![0f32; s * s];
11812
11813 crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
11815
11816 for bi in 0..b {
11818 for hi in 0..n_h {
11819 for qi in 0..s {
11820 for ki in 0..s {
11821 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11822 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11823 let mut dot = 0f32;
11824 for i in 0..half_dh {
11825 let q1 = qkv[q_base + i];
11826 let q2 = qkv[q_base + half_dh + i];
11827 let k1 = qkv[k_base + i];
11828 let k2 = qkv[k_base + half_dh + i];
11829 let cq = cos_tab[qi * half_dh + i];
11830 let sq = sin_tab[qi * half_dh + i];
11831 let ck = cos_tab[ki * half_dh + i];
11832 let sk = sin_tab[ki * half_dh + i];
11833 dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
11834 + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
11835 }
11836 sc[qi * s + ki] = dot * scale;
11837 if mk[bi * s + ki] < mask_thr {
11838 sc[qi * s + ki] = mask_neg;
11839 }
11840 }
11841 }
11842 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11843 for qi in 0..s {
11844 let o = bi * s * h + qi * h + hi * d_h;
11845 for d in 0..d_h {
11846 attn[o + d] = 0.0;
11847 }
11848 for ki in 0..s {
11849 let w = sc[qi * s + ki];
11850 if w > score_thr {
11851 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11852 #[cfg(target_arch = "aarch64")]
11853 {
11854 use std::arch::aarch64::*;
11855 let vw = vdupq_n_f32(w);
11856 for c in 0..neon_chunks {
11857 let off = c * 4;
11858 vst1q_f32(
11859 attn.as_mut_ptr().add(o + off),
11860 vfmaq_f32(
11861 vld1q_f32(attn.as_ptr().add(o + off)),
11862 vw,
11863 vld1q_f32(qkv.as_ptr().add(v + off)),
11864 ),
11865 );
11866 }
11867 }
11868 #[cfg(not(target_arch = "aarch64"))]
11869 for d in 0..d_h {
11870 attn[o + d] += w * qkv[v + d];
11871 }
11872 }
11873 }
11874 }
11875 }
11876 }
11877
11878 crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
11880 for i in 0..m * h {
11881 res[i] += inp[i];
11882 }
11883
11884 let g1 = sl(*ln1_g, base, h);
11886 let b1 = sl(*ln1_b, base, h);
11887 for r in 0..m {
11888 crate::kernels::layer_norm_row(
11889 &res[r * h..(r + 1) * h],
11890 g1,
11891 b1,
11892 &mut normed[r * h..(r + 1) * h],
11893 h,
11894 *eps1,
11895 );
11896 }
11897
11898 crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
11900 for row in 0..m {
11903 let bo = row * 2 * id;
11904 for j in 0..id {
11906 let x = ffn_concat[bo + id + j];
11907 ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
11908 }
11909 for j in 0..id {
11911 ffn_concat[bo + j] *= ffn_concat[bo + id + j];
11912 }
11913 }
11914
11915 crate::blas::sgemm_general(
11920 ffn_concat.as_ptr(),
11921 sl(*fc2_w, base, id * h).as_ptr(),
11922 res.as_mut_ptr(),
11923 m,
11924 h,
11925 id,
11926 1.0,
11927 0.0,
11928 2 * id,
11929 h,
11930 h,
11931 false,
11932 false,
11933 );
11934 for i in 0..m * h {
11935 res[i] += normed[i];
11936 }
11937
11938 let g2 = sl(*ln2_g, base, h);
11940 let b2 = sl(*ln2_b, base, h);
11941 for r in 0..m {
11942 crate::kernels::layer_norm_row(
11943 &res[r * h..(r + 1) * h],
11944 g2,
11945 b2,
11946 &mut dst[r * h..(r + 1) * h],
11947 h,
11948 *eps2,
11949 );
11950 }
11951 }
11952 }
11953
11954 Thunk::FusedSwiGLU {
11955 src,
11956 dst,
11957 n_half,
11958 total,
11959 gate_first,
11960 } => {
11961 let n = *n_half as usize;
11962 let t = *total as usize;
11963 let outer = t / n;
11964 let in_total = outer * 2 * n;
11965 let gate_first = *gate_first;
11966 unsafe {
11967 let inp = sl(*src, base, in_total);
11968 let out = sl_mut(*dst, base, t);
11969 for o in 0..outer {
11970 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
11971 let out_row = &mut out[o * n..(o + 1) * n];
11972 for i in 0..n {
11973 let (up, gate) = if gate_first {
11974 (in_row[n + i], in_row[i])
11975 } else {
11976 (in_row[i], in_row[n + i])
11977 };
11978 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
11979 }
11980 }
11981 }
11982 }
11983
11984 Thunk::Concat {
11985 dst,
11986 outer,
11987 inner,
11988 total_axis,
11989 inputs,
11990 } => {
11991 let outer = *outer as usize;
11992 let inner = *inner as usize;
11993 let total_axis = *total_axis as usize;
11994 let row_stride = total_axis * inner;
11995 let out_total = outer * row_stride;
11996 unsafe {
11997 let out = sl_mut(*dst, base, out_total);
11998 let mut cum: usize = 0;
11999 for (src_off, in_axis, in_numel) in inputs {
12000 let in_axis = *in_axis as usize;
12001 let copy_per_row = in_axis * inner;
12002 let dst_col_off = cum * inner;
12003 let inp = sl(*src_off, base, (*in_numel as usize).max(1));
12004 concat_copy_rows_f32(
12005 out,
12006 inp,
12007 outer,
12008 copy_per_row,
12009 row_stride,
12010 dst_col_off,
12011 *in_numel as usize,
12012 );
12013 cum += in_axis;
12014 }
12015 }
12016 }
12017
12018 Thunk::ConcatF64 {
12019 dst,
12020 outer,
12021 inner,
12022 total_axis,
12023 inputs,
12024 } => {
12025 let outer = *outer as usize;
12026 let inner = *inner as usize;
12027 let total_axis = *total_axis as usize;
12028 let row_stride = total_axis * inner;
12029 let out_total = outer * row_stride;
12030 unsafe {
12031 let out = sl_mut_f64(*dst, base, out_total);
12032 let mut cum: usize = 0;
12033 for (src_off, in_axis, in_numel) in inputs {
12034 let in_axis = *in_axis as usize;
12035 let copy_per_row = in_axis * inner;
12036 let dst_col_off = cum * inner;
12037 let inp = sl_f64(*src_off, base, (*in_numel as usize).max(1));
12038 concat_copy_rows_f64(
12039 out,
12040 inp,
12041 outer,
12042 copy_per_row,
12043 row_stride,
12044 dst_col_off,
12045 *in_numel as usize,
12046 );
12047 cum += in_axis;
12048 }
12049 }
12050 }
12051
12052 Thunk::Compare {
12053 lhs,
12054 rhs,
12055 dst,
12056 len,
12057 op,
12058 inputs_i64,
12059 inputs_elem_bytes,
12060 dst_elem_bytes,
12061 } => {
12062 let len = *len as usize;
12063 let arena_len = arena_buf.len();
12064 let elem = (*inputs_elem_bytes).max(1) as usize;
12065 let dst_eb = (*dst_elem_bytes).max(1) as usize;
12066 let max_l = (arena_len.saturating_sub(*lhs)) / elem;
12067 let max_r = (arena_len.saturating_sub(*rhs)) / elem;
12068 let max_d = (arena_len.saturating_sub(*dst)) / dst_eb;
12069 let len = len.min(max_l).min(max_r).min(max_d);
12070 if trace_thunks && len > 0 {
12071 eprintln!("[compare] len={len} lhs={} rhs={} dst={}", *lhs, *rhs, *dst);
12072 }
12073 if elem == 1 {
12074 let l = arena_buf[*lhs..*lhs + len].to_vec();
12075 let r = arena_buf[*rhs..*rhs + len].to_vec();
12076 for i in 0..len {
12077 let v = match op {
12078 CmpOp::Eq => l[i] == r[i],
12079 CmpOp::Ne => l[i] != r[i],
12080 CmpOp::Lt => l[i] < r[i],
12081 CmpOp::Le => l[i] <= r[i],
12082 CmpOp::Gt => l[i] > r[i],
12083 CmpOp::Ge => l[i] >= r[i],
12084 };
12085 if *dst_elem_bytes == 1 {
12086 arena_buf[*dst + i] = u8::from(v);
12087 } else {
12088 unsafe {
12089 let o = sl_mut(*dst, base, len);
12090 o[i] = if v { 1.0 } else { 0.0 };
12091 }
12092 }
12093 }
12094 } else if *inputs_i64 != 0 {
12095 unsafe {
12096 let l = sl_i64(*lhs, base, len);
12097 let r = sl_i64(*rhs, base, len);
12098 for i in 0..len {
12099 let v = match op {
12100 CmpOp::Eq => l[i] == r[i],
12101 CmpOp::Ne => l[i] != r[i],
12102 CmpOp::Lt => l[i] < r[i],
12103 CmpOp::Le => l[i] <= r[i],
12104 CmpOp::Gt => l[i] > r[i],
12105 CmpOp::Ge => l[i] >= r[i],
12106 };
12107 if *dst_elem_bytes == 1 {
12108 arena_buf[*dst + i] = u8::from(v);
12109 } else {
12110 let o = sl_mut(*dst, base, len);
12111 o[i] = if v { 1.0 } else { 0.0 };
12112 }
12113 }
12114 }
12115 } else {
12116 unsafe {
12117 let l = sl(*lhs, base, len);
12118 let r = sl(*rhs, base, len);
12119 for i in 0..len {
12120 let v = match op {
12121 CmpOp::Eq => l[i] == r[i],
12122 CmpOp::Ne => l[i] != r[i],
12123 CmpOp::Lt => l[i] < r[i],
12124 CmpOp::Le => l[i] <= r[i],
12125 CmpOp::Gt => l[i] > r[i],
12126 CmpOp::Ge => l[i] >= r[i],
12127 };
12128 if *dst_elem_bytes == 1 {
12129 arena_buf[*dst + i] = u8::from(v);
12130 } else {
12131 let o = sl_mut(*dst, base, len);
12132 o[i] = if v { 1.0 } else { 0.0 };
12133 }
12134 }
12135 }
12136 }
12137 }
12138
12139 Thunk::Where {
12140 cond,
12141 on_true,
12142 on_false,
12143 dst,
12144 len,
12145 elem_bytes,
12146 cond_elem_bytes,
12147 } => {
12148 let len = *len as usize;
12149 let eb = *elem_bytes as usize;
12150 let cond_eb = (*cond_elem_bytes).max(1) as usize;
12151 let arena_len = arena_buf.len();
12152 let len = len
12153 .min((arena_len.saturating_sub(*cond)) / cond_eb)
12154 .min((arena_len.saturating_sub(*on_true)) / eb)
12155 .min((arena_len.saturating_sub(*on_false)) / eb)
12156 .min((arena_len.saturating_sub(*dst)) / eb);
12157 unsafe {
12158 if *elem_bytes == 8 {
12159 let t = sl_i64(*on_true, base, len);
12160 let e = sl_i64(*on_false, base, len);
12161 let o = sl_mut_i64(*dst, base, len);
12162 if *cond_elem_bytes == 1 {
12163 let c = &arena_buf[*cond..*cond + len];
12164 for i in 0..len {
12165 o[i] = if c[i] != 0 { t[i] } else { e[i] };
12166 }
12167 } else {
12168 let c = sl_i64(*cond, base, len);
12169 for i in 0..len {
12170 o[i] = if c[i] != 0 { t[i] } else { e[i] };
12171 }
12172 }
12173 } else if *cond_elem_bytes == 1 {
12174 let c = &arena_buf[*cond..*cond + len];
12175 let t = sl(*on_true, base, len);
12176 let e = sl(*on_false, base, len);
12177 let o = sl_mut(*dst, base, len);
12178 for i in 0..len {
12179 o[i] = if c[i] != 0 { t[i] } else { e[i] };
12180 }
12181 } else {
12182 let c = sl(*cond, base, len);
12183 let t = sl(*on_true, base, len);
12184 let e = sl(*on_false, base, len);
12185 let o = sl_mut(*dst, base, len);
12186 for i in 0..len {
12187 o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
12188 }
12189 }
12190 }
12191 }
12192
12193 Thunk::ScatterAdd {
12194 updates,
12195 indices,
12196 dst,
12197 num_updates,
12198 out_dim,
12199 trailing,
12200 } => {
12201 let num_updates = *num_updates as usize;
12202 let out_dim = *out_dim as usize;
12203 let trailing = *trailing as usize;
12204 unsafe {
12205 let upd = sl(*updates, base, num_updates * trailing);
12206 let ids = sl(*indices, base, num_updates);
12207 let out = sl_mut(*dst, base, out_dim * trailing);
12208 for v in out.iter_mut() {
12210 *v = 0.0;
12211 }
12212 for i in 0..num_updates {
12213 let row = ids[i] as usize;
12214 debug_assert!(row < out_dim, "ScatterAdd index out of range");
12215 let src_off = i * trailing;
12216 let dst_off = row * trailing;
12217 for j in 0..trailing {
12218 out[dst_off + j] += upd[src_off + j];
12219 }
12220 }
12221 }
12222 }
12223
12224 Thunk::GroupedMatMul {
12225 input,
12226 weight,
12227 expert_idx,
12228 dst,
12229 m,
12230 k_dim,
12231 n,
12232 num_experts,
12233 } => {
12234 let m = *m as usize;
12235 let k_dim = *k_dim as usize;
12236 let n = *n as usize;
12237 let num_experts = *num_experts as usize;
12238 unsafe {
12239 let inp = sl(*input, base, m * k_dim);
12240 let wt = sl(*weight, base, num_experts * k_dim * n);
12241 let ids = sl(*expert_idx, base, m);
12242 let out = sl_mut(*dst, base, m * n);
12243
12244 let mut counts = vec![0usize; num_experts];
12247 for i in 0..m {
12248 let e = ids[i] as usize;
12249 debug_assert!(
12250 e < num_experts,
12251 "expert_idx out of range: {e} >= {num_experts}"
12252 );
12253 counts[e] += 1;
12254 }
12255 let mut offsets = vec![0usize; num_experts + 1];
12257 for e in 0..num_experts {
12258 offsets[e + 1] = offsets[e] + counts[e];
12259 }
12260 let mut packed_in = vec![0f32; m * k_dim];
12264 let mut original_pos = vec![0usize; m];
12265 let mut write_idx = vec![0usize; num_experts];
12266 for i in 0..m {
12267 let e = ids[i] as usize;
12268 let dst_row = offsets[e] + write_idx[e];
12269 packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
12270 .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
12271 original_pos[dst_row] = i;
12272 write_idx[e] += 1;
12273 }
12274
12275 let mut packed_out = vec![0f32; m * n];
12279 let expert_stride = k_dim * n;
12280 let gmm_ord = crate::moe_residency::next_gmm_ord();
12281 let moe_layer = gmm_ord / 3;
12282 for e in 0..num_experts {
12283 let count = counts[e];
12284 if count == 0 {
12285 continue;
12286 }
12287 crate::moe_residency::record_expert_tokens(moe_layer, e, count);
12288 let in_start = offsets[e];
12289 let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
12290 let w_slab: &[f32] =
12291 if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
12292 if let Some(ptr) =
12293 crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
12294 {
12295 std::slice::from_raw_parts(ptr, expert_stride)
12296 } else {
12297 &wt[e * expert_stride..(e + 1) * expert_stride]
12298 }
12299 } else {
12300 &wt[e * expert_stride..(e + 1) * expert_stride]
12301 };
12302 let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
12303 crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
12304 }
12305
12306 for packed_idx in 0..m {
12308 let i = original_pos[packed_idx];
12309 out[i * n..(i + 1) * n]
12310 .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
12311 }
12312 }
12313 }
12314
12315 Thunk::DequantGroupedMatMulGguf {
12316 input,
12317 w_q,
12318 expert_idx,
12319 dst,
12320 m,
12321 k_dim,
12322 n,
12323 num_experts,
12324 scheme,
12325 } => {
12326 let m = *m as usize;
12327 let k_dim = *k_dim as usize;
12328 let n = *n as usize;
12329 let num_experts = *num_experts as usize;
12330 let block_elems = scheme.gguf_block_size() as usize;
12331 let block_bytes = scheme.gguf_block_bytes() as usize;
12332 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
12333 unsafe {
12334 let inp = sl(*input, base, m * k_dim);
12335 let wt = std::slice::from_raw_parts(
12336 base.add(*w_q) as *const u8,
12337 num_experts * slab_bytes,
12338 );
12339 let ids = sl(*expert_idx, base, m);
12340 let out = sl_mut(*dst, base, m * n);
12341 crate::gguf_matmul::gguf_grouped_matmul_bt(
12342 inp,
12343 wt,
12344 ids,
12345 out,
12346 m,
12347 k_dim,
12348 n,
12349 num_experts,
12350 *scheme,
12351 );
12352 }
12353 }
12354
12355 Thunk::DequantMoEWeightsGguf {
12356 w_q,
12357 dst,
12358 k_dim,
12359 n,
12360 num_experts,
12361 scheme,
12362 } => {
12363 let k_dim = *k_dim as usize;
12364 let n = *n as usize;
12365 let num_experts = *num_experts as usize;
12366 let block_elems = scheme.gguf_block_size() as usize;
12367 let block_bytes = scheme.gguf_block_bytes() as usize;
12368 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
12369 unsafe {
12370 let wt = std::slice::from_raw_parts(
12371 base.add(*w_q) as *const u8,
12372 num_experts * slab_bytes,
12373 );
12374 let out = sl_mut(*dst, base, num_experts * k_dim * n);
12375 crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
12376 wt,
12377 out,
12378 num_experts,
12379 k_dim,
12380 n,
12381 *scheme,
12382 );
12383 }
12384 }
12385
12386 Thunk::TopK {
12387 src,
12388 dst,
12389 outer,
12390 axis_dim,
12391 k,
12392 indices_i64,
12393 } => {
12394 let outer = *outer as usize;
12395 let axis_dim = *axis_dim as usize;
12396 let k = *k as usize;
12397 unsafe {
12398 let inp = sl(*src, base, outer * axis_dim);
12399 let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
12403 if *indices_i64 != 0 {
12404 let out = sl_mut_i64(*dst, base, outer * k);
12405 for o in 0..outer {
12406 row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
12407 for ki in 0..k {
12408 let mut best_i = 0usize;
12409 let mut best_v = row_buf[0];
12410 for i in 1..axis_dim {
12411 let v = row_buf[i];
12412 if v > best_v {
12413 best_v = v;
12414 best_i = i;
12415 }
12416 }
12417 out[o * k + ki] = best_i as i64;
12418 row_buf[best_i] = f32::NEG_INFINITY;
12419 }
12420 }
12421 } else {
12422 let out = sl_mut(*dst, base, outer * k);
12423 for o in 0..outer {
12424 row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
12425 for ki in 0..k {
12426 let mut best_i = 0usize;
12427 let mut best_v = row_buf[0];
12428 for i in 1..axis_dim {
12429 let v = row_buf[i];
12430 if v > best_v {
12431 best_v = v;
12432 best_i = i;
12433 }
12434 }
12435 out[o * k + ki] = best_i as f32;
12436 row_buf[best_i] = f32::NEG_INFINITY;
12437 }
12438 }
12439 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
12440 cap.push_topk_f32(&out[..outer * k], axis_dim);
12441 }
12442 }
12443 }
12444 }
12445
12446 Thunk::Reduce {
12447 src,
12448 dst,
12449 outer,
12450 reduced,
12451 inner,
12452 op,
12453 } => {
12454 let outer = *outer as usize;
12455 let reduced = *reduced as usize;
12456 let inner = *inner as usize;
12457 let in_total = outer * reduced * inner;
12458 let out_total = outer * inner;
12459 unsafe {
12460 let inp = sl(*src, base, in_total);
12461 let out = sl_mut(*dst, base, out_total);
12462 for o in 0..outer {
12463 for i in 0..inner {
12464 let mut acc = match op {
12465 ReduceOp::Max => f32::NEG_INFINITY,
12466 ReduceOp::Min => f32::INFINITY,
12467 ReduceOp::Prod => 1.0f32,
12468 _ => 0.0f32, };
12470 for r in 0..reduced {
12472 let v = inp[o * reduced * inner + r * inner + i];
12473 acc = match op {
12474 ReduceOp::Sum | ReduceOp::Mean => acc + v,
12475 ReduceOp::Max => acc.max(v),
12476 ReduceOp::Min => acc.min(v),
12477 ReduceOp::Prod => acc * v,
12478 };
12479 }
12480 if matches!(op, ReduceOp::Mean) {
12481 acc /= reduced as f32;
12482 }
12483 out[o * inner + i] = acc;
12484 }
12485 }
12486 }
12487 }
12488
12489 Thunk::ArgReduce {
12490 src,
12491 dst,
12492 outer,
12493 reduced,
12494 inner,
12495 is_max,
12496 } => {
12497 let outer = *outer as usize;
12498 let reduced = *reduced as usize;
12499 let inner = *inner as usize;
12500 let in_total = outer * reduced * inner;
12501 let out_total = outer * inner;
12502 unsafe {
12503 let inp = sl(*src, base, in_total);
12504 let out = sl_mut(*dst, base, out_total);
12505 for o in 0..outer {
12506 for i in 0..inner {
12507 let mut best = inp[o * reduced * inner + i];
12508 let mut best_idx = 0usize;
12509 for r in 1..reduced {
12510 let v = inp[o * reduced * inner + r * inner + i];
12511 let better = if *is_max { v > best } else { v < best };
12512 if better {
12513 best = v;
12514 best_idx = r;
12515 }
12516 }
12517 out[o * inner + i] = best_idx as f32;
12518 }
12519 }
12520 }
12521 }
12522
12523 Thunk::Conv2D1x1 {
12524 src,
12525 weight,
12526 dst,
12527 n,
12528 c_in,
12529 c_out,
12530 hw,
12531 } => {
12532 let n = *n as usize;
12533 let c_in = *c_in as usize;
12534 let c_out = *c_out as usize;
12535 let hw = *hw as usize;
12536 unsafe {
12537 let inp = sl(*src, base, n * c_in * hw);
12538 let wt = sl(*weight, base, c_out * c_in);
12539 let out = sl_mut(*dst, base, n * c_out * hw);
12540 for ni in 0..n {
12545 let in_off = ni * c_in * hw;
12546 let out_off = ni * c_out * hw;
12547 crate::blas::sgemm(
12548 wt,
12549 &inp[in_off..in_off + c_in * hw],
12550 &mut out[out_off..out_off + c_out * hw],
12551 c_out,
12552 c_in,
12553 hw,
12554 );
12555 }
12556 }
12557 }
12558
12559 Thunk::Conv2D {
12560 src,
12561 weight,
12562 dst,
12563 n,
12564 c_in,
12565 h,
12566 w,
12567 c_out,
12568 h_out,
12569 w_out,
12570 kh,
12571 kw,
12572 sh,
12573 sw,
12574 ph,
12575 pw,
12576 dh,
12577 dw,
12578 groups,
12579 } => {
12580 let n = *n as usize;
12581 let c_in = *c_in as usize;
12582 let h = *h as usize;
12583 let w = *w as usize;
12584 let c_out = *c_out as usize;
12585 let h_out = *h_out as usize;
12586 let w_out = *w_out as usize;
12587 let kh = *kh as usize;
12588 let kw = *kw as usize;
12589 let sh = *sh as usize;
12590 let sw = *sw as usize;
12591 let ph = *ph as usize;
12592 let pw = *pw as usize;
12593 let dh = *dh as usize;
12594 let dw = *dw as usize;
12595 let groups = *groups as usize;
12596 let c_in_per_g = c_in / groups;
12597 let c_out_per_g = c_out / groups;
12598 unsafe {
12599 let inp = sl(*src, base, n * c_in * h * w);
12600 let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
12601 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
12602 for ni in 0..n {
12603 for co in 0..c_out {
12604 let g = co / c_out_per_g;
12605 let ci_start = g * c_in_per_g;
12606 for ho in 0..h_out {
12607 for wo in 0..w_out {
12608 let mut acc = 0f32;
12609 for ci_off in 0..c_in_per_g {
12610 let ci = ci_start + ci_off;
12611 let in_chan = ((ni * c_in) + ci) * h * w;
12612 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12613 for ki in 0..kh {
12614 for kj in 0..kw {
12615 let hi = ho * sh + ki * dh;
12616 let wi = wo * sw + kj * dw;
12617 if hi < ph || wi < pw {
12618 continue;
12619 }
12620 let hi = hi - ph;
12621 let wi = wi - pw;
12622 if hi >= h || wi >= w {
12623 continue;
12624 }
12625 acc += inp[in_chan + hi * w + wi]
12626 * wt[wt_chan + ki * kw + kj];
12627 }
12628 }
12629 }
12630 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
12631 acc;
12632 }
12633 }
12634 }
12635 }
12636 }
12637 }
12638
12639 Thunk::Pool2D {
12640 src,
12641 dst,
12642 n,
12643 c,
12644 h,
12645 w,
12646 h_out,
12647 w_out,
12648 kh,
12649 kw,
12650 sh,
12651 sw,
12652 ph,
12653 pw,
12654 kind,
12655 } => {
12656 let n = *n as usize;
12657 let c = *c as usize;
12658 let h = *h as usize;
12659 let w = *w as usize;
12660 let h_out = *h_out as usize;
12661 let w_out = *w_out as usize;
12662 let kh = *kh as usize;
12663 let kw = *kw as usize;
12664 let sh = *sh as usize;
12665 let sw = *sw as usize;
12666 let ph = *ph as usize;
12667 let pw = *pw as usize;
12668 let kernel_area = (kh * kw) as f32;
12669 unsafe {
12670 let inp = sl(*src, base, n * c * h * w);
12671 let out = sl_mut(*dst, base, n * c * h_out * w_out);
12672 for ni in 0..n {
12673 for ci in 0..c {
12674 let in_chan = ni * c * h * w + ci * h * w;
12675 let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
12676 for ho in 0..h_out {
12677 for wo in 0..w_out {
12678 let mut acc = match kind {
12679 ReduceOp::Max => f32::NEG_INFINITY,
12680 _ => 0f32, };
12682 for ki in 0..kh {
12683 for kj in 0..kw {
12684 let hi = ho * sh + ki;
12685 let wi = wo * sw + kj;
12686 if hi < ph || wi < pw {
12688 continue;
12689 }
12690 let hi = hi - ph;
12691 let wi = wi - pw;
12692 if hi >= h || wi >= w {
12693 continue;
12694 }
12695 let v = inp[in_chan + hi * w + wi];
12696 match kind {
12697 ReduceOp::Max => acc = acc.max(v),
12698 _ => acc += v,
12699 }
12700 }
12701 }
12702 if matches!(kind, ReduceOp::Mean) {
12703 acc /= kernel_area;
12704 }
12705 out[out_chan + ho * w_out + wo] = acc;
12706 }
12707 }
12708 }
12709 }
12710 }
12711 }
12712
12713 Thunk::ReluBackward { x, dy, dx, len } => {
12714 let len = *len as usize;
12715 unsafe {
12716 let xs = sl(*x, base, len);
12717 let dys = sl(*dy, base, len);
12718 let out = sl_mut(*dx, base, len);
12719 for i in 0..len {
12720 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12721 }
12722 }
12723 }
12724
12725 Thunk::ReluBackwardF64 { x, dy, dx, len } => {
12726 let len = *len as usize;
12727 unsafe {
12728 let xs = sl_f64(*x, base, len);
12729 let dys = sl_f64(*dy, base, len);
12730 let out = sl_mut_f64(*dx, base, len);
12731 for i in 0..len {
12732 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12733 }
12734 }
12735 }
12736
12737 Thunk::QMatMul {
12738 x,
12739 w,
12740 bias,
12741 out,
12742 m,
12743 k,
12744 n,
12745 x_zp,
12746 w_zp,
12747 out_zp,
12748 mult,
12749 } => {
12750 let m = *m as usize;
12751 let k = *k as usize;
12752 let n = *n as usize;
12753 unsafe {
12754 let x_ptr = base.add(*x) as *const i8;
12755 let w_ptr = base.add(*w) as *const i8;
12756 let bias_ptr = base.add(*bias) as *const i32;
12757 let out_ptr = base.add(*out) as *mut i8;
12758 for mi in 0..m {
12759 for ni in 0..n {
12760 let mut acc: i32 = *bias_ptr.add(ni);
12761 for ki in 0..k {
12762 let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
12763 let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
12764 acc += xv * wv;
12765 }
12766 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12769 let r = r.clamp(-128, 127) as i8;
12770 *out_ptr.add(mi * n + ni) = r;
12771 }
12772 }
12773 }
12774 }
12775
12776 Thunk::QConv2d {
12777 x,
12778 w,
12779 bias,
12780 out,
12781 n,
12782 c_in,
12783 h,
12784 w_in,
12785 c_out,
12786 h_out,
12787 w_out,
12788 kh,
12789 kw,
12790 sh,
12791 sw,
12792 ph,
12793 pw,
12794 dh,
12795 dw,
12796 groups,
12797 x_zp,
12798 w_zp,
12799 out_zp,
12800 mult,
12801 } => {
12802 let n = *n as usize;
12803 let c_in = *c_in as usize;
12804 let h = *h as usize;
12805 let w_in = *w_in as usize;
12806 let c_out = *c_out as usize;
12807 let h_out = *h_out as usize;
12808 let w_out = *w_out as usize;
12809 let kh = *kh as usize;
12810 let kw = *kw as usize;
12811 let sh = *sh as usize;
12812 let sw = *sw as usize;
12813 let ph = *ph as usize;
12814 let pw = *pw as usize;
12815 let dh = *dh as usize;
12816 let dw = *dw as usize;
12817 let groups = *groups as usize;
12818 let c_in_per_g = c_in / groups;
12819 let c_out_per_g = c_out / groups;
12820 unsafe {
12821 let x_ptr = base.add(*x) as *const i8;
12822 let w_ptr = base.add(*w) as *const i8;
12823 let bias_ptr = base.add(*bias) as *const i32;
12824 let out_ptr = base.add(*out) as *mut i8;
12825 for ni in 0..n {
12826 for co in 0..c_out {
12827 let g = co / c_out_per_g;
12828 let ci_start = g * c_in_per_g;
12829 for ho in 0..h_out {
12830 for wo in 0..w_out {
12831 let mut acc: i32 = *bias_ptr.add(co);
12832 for ci_off in 0..c_in_per_g {
12833 let ci = ci_start + ci_off;
12834 let in_chan = ((ni * c_in) + ci) * h * w_in;
12835 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12836 for ki in 0..kh {
12837 for kj in 0..kw {
12838 let hi = ho * sh + ki * dh;
12839 let wi = wo * sw + kj * dw;
12840 if hi < ph || wi < pw {
12841 continue;
12842 }
12843 let hi = hi - ph;
12844 let wi = wi - pw;
12845 if hi >= h || wi >= w_in {
12846 continue;
12847 }
12848 let xv = *x_ptr.add(in_chan + hi * w_in + wi)
12849 as i32
12850 - *x_zp;
12851 let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
12852 - *w_zp;
12853 acc += xv * wv;
12854 }
12855 }
12856 }
12857 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12858 let r = r.clamp(-128, 127) as i8;
12859 let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
12860 *out_ptr.add(dst) = r;
12861 }
12862 }
12863 }
12864 }
12865 }
12866 }
12867
12868 Thunk::Quantize {
12869 x,
12870 q,
12871 len,
12872 chan_axis: _,
12873 chan_dim,
12874 inner,
12875 scales,
12876 zero_points,
12877 } => {
12878 let len = *len as usize;
12879 let chan_dim = *chan_dim as usize;
12880 let inner = *inner as usize;
12881 unsafe {
12882 let xs = sl(*x, base, len);
12883 let q_ptr = base.add(*q) as *mut i8;
12884 for i in 0..len {
12885 let c = if chan_dim == 1 {
12886 0
12887 } else {
12888 (i / inner) % chan_dim
12889 };
12890 let inv_scale = 1.0 / scales[c];
12891 let zp = zero_points[c];
12892 let v = (xs[i] * inv_scale).round() as i32 + zp;
12893 *q_ptr.add(i) = v.clamp(-128, 127) as i8;
12894 }
12895 }
12896 }
12897
12898 Thunk::Dequantize {
12899 q,
12900 x,
12901 len,
12902 chan_axis: _,
12903 chan_dim,
12904 inner,
12905 scales,
12906 zero_points,
12907 } => {
12908 let len = *len as usize;
12909 let chan_dim = *chan_dim as usize;
12910 let inner = *inner as usize;
12911 unsafe {
12912 let q_ptr = base.add(*q) as *const i8;
12913 let out = sl_mut(*x, base, len);
12914 for i in 0..len {
12915 let c = if chan_dim == 1 {
12916 0
12917 } else {
12918 (i / inner) % chan_dim
12919 };
12920 let scale = scales[c];
12921 let zp = zero_points[c];
12922 let qv = *q_ptr.add(i) as i32;
12923 out[i] = (qv - zp) as f32 * scale;
12924 }
12925 }
12926 }
12927
12928 Thunk::FakeQuantize {
12929 x,
12930 out,
12931 len,
12932 chan_axis: _,
12933 chan_dim,
12934 inner,
12935 bits,
12936 ste: _,
12937 scale_mode,
12938 state_off,
12939 } => {
12940 use rlx_ir::op::ScaleMode;
12941 let len = *len as usize;
12942 let chan_dim = *chan_dim as usize;
12943 let inner = *inner as usize;
12944 let q_max: f32 = match *bits {
12945 8 => 127.0,
12946 4 => 7.0,
12947 2 => 1.0,
12948 n => panic!("FakeQuantize: unsupported bits {n}"),
12949 };
12950 unsafe {
12951 let xs = sl(*x, base, len);
12952 let outs = sl_mut(*out, base, len);
12953
12954 let mut scale = vec![0f32; chan_dim];
12955 match scale_mode {
12956 ScaleMode::PerBatch => {
12957 let mut max_abs = vec![0f32; chan_dim];
12958 for i in 0..len {
12959 let c = if chan_dim == 1 {
12960 0
12961 } else {
12962 (i / inner) % chan_dim
12963 };
12964 let a = xs[i].abs();
12965 if a > max_abs[c] {
12966 max_abs[c] = a;
12967 }
12968 }
12969 for c in 0..chan_dim {
12970 scale[c] = (max_abs[c] / q_max).max(1e-12);
12971 }
12972 }
12973 ScaleMode::EMA { decay } => {
12974 let mut max_abs = vec![0f32; chan_dim];
12977 for i in 0..len {
12978 let c = if chan_dim == 1 {
12979 0
12980 } else {
12981 (i / inner) % chan_dim
12982 };
12983 let a = xs[i].abs();
12984 if a > max_abs[c] {
12985 max_abs[c] = a;
12986 }
12987 }
12988 let state =
12989 sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
12990 for c in 0..chan_dim {
12991 let cur = (max_abs[c] / q_max).max(1e-12);
12992 let blended = if state[c] <= 0.0 {
12994 cur
12995 } else {
12996 *decay * state[c] + (1.0 - *decay) * cur
12997 };
12998 state[c] = blended;
12999 scale[c] = blended;
13000 }
13001 }
13002 ScaleMode::Fixed => {
13003 let state =
13004 sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
13005 for c in 0..chan_dim {
13006 scale[c] = state[c].max(1e-12);
13007 }
13008 }
13009 }
13010
13011 for i in 0..len {
13012 let c = if chan_dim == 1 {
13013 0
13014 } else {
13015 (i / inner) % chan_dim
13016 };
13017 let s = scale[c];
13018 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
13019 outs[i] = qv * s;
13020 }
13021 }
13022 }
13023
13024 Thunk::ActivationBackward {
13025 x,
13026 dy,
13027 dx,
13028 len,
13029 kind,
13030 } => {
13031 let len = *len as usize;
13032 unsafe {
13033 let xs = sl(*x, base, len);
13034 let dys = sl(*dy, base, len);
13035 let out = sl_mut(*dx, base, len);
13036 activation_backward_kernel(*kind, xs, dys, out);
13037 }
13038 }
13039
13040 Thunk::ActivationBackwardF64 {
13041 x,
13042 dy,
13043 dx,
13044 len,
13045 kind,
13046 } => {
13047 let len = *len as usize;
13048 unsafe {
13049 let xs = sl_f64(*x, base, len);
13050 let dys = sl_f64(*dy, base, len);
13051 let out = sl_mut_f64(*dx, base, len);
13052 activation_backward_kernel_f64(*kind, xs, dys, out);
13053 }
13054 }
13055
13056 Thunk::FakeQuantizeLSQ {
13057 x,
13058 scale_off,
13059 out,
13060 len,
13061 chan_axis: _,
13062 chan_dim,
13063 inner,
13064 bits,
13065 } => {
13066 let len = *len as usize;
13067 let chan_dim = *chan_dim as usize;
13068 let inner = *inner as usize;
13069 let q_max: f32 = match *bits {
13070 8 => 127.0,
13071 4 => 7.0,
13072 2 => 1.0,
13073 n => panic!("FakeQuantizeLSQ: bad bits {n}"),
13074 };
13075 unsafe {
13076 let xs = sl(*x, base, len);
13077 let scale = sl(*scale_off, base, chan_dim);
13078 let outs = sl_mut(*out, base, len);
13079 for i in 0..len {
13080 let c = if chan_dim == 1 {
13081 0
13082 } else {
13083 (i / inner) % chan_dim
13084 };
13085 let s = scale[c].max(1e-12);
13086 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
13087 outs[i] = qv * s;
13088 }
13089 }
13090 }
13091
13092 Thunk::FakeQuantizeLSQBackwardX {
13093 x,
13094 scale_off,
13095 dy,
13096 dx,
13097 len,
13098 chan_axis: _,
13099 chan_dim,
13100 inner,
13101 bits,
13102 } => {
13103 let len = *len as usize;
13104 let chan_dim = *chan_dim as usize;
13105 let inner = *inner as usize;
13106 let q_max: f32 = match *bits {
13107 8 => 127.0,
13108 4 => 7.0,
13109 2 => 1.0,
13110 n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
13111 };
13112 unsafe {
13113 let xs = sl(*x, base, len);
13114 let scale = sl(*scale_off, base, chan_dim);
13115 let dys = sl(*dy, base, len);
13116 let outs = sl_mut(*dx, base, len);
13117 for i in 0..len {
13119 let c = if chan_dim == 1 {
13120 0
13121 } else {
13122 (i / inner) % chan_dim
13123 };
13124 let z = xs[i] / scale[c].max(1e-12);
13125 outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
13126 }
13127 }
13128 }
13129
13130 Thunk::FakeQuantizeLSQBackwardScale {
13131 x,
13132 scale_off,
13133 dy,
13134 dscale,
13135 len,
13136 chan_axis: _,
13137 chan_dim,
13138 inner,
13139 bits,
13140 } => {
13141 let len = *len as usize;
13142 let chan_dim = *chan_dim as usize;
13143 let inner = *inner as usize;
13144 let q_max: f32 = match *bits {
13145 8 => 127.0,
13146 4 => 7.0,
13147 2 => 1.0,
13148 n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
13149 };
13150 unsafe {
13151 let xs = sl(*x, base, len);
13152 let scale = sl(*scale_off, base, chan_dim);
13153 let dys = sl(*dy, base, len);
13154 let outs = sl_mut(*dscale, base, chan_dim);
13155 for v in outs.iter_mut() {
13156 *v = 0.0;
13157 }
13158 for i in 0..len {
13161 let c = if chan_dim == 1 {
13162 0
13163 } else {
13164 (i / inner) % chan_dim
13165 };
13166 let s = scale[c].max(1e-12);
13167 let z = xs[i] / s;
13168 let psi = if z.abs() <= q_max {
13169 -z + z.round()
13170 } else if z > 0.0 {
13171 q_max
13172 } else {
13173 -q_max
13174 };
13175 outs[c] += psi * dys[i];
13176 }
13177 }
13178 }
13179
13180 Thunk::FakeQuantizeBackward {
13181 x,
13182 dy,
13183 dx,
13184 len,
13185 chan_axis: _,
13186 chan_dim,
13187 inner,
13188 bits,
13189 ste,
13190 } => {
13191 use rlx_ir::op::SteKind;
13192 let len = *len as usize;
13193 let chan_dim = *chan_dim as usize;
13194 let inner = *inner as usize;
13195 let q_max: f32 = match *bits {
13196 8 => 127.0,
13197 4 => 7.0,
13198 2 => 1.0,
13199 n => panic!("FakeQuantizeBackward: bad bits {n}"),
13200 };
13201 unsafe {
13202 let xs = sl(*x, base, len);
13203 let dys = sl(*dy, base, len);
13204 let outs = sl_mut(*dx, base, len);
13205
13206 let mut max_abs = vec![0f32; chan_dim];
13208 for i in 0..len {
13209 let c = if chan_dim == 1 {
13210 0
13211 } else {
13212 (i / inner) % chan_dim
13213 };
13214 let a = xs[i].abs();
13215 if a > max_abs[c] {
13216 max_abs[c] = a;
13217 }
13218 }
13219 let mut scale = vec![0f32; chan_dim];
13220 for c in 0..chan_dim {
13221 scale[c] = (max_abs[c] / q_max).max(1e-12);
13222 }
13223
13224 match *ste {
13225 SteKind::Identity => {
13226 outs.copy_from_slice(dys);
13228 }
13229 SteKind::ClippedIdentity => {
13230 for i in 0..len {
13233 let c = if chan_dim == 1 {
13234 0
13235 } else {
13236 (i / inner) % chan_dim
13237 };
13238 let bound = q_max * scale[c];
13239 outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
13240 }
13241 }
13242 SteKind::Tanh => {
13243 for i in 0..len {
13245 let c = if chan_dim == 1 {
13246 0
13247 } else {
13248 (i / inner) % chan_dim
13249 };
13250 let t = (xs[i] / scale[c]).tanh();
13251 outs[i] = dys[i] * (1.0 - t * t);
13252 }
13253 }
13254 SteKind::HardTanh => {
13255 for i in 0..len {
13257 let c = if chan_dim == 1 {
13258 0
13259 } else {
13260 (i / inner) % chan_dim
13261 };
13262 let bound = q_max * scale[c];
13263 let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
13264 outs[i] = dys[i] * attenuation;
13265 }
13266 }
13267 }
13268 }
13269 }
13270
13271 Thunk::LayerNormBackwardInput {
13272 x,
13273 gamma,
13274 dy,
13275 dx,
13276 rows,
13277 h,
13278 eps,
13279 } => {
13280 let rows = *rows as usize;
13281 let h = *h as usize;
13282 let eps = *eps;
13283 unsafe {
13284 let xs = sl(*x, base, rows * h);
13285 let g = sl(*gamma, base, h);
13286 let dys = sl(*dy, base, rows * h);
13287 let out = sl_mut(*dx, base, rows * h);
13288 let n_inv = 1.0 / h as f32;
13289 for r in 0..rows {
13290 let xr = &xs[r * h..(r + 1) * h];
13291 let dyr = &dys[r * h..(r + 1) * h];
13292 let mut sum = 0f32;
13295 for &v in xr {
13296 sum += v;
13297 }
13298 let mean = sum * n_inv;
13299 let mut var = 0f32;
13300 for &v in xr {
13301 let d = v - mean;
13302 var += d * d;
13303 }
13304 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
13305
13306 let mut s_sy = 0f32;
13309 let mut s_sxh = 0f32;
13310 for d in 0..h {
13311 let xh = (xr[d] - mean) * inv_std;
13312 let sy = dyr[d] * g[d];
13313 s_sy += sy;
13314 s_sxh += sy * xh;
13315 }
13316 let m_sy = s_sy * n_inv;
13317 let m_sxh = s_sxh * n_inv;
13318
13319 for d in 0..h {
13320 let xh = (xr[d] - mean) * inv_std;
13321 let sy = dyr[d] * g[d];
13322 out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
13323 }
13324 }
13325 }
13326 }
13327
13328 Thunk::BatchNormInferenceBackwardInput {
13329 x,
13330 gamma,
13331 mean,
13332 var,
13333 dy,
13334 dx,
13335 count,
13336 channels,
13337 eps,
13338 } => {
13339 let count = *count as usize;
13340 let c = *channels as usize;
13341 let n = count * c;
13342 let eps = *eps;
13343 unsafe {
13344 crate::kernels::batch_norm_inference_backward_input(
13345 sl(*x, base, n),
13346 sl(*gamma, base, c),
13347 sl(*mean, base, c),
13348 sl(*var, base, c),
13349 sl(*dy, base, n),
13350 sl_mut(*dx, base, n),
13351 c,
13352 eps,
13353 );
13354 }
13355 }
13356
13357 Thunk::BatchNormInferenceBackwardGamma {
13358 x,
13359 mean,
13360 var,
13361 dy,
13362 dgamma,
13363 count,
13364 channels,
13365 eps,
13366 } => {
13367 let count = *count as usize;
13368 let c = *channels as usize;
13369 let n = count * c;
13370 let eps = *eps;
13371 unsafe {
13372 crate::kernels::batch_norm_inference_backward_gamma(
13373 sl(*x, base, n),
13374 sl(*mean, base, c),
13375 sl(*var, base, c),
13376 sl(*dy, base, n),
13377 sl_mut(*dgamma, base, c),
13378 c,
13379 eps,
13380 );
13381 }
13382 }
13383
13384 Thunk::BatchNormInferenceBackwardBeta {
13385 dy,
13386 dbeta,
13387 count,
13388 channels,
13389 } => {
13390 let count = *count as usize;
13391 let c = *channels as usize;
13392 let n = count * c;
13393 unsafe {
13394 crate::kernels::batch_norm_inference_backward_beta(
13395 sl(*dy, base, n),
13396 sl_mut(*dbeta, base, c),
13397 c,
13398 );
13399 }
13400 }
13401
13402 Thunk::LayerNormBackwardGamma {
13403 x,
13404 dy,
13405 dgamma,
13406 rows,
13407 h,
13408 eps,
13409 } => {
13410 let rows = *rows as usize;
13411 let h = *h as usize;
13412 let eps = *eps;
13413 unsafe {
13414 let xs = sl(*x, base, rows * h);
13415 let dys = sl(*dy, base, rows * h);
13416 let out = sl_mut(*dgamma, base, h);
13417 for v in out.iter_mut() {
13418 *v = 0.0;
13419 }
13420 let n_inv = 1.0 / h as f32;
13421 for r in 0..rows {
13422 let xr = &xs[r * h..(r + 1) * h];
13423 let dyr = &dys[r * h..(r + 1) * h];
13424 let mut sum = 0f32;
13425 for &v in xr {
13426 sum += v;
13427 }
13428 let mean = sum * n_inv;
13429 let mut var = 0f32;
13430 for &v in xr {
13431 let d = v - mean;
13432 var += d * d;
13433 }
13434 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
13435 for d in 0..h {
13436 let xh = (xr[d] - mean) * inv_std;
13437 out[d] += dyr[d] * xh;
13438 }
13439 }
13440 }
13441 }
13442
13443 Thunk::RmsNormBackwardInput {
13444 x,
13445 gamma,
13446 beta,
13447 dy,
13448 dx,
13449 rows,
13450 h,
13451 eps,
13452 } => {
13453 let (rows, h) = (*rows as usize, *h as usize);
13454 unsafe {
13455 let xs = sl(*x, base, rows * h);
13456 let g = sl(*gamma, base, h);
13457 let b = sl(*beta, base, h);
13458 let dys = sl(*dy, base, rows * h);
13459 let out = sl_mut(*dx, base, rows * h);
13460 let mut dg = vec![0f32; h];
13461 let mut db = vec![0f32; h];
13462 for r in 0..rows {
13463 crate::training_bwd::rms_norm_backward_row(
13464 &xs[r * h..(r + 1) * h],
13465 g,
13466 b,
13467 &dys[r * h..(r + 1) * h],
13468 &mut out[r * h..(r + 1) * h],
13469 &mut dg,
13470 &mut db,
13471 *eps,
13472 );
13473 }
13474 }
13475 }
13476
13477 Thunk::RmsNormBackwardGamma {
13478 x,
13479 gamma,
13480 beta,
13481 dy,
13482 dgamma,
13483 rows,
13484 h,
13485 eps,
13486 } => {
13487 let (rows, h) = (*rows as usize, *h as usize);
13488 unsafe {
13489 let xs = sl(*x, base, rows * h);
13490 let g = sl(*gamma, base, h);
13491 let b = sl(*beta, base, h);
13492 let dys = sl(*dy, base, rows * h);
13493 let out = sl_mut(*dgamma, base, h);
13494 for v in out.iter_mut() {
13495 *v = 0.0;
13496 }
13497 let mut dx = vec![0f32; h];
13498 let mut db = vec![0f32; h];
13499 for r in 0..rows {
13500 crate::training_bwd::rms_norm_backward_row(
13501 &xs[r * h..(r + 1) * h],
13502 g,
13503 b,
13504 &dys[r * h..(r + 1) * h],
13505 &mut dx,
13506 &mut *out,
13507 &mut db,
13508 *eps,
13509 );
13510 }
13511 }
13512 }
13513
13514 Thunk::RmsNormBackwardBeta {
13515 x,
13516 gamma,
13517 beta,
13518 dy,
13519 dbeta,
13520 rows,
13521 h,
13522 eps,
13523 } => {
13524 let (rows, h) = (*rows as usize, *h as usize);
13525 unsafe {
13526 let xs = sl(*x, base, rows * h);
13527 let g = sl(*gamma, base, h);
13528 let b = sl(*beta, base, h);
13529 let dys = sl(*dy, base, rows * h);
13530 let out = sl_mut(*dbeta, base, h);
13531 for v in out.iter_mut() {
13532 *v = 0.0;
13533 }
13534 let mut dx = vec![0f32; h];
13535 let mut dg = vec![0f32; h];
13536 for r in 0..rows {
13537 crate::training_bwd::rms_norm_backward_row(
13538 &xs[r * h..(r + 1) * h],
13539 g,
13540 b,
13541 &dys[r * h..(r + 1) * h],
13542 &mut dx,
13543 &mut dg,
13544 &mut *out,
13545 *eps,
13546 );
13547 }
13548 }
13549 }
13550
13551 Thunk::RopeBackward {
13552 dy,
13553 cos,
13554 sin,
13555 dx,
13556 batch,
13557 seq,
13558 hidden,
13559 head_dim,
13560 n_rot,
13561 cos_len,
13562 } => {
13563 let (b, s, hs, dh, nr, cl) = (
13564 *batch as usize,
13565 *seq as usize,
13566 *hidden as usize,
13567 *head_dim as usize,
13568 *n_rot as usize,
13569 *cos_len as usize,
13570 );
13571 let nh = hs / dh;
13572 let tab_half = dh / 2;
13573 unsafe {
13574 let dys = sl(*dy, base, b * s * hs);
13575 let cos_tab = sl(*cos, base, cl);
13576 let sin_tab = sl(*sin, base, cl);
13577 let out = sl_mut(*dx, base, b * s * hs);
13578 for bi in 0..b {
13579 for si in 0..s {
13580 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
13581 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
13582 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
13583 for hi in 0..nh {
13584 let base_idx = bi * s * hs + si * hs + hi * dh;
13585 crate::training_bwd::rope_backward_row(
13586 &dys[base_idx..base_idx + dh],
13587 cp,
13588 sp,
13589 &mut out[base_idx..base_idx + dh],
13590 dh,
13591 nr,
13592 );
13593 }
13594 }
13595 }
13596 }
13597 }
13598
13599 Thunk::CumsumBackward {
13600 dy,
13601 dx,
13602 rows,
13603 cols,
13604 exclusive,
13605 } => {
13606 let (rows, cols) = (*rows as usize, *cols as usize);
13607 unsafe {
13608 let dys = sl(*dy, base, rows * cols);
13609 let out = sl_mut(*dx, base, rows * cols);
13610 for r in 0..rows {
13611 crate::training_bwd::cumsum_backward_row(
13612 &dys[r * cols..(r + 1) * cols],
13613 &mut out[r * cols..(r + 1) * cols],
13614 *exclusive,
13615 );
13616 }
13617 }
13618 }
13619
13620 Thunk::GroupNormBackwardInput {
13621 x,
13622 gamma,
13623 beta: _beta,
13624 dy,
13625 dx,
13626 n,
13627 c,
13628 h,
13629 w,
13630 num_groups,
13631 eps,
13632 } => {
13633 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13634 let plane = c * h * w;
13635 unsafe {
13636 let xs = sl(*x, base, n * plane);
13637 let g = sl(*gamma, base, c);
13638 let dys = sl(*dy, base, n * plane);
13639 let out = sl_mut(*dx, base, n * plane);
13640 crate::training_bwd::group_norm_backward_input_nchw(
13641 xs,
13642 g,
13643 dys,
13644 out,
13645 n,
13646 c,
13647 h,
13648 w,
13649 *num_groups as usize,
13650 *eps,
13651 );
13652 }
13653 }
13654
13655 Thunk::GroupNormBackwardGamma {
13656 x,
13657 dy,
13658 dgamma,
13659 n,
13660 c,
13661 h,
13662 w,
13663 num_groups,
13664 eps,
13665 } => {
13666 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13667 let plane = c * h * w;
13668 unsafe {
13669 let xs = sl(*x, base, n * plane);
13670 let dys = sl(*dy, base, n * plane);
13671 let out = sl_mut(*dgamma, base, c);
13672 crate::training_bwd::group_norm_backward_gamma_nchw(
13673 xs,
13674 dys,
13675 out,
13676 n,
13677 c,
13678 h,
13679 w,
13680 *num_groups as usize,
13681 *eps,
13682 );
13683 }
13684 }
13685
13686 Thunk::GroupNormBackwardBeta {
13687 dy,
13688 dbeta,
13689 n,
13690 c,
13691 h,
13692 w,
13693 } => {
13694 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13695 let plane = c * h * w;
13696 unsafe {
13697 let dys = sl(*dy, base, n * plane);
13698 let out = sl_mut(*dbeta, base, c);
13699 crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
13700 }
13701 }
13702
13703 Thunk::GatherBackward {
13704 dy,
13705 indices,
13706 dst,
13707 outer,
13708 axis_dim,
13709 num_idx,
13710 trailing,
13711 } => {
13712 let (outer, axis_dim, num_idx, trailing) = (
13713 *outer as usize,
13714 *axis_dim as usize,
13715 *num_idx as usize,
13716 *trailing as usize,
13717 );
13718 unsafe {
13719 let dys = sl(*dy, base, outer * num_idx * trailing);
13720 let ids = sl(*indices, base, num_idx);
13721 let out = sl_mut(*dst, base, outer * axis_dim * trailing);
13722 for v in out.iter_mut() {
13723 *v = 0.0;
13724 }
13725 crate::training_bwd::gather_axis_backward(
13726 dys, ids, out, outer, axis_dim, num_idx, trailing,
13727 );
13728 }
13729 }
13730
13731 Thunk::MaxPool2dBackward {
13732 x,
13733 dy,
13734 dx,
13735 n,
13736 c,
13737 h,
13738 w,
13739 h_out,
13740 w_out,
13741 kh,
13742 kw,
13743 sh,
13744 sw,
13745 ph,
13746 pw,
13747 } => unsafe {
13748 execute_maxpool2d_backward_f32(
13749 *x, *dy, *dx, *n, *c, *h, *w, *h_out, *w_out, *kh, *kw, *sh, *sw, *ph, *pw,
13750 base,
13751 );
13752 },
13753
13754 Thunk::Conv2dBackwardInput {
13755 dy,
13756 w,
13757 dx,
13758 n,
13759 c_in,
13760 h,
13761 w_in,
13762 c_out,
13763 h_out,
13764 w_out,
13765 kh,
13766 kw,
13767 sh,
13768 sw,
13769 ph,
13770 pw,
13771 dh,
13772 dw,
13773 groups,
13774 } => {
13775 let n = *n as usize;
13787 let c_in = *c_in as usize;
13788 let h = *h as usize;
13789 let w_in = *w_in as usize;
13790 let c_out = *c_out as usize;
13791 let h_out = *h_out as usize;
13792 let w_out = *w_out as usize;
13793 let kh = *kh as usize;
13794 let kw = *kw as usize;
13795 let sh = *sh as usize;
13796 let sw = *sw as usize;
13797 let ph = *ph as usize;
13798 let pw = *pw as usize;
13799 let dh = *dh as usize;
13800 let dw = *dw as usize;
13801 let groups = *groups as usize;
13802 let c_in_per_g = c_in / groups;
13803 let c_out_per_g = c_out / groups;
13804
13805 let m_dim = c_in_per_g * kh * kw;
13806 let n_dim = h_out * w_out;
13807 let k_dim = c_out_per_g;
13808
13809 let dy_stride_n = c_out * h_out * w_out;
13810 let dy_stride_g = c_out_per_g * h_out * w_out;
13811 let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13812 let dx_stride_n = c_in * h * w_in;
13813 let dx_stride_g = c_in_per_g * h * w_in;
13814
13815 unsafe {
13816 let dys = sl(*dy, base, n * c_out * h_out * w_out);
13817 let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
13818 let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
13819 for v in dxs.iter_mut() {
13820 *v = 0.0;
13821 }
13822
13823 let mut dcol = vec![0f32; m_dim * n_dim];
13825
13826 for ni in 0..n {
13827 for g in 0..groups {
13828 let w_g_off = g * w_stride_g;
13829 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13830 let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
13831
13832 crate::blas::sgemm_general(
13837 ws.as_ptr().add(w_g_off),
13838 dys.as_ptr().add(dy_n_g_off),
13839 dcol.as_mut_ptr(),
13840 m_dim,
13841 n_dim,
13842 k_dim,
13843 1.0,
13844 0.0,
13845 m_dim,
13846 n_dim,
13847 n_dim,
13848 true,
13849 false,
13850 );
13851
13852 col2im(
13854 &dcol,
13855 &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
13856 c_in_per_g,
13857 h,
13858 w_in,
13859 h_out,
13860 w_out,
13861 kh,
13862 kw,
13863 sh,
13864 sw,
13865 ph,
13866 pw,
13867 dh,
13868 dw,
13869 );
13870 }
13871 }
13872 }
13873 }
13874
13875 Thunk::Conv2dBackwardWeight {
13876 x,
13877 dy,
13878 dw,
13879 n,
13880 c_in,
13881 h,
13882 w,
13883 c_out,
13884 h_out,
13885 w_out,
13886 kh,
13887 kw,
13888 sh,
13889 sw,
13890 ph,
13891 pw,
13892 dh,
13893 dw_dil,
13894 groups,
13895 } => {
13896 let n = *n as usize;
13897 let c_in = *c_in as usize;
13898 let h = *h as usize;
13899 let w = *w as usize;
13900 let c_out = *c_out as usize;
13911 let h_out = *h_out as usize;
13912 let w_out = *w_out as usize;
13913 let kh = *kh as usize;
13914 let kw = *kw as usize;
13915 let sh = *sh as usize;
13916 let sw = *sw as usize;
13917 let ph = *ph as usize;
13918 let pw = *pw as usize;
13919 let dh = *dh as usize;
13920 let dw_dil = *dw_dil as usize;
13921 let groups = *groups as usize;
13922 let c_in_per_g = c_in / groups;
13923 let c_out_per_g = c_out / groups;
13924
13925 let m_dim = c_out_per_g;
13926 let n_dim = c_in_per_g * kh * kw;
13927 let k_dim = h_out * w_out;
13928
13929 let x_stride_n = c_in * h * w;
13930 let x_stride_g = c_in_per_g * h * w;
13931 let dy_stride_n = c_out * h_out * w_out;
13932 let dy_stride_g = c_out_per_g * h_out * w_out;
13933 let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13934
13935 unsafe {
13936 let xs = sl(*x, base, n * c_in * h * w);
13937 let dys = sl(*dy, base, n * c_out * h_out * w_out);
13938 let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
13939 for v in dws.iter_mut() {
13940 *v = 0.0;
13941 }
13942
13943 let mut col = vec![0f32; n_dim * k_dim];
13944
13945 for ni in 0..n {
13946 for g in 0..groups {
13947 let x_n_g_off = ni * x_stride_n + g * x_stride_g;
13948 im2col(
13949 &xs[x_n_g_off..x_n_g_off + x_stride_g],
13950 &mut col,
13951 c_in_per_g,
13952 h,
13953 w,
13954 h_out,
13955 w_out,
13956 kh,
13957 kw,
13958 sh,
13959 sw,
13960 ph,
13961 pw,
13962 dh,
13963 dw_dil,
13964 );
13965
13966 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13967 let dw_g_off = g * dw_stride_g;
13968
13969 crate::blas::sgemm_general(
13977 dys.as_ptr().add(dy_n_g_off),
13978 col.as_ptr(),
13979 dws.as_mut_ptr().add(dw_g_off),
13980 m_dim,
13981 n_dim,
13982 k_dim,
13983 1.0,
13984 1.0,
13985 k_dim,
13986 k_dim,
13987 n_dim,
13988 false,
13989 true,
13990 );
13991 }
13992 }
13993 }
13994 }
13995
13996 Thunk::Im2Col {
13997 x,
13998 col,
13999 n,
14000 c_in,
14001 h,
14002 w,
14003 h_out,
14004 w_out,
14005 kh,
14006 kw,
14007 sh,
14008 sw,
14009 ph,
14010 pw,
14011 dh,
14012 dw_dil,
14013 } => {
14014 let c_in = *c_in as usize;
14015 let h = *h as usize;
14016 let w = *w as usize;
14017 let h_out = *h_out as usize;
14018 let w_out = *w_out as usize;
14019 let kh = *kh as usize;
14020 let kw = *kw as usize;
14021 let sh = *sh as usize;
14022 let sw = *sw as usize;
14023 let ph = *ph as usize;
14024 let pw = *pw as usize;
14025 let dh = *dh as usize;
14026 let dw_dil = *dw_dil as usize;
14027 let per_batch = c_in * h * w;
14028 unsafe {
14029 let n_eff = if *n == 0 { 0usize } else { *n as usize };
14030 let x_floats = if n_eff == 0 {
14031 per_batch.max(1)
14032 } else {
14033 n_eff * per_batch
14034 };
14035 let xs = sl(*x, base, x_floats);
14036 let n = if *n == 0 {
14037 xs.len() / per_batch.max(1)
14038 } else {
14039 n_eff
14040 };
14041 let m = n * h_out * w_out;
14042 let k = c_in * kh * kw;
14043 let cols = sl_mut(*col, base, m * k);
14044 crate::im2col::im2col_rows_layout(
14045 xs, cols, n, c_in, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw_dil,
14046 );
14047 }
14048 }
14049
14050 Thunk::SoftmaxCrossEntropy {
14051 logits,
14052 labels,
14053 dst,
14054 n,
14055 c,
14056 } => {
14057 let n = *n as usize;
14058 let c = *c as usize;
14059 unsafe {
14060 let lg = sl(*logits, base, n * c);
14061 let lb = sl(*labels, base, n);
14062 let out = sl_mut(*dst, base, n);
14063 for ni in 0..n {
14064 let row = &lg[ni * c..(ni + 1) * c];
14065 let mut m = f32::NEG_INFINITY;
14067 for &v in row {
14068 if v > m {
14069 m = v;
14070 }
14071 }
14072 let mut sum = 0f32;
14073 for &v in row {
14074 sum += (v - m).exp();
14075 }
14076 let lse = m + sum.ln();
14077 let label_idx = lb[ni] as usize;
14078 out[ni] = lse - row[label_idx];
14080 }
14081 }
14082 }
14083
14084 Thunk::SoftmaxCrossEntropyBackward {
14085 logits,
14086 labels,
14087 d_loss,
14088 dlogits,
14089 n,
14090 c,
14091 } => {
14092 let n = *n as usize;
14093 let c = *c as usize;
14094 unsafe {
14095 let lg = sl(*logits, base, n * c);
14096 let lb = sl(*labels, base, n);
14097 let dl = sl(*d_loss, base, n);
14098 let out = sl_mut(*dlogits, base, n * c);
14099 for ni in 0..n {
14100 let row = &lg[ni * c..(ni + 1) * c];
14101 let label_idx = lb[ni] as usize;
14102 let scale = dl[ni];
14103 let mut m = f32::NEG_INFINITY;
14104 for &v in row {
14105 if v > m {
14106 m = v;
14107 }
14108 }
14109 let mut sum = 0f32;
14110 for &v in row {
14111 sum += (v - m).exp();
14112 }
14113 let inv_sum = 1.0 / sum;
14114 let dst_row = &mut out[ni * c..(ni + 1) * c];
14115 for k in 0..c {
14116 let p = (row[k] - m).exp() * inv_sum;
14117 let one_hot = if k == label_idx { 1.0 } else { 0.0 };
14118 dst_row[k] = (p - one_hot) * scale;
14119 }
14120 }
14121 }
14122 }
14123
14124 Thunk::GatherAxis {
14125 table,
14126 idx,
14127 dst,
14128 outer,
14129 axis_dim,
14130 num_idx,
14131 trailing,
14132 idx_i64,
14133 table_bytes,
14134 } => {
14135 let outer = *outer as usize;
14136 let axis_dim = *axis_dim as usize;
14137 let num_idx = *num_idx as usize;
14138 let trailing = *trailing as usize;
14139 unsafe {
14140 if *table_bytes == 8 {
14141 let tab = sl_i64(*table, base, outer * axis_dim * trailing);
14142 let out = sl_mut_i64(*dst, base, outer * num_idx * trailing);
14143 for o in 0..outer {
14144 let tab_outer = o * axis_dim * trailing;
14145 let out_outer = o * num_idx * trailing;
14146 if *idx_i64 != 0 {
14147 let ids = sl_i64(*idx, base, num_idx);
14148 for k in 0..num_idx {
14149 let row = ids[k].max(0) as usize;
14150 if row < axis_dim {
14151 let tab_row = tab_outer + row * trailing;
14152 let out_row = out_outer + k * trailing;
14153 out[out_row..out_row + trailing]
14154 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
14155 }
14156 }
14157 } else {
14158 let ids = sl(*idx, base, num_idx);
14159 for k in 0..num_idx {
14160 let row = ids[k] as usize;
14161 if row < axis_dim {
14162 let tab_row = tab_outer + row * trailing;
14163 let out_row = out_outer + k * trailing;
14164 out[out_row..out_row + trailing]
14165 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
14166 }
14167 }
14168 }
14169 }
14170 } else {
14171 let tab = sl(*table, base, outer * axis_dim * trailing);
14172 let out = sl_mut(*dst, base, outer * num_idx * trailing);
14173 for o in 0..outer {
14174 let tab_outer = o * axis_dim * trailing;
14175 let out_outer = o * num_idx * trailing;
14176 if *idx_i64 != 0 {
14177 let ids = sl_i64(*idx, base, num_idx);
14178 for k in 0..num_idx {
14179 let row = ids[k].max(0) as usize;
14180 if row < axis_dim {
14181 let tab_row = tab_outer + row * trailing;
14182 let out_row = out_outer + k * trailing;
14183 out[out_row..out_row + trailing]
14184 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
14185 }
14186 }
14187 } else {
14188 let ids = sl(*idx, base, num_idx);
14189 for k in 0..num_idx {
14190 let row = ids[k] as usize;
14191 if row < axis_dim {
14192 let tab_row = tab_outer + row * trailing;
14193 let out_row = out_outer + k * trailing;
14194 out[out_row..out_row + trailing]
14195 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
14196 }
14197 }
14198 }
14199 }
14200 }
14201 }
14202 }
14203
14204 Thunk::Transpose {
14205 src,
14206 dst,
14207 in_total,
14208 out_dims,
14209 in_strides,
14210 elem_bytes,
14211 } => {
14212 let rank = out_dims.len();
14217 let total: usize = out_dims.iter().map(|&d| d as usize).product();
14218 let in_total = *in_total as usize;
14219 unsafe {
14220 if *elem_bytes == 1 {
14221 let inp = arena_buf[*src..*src + in_total].to_vec();
14226 let out = &mut arena_buf[*dst..*dst + total];
14227 let mut idx = vec![0usize; rank];
14228 for o in 0..total {
14229 let mut src_idx = 0usize;
14230 for d in 0..rank {
14231 src_idx += idx[d] * in_strides[d] as usize;
14232 }
14233 out[o] = inp[broadcast_src_index(src_idx, in_total)];
14234 for d in (0..rank).rev() {
14235 idx[d] += 1;
14236 if idx[d] < out_dims[d] as usize {
14237 break;
14238 }
14239 idx[d] = 0;
14240 }
14241 }
14242 } else if *elem_bytes == 8 {
14243 let inp = sl_i64(*src, base, in_total);
14244 let out = sl_mut_i64(*dst, base, total);
14245 let mut idx = vec![0usize; rank];
14246 for o in 0..total {
14247 let mut src_idx = 0usize;
14248 for d in 0..rank {
14249 src_idx += idx[d] * in_strides[d] as usize;
14250 }
14251 out[o] = inp[broadcast_src_index(src_idx, in_total)];
14252 for d in (0..rank).rev() {
14253 idx[d] += 1;
14254 if idx[d] < out_dims[d] as usize {
14255 break;
14256 }
14257 idx[d] = 0;
14258 }
14259 }
14260 } else {
14261 let inp = sl(*src, base, in_total);
14262 let out = sl_mut(*dst, base, total);
14263 let mut idx = vec![0usize; rank];
14264 for o in 0..total {
14265 let mut src_idx = 0usize;
14266 for d in 0..rank {
14267 src_idx += idx[d] * in_strides[d] as usize;
14268 }
14269 out[o] = inp[broadcast_src_index(src_idx, in_total)];
14270 for d in (0..rank).rev() {
14271 idx[d] += 1;
14272 if idx[d] < out_dims[d] as usize {
14273 break;
14274 }
14275 idx[d] = 0;
14276 }
14277 }
14278 }
14279 }
14280 }
14281
14282 Thunk::CustomOp {
14288 kernel,
14289 inputs,
14290 output,
14291 attrs,
14292 } => {
14293 let (out_off, out_len, out_shape) = output;
14294 unsafe {
14295 dispatch_custom_op(
14296 &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
14297 );
14298 }
14299 }
14300 }
14301 if trace_done {
14302 eprintln!("[thunk {i} done]");
14303 }
14304 }
14305}
14306
14307#[allow(clippy::too_many_arguments)]
14322unsafe fn griewank_process_segment(
14323 t_lo: usize,
14324 t_hi: usize,
14325 anchor_carry: &[u8],
14326 cb: usize,
14327 fwd_sched: &ThunkSchedule,
14328 fwd_init: &[u8],
14329 fwd_carry_in_off: usize,
14330 fwd_output_off: usize,
14331 fwd_x_offs: &[usize],
14332 base: *mut u8,
14333 outer_xs_offs: &[(usize, u32)],
14334 fwd_buf: &mut Vec<u8>,
14335 leaf_threshold: usize,
14336 process_iter: &mut dyn FnMut(usize, &[u8]),
14337) {
14338 unsafe {
14339 let size = t_hi - t_lo + 1;
14340 if size == 1 {
14341 process_iter(t_lo, anchor_carry);
14342 return;
14343 }
14344 if size <= leaf_threshold {
14345 let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
14347 cache.extend_from_slice(anchor_carry);
14348 fwd_buf.copy_from_slice(fwd_init);
14349 std::ptr::copy_nonoverlapping(
14350 anchor_carry.as_ptr(),
14351 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
14352 cb,
14353 );
14354 for i in 1..size {
14355 let cur_iter = t_lo + i - 1;
14356 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
14357 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
14358 let xb = x_psb as usize;
14359 std::ptr::copy_nonoverlapping(
14360 base.add(outer_xs_off + cur_iter * xb),
14361 fwd_buf.as_mut_ptr().add(*fb_x_off),
14362 xb,
14363 );
14364 }
14365 execute_thunks(fwd_sched, fwd_buf);
14366 if fwd_output_off != fwd_carry_in_off {
14367 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
14368 }
14369 cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
14370 }
14371 for t in (t_lo..=t_hi).rev() {
14373 let idx = t - t_lo;
14374 let carry = &cache[idx * cb..(idx + 1) * cb];
14375 process_iter(t, carry);
14376 }
14377 return;
14378 }
14379
14380 let mid = t_lo + size / 2;
14384 fwd_buf.copy_from_slice(fwd_init);
14385 std::ptr::copy_nonoverlapping(
14386 anchor_carry.as_ptr(),
14387 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
14388 cb,
14389 );
14390 for cur_iter in t_lo..mid {
14391 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
14392 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
14393 let xb = x_psb as usize;
14394 std::ptr::copy_nonoverlapping(
14395 base.add(outer_xs_off + cur_iter * xb),
14396 fwd_buf.as_mut_ptr().add(*fb_x_off),
14397 xb,
14398 );
14399 }
14400 execute_thunks(fwd_sched, fwd_buf);
14401 if fwd_output_off != fwd_carry_in_off {
14402 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
14403 }
14404 }
14405 let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
14406
14407 griewank_process_segment(
14411 mid,
14412 t_hi,
14413 &mid_carry,
14414 cb,
14415 fwd_sched,
14416 fwd_init,
14417 fwd_carry_in_off,
14418 fwd_output_off,
14419 fwd_x_offs,
14420 base,
14421 outer_xs_offs,
14422 fwd_buf,
14423 leaf_threshold,
14424 process_iter,
14425 );
14426 griewank_process_segment(
14428 t_lo,
14429 mid - 1,
14430 anchor_carry,
14431 cb,
14432 fwd_sched,
14433 fwd_init,
14434 fwd_carry_in_off,
14435 fwd_output_off,
14436 fwd_x_offs,
14437 base,
14438 outer_xs_offs,
14439 fwd_buf,
14440 leaf_threshold,
14441 process_iter,
14442 );
14443 }
14444}
14445
14446pub unsafe fn execute_fft1d_f64(
14463 src: usize,
14464 dst: usize,
14465 outer: usize,
14466 n_complex: usize,
14467 inverse: bool,
14468 norm_tag: u32,
14469 base: *mut u8,
14470) {
14471 let row_elems = 2 * n_complex;
14472 let mut re = vec![0f64; n_complex];
14473 let mut im = vec![0f64; n_complex];
14474 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14475 let scale = norm.output_scale(n_complex, inverse);
14476 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14479 BluesteinScratchF64::empty()
14480 } else {
14481 BluesteinScratchF64::build(n_complex, inverse)
14482 };
14483 for o in 0..outer {
14484 let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
14485 let s = unsafe { sl_f64(row_offset, base, row_elems) };
14486 re.copy_from_slice(&s[..n_complex]);
14487 im.copy_from_slice(&s[n_complex..]);
14488 if n_complex.is_power_of_two() {
14489 fft_radix2_inplace_f64(&mut re, &mut im, inverse);
14490 } else if n_complex <= 16 {
14491 fft_naive_inplace_f64(&mut re, &mut im, inverse);
14492 } else {
14493 fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
14494 }
14495 if scale != 1.0 {
14496 re.iter_mut().for_each(|v| *v *= scale);
14497 im.iter_mut().for_each(|v| *v *= scale);
14498 }
14499 let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
14500 let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
14501 d[..n_complex].copy_from_slice(&re);
14502 d[n_complex..].copy_from_slice(&im);
14503 }
14504}
14505
14506unsafe fn cgemm_c64(
14516 a_off: usize,
14517 b_off: usize,
14518 c_off: usize,
14519 m: usize,
14520 k: usize,
14521 n: usize,
14522 base: *mut u8,
14523) {
14524 use rayon::prelude::*;
14525 let bptr = base as usize;
14526 unsafe {
14527 let a = std::slice::from_raw_parts((bptr + a_off) as *const f32, 2 * m * k);
14528 let b = std::slice::from_raw_parts((bptr + b_off) as *const f32, 2 * k * n);
14529 let c_base = bptr + c_off;
14530 (0..m).into_par_iter().for_each(|i| {
14531 let crow = std::slice::from_raw_parts_mut((c_base + i * n * 8) as *mut f32, 2 * n);
14532 for j in 0..n {
14533 let mut re = 0f32;
14534 let mut im = 0f32;
14535 for l in 0..k {
14536 let ar = a[2 * (i * k + l)];
14537 let ai = a[2 * (i * k + l) + 1];
14538 let br = b[2 * (l * n + j)];
14539 let bi = b[2 * (l * n + j) + 1];
14540 re += ar * br - ai * bi;
14541 im += ar * bi + ai * br;
14542 }
14543 crow[2 * j] = re;
14544 crow[2 * j + 1] = im;
14545 }
14546 });
14547 }
14548}
14549
14550#[allow(clippy::too_many_arguments)]
14558pub unsafe fn execute_lstm_f32(
14559 x: usize,
14560 w_ih: usize,
14561 w_hh: usize,
14562 bias: usize,
14563 h0: usize,
14564 c0: usize,
14565 dst: usize,
14566 batch: usize,
14567 seq: usize,
14568 input_size: usize,
14569 hidden: usize,
14570 num_layers: usize,
14571 bidirectional: bool,
14572 carry: bool,
14573 base: *mut u8,
14574) {
14575 use rayon::prelude::*;
14576
14577 #[inline]
14578 fn sigmoid(z: f32) -> f32 {
14579 1.0 / (1.0 + (-z).exp())
14580 }
14581
14582 let bptr = base as usize;
14583 let four_h = 4 * hidden;
14584 let dirs = if bidirectional { 2 } else { 1 };
14585
14586 unsafe {
14587 let f32s = |off: usize, n: usize| -> &[f32] {
14588 std::slice::from_raw_parts((bptr + off) as *const f32, n)
14589 };
14590
14591 let mut layer_in: Vec<f32> = f32s(x, batch * seq * input_size).to_vec();
14593 let mut in_l = input_size;
14594 let mut wih_cursor = 0usize;
14597
14598 for l in 0..num_layers {
14599 let out_width = dirs * hidden;
14600 let mut layer_out = vec![0f32; batch * seq * out_width];
14601 let lo_ptr = layer_out.as_mut_ptr() as usize;
14602 let li_ref: &[f32] = &layer_in;
14603 let wih_block = four_h * in_l;
14604
14605 for dir in 0..dirs {
14606 let ld = l * dirs + dir;
14607 let wih = f32s((w_ih / 4 + wih_cursor + dir * wih_block) * 4, wih_block);
14608 let whh = f32s(w_hh + ld * four_h * hidden * 4, four_h * hidden);
14609 let bs = f32s(bias + ld * four_h * 4, four_h);
14610 let h0p = bptr + h0 + ld * batch * hidden * 4;
14611 let c0p = bptr + c0 + ld * batch * hidden * 4;
14612
14613 (0..batch).into_par_iter().for_each(|b| {
14614 let lo = lo_ptr as *mut f32;
14615 let mut h = vec![0f32; hidden];
14616 let mut c = vec![0f32; hidden];
14617 if carry {
14618 let hin = std::slice::from_raw_parts(
14619 (h0p + b * hidden * 4) as *const f32,
14620 hidden,
14621 );
14622 let cin = std::slice::from_raw_parts(
14623 (c0p + b * hidden * 4) as *const f32,
14624 hidden,
14625 );
14626 h.copy_from_slice(hin);
14627 c.copy_from_slice(cin);
14628 }
14629 let mut z = vec![0f32; four_h];
14630 for step in 0..seq {
14631 let t = if dir == 0 { step } else { seq - 1 - step };
14632 let x_t = &li_ref[(b * seq + t) * in_l..(b * seq + t + 1) * in_l];
14633 for r in 0..four_h {
14634 let wr = &wih[r * in_l..(r + 1) * in_l];
14635 let mut acc = bs[r];
14636 for j in 0..in_l {
14637 acc += wr[j] * x_t[j];
14638 }
14639 let hr = &whh[r * hidden..(r + 1) * hidden];
14640 for (j, &hj) in h.iter().enumerate() {
14641 acc += hr[j] * hj;
14642 }
14643 z[r] = acc;
14644 }
14645 for k in 0..hidden {
14646 let i_g = sigmoid(z[k]);
14647 let f_g = sigmoid(z[hidden + k]);
14648 let g_g = z[2 * hidden + k].tanh();
14649 let o_g = sigmoid(z[3 * hidden + k]);
14650 let c_new = f_g * c[k] + i_g * g_g;
14651 c[k] = c_new;
14652 let h_new = o_g * c_new.tanh();
14653 h[k] = h_new;
14654 *lo.add((b * seq + t) * out_width + dir * hidden + k) = h_new;
14657 }
14658 }
14659 if carry {
14660 let hout = std::slice::from_raw_parts_mut(
14661 (h0p + b * hidden * 4) as *mut f32,
14662 hidden,
14663 );
14664 let cout = std::slice::from_raw_parts_mut(
14665 (c0p + b * hidden * 4) as *mut f32,
14666 hidden,
14667 );
14668 hout.copy_from_slice(&h);
14669 cout.copy_from_slice(&c);
14670 }
14671 });
14672 }
14673
14674 wih_cursor += dirs * wih_block;
14675 layer_in = layer_out;
14676 in_l = out_width;
14677 }
14678
14679 let dst_slice = std::slice::from_raw_parts_mut((bptr + dst) as *mut f32, layer_in.len());
14681 dst_slice.copy_from_slice(&layer_in);
14682 }
14683}
14684
14685pub unsafe fn execute_gated_delta_net_f32(
14688 q: usize,
14689 k: usize,
14690 v: usize,
14691 g: usize,
14692 beta: usize,
14693 state: usize,
14694 dst: usize,
14695 batch: usize,
14696 seq: usize,
14697 heads: usize,
14698 state_size: usize,
14699 base: *mut u8,
14700) {
14701 use rayon::prelude::*;
14702
14703 #[derive(Copy, Clone)]
14704 struct ArenaPtr(usize);
14705 unsafe impl Send for ArenaPtr {}
14706 unsafe impl Sync for ArenaPtr {}
14707 impl ArenaPtr {
14708 #[inline]
14709 fn get(self) -> *mut u8 {
14710 self.0 as *mut u8
14711 }
14712 }
14713
14714 unsafe {
14715 let arena = ArenaPtr(base as usize);
14716 let (b, s, h, n) = (batch, seq, heads, state_size);
14717 let scale = 1.0f32 / (n as f32).sqrt();
14718 let use_external = state != 0;
14719 let mut owned_state = vec![0f32; h * n * n];
14720
14721 crate::pool::num_threads();
14722
14723 assert!(
14724 n <= crate::gdn::GDN_MAX_STATE,
14725 "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
14726 crate::gdn::GDN_MAX_STATE
14727 );
14728
14729 let qs = sl(q, arena.get(), b * s * h * n);
14730 let ks = sl(k, arena.get(), b * s * h * n);
14731 let vs = sl(v, arena.get(), b * s * h * n);
14732 let gs = sl(g, arena.get(), b * s * h);
14733 let betas = sl(beta, arena.get(), b * s * h);
14734 let _out = sl_mut(dst, arena.get(), b * s * h * n);
14735 let hs_n = h * n;
14736
14737 let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
14738 for ti in 0..s {
14739 let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
14740 let gb_step = bi * s * h + ti * h + hi;
14741 let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
14742 crate::gdn::gdn_step_blas(
14743 s_mat,
14744 &qs[qkv_step..qkv_step + n],
14745 &ks[qkv_step..qkv_step + n],
14746 &vs[qkv_step..qkv_step + n],
14747 gs[gb_step],
14748 betas[gb_step],
14749 out_row,
14750 sk,
14751 n,
14752 scale,
14753 );
14754 }
14755 };
14756
14757 if !use_external && s > 1 {
14760 for bi in 0..b {
14761 (0..h).into_par_iter().for_each(|hi| {
14762 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14763 let sk = &mut sk_buf[..n];
14764 let mut local_state =
14765 [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
14766 let s_mat = &mut local_state[..n * n];
14767 s_mat.fill(0.0);
14768 run_head(bi, hi, s_mat, sk);
14769 });
14770 }
14771 return;
14772 }
14773
14774 if use_external {
14775 let state_bytes = state;
14776 (0..b * h).into_par_iter().for_each(|bhi| {
14777 let bi = bhi / h;
14778 let hi = bhi % h;
14779 let elem_off = bi * h * n * n + hi * n * n;
14780 let s_mat = sl_mut(
14781 state_bytes + elem_off * std::mem::size_of::<f32>(),
14782 arena.get(),
14783 n * n,
14784 );
14785 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14786 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
14787 });
14788 } else {
14789 for bi in 0..b {
14790 owned_state.fill(0.0);
14791 owned_state
14792 .par_chunks_mut(n * n)
14793 .enumerate()
14794 .for_each(|(hi, s_mat)| {
14795 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14796 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
14797 });
14798 }
14799 }
14800 }
14801}
14802
14803pub unsafe fn execute_rms_norm_backward_input_f32(
14805 x: usize,
14806 gamma: usize,
14807 beta: usize,
14808 dy: usize,
14809 dx: usize,
14810 rows: u32,
14811 h: u32,
14812 eps: f32,
14813 base: *mut u8,
14814) {
14815 let (rows, h) = (rows as usize, h as usize);
14816 let mut dg = vec![0f32; h];
14817 let mut db = vec![0f32; h];
14818 let xs = sl(x, base, rows * h);
14819 let dys = sl(dy, base, rows * h);
14820 let g = sl(gamma, base, h);
14821 let b = sl(beta, base, h);
14822 let out = sl_mut(dx, base, rows * h);
14823 for r in 0..rows {
14824 crate::training_bwd::rms_norm_backward_row(
14825 &xs[r * h..(r + 1) * h],
14826 g,
14827 b,
14828 &dys[r * h..(r + 1) * h],
14829 &mut out[r * h..(r + 1) * h],
14830 &mut dg,
14831 &mut db,
14832 eps,
14833 );
14834 }
14835}
14836
14837pub unsafe fn execute_rms_norm_backward_gamma_f32(
14838 x: usize,
14839 gamma: usize,
14840 beta: usize,
14841 dy: usize,
14842 dgamma: usize,
14843 rows: u32,
14844 h: u32,
14845 eps: f32,
14846 base: *mut u8,
14847) {
14848 let (rows, h) = (rows as usize, h as usize);
14849 let out = sl_mut(dgamma, base, h);
14850 out.fill(0.0);
14851 let mut dx = vec![0f32; h];
14852 let mut db = vec![0f32; h];
14853 let xs = sl(x, base, rows * h);
14854 let dys = sl(dy, base, rows * h);
14855 let g = sl(gamma, base, h);
14856 let b = sl(beta, base, h);
14857 for r in 0..rows {
14858 crate::training_bwd::rms_norm_backward_row(
14859 &xs[r * h..(r + 1) * h],
14860 g,
14861 b,
14862 &dys[r * h..(r + 1) * h],
14863 &mut dx,
14864 out,
14865 &mut db,
14866 eps,
14867 );
14868 }
14869}
14870
14871pub unsafe fn execute_rms_norm_backward_beta_f32(
14872 x: usize,
14873 gamma: usize,
14874 beta: usize,
14875 dy: usize,
14876 dbeta: usize,
14877 rows: u32,
14878 h: u32,
14879 eps: f32,
14880 base: *mut u8,
14881) {
14882 let (rows, h) = (rows as usize, h as usize);
14883 let out = sl_mut(dbeta, base, h);
14884 out.fill(0.0);
14885 let mut dx = vec![0f32; h];
14886 let mut dg = vec![0f32; h];
14887 let xs = sl(x, base, rows * h);
14888 let dys = sl(dy, base, rows * h);
14889 let g = sl(gamma, base, h);
14890 let b = sl(beta, base, h);
14891 for r in 0..rows {
14892 crate::training_bwd::rms_norm_backward_row(
14893 &xs[r * h..(r + 1) * h],
14894 g,
14895 b,
14896 &dys[r * h..(r + 1) * h],
14897 &mut dx,
14898 &mut dg,
14899 out,
14900 eps,
14901 );
14902 }
14903}
14904
14905#[allow(clippy::too_many_arguments)]
14906pub unsafe fn execute_conv2d_forward_f32(
14907 src: usize,
14908 weight: usize,
14909 dst: usize,
14910 n: u32,
14911 c_in: u32,
14912 h: u32,
14913 w: u32,
14914 c_out: u32,
14915 h_out: u32,
14916 w_out: u32,
14917 kh: u32,
14918 kw: u32,
14919 sh: u32,
14920 sw: u32,
14921 ph: u32,
14922 pw: u32,
14923 dh: u32,
14924 dw: u32,
14925 groups: u32,
14926 base: *mut u8,
14927) {
14928 let n = n as usize;
14929 let c_in = c_in as usize;
14930 let h = h as usize;
14931 let w = w as usize;
14932 let c_out = c_out as usize;
14933 let h_out = h_out as usize;
14934 let w_out = w_out as usize;
14935 let kh = kh as usize;
14936 let kw = kw as usize;
14937 let sh = sh as usize;
14938 let sw = sw as usize;
14939 let ph = ph as usize;
14940 let pw = pw as usize;
14941 let dh = dh as usize;
14942 let dw = dw as usize;
14943 let groups = groups as usize;
14944 let c_in_per_g = c_in / groups;
14945 let inp = sl(src, base, n * c_in * h * w);
14946 let wt = sl(weight, base, c_out * c_in_per_g * kh * kw);
14947 let out = sl_mut(dst, base, n * c_out * h_out * w_out);
14948 crate::conv_fwd::conv2d_forward_nchw_f32(
14949 inp, wt, out, n, c_in, h, w, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw, groups,
14950 );
14951}
14952
14953pub unsafe fn execute_maxpool2d_backward_f32(
14954 x: usize,
14955 dy: usize,
14956 dx: usize,
14957 n: u32,
14958 c: u32,
14959 h: u32,
14960 w: u32,
14961 h_out: u32,
14962 w_out: u32,
14963 kh: u32,
14964 kw: u32,
14965 sh: u32,
14966 sw: u32,
14967 ph: u32,
14968 pw: u32,
14969 base: *mut u8,
14970) {
14971 let (n, c, h, w) = (n as usize, c as usize, h as usize, w as usize);
14972 let (h_out, w_out) = (h_out as usize, w_out as usize);
14973 let (kh, kw) = (kh as usize, kw as usize);
14974 let (sh, sw) = (sh as usize, sw as usize);
14975 let (ph, pw) = (ph as usize, pw as usize);
14976 let xs = sl(x, base, n * c * h * w);
14977 let dys = sl(dy, base, n * c * h_out * w_out);
14978 let dxs = sl_mut(dx, base, n * c * h * w);
14979 crate::training_bwd::maxpool2d_backward_nchw(
14980 xs, dys, dxs, n, c, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw,
14981 );
14982}
14983
14984pub unsafe fn execute_rope_backward_f32(
14985 dy: usize,
14986 cos: usize,
14987 sin: usize,
14988 dx: usize,
14989 batch: u32,
14990 seq: u32,
14991 hidden: u32,
14992 head_dim: u32,
14993 n_rot: u32,
14994 cos_len: u32,
14995 base: *mut u8,
14996) {
14997 let (b, s, hs, dh, nr, cl) = (
14998 batch as usize,
14999 seq as usize,
15000 hidden as usize,
15001 head_dim as usize,
15002 n_rot as usize,
15003 cos_len as usize,
15004 );
15005 let nh = hs / dh;
15006 let tab_half = dh / 2;
15007 let dys = sl(dy, base, b * s * hs);
15008 let cos_tab = sl(cos, base, cl);
15009 let sin_tab = sl(sin, base, cl);
15010 let out = sl_mut(dx, base, b * s * hs);
15011 for bi in 0..b {
15012 for si in 0..s {
15013 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
15014 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
15015 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
15016 for hi in 0..nh {
15017 let base_idx = bi * s * hs + si * hs + hi * dh;
15018 crate::training_bwd::rope_backward_row(
15019 &dys[base_idx..base_idx + dh],
15020 cp,
15021 sp,
15022 &mut out[base_idx..base_idx + dh],
15023 dh,
15024 nr,
15025 );
15026 }
15027 }
15028 }
15029}
15030
15031pub unsafe fn execute_cumsum_backward_f32(
15032 dy: usize,
15033 dx: usize,
15034 rows: u32,
15035 cols: u32,
15036 exclusive: bool,
15037 base: *mut u8,
15038) {
15039 let (rows, cols) = (rows as usize, cols as usize);
15040 let dys = sl(dy, base, rows * cols);
15041 let out = sl_mut(dx, base, rows * cols);
15042 for r in 0..rows {
15043 crate::training_bwd::cumsum_backward_row(
15044 &dys[r * cols..(r + 1) * cols],
15045 &mut out[r * cols..(r + 1) * cols],
15046 exclusive,
15047 );
15048 }
15049}
15050
15051pub unsafe fn execute_gather_backward_f32(
15052 dy: usize,
15053 indices: usize,
15054 dst: usize,
15055 outer: u32,
15056 axis_dim: u32,
15057 num_idx: u32,
15058 trailing: u32,
15059 base: *mut u8,
15060) {
15061 let (outer, axis_dim, num_idx, trailing) = (
15062 outer as usize,
15063 axis_dim as usize,
15064 num_idx as usize,
15065 trailing as usize,
15066 );
15067 let out = sl_mut(dst, base, outer * axis_dim * trailing);
15068 out.fill(0.0);
15069 crate::training_bwd::gather_axis_backward(
15070 sl(dy, base, outer * num_idx * trailing),
15071 sl(indices, base, num_idx),
15072 out,
15073 outer,
15074 axis_dim,
15075 num_idx,
15076 trailing,
15077 );
15078}
15079
15080pub unsafe fn execute_dequant_matmul_gguf_f32(
15082 x: usize,
15083 w_q: usize,
15084 dst: usize,
15085 m: usize,
15086 k: usize,
15087 n: usize,
15088 scheme: rlx_ir::quant::QuantScheme,
15089 base: *mut u8,
15090) {
15091 unsafe {
15092 let block_bytes = scheme.gguf_block_bytes() as usize;
15093 let block_elems = scheme.gguf_block_size() as usize;
15094 let total_bytes = (k * n) / block_elems * block_bytes;
15095 let xs = sl(x, base, m * k);
15096 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
15097 let out = sl_mut(dst, base, m * n);
15098 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
15099 }
15100}
15101
15102pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
15104 input: usize,
15105 w_q: usize,
15106 expert_idx: usize,
15107 dst: usize,
15108 m: usize,
15109 k: usize,
15110 n: usize,
15111 num_experts: usize,
15112 scheme: rlx_ir::quant::QuantScheme,
15113 base: *mut u8,
15114) {
15115 unsafe {
15116 let block_bytes = scheme.gguf_block_bytes() as usize;
15117 let block_elems = scheme.gguf_block_size() as usize;
15118 let slab_bytes = (k * n) / block_elems * block_bytes;
15119 let xs = sl(input, base, m * k);
15120 let w_bytes =
15121 std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
15122 let ids = sl(expert_idx, base, m);
15123 let out = sl_mut(dst, base, m * n);
15124 crate::gguf_matmul::gguf_grouped_matmul_bt(
15125 xs,
15126 w_bytes,
15127 ids,
15128 out,
15129 m,
15130 k,
15131 n,
15132 num_experts,
15133 scheme,
15134 );
15135 }
15136}
15137
15138pub unsafe fn execute_dequant_matmul_int8_f32(
15140 x: usize,
15141 w_q: usize,
15142 scale: usize,
15143 zp: usize,
15144 dst: usize,
15145 m: usize,
15146 k: usize,
15147 n: usize,
15148 block_size: u32,
15149 is_asymmetric: bool,
15150 base: *mut u8,
15151) {
15152 let bs = block_size as usize;
15153 let n_blocks = k.div_ceil(bs);
15154 unsafe {
15155 let xs = sl(x, base, m * k);
15156 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const i8, k * n);
15157 let scales = sl(scale, base, n_blocks * n);
15158 let zps = if is_asymmetric {
15159 sl(zp, base, n_blocks * n)
15160 } else {
15161 &[][..]
15162 };
15163 let out = sl_mut(dst, base, m * n);
15164 dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
15165 }
15166}
15167
15168pub unsafe fn execute_dequant_matmul_int4_f32(
15170 x: usize,
15171 w_q: usize,
15172 scale: usize,
15173 zp: usize,
15174 dst: usize,
15175 m: usize,
15176 k: usize,
15177 n: usize,
15178 block_size: u32,
15179 is_asymmetric: bool,
15180 base: *mut u8,
15181) {
15182 let bs = block_size as usize;
15183 let n_blocks = k.div_ceil(bs);
15184 unsafe {
15185 let xs = sl(x, base, m * k);
15186 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
15187 let scales = sl(scale, base, n_blocks * n);
15188 let zps = if is_asymmetric {
15189 sl(zp, base, n_blocks * n)
15190 } else {
15191 &[][..]
15192 };
15193 let out = sl_mut(dst, base, m * n);
15194 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
15195 }
15196}
15197
15198pub unsafe fn execute_dequant_matmul_fp8_f32(
15200 x: usize,
15201 w_q: usize,
15202 scale: usize,
15203 dst: usize,
15204 m: usize,
15205 k: usize,
15206 n: usize,
15207 e5m2: bool,
15208 base: *mut u8,
15209) {
15210 unsafe {
15211 let xs = sl(x, base, m * k);
15212 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
15213 let scales = sl(scale, base, n);
15214 let out = sl_mut(dst, base, m * n);
15215 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
15216 }
15217}
15218
15219pub unsafe fn execute_dequant_matmul_nvfp4_f32(
15221 x: usize,
15222 w_q: usize,
15223 scale: usize,
15224 global_scale: usize,
15225 dst: usize,
15226 m: usize,
15227 k: usize,
15228 n: usize,
15229 base: *mut u8,
15230) {
15231 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
15232 unsafe {
15233 let xs = sl(x, base, m * k);
15234 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
15235 let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
15236 let gs = sl(global_scale, base, 1)[0];
15237 let out = sl_mut(dst, base, m * n);
15238 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
15239 }
15240}
15241
15242pub unsafe fn execute_gated_delta_net_f16(
15244 q: usize,
15245 k: usize,
15246 v: usize,
15247 g: usize,
15248 beta: usize,
15249 state: usize,
15250 dst: usize,
15251 batch: usize,
15252 seq: usize,
15253 heads: usize,
15254 state_size: usize,
15255 base: *mut u8,
15256) {
15257 use half::f16;
15258 unsafe {
15259 let read_f16 = |off: usize, len: usize| -> Vec<f32> {
15260 let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
15261 raw.chunks_exact(2)
15262 .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
15263 .collect()
15264 };
15265 let write_f16 = |off: usize, data: &[f32]| {
15266 let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
15267 for (i, &v) in data.iter().enumerate() {
15268 let le = f16::from_f32(v).to_le_bytes();
15269 out[i * 2] = le[0];
15270 out[i * 2 + 1] = le[1];
15271 }
15272 };
15273
15274 let (b, s, h, n) = (batch, seq, heads, state_size);
15275 let q_f = read_f16(q, b * s * h * n);
15276 let k_f = read_f16(k, b * s * h * n);
15277 let v_f = read_f16(v, b * s * h * n);
15278 let g_f = read_f16(g, b * s * h);
15279 let b_f = read_f16(beta, b * s * h);
15280 let mut state_f = if state != 0 {
15281 read_f16(state, b * h * n * n)
15282 } else {
15283 vec![0f32; b * h * n * n]
15284 };
15285 let mut out_f = vec![0f32; b * s * h * n];
15286 let scale = 1.0f32 / (n as f32).sqrt();
15287 let mut sk_buf = vec![0f32; n];
15288 let mut owned_state = vec![0f32; h * n * n];
15289
15290 for bi in 0..b {
15291 let state_slice: &mut [f32] = if state != 0 {
15292 let start = bi * h * n * n;
15293 &mut state_f[start..start + h * n * n]
15294 } else {
15295 owned_state.fill(0.0);
15296 &mut owned_state
15297 };
15298
15299 for ti in 0..s {
15300 let qkv_step_base = bi * s * h * n + ti * h * n;
15301 let gb_step_base = bi * s * h + ti * h;
15302
15303 for hi in 0..h {
15304 let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
15305 let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
15306 let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
15307 let g_t = g_f[gb_step_base + hi];
15308 let beta_t = b_f[gb_step_base + hi];
15309
15310 let s_base = hi * n * n;
15311 let s_mat = &mut state_slice[s_base..s_base + n * n];
15312
15313 let g_exp = g_t.exp();
15314 for st in s_mat.iter_mut() {
15315 *st *= g_exp;
15316 }
15317
15318 for j in 0..n {
15319 let mut acc = 0f32;
15320 for i in 0..n {
15321 acc += s_mat[i * n + j] * k_row[i];
15322 }
15323 sk_buf[j] = acc;
15324 }
15325
15326 for j in 0..n {
15327 sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
15328 }
15329
15330 for i in 0..n {
15331 let ki = k_row[i];
15332 for j in 0..n {
15333 s_mat[i * n + j] += ki * sk_buf[j];
15334 }
15335 }
15336
15337 let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
15338 for j in 0..n {
15339 let mut acc = 0f32;
15340 for i in 0..n {
15341 acc += s_mat[i * n + j] * q_row[i];
15342 }
15343 out_row[j] = acc * scale;
15344 }
15345 }
15346 }
15347 }
15348
15349 write_f16(dst, &out_f);
15350 if state != 0 {
15351 write_f16(state, &state_f);
15352 }
15353 }
15354}
15355
15356pub unsafe fn execute_group_norm_nchw_f32(
15358 src: usize,
15359 g: usize,
15360 b: usize,
15361 dst: usize,
15362 n: usize,
15363 c: usize,
15364 h: usize,
15365 w: usize,
15366 num_groups: usize,
15367 eps: f32,
15368 base: *mut u8,
15369) {
15370 let plane = c * h * w;
15371 for ni in 0..n {
15372 let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
15373 let gamma = unsafe { sl(g, base, c) };
15374 let beta = unsafe { sl(b, base, c) };
15375 let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
15376 crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
15377 }
15378}
15379
15380pub unsafe fn execute_layer_norm2d_nchw_f32(
15382 src: usize,
15383 g: usize,
15384 b: usize,
15385 dst: usize,
15386 n: usize,
15387 c: usize,
15388 h: usize,
15389 w: usize,
15390 eps: f32,
15391 base: *mut u8,
15392) {
15393 let plane = c * h * w;
15394 unsafe {
15395 let input = sl(src, base, n * plane);
15396 let gamma = sl(g, base, c);
15397 let beta = sl(b, base, c);
15398 let output = sl_mut(dst, base, n * plane);
15399 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
15400 }
15401}
15402
15403pub unsafe fn execute_conv_transpose2d_nchw_f32(
15405 src: usize,
15406 weight: usize,
15407 dst: usize,
15408 n: usize,
15409 c_in: usize,
15410 h: usize,
15411 w_in: usize,
15412 c_out: usize,
15413 h_out: usize,
15414 w_out: usize,
15415 kh: usize,
15416 kw: usize,
15417 sh: usize,
15418 sw: usize,
15419 ph: usize,
15420 pw: usize,
15421 dh: usize,
15422 dw: usize,
15423 groups: usize,
15424 base: *mut u8,
15425) {
15426 let in_elems = n * c_in * h * w_in;
15427 let w_elems = c_in * (c_out / groups) * kh * kw;
15428 let out_elems = n * c_out * h_out * w_out;
15429 unsafe {
15430 let input = sl(src, base, in_elems);
15431 let wt = sl(weight, base, w_elems);
15432 let output = sl_mut(dst, base, out_elems);
15433 crate::kernels::conv_transpose2d_nchw(
15434 input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
15435 dw, groups,
15436 );
15437 }
15438}
15439
15440pub unsafe fn execute_resize_nearest_2x_f32(
15442 src: usize,
15443 dst: usize,
15444 n: usize,
15445 c: usize,
15446 h: usize,
15447 w: usize,
15448 base: *mut u8,
15449) {
15450 let in_plane = c * h * w;
15451 let out_plane = c * h * 2 * w * 2;
15452 for ni in 0..n {
15453 let input = unsafe {
15454 sl(
15455 src + ni * in_plane * std::mem::size_of::<f32>(),
15456 base,
15457 in_plane,
15458 )
15459 };
15460 let output = unsafe {
15461 sl_mut(
15462 dst + ni * out_plane * std::mem::size_of::<f32>(),
15463 base,
15464 out_plane,
15465 )
15466 };
15467 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
15468 }
15469}
15470
15471pub unsafe fn execute_axial_rope2d_f32(
15473 src: usize,
15474 dst: usize,
15475 batch: usize,
15476 seq: usize,
15477 hidden: usize,
15478 end_x: usize,
15479 end_y: usize,
15480 head_dim: usize,
15481 num_heads: usize,
15482 theta: f32,
15483 repeat_factor: usize,
15484 base: *mut u8,
15485) {
15486 let plane = seq * hidden;
15487 let plane_bytes = plane * std::mem::size_of::<f32>();
15488 for bi in 0..batch {
15489 let in_off = src + bi * plane_bytes;
15490 let input = unsafe { sl(in_off, base, plane) };
15491 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
15492 input,
15493 num_heads,
15494 seq,
15495 head_dim,
15496 end_x,
15497 end_y,
15498 theta,
15499 repeat_factor,
15500 );
15501 let out_off = dst + bi * plane_bytes;
15502 let output = unsafe { sl_mut(out_off, base, plane) };
15503 output.copy_from_slice(&rotated);
15504 }
15505}
15506
15507pub unsafe fn execute_fft_butterfly_stage_f32(
15509 state_src: usize,
15510 state_dst: usize,
15511 gate_src: usize,
15512 rev_src: usize,
15513 tw_re_src: usize,
15514 tw_im_src: usize,
15515 batch: usize,
15516 n_fft: usize,
15517 stage: usize,
15518 base: *mut u8,
15519) {
15520 let half = n_fft / 2;
15521 let stride = 1usize << stage;
15522 let gate = unsafe { sl(gate_src, base, half) };
15523 let rev = unsafe { sl(rev_src, base, half) };
15524 let tw_re = unsafe { sl(tw_re_src, base, half) };
15525 let tw_im = unsafe { sl(tw_im_src, base, half) };
15526 let row_elems = n_fft * 2;
15527 for b in 0..batch {
15528 let in_off = state_src + b * row_elems * std::mem::size_of::<f32>();
15529 let out_off = state_dst + b * row_elems * std::mem::size_of::<f32>();
15530 let inp = unsafe { sl(in_off, base, row_elems) };
15531 let out = unsafe { sl_mut(out_off, base, row_elems) };
15532 out.copy_from_slice(inp);
15533 for bf in 0..half {
15534 if gate[bf] == 0.0 {
15535 continue;
15536 }
15537 let group = bf / stride;
15538 let k = bf % stride;
15539 let i0 = group * 2 * stride + k;
15540 let i1 = i0 + stride;
15541 let w_re = tw_re[bf];
15542 let w_im = tw_im[bf];
15543 let in_a_re = inp[i0 * 2];
15544 let in_a_im = inp[i0 * 2 + 1];
15545 let in_b_re = inp[i1 * 2];
15546 let in_b_im = inp[i1 * 2 + 1];
15547 let (b_re, b_im) = (
15548 in_b_re * w_re - in_b_im * w_im,
15549 in_b_re * w_im + in_b_im * w_re,
15550 );
15551 let (top_re, top_im) = (in_a_re + b_re, in_a_im + b_im);
15552 let (bot_re, bot_im) = (in_a_re - b_re, in_a_im - b_im);
15553 let (oa_re, oa_im, ob_re, ob_im) = if rev[bf] >= 0.5 {
15554 (bot_re, bot_im, top_re, top_im)
15555 } else {
15556 (top_re, top_im, bot_re, bot_im)
15557 };
15558 out[i0 * 2] = oa_re;
15559 out[i0 * 2 + 1] = oa_im;
15560 out[i1 * 2] = ob_re;
15561 out[i1 * 2 + 1] = ob_im;
15562 }
15563 }
15564}
15565
15566pub unsafe fn execute_fft1d_f32(
15568 src: usize,
15569 dst: usize,
15570 outer: usize,
15571 n_complex: usize,
15572 inverse: bool,
15573 norm_tag: u32,
15574 base: *mut u8,
15575) {
15576 let row_elems = 2 * n_complex;
15577 let mut re = vec![0f32; n_complex];
15578 let mut im = vec![0f32; n_complex];
15579 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
15580 let scale = norm.output_scale(n_complex, inverse) as f32;
15581 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
15582 BluesteinScratchF32::empty()
15583 } else {
15584 BluesteinScratchF32::build(n_complex, inverse)
15585 };
15586 for o in 0..outer {
15587 let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
15588 let s = unsafe { sl(row_offset, base, row_elems) };
15589 re.copy_from_slice(&s[..n_complex]);
15590 im.copy_from_slice(&s[n_complex..]);
15591 if n_complex.is_power_of_two() {
15592 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
15593 } else if n_complex <= 16 {
15594 fft_naive_inplace_f32(&mut re, &mut im, inverse);
15595 } else {
15596 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
15597 }
15598 if scale != 1.0 {
15599 re.iter_mut().for_each(|v| *v *= scale);
15600 im.iter_mut().for_each(|v| *v *= scale);
15601 }
15602 let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
15603 let d = unsafe { sl_mut(dst_offset, base, row_elems) };
15604 d[..n_complex].copy_from_slice(&re);
15605 d[n_complex..].copy_from_slice(&im);
15606 }
15607}
15608
15609pub unsafe fn execute_fft1d_c64(
15611 src: usize,
15612 dst: usize,
15613 outer: usize,
15614 n_complex: usize,
15615 inverse: bool,
15616 norm_tag: u32,
15617 base: *mut u8,
15618) {
15619 let row_bytes = n_complex * 8;
15620 let mut re = vec![0f32; n_complex];
15621 let mut im = vec![0f32; n_complex];
15622 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
15623 let scale = norm.output_scale(n_complex, inverse) as f32;
15624 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
15625 BluesteinScratchF32::empty()
15626 } else {
15627 BluesteinScratchF32::build(n_complex, inverse)
15628 };
15629 for o in 0..outer {
15630 let row_offset = src + o * row_bytes;
15631 for i in 0..n_complex {
15632 let elem_off = row_offset + i * 8;
15633 re[i] = f32::from_le_bytes([
15634 *base.add(elem_off),
15635 *base.add(elem_off + 1),
15636 *base.add(elem_off + 2),
15637 *base.add(elem_off + 3),
15638 ]);
15639 im[i] = f32::from_le_bytes([
15640 *base.add(elem_off + 4),
15641 *base.add(elem_off + 5),
15642 *base.add(elem_off + 6),
15643 *base.add(elem_off + 7),
15644 ]);
15645 }
15646 if n_complex.is_power_of_two() {
15647 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
15648 } else if n_complex <= 16 {
15649 fft_naive_inplace_f32(&mut re, &mut im, inverse);
15650 } else {
15651 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
15652 }
15653 if scale != 1.0 {
15654 re.iter_mut().for_each(|v| *v *= scale);
15655 im.iter_mut().for_each(|v| *v *= scale);
15656 }
15657 let dst_row = dst + o * row_bytes;
15658 for i in 0..n_complex {
15659 let elem_off = dst_row + i * 8;
15660 let re_b = re[i].to_le_bytes();
15661 let im_b = im[i].to_le_bytes();
15662 for j in 0..4 {
15663 *base.add(elem_off + j) = re_b[j];
15664 *base.add(elem_off + 4 + j) = im_b[j];
15665 }
15666 }
15667 }
15668}
15669
15670pub unsafe fn execute_log_mel(
15672 spec: usize,
15673 filters: usize,
15674 dst: usize,
15675 outer: usize,
15676 n_fft: usize,
15677 n_bins: usize,
15678 n_mels: usize,
15679 base: *mut u8,
15680) {
15681 execute_log_mel_f32(spec, filters, dst, outer, n_fft, n_bins, n_mels, base);
15682}
15683
15684pub unsafe fn execute_log_mel_f32(
15685 spec: usize,
15686 filters: usize,
15687 dst: usize,
15688 outer: usize,
15689 n_fft: usize,
15690 n_bins: usize,
15691 n_mels: usize,
15692 base: *mut u8,
15693) {
15694 let spec_ptr = base.add(spec) as *const f32;
15695 let filt_ptr = base.add(filters) as *const f32;
15696 let dst_ptr = base.add(dst) as *mut f32;
15697 let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
15698 let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
15699 let out = std::slice::from_raw_parts_mut(dst_ptr, outer * n_mels);
15700 rlx_ir::audio::log_mel_block_f32(spec, filters, outer, n_fft, n_bins, n_mels, out);
15701}
15702
15703pub unsafe fn execute_welch_peaks_f32(
15704 spec: usize,
15705 dst: usize,
15706 welch_batch: usize,
15707 n_fft: usize,
15708 n_segments: usize,
15709 k: usize,
15710 base: *mut u8,
15711) {
15712 let spec_ptr = base.add(spec) as *const f32;
15713 let dst_ptr = base.add(dst) as *mut f32;
15714 let outer = welch_batch * n_segments;
15715 let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
15716 let out = std::slice::from_raw_parts_mut(dst_ptr, welch_batch * k * 2);
15717 rlx_ir::audio::welch_peaks_block_f32(spec, welch_batch, n_fft, n_segments, k, out);
15718}
15719
15720pub unsafe fn execute_log_mel_backward_f32(
15721 spec: usize,
15722 filters: usize,
15723 dy: usize,
15724 dst: usize,
15725 outer: usize,
15726 n_fft: usize,
15727 n_bins: usize,
15728 n_mels: usize,
15729 base: *mut u8,
15730) {
15731 let spec_ptr = base.add(spec) as *const f32;
15732 let filt_ptr = base.add(filters) as *const f32;
15733 let dy_ptr = base.add(dy) as *const f32;
15734 let dst_ptr = base.add(dst) as *mut f32;
15735 let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
15736 let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
15737 let dy = std::slice::from_raw_parts(dy_ptr, outer * n_mels);
15738 let d_spec = std::slice::from_raw_parts_mut(dst_ptr, outer * n_fft * 2);
15739 d_spec.fill(0.0);
15740 rlx_ir::audio::log_mel_block_vjp(spec, filters, dy, outer, n_fft, n_bins, n_mels, d_spec);
15741}
15742
15743pub unsafe fn execute_fft1d(
15745 src: usize,
15746 dst: usize,
15747 outer: usize,
15748 n_complex: usize,
15749 inverse: bool,
15750 norm_tag: u32,
15751 dtype: rlx_ir::DType,
15752 base: *mut u8,
15753) {
15754 match dtype {
15755 rlx_ir::DType::F32 => {
15756 execute_fft1d_f32(src, dst, outer, n_complex, inverse, norm_tag, base)
15757 }
15758 rlx_ir::DType::F64 => {
15759 execute_fft1d_f64(src, dst, outer, n_complex, inverse, norm_tag, base)
15760 }
15761 rlx_ir::DType::C64 => {
15762 execute_fft1d_c64(src, dst, outer, n_complex, inverse, norm_tag, base)
15763 }
15764 other => panic!("execute_fft1d: unsupported dtype {other:?}"),
15765 }
15766}
15767
15768fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
15773 let n = re.len();
15774 debug_assert_eq!(im.len(), n);
15775 debug_assert!(
15776 n.is_power_of_two(),
15777 "fft_radix2_f32: n={n} must be a power of two"
15778 );
15779 if n <= 1 {
15780 return;
15781 }
15782
15783 let mut j = 0usize;
15784 for i in 1..n {
15785 let mut bit = n >> 1;
15786 while j & bit != 0 {
15787 j ^= bit;
15788 bit >>= 1;
15789 }
15790 j ^= bit;
15791 if i < j {
15792 re.swap(i, j);
15793 im.swap(i, j);
15794 }
15795 }
15796
15797 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15798 let mut len = 2usize;
15799 while len <= n {
15800 let half = len / 2;
15801 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
15802 let w_re_step = theta.cos();
15803 let w_im_step = theta.sin();
15804 let mut i = 0usize;
15805 while i < n {
15806 let mut wre = 1.0_f64;
15807 let mut wim = 0.0_f64;
15808 for k in 0..half {
15809 let wre_f = wre as f32;
15810 let wim_f = wim as f32;
15811 let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
15812 let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
15813 let u_re = re[i + k];
15814 let u_im = im[i + k];
15815 re[i + k] = u_re + t_re;
15816 im[i + k] = u_im + t_im;
15817 re[i + k + half] = u_re - t_re;
15818 im[i + k + half] = u_im - t_im;
15819 let new_wre = wre * w_re_step - wim * w_im_step;
15820 let new_wim = wre * w_im_step + wim * w_re_step;
15821 wre = new_wre;
15822 wim = new_wim;
15823 }
15824 i += len;
15825 }
15826 len <<= 1;
15827 }
15828}
15829
15830fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15834 let n = re.len();
15835 debug_assert_eq!(im.len(), n);
15836 debug_assert!(
15837 n.is_power_of_two(),
15838 "fft_radix2: n={n} must be a power of two"
15839 );
15840 if n <= 1 {
15841 return;
15842 }
15843
15844 let mut j = 0usize;
15846 for i in 1..n {
15847 let mut bit = n >> 1;
15848 while j & bit != 0 {
15849 j ^= bit;
15850 bit >>= 1;
15851 }
15852 j ^= bit;
15853 if i < j {
15854 re.swap(i, j);
15855 im.swap(i, j);
15856 }
15857 }
15858
15859 let sign = if inverse { 1.0 } else { -1.0 };
15861 let mut len = 2usize;
15862 while len <= n {
15863 let half = len / 2;
15864 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
15865 let w_re_step = theta.cos();
15866 let w_im_step = theta.sin();
15867 let mut i = 0usize;
15868 while i < n {
15869 let mut wre = 1.0_f64;
15871 let mut wim = 0.0_f64;
15872 for k in 0..half {
15873 let t_re = wre * re[i + k + half] - wim * im[i + k + half];
15874 let t_im = wre * im[i + k + half] + wim * re[i + k + half];
15875 let u_re = re[i + k];
15876 let u_im = im[i + k];
15877 re[i + k] = u_re + t_re;
15878 im[i + k] = u_im + t_im;
15879 re[i + k + half] = u_re - t_re;
15880 im[i + k + half] = u_im - t_im;
15881 let new_wre = wre * w_re_step - wim * w_im_step;
15882 let new_wim = wre * w_im_step + wim * w_re_step;
15883 wre = new_wre;
15884 wim = new_wim;
15885 }
15886 i += len;
15887 }
15888 len <<= 1;
15889 }
15890}
15891
15892struct BluesteinScratchF64 {
15896 m: usize,
15898 w_re: Vec<f64>,
15902 w_im: Vec<f64>,
15903 bf_re: Vec<f64>,
15906 bf_im: Vec<f64>,
15907 ar: Vec<f64>,
15909 ai: Vec<f64>,
15910}
15911
15912impl BluesteinScratchF64 {
15913 fn empty() -> Self {
15914 Self {
15915 m: 0,
15916 w_re: Vec::new(),
15917 w_im: Vec::new(),
15918 bf_re: Vec::new(),
15919 bf_im: Vec::new(),
15920 ar: Vec::new(),
15921 ai: Vec::new(),
15922 }
15923 }
15924
15925 fn build(n: usize, inverse: bool) -> Self {
15926 let m = if n <= 1 {
15929 1
15930 } else {
15931 (2 * n - 1).next_power_of_two()
15932 };
15933
15934 let mod_2n = (2 * n) as u64;
15937 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15938 let mut w_re = vec![0.0_f64; n];
15939 let mut w_im = vec![0.0_f64; n];
15940 for k in 0..n {
15941 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15942 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15943 w_re[k] = theta.cos();
15944 w_im[k] = theta.sin();
15945 }
15946
15947 let mut bf_re = vec![0.0_f64; m];
15950 let mut bf_im = vec![0.0_f64; m];
15951 if n > 0 {
15952 bf_re[0] = w_re[0];
15953 bf_im[0] = -w_im[0];
15954 for k in 1..n {
15955 bf_re[k] = w_re[k];
15956 bf_im[k] = -w_im[k];
15957 bf_re[m - k] = w_re[k];
15958 bf_im[m - k] = -w_im[k];
15959 }
15960 }
15961 if m > 1 {
15962 fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
15963 }
15964
15965 Self {
15966 m,
15967 w_re,
15968 w_im,
15969 bf_re,
15970 bf_im,
15971 ar: vec![0.0_f64; m],
15972 ai: vec![0.0_f64; m],
15973 }
15974 }
15975}
15976
15977fn fft_naive_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15979 let n = re.len();
15980 if n <= 1 {
15981 return;
15982 }
15983 let sign = if inverse { 1.0 } else { -1.0 };
15984 let mut out_re = vec![0.0_f64; n];
15985 let mut out_im = vec![0.0_f64; n];
15986 for k in 0..n {
15987 for nn in 0..n {
15988 let theta = sign * 2.0 * std::f64::consts::PI * (nn as f64) * (k as f64) / (n as f64);
15989 let c = theta.cos();
15990 let s = theta.sin();
15991 out_re[k] += re[nn] * c - im[nn] * s;
15992 out_im[k] += re[nn] * s + im[nn] * c;
15993 }
15994 }
15995 re.copy_from_slice(&out_re);
15996 im.copy_from_slice(&out_im);
15997}
15998
15999fn fft_naive_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
16000 let n = re.len();
16001 if n <= 1 {
16002 return;
16003 }
16004 let sign = if inverse { 1.0f32 } else { -1.0f32 };
16005 let mut out_re = vec![0.0_f32; n];
16006 let mut out_im = vec![0.0_f32; n];
16007 for k in 0..n {
16008 for nn in 0..n {
16009 let theta = sign * 2.0 * std::f32::consts::PI * (nn as f32) * (k as f32) / (n as f32);
16010 let c = theta.cos();
16011 let s = theta.sin();
16012 out_re[k] += re[nn] * c - im[nn] * s;
16013 out_im[k] += re[nn] * s + im[nn] * c;
16014 }
16015 }
16016 re.copy_from_slice(&out_re);
16017 im.copy_from_slice(&out_im);
16018}
16019
16020fn fft_bluestein_inplace_f64(
16029 re: &mut [f64],
16030 im: &mut [f64],
16031 _inverse: bool,
16032 s: &mut BluesteinScratchF64,
16033) {
16034 let n = re.len();
16035 debug_assert_eq!(im.len(), n);
16036 debug_assert_eq!(s.w_re.len(), n);
16037 if n <= 1 {
16038 return;
16039 }
16040 let m = s.m;
16041
16042 for k in 0..m {
16044 s.ar[k] = 0.0;
16045 s.ai[k] = 0.0;
16046 }
16047 for k in 0..n {
16048 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
16049 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
16050 }
16051
16052 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
16054
16055 for k in 0..m {
16057 let ar = s.ar[k];
16058 let ai = s.ai[k];
16059 let br = s.bf_re[k];
16060 let bi = s.bf_im[k];
16061 s.ar[k] = ar * br - ai * bi;
16062 s.ai[k] = ar * bi + ai * br;
16063 }
16064
16065 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
16068 let inv_m = 1.0 / (m as f64);
16069
16070 for k in 0..n {
16072 let yr = s.ar[k] * inv_m;
16073 let yi = s.ai[k] * inv_m;
16074 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
16075 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
16076 }
16077}
16078
16079struct BluesteinScratchF32 {
16083 m: usize,
16084 w_re: Vec<f32>,
16085 w_im: Vec<f32>,
16086 bf_re: Vec<f32>,
16087 bf_im: Vec<f32>,
16088 ar: Vec<f32>,
16089 ai: Vec<f32>,
16090}
16091
16092impl BluesteinScratchF32 {
16093 fn empty() -> Self {
16094 Self {
16095 m: 0,
16096 w_re: Vec::new(),
16097 w_im: Vec::new(),
16098 bf_re: Vec::new(),
16099 bf_im: Vec::new(),
16100 ar: Vec::new(),
16101 ai: Vec::new(),
16102 }
16103 }
16104
16105 fn build(n: usize, inverse: bool) -> Self {
16106 let m = if n <= 1 {
16107 1
16108 } else {
16109 (2 * n - 1).next_power_of_two()
16110 };
16111
16112 let mod_2n = (2 * n) as u64;
16113 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
16114 let mut w_re = vec![0.0_f32; n];
16115 let mut w_im = vec![0.0_f32; n];
16116 for k in 0..n {
16117 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
16118 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
16119 w_re[k] = theta.cos() as f32;
16120 w_im[k] = theta.sin() as f32;
16121 }
16122
16123 let mut bf_re = vec![0.0_f32; m];
16124 let mut bf_im = vec![0.0_f32; m];
16125 if n > 0 {
16126 bf_re[0] = w_re[0];
16127 bf_im[0] = -w_im[0];
16128 for k in 1..n {
16129 bf_re[k] = w_re[k];
16130 bf_im[k] = -w_im[k];
16131 bf_re[m - k] = w_re[k];
16132 bf_im[m - k] = -w_im[k];
16133 }
16134 }
16135 if m > 1 {
16136 fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
16137 }
16138
16139 Self {
16140 m,
16141 w_re,
16142 w_im,
16143 bf_re,
16144 bf_im,
16145 ar: vec![0.0_f32; m],
16146 ai: vec![0.0_f32; m],
16147 }
16148 }
16149}
16150
16151fn fft_bluestein_inplace_f32(
16152 re: &mut [f32],
16153 im: &mut [f32],
16154 _inverse: bool,
16155 s: &mut BluesteinScratchF32,
16156) {
16157 let n = re.len();
16158 debug_assert_eq!(im.len(), n);
16159 debug_assert_eq!(s.w_re.len(), n);
16160 if n <= 1 {
16161 return;
16162 }
16163 let m = s.m;
16164
16165 for k in 0..m {
16166 s.ar[k] = 0.0;
16167 s.ai[k] = 0.0;
16168 }
16169 for k in 0..n {
16170 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
16171 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
16172 }
16173
16174 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
16175
16176 for k in 0..m {
16177 let ar = s.ar[k];
16178 let ai = s.ai[k];
16179 let br = s.bf_re[k];
16180 let bi = s.bf_im[k];
16181 s.ar[k] = ar * br - ai * bi;
16182 s.ai[k] = ar * bi + ai * br;
16183 }
16184
16185 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
16186 let inv_m = 1.0_f32 / (m as f32);
16187
16188 for k in 0..n {
16189 let yr = s.ar[k] * inv_m;
16190 let yi = s.ai[k] * inv_m;
16191 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
16192 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
16193 }
16194}
16195
16196unsafe fn dispatch_custom_op(
16202 kernel: &dyn crate::op_registry::CpuKernel,
16203 inputs: &[(usize, u32, Shape)],
16204 out_off: usize,
16205 out_len: u32,
16206 out_shape: &Shape,
16207 attrs: &[u8],
16208 base: *mut u8,
16209) {
16210 use crate::op_registry::{CpuTensorMut, CpuTensorRef};
16211 use rlx_ir::DType;
16212
16213 macro_rules! build_in_view {
16218 ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
16219 CpuTensorRef::$variant {
16220 data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
16221 shape: $shape,
16222 }
16223 };
16224 }
16225 macro_rules! build_out_view {
16226 ($variant:ident, $rust_ty:ty) => {
16227 CpuTensorMut::$variant {
16228 data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
16229 shape: out_shape,
16230 }
16231 };
16232 }
16233
16234 let in_views: Vec<CpuTensorRef<'_>> = inputs
16235 .iter()
16236 .map(|(off, len, shape)| {
16237 let n = *len as usize;
16238 let off = *off;
16239 match shape.dtype() {
16240 DType::F32 => build_in_view!(shape, off, n, F32, f32),
16241 DType::F64 => build_in_view!(shape, off, n, F64, f64),
16242 DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
16243 DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
16244 DType::I8 => build_in_view!(shape, off, n, I8, i8),
16245 DType::I16 => build_in_view!(shape, off, n, I16, i16),
16246 DType::I32 => build_in_view!(shape, off, n, I32, i32),
16247 DType::I64 => build_in_view!(shape, off, n, I64, i64),
16248 DType::U8 => build_in_view!(shape, off, n, U8, u8),
16249 DType::U32 => build_in_view!(shape, off, n, U32, u32),
16250 DType::Bool => build_in_view!(shape, off, n, Bool, u8),
16251 DType::C64 => panic!(
16255 "Op::Custom kernel input has DType::C64 — built-in \
16256 complex ops handle their own kernels; user-registered \
16257 ops don't yet see complex tensors"
16258 ),
16259 }
16260 })
16261 .collect();
16262
16263 let result = match out_shape.dtype() {
16264 DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
16265 DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
16266 DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
16267 DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
16268 DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
16269 DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
16270 DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
16271 DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
16272 DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
16273 DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
16274 DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
16275 DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
16276 };
16277 if let Err(e) = result {
16278 panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
16279 }
16280}
16281
16282#[inline(always)]
16288unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
16289 if offset == usize::MAX {
16290 return &[];
16291 }
16292 unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
16293}
16294
16295#[inline(always)]
16296unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
16297 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
16298}
16299
16300#[inline(always)]
16302fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
16306 use rlx_ir::op::Activation;
16307 match act {
16308 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
16309 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
16310 Activation::Silu => crate::kernels::par_silu_inplace(d),
16311 Activation::Relu => {
16312 for v in d.iter_mut() {
16313 *v = v.max(0.0);
16314 }
16315 }
16316 Activation::Sigmoid => {
16317 for v in d.iter_mut() {
16318 *v = 1.0 / (1.0 + (-*v).exp());
16319 }
16320 }
16321 Activation::Tanh => {
16322 for v in d.iter_mut() {
16323 *v = v.tanh();
16324 }
16325 }
16326 Activation::Exp => {
16327 for v in d.iter_mut() {
16328 *v = v.exp();
16329 }
16330 }
16331 Activation::Log => {
16332 for v in d.iter_mut() {
16333 *v = v.ln();
16334 }
16335 }
16336 Activation::Sqrt => {
16337 for v in d.iter_mut() {
16338 *v = v.sqrt();
16339 }
16340 }
16341 Activation::Rsqrt => {
16342 for v in d.iter_mut() {
16343 *v = 1.0 / v.sqrt();
16344 }
16345 }
16346 Activation::Neg => {
16347 for v in d.iter_mut() {
16348 *v = -*v;
16349 }
16350 }
16351 Activation::Abs => {
16352 for v in d.iter_mut() {
16353 *v = v.abs();
16354 }
16355 }
16356 Activation::Round => {
16357 for v in d.iter_mut() {
16358 *v = v.round();
16359 }
16360 }
16361 Activation::Sin => {
16362 for v in d.iter_mut() {
16363 *v = v.sin();
16364 }
16365 }
16366 Activation::Cos => {
16367 for v in d.iter_mut() {
16368 *v = v.cos();
16369 }
16370 }
16371 Activation::Tan => {
16372 for v in d.iter_mut() {
16373 *v = v.tan();
16374 }
16375 }
16376 Activation::Atan => {
16377 for v in d.iter_mut() {
16378 *v = v.atan();
16379 }
16380 }
16381 }
16382}
16383
16384#[allow(clippy::too_many_arguments)]
16393fn im2col(
16394 x: &[f32],
16395 col: &mut [f32],
16396 c_in: usize,
16397 h: usize,
16398 w: usize,
16399 h_out: usize,
16400 w_out: usize,
16401 kh: usize,
16402 kw: usize,
16403 sh: usize,
16404 sw: usize,
16405 ph: usize,
16406 pw: usize,
16407 dh: usize,
16408 dw_dil: usize,
16409) {
16410 let n_dim = h_out * w_out;
16411 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
16412 debug_assert_eq!(x.len(), c_in * h * w);
16413 let h_isz = h as isize;
16414 let w_isz = w as isize;
16415 let ph_isz = ph as isize;
16416 let pw_isz = pw as isize;
16417 for ci in 0..c_in {
16418 for ki in 0..kh {
16419 for kj in 0..kw {
16420 let row = ((ci * kh) + ki) * kw + kj;
16421 let row_off = row * n_dim;
16422 for ho in 0..h_out {
16423 let hi = (ho * sh + ki * dh) as isize - ph_isz;
16424 if hi < 0 || hi >= h_isz {
16425 for wo in 0..w_out {
16426 col[row_off + ho * w_out + wo] = 0.0;
16427 }
16428 continue;
16429 }
16430 let hi = hi as usize;
16431 let in_row_off = (ci * h + hi) * w;
16432 for wo in 0..w_out {
16433 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
16434 col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
16435 0.0
16436 } else {
16437 x[in_row_off + wi as usize]
16438 };
16439 }
16440 }
16441 }
16442 }
16443 }
16444}
16445
16446#[allow(clippy::too_many_arguments)]
16453fn col2im(
16454 col: &[f32],
16455 x: &mut [f32],
16456 c_in: usize,
16457 h: usize,
16458 w: usize,
16459 h_out: usize,
16460 w_out: usize,
16461 kh: usize,
16462 kw: usize,
16463 sh: usize,
16464 sw: usize,
16465 ph: usize,
16466 pw: usize,
16467 dh: usize,
16468 dw_dil: usize,
16469) {
16470 let n_dim = h_out * w_out;
16471 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
16472 debug_assert_eq!(x.len(), c_in * h * w);
16473 let h_isz = h as isize;
16474 let w_isz = w as isize;
16475 let ph_isz = ph as isize;
16476 let pw_isz = pw as isize;
16477 for ci in 0..c_in {
16478 for ki in 0..kh {
16479 for kj in 0..kw {
16480 let row = ((ci * kh) + ki) * kw + kj;
16481 let row_off = row * n_dim;
16482 for ho in 0..h_out {
16483 let hi = (ho * sh + ki * dh) as isize - ph_isz;
16484 if hi < 0 || hi >= h_isz {
16485 continue;
16486 }
16487 let hi = hi as usize;
16488 let in_row_off = (ci * h + hi) * w;
16489 for wo in 0..w_out {
16490 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
16491 if wi < 0 || wi >= w_isz {
16492 continue;
16493 }
16494 x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
16495 }
16496 }
16497 }
16498 }
16499 }
16500}
16501
16502fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
16512 match axis {
16513 None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
16514 Some(d) => {
16515 let chan_dim = shape.dim(d).unwrap_static();
16516 let inner: usize = (d + 1..shape.rank())
16517 .map(|i| shape.dim(i).unwrap_static())
16518 .product::<usize>()
16519 .max(1);
16520 (d, chan_dim, inner)
16521 }
16522 }
16523}
16524
16525fn activation_backward_kernel(
16526 act: rlx_ir::op::Activation,
16527 xs: &[f32],
16528 dys: &[f32],
16529 out: &mut [f32],
16530) {
16531 use rlx_ir::op::Activation;
16532 let n = xs.len();
16533 debug_assert_eq!(dys.len(), n);
16534 debug_assert_eq!(out.len(), n);
16535 match act {
16536 Activation::Relu => {
16537 for i in 0..n {
16538 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
16539 }
16540 }
16541 Activation::Sigmoid => {
16542 for i in 0..n {
16543 let s = 1.0 / (1.0 + (-xs[i]).exp());
16544 out[i] = s * (1.0 - s) * dys[i];
16545 }
16546 }
16547 Activation::Tanh => {
16548 for i in 0..n {
16549 let t = xs[i].tanh();
16550 out[i] = (1.0 - t * t) * dys[i];
16551 }
16552 }
16553 Activation::Silu => {
16554 for i in 0..n {
16556 let s = 1.0 / (1.0 + (-xs[i]).exp());
16557 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
16558 }
16559 }
16560 Activation::Gelu => {
16561 const INV_SQRT2: f32 = 0.707_106_77;
16564 const INV_SQRT_2PI: f32 = 0.398_942_3;
16565 for i in 0..n {
16566 let x = xs[i];
16567 let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
16568 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
16569 out[i] = (phi + x * pdf) * dys[i];
16570 }
16571 }
16572 Activation::GeluApprox => {
16573 const C: f32 = 0.797_884_6; const A: f32 = 0.044_715;
16577 for i in 0..n {
16578 let x = xs[i];
16579 let inner = C * (x + A * x * x * x);
16580 let t = inner.tanh();
16581 let dinner = C * (1.0 + 3.0 * A * x * x);
16582 let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
16583 out[i] = d * dys[i];
16584 }
16585 }
16586 Activation::Exp => {
16587 for i in 0..n {
16588 out[i] = xs[i].exp() * dys[i];
16589 }
16590 }
16591 Activation::Log => {
16592 for i in 0..n {
16593 out[i] = dys[i] / xs[i];
16594 }
16595 }
16596 Activation::Sqrt => {
16597 for i in 0..n {
16599 let s = xs[i].sqrt();
16600 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
16601 }
16602 }
16603 Activation::Rsqrt => {
16604 for i in 0..n {
16606 let s = xs[i].sqrt();
16607 out[i] = if s > 0.0 {
16608 -0.5 * dys[i] / (xs[i] * s)
16609 } else {
16610 0.0
16611 };
16612 }
16613 }
16614 Activation::Neg => {
16615 for i in 0..n {
16616 out[i] = -dys[i];
16617 }
16618 }
16619 Activation::Abs => {
16620 for i in 0..n {
16622 let x = xs[i];
16623 let s = if x > 0.0 {
16624 1.0
16625 } else if x < 0.0 {
16626 -1.0
16627 } else {
16628 0.0
16629 };
16630 out[i] = s * dys[i];
16631 }
16632 }
16633 Activation::Round => {
16634 out.copy_from_slice(dys);
16639 }
16640 Activation::Sin => {
16641 for i in 0..n {
16643 out[i] = xs[i].cos() * dys[i];
16644 }
16645 }
16646 Activation::Cos => {
16647 for i in 0..n {
16648 out[i] = -xs[i].sin() * dys[i];
16649 }
16650 }
16651 Activation::Tan => {
16652 for i in 0..n {
16654 let t = xs[i].tan();
16655 out[i] = (1.0 + t * t) * dys[i];
16656 }
16657 }
16658 Activation::Atan => {
16659 for i in 0..n {
16661 let x = xs[i];
16662 out[i] = dys[i] / (1.0 + x * x);
16663 }
16664 }
16665 }
16666}
16667
16668fn activation_backward_kernel_f64(
16672 act: rlx_ir::op::Activation,
16673 xs: &[f64],
16674 dys: &[f64],
16675 out: &mut [f64],
16676) {
16677 use rlx_ir::op::Activation;
16678 let n = xs.len();
16679 debug_assert_eq!(dys.len(), n);
16680 debug_assert_eq!(out.len(), n);
16681 match act {
16682 Activation::Relu => {
16683 for i in 0..n {
16684 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
16685 }
16686 }
16687 Activation::Sigmoid => {
16688 for i in 0..n {
16689 let s = 1.0 / (1.0 + (-xs[i]).exp());
16690 out[i] = s * (1.0 - s) * dys[i];
16691 }
16692 }
16693 Activation::Tanh => {
16694 for i in 0..n {
16695 let t = xs[i].tanh();
16696 out[i] = (1.0 - t * t) * dys[i];
16697 }
16698 }
16699 Activation::Silu => {
16700 for i in 0..n {
16701 let s = 1.0 / (1.0 + (-xs[i]).exp());
16702 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
16703 }
16704 }
16705 Activation::Gelu | Activation::GeluApprox => {
16706 const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
16708 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
16709 for i in 0..n {
16710 let x = xs[i];
16711 let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
16712 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
16713 out[i] = (phi + x * pdf) * dys[i];
16714 }
16715 }
16716 Activation::Exp => {
16717 for i in 0..n {
16718 out[i] = xs[i].exp() * dys[i];
16719 }
16720 }
16721 Activation::Log => {
16722 for i in 0..n {
16723 out[i] = dys[i] / xs[i];
16724 }
16725 }
16726 Activation::Sqrt => {
16727 for i in 0..n {
16728 let s = xs[i].sqrt();
16729 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
16730 }
16731 }
16732 Activation::Rsqrt => {
16733 for i in 0..n {
16734 let s = xs[i].sqrt();
16735 out[i] = if s > 0.0 {
16736 -0.5 * dys[i] / (xs[i] * s)
16737 } else {
16738 0.0
16739 };
16740 }
16741 }
16742 Activation::Neg => {
16743 for i in 0..n {
16744 out[i] = -dys[i];
16745 }
16746 }
16747 Activation::Abs => {
16748 for i in 0..n {
16749 let x = xs[i];
16750 let s = if x > 0.0 {
16751 1.0
16752 } else if x < 0.0 {
16753 -1.0
16754 } else {
16755 0.0
16756 };
16757 out[i] = s * dys[i];
16758 }
16759 }
16760 Activation::Round => {
16761 out.copy_from_slice(dys);
16762 }
16763 Activation::Sin => {
16764 for i in 0..n {
16765 out[i] = xs[i].cos() * dys[i];
16766 }
16767 }
16768 Activation::Cos => {
16769 for i in 0..n {
16770 out[i] = -xs[i].sin() * dys[i];
16771 }
16772 }
16773 Activation::Tan => {
16774 for i in 0..n {
16775 let t = xs[i].tan();
16776 out[i] = (1.0 + t * t) * dys[i];
16777 }
16778 }
16779 Activation::Atan => {
16780 for i in 0..n {
16781 let x = xs[i];
16782 out[i] = dys[i] / (1.0 + x * x);
16783 }
16784 }
16785 }
16786}
16787
16788#[inline(always)]
16793fn erf_f64(x: f64) -> f64 {
16794 let s = x.signum();
16795 let x = x.abs();
16796 let t = 1.0 / (1.0 + 0.327_591_1 * x);
16797 let y = 1.0
16798 - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
16799 + 0.254_829_59)
16800 * t
16801 * (-x * x).exp();
16802 s * y
16803}
16804
16805#[inline(always)]
16808fn erf_f32(x: f32) -> f32 {
16809 let s = x.signum();
16810 let x = x.abs();
16811 let t = 1.0 / (1.0 + 0.327_591_1 * x);
16812 let y = 1.0
16813 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
16814 + 0.254_829_6)
16815 * t
16816 * (-x * x).exp();
16817 s * y
16818}
16819
16820fn narrow_thunk_closure(
16821 src: usize,
16822 dst: usize,
16823 outer: u32,
16824 src_stride: u32,
16825 dst_stride: u32,
16826 inner: u32,
16827 elem_bytes: u8,
16828) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
16829 let (outer, ss, ds, inner, eb) = (
16830 outer as usize,
16831 src_stride as usize,
16832 dst_stride as usize,
16833 inner as usize,
16834 elem_bytes as usize,
16835 );
16836 let row_bytes = inner.saturating_mul(eb);
16837 let src_row_stride = ss.saturating_mul(eb);
16838 let dst_row_stride = ds.saturating_mul(eb);
16839 Arc::new(move |base: *mut u8| unsafe {
16840 if row_bytes == 0 || src == dst {
16841 return;
16842 }
16843 let arena_len = usize::MAX;
16845 for o in 0..outer {
16846 let s_off = src + o * src_row_stride;
16847 let d_off = dst + o * dst_row_stride;
16848 if s_off == d_off {
16849 continue;
16850 }
16851 if s_off.saturating_add(row_bytes) > arena_len
16852 || d_off.saturating_add(row_bytes) > arena_len
16853 {
16854 break;
16855 }
16856 std::ptr::copy_nonoverlapping(base.add(s_off), base.add(d_off), row_bytes);
16857 }
16858 })
16859}
16860
16861unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
16862 if offset == usize::MAX {
16863 return &[];
16864 }
16865 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
16866}
16867
16868#[inline(always)]
16869unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
16870 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
16871}
16872
16873#[inline(always)]
16874unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
16875 if offset == usize::MAX {
16876 return &[];
16877 }
16878 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
16879}
16880
16881#[inline(always)]
16882unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
16883 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
16884}
16885
16886#[inline(always)]
16891#[allow(dead_code)]
16892unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
16893 if offset == usize::MAX {
16894 return &[];
16895 }
16896 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
16897}
16898
16899#[inline(always)]
16900#[allow(dead_code)]
16901unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
16902 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
16903}
16904
16905#[inline(always)]
16906unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
16907 if offset == usize::MAX {
16908 return &[];
16909 }
16910 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
16911}
16912
16913#[inline(always)]
16914unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
16915 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
16916}
16917
16918fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
16922 let rank = out_dims.len();
16923 let mut idx = vec![0u32; rank];
16924 for o in 0..out.len() {
16925 let mut src_off = 0usize;
16926 for d in 0..rank {
16927 src_off += idx[d] as usize * in_strides[d] as usize;
16928 }
16929 out[o] = inp[broadcast_src_index(src_off, inp.len())];
16930 for d in (0..rank).rev() {
16932 idx[d] += 1;
16933 if idx[d] < out_dims[d] {
16934 break;
16935 }
16936 idx[d] = 0;
16937 }
16938 }
16939}
16940
16941fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
16947 match kind {
16948 Activation::Neg => {
16949 for (o, &v) in out.iter_mut().zip(inp) {
16950 *o = -v;
16951 }
16952 }
16953 Activation::Exp => {
16954 for (o, &v) in out.iter_mut().zip(inp) {
16955 *o = v.exp();
16956 }
16957 }
16958 Activation::Log => {
16959 for (o, &v) in out.iter_mut().zip(inp) {
16960 *o = v.ln();
16961 }
16962 }
16963 Activation::Sqrt => {
16964 for (o, &v) in out.iter_mut().zip(inp) {
16965 *o = v.sqrt();
16966 }
16967 }
16968 Activation::Rsqrt => {
16969 for (o, &v) in out.iter_mut().zip(inp) {
16970 *o = 1.0 / v.sqrt();
16971 }
16972 }
16973 Activation::Abs => {
16974 for (o, &v) in out.iter_mut().zip(inp) {
16975 *o = v.abs();
16976 }
16977 }
16978 Activation::Tanh => {
16979 for (o, &v) in out.iter_mut().zip(inp) {
16980 *o = v.tanh();
16981 }
16982 }
16983 Activation::Sigmoid => {
16984 for (o, &v) in out.iter_mut().zip(inp) {
16985 *o = 1.0 / (1.0 + (-v).exp());
16986 }
16987 }
16988 Activation::Relu => {
16989 for (o, &v) in out.iter_mut().zip(inp) {
16990 *o = v.max(0.0);
16991 }
16992 }
16993 Activation::Round => {
16994 for (o, &v) in out.iter_mut().zip(inp) {
16995 *o = v.round_ties_even();
16996 }
16997 }
16998 Activation::Sin => {
16999 for (o, &v) in out.iter_mut().zip(inp) {
17000 *o = v.sin();
17001 }
17002 }
17003 Activation::Cos => {
17004 for (o, &v) in out.iter_mut().zip(inp) {
17005 *o = v.cos();
17006 }
17007 }
17008 Activation::Tan => {
17009 for (o, &v) in out.iter_mut().zip(inp) {
17010 *o = v.tan();
17011 }
17012 }
17013 Activation::Atan => {
17014 for (o, &v) in out.iter_mut().zip(inp) {
17015 *o = v.atan();
17016 }
17017 }
17018 Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
17019 panic!(
17020 "apply_activation_f64: {kind:?} not yet implemented at f64. \
17021 Add when a workload needs it."
17022 );
17023 }
17024 }
17025}
17026
17027#[inline]
17028fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
17029 match op {
17030 BinaryOp::Add => a + b,
17031 BinaryOp::Sub => a - b,
17032 BinaryOp::Mul => a * b,
17033 BinaryOp::Div => a / b,
17034 BinaryOp::Max => a.max(b),
17035 BinaryOp::Min => a.min(b),
17036 BinaryOp::Pow => a.powf(b),
17037 }
17038}
17039
17040fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
17043 for o in 0..outer {
17044 for n in 0..inner {
17045 let mut acc = 0.0_f64;
17046 for r in 0..reduced {
17047 acc += inp[o * reduced * inner + r * inner + n];
17048 }
17049 out[o * inner + n] = acc;
17050 }
17051 }
17052}
17053
17054pub unsafe fn fill_rng_normal_arena(
17060 dst_off: usize,
17061 len: usize,
17062 mean: f32,
17063 scale: f32,
17064 key: u64,
17065 op_seed: Option<f32>,
17066 opts: rlx_ir::RngOptions,
17067 arena: *mut u8,
17068) {
17069 if len == 0 {
17070 return;
17071 }
17072 unsafe {
17073 let out = std::slice::from_raw_parts_mut((arena.add(dst_off)) as *mut f32, len);
17074 rlx_ir::fill_normal_like(out, mean, scale, opts, key, op_seed);
17075 }
17076}
17077
17078pub unsafe fn fill_rng_uniform_arena(
17079 dst_off: usize,
17080 len: usize,
17081 low: f32,
17082 high: f32,
17083 key: u64,
17084 op_seed: Option<f32>,
17085 opts: rlx_ir::RngOptions,
17086 arena: *mut u8,
17087) {
17088 if len == 0 {
17089 return;
17090 }
17091 unsafe {
17092 let out = std::slice::from_raw_parts_mut((arena.add(dst_off)) as *mut f32, len);
17093 rlx_ir::fill_uniform_like(out, low, high, opts, key, op_seed);
17094 }
17095}
17096
17097#[cfg(test)]
17098mod tests {
17099 use super::*;
17100 use rlx_ir::*;
17101
17102 #[test]
17108 fn narrow_rope_fuses_in_unfused_path() {
17109 let f = DType::F32;
17110 let mut g = Graph::new("nr_fuse");
17111 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); let cos = g.input("cos", Shape::new(&[16], f));
17114 let sin = g.input("sin", Shape::new(&[16], f));
17115 let q = g.narrow_(qkv, 2, 0, 64);
17117 let q_rope = g.rope(q, cos, sin, 16);
17118 g.set_outputs(vec![q_rope]);
17119
17120 let plan = rlx_opt::memory::plan_memory(&g);
17121 let arena = crate::arena::Arena::from_plan(plan);
17122 let sched = compile_thunks(&g, &arena);
17123
17124 let mut narrow_count = 0;
17125 let mut rope_with_stride: Option<u32> = None;
17126 for t in &sched.thunks {
17127 match t {
17128 Thunk::Narrow { .. } => narrow_count += 1,
17129 Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
17130 _ => {}
17131 }
17132 }
17133 assert_eq!(
17136 narrow_count, 0,
17137 "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
17138 );
17139 assert_eq!(
17140 rope_with_stride,
17141 Some(192),
17142 "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
17143 );
17144 }
17145
17146 #[test]
17149 fn ssm_selective_scan_matches_reference() {
17150 use rlx_ir::Philox4x32;
17151 let bch = 1usize;
17152 let s = 4usize;
17153 let h = 3usize;
17154 let n = 2usize;
17155
17156 let mut rng = Philox4x32::new(13);
17157 let mut x = vec![0f32; bch * s * h];
17158 rng.fill_normal(&mut x);
17159 let mut delta = vec![0f32; bch * s * h];
17160 for v in delta.iter_mut() {
17162 *v = (rng.next_f32() - 0.5) * 0.1;
17163 }
17164 let mut a = vec![0f32; h * n];
17165 for v in a.iter_mut() {
17166 *v = -(rng.next_f32() * 0.5 + 0.1);
17167 } let mut b = vec![0f32; bch * s * n];
17169 rng.fill_normal(&mut b);
17170 let mut c = vec![0f32; bch * s * n];
17171 rng.fill_normal(&mut c);
17172
17173 let mut expected = vec![0f32; bch * s * h];
17175 for bi in 0..bch {
17176 let mut state = vec![0f32; h * n];
17177 for si in 0..s {
17178 for ci in 0..h {
17179 let d = delta[bi * s * h + si * h + ci];
17180 let xv = x[bi * s * h + si * h + ci];
17181 let mut acc = 0f32;
17182 for ni in 0..n {
17183 let da = (d * a[ci * n + ni]).exp();
17184 state[ci * n + ni] =
17185 da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
17186 acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
17187 }
17188 expected[bi * s * h + si * h + ci] = acc;
17189 }
17190 }
17191 }
17192
17193 let f = DType::F32;
17195 let mut g = Graph::new("ssm");
17196 let xn = g.input("x", Shape::new(&[bch, s, h], f));
17197 let dn = g.input("delta", Shape::new(&[bch, s, h], f));
17198 let an = g.param("a", Shape::new(&[h, n], f));
17199 let bn = g.param("b", Shape::new(&[bch, s, n], f));
17200 let cn = g.param("c", Shape::new(&[bch, s, n], f));
17201 let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
17202 g.set_outputs(vec![yn]);
17203
17204 let plan = rlx_opt::memory::plan_memory(&g);
17205 let mut arena = crate::arena::Arena::from_plan(plan);
17206 let sched = compile_thunks(&g, &arena);
17207
17208 let xn_off = arena.byte_offset(xn);
17209 let dn_off = arena.byte_offset(dn);
17210 let an_off = arena.byte_offset(an);
17211 let bn_off = arena.byte_offset(bn);
17212 let cn_off = arena.byte_offset(cn);
17213 let yn_off = arena.byte_offset(yn);
17214 let buf = arena.raw_buf_mut();
17215 unsafe {
17216 let copy = |dst: *mut f32, data: &[f32]| {
17217 for (i, &v) in data.iter().enumerate() {
17218 *dst.add(i) = v;
17219 }
17220 };
17221 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
17222 copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
17223 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
17224 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
17225 copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
17226 }
17227 execute_thunks(&sched, arena.raw_buf_mut());
17228
17229 let actual: Vec<f32> = unsafe {
17230 let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
17231 (0..bch * s * h).map(|i| *p.add(i)).collect()
17232 };
17233
17234 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
17235 assert!(
17236 (e - a).abs() < 1e-3,
17237 "mismatch at {i}: expected {e}, got {a}"
17238 );
17239 }
17240 }
17241
17242 #[test]
17245 fn conv_1x1_fast_path_matches_scalar() {
17246 use rlx_ir::Philox4x32;
17247 let n = 2usize;
17249 let c_in = 4usize;
17250 let h = 3usize;
17251 let w = 3usize;
17252 let c_out = 5usize;
17253 let mut rng = Philox4x32::new(31);
17254 let mut x = vec![0f32; n * c_in * h * w];
17255 rng.fill_normal(&mut x);
17256 let mut weight = vec![0f32; c_out * c_in];
17257 rng.fill_normal(&mut weight);
17258
17259 let mut expected = vec![0f32; n * c_out * h * w];
17262 for ni in 0..n {
17263 for co in 0..c_out {
17264 for hi in 0..h {
17265 for wi in 0..w {
17266 let mut acc = 0f32;
17267 for ci in 0..c_in {
17268 acc += weight[co * c_in + ci]
17269 * x[((ni * c_in) + ci) * h * w + hi * w + wi];
17270 }
17271 expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
17272 }
17273 }
17274 }
17275 }
17276
17277 let f = DType::F32;
17279 let mut g = Graph::new("conv1x1");
17280 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
17281 let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
17282 let cn = g.add_node(
17284 rlx_ir::Op::Conv {
17285 kernel_size: vec![1, 1],
17286 stride: vec![1, 1],
17287 padding: vec![0, 0],
17288 dilation: vec![1, 1],
17289 groups: 1,
17290 },
17291 vec![xn, wn],
17292 Shape::new(&[n, c_out, h, w], f),
17293 );
17294 g.set_outputs(vec![cn]);
17295
17296 let plan = rlx_opt::memory::plan_memory(&g);
17297 let mut arena = crate::arena::Arena::from_plan(plan);
17298 let sched = compile_thunks(&g, &arena);
17299
17300 let saw_fast = sched
17302 .thunks
17303 .iter()
17304 .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
17305 let saw_slow = sched
17306 .thunks
17307 .iter()
17308 .any(|t| matches!(t, Thunk::Conv2D { .. }));
17309 assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
17310 assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
17311
17312 let xn_off = arena.byte_offset(xn);
17313 let wn_off = arena.byte_offset(wn);
17314 let cn_off = arena.byte_offset(cn);
17315 let buf = arena.raw_buf_mut();
17316 unsafe {
17317 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
17318 for (i, &v) in x.iter().enumerate() {
17319 *xp.add(i) = v;
17320 }
17321 let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
17322 for (i, &v) in weight.iter().enumerate() {
17323 *wp.add(i) = v;
17324 }
17325 }
17326 execute_thunks(&sched, arena.raw_buf_mut());
17327
17328 let actual: Vec<f32> = unsafe {
17329 let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
17330 (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
17331 };
17332
17333 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
17334 assert!(
17335 (e - a).abs() < 1e-3,
17336 "mismatch at {i}: expected {e}, got {a}"
17337 );
17338 }
17339 }
17340
17341 #[test]
17344 fn dequant_matmul_int8_sym_matches_reference() {
17345 use rlx_ir::Philox4x32;
17346 use rlx_ir::quant::QuantScheme;
17347
17348 let m = 3usize;
17349 let k = 8usize;
17350 let n = 4usize;
17351 let block_size = 4usize; let blocks_per_col = k / block_size;
17353
17354 let mut rng = Philox4x32::new(99);
17356 let mut x = vec![0f32; m * k];
17357 rng.fill_normal(&mut x);
17358 let w_q: Vec<i8> = (0..(k * n))
17359 .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
17360 .collect();
17361 let scales: Vec<f32> = (0..(blocks_per_col * n))
17362 .map(|i| 0.01 + 0.001 * i as f32)
17363 .collect();
17364
17365 let mut w_f32 = vec![0f32; k * n];
17367 for p in 0..k {
17368 let block = p / block_size;
17369 for j in 0..n {
17370 let s = scales[block * n + j];
17371 w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
17372 }
17373 }
17374 let mut expected = vec![0f32; m * n];
17375 for i in 0..m {
17376 for j in 0..n {
17377 let mut acc = 0f32;
17378 for p in 0..k {
17379 acc += x[i * k + p] * w_f32[p * n + j];
17380 }
17381 expected[i * n + j] = acc;
17382 }
17383 }
17384
17385 let f = DType::F32;
17387 let mut g = Graph::new("dq");
17388 let xn = g.input("x", Shape::new(&[m, k], f));
17389 let wn = g.param("w", Shape::new(&[k, n], DType::I8));
17390 let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
17391 let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); let dq = g.dequant_matmul(
17393 xn,
17394 wn,
17395 sn,
17396 zn,
17397 QuantScheme::Int8Block {
17398 block_size: block_size as u32,
17399 },
17400 Shape::new(&[m, n], f),
17401 );
17402 g.set_outputs(vec![dq]);
17403
17404 let plan = rlx_opt::memory::plan_memory(&g);
17405 let mut arena = crate::arena::Arena::from_plan(plan);
17406 let sched = compile_thunks(&g, &arena);
17407
17408 let xn_off = arena.byte_offset(xn);
17409 let wn_off = arena.byte_offset(wn);
17410 let sn_off = arena.byte_offset(sn);
17411 let zn_off = arena.byte_offset(zn);
17412 let dq_off = arena.byte_offset(dq);
17413 let buf = arena.raw_buf_mut();
17414 unsafe {
17415 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
17417 for (i, &v) in x.iter().enumerate() {
17418 *xp.add(i) = v;
17419 }
17420 let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
17421 for (i, &v) in scales.iter().enumerate() {
17422 *sp.add(i) = v;
17423 }
17424 let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
17425 for i in 0..(blocks_per_col * n) {
17426 *zp.add(i) = 0.0;
17427 }
17428 let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
17430 for (i, &v) in w_q.iter().enumerate() {
17431 *wp.add(i) = v;
17432 }
17433 }
17434 execute_thunks(&sched, arena.raw_buf_mut());
17435
17436 let actual: Vec<f32> = unsafe {
17437 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
17438 (0..m * n).map(|i| *p.add(i)).collect()
17439 };
17440
17441 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
17442 assert!(
17443 (e - a).abs() < 1e-3,
17444 "mismatch at {i}: expected {e}, got {a}"
17445 );
17446 }
17447 }
17448
17449 #[test]
17451 fn lora_matmul_matches_unfused_reference() {
17452 use rlx_ir::Philox4x32;
17453
17454 let m = 4usize;
17455 let k = 8usize;
17456 let n = 6usize;
17457 let r = 2usize;
17458 let scale = 0.5f32;
17459
17460 let mut rng = Philox4x32::new(42);
17462 let mut x = vec![0f32; m * k];
17463 rng.fill_normal(&mut x);
17464 let mut w = vec![0f32; k * n];
17465 rng.fill_normal(&mut w);
17466 let mut a = vec![0f32; k * r];
17467 rng.fill_normal(&mut a);
17468 let mut b = vec![0f32; r * n];
17469 rng.fill_normal(&mut b);
17470
17471 let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
17473 let mut o = vec![0f32; rows * cols];
17474 for i in 0..rows {
17475 for j in 0..cols {
17476 let mut acc = 0f32;
17477 for p in 0..inner {
17478 acc += a_buf[i * inner + p] * b_buf[p * cols + j];
17479 }
17480 o[i * cols + j] = acc;
17481 }
17482 }
17483 o
17484 };
17485 let xw = naive(&x, &w, m, k, n);
17486 let xa = naive(&x, &a, m, k, r);
17487 let xab = naive(&xa, &b, m, r, n);
17488 let mut expected = xw;
17489 for i in 0..(m * n) {
17490 expected[i] += scale * xab[i];
17491 }
17492
17493 let f = DType::F32;
17495 let mut g = Graph::new("lora");
17496 let xn = g.input("x", Shape::new(&[m, k], f));
17497 let wn = g.param("w", Shape::new(&[k, n], f));
17498 let an = g.param("a", Shape::new(&[k, r], f));
17499 let bn = g.param("b", Shape::new(&[r, n], f));
17500 let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
17501 g.set_outputs(vec![lm]);
17502
17503 let plan = rlx_opt::memory::plan_memory(&g);
17504 let mut arena = crate::arena::Arena::from_plan(plan);
17505 let sched = compile_thunks(&g, &arena);
17506
17507 let xn_off = arena.byte_offset(xn);
17508 let wn_off = arena.byte_offset(wn);
17509 let an_off = arena.byte_offset(an);
17510 let bn_off = arena.byte_offset(bn);
17511 let lm_off = arena.byte_offset(lm);
17512 let buf = arena.raw_buf_mut();
17513 unsafe {
17514 let copy = |dst: *mut f32, data: &[f32]| {
17515 for (i, &v) in data.iter().enumerate() {
17516 *dst.add(i) = v;
17517 }
17518 };
17519 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
17520 copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
17521 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
17522 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
17523 }
17524 execute_thunks(&sched, arena.raw_buf_mut());
17525
17526 let actual: Vec<f32> = unsafe {
17527 let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
17528 (0..m * n).map(|i| *p.add(i)).collect()
17529 };
17530
17531 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
17532 assert!(
17533 (e - a).abs() < 1e-3,
17534 "mismatch at {i}: expected {e}, got {a}"
17535 );
17536 }
17537 }
17538
17539 #[test]
17541 fn sample_temperature_zero_is_argmax() {
17542 let f = DType::F32;
17545 let mut g = Graph::new("samp");
17546 let logits = g.input("logits", Shape::new(&[1, 8], f));
17547 let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
17548 g.set_outputs(vec![s]);
17549 let plan = rlx_opt::memory::plan_memory(&g);
17550 let mut arena = crate::arena::Arena::from_plan(plan);
17551 let sched = compile_thunks(&g, &arena);
17552
17553 let logits_off = arena.byte_offset(logits);
17554 let s_off = arena.byte_offset(s);
17555 let buf = arena.raw_buf_mut();
17556 unsafe {
17557 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
17558 let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
17560 for (i, &v) in inputs.iter().enumerate() {
17561 *p.add(i) = v;
17562 }
17563 }
17564 execute_thunks(&sched, arena.raw_buf_mut());
17565
17566 let token = unsafe {
17567 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
17568 *p as usize
17569 };
17570 assert_eq!(token, 5, "low-temp sampling should pick the argmax");
17571 }
17572
17573 #[test]
17574 fn sample_top_k_one_is_deterministic() {
17575 let f = DType::F32;
17577 let mut g = Graph::new("samp_k1");
17578 let logits = g.input("logits", Shape::new(&[1, 4], f));
17579 let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
17580 g.set_outputs(vec![s]);
17581 let plan = rlx_opt::memory::plan_memory(&g);
17582 let mut arena = crate::arena::Arena::from_plan(plan);
17583 let sched = compile_thunks(&g, &arena);
17584
17585 let logits_off = arena.byte_offset(logits);
17586 let s_off = arena.byte_offset(s);
17587 let buf = arena.raw_buf_mut();
17588 unsafe {
17589 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
17590 let inputs = [0.1f32, 5.0, 0.3, 0.4]; for (i, &v) in inputs.iter().enumerate() {
17592 *p.add(i) = v;
17593 }
17594 }
17595 execute_thunks(&sched, arena.raw_buf_mut());
17596 let token = unsafe {
17597 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
17598 *p as usize
17599 };
17600 assert_eq!(token, 1);
17601 }
17602
17603 #[test]
17605 fn cumsum_inclusive_matches_naive() {
17606 let f = DType::F32;
17607 let mut g = Graph::new("cumsum");
17608 let x = g.input("x", Shape::new(&[2, 4], f));
17609 let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
17610 g.set_outputs(vec![cs]);
17611 let plan = rlx_opt::memory::plan_memory(&g);
17612 let mut arena = crate::arena::Arena::from_plan(plan);
17613 let sched = compile_thunks(&g, &arena);
17614
17615 let x_off = arena.byte_offset(x);
17617 let out_off = arena.byte_offset(cs);
17618 let buf = arena.raw_buf_mut();
17619 unsafe {
17620 let p = buf.as_mut_ptr().add(x_off) as *mut f32;
17621 let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
17622 for (i, &v) in inputs.iter().enumerate() {
17623 *p.add(i) = v;
17624 }
17625 }
17626 execute_thunks(&sched, arena.raw_buf_mut());
17627
17628 let out: Vec<f32> = unsafe {
17629 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
17630 (0..8).map(|i| *p.add(i)).collect()
17631 };
17632 assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
17633 }
17634
17635 #[test]
17639 fn narrow_attention_fuses_in_unfused_path() {
17640 let f = DType::F32;
17641 let mut g = Graph::new("nattn_fuse");
17642 let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); let mask = g.input("mask", Shape::new(&[8, 16], f));
17645 let q = g.narrow_(qkv, 2, 0, 64);
17646 let k = g.narrow_(qkv, 2, 64, 64);
17647 let v = g.narrow_(qkv, 2, 128, 64);
17648 let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
17649 g.set_outputs(vec![attn]);
17650
17651 let plan = rlx_opt::memory::plan_memory(&g);
17652 let arena = crate::arena::Arena::from_plan(plan);
17653 let sched = compile_thunks(&g, &arena);
17654
17655 let mut narrow_count = 0;
17656 let mut attn_strides: Option<(u32, u32, u32)> = None;
17657 for t in &sched.thunks {
17658 match t {
17659 Thunk::Narrow { .. } => narrow_count += 1,
17660 Thunk::Attention {
17661 q_row_stride,
17662 k_row_stride,
17663 v_row_stride,
17664 ..
17665 } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
17666 _ => {}
17667 }
17668 }
17669 assert_eq!(
17672 narrow_count, 0,
17673 "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
17674 );
17675 assert_eq!(
17676 attn_strides,
17677 Some((192, 192, 192)),
17678 "Attention should walk Q/K/V with parent row stride 192"
17679 );
17680 }
17681
17682 fn run_graph(
17693 g: &Graph,
17694 inputs: &[(NodeId, &[f32])],
17695 out_id: NodeId,
17696 out_len: usize,
17697 ) -> Vec<f32> {
17698 let plan = rlx_opt::memory::plan_memory(g);
17699 let mut arena = crate::arena::Arena::from_plan(plan);
17700 let sched = compile_thunks(g, &arena);
17701 for &(id, data) in inputs {
17702 let off = arena.byte_offset(id);
17703 let buf = arena.raw_buf_mut();
17704 unsafe {
17705 let p = buf.as_mut_ptr().add(off) as *mut f32;
17706 for (i, &v) in data.iter().enumerate() {
17707 *p.add(i) = v;
17708 }
17709 }
17710 }
17711 execute_thunks(&sched, arena.raw_buf_mut());
17712 let off = arena.byte_offset(out_id);
17713 unsafe {
17714 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
17715 (0..out_len).map(|i| *p.add(i)).collect()
17716 }
17717 }
17718
17719 #[test]
17720 fn relu_backward_matches_mask() {
17721 let f = DType::F32;
17722 let len = 7usize;
17723 let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
17724 let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
17725
17726 let mut g = Graph::new("relu_bw");
17727 let xn = g.input("x", Shape::new(&[len], f));
17728 let dyn_ = g.input("dy", Shape::new(&[len], f));
17729 let dx = g.relu_backward(xn, dyn_);
17730 g.set_outputs(vec![dx]);
17731
17732 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
17733 let expected: Vec<f32> = x
17737 .iter()
17738 .zip(&dy)
17739 .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
17740 .collect();
17741 for (a, e) in actual.iter().zip(&expected) {
17742 assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
17743 }
17744 }
17745
17746 #[test]
17747 fn maxpool2d_backward_routes_to_argmax() {
17748 let f = DType::F32;
17749 let x: Vec<f32> = vec![
17751 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17752 ];
17753 let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
17757
17758 let mut g = Graph::new("maxpool_bw");
17759 let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
17760 let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
17761 let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
17762 g.set_outputs(vec![dx]);
17763
17764 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
17765 let mut expected = vec![0f32; 16];
17766 expected[5] = 0.5;
17767 expected[7] = 1.0;
17768 expected[13] = 2.0;
17769 expected[15] = 4.0;
17770 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
17771 assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
17772 }
17773 }
17774
17775 #[test]
17776 fn conv2d_backward_input_matches_numerical_gradient() {
17777 use rlx_ir::Philox4x32;
17778 let n = 1usize;
17781 let c_in = 2usize;
17782 let h = 4usize;
17783 let w = 4usize;
17784 let c_out = 3usize;
17785 let kh = 3usize;
17786 let kw = 3usize;
17787 let ph = 1usize;
17788 let pw = 1usize;
17789 let sh = 1usize;
17790 let sw = 1usize;
17791 let h_out = (h + 2 * ph - kh) / sh + 1;
17793 let w_out = (w + 2 * pw - kw) / sw + 1;
17794 assert_eq!(h_out, 4);
17795 assert_eq!(w_out, 4);
17796
17797 let mut rng = Philox4x32::new(7);
17798 let mut x = vec![0f32; n * c_in * h * w];
17799 rng.fill_normal(&mut x);
17800 let mut wt = vec![0f32; c_out * c_in * kh * kw];
17801 rng.fill_normal(&mut wt);
17802 let mut dy = vec![0f32; n * c_out * h_out * w_out];
17803 rng.fill_normal(&mut dy);
17804
17805 let f = DType::F32;
17807 let mut g = Graph::new("conv_bwi");
17808 let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
17809 let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
17810 let dx = g.conv2d_backward_input(
17811 dy_in,
17812 w_in,
17813 Shape::new(&[n, c_in, h, w], f),
17814 vec![kh, kw],
17815 vec![sh, sw],
17816 vec![ph, pw],
17817 vec![1, 1],
17818 1,
17819 );
17820 g.set_outputs(vec![dx]);
17821 let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
17822
17823 let forward = |x: &[f32]| -> Vec<f32> {
17827 let mut out = vec![0f32; n * c_out * h_out * w_out];
17828 for ni in 0..n {
17829 for co in 0..c_out {
17830 for ho in 0..h_out {
17831 for wo in 0..w_out {
17832 let mut acc = 0f32;
17833 for ci in 0..c_in {
17834 for ki in 0..kh {
17835 for kj in 0..kw {
17836 let hi = ho * sh + ki;
17837 let wi = wo * sw + kj;
17838 if hi < ph || wi < pw {
17839 continue;
17840 }
17841 let hi = hi - ph;
17842 let wi = wi - pw;
17843 if hi >= h || wi >= w {
17844 continue;
17845 }
17846 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
17847 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
17848 acc += xv * wv;
17849 }
17850 }
17851 }
17852 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
17853 }
17854 }
17855 }
17856 }
17857 out
17858 };
17859 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17860 let eps = 1e-3f32;
17861 let mut numerical = vec![0f32; x.len()];
17862 for i in 0..x.len() {
17863 let saved = x[i];
17864 x[i] = saved + eps;
17865 let plus = dot(&forward(&x), &dy);
17866 x[i] = saved - eps;
17867 let minus = dot(&forward(&x), &dy);
17868 x[i] = saved;
17869 numerical[i] = (plus - minus) / (2.0 * eps);
17870 }
17871 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17872 assert!(
17874 (a - n).abs() < 5e-3,
17875 "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
17876 );
17877 }
17878 }
17879
17880 #[test]
17881 fn conv2d_backward_weight_matches_numerical_gradient() {
17882 use rlx_ir::Philox4x32;
17883 let n = 2usize;
17884 let c_in = 2usize;
17885 let h = 4usize;
17886 let w = 4usize;
17887 let c_out = 2usize;
17888 let kh = 3usize;
17889 let kw = 3usize;
17890 let ph = 0usize;
17891 let pw = 0usize;
17892 let sh = 1usize;
17893 let sw = 1usize;
17894 let h_out = (h + 2 * ph - kh) / sh + 1;
17895 let w_out = (w + 2 * pw - kw) / sw + 1;
17896
17897 let mut rng = Philox4x32::new(11);
17898 let mut x = vec![0f32; n * c_in * h * w];
17899 rng.fill_normal(&mut x);
17900 let mut wt = vec![0f32; c_out * c_in * kh * kw];
17901 rng.fill_normal(&mut wt);
17902 let mut dy = vec![0f32; n * c_out * h_out * w_out];
17903 rng.fill_normal(&mut dy);
17904
17905 let f = DType::F32;
17906 let mut g = Graph::new("conv_bww");
17907 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
17908 let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
17909 let dwn = g.conv2d_backward_weight(
17910 xn,
17911 dyn_,
17912 Shape::new(&[c_out, c_in, kh, kw], f),
17913 vec![kh, kw],
17914 vec![sh, sw],
17915 vec![ph, pw],
17916 vec![1, 1],
17917 1,
17918 );
17919 g.set_outputs(vec![dwn]);
17920 let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
17921
17922 let forward = |wt: &[f32]| -> Vec<f32> {
17923 let mut out = vec![0f32; n * c_out * h_out * w_out];
17924 for ni in 0..n {
17925 for co in 0..c_out {
17926 for ho in 0..h_out {
17927 for wo in 0..w_out {
17928 let mut acc = 0f32;
17929 for ci in 0..c_in {
17930 for ki in 0..kh {
17931 for kj in 0..kw {
17932 let hi = ho + ki;
17933 let wi = wo + kj;
17934 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
17935 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
17936 acc += xv * wv;
17937 }
17938 }
17939 }
17940 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
17941 }
17942 }
17943 }
17944 }
17945 out
17946 };
17947 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17948 let eps = 1e-3f32;
17949 let mut numerical = vec![0f32; wt.len()];
17950 for i in 0..wt.len() {
17951 let saved = wt[i];
17952 wt[i] = saved + eps;
17953 let plus = dot(&forward(&wt), &dy);
17954 wt[i] = saved - eps;
17955 let minus = dot(&forward(&wt), &dy);
17956 wt[i] = saved;
17957 numerical[i] = (plus - minus) / (2.0 * eps);
17958 }
17959 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17960 assert!(
17961 (a - n).abs() < 5e-3,
17962 "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
17963 );
17964 }
17965 }
17966
17967 #[test]
17968 fn softmax_cross_entropy_matches_reference() {
17969 let f = DType::F32;
17970 let logits: Vec<f32> = vec![
17971 1.0, 2.0, 3.0, -1.0, 0.0, 4.0, 5.0, 5.0, 5.0, ];
17975 let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
17976
17977 let mut g = Graph::new("sce");
17978 let lg = g.input("logits", Shape::new(&[3, 3], f));
17979 let lb = g.input("labels", Shape::new(&[3], f));
17980 let loss = g.softmax_cross_entropy_with_logits(lg, lb);
17981 g.set_outputs(vec![loss]);
17982 let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
17983
17984 let mut expected = vec![0f32; 3];
17986 for ni in 0..3 {
17987 let row = &logits[ni * 3..(ni + 1) * 3];
17988 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17989 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17990 let lse = m + sum.ln();
17991 let label_idx = labels[ni] as usize;
17992 expected[ni] = lse - row[label_idx];
17993 }
17994 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
17995 assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
17996 }
17997 }
17998
17999 #[test]
18000 fn softmax_cross_entropy_backward_matches_numerical_gradient() {
18001 use rlx_ir::Philox4x32;
18002 let n = 4usize;
18003 let c = 5usize;
18004 let mut rng = Philox4x32::new(23);
18005 let mut logits = vec![0f32; n * c];
18006 rng.fill_normal(&mut logits);
18007 let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
18008 let mut d_loss = vec![0f32; n];
18009 rng.fill_normal(&mut d_loss);
18010
18011 let f = DType::F32;
18012 let mut g = Graph::new("sce_bw");
18013 let lg = g.input("logits", Shape::new(&[n, c], f));
18014 let lb = g.input("labels", Shape::new(&[n], f));
18015 let dl = g.input("d_loss", Shape::new(&[n], f));
18016 let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
18017 g.set_outputs(vec![dlogits]);
18018 let analytical = run_graph(
18019 &g,
18020 &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
18021 dlogits,
18022 n * c,
18023 );
18024
18025 let sce_loss = |logits: &[f32]| -> Vec<f32> {
18027 let mut out = vec![0f32; n];
18028 for ni in 0..n {
18029 let row = &logits[ni * c..(ni + 1) * c];
18030 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
18031 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
18032 out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
18033 }
18034 out
18035 };
18036 let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
18037 let eps = 1e-3f32;
18038 let mut numerical = vec![0f32; logits.len()];
18039 for i in 0..logits.len() {
18040 let saved = logits[i];
18041 logits[i] = saved + eps;
18042 let plus = dot(&sce_loss(&logits), &d_loss);
18043 logits[i] = saved - eps;
18044 let minus = dot(&sce_loss(&logits), &d_loss);
18045 logits[i] = saved;
18046 numerical[i] = (plus - minus) / (2.0 * eps);
18047 }
18048 for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
18049 assert!(
18050 (a - num).abs() < 5e-3,
18051 "sce_bw[{i}]: analytical {a} vs numerical {num}"
18052 );
18053 }
18054 }
18055
18056 fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
18069 for node in graph.nodes() {
18070 if let Op::Constant { data } = &node.op
18071 && arena.has_buffer(node.id)
18072 && !data.is_empty()
18073 {
18074 let buf = arena.slice_mut(node.id);
18075 let n_floats = data.len() / 4;
18076 let n = buf.len().min(n_floats);
18077 for i in 0..n {
18078 let bytes = [
18079 data[i * 4],
18080 data[i * 4 + 1],
18081 data[i * 4 + 2],
18082 data[i * 4 + 3],
18083 ];
18084 buf[i] = f32::from_le_bytes(bytes);
18085 }
18086 }
18087 }
18088 }
18089
18090 fn prepare(
18094 graph: &Graph,
18095 seed_inputs: &[(NodeId, &[f32])],
18096 ) -> (ThunkSchedule, crate::arena::Arena) {
18097 let plan = rlx_opt::memory::plan_memory(graph);
18098 let mut arena = crate::arena::Arena::from_plan(plan);
18099 let sched = compile_thunks(graph, &arena);
18100 fill_constants_into_arena(graph, &mut arena);
18101 for &(id, data) in seed_inputs {
18102 let off = arena.byte_offset(id);
18103 let buf = arena.raw_buf_mut();
18104 unsafe {
18105 let p = buf.as_mut_ptr().add(off) as *mut f32;
18106 for (i, &v) in data.iter().enumerate() {
18107 *p.add(i) = v;
18108 }
18109 }
18110 }
18111 (sched, arena)
18112 }
18113
18114 fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
18115 let off = arena.byte_offset(id);
18116 unsafe {
18117 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
18118 (0..len).map(|i| *p.add(i)).collect()
18119 }
18120 }
18121
18122 fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
18123 let off = arena.byte_offset(id);
18124 let buf = arena.raw_buf_mut();
18125 unsafe {
18126 let p = buf.as_mut_ptr().add(off) as *mut f32;
18127 for (i, &v) in data.iter().enumerate() {
18128 *p.add(i) = v;
18129 }
18130 }
18131 }
18132
18133 fn prepare_f64(
18135 graph: &Graph,
18136 seed_inputs: &[(NodeId, &[f64])],
18137 ) -> (ThunkSchedule, crate::arena::Arena) {
18138 let plan = rlx_opt::memory::plan_memory(graph);
18139 let mut arena = crate::arena::Arena::from_plan(plan);
18140 let sched = compile_thunks(graph, &arena);
18141 fill_constants_into_arena(graph, &mut arena);
18142 for &(id, data) in seed_inputs {
18143 let off = arena.byte_offset(id);
18144 let buf = arena.raw_buf_mut();
18145 unsafe {
18146 let p = buf.as_mut_ptr().add(off) as *mut f64;
18147 for (i, &v) in data.iter().enumerate() {
18148 *p.add(i) = v;
18149 }
18150 }
18151 }
18152 (sched, arena)
18153 }
18154
18155 fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
18156 let off = arena.byte_offset(id);
18157 unsafe {
18158 let p = arena.raw_buf().as_ptr().add(off) as *const f64;
18159 (0..len).map(|i| *p.add(i)).collect()
18160 }
18161 }
18162
18163 #[test]
18173 fn dense_solve_f64_end_to_end() {
18174 let mut g = Graph::new("solve_e2e");
18175 let a = g.input("A", Shape::new(&[2, 2], DType::F64));
18176 let b = g.input("b", Shape::new(&[2], DType::F64));
18177 let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
18178 g.set_outputs(vec![x]);
18179
18180 let a_data = [2.0, 1.0, 1.0, 3.0_f64];
18181 let b_data = [5.0, 10.0_f64];
18182 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
18183 execute_thunks(&sched, arena.raw_buf_mut());
18184
18185 let got = read_arena_f64(&arena, x, 2);
18186 let want = [1.0, 3.0_f64];
18187 for i in 0..2 {
18188 assert!(
18189 (got[i] - want[i]).abs() < 1e-12,
18190 "x[{i}] = {} (expected {})",
18191 got[i],
18192 want[i]
18193 );
18194 }
18195 }
18196
18197 #[test]
18203 fn dense_solve_f64_5x5_laplacian() {
18204 let n = 5usize;
18205 let mut g = Graph::new("solve_5x5");
18206 let a = g.input("A", Shape::new(&[n, n], DType::F64));
18207 let b = g.input("b", Shape::new(&[n], DType::F64));
18208 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
18209 g.set_outputs(vec![x]);
18210
18211 let mut a_data = vec![0.0_f64; n * n];
18213 for i in 0..n {
18214 a_data[i * n + i] = 2.0;
18215 if i > 0 {
18216 a_data[i * n + (i - 1)] = -1.0;
18217 }
18218 if i + 1 < n {
18219 a_data[i * n + (i + 1)] = -1.0;
18220 }
18221 }
18222 let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
18223 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
18224 execute_thunks(&sched, arena.raw_buf_mut());
18225
18226 let got = read_arena_f64(&arena, x, n);
18227 let mut residual = vec![0.0_f64; n];
18229 for i in 0..n {
18230 for j in 0..n {
18231 residual[i] += a_data[i * n + j] * got[j];
18232 }
18233 }
18234 for i in 0..n {
18235 assert!(
18236 (residual[i] - b_data[i]).abs() < 1e-10,
18237 "row {i}: residual {} vs b {}",
18238 residual[i],
18239 b_data[i]
18240 );
18241 }
18242 }
18243
18244 #[test]
18263 fn hello_resistor_gradient_end_to_end() {
18264 use rlx_opt::autodiff::grad_with_loss;
18265 let n = 3usize;
18266
18267 let mut g = Graph::new("hello_resistor");
18269 let a = g.param("A", Shape::new(&[n, n], DType::F64));
18270 let b = g.input("b", Shape::new(&[n], DType::F64));
18271 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
18272 let loss = g.reduce(
18273 x,
18274 ReduceOp::Sum,
18275 vec![0],
18276 false,
18277 Shape::new(&[1], DType::F64),
18278 );
18279 g.set_outputs(vec![loss]);
18280
18281 let bwd = grad_with_loss(&g, &[a, b]);
18283 assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
18284
18285 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18289 for node in graph.nodes() {
18290 let name = match &node.op {
18291 rlx_ir::Op::Input { name } => Some(name.as_str()),
18292 rlx_ir::Op::Param { name } => Some(name.as_str()),
18293 _ => None,
18294 };
18295 if name == Some(want) {
18296 return node.id;
18297 }
18298 }
18299 panic!("no node named {want:?} in bwd graph");
18300 };
18301 let a_bwd = find_by_name(&bwd, "A");
18302 let b_bwd = find_by_name(&bwd, "b");
18303 let d_out_bwd = find_by_name(&bwd, "d_output");
18304
18305 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
18309 let b_data = [1.0, 2.0, 3.0_f64];
18310 let d_output = [1.0_f64]; let (sched, mut arena) = prepare_f64(
18314 &bwd,
18315 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
18316 );
18317 execute_thunks(&sched, arena.raw_buf_mut());
18318
18319 let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
18320 let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
18321 let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
18322
18323 let x_ref = {
18326 let mut a = a_data;
18327 let mut b = b_data;
18328 let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
18329 assert_eq!(info, 0);
18330 b
18331 };
18332 let loss_ref: f64 = x_ref.iter().sum();
18333 let db_ref = {
18335 let mut at = [0.0_f64; 9];
18336 for i in 0..n {
18337 for j in 0..n {
18338 at[i * n + j] = a_data[j * n + i];
18339 }
18340 }
18341 let mut ones = [1.0_f64; 3];
18342 let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
18343 assert_eq!(info, 0);
18344 ones
18345 };
18346 let mut da_ref = [0.0_f64; 9];
18348 for i in 0..n {
18349 for j in 0..n {
18350 da_ref[i * n + j] = -db_ref[i] * x_ref[j];
18351 }
18352 }
18353
18354 assert!(
18356 (loss_out[0] - loss_ref).abs() < 1e-10,
18357 "loss: got {}, want {}",
18358 loss_out[0],
18359 loss_ref
18360 );
18361 for i in 0..n {
18362 assert!(
18363 (db_out[i] - db_ref[i]).abs() < 1e-10,
18364 "db[{i}]: got {}, want {}",
18365 db_out[i],
18366 db_ref[i]
18367 );
18368 }
18369 for i in 0..n * n {
18370 assert!(
18371 (da_out[i] - da_ref[i]).abs() < 1e-10,
18372 "dA[{i}]: got {}, want {}",
18373 da_out[i],
18374 da_ref[i]
18375 );
18376 }
18377
18378 let h = 1e-6_f64;
18381 for k in 0..n {
18382 let mut bp = b_data;
18383 bp[k] += h;
18384 let mut bm = b_data;
18385 bm[k] -= h;
18386 let lp = {
18387 let mut ac = a_data;
18388 let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
18389 assert_eq!(info, 0);
18390 bp.iter().sum::<f64>()
18391 };
18392 let lm = {
18393 let mut ac = a_data;
18394 let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
18395 assert_eq!(info, 0);
18396 bm.iter().sum::<f64>()
18397 };
18398 let fd = (lp - lm) / (2.0 * h);
18399 assert!(
18400 (db_out[k] - fd).abs() < 1e-7,
18401 "FD mismatch on db[{k}]: AD={} FD={}",
18402 db_out[k],
18403 fd
18404 );
18405 }
18406 }
18407
18408 #[test]
18413 fn scan_geometric_growth_f64() {
18414 let n = 3usize;
18415 let length = 10u32;
18416
18417 let mut body = Graph::new("scan_body");
18419 let x = body.input("carry", Shape::new(&[n], DType::F64));
18420 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
18421 let scale = body.add_node(
18422 Op::Constant { data: scale_bytes },
18423 vec![],
18424 Shape::new(&[n], DType::F64),
18425 );
18426 let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
18427 let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
18428 body.set_outputs(vec![next]);
18429
18430 let mut g = Graph::new("scan_outer");
18432 let init = g.input("init", Shape::new(&[n], DType::F64));
18433 let final_carry = g.scan(init, body, length);
18434 g.set_outputs(vec![final_carry]);
18435
18436 let init_data = vec![1.0_f64; n];
18437 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
18438 execute_thunks(&sched, arena.raw_buf_mut());
18439 let got = read_arena_f64(&arena, final_carry, n);
18440 let want: f64 = 1.1_f64.powi(length as i32);
18441 for i in 0..n {
18442 assert!(
18443 (got[i] - want).abs() < 1e-12,
18444 "got[{i}] = {} want {}",
18445 got[i],
18446 want
18447 );
18448 }
18449 }
18450
18451 #[test]
18458 fn scan_with_xs_cumulative_sum() {
18459 let n = 3usize;
18460 let length = 4u32;
18461
18462 let mut body = Graph::new("cumsum_body");
18463 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18465 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
18466 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
18467 body.set_outputs(vec![next]);
18468
18469 let mut g = Graph::new("cumsum_outer");
18470 let init = g.input("init", Shape::new(&[n], DType::F64));
18471 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18472 let final_carry = g.scan_with_xs(init, &[xs], body, length);
18473 g.set_outputs(vec![final_carry]);
18474
18475 let init_data = vec![0.0_f64; n];
18476 let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
18478 execute_thunks(&sched, arena.raw_buf_mut());
18479 let got = read_arena_f64(&arena, final_carry, n);
18480
18481 let mut want = init_data.clone();
18485 for t in 0..length as usize {
18486 for j in 0..n {
18487 want[j] += xs_data[t * n + j];
18488 }
18489 }
18490 for i in 0..n {
18491 assert!(
18492 (got[i] - want[i]).abs() < 1e-12,
18493 "got[{i}] = {} want {}",
18494 got[i],
18495 want[i]
18496 );
18497 }
18498 }
18499
18500 #[test]
18504 fn scan_with_xs_be_with_drive() {
18505 let n = 3usize;
18506 let length = 4u32;
18507 let dt = 0.1_f64;
18508
18509 let mut m_data = vec![0.0_f64; n * n];
18510 for i in 0..n {
18511 m_data[i * n + i] = 1.0 + dt * 2.0;
18512 if i > 0 {
18513 m_data[i * n + (i - 1)] = -dt;
18514 }
18515 if i + 1 < n {
18516 m_data[i * n + (i + 1)] = -dt;
18517 }
18518 }
18519 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18520
18521 let mut body = Graph::new("be_drive_body");
18522 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18523 let drive = body.input("drive", Shape::new(&[n], DType::F64));
18524 let m = body.add_node(
18525 Op::Constant { data: m_bytes },
18526 vec![],
18527 Shape::new(&[n, n], DType::F64),
18528 );
18529 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18530 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18531 body.set_outputs(vec![next]);
18532
18533 let mut g = Graph::new("be_drive_outer");
18534 let init = g.input("init", Shape::new(&[n], DType::F64));
18535 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18536 let final_carry = g.scan_with_xs(init, &[xs], body, length);
18537 g.set_outputs(vec![final_carry]);
18538
18539 let init_data = vec![0.0_f64; n];
18540 let mut xs_data = vec![0.0_f64; length as usize * n];
18543 xs_data[0] = 1.0;
18544
18545 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
18546 execute_thunks(&sched, arena.raw_buf_mut());
18547 let got = read_arena_f64(&arena, final_carry, n);
18548
18549 let mut x = init_data.clone();
18551 for t in 0..length as usize {
18552 for j in 0..n {
18553 x[j] += xs_data[t * n + j];
18554 }
18555 let mut a_copy = m_data.clone();
18556 crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
18557 }
18558 for i in 0..n {
18559 assert!(
18560 (got[i] - x[i]).abs() < 1e-12,
18561 "got[{i}] = {} ref {}",
18562 got[i],
18563 x[i]
18564 );
18565 }
18566 }
18567
18568 #[test]
18574 fn batched_dense_solve_gradient_matches_per_batch_analytic() {
18575 use rlx_opt::autodiff::grad_with_loss;
18576 let n = 3usize;
18577 let batch = 4usize;
18578
18579 let mut g = Graph::new("bds_grad");
18580 let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
18581 let b = g.input("b", Shape::new(&[batch, n], DType::F64));
18582 let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
18583 let loss = g.reduce(
18584 x,
18585 ReduceOp::Sum,
18586 vec![0, 1],
18587 false,
18588 Shape::new(&[1], DType::F64),
18589 );
18590 g.set_outputs(vec![loss]);
18591
18592 let bwd = grad_with_loss(&g, &[a, b]);
18593
18594 let find = |graph: &Graph, want: &str| -> NodeId {
18595 for node in graph.nodes() {
18596 let name = match &node.op {
18597 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18598 _ => None,
18599 };
18600 if name == Some(want) {
18601 return node.id;
18602 }
18603 }
18604 panic!("no node named {want}");
18605 };
18606 let a_id = find(&bwd, "A");
18607 let b_id = find(&bwd, "b");
18608 let d_out_id = find(&bwd, "d_output");
18609
18610 let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
18611 let mut a_data = vec![0.0_f64; batch * n * n];
18612 let mut b_data = vec![0.0_f64; batch * n];
18613 for bi in 0..batch {
18614 for i in 0..n {
18615 for j in 0..n {
18616 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
18617 }
18618 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
18619 }
18620 for i in 0..n {
18621 b_data[bi * n + i] = rng.next_f32() as f64;
18622 }
18623 }
18624 let d_seed = [1.0_f64];
18625
18626 let (sched, mut arena) = prepare_f64(
18627 &bwd,
18628 &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
18629 );
18630 execute_thunks(&sched, arena.raw_buf_mut());
18631 let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
18632 let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
18633
18634 for bi in 0..batch {
18637 let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
18638 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
18639 let mut a_copy = a_slice.clone();
18640 crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
18641 let x_ref = b_slice.clone();
18642 let mut at = vec![0.0_f64; n * n];
18644 for i in 0..n {
18645 for j in 0..n {
18646 at[i * n + j] = a_slice[j * n + i];
18647 }
18648 }
18649 let mut ones = vec![1.0_f64; n];
18650 crate::blas::dgesv(&mut at, &mut ones, n, 1);
18651 let db_ref = ones;
18652 for i in 0..n {
18653 let got = db_out[bi * n + i];
18654 assert!(
18655 (got - db_ref[i]).abs() < 1e-10,
18656 "batch {bi}, db[{i}]: got {got} ref {}",
18657 db_ref[i]
18658 );
18659 }
18660 for i in 0..n {
18662 for j in 0..n {
18663 let got = da_out[bi * n * n + i * n + j];
18664 let want = -db_ref[i] * x_ref[j];
18665 assert!(
18666 (got - want).abs() < 1e-10,
18667 "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
18668 );
18669 }
18670 }
18671 }
18672 }
18673
18674 #[test]
18679 fn scan_checkpointed_grad_matches_plain_scan_grad() {
18680 use rlx_opt::autodiff::grad_with_loss;
18681 let n = 2usize;
18682 let length = 6u32;
18683
18684 let make_body = || {
18685 let mut body = Graph::new("ck_body");
18686 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18687 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
18688 let scale = body.add_node(
18689 Op::Constant { data: scale_bytes },
18690 vec![],
18691 Shape::new(&[n], DType::F64),
18692 );
18693 let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
18694 body.set_outputs(vec![next]);
18695 body
18696 };
18697
18698 let mut g_plain = Graph::new("ck_plain");
18700 let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
18701 let final_p = g_plain.scan(init_p, make_body(), length);
18702 let loss_p = g_plain.reduce(
18703 final_p,
18704 ReduceOp::Sum,
18705 vec![0],
18706 false,
18707 Shape::new(&[1], DType::F64),
18708 );
18709 g_plain.set_outputs(vec![loss_p]);
18710 let bwd_p = grad_with_loss(&g_plain, &[init_p]);
18711
18712 let mut g_ck = Graph::new("ck_ckpt");
18714 let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
18715 let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
18716 let loss_c = g_ck.reduce(
18717 final_c,
18718 ReduceOp::Sum,
18719 vec![0],
18720 false,
18721 Shape::new(&[1], DType::F64),
18722 );
18723 g_ck.set_outputs(vec![loss_c]);
18724 let bwd_c = grad_with_loss(&g_ck, &[init_c]);
18725
18726 let find = |graph: &Graph, want: &str| -> NodeId {
18727 for node in graph.nodes() {
18728 let name = match &node.op {
18729 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18730 _ => None,
18731 };
18732 if name == Some(want) {
18733 return node.id;
18734 }
18735 }
18736 panic!("no {want}");
18737 };
18738
18739 let init_data = vec![0.5_f64, -0.5];
18740 let d_seed = [1.0_f64];
18741
18742 let (s_p, mut a_p) = prepare_f64(
18743 &bwd_p,
18744 &[
18745 (find(&bwd_p, "init"), &init_data),
18746 (find(&bwd_p, "d_output"), &d_seed),
18747 ],
18748 );
18749 execute_thunks(&s_p, a_p.raw_buf_mut());
18750 let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
18751
18752 let (s_c, mut a_c) = prepare_f64(
18753 &bwd_c,
18754 &[
18755 (find(&bwd_c, "init"), &init_data),
18756 (find(&bwd_c, "d_output"), &d_seed),
18757 ],
18758 );
18759 execute_thunks(&s_c, a_c.raw_buf_mut());
18760 let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
18761
18762 for i in 0..n {
18763 assert!(
18764 (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
18765 "dinit[{i}]: plain={} checkpointed={}",
18766 dinit_p[i],
18767 dinit_c[i]
18768 );
18769 }
18770 }
18771
18772 #[test]
18778 fn recursive_checkpointing_matches_full_trajectory() {
18779 let n = 2usize;
18780 let length = 4u32;
18781
18782 let build_body = || -> Graph {
18784 let mut body = Graph::new("rc_body");
18785 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18786 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18787 let ones = body.add_node(
18788 Op::Constant { data: ones_bytes },
18789 vec![],
18790 Shape::new(&[n], DType::F64),
18791 );
18792 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
18793 body.set_outputs(vec![next]);
18794 body
18795 };
18796
18797 let body_vjp_for = || -> Graph {
18800 use rlx_opt::autodiff::grad;
18801 let body = build_body();
18802 let carry_id = body
18804 .nodes()
18805 .iter()
18806 .find(|n| matches!(n.op, Op::Input { .. }))
18807 .map(|n| n.id)
18808 .unwrap();
18809 grad(&body, &[carry_id])
18810 };
18811
18812 let mut g_full = Graph::new("rc_outer_full");
18814 let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
18815 let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
18816 let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
18818 let dinit_full_id = g_full.scan_backward(
18819 init_full,
18820 traj_full_id,
18821 upstream_full,
18822 &[],
18823 body_vjp_for(),
18824 length,
18825 true,
18826 Shape::new(&[n], DType::F64),
18827 );
18828 g_full.set_outputs(vec![dinit_full_id]);
18829
18830 let k = 2u32;
18833 let mut g_rec = Graph::new("rc_outer_rec");
18834 let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
18835 let traj_rec_id = g_rec.add_node(
18836 Op::Scan {
18837 body: Box::new(build_body()),
18838 length,
18839 save_trajectory: true,
18840 num_bcast: 0,
18841 num_xs: 0,
18842 num_checkpoints: k,
18843 },
18844 vec![init_rec],
18845 Shape::new(&[k as usize, n], DType::F64),
18846 );
18847 let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
18850 let dinit_rec_id = g_rec.add_node(
18851 Op::ScanBackward {
18852 body_vjp: Box::new(body_vjp_for()),
18853 length,
18854 save_trajectory: true,
18855 num_xs: 0,
18856 num_checkpoints: k,
18857 forward_body: Some(Box::new(build_body())),
18858 },
18859 vec![init_rec, traj_rec_id, upstream_rec],
18860 Shape::new(&[n], DType::F64),
18861 );
18862 g_rec.set_outputs(vec![dinit_rec_id]);
18863
18864 let init_data = vec![0.5_f64, -0.5];
18866 let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
18867
18868 let find = |graph: &Graph, want: &str| -> NodeId {
18869 for node in graph.nodes() {
18870 if let Op::Input { name } = &node.op
18871 && name == want
18872 {
18873 return node.id;
18874 }
18875 }
18876 panic!("no input {want}");
18877 };
18878
18879 let (s_full, mut a_full) = prepare_f64(
18880 &g_full,
18881 &[
18882 (find(&g_full, "init"), &init_data),
18883 (find(&g_full, "upstream"), &upstream_data),
18884 ],
18885 );
18886 execute_thunks(&s_full, a_full.raw_buf_mut());
18887 let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
18888
18889 let (s_rec, mut a_rec) = prepare_f64(
18890 &g_rec,
18891 &[
18892 (find(&g_rec, "init"), &init_data),
18893 (find(&g_rec, "upstream"), &upstream_data),
18894 ],
18895 );
18896 execute_thunks(&s_rec, a_rec.raw_buf_mut());
18897 let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
18898
18899 for i in 0..n {
18900 assert!(
18901 (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
18902 "i={i}: full={} rec={}",
18903 dinit_full[i],
18904 dinit_rec[i]
18905 );
18906 }
18907 }
18908
18909 #[test]
18918 fn vmap_of_grad_scan_matches_per_row_runs() {
18919 use rlx_opt::autodiff::grad_with_loss;
18920 use rlx_opt::vmap::vmap;
18921 let n = 2usize;
18922 let length = 3u32;
18923 let batch = 3usize;
18924
18925 let mut body = Graph::new("scan_grad_body");
18926 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18927 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18928 let ones = body.add_node(
18929 Op::Constant { data: ones_bytes },
18930 vec![],
18931 Shape::new(&[n], DType::F64),
18932 );
18933 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
18934 body.set_outputs(vec![next]);
18935
18936 let mut g = Graph::new("scan_grad_outer");
18937 let init = g.input("init", Shape::new(&[n], DType::F64));
18938 let final_x = g.scan(init, body, length);
18939 let loss = g.reduce(
18940 final_x,
18941 ReduceOp::Sum,
18942 vec![0],
18943 false,
18944 Shape::new(&[1], DType::F64),
18945 );
18946 g.set_outputs(vec![loss]);
18947
18948 let bwd = grad_with_loss(&g, &[init]);
18949 let bg = vmap(&bwd, &["init"], batch);
18950
18951 let find = |graph: &Graph, want: &str| -> NodeId {
18952 for node in graph.nodes() {
18953 let name = match &node.op {
18954 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18955 _ => None,
18956 };
18957 if name == Some(want) {
18958 return node.id;
18959 }
18960 }
18961 panic!("no node named {want}");
18962 };
18963 let init_b = find(&bg, "init");
18964 let d_out_b = find(&bg, "d_output");
18965
18966 let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
18967 let d_seed = [1.0_f64];
18968
18969 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
18970 execute_thunks(&sched, arena.raw_buf_mut());
18971 let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
18972
18973 for i in 0..batch * n {
18974 assert!(
18975 (dinit_b[i] - 1.0).abs() < 1e-12,
18976 "dinit[{i}] = {} (expected 1.0)",
18977 dinit_b[i]
18978 );
18979 }
18980
18981 for bi in 0..batch {
18983 let row = &init_data[bi * n..(bi + 1) * n];
18984 let mut g2 = Graph::new("per_row_grad");
18985 let init2 = g2.input("init", Shape::new(&[n], DType::F64));
18986 let mut body2 = Graph::new("per_row_body");
18987 let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
18988 let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18989 let ones2 = body2.add_node(
18990 Op::Constant { data: ones2_bytes },
18991 vec![],
18992 Shape::new(&[n], DType::F64),
18993 );
18994 let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
18995 body2.set_outputs(vec![next2]);
18996 let final2 = g2.scan(init2, body2, length);
18997 let loss2 = g2.reduce(
18998 final2,
18999 ReduceOp::Sum,
19000 vec![0],
19001 false,
19002 Shape::new(&[1], DType::F64),
19003 );
19004 g2.set_outputs(vec![loss2]);
19005 let bwd2 = grad_with_loss(&g2, &[init2]);
19006 let init2_id = find(&bwd2, "init");
19007 let d_out2_id = find(&bwd2, "d_output");
19008 let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
19009 execute_thunks(&s2, a2.raw_buf_mut());
19010 let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
19011 for j in 0..n {
19012 let got = dinit_b[bi * n + j];
19013 let want = row_dinit[j];
19014 assert!(
19015 (got - want).abs() < 1e-12,
19016 "row {bi}, j {j}: vmap'd={got} per-row={want}"
19017 );
19018 }
19019 }
19020 }
19021
19022 #[test]
19028 fn vmap_scan_cumulative_sum_matches_scalar_runs() {
19029 use rlx_opt::vmap::vmap;
19030 let n = 2usize;
19031 let length = 4u32;
19032 let batch = 3usize;
19033
19034 let mut body = Graph::new("scan_body_cumsum");
19036 let carry = body.input("carry", Shape::new(&[n], DType::F64));
19037 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
19038 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
19039 body.set_outputs(vec![next]);
19040
19041 let mut g = Graph::new("scan_outer_cumsum");
19042 let init = g.input("init", Shape::new(&[n], DType::F64));
19043 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
19044 let final_carry = g.scan_with_xs(init, &[xs], body, length);
19045 g.set_outputs(vec![final_carry]);
19046
19047 let bg = vmap(&g, &["init", "xs"], batch);
19049
19050 let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
19052 let xs_data: Vec<f64> = (0..batch * length as usize * n)
19055 .map(|i| 0.1 * (i as f64))
19056 .collect();
19057
19058 let find = |graph: &Graph, want: &str| -> NodeId {
19059 for node in graph.nodes() {
19060 if let Op::Input { name } = &node.op
19061 && name == want
19062 {
19063 return node.id;
19064 }
19065 }
19066 panic!("no input {want}");
19067 };
19068 let init_b = find(&bg, "init");
19069 let xs_b = find(&bg, "xs");
19070 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
19071 execute_thunks(&sched, arena.raw_buf_mut());
19072 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
19073
19074 for bi in 0..batch {
19076 let init_slice = &init_data[bi * n..(bi + 1) * n];
19077 let mut x = init_slice.to_vec();
19078 for t in 0..length as usize {
19079 for j in 0..n {
19080 x[j] += xs_data[bi * length as usize * n + t * n + j];
19081 }
19082 }
19083
19084 for i in 0..n {
19085 let got = batched_out[bi * n + i];
19086 assert!(
19087 (got - x[i]).abs() < 1e-12,
19088 "row {bi}, i {i}: got {got} ref {}",
19089 x[i]
19090 );
19091 }
19092 }
19093 }
19094
19095 #[test]
19100 fn vmap_dense_solve_matches_scalar_runs() {
19101 use rlx_opt::vmap::vmap;
19102 let n = 3usize;
19103 let batch = 4usize;
19104
19105 let mut g = Graph::new("solve_forward");
19106 let a = g.input("A", Shape::new(&[n, n], DType::F64));
19107 let b = g.input("b", Shape::new(&[n], DType::F64));
19108 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19109 g.set_outputs(vec![x]);
19110
19111 let bg = vmap(&g, &["A", "b"], batch);
19113
19114 let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
19116 let mut a_data = vec![0.0_f64; batch * n * n];
19117 let mut b_data = vec![0.0_f64; batch * n];
19118 for bi in 0..batch {
19119 for i in 0..n {
19121 for j in 0..n {
19122 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
19123 }
19124 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
19125 }
19126 for i in 0..n {
19127 b_data[bi * n + i] = rng.next_f32() as f64;
19128 }
19129 }
19130
19131 let find = |graph: &Graph, want: &str| -> NodeId {
19132 for node in graph.nodes() {
19133 if let Op::Input { name } = &node.op
19134 && name == want
19135 {
19136 return node.id;
19137 }
19138 }
19139 panic!("no input named {want}");
19140 };
19141 let ba = find(&bg, "A");
19142 let bb = find(&bg, "b");
19143 let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
19144 execute_thunks(&sched, arena.raw_buf_mut());
19145 let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
19146
19147 for bi in 0..batch {
19149 let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
19150 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
19151 crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
19152 for i in 0..n {
19153 let got = batched_x[bi * n + i];
19154 let want = b_slice[i];
19155 assert!(
19156 (got - want).abs() < 1e-12,
19157 "row {bi}, i {i}: got {got} want {want}"
19158 );
19159 }
19160 }
19161 }
19162
19163 #[test]
19170 fn vmap_matmul_add_reduce_matches_scalar_runs() {
19171 use rlx_opt::vmap::vmap;
19172 let n = 3usize;
19173 let batch = 4usize;
19174
19175 let mut g = Graph::new("vmap_e2e_forward");
19177 let x = g.input("x", Shape::new(&[n], DType::F64));
19178 let w = g.input("w", Shape::new(&[n, n], DType::F64));
19179 let b = g.input("b", Shape::new(&[n], DType::F64));
19180 let x_row = g.add_node(
19181 Op::Reshape {
19182 new_shape: vec![1, n as i64],
19183 },
19184 vec![x],
19185 Shape::new(&[1, n], DType::F64),
19186 );
19187 let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
19188 let mm_flat = g.add_node(
19189 Op::Reshape {
19190 new_shape: vec![n as i64],
19191 },
19192 vec![mm],
19193 Shape::new(&[n], DType::F64),
19194 );
19195 let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
19196 let loss = g.reduce(
19197 yv,
19198 ReduceOp::Sum,
19199 vec![0],
19200 false,
19201 Shape::new(&[1], DType::F64),
19202 );
19203 g.set_outputs(vec![loss]);
19204
19205 let bg = vmap(&g, &["x"], batch);
19207
19208 let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
19210 let n_w = n * n;
19211 let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
19212 let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
19213 let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
19214 for _ in 0..batch * n {
19215 x_data_batched.push(rng.next_f32() as f64);
19216 }
19217
19218 let find = |graph: &Graph, want: &str| -> NodeId {
19220 for node in graph.nodes() {
19221 if let Op::Input { name } = &node.op
19222 && name == want
19223 {
19224 return node.id;
19225 }
19226 }
19227 panic!("no input named {want}");
19228 };
19229 let bx = find(&bg, "x");
19230 let bw = find(&bg, "w");
19231 let bb = find(&bg, "b");
19232 let (sched, mut arena) =
19233 prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
19234 execute_thunks(&sched, arena.raw_buf_mut());
19235 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
19241
19242 for bi in 0..batch {
19244 let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
19245 let mut g2 = Graph::new("scalar_run");
19246 let x2 = g2.input("x", Shape::new(&[n], DType::F64));
19247 let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
19248 let b2 = g2.input("b", Shape::new(&[n], DType::F64));
19249 let xr = g2.add_node(
19250 Op::Reshape {
19251 new_shape: vec![1, n as i64],
19252 },
19253 vec![x2],
19254 Shape::new(&[1, n], DType::F64),
19255 );
19256 let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
19257 let mf = g2.add_node(
19258 Op::Reshape {
19259 new_shape: vec![n as i64],
19260 },
19261 vec![m],
19262 Shape::new(&[n], DType::F64),
19263 );
19264 let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
19265 let l2 = g2.reduce(
19266 yv2,
19267 ReduceOp::Sum,
19268 vec![0],
19269 false,
19270 Shape::new(&[1], DType::F64),
19271 );
19272 g2.set_outputs(vec![l2]);
19273 let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
19274 execute_thunks(&s2, a2.raw_buf_mut());
19275 let scalar_out = read_arena_f64(&a2, l2, 1);
19276 assert!(
19277 (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
19278 "row {bi}: batched={} scalar={}",
19279 batched_out[bi],
19280 scalar_out[0]
19281 );
19282 }
19283 }
19284
19285 #[test]
19292 fn scan_with_xs_dxs_matches_fd() {
19293 use rlx_opt::autodiff::grad_with_loss;
19294 let n = 3usize;
19295 let length = 3u32;
19296 let dt = 0.1_f64;
19297
19298 let mut m_data = vec![0.0_f64; n * n];
19299 for i in 0..n {
19300 m_data[i * n + i] = 1.0 + dt * 2.0;
19301 if i > 0 {
19302 m_data[i * n + (i - 1)] = -dt;
19303 }
19304 if i + 1 < n {
19305 m_data[i * n + (i + 1)] = -dt;
19306 }
19307 }
19308 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19309
19310 let mut body = Graph::new("be_dxs_body");
19311 let carry = body.input("carry", Shape::new(&[n], DType::F64));
19312 let drive = body.input("drive", Shape::new(&[n], DType::F64));
19313 let m = body.add_node(
19314 Op::Constant { data: m_bytes },
19315 vec![],
19316 Shape::new(&[n, n], DType::F64),
19317 );
19318 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
19319 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
19320 body.set_outputs(vec![next]);
19321
19322 let mut g = Graph::new("be_dxs_outer");
19323 let init = g.input("init", Shape::new(&[n], DType::F64));
19324 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
19325 let final_carry = g.scan_with_xs(init, &[xs], body, length);
19326 let loss = g.reduce(
19327 final_carry,
19328 ReduceOp::Sum,
19329 vec![0],
19330 false,
19331 Shape::new(&[1], DType::F64),
19332 );
19333 g.set_outputs(vec![loss]);
19334
19335 let bwd = grad_with_loss(&g, &[init, xs]);
19337 assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
19338
19339 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19340 for node in graph.nodes() {
19341 let name = match &node.op {
19342 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19343 _ => None,
19344 };
19345 if name == Some(want) {
19346 return node.id;
19347 }
19348 }
19349 panic!("no node named {want:?}");
19350 };
19351 let init_bwd = find_by_name(&bwd, "init");
19352 let xs_bwd = find_by_name(&bwd, "xs");
19353 let d_out_bwd = find_by_name(&bwd, "d_output");
19354
19355 let init_data = vec![0.5_f64, 0.0, -0.5];
19356 let xs_data: Vec<f64> = (0..length as usize * n)
19357 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
19358 .collect();
19359 let d_seed = [1.0_f64];
19360
19361 let (sched, mut arena) = prepare_f64(
19362 &bwd,
19363 &[
19364 (init_bwd, &init_data),
19365 (xs_bwd, &xs_data),
19366 (d_out_bwd, &d_seed),
19367 ],
19368 );
19369 execute_thunks(&sched, arena.raw_buf_mut());
19370 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
19371 let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
19372
19373 let h = 1e-6;
19374 let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
19375 let mut acc = x0.to_vec();
19376 for t in 0..length as usize {
19377 for j in 0..n {
19378 acc[j] += xs_in[t * n + j];
19379 }
19380 let mut a_copy = m_data.clone();
19381 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
19382 }
19383 acc.iter().sum()
19384 };
19385
19386 for i in 0..n {
19388 let mut ip = init_data.to_vec();
19389 ip[i] += h;
19390 let mut im = init_data.to_vec();
19391 im[i] -= h;
19392 let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
19393 assert!(
19394 (dinit[i] - fd).abs() < 1e-7,
19395 "FD dinit[{i}]: AD={} FD={}",
19396 dinit[i],
19397 fd
19398 );
19399 }
19400
19401 for t in 0..length as usize {
19403 for j in 0..n {
19404 let idx = t * n + j;
19405 let mut xp = xs_data.clone();
19406 xp[idx] += h;
19407 let mut xm = xs_data.clone();
19408 xm[idx] -= h;
19409 let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
19410 assert!(
19411 (dxs[idx] - fd).abs() < 1e-7,
19412 "FD dxs[t={t},j={j}]: AD={} FD={}",
19413 dxs[idx],
19414 fd
19415 );
19416 }
19417 }
19418 }
19419
19420 #[test]
19428 fn scan_with_xs_gradient_dinit_matches_fd() {
19429 use rlx_opt::autodiff::grad_with_loss;
19430 let n = 3usize;
19431 let length = 3u32;
19432 let dt = 0.1_f64;
19433
19434 let mut m_data = vec![0.0_f64; n * n];
19435 for i in 0..n {
19436 m_data[i * n + i] = 1.0 + dt * 2.0;
19437 if i > 0 {
19438 m_data[i * n + (i - 1)] = -dt;
19439 }
19440 if i + 1 < n {
19441 m_data[i * n + (i + 1)] = -dt;
19442 }
19443 }
19444 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19445
19446 let mut body = Graph::new("be_xs_grad_body");
19447 let carry = body.input("carry", Shape::new(&[n], DType::F64));
19448 let drive = body.input("drive", Shape::new(&[n], DType::F64));
19449 let m = body.add_node(
19450 Op::Constant { data: m_bytes },
19451 vec![],
19452 Shape::new(&[n, n], DType::F64),
19453 );
19454 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
19455 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
19456 body.set_outputs(vec![next]);
19457
19458 let mut g = Graph::new("be_xs_grad_outer");
19459 let init = g.input("init", Shape::new(&[n], DType::F64));
19460 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
19461 let final_carry = g.scan_with_xs(init, &[xs], body, length);
19462 let loss = g.reduce(
19463 final_carry,
19464 ReduceOp::Sum,
19465 vec![0],
19466 false,
19467 Shape::new(&[1], DType::F64),
19468 );
19469 g.set_outputs(vec![loss]);
19470
19471 let bwd = grad_with_loss(&g, &[init]);
19472
19473 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19474 for node in graph.nodes() {
19475 let name = match &node.op {
19476 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19477 _ => None,
19478 };
19479 if name == Some(want) {
19480 return node.id;
19481 }
19482 }
19483 panic!("no node named {want:?}");
19484 };
19485 let init_bwd = find_by_name(&bwd, "init");
19486 let xs_bwd = find_by_name(&bwd, "xs");
19487 let d_out_bwd = find_by_name(&bwd, "d_output");
19488
19489 let init_data = vec![0.5_f64, 0.0, -0.5];
19490 let xs_data: Vec<f64> = (0..length as usize * n)
19492 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
19493 .collect();
19494 let d_seed = [1.0_f64];
19495
19496 let (sched, mut arena) = prepare_f64(
19497 &bwd,
19498 &[
19499 (init_bwd, &init_data),
19500 (xs_bwd, &xs_data),
19501 (d_out_bwd, &d_seed),
19502 ],
19503 );
19504 execute_thunks(&sched, arena.raw_buf_mut());
19505 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
19506
19507 let h = 1e-6;
19508 let loss_at = |x0: &[f64]| -> f64 {
19509 let mut acc = x0.to_vec();
19510 for t in 0..length as usize {
19511 for j in 0..n {
19512 acc[j] += xs_data[t * n + j];
19513 }
19514 let mut a_copy = m_data.clone();
19515 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
19516 }
19517 acc.iter().sum()
19518 };
19519 for i in 0..n {
19520 let mut ip = init_data.to_vec();
19521 ip[i] += h;
19522 let mut im = init_data.to_vec();
19523 im[i] -= h;
19524 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
19525 assert!(
19526 (dinit[i] - fd).abs() < 1e-7,
19527 "FD dinit[{i}]: AD={} FD={}",
19528 dinit[i],
19529 fd
19530 );
19531 }
19532 }
19533
19534 #[test]
19542 fn scan_gradient_geometric_matches_closed_form() {
19543 use rlx_opt::autodiff::grad_with_loss;
19544 let n = 3usize;
19545 let length = 5u32;
19546
19547 let mut body = Graph::new("scan_grad_body");
19548 let x = body.input("carry", Shape::new(&[n], DType::F64));
19549 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
19550 let scale = body.add_node(
19551 Op::Constant { data: scale_bytes },
19552 vec![],
19553 Shape::new(&[n], DType::F64),
19554 );
19555 let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
19556 body.set_outputs(vec![next]);
19557
19558 let mut g = Graph::new("scan_grad_outer");
19559 let init = g.input("init", Shape::new(&[n], DType::F64));
19560 let final_x = g.scan(init, body, length);
19561 let loss = g.reduce(
19562 final_x,
19563 ReduceOp::Sum,
19564 vec![0],
19565 false,
19566 Shape::new(&[1], DType::F64),
19567 );
19568 g.set_outputs(vec![loss]);
19569
19570 let bwd = grad_with_loss(&g, &[init]);
19571 assert_eq!(bwd.outputs.len(), 2);
19572
19573 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19574 for node in graph.nodes() {
19575 let name = match &node.op {
19576 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19577 _ => None,
19578 };
19579 if name == Some(want) {
19580 return node.id;
19581 }
19582 }
19583 panic!("no node named {want:?}");
19584 };
19585 let init_bwd = find_by_name(&bwd, "init");
19586 let d_out_bwd = find_by_name(&bwd, "d_output");
19587
19588 let init_data = vec![1.0_f64; n];
19589 let d_seed = [1.0_f64];
19590 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
19591 execute_thunks(&sched, arena.raw_buf_mut());
19592 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
19593
19594 let want = 1.1_f64.powi(length as i32);
19595 for i in 0..n {
19596 assert!(
19597 (dinit[i] - want).abs() < 1e-12,
19598 "dinit[{i}] = {} want {}",
19599 dinit[i],
19600 want
19601 );
19602 }
19603
19604 let h = 1e-6;
19606 let loss_at = |x: &[f64]| -> f64 {
19607 let mut acc = x.to_vec();
19608 for _ in 0..length {
19609 for v in acc.iter_mut() {
19610 *v *= 1.1;
19611 }
19612 }
19613 acc.iter().sum()
19614 };
19615 let mut ip = init_data.clone();
19616 ip[0] += h;
19617 let mut im = init_data.clone();
19618 im[0] -= h;
19619 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
19620 assert!(
19621 (dinit[0] - fd).abs() < 1e-7,
19622 "FD dinit[0]: AD={} FD={}",
19623 dinit[0],
19624 fd
19625 );
19626 }
19627
19628 #[test]
19631 fn scan_gradient_backward_euler_matches_fd() {
19632 use rlx_opt::autodiff::grad_with_loss;
19633 let n = 4usize;
19634 let length = 3u32;
19635 let dt = 0.05_f64;
19636
19637 let mut m_data = vec![0.0_f64; n * n];
19638 for i in 0..n {
19639 m_data[i * n + i] = 1.0 + dt * 2.0;
19640 if i > 0 {
19641 m_data[i * n + (i - 1)] = -dt;
19642 }
19643 if i + 1 < n {
19644 m_data[i * n + (i + 1)] = -dt;
19645 }
19646 }
19647 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19648
19649 let mut body = Graph::new("be_grad_body");
19650 let x = body.input("x", Shape::new(&[n], DType::F64));
19651 let m = body.add_node(
19652 Op::Constant { data: m_bytes },
19653 vec![],
19654 Shape::new(&[n, n], DType::F64),
19655 );
19656 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
19657 body.set_outputs(vec![next]);
19658
19659 let mut g = Graph::new("be_grad_outer");
19660 let init = g.input("x0", Shape::new(&[n], DType::F64));
19661 let final_x = g.scan(init, body, length);
19662 let loss = g.reduce(
19663 final_x,
19664 ReduceOp::Sum,
19665 vec![0],
19666 false,
19667 Shape::new(&[1], DType::F64),
19668 );
19669 g.set_outputs(vec![loss]);
19670
19671 let bwd = grad_with_loss(&g, &[init]);
19672
19673 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19674 for node in graph.nodes() {
19675 let name = match &node.op {
19676 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19677 _ => None,
19678 };
19679 if name == Some(want) {
19680 return node.id;
19681 }
19682 }
19683 panic!("no node named {want:?}");
19684 };
19685 let init_bwd = find_by_name(&bwd, "x0");
19686 let d_out_bwd = find_by_name(&bwd, "d_output");
19687
19688 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
19689 let d_seed = [1.0_f64];
19690 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
19691 execute_thunks(&sched, arena.raw_buf_mut());
19692 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
19693
19694 let h = 1e-6;
19695 let loss_at = |x0: &[f64]| -> f64 {
19696 let mut acc = x0.to_vec();
19697 for _ in 0..length {
19698 let mut a_copy = m_data.clone();
19699 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
19700 }
19701 acc.iter().sum()
19702 };
19703 for i in 0..n {
19704 let mut ip = init_data.to_vec();
19705 ip[i] += h;
19706 let mut im = init_data.to_vec();
19707 im[i] -= h;
19708 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
19709 assert!(
19710 (dinit[i] - fd).abs() < 1e-7,
19711 "FD dinit[{i}]: AD={} FD={}",
19712 dinit[i],
19713 fd
19714 );
19715 }
19716 }
19717
19718 #[test]
19724 fn scan_trajectory_backward_euler_records_waveform() {
19725 let n = 4usize;
19726 let length = 5u32;
19727 let dt = 0.05_f64;
19728
19729 let mut m_data = vec![0.0_f64; n * n];
19730 for i in 0..n {
19731 m_data[i * n + i] = 1.0 + dt * 2.0;
19732 if i > 0 {
19733 m_data[i * n + (i - 1)] = -dt;
19734 }
19735 if i + 1 < n {
19736 m_data[i * n + (i + 1)] = -dt;
19737 }
19738 }
19739 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19740
19741 let mut body = Graph::new("be_traj_body");
19742 let x = body.input("x", Shape::new(&[n], DType::F64));
19743 let m = body.add_node(
19744 Op::Constant { data: m_bytes },
19745 vec![],
19746 Shape::new(&[n, n], DType::F64),
19747 );
19748 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
19749 body.set_outputs(vec![next]);
19750
19751 let mut g = Graph::new("be_traj_outer");
19752 let init = g.input("x0", Shape::new(&[n], DType::F64));
19753 let traj = g.scan_trajectory(init, body, length);
19754 g.set_outputs(vec![traj]);
19755
19756 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
19757 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
19758 execute_thunks(&sched, arena.raw_buf_mut());
19759 let got = read_arena_f64(&arena, traj, length as usize * n);
19760
19761 let mut want = Vec::<f64>::with_capacity(length as usize * n);
19763 let mut x_ref = init_data.to_vec();
19764 for _ in 0..length {
19765 let mut a_copy = m_data.clone();
19766 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
19767 want.extend_from_slice(&x_ref);
19768 }
19769 for i in 0..length as usize * n {
19770 assert!(
19771 (got[i] - want[i]).abs() < 1e-12,
19772 "got[{i}] = {} ref {}",
19773 got[i],
19774 want[i]
19775 );
19776 }
19777
19778 for t in 1..length as usize {
19781 let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
19782 let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
19783 assert!(
19784 curr <= prev + 1e-15,
19785 "mass should decay: row {} sum {prev}, row {t} sum {curr}",
19786 t - 1
19787 );
19788 }
19789
19790 let mut body2 = Graph::new("be_final_body");
19794 let x2 = body2.input("x", Shape::new(&[n], DType::F64));
19795 let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19796 let m2 = body2.add_node(
19797 Op::Constant { data: m_bytes2 },
19798 vec![],
19799 Shape::new(&[n, n], DType::F64),
19800 );
19801 let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
19802 body2.set_outputs(vec![next2]);
19803
19804 let mut g2 = Graph::new("be_final_outer");
19805 let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
19806 let final_x = g2.scan(init2, body2, length);
19807 g2.set_outputs(vec![final_x]);
19808 let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
19809 execute_thunks(&sched2, arena2.raw_buf_mut());
19810 let final_got = read_arena_f64(&arena2, final_x, n);
19811
19812 let last_row = &got[(length as usize - 1) * n..length as usize * n];
19813 for i in 0..n {
19814 assert!(
19815 (last_row[i] - final_got[i]).abs() < 1e-15,
19816 "last trajectory row[{i}] = {} vs final-scan = {}",
19817 last_row[i],
19818 final_got[i]
19819 );
19820 }
19821 }
19822
19823 #[test]
19829 fn scan_backward_euler_heat_f64() {
19830 let n = 4usize;
19831 let length = 5u32;
19832 let dt = 0.05_f64;
19833
19834 let mut m_data = vec![0.0_f64; n * n];
19837 for i in 0..n {
19838 m_data[i * n + i] = 1.0 + dt * 2.0;
19839 if i > 0 {
19840 m_data[i * n + (i - 1)] = -dt;
19841 }
19842 if i + 1 < n {
19843 m_data[i * n + (i + 1)] = -dt;
19844 }
19845 }
19846 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19847
19848 let mut body = Graph::new("be_body");
19849 let x = body.input("x", Shape::new(&[n], DType::F64));
19850 let m = body.add_node(
19851 Op::Constant { data: m_bytes },
19852 vec![],
19853 Shape::new(&[n, n], DType::F64),
19854 );
19855 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
19856 body.set_outputs(vec![next]);
19857
19858 let mut g = Graph::new("be_outer");
19859 let init = g.input("x0", Shape::new(&[n], DType::F64));
19860 let final_x = g.scan(init, body, length);
19861 g.set_outputs(vec![final_x]);
19862
19863 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
19865 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
19866 execute_thunks(&sched, arena.raw_buf_mut());
19867 let got = read_arena_f64(&arena, final_x, n);
19868
19869 let mut ref_x = init_data.to_vec();
19871 for _ in 0..length {
19872 let mut a_copy = m_data.clone();
19873 crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
19874 }
19875 for i in 0..n {
19876 assert!(
19877 (got[i] - ref_x[i]).abs() < 1e-12,
19878 "got[{i}] = {} ref {}",
19879 got[i],
19880 ref_x[i]
19881 );
19882 }
19883 let mass: f64 = got.iter().sum();
19888 assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
19889 }
19890
19891 #[test]
19895 fn dense_solve_f64_multi_rhs_forward() {
19896 let n = 3usize;
19897 let k = 2usize;
19898 let mut g = Graph::new("solve_multi_rhs");
19899 let a = g.input("A", Shape::new(&[n, n], DType::F64));
19900 let b = g.input("B", Shape::new(&[n, k], DType::F64));
19901 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19902 g.set_outputs(vec![x]);
19903
19904 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19905 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19906 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
19907 execute_thunks(&sched, arena.raw_buf_mut());
19908 let x_got = read_arena_f64(&arena, x, n * k);
19909 for c in 0..k {
19910 for i in 0..n {
19911 let mut acc = 0.0_f64;
19912 for j in 0..n {
19913 acc += a_data[i * n + j] * x_got[j * k + c];
19914 }
19915 let want = b_data[i * k + c];
19916 assert!(
19917 (acc - want).abs() < 1e-10,
19918 "col {c} row {i}: got {acc} want {want}"
19919 );
19920 }
19921 }
19922 }
19923
19924 #[test]
19927 fn dense_solve_f64_multi_rhs_gradient() {
19928 use rlx_opt::autodiff::grad_with_loss;
19929 let n = 3usize;
19930 let k = 2usize;
19931 let mut g = Graph::new("solve_mrhs_grad");
19932 let a = g.param("A", Shape::new(&[n, n], DType::F64));
19933 let b = g.input("B", Shape::new(&[n, k], DType::F64));
19934 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19935 let loss = g.reduce(
19936 x,
19937 ReduceOp::Sum,
19938 vec![0, 1],
19939 false,
19940 Shape::new(&[1], DType::F64),
19941 );
19942 g.set_outputs(vec![loss]);
19943
19944 let bwd = grad_with_loss(&g, &[a, b]);
19945 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19946 for node in graph.nodes() {
19947 let name = match &node.op {
19948 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19949 _ => None,
19950 };
19951 if name == Some(want) {
19952 return node.id;
19953 }
19954 }
19955 panic!("no node named {want:?}");
19956 };
19957 let a_bwd = find_by_name(&bwd, "A");
19958 let b_bwd = find_by_name(&bwd, "B");
19959 let d_out = find_by_name(&bwd, "d_output");
19960
19961 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19962 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19963 let d_seed = [1.0_f64];
19964
19965 let (sched, mut arena) = prepare_f64(
19966 &bwd,
19967 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
19968 );
19969 execute_thunks(&sched, arena.raw_buf_mut());
19970 let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
19971 let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
19972
19973 let mut x_ref = b_data;
19975 {
19976 let mut a_copy = a_data;
19977 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
19978 }
19979 let mut at = [0.0_f64; 9];
19980 for i in 0..n {
19981 for j in 0..n {
19982 at[i * n + j] = a_data[j * n + i];
19983 }
19984 }
19985 let mut ones_nk = vec![1.0_f64; n * k];
19986 crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
19987 let db_ref = ones_nk;
19988 let mut da_ref = [0.0_f64; 9];
19989 for i in 0..n {
19990 for j in 0..n {
19991 let mut acc = 0.0_f64;
19992 for c in 0..k {
19993 acc += db_ref[i * k + c] * x_ref[j * k + c];
19994 }
19995 da_ref[i * n + j] = -acc;
19996 }
19997 }
19998 for i in 0..n * k {
19999 assert!(
20000 (db_got[i] - db_ref[i]).abs() < 1e-10,
20001 "dB[{i}]: got {} want {}",
20002 db_got[i],
20003 db_ref[i]
20004 );
20005 }
20006 for i in 0..n * n {
20007 assert!(
20008 (da_got[i] - da_ref[i]).abs() < 1e-10,
20009 "dA[{i}]: got {} want {}",
20010 da_got[i],
20011 da_ref[i]
20012 );
20013 }
20014
20015 let h = 1e-6;
20017 let mut bp = b_data;
20018 bp[0] += h;
20019 let mut bm = b_data;
20020 bm[0] -= h;
20021 let xp = {
20022 let mut a_copy = a_data;
20023 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
20024 bp
20025 };
20026 let xm = {
20027 let mut a_copy = a_data;
20028 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
20029 bm
20030 };
20031 let lp: f64 = xp.iter().sum();
20032 let lm: f64 = xm.iter().sum();
20033 let fd = (lp - lm) / (2.0 * h);
20034 assert!(
20035 (db_got[0] - fd).abs() < 1e-7,
20036 "FD dB[0,0]: AD={} FD={}",
20037 db_got[0],
20038 fd
20039 );
20040 }
20041
20042 #[test]
20044 fn dense_solve_f64_multi_rhs_jvp() {
20045 use rlx_opt::autodiff_fwd::jvp;
20046 let n = 3usize;
20047 let k = 2usize;
20048 let mut g = Graph::new("solve_mrhs_jvp");
20049 let a = g.input("A", Shape::new(&[n, n], DType::F64));
20050 let b = g.input("B", Shape::new(&[n, k], DType::F64));
20051 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
20052 g.set_outputs(vec![x]);
20053
20054 let jg = jvp(&g, &[b]);
20055 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
20056 for node in graph.nodes() {
20057 let name = match &node.op {
20058 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20059 _ => None,
20060 };
20061 if name == Some(want) {
20062 return node.id;
20063 }
20064 }
20065 panic!("no node named {want:?}");
20066 };
20067 let a_id = find_by_name(&jg, "A");
20068 let b_id = find_by_name(&jg, "B");
20069 let tb_id = find_by_name(&jg, "tangent_B");
20070
20071 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
20072 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
20073 let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
20074
20075 let (sched, mut arena) =
20076 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
20077 execute_thunks(&sched, arena.raw_buf_mut());
20078 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
20079
20080 let mut a_copy = a_data;
20081 let mut tb_copy = tb_data;
20082 crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
20083 for i in 0..n * k {
20084 assert!(
20085 (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
20086 "t_X[{i}]: AD={} ref={}",
20087 tangent_x[i],
20088 tb_copy[i]
20089 );
20090 }
20091
20092 let h = 1e-6;
20093 let mut bp = b_data;
20094 let mut bm = b_data;
20095 for i in 0..n * k {
20096 bp[i] += h * tb_data[i];
20097 bm[i] -= h * tb_data[i];
20098 }
20099 let xp = {
20100 let mut a_copy = a_data;
20101 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
20102 bp
20103 };
20104 let xm = {
20105 let mut a_copy = a_data;
20106 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
20107 bm
20108 };
20109 for i in 0..n * k {
20110 let fd = (xp[i] - xm[i]) / (2.0 * h);
20111 assert!(
20112 (tangent_x[i] - fd).abs() < 1e-7,
20113 "FD t_X[{i}]: AD={} FD={}",
20114 tangent_x[i],
20115 fd
20116 );
20117 }
20118 }
20119
20120 #[test]
20127 fn jvp_dense_solve_b_runs_and_matches_fd() {
20128 use rlx_opt::autodiff_fwd::jvp;
20129 let n = 3usize;
20130
20131 let mut g = Graph::new("jvp_b_e2e");
20133 let a = g.input("A", Shape::new(&[n, n], DType::F64));
20134 let b = g.input("b", Shape::new(&[n], DType::F64));
20135 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
20136 g.set_outputs(vec![x]);
20137
20138 let jg = jvp(&g, &[b]);
20140 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
20142 for node in graph.nodes() {
20143 let name = match &node.op {
20144 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20145 _ => None,
20146 };
20147 if name == Some(want) {
20148 return node.id;
20149 }
20150 }
20151 panic!("no node named {want:?}");
20152 };
20153 let a_id = find_by_name(&jg, "A");
20154 let b_id = find_by_name(&jg, "b");
20155 let tb_id = find_by_name(&jg, "tangent_b");
20156
20157 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
20158 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
20159 let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
20161
20162 let (sched, mut arena) =
20163 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
20164 execute_thunks(&sched, arena.raw_buf_mut());
20165
20166 let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
20168 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
20169
20170 let t_x_ref = {
20172 let mut a = a_data;
20173 let mut tb = tb_data;
20174 let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
20175 assert_eq!(info, 0);
20176 tb
20177 };
20178 for i in 0..n {
20179 assert!(
20180 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
20181 "t_x[{i}]: got {} want {}",
20182 tangent_x[i],
20183 t_x_ref[i]
20184 );
20185 }
20186
20187 let h = 1e-6;
20189 let mut bp = b_data;
20190 let mut bm = b_data;
20191 for i in 0..n {
20192 bp[i] += h * tb_data[i];
20193 bm[i] -= h * tb_data[i];
20194 }
20195 let xp = {
20196 let mut a = a_data;
20197 let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
20198 assert_eq!(info, 0);
20199 bp
20200 };
20201 let xm = {
20202 let mut a = a_data;
20203 let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
20204 assert_eq!(info, 0);
20205 bm
20206 };
20207 let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
20208 for i in 0..n {
20209 assert!(
20210 (tangent_x[i] - fd[i]).abs() < 1e-7,
20211 "FD mismatch t_x[{i}]: AD={} FD={}",
20212 tangent_x[i],
20213 fd[i]
20214 );
20215 }
20216 let primal_ref = {
20218 let mut a = a_data;
20219 let mut b = b_data;
20220 crate::blas::dgesv(&mut a, &mut b, n, 1);
20221 b
20222 };
20223 for i in 0..n {
20224 assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
20225 }
20226 }
20227
20228 #[test]
20234 fn jvp_dense_solve_a_runs_and_matches_fd() {
20235 use rlx_opt::autodiff_fwd::jvp;
20236 let n = 3usize;
20237
20238 let mut g = Graph::new("jvp_a_e2e");
20239 let a = g.input("A", Shape::new(&[n, n], DType::F64));
20240 let b = g.input("b", Shape::new(&[n], DType::F64));
20241 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
20242 g.set_outputs(vec![x]);
20243
20244 let jg = jvp(&g, &[a]);
20245 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
20246 for node in graph.nodes() {
20247 let name = match &node.op {
20248 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20249 _ => None,
20250 };
20251 if name == Some(want) {
20252 return node.id;
20253 }
20254 }
20255 panic!("no node named {want:?}");
20256 };
20257 let a_id = find_by_name(&jg, "A");
20258 let b_id = find_by_name(&jg, "b");
20259 let ta_id = find_by_name(&jg, "tangent_A");
20260
20261 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
20262 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
20263 let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
20265
20266 let (sched, mut arena) =
20267 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
20268 execute_thunks(&sched, arena.raw_buf_mut());
20269
20270 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
20271
20272 let x_ref = {
20274 let mut a = a_data;
20275 let mut b = b_data;
20276 crate::blas::dgesv(&mut a, &mut b, n, 1);
20277 b
20278 };
20279 let mut prod = [0.0_f64; 3];
20280 for i in 0..n {
20281 for j in 0..n {
20282 prod[i] += ta_data[i * n + j] * x_ref[j];
20283 }
20284 }
20285 let t_x_ref = {
20286 let mut a = a_data;
20287 let mut p = prod;
20288 crate::blas::dgesv(&mut a, &mut p, n, 1);
20289 [-p[0], -p[1], -p[2]]
20290 };
20291 for i in 0..n {
20292 assert!(
20293 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
20294 "closed-form t_x[{i}]: AD={} ref={}",
20295 tangent_x[i],
20296 t_x_ref[i]
20297 );
20298 }
20299
20300 let h = 1e-6;
20302 let mut ap = a_data;
20303 let mut am = a_data;
20304 for i in 0..n * n {
20305 ap[i] += h * ta_data[i];
20306 am[i] -= h * ta_data[i];
20307 }
20308 let xp = {
20309 let mut a = ap;
20310 let mut b = b_data;
20311 crate::blas::dgesv(&mut a, &mut b, n, 1);
20312 b
20313 };
20314 let xm = {
20315 let mut a = am;
20316 let mut b = b_data;
20317 crate::blas::dgesv(&mut a, &mut b, n, 1);
20318 b
20319 };
20320 for i in 0..n {
20321 let fd = (xp[i] - xm[i]) / (2.0 * h);
20322 assert!(
20323 (tangent_x[i] - fd).abs() < 1e-7,
20324 "FD t_x[{i}]: AD={} FD={}",
20325 tangent_x[i],
20326 fd
20327 );
20328 }
20329 }
20330
20331 #[test]
20337 fn q_conv2d_matches_reference() {
20338 use rlx_ir::Philox4x32;
20339 let n = 1usize;
20341 let c_in = 2usize;
20342 let h = 5usize;
20343 let w_in = 5usize;
20344 let c_out = 3usize;
20345 let kh = 3usize;
20346 let kw = 3usize;
20347 let ph = 1usize;
20348 let pw = 1usize;
20349 let sh = 1usize;
20350 let sw = 1usize;
20351 let h_out = (h + 2 * ph - kh) / sh + 1;
20352 let w_out = (w_in + 2 * pw - kw) / sw + 1;
20353
20354 let x_scale = 0.04f32;
20355 let w_scale = 0.02f32;
20356 let out_scale = 0.5f32;
20357 let mult = x_scale * w_scale / out_scale;
20358
20359 let mut rng = Philox4x32::new(2099);
20360 let mut xf = vec![0f32; n * c_in * h * w_in];
20361 rng.fill_normal(&mut xf);
20362 let mut wf = vec![0f32; c_out * c_in * kh * kw];
20363 rng.fill_normal(&mut wf);
20364 let xq: Vec<i8> = xf
20365 .iter()
20366 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
20367 .collect();
20368 let wq: Vec<i8> = wf
20369 .iter()
20370 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
20371 .collect();
20372 let bias: Vec<i32> = vec![0i32; c_out];
20373
20374 let mut g = Graph::new("qconv");
20375 let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
20376 let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
20377 let bn = g.input("b", Shape::new(&[c_out], DType::I32));
20378 let out = g.q_conv2d(
20379 xn,
20380 wn,
20381 bn,
20382 vec![kh, kw],
20383 vec![sh, sw],
20384 vec![ph, pw],
20385 vec![1, 1],
20386 1,
20387 0,
20388 0,
20389 0,
20390 mult,
20391 Shape::new(&[n, c_out, h_out, w_out], DType::I8),
20392 );
20393 g.set_outputs(vec![out]);
20394
20395 let plan = rlx_opt::memory::plan_memory(&g);
20396 let mut arena = crate::arena::Arena::from_plan(plan);
20397 let sched = compile_thunks(&g, &arena);
20398 let xn_off = arena.byte_offset(xn);
20401 let wn_off = arena.byte_offset(wn);
20402 let bn_off = arena.byte_offset(bn);
20403 let out_off = arena.byte_offset(out);
20404 let buf = arena.raw_buf_mut();
20405 unsafe {
20406 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
20407 for (i, &v) in xq.iter().enumerate() {
20408 *p.add(i) = v;
20409 }
20410 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
20411 for (i, &v) in wq.iter().enumerate() {
20412 *p.add(i) = v;
20413 }
20414 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
20415 for (i, &v) in bias.iter().enumerate() {
20416 *p.add(i) = v;
20417 }
20418 }
20419 execute_thunks(&sched, arena.raw_buf_mut());
20420 let out_q: Vec<i8> = unsafe {
20421 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
20422 (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
20423 };
20424
20425 let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
20427 for ni in 0..n {
20428 for co in 0..c_out {
20429 for ho in 0..h_out {
20430 for wo in 0..w_out {
20431 let mut acc: i32 = 0;
20432 for ci in 0..c_in {
20433 for ki in 0..kh {
20434 for kj in 0..kw {
20435 let hi = ho * sh + ki;
20436 let wi = wo * sw + kj;
20437 if hi < ph || wi < pw {
20438 continue;
20439 }
20440 let hi = hi - ph;
20441 let wi = wi - pw;
20442 if hi >= h || wi >= w_in {
20443 continue;
20444 }
20445 let xv =
20446 xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
20447 let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
20448 acc += xv * wv;
20449 }
20450 }
20451 }
20452 let r = (acc as f32 * mult).round() as i32;
20453 let r = r.clamp(-128, 127) as i8;
20454 out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
20455 }
20456 }
20457 }
20458 }
20459
20460 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
20461 assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
20462 }
20463 }
20464
20465 #[test]
20473 fn q_matmul_matches_fake_quant_reference() {
20474 use rlx_ir::Philox4x32;
20475 let m = 3usize;
20476 let k = 8usize;
20477 let n = 5usize;
20478 let mut rng = Philox4x32::new(2031);
20479
20480 let x_scale = 0.05f32;
20482 let w_scale = 0.03f32;
20483 let out_scale = 0.4f32;
20484 let mult = x_scale * w_scale / out_scale;
20485 let mut xf = vec![0f32; m * k];
20486 rng.fill_normal(&mut xf);
20487 let mut wf = vec![0f32; k * n];
20488 rng.fill_normal(&mut wf);
20489 let xq: Vec<i8> = xf
20490 .iter()
20491 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
20492 .collect();
20493 let wq: Vec<i8> = wf
20494 .iter()
20495 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
20496 .collect();
20497 let bias: Vec<i32> = vec![0i32; n];
20498
20499 let _f = DType::F32;
20501 let mut g_q = Graph::new("qmm_direct");
20502 let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
20503 let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
20504 let bn = g_q.input("b", Shape::new(&[n], DType::I32));
20505 let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
20506 g_q.set_outputs(vec![out]);
20507 let plan = rlx_opt::memory::plan_memory(&g_q);
20508 let mut arena = crate::arena::Arena::from_plan(plan);
20509 let sched = compile_thunks(&g_q, &arena);
20510
20511 let xn_off = arena.byte_offset(xn);
20513 let wn_off = arena.byte_offset(wn);
20514 let bn_off = arena.byte_offset(bn);
20515 let out_off = arena.byte_offset(out);
20516 let buf = arena.raw_buf_mut();
20517 unsafe {
20518 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
20519 for (i, &v) in xq.iter().enumerate() {
20520 *p.add(i) = v;
20521 }
20522 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
20523 for (i, &v) in wq.iter().enumerate() {
20524 *p.add(i) = v;
20525 }
20526 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
20527 for (i, &v) in bias.iter().enumerate() {
20528 *p.add(i) = v;
20529 }
20530 }
20531 execute_thunks(&sched, arena.raw_buf_mut());
20532 let out_q: Vec<i8> = unsafe {
20533 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
20534 (0..m * n).map(|i| *p.add(i)).collect()
20535 };
20536
20537 let mut out_ref = vec![0i8; m * n];
20542 for mi in 0..m {
20543 for ni in 0..n {
20544 let mut acc: i32 = 0;
20545 for ki in 0..k {
20546 acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
20547 }
20548 let r = (acc as f32 * mult).round() as i32;
20549 out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
20550 }
20551 }
20552
20553 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
20554 assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
20555 }
20556 }
20557
20558 #[test]
20563 fn quantize_dequantize_round_trip() {
20564 use rlx_ir::Philox4x32;
20565 let len = 64;
20566 let mut rng = Philox4x32::new(2027);
20567 let mut x = vec![0f32; len];
20568 rng.fill_normal(&mut x);
20569 x[0] = 999.0;
20572 x[1] = -999.0;
20573
20574 let scale = 0.05f32;
20575 let zp = 3i32;
20576
20577 let f = DType::F32;
20578 let mut g = Graph::new("qdq");
20579 let xn = g.input("x", Shape::new(&[len], f));
20580 let q = g.quantize(xn, scale, zp);
20581 let dq = g.dequantize(q, scale, zp);
20582 g.set_outputs(vec![dq]);
20583
20584 let plan = rlx_opt::memory::plan_memory(&g);
20585 let mut arena = crate::arena::Arena::from_plan(plan);
20586 let sched = compile_thunks(&g, &arena);
20587 let xn_off = arena.byte_offset(xn);
20588 let dq_off = arena.byte_offset(dq);
20589 let buf = arena.raw_buf_mut();
20590 unsafe {
20591 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
20592 for (i, &v) in x.iter().enumerate() {
20593 *p.add(i) = v;
20594 }
20595 }
20596 execute_thunks(&sched, arena.raw_buf_mut());
20597 let out: Vec<f32> = unsafe {
20598 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
20599 (0..len).map(|i| *p.add(i)).collect()
20600 };
20601
20602 let sat_pos = (127 - zp) as f32 * scale;
20605 let sat_neg = (-128 - zp) as f32 * scale;
20606 assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
20607 assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
20608
20609 for i in 2..len {
20612 assert!(
20613 (out[i] - x[i]).abs() <= scale + 1e-5,
20614 "qdq[{i}]: {} → {}, scale={scale}",
20615 x[i],
20616 out[i]
20617 );
20618 }
20619 }
20620
20621 #[test]
20627 fn quantize_per_channel_round_trip() {
20628 let c = 4usize;
20629 let inner = 5usize;
20630 let mags = [0.01f32, 0.5, 5.0, 50.0];
20633 let mut x = vec![0f32; c * inner];
20634 for ci in 0..c {
20635 for ii in 0..inner {
20636 x[ci * inner + ii] = match ii {
20640 0 => -mags[ci],
20641 1 => 0.0,
20642 2 => mags[ci],
20643 3 => mags[ci] * 1000.0, _ => -mags[ci] * 1000.0, };
20646 }
20647 }
20648 let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
20649 let zps: Vec<i32> = vec![0, 0, 0, 0];
20650
20651 let f = DType::F32;
20652 let mut g = Graph::new("qdq_pc");
20653 let xn = g.input("x", Shape::new(&[c, inner], f));
20654 let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
20655 let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
20656 g.set_outputs(vec![dq]);
20657
20658 let plan = rlx_opt::memory::plan_memory(&g);
20659 let mut arena = crate::arena::Arena::from_plan(plan);
20660 let sched = compile_thunks(&g, &arena);
20661 let xn_off = arena.byte_offset(xn);
20662 let dq_off = arena.byte_offset(dq);
20663 let buf = arena.raw_buf_mut();
20664 unsafe {
20665 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
20666 for (i, &v) in x.iter().enumerate() {
20667 *p.add(i) = v;
20668 }
20669 }
20670 execute_thunks(&sched, arena.raw_buf_mut());
20671 let out: Vec<f32> = unsafe {
20672 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
20673 (0..c * inner).map(|i| *p.add(i)).collect()
20674 };
20675
20676 for ci in 0..c {
20677 for ii in 0..3 {
20680 let idx = ci * inner + ii;
20681 assert!(
20682 (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
20683 "ch {ci} idx {ii}: {} vs {}",
20684 x[idx],
20685 out[idx]
20686 );
20687 }
20688 let sat_pos = 127.0 * scales[ci];
20690 let sat_neg = -128.0 * scales[ci];
20691 assert!(
20692 (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
20693 "ch {ci} +sat: {}",
20694 out[ci * inner + 3]
20695 );
20696 assert!(
20697 (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
20698 "ch {ci} -sat: {}",
20699 out[ci * inner + 4]
20700 );
20701 }
20702 }
20703
20704 #[test]
20710 fn activation_backward_matches_numerical_per_kind() {
20711 use rlx_ir::Philox4x32;
20712 use rlx_ir::op::Activation;
20713 let mut rng = Philox4x32::new(91);
20714 let len = 32;
20715 let mut x_pos = vec![0f32; len];
20720 rng.fill_normal(&mut x_pos);
20721 for v in x_pos.iter_mut() {
20722 *v = v.abs() + 0.5;
20723 }
20724 let mut x_any = vec![0f32; len];
20725 rng.fill_normal(&mut x_any);
20726 let mut dy = vec![0f32; len];
20727 rng.fill_normal(&mut dy);
20728
20729 for &(kind, x_data, eps, tol) in &[
20730 (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
20731 (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
20732 (Activation::Silu, &x_any[..], 1e-3, 5e-3),
20733 (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
20734 (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
20735 (Activation::Exp, &x_any[..], 1e-4, 5e-3),
20736 (Activation::Log, &x_pos[..], 1e-4, 5e-3),
20737 (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
20738 (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
20739 (Activation::Neg, &x_any[..], 1e-3, 5e-4),
20740 ] {
20741 let f = DType::F32;
20742 let mut g = Graph::new("act_bw");
20743 let xn = g.input("x", Shape::new(&[len], f));
20744 let dyn_ = g.input("dy", Shape::new(&[len], f));
20745 let dx = g.activation_backward(kind, xn, dyn_);
20746 g.set_outputs(vec![dx]);
20747
20748 let plan = rlx_opt::memory::plan_memory(&g);
20749 let mut arena = crate::arena::Arena::from_plan(plan);
20750 let sched = compile_thunks(&g, &arena);
20751
20752 let xn_off = arena.byte_offset(xn);
20753 let dyn_off = arena.byte_offset(dyn_);
20754 let dx_off = arena.byte_offset(dx);
20755 let buf = arena.raw_buf_mut();
20756 unsafe {
20757 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
20758 for (i, &v) in x_data.iter().enumerate() {
20759 *p.add(i) = v;
20760 }
20761 let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
20762 for (i, &v) in dy.iter().enumerate() {
20763 *p.add(i) = v;
20764 }
20765 }
20766 execute_thunks(&sched, arena.raw_buf_mut());
20767 let analytical: Vec<f32> = unsafe {
20768 let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
20769 (0..len).map(|i| *p.add(i)).collect()
20770 };
20771
20772 let act_apply = |kind: Activation, x: f32| -> f32 {
20775 match kind {
20776 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
20777 Activation::Tanh => x.tanh(),
20778 Activation::Silu => x / (1.0 + (-x).exp()),
20779 Activation::Gelu => {
20780 const INV_SQRT2: f32 = 0.707_106_77;
20782 0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
20783 }
20784 Activation::GeluApprox => {
20785 const C: f32 = 0.797_884_6;
20786 const A: f32 = 0.044_715;
20787 let inner = C * (x + A * x * x * x);
20788 0.5 * x * (1.0 + inner.tanh())
20789 }
20790 Activation::Exp => x.exp(),
20791 Activation::Log => x.ln(),
20792 Activation::Sqrt => x.sqrt(),
20793 Activation::Rsqrt => 1.0 / x.sqrt(),
20794 Activation::Neg => -x,
20795 Activation::Relu => x.max(0.0),
20796 Activation::Abs => x.abs(),
20797 Activation::Round => x.round(),
20798 Activation::Sin => x.sin(),
20799 Activation::Cos => x.cos(),
20800 Activation::Tan => x.tan(),
20801 Activation::Atan => x.atan(),
20802 }
20803 };
20804 for i in 0..len {
20805 let xv = x_data[i];
20806 let plus = act_apply(kind, xv + eps);
20807 let minus = act_apply(kind, xv - eps);
20808 let num = (plus - minus) / (2.0 * eps) * dy[i];
20809 assert!(
20810 (analytical[i] - num).abs() < tol,
20811 "{kind:?}[{i}]: analytical {} vs numerical {num}",
20812 analytical[i]
20813 );
20814 }
20815 }
20816 }
20817
20818 #[test]
20822 fn matmul_3d_gradient_matches_numerical() {
20823 use rlx_ir::Philox4x32;
20824 let batch = 2usize;
20825 let m = 3usize;
20826 let k = 4usize;
20827 let n = 5usize;
20828 let mut rng = Philox4x32::new(101);
20829 let mut a_data = vec![0f32; batch * m * k];
20830 rng.fill_normal(&mut a_data);
20831 let mut b_data = vec![0f32; batch * k * n];
20832 rng.fill_normal(&mut b_data);
20833
20834 let f = DType::F32;
20835 let mut fwd = Graph::new("matmul_3d");
20836 let an = fwd.input("a", Shape::new(&[batch, m, k], f));
20837 let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
20838 let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
20839 let loss = fwd.add_node(
20840 Op::Reduce {
20841 op: ReduceOp::Sum,
20842 axes: vec![0, 1, 2],
20843 keep_dim: false,
20844 },
20845 vec![mm],
20846 Shape::from_dims(&[], f),
20847 );
20848 fwd.set_outputs(vec![loss]);
20849
20850 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
20851 let d_out = bwd_graph
20852 .nodes()
20853 .iter()
20854 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20855 .map(|n| n.id)
20856 .unwrap();
20857
20858 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20859 let mut arena = crate::arena::Arena::from_plan(plan);
20860 let sched = compile_thunks(&bwd_graph, &arena);
20861 for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
20862 let off = arena.byte_offset(id);
20863 let buf = arena.raw_buf_mut();
20864 unsafe {
20865 let p = buf.as_mut_ptr().add(off) as *mut f32;
20866 for (i, &v) in data.iter().enumerate() {
20867 *p.add(i) = v;
20868 }
20869 }
20870 }
20871 execute_thunks(&sched, arena.raw_buf_mut());
20872 let gb_id = bwd_graph.outputs[1];
20873 let g_b: Vec<f32> = unsafe {
20874 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
20875 (0..batch * k * n).map(|i| *p.add(i)).collect()
20876 };
20877
20878 let forward_loss = |b_vals: &[f32]| -> f32 {
20880 let mut out = vec![0f32; batch * m * n];
20881 for bi in 0..batch {
20882 for mi in 0..m {
20883 for ni in 0..n {
20884 let mut acc = 0f32;
20885 for ki in 0..k {
20886 acc +=
20887 a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
20888 }
20889 out[bi * m * n + mi * n + ni] = acc;
20890 }
20891 }
20892 }
20893 out.iter().sum()
20894 };
20895 let eps = 1e-3f32;
20896 let mut bp_p = b_data.clone();
20897 let mut g_b_num = vec![0f32; b_data.len()];
20898 for i in 0..b_data.len() {
20899 let s = bp_p[i];
20900 bp_p[i] = s + eps;
20901 let lp = forward_loss(&bp_p);
20902 bp_p[i] = s - eps;
20903 let lm = forward_loss(&bp_p);
20904 bp_p[i] = s;
20905 g_b_num[i] = (lp - lm) / (2.0 * eps);
20906 }
20907 for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
20908 assert!(
20909 (a - n).abs() < 5e-3,
20910 "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
20911 );
20912 }
20913 }
20914
20915 #[test]
20921 fn softmax_gradient_matches_numerical() {
20922 use rlx_ir::Philox4x32;
20923 let n = 3usize;
20924 let c = 5usize;
20925 let mut rng = Philox4x32::new(57);
20926 let mut x_data = vec![0f32; n * c];
20927 rng.fill_normal(&mut x_data);
20928
20929 let f = DType::F32;
20930 let mut fwd = Graph::new("softmax_only");
20931 let xn = fwd.input("x", Shape::new(&[n, c], f));
20932 let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
20933 let loss = fwd.add_node(
20937 Op::Reduce {
20938 op: ReduceOp::Sum,
20939 axes: vec![0, 1],
20940 keep_dim: false,
20941 },
20942 vec![sm],
20943 Shape::from_dims(&[], f),
20944 );
20945 fwd.set_outputs(vec![loss]);
20946
20947 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
20951 let d_out = bwd_graph
20952 .nodes()
20953 .iter()
20954 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20955 .map(|n| n.id)
20956 .unwrap();
20957
20958 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20959 let mut arena = crate::arena::Arena::from_plan(plan);
20960 let sched = compile_thunks(&bwd_graph, &arena);
20961 for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
20962 let off = arena.byte_offset(id);
20963 let buf = arena.raw_buf_mut();
20964 unsafe {
20965 let p = buf.as_mut_ptr().add(off) as *mut f32;
20966 for (i, &v) in data.iter().enumerate() {
20967 *p.add(i) = v;
20968 }
20969 }
20970 }
20971 execute_thunks(&sched, arena.raw_buf_mut());
20972 let g_x_id = bwd_graph.outputs[1];
20973 let g_x: Vec<f32> = unsafe {
20974 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
20975 (0..n * c).map(|i| *p.add(i)).collect()
20976 };
20977
20978 let forward_loss = |x: &[f32]| -> f32 {
20982 let mut total = 0f32;
20983 for ni in 0..n {
20984 let row = &x[ni * c..(ni + 1) * c];
20985 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
20986 let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
20987 for &v in row {
20988 total += (v - m).exp() / denom;
20989 }
20990 }
20991 total
20992 };
20993 let eps = 1e-3f32;
20994 let mut p = x_data.clone();
20995 for i in 0..x_data.len() {
20996 let s = p[i];
20997 p[i] = s + eps;
20998 let lp = forward_loss(&p);
20999 p[i] = s - eps;
21000 let lm = forward_loss(&p);
21001 p[i] = s;
21002 let num = (lp - lm) / (2.0 * eps);
21003 assert!(
21004 (g_x[i] - num).abs() < 5e-3,
21005 "softmax g_x[{i}]: analytical {} vs numerical {num}",
21006 g_x[i]
21007 );
21008 }
21009 }
21010
21011 #[test]
21016 fn layer_norm_gradient_matches_numerical() {
21017 use rlx_ir::Philox4x32;
21018 let rows = 3usize;
21019 let h = 6usize;
21020 let mut rng = Philox4x32::new(1009);
21021 let mut x_data = vec![0f32; rows * h];
21022 rng.fill_normal(&mut x_data);
21023 let mut g_data = vec![0f32; h];
21024 rng.fill_normal(&mut g_data);
21025 for v in g_data.iter_mut() {
21026 *v = v.abs() + 0.5;
21027 }
21028 let mut b_data = vec![0f32; h];
21029 rng.fill_normal(&mut b_data);
21030 let eps = 1e-5f32;
21031
21032 let f = DType::F32;
21033 let mut fwd = Graph::new("ln_only");
21034 let xn = fwd.input("x", Shape::new(&[rows, h], f));
21035 let gp = fwd.param("gamma", Shape::new(&[h], f));
21036 let bp = fwd.param("beta", Shape::new(&[h], f));
21037 let ln = fwd.add_node(
21038 Op::LayerNorm { axis: -1, eps },
21039 vec![xn, gp, bp],
21040 Shape::new(&[rows, h], f),
21041 );
21042 let loss = fwd.add_node(
21043 Op::Reduce {
21044 op: ReduceOp::Sum,
21045 axes: vec![0, 1],
21046 keep_dim: false,
21047 },
21048 vec![ln],
21049 Shape::from_dims(&[], f),
21050 );
21051 fwd.set_outputs(vec![loss]);
21052
21053 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
21054 let d_out = bwd_graph
21055 .nodes()
21056 .iter()
21057 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
21058 .map(|n| n.id)
21059 .unwrap();
21060
21061 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
21062 let mut arena = crate::arena::Arena::from_plan(plan);
21063 let sched = compile_thunks(&bwd_graph, &arena);
21064 for &(id, data) in &[
21065 (xn, &x_data),
21066 (gp, &g_data),
21067 (bp, &b_data),
21068 (d_out, &vec![1.0f32]),
21069 ] {
21070 let off = arena.byte_offset(id);
21071 let buf = arena.raw_buf_mut();
21072 unsafe {
21073 let p = buf.as_mut_ptr().add(off) as *mut f32;
21074 for (i, &v) in data.iter().enumerate() {
21075 *p.add(i) = v;
21076 }
21077 }
21078 }
21079 execute_thunks(&sched, arena.raw_buf_mut());
21080 let read = |id: NodeId, n: usize| -> Vec<f32> {
21081 let off = arena.byte_offset(id);
21082 unsafe {
21083 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
21084 (0..n).map(|i| *p.add(i)).collect()
21085 }
21086 };
21087 let dx_a = read(bwd_graph.outputs[1], rows * h);
21088 let dg_a = read(bwd_graph.outputs[2], h);
21089 let db_a = read(bwd_graph.outputs[3], h);
21090
21091 let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
21092 let mut total = 0f32;
21093 for r in 0..rows {
21094 let row = &x[r * h..(r + 1) * h];
21095 let mean = row.iter().sum::<f32>() / h as f32;
21096 let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
21097 let inv_std = 1.0 / (var + eps).sqrt();
21098 for d in 0..h {
21099 total += ((row[d] - mean) * inv_std) * g[d] + b[d];
21100 }
21101 }
21102 total
21103 };
21104 let h_eps = 1e-3f32;
21105
21106 let mut x_p = x_data.clone();
21107 for i in 0..x_p.len() {
21108 let s = x_p[i];
21109 x_p[i] = s + h_eps;
21110 let lp = forward_loss(&x_p, &g_data, &b_data);
21111 x_p[i] = s - h_eps;
21112 let lm = forward_loss(&x_p, &g_data, &b_data);
21113 x_p[i] = s;
21114 let num = (lp - lm) / (2.0 * h_eps);
21115 assert!(
21116 (dx_a[i] - num).abs() < 5e-3,
21117 "ln dx[{i}]: analytical {} vs numerical {num}",
21118 dx_a[i]
21119 );
21120 }
21121 let mut g_p = g_data.clone();
21122 for i in 0..g_p.len() {
21123 let s = g_p[i];
21124 g_p[i] = s + h_eps;
21125 let lp = forward_loss(&x_data, &g_p, &b_data);
21126 g_p[i] = s - h_eps;
21127 let lm = forward_loss(&x_data, &g_p, &b_data);
21128 g_p[i] = s;
21129 let num = (lp - lm) / (2.0 * h_eps);
21130 assert!(
21131 (dg_a[i] - num).abs() < 5e-3,
21132 "ln dg[{i}]: analytical {} vs numerical {num}",
21133 dg_a[i]
21134 );
21135 }
21136 let mut b_p = b_data.clone();
21137 for i in 0..b_p.len() {
21138 let s = b_p[i];
21139 b_p[i] = s + h_eps;
21140 let lp = forward_loss(&x_data, &g_data, &b_p);
21141 b_p[i] = s - h_eps;
21142 let lm = forward_loss(&x_data, &g_data, &b_p);
21143 b_p[i] = s;
21144 let num = (lp - lm) / (2.0 * h_eps);
21145 assert!(
21146 (db_a[i] - num).abs() < 5e-3,
21147 "ln db[{i}]: analytical {} vs numerical {num}",
21148 db_a[i]
21149 );
21150 }
21151 }
21152
21153 #[test]
21158 fn dense_sce_mean_gradient_matches_numerical() {
21159 use rlx_ir::Philox4x32;
21160 let bs = 4usize;
21161 let k_in = 3usize;
21162 let c = 5usize;
21163 let mut rng = Philox4x32::new(7);
21164 let mut x = vec![0f32; bs * k_in];
21165 rng.fill_normal(&mut x);
21166 let mut w_init = vec![0f32; k_in * c];
21167 rng.fill_normal(&mut w_init);
21168 let mut b_init = vec![0f32; c];
21169 rng.fill_normal(&mut b_init);
21170 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
21171
21172 let f = DType::F32;
21174 let mut fwd = Graph::new("dense_sce");
21175 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
21176 let lb = fwd.input("labels", Shape::new(&[bs], f));
21177 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
21178 let bp = fwd.param("b", Shape::new(&[c], f));
21179 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
21180 let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
21181 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
21182 let loss = fwd.add_node(
21183 Op::Reduce {
21184 op: ReduceOp::Sum,
21185 axes: vec![0],
21186 keep_dim: false,
21187 },
21188 vec![loss_per],
21189 Shape::from_dims(&[], f),
21191 );
21192 fwd.set_outputs(vec![loss]);
21200
21201 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
21203 let d_out = bwd_graph
21206 .nodes()
21207 .iter()
21208 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
21209 .map(|n| n.id)
21210 .expect("d_output input");
21211
21212 let (sched, mut arena) = prepare(
21213 &bwd_graph,
21214 &[
21215 (xn, &x),
21216 (lb, &labels),
21217 (wp, &w_init),
21218 (bp, &b_init),
21219 (d_out, &[1.0]),
21220 ],
21221 );
21222 execute_thunks(&sched, arena.raw_buf_mut());
21223
21224 let outs = &bwd_graph.outputs;
21225 let loss_id = outs[0];
21226 let gw_id = outs[1];
21227 let gb_id = outs[2];
21228 let loss_actual = read_arena(&arena, loss_id, 1)[0];
21229 let gw_actual = read_arena(&arena, gw_id, k_in * c);
21230 let gb_actual = read_arena(&arena, gb_id, c);
21231
21232 let plan = rlx_opt::memory::plan_memory(&fwd);
21236 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
21237 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
21238 write_arena(&mut fwd_arena, xn, &x);
21239 write_arena(&mut fwd_arena, lb, &labels);
21240
21241 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
21242 write_arena(arena, wp, w);
21243 write_arena(arena, bp, b);
21244 execute_thunks(&fwd_sched, arena.raw_buf_mut());
21245 read_arena(arena, loss, 1)[0]
21246 };
21247
21248 let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
21251 assert!(
21252 (loss_actual - loss_check).abs() < 1e-4,
21253 "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
21254 );
21255
21256 let eps = 1e-3f32;
21257 let mut w_perturbed = w_init.clone();
21258 let mut gw_numerical = vec![0f32; w_init.len()];
21259 for i in 0..w_init.len() {
21260 let saved = w_perturbed[i];
21261 w_perturbed[i] = saved + eps;
21262 let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
21263 w_perturbed[i] = saved - eps;
21264 let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
21265 w_perturbed[i] = saved;
21266 gw_numerical[i] = (lp - lm) / (2.0 * eps);
21267 }
21268 for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
21269 assert!(
21270 (a - n).abs() < 5e-3,
21271 "grad_w[{i}]: analytical {a} vs numerical {n}"
21272 );
21273 }
21274
21275 let mut b_perturbed = b_init.clone();
21276 let mut gb_numerical = vec![0f32; b_init.len()];
21277 for i in 0..b_init.len() {
21278 let saved = b_perturbed[i];
21279 b_perturbed[i] = saved + eps;
21280 let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
21281 b_perturbed[i] = saved - eps;
21282 let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
21283 b_perturbed[i] = saved;
21284 gb_numerical[i] = (lp - lm) / (2.0 * eps);
21285 }
21286 for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
21287 assert!(
21288 (a - n).abs() < 5e-3,
21289 "grad_b[{i}]: analytical {a} vs numerical {n}"
21290 );
21291 }
21292 }
21293
21294 #[test]
21297 fn dense_sce_mean_reduce_gradient_matches_numerical() {
21298 use rlx_ir::Philox4x32;
21299 let bs = 3usize;
21300 let k_in = 2usize;
21301 let c = 4usize;
21302 let mut rng = Philox4x32::new(13);
21303 let mut x = vec![0f32; bs * k_in];
21304 rng.fill_normal(&mut x);
21305 let mut w_init = vec![0f32; k_in * c];
21306 rng.fill_normal(&mut w_init);
21307 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
21308
21309 let f = DType::F32;
21310 let mut fwd = Graph::new("dense_sce_mean");
21311 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
21312 let lb = fwd.input("labels", Shape::new(&[bs], f));
21313 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
21314 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
21315 let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
21316 let loss = fwd.add_node(
21317 Op::Reduce {
21318 op: ReduceOp::Mean,
21319 axes: vec![0],
21320 keep_dim: false,
21321 },
21322 vec![loss_per],
21323 Shape::from_dims(&[], f),
21324 );
21325 fwd.set_outputs(vec![loss]);
21326
21327 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
21328 let d_out = bwd_graph
21329 .nodes()
21330 .iter()
21331 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
21332 .map(|n| n.id)
21333 .unwrap();
21334
21335 let (sched, mut arena) = prepare(
21336 &bwd_graph,
21337 &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
21338 );
21339 execute_thunks(&sched, arena.raw_buf_mut());
21340
21341 let outs = &bwd_graph.outputs;
21342 let loss_id = outs[0];
21343 let gw_id = outs[1];
21344 let _ = read_arena(&arena, loss_id, 1)[0];
21345 let gw_actual = read_arena(&arena, gw_id, k_in * c);
21346
21347 let plan = rlx_opt::memory::plan_memory(&fwd);
21348 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
21349 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
21350 write_arena(&mut fwd_arena, xn, &x);
21351 write_arena(&mut fwd_arena, lb, &labels);
21352
21353 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
21354 write_arena(arena, wp, w);
21355 execute_thunks(&fwd_sched, arena.raw_buf_mut());
21356 read_arena(arena, loss, 1)[0]
21357 };
21358
21359 let eps = 1e-3f32;
21360 let mut wp_p = w_init.clone();
21361 let mut gw_num = vec![0f32; w_init.len()];
21362 for i in 0..w_init.len() {
21363 let s = wp_p[i];
21364 wp_p[i] = s + eps;
21365 let lp = run_loss(&mut fwd_arena, &wp_p);
21366 wp_p[i] = s - eps;
21367 let lm = run_loss(&mut fwd_arena, &wp_p);
21368 wp_p[i] = s;
21369 gw_num[i] = (lp - lm) / (2.0 * eps);
21370 }
21371 for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
21372 assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
21373 }
21374 }
21375 #[test]
21380 fn tinyconv_full_gradient_matches_numerical() {
21381 use rlx_ir::Philox4x32;
21382 let n = 1usize;
21384 let c_in = 1usize;
21385 let h = 6usize;
21386 let w_in = 6usize;
21387 let c_mid = 2usize; let kh = 3;
21389 let kw = 3;
21390 let h1 = h - kh + 1; let w1 = w_in - kw + 1; let h2 = h1 / 2;
21393 let w2 = w1 / 2; let flat = c_mid * h2 * w2; let num_classes = 3usize;
21396
21397 let mut rng = Philox4x32::new(31);
21398 let mut x = vec![0f32; n * c_in * h * w_in];
21399 rng.fill_normal(&mut x);
21400 let mut wc = vec![0f32; c_mid * c_in * kh * kw];
21401 rng.fill_normal(&mut wc);
21402 for v in wc.iter_mut() {
21403 *v *= 0.2;
21404 }
21405 let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
21414 let mut wfc = vec![0f32; flat * num_classes];
21415 rng.fill_normal(&mut wfc);
21416 for v in wfc.iter_mut() {
21417 *v *= 0.5;
21418 }
21419 let mut bfc = vec![0f32; num_classes];
21420 rng.fill_normal(&mut bfc);
21421 let labels: Vec<f32> = vec![1.0]; let f = DType::F32;
21424 let mut fwd = Graph::new("tinyconv");
21425 let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
21426 let lb = fwd.input("labels", Shape::new(&[n], f));
21427 let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
21428 let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
21429 let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
21430 let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
21431
21432 let conv = fwd.add_node(
21434 Op::Conv {
21435 kernel_size: vec![kh, kw],
21436 stride: vec![1, 1],
21437 padding: vec![0, 0],
21438 dilation: vec![1, 1],
21439 groups: 1,
21440 },
21441 vec![xn, wcp],
21442 Shape::new(&[n, c_mid, h1, w1], f),
21443 );
21444 let bc_4d = fwd.add_node(
21456 Op::Reshape {
21457 new_shape: vec![1, c_mid as i64, 1, 1],
21458 },
21459 vec![bcp],
21460 Shape::new(&[1, c_mid, 1, 1], f),
21461 );
21462 let bc_expanded = fwd.add_node(
21463 Op::Expand {
21464 target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
21465 },
21466 vec![bc_4d],
21467 Shape::new(&[n, c_mid, h1, w1], f),
21468 );
21469 let conv_b = fwd.binary(
21470 BinaryOp::Add,
21471 conv,
21472 bc_expanded,
21473 Shape::new(&[n, c_mid, h1, w1], f),
21474 );
21475 let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
21476 let pool = fwd.add_node(
21477 Op::Pool {
21478 kind: ReduceOp::Max,
21479 kernel_size: vec![2, 2],
21480 stride: vec![2, 2],
21481 padding: vec![0, 0],
21482 },
21483 vec![relu],
21484 Shape::new(&[n, c_mid, h2, w2], f),
21485 );
21486 let flatn = fwd.add_node(
21487 Op::Reshape {
21488 new_shape: vec![n as i64, flat as i64],
21489 },
21490 vec![pool],
21491 Shape::new(&[n, flat], f),
21492 );
21493 let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
21494 let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
21495 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
21496 let loss = fwd.add_node(
21497 Op::Reduce {
21498 op: ReduceOp::Mean,
21499 axes: vec![0],
21500 keep_dim: false,
21501 },
21502 vec![loss_per],
21503 Shape::from_dims(&[], f),
21504 );
21505 fwd.set_outputs(vec![loss]);
21506
21507 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
21508 let d_out = bwd_graph
21509 .nodes()
21510 .iter()
21511 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
21512 .map(|n| n.id)
21513 .unwrap();
21514
21515 let (sched, mut arena) = prepare(
21516 &bwd_graph,
21517 &[
21518 (xn, &x),
21519 (lb, &labels),
21520 (wcp, &wc),
21521 (bcp, &bc),
21522 (wfp, &wfc),
21523 (bfp, &bfc),
21524 (d_out, &[1.0]),
21525 ],
21526 );
21527 execute_thunks(&sched, arena.raw_buf_mut());
21528
21529 let outs = bwd_graph.outputs.clone();
21530 let loss_id = outs[0];
21531 let g_wc_id = outs[1];
21532 let g_bc_id = outs[2];
21533 let g_wfc_id = outs[3];
21534 let g_bfc_id = outs[4];
21535 let loss_actual = read_arena(&arena, loss_id, 1)[0];
21536 let g_wc = read_arena(&arena, g_wc_id, wc.len());
21537 let g_bc = read_arena(&arena, g_bc_id, bc.len());
21538 let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
21539 let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
21540
21541 let plan = rlx_opt::memory::plan_memory(&fwd);
21543 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
21544 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
21545 write_arena(&mut fwd_arena, xn, &x);
21546 write_arena(&mut fwd_arena, lb, &labels);
21547
21548 let run_loss = |arena: &mut crate::arena::Arena,
21551 wc: &[f32],
21552 bc: &[f32],
21553 wfc: &[f32],
21554 bfc: &[f32]|
21555 -> f32 {
21556 write_arena(arena, wcp, wc);
21557 write_arena(arena, bcp, bc);
21558 write_arena(arena, wfp, wfc);
21559 write_arena(arena, bfp, bfc);
21560 execute_thunks(&fwd_sched, arena.raw_buf_mut());
21561 read_arena(arena, loss, 1)[0]
21562 };
21563
21564 let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
21565 assert!(
21566 (loss_actual - loss_check).abs() < 1e-4,
21567 "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
21568 );
21569
21570 let eps = 1e-3f32;
21571 let check_grad = |arena: &mut crate::arena::Arena,
21572 name: &str,
21573 analytical: &[f32],
21574 mut perturb: Box<
21575 dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
21576 >,
21577 n: usize| {
21578 for i in 0..n {
21579 let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
21580 let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
21581 let num = (lp - lm) / (2.0 * eps);
21582 assert!(
21583 (analytical[i] - num).abs() < 5e-3,
21584 "{name}[{i}]: analytical {} vs numerical {num}",
21585 analytical[i]
21586 );
21587 }
21588 };
21589
21590 #[allow(unused_macros)]
21593 macro_rules! sweep {
21594 ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
21595 let n = $base.len();
21596 for i in 0..n {
21597 let mut p = $base.clone();
21598 let s = p[i];
21599 p[i] = s + eps;
21600 let lp = {
21601 let $set_param = &p;
21602 run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
21603 let _ = $set_param;
21606 0.0_f32
21608 };
21609 let _ = lp;
21610 }
21611 }};
21612 }
21613 let _ = check_grad; for i in 0..wc.len() {
21617 let mut p = wc.clone();
21618 let s = p[i];
21619 p[i] = s + eps;
21620 let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
21621 p[i] = s - eps;
21622 let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
21623 let num = (lp - lm) / (2.0 * eps);
21624 assert!(
21625 (g_wc[i] - num).abs() < 5e-3,
21626 "g_wc[{i}]: {} vs {num}",
21627 g_wc[i]
21628 );
21629 }
21630 for i in 0..bc.len() {
21631 let mut p = bc.clone();
21632 let s = p[i];
21633 p[i] = s + eps;
21634 let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
21635 p[i] = s - eps;
21636 let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
21637 let num = (lp - lm) / (2.0 * eps);
21638 assert!(
21639 (g_bc[i] - num).abs() < 5e-3,
21640 "g_bc[{i}]: {} vs {num}",
21641 g_bc[i]
21642 );
21643 }
21644 for i in 0..wfc.len() {
21645 let mut p = wfc.clone();
21646 let s = p[i];
21647 p[i] = s + eps;
21648 let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
21649 p[i] = s - eps;
21650 let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
21651 let num = (lp - lm) / (2.0 * eps);
21652 assert!(
21653 (g_wfc[i] - num).abs() < 5e-3,
21654 "g_wfc[{i}]: {} vs {num}",
21655 g_wfc[i]
21656 );
21657 }
21658 for i in 0..bfc.len() {
21659 let mut p = bfc.clone();
21660 let s = p[i];
21661 p[i] = s + eps;
21662 let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
21663 p[i] = s - eps;
21664 let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
21665 let num = (lp - lm) / (2.0 * eps);
21666 assert!(
21667 (g_bfc[i] - num).abs() < 5e-3,
21668 "g_bfc[{i}]: {} vs {num}",
21669 g_bfc[i]
21670 );
21671 }
21672 }
21673
21674 #[test]
21678 fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
21679 let f = DType::F32;
21680 let mut g = Graph::new("nr_skip");
21681 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
21682 let cos = g.input("cos", Shape::new(&[16], f));
21683 let sin = g.input("sin", Shape::new(&[16], f));
21684 let q = g.narrow_(qkv, 2, 0, 64);
21685 let q_rope = g.rope(q, cos, sin, 16);
21686 let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
21688 g.set_outputs(vec![q_rope, q_dup]);
21689
21690 let plan = rlx_opt::memory::plan_memory(&g);
21691 let arena = crate::arena::Arena::from_plan(plan);
21692 let sched = compile_thunks(&g, &arena);
21693
21694 let narrow_count = sched
21695 .thunks
21696 .iter()
21697 .filter(|t| matches!(t, Thunk::Narrow { .. }))
21698 .count();
21699 assert!(
21700 narrow_count >= 1,
21701 "Narrow with multiple consumers must NOT be fused away"
21702 );
21703 }
21704
21705 #[test]
21718 fn custom_fn_forward_inlines_body() {
21719 let s = Shape::new(&[3], DType::F32);
21720
21721 let mut body = Graph::new("addone_body");
21723 let x = body.input("x", s.clone());
21724 let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
21725 let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
21726 let y = body.binary(BinaryOp::Add, x, one, s.clone());
21727 body.set_outputs(vec![y]);
21728
21729 let mut g = Graph::new("custom_fn_outer");
21730 let xin = g.input("x_in", s.clone());
21731 let cf = g.custom_fn(vec![xin], body, None, None);
21732 g.set_outputs(vec![cf]);
21733
21734 let xs = vec![10.0_f32, 20.0, 30.0];
21735 let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
21736 execute_thunks(&sched, arena.raw_buf_mut());
21737 let got = read_arena(&arena, cf, 3);
21738 assert_eq!(got, vec![11.0, 21.0, 31.0]);
21739 }
21740
21741 fn find_named(graph: &Graph, want: &str) -> NodeId {
21743 for n in graph.nodes() {
21744 let name = match &n.op {
21745 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
21746 _ => None,
21747 };
21748 if name == Some(want) {
21749 return n.id;
21750 }
21751 }
21752 panic!("no node named {want:?} in graph");
21753 }
21754
21755 #[test]
21759 fn custom_fn_vjp_overrides_natural_gradient() {
21760 use rlx_opt::autodiff::grad_with_loss;
21761 let s = Shape::new(&[1], DType::F32);
21762
21763 let mut fwd = Graph::new("id_fwd");
21764 let x = fwd.input("x", s.clone());
21765 fwd.set_outputs(vec![x]);
21766
21767 let mut vjp_g = Graph::new("id_vjp");
21768 let _x_p = vjp_g.input("x", s.clone());
21769 let _y_p = vjp_g.input("primal_output", s.clone());
21770 let dy = vjp_g.input("d_output", s.clone());
21771 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
21772 let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
21773 let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
21774 vjp_g.set_outputs(vec![dx]);
21775
21776 let mut g = Graph::new("outer");
21777 let xp = g.param("x", s.clone());
21778 let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
21779 g.set_outputs(vec![cf]);
21780
21781 let bwd = grad_with_loss(&g, &[xp]);
21782 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
21783
21784 let xb = find_named(&bwd, "x");
21785 let dout = find_named(&bwd, "d_output");
21786 let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
21787 execute_thunks(&sched, arena.raw_buf_mut());
21788 let loss = read_arena(&arena, bwd.outputs[0], 1);
21789 let dx_v = read_arena(&arena, bwd.outputs[1], 1);
21790 assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
21791 assert!(
21792 (dx_v[0] - 2.0).abs() < 1e-6,
21793 "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
21794 dx_v[0]
21795 );
21796 }
21797
21798 #[test]
21803 fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
21804 use rlx_opt::autodiff::grad_with_loss;
21805 let s = Shape::new(&[1], DType::F32);
21806
21807 let mut fwd = Graph::new("mul_fwd");
21808 let a_f = fwd.input("a", s.clone());
21809 let b_f = fwd.input("b", s.clone());
21810 let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
21811 fwd.set_outputs(vec![y_f]);
21812
21813 let mut vjp_g = Graph::new("mul_vjp");
21814 let a_v = vjp_g.input("a", s.clone());
21815 let b_v = vjp_g.input("b", s.clone());
21816 let _y_v = vjp_g.input("primal_output", s.clone());
21817 let dy_v = vjp_g.input("d_output", s.clone());
21818 let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
21819 let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
21820 vjp_g.set_outputs(vec![da, db]);
21821
21822 let mut g = Graph::new("outer");
21823 let ap = g.param("a", s.clone());
21824 let bp = g.param("b", s.clone());
21825 let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
21826 g.set_outputs(vec![cf]);
21827
21828 let bwd = grad_with_loss(&g, &[ap, bp]);
21829 assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
21830
21831 let ab = find_named(&bwd, "a");
21832 let bb = find_named(&bwd, "b");
21833 let dout = find_named(&bwd, "d_output");
21834 let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
21835 execute_thunks(&sched, arena.raw_buf_mut());
21836 let loss = read_arena(&arena, bwd.outputs[0], 1);
21837 let da_v = read_arena(&arena, bwd.outputs[1], 1);
21838 let db_v = read_arena(&arena, bwd.outputs[2], 1);
21839 assert!((loss[0] - 15.0).abs() < 1e-5);
21840 assert!(
21841 (da_v[0] - 5.0).abs() < 1e-5,
21842 "da should be b=5.0, got {}",
21843 da_v[0]
21844 );
21845 assert!(
21846 (db_v[0] - 3.0).abs() < 1e-5,
21847 "db should be a=3.0, got {}",
21848 db_v[0]
21849 );
21850 }
21851
21852 #[test]
21855 fn custom_fn_jvp_overrides_natural_tangent() {
21856 use rlx_opt::autodiff_fwd::jvp;
21857 let s = Shape::new(&[1], DType::F32);
21858
21859 let mut fwd = Graph::new("id_fwd");
21860 let x = fwd.input("x", s.clone());
21861 fwd.set_outputs(vec![x]);
21862
21863 let mut jvp_g = Graph::new("id_jvp");
21864 let _x_p = jvp_g.input("x", s.clone());
21865 let tx = jvp_g.input("tangent_0", s.clone());
21866 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
21867 let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
21868 let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
21869 jvp_g.set_outputs(vec![ty]);
21870
21871 let mut g = Graph::new("outer");
21872 let xin = g.input("x_in", s.clone());
21873 let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
21874 g.set_outputs(vec![cf]);
21875
21876 let fwd_g = jvp(&g, &[xin]);
21877 assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
21878
21879 let xb = find_named(&fwd_g, "x_in");
21880 let tan = find_named(&fwd_g, "tangent_x_in");
21881 let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
21882 execute_thunks(&sched, arena.raw_buf_mut());
21883 let y = read_arena(&arena, fwd_g.outputs[0], 1);
21884 let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
21885 assert!((y[0] - 7.0).abs() < 1e-6);
21886 assert!(
21887 (ty_v[0] - 2.0).abs() < 1e-6,
21888 "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
21889 ty_v[0]
21890 );
21891 }
21892
21893 #[test]
21898 fn c64_dtype_storage_layout() {
21899 assert_eq!(
21900 DType::C64.size_bytes(),
21901 8,
21902 "C64 should be 8 bytes (f32 real + f32 imag)"
21903 );
21904 assert!(DType::C64.is_complex());
21905 assert!(!DType::C64.is_float());
21906
21907 let s = Shape::new(&[2], DType::C64);
21909 assert_eq!(s.size_bytes().unwrap(), 16);
21910 }
21911
21912 fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
21919 let n = a.len();
21920 let s = Shape::new(&[n], DType::C64);
21921 let mut g = Graph::new("c64_bin");
21922 let in_a = g.input("a", s.clone());
21923 let in_b = g.input("b", s.clone());
21924 let out = g.binary(op, in_a, in_b, s.clone());
21925 g.set_outputs(vec![out]);
21926
21927 let plan = rlx_opt::memory::plan_memory(&g);
21928 let mut arena = crate::arena::Arena::from_plan(plan);
21929 let sched = compile_thunks(&g, &arena);
21930
21931 let a_off = arena.byte_offset(in_a);
21932 let b_off = arena.byte_offset(in_b);
21933 let out_off = arena.byte_offset(out);
21934 let buf = arena.raw_buf_mut();
21936 unsafe {
21937 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21938 let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
21939 for (i, &(re, im)) in a.iter().enumerate() {
21940 *pa.add(2 * i) = re;
21941 *pa.add(2 * i + 1) = im;
21942 }
21943 for (i, &(re, im)) in b.iter().enumerate() {
21944 *pb.add(2 * i) = re;
21945 *pb.add(2 * i + 1) = im;
21946 }
21947 }
21948 execute_thunks(&sched, arena.raw_buf_mut());
21949 let raw_out: Vec<f32> = unsafe {
21950 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21951 (0..(2 * n)).map(|i| *p.add(i)).collect()
21952 };
21953 (0..n)
21954 .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
21955 .collect()
21956 }
21957
21958 #[track_caller]
21959 fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
21960 let dr = (got.0 - expected.0).abs();
21961 let di = (got.1 - expected.1).abs();
21962 assert!(
21963 dr < tol && di < tol,
21964 "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
21965 got.0,
21966 got.1,
21967 expected.0,
21968 expected.1
21969 );
21970 }
21971
21972 #[test]
21973 fn c64_binary_add_matches_complex_arithmetic() {
21974 let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
21975 let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
21976 let out = run_c64_binary(BinaryOp::Add, &a, &b);
21977 assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
21978 assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
21979 }
21980
21981 #[test]
21982 fn c64_binary_sub_matches_complex_arithmetic() {
21983 let a = [(5.0_f32, 1.0_f32)];
21984 let b = [(2.0_f32, 3.0_f32)];
21985 let out = run_c64_binary(BinaryOp::Sub, &a, &b);
21986 assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
21987 }
21988
21989 #[test]
21990 fn c64_binary_mul_matches_complex_arithmetic() {
21991 let a = [(1.0_f32, 2.0_f32)];
21993 let b = [(3.0_f32, 4.0_f32)];
21994 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21995 assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
21996 }
21997
21998 #[test]
21999 fn c64_binary_div_matches_complex_arithmetic() {
22000 let a = [(1.0_f32, 2.0_f32)];
22004 let b = [(3.0_f32, 4.0_f32)];
22005 let out = run_c64_binary(BinaryOp::Div, &a, &b);
22006 assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
22007 }
22008
22009 #[test]
22010 fn c64_binary_mul_identity_one_is_no_op() {
22011 let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
22013 let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
22014 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
22015 assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
22016 assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
22017 }
22018
22019 #[test]
22020 fn c64_binary_mul_by_i_rotates_90_degrees() {
22021 let a = [(1.0_f32, 0.0_f32)];
22023 let b = [(0.0_f32, 1.0_f32)];
22024 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
22025 assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
22026 }
22027
22028 #[test]
22029 fn c64_binary_div_by_self_gives_unity() {
22030 let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
22031 let out = run_c64_binary(BinaryOp::Div, &a, &a);
22032 assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
22033 assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
22034 }
22035
22036 #[test]
22037 #[should_panic(expected = "C64: complex max/min/pow")]
22038 fn c64_binary_max_is_rejected_at_lowering() {
22039 run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
22040 }
22041
22042 fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
22043 let n = a.len();
22044 let s = Shape::new(&[n], DType::C64);
22045 let mut g = Graph::new("c64_act");
22046 let in_a = g.input("a", s.clone());
22047 let out = g.activation(act, in_a, s.clone());
22048 g.set_outputs(vec![out]);
22049 let plan = rlx_opt::memory::plan_memory(&g);
22050 let mut arena = crate::arena::Arena::from_plan(plan);
22051 let sched = compile_thunks(&g, &arena);
22052 let a_off = arena.byte_offset(in_a);
22053 let out_off = arena.byte_offset(out);
22054 let buf = arena.raw_buf_mut();
22055 unsafe {
22056 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
22057 for (i, &(re, im)) in a.iter().enumerate() {
22058 *pa.add(2 * i) = re;
22059 *pa.add(2 * i + 1) = im;
22060 }
22061 }
22062 execute_thunks(&sched, arena.raw_buf_mut());
22063 let raw: Vec<f32> = unsafe {
22064 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
22065 (0..(2 * n)).map(|i| *p.add(i)).collect()
22066 };
22067 (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
22068 }
22069
22070 #[test]
22071 fn c64_activation_neg_negates_both_components() {
22072 let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
22073 let out = run_c64_activation(Activation::Neg, &inp);
22074 assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
22075 assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
22076 }
22077
22078 #[test]
22079 fn c64_activation_exp_matches_euler() {
22080 let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
22083 let out = run_c64_activation(Activation::Exp, &inp);
22084 assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
22085 assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
22086 }
22087
22088 #[test]
22089 fn c64_activation_log_matches_principal_branch() {
22090 let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
22094 let out = run_c64_activation(Activation::Log, &inp);
22095 assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
22096 assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
22097 assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
22098 }
22099
22100 #[test]
22101 fn c64_activation_sqrt_squared_recovers_input() {
22102 let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
22105 let roots = run_c64_activation(Activation::Sqrt, &inp);
22106 assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
22108 assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
22109 }
22110
22111 #[test]
22112 #[should_panic(expected = "no natural complex extension")]
22113 fn c64_activation_relu_is_rejected_at_lowering() {
22114 run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
22115 }
22116
22117 fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
22121 let n = z.len();
22122 let mut g = Graph::new("cns_fwd");
22123 let in_z = g.input("z", Shape::new(&[n], DType::C64));
22124 let out = g.complex_norm_sq(in_z);
22125 g.set_outputs(vec![out]);
22126 let plan = rlx_opt::memory::plan_memory(&g);
22127 let mut arena = crate::arena::Arena::from_plan(plan);
22128 let sched = compile_thunks(&g, &arena);
22129 let z_off = arena.byte_offset(in_z);
22130 let out_off = arena.byte_offset(out);
22131 let buf = arena.raw_buf_mut();
22132 unsafe {
22133 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
22134 for (i, &(re, im)) in z.iter().enumerate() {
22135 *pz.add(2 * i) = re;
22136 *pz.add(2 * i + 1) = im;
22137 }
22138 }
22139 execute_thunks(&sched, arena.raw_buf_mut());
22140 unsafe {
22141 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
22142 (0..n).map(|i| *p.add(i)).collect()
22143 }
22144 }
22145
22146 fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
22148 let n = z.len();
22149 let mut gr = Graph::new("cns_bwd");
22150 let in_z = gr.input("z", Shape::new(&[n], DType::C64));
22151 let in_g = gr.input("g", Shape::new(&[n], DType::F32));
22152 let out = gr.complex_norm_sq_backward(in_z, in_g);
22153 gr.set_outputs(vec![out]);
22154 let plan = rlx_opt::memory::plan_memory(&gr);
22155 let mut arena = crate::arena::Arena::from_plan(plan);
22156 let sched = compile_thunks(&gr, &arena);
22157 let z_off = arena.byte_offset(in_z);
22158 let g_off = arena.byte_offset(in_g);
22159 let out_off = arena.byte_offset(out);
22160 let buf = arena.raw_buf_mut();
22161 unsafe {
22162 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
22163 let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
22164 for (i, &(re, im)) in z.iter().enumerate() {
22165 *pz.add(2 * i) = re;
22166 *pz.add(2 * i + 1) = im;
22167 }
22168 for (i, &v) in g.iter().enumerate() {
22169 *pg.add(i) = v;
22170 }
22171 }
22172 execute_thunks(&sched, arena.raw_buf_mut());
22173 unsafe {
22174 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
22175 (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
22176 }
22177 }
22178
22179 #[test]
22180 fn complex_norm_sq_matches_textbook() {
22181 let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
22185 let out = run_complex_norm_sq(&z);
22186 assert!((out[0] - 25.0).abs() < 1e-5);
22187 assert!((out[1] - 1.0).abs() < 1e-6);
22188 assert!(out[2].abs() < 1e-6);
22189 }
22190
22191 #[test]
22192 fn complex_norm_sq_backward_matches_wirtinger_formula() {
22193 let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
22195 let g = [1.0_f32, 1.0_f32];
22196 let dz = run_complex_norm_sq_bwd(&z, &g);
22197 assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
22198 assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
22199 }
22200
22201 #[test]
22202 fn complex_norm_sq_backward_scales_with_upstream() {
22203 let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
22205 let g = [0.5_f32, -2.0_f32];
22206 let dz = run_complex_norm_sq_bwd(&z, &g);
22207 assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
22208 assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
22209 }
22210
22211 #[test]
22216 fn custom_fn_multi_extracts_each_subgraph_output() {
22217 use rlx_ir::ops::special::MultiOutputHandle;
22218
22219 let _ = MultiOutputHandle {
22220 source: NodeId(0),
22221 sub_shapes: vec![],
22222 offsets: vec![],
22223 }; let mut body = Graph::new("multi_body");
22227 let s3 = Shape::new(&[3], DType::F32);
22228 let x = body.input("x", s3.clone());
22229 let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
22230 let two = body.add_node(
22231 Op::Constant {
22232 data: vec![
22233 2.0_f32.to_le_bytes(),
22234 2.0_f32.to_le_bytes(),
22235 2.0_f32.to_le_bytes(),
22236 ]
22237 .into_iter()
22238 .flatten()
22239 .collect(),
22240 },
22241 vec![],
22242 s3.clone(),
22243 );
22244 let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
22245 body.set_outputs(vec![x_sq, two_x]);
22246
22247 let mut outer = Graph::new("multi_outer");
22249 let in_x = outer.input("xin", s3.clone());
22250 let handle = outer.custom_fn_multi(vec![in_x], body);
22251 assert_eq!(handle.n_outputs(), 2);
22252 let out0 = handle.output(&mut outer, 0); let out1 = handle.output(&mut outer, 1); outer.set_outputs(vec![out0, out1]);
22255
22256 let plan = rlx_opt::memory::plan_memory(&outer);
22257 let mut arena = crate::arena::Arena::from_plan(plan);
22258 let sched = compile_thunks(&outer, &arena);
22259 let xin_off = arena.byte_offset(in_x);
22260 let out0_off = arena.byte_offset(out0);
22261 let out1_off = arena.byte_offset(out1);
22262 let xs = [1.0_f32, 2.0, 3.0];
22263 unsafe {
22264 let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
22265 for (i, &v) in xs.iter().enumerate() {
22266 *p.add(i) = v;
22267 }
22268 }
22269 execute_thunks(&sched, arena.raw_buf_mut());
22270 let out0_v: Vec<f32> = unsafe {
22271 let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
22272 (0..3).map(|i| *p.add(i)).collect()
22273 };
22274 let out1_v: Vec<f32> = unsafe {
22275 let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
22276 (0..3).map(|i| *p.add(i)).collect()
22277 };
22278 for i in 0..3 {
22280 assert!(
22281 (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
22282 "out0[{i}] = {} != x² = {}",
22283 out0_v[i],
22284 xs[i] * xs[i]
22285 );
22286 assert!(
22287 (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
22288 "out1[{i}] = {} != 2x = {}",
22289 out1_v[i],
22290 2.0 * xs[i]
22291 );
22292 }
22293 }
22294
22295 #[test]
22296 fn complex_norm_sq_gradient_matches_finite_difference() {
22297 let z = [(3.0_f32, 4.0_f32)];
22299 let eps = 1e-3_f32;
22300 let v0 = run_complex_norm_sq(&z)[0];
22301 let z_pert = [(3.0_f32 + eps, 4.0_f32)];
22302 let v1 = run_complex_norm_sq(&z_pert)[0];
22303 let fd_re = (v1 - v0) / eps;
22304 let analytic_re = 2.0 * z[0].0;
22305 assert!((fd_re - analytic_re).abs() < 1e-2);
22306
22307 let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
22309 let v2 = run_complex_norm_sq(&z_pert_im)[0];
22310 let fd_im = (v2 - v0) / eps;
22311 let analytic_im = 2.0 * z[0].1;
22312 assert!((fd_im - analytic_im).abs() < 1e-2);
22313
22314 let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
22320 assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
22321 assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
22322 }
22323
22324 #[test]
22329 fn binary_full_5d_mid_singleton_broadcast() {
22330 let bh = 2usize;
22331 let h = 3;
22332 let w = 4;
22333 let f = DType::F32;
22334
22335 let mut g = Graph::new("bcast_5d");
22336 let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
22337 let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
22339 let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
22340 g.set_outputs(vec![out]);
22341
22342 let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
22344 let rhs_data: Vec<f32> = (0..bh * h * w * w)
22345 .map(|i| (i as f32 + 100.0) * 0.01)
22346 .collect();
22347
22348 let mut expected = vec![0f32; bh * h * w * h * w];
22350 for b_ in 0..bh {
22351 for hq in 0..h {
22352 for wq in 0..w {
22353 for hk in 0..h {
22354 for wk in 0..w {
22355 let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
22356 let ri = ((b_ * h + hq) * w + wq) * w + wk;
22358 expected[li] = lhs_data[li] + rhs_data[ri];
22359 }
22360 }
22361 }
22362 }
22363 }
22364
22365 let plan = rlx_opt::memory::plan_memory(&g);
22366 let mut arena = crate::arena::Arena::from_plan(plan);
22367 let sched = compile_thunks(&g, &arena);
22368 let lhs_off = arena.byte_offset(lhs);
22369 let rhs_off = arena.byte_offset(rhs);
22370 let out_off = arena.byte_offset(out);
22371 let buf = arena.raw_buf_mut();
22372 unsafe {
22373 let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
22374 for (i, &v) in lhs_data.iter().enumerate() {
22375 *p.add(i) = v;
22376 }
22377 let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
22378 for (i, &v) in rhs_data.iter().enumerate() {
22379 *p.add(i) = v;
22380 }
22381 }
22382 execute_thunks(&sched, arena.raw_buf_mut());
22383 let actual: Vec<f32> = unsafe {
22384 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
22385 (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
22386 };
22387
22388 let mut max_diff = 0f32;
22390 let mut max_idx = 0;
22391 for i in 0..actual.len() {
22392 let d = (actual[i] - expected[i]).abs();
22393 if d > max_diff {
22394 max_diff = d;
22395 max_idx = i;
22396 }
22397 }
22398 assert!(
22399 max_diff < 1e-6,
22400 "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
22401 (actual={}, expected={})",
22402 actual[max_idx],
22403 expected[max_idx]
22404 );
22405 }
22406
22407 #[test]
22408 fn layer_norm2d_and_conv_transpose2d_kernels() {
22409 let mut out = vec![0f32; 8];
22410 crate::kernels::layer_norm2d_nchw(
22411 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
22412 &[1.0, 1.0],
22413 &[0.0, 0.0],
22414 &mut out,
22415 1,
22416 2,
22417 2,
22418 2,
22419 1e-5,
22420 );
22421 let mean0: f32 = (1.0 + 3.0) / 2.0;
22422 assert!((out[0] - mean0).abs() > 0.1);
22423
22424 let mut up = vec![0f32; 4];
22425 crate::kernels::conv_transpose2d_nchw(
22426 &[2.0],
22427 &[1.0, 0.0, 0.0, 1.0],
22428 &mut up,
22429 1,
22430 1,
22431 1,
22432 1,
22433 1,
22434 2,
22435 2,
22436 2,
22437 2,
22438 2,
22439 2,
22440 0,
22441 0,
22442 1,
22443 1,
22444 1,
22445 );
22446 assert!((up[0] - 2.0).abs() < 1e-5);
22447 assert!((up[3] - 2.0).abs() < 1e-5);
22448 }
22449}