1#![allow(unsafe_op_in_unsafe_fn)]
24use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33#[derive(Clone)]
35pub enum Thunk {
36 Nop,
38 Sgemm {
40 a: usize,
41 b: usize,
42 c: usize,
43 m: u32,
44 k: u32,
45 n: u32,
46 },
47 DenseSolveF64 {
53 a: usize,
54 b: usize,
55 x: usize,
56 n: u32,
57 nrhs: u32,
58 },
59 DenseSolveF32 {
62 a: usize,
63 b: usize,
64 x: usize,
65 n: u32,
66 nrhs: u32,
67 },
68 BatchedDenseSolveF64 {
73 a: usize,
74 b: usize,
75 x: usize,
76 batch: u32,
77 n: u32,
78 nrhs: u32,
79 },
80 BatchedDenseSolveF32 {
82 a: usize,
83 b: usize,
84 x: usize,
85 batch: u32,
86 n: u32,
87 nrhs: u32,
88 },
89 BatchedDgemmF64 {
95 a: usize,
96 b: usize,
97 c: usize,
98 batch: u32,
99 m: u32,
100 k: u32,
101 n: u32,
102 },
103 BatchedSgemm {
110 a: usize,
111 b: usize,
112 c: usize,
113 batch: u32,
114 m: u32,
115 k: u32,
116 n: u32,
117 },
118 Dgemm {
120 a: usize,
121 b: usize,
122 c: usize,
123 m: u32,
124 k: u32,
125 n: u32,
126 },
127 TransposeF64 {
131 src: usize,
132 dst: usize,
133 in_total: u32,
134 out_dims: Vec<u32>,
135 in_strides: Vec<u32>,
136 },
137 ActivationF64 {
141 src: usize,
142 dst: usize,
143 len: u32,
144 kind: Activation,
145 },
146 ComplexNormSqF32 {
150 src: usize,
151 dst: usize,
152 len: u32,
154 },
155 ComplexNormSqBackwardF32 {
159 z: usize,
160 g: usize,
161 dz: usize,
162 len: u32,
163 },
164 ConjugateC64 { src: usize, dst: usize, len: u32 },
167 ActivationC64 {
174 src: usize,
175 dst: usize,
176 len: u32,
177 kind: Activation,
178 },
179 ReduceSumF64 {
183 src: usize,
184 dst: usize,
185 outer: u32,
186 reduced: u32,
187 inner: u32,
188 },
189 CopyF64 { src: usize, dst: usize, len: u32 },
192 CopyI64 { src: usize, dst: usize, len: u32 },
194 CastF32ToI64 { src: usize, dst: usize, len: u32 },
196 CastI64ToF32 { src: usize, dst: usize, len: u32 },
198 CastBoolToI32 { src: usize, dst: usize, len: u32 },
200 CastI32ToF32 { src: usize, dst: usize, len: u32 },
202 BinaryFullF64 {
206 lhs: usize,
207 rhs: usize,
208 dst: usize,
209 len: u32,
210 lhs_len: u32,
211 rhs_len: u32,
212 op: BinaryOp,
213 out_dims_bcast: Vec<u32>,
216 bcast_lhs_strides: Vec<u32>,
217 bcast_rhs_strides: Vec<u32>,
218 },
219 ConcatF64 {
223 dst: usize,
224 outer: u32,
225 inner: u32,
226 total_axis: u32,
227 inputs: Vec<(usize, u32)>,
228 },
229 BinaryFullC64 {
237 lhs: usize,
238 rhs: usize,
239 dst: usize,
240 len: u32,
243 lhs_len: u32,
244 rhs_len: u32,
245 op: BinaryOp,
246 out_dims_bcast: Vec<u32>,
247 bcast_lhs_strides: Vec<u32>,
248 bcast_rhs_strides: Vec<u32>,
249 },
250 Scan {
259 body: Arc<ThunkSchedule>,
260 body_init: Arc<Vec<u8>>, body_input_off: usize, body_output_off: usize, outer_init_off: usize, outer_final_off: usize, length: u32,
266 carry_bytes: u32, save_trajectory: bool,
272 xs_inputs: Arc<Vec<(usize, usize, u32)>>,
277 bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
283 num_checkpoints: u32,
289 },
290
291 ScanBackward {
299 body_vjp: Arc<ThunkSchedule>,
300 body_init: Arc<Vec<u8>>,
301 body_carry_in_off: usize, body_x_offs: Arc<Vec<usize>>, body_d_output_off: usize, body_dcarry_out_off: usize, outer_init_off: usize, outer_traj_off: usize, outer_upstream_off: usize, outer_xs_offs: Arc<Vec<(usize, u32)>>,
311 outer_dinit_off: usize, length: u32,
313 carry_bytes: u32,
314 carry_elem_size: u32,
320 save_trajectory: bool, num_checkpoints: u32,
327 forward_body: Option<Arc<ThunkSchedule>>,
331 forward_body_init: Option<Arc<Vec<u8>>>,
333 forward_body_carry_in_off: usize,
336 forward_body_output_off: usize,
337 forward_body_x_offs: Arc<Vec<usize>>,
340 },
341
342 ScanBackwardXs {
349 body_vjp: Arc<ThunkSchedule>,
350 body_init: Arc<Vec<u8>>,
351 body_carry_in_off: usize,
352 body_x_offs: Arc<Vec<usize>>,
353 body_d_output_off: usize,
354 body_dcarry_out_off: usize,
355 body_dxs_out_off: usize, outer_init_off: usize,
357 outer_traj_off: usize,
358 outer_upstream_off: usize,
359 outer_xs_offs: Arc<Vec<(usize, u32)>>,
360 outer_dxs_off: usize, length: u32,
362 carry_bytes: u32,
363 carry_elem_size: u32,
365 per_step_bytes: u32, save_trajectory: bool,
367 num_checkpoints: u32,
375 forward_body: Option<Arc<ThunkSchedule>>,
376 forward_body_init: Option<Arc<Vec<u8>>>,
377 forward_body_carry_in_off: usize,
378 forward_body_output_off: usize,
379 forward_body_x_offs: Arc<Vec<usize>>,
380 },
381 CustomFn {
386 body: Arc<ThunkSchedule>,
387 body_init: Arc<Vec<u8>>,
388 inputs: Arc<Vec<(usize, usize, u32)>>,
390 body_output_off: usize,
391 outer_output_off: usize,
392 out_bytes: u32,
393 },
394 FusedMmBiasAct {
396 a: usize,
397 w: usize,
398 bias: usize,
399 c: usize,
400 m: u32,
401 k: u32,
402 n: u32,
403 act: Option<Activation>,
404 },
405 FusedResidualLN {
407 x: usize,
408 res: usize,
409 bias: usize,
410 g: usize,
411 b: usize,
412 out: usize,
413 rows: u32,
414 h: u32,
415 eps: f32,
416 has_bias: bool,
417 },
418 FusedResidualRmsNorm {
420 x: usize,
421 res: usize,
422 bias: usize,
423 g: usize,
424 b: usize,
425 out: usize,
426 rows: u32,
427 h: u32,
428 eps: f32,
429 has_bias: bool,
430 },
431 BiasAdd {
433 src: usize,
434 bias: usize,
435 dst: usize,
436 m: u32,
437 n: u32,
438 },
439 BinaryFull {
454 lhs: usize,
455 rhs: usize,
456 dst: usize,
457 len: u32,
458 lhs_len: u32,
459 rhs_len: u32,
460 op: BinaryOp,
461 out_dims_bcast: Vec<u32>,
463 bcast_lhs_strides: Vec<u32>,
465 bcast_rhs_strides: Vec<u32>,
467 elem_bytes: u8,
469 },
470 ActivationInPlace {
472 data: usize,
473 len: u32,
474 act: Activation,
475 },
476 Gather {
478 table: usize,
479 table_len: u32,
480 idx: usize,
481 dst: usize,
482 num_idx: u32,
483 trailing: u32,
484 idx_i64: u8,
486 table_bytes: u8,
488 },
489 Narrow {
491 src: usize,
492 dst: usize,
493 outer: u32,
494 src_stride: u32,
495 dst_stride: u32,
496 inner: u32,
497 elem_bytes: u8,
498 },
499 Copy { src: usize, dst: usize, len: u32 },
501 LayerNorm {
503 src: usize,
504 g: usize,
505 b: usize,
506 dst: usize,
507 rows: u32,
508 h: u32,
509 eps: f32,
510 },
511 GroupNorm {
513 src: usize,
514 g: usize,
515 b: usize,
516 dst: usize,
517 n: u32,
518 c: u32,
519 h: u32,
520 w: u32,
521 num_groups: u32,
522 eps: f32,
523 },
524 BatchNormInference {
526 src: usize,
527 g: usize,
528 b: usize,
529 mean: usize,
530 var: usize,
531 dst: usize,
532 count: u32,
533 channels: u32,
534 eps: f32,
535 },
536 BatchNormInferenceBackwardInput {
537 x: usize,
538 gamma: usize,
539 mean: usize,
540 var: usize,
541 dy: usize,
542 dx: usize,
543 count: u32,
544 channels: u32,
545 eps: f32,
546 },
547 BatchNormInferenceBackwardGamma {
548 x: usize,
549 mean: usize,
550 var: usize,
551 dy: usize,
552 dgamma: usize,
553 count: u32,
554 channels: u32,
555 eps: f32,
556 },
557 BatchNormInferenceBackwardBeta {
558 dy: usize,
559 dbeta: usize,
560 count: u32,
561 channels: u32,
562 },
563 LayerNorm2d {
565 src: usize,
566 g: usize,
567 b: usize,
568 dst: usize,
569 n: u32,
570 c: u32,
571 h: u32,
572 w: u32,
573 eps: f32,
574 },
575 ConvTranspose2d {
577 src: usize,
578 weight: usize,
579 dst: usize,
580 n: u32,
581 c_in: u32,
582 h: u32,
583 w_in: u32,
584 c_out: u32,
585 h_out: u32,
586 w_out: u32,
587 kh: u32,
588 kw: u32,
589 sh: u32,
590 sw: u32,
591 ph: u32,
592 pw: u32,
593 dh: u32,
594 dw: u32,
595 groups: u32,
596 },
597 ResizeNearest2x {
599 src: usize,
600 dst: usize,
601 n: u32,
602 c: u32,
603 h: u32,
604 w: u32,
605 },
606 AxialRope2d {
608 src: usize,
609 dst: usize,
610 batch: u32,
611 seq: u32,
612 hidden: u32,
613 end_x: u32,
614 end_y: u32,
615 head_dim: u32,
616 num_heads: u32,
617 theta: f32,
618 repeat_factor: u32,
619 },
620 RmsNorm {
623 src: usize,
624 g: usize,
625 b: usize,
626 dst: usize,
627 rows: u32,
628 h: u32,
629 eps: f32,
630 },
631 Softmax { data: usize, rows: u32, cols: u32 },
633 Cumsum {
636 src: usize,
637 dst: usize,
638 rows: u32,
639 cols: u32,
640 exclusive: bool,
641 },
642 SelectiveScan {
646 x: usize,
647 delta: usize,
648 a: usize,
649 b: usize,
650 c: usize,
651 dst: usize,
652 batch: u32,
653 seq: u32,
654 hidden: u32,
655 state_size: u32,
656 },
657
658 GatedDeltaNet {
662 q: usize,
663 k: usize,
664 v: usize,
665 g: usize,
666 beta: usize,
667 state: usize,
670 dst: usize,
671 batch: u32,
672 seq: u32,
673 heads: u32,
674 state_size: u32,
675 },
676
677 Conv2D1x1 {
687 src: usize,
688 weight: usize,
689 dst: usize,
690 n: u32,
691 c_in: u32,
692 c_out: u32,
693 hw: u32,
694 },
695
696 DequantMatMul {
700 x: usize,
701 w_q: usize, scale: usize, zp: usize, dst: usize,
705 m: u32,
706 k: u32,
707 n: u32,
708 block_size: u32,
709 is_asymmetric: bool,
710 },
711
712 DequantMatMulGguf {
722 x: usize, w_q: usize, dst: usize, m: u32,
726 k: u32,
727 n: u32,
728 scheme: rlx_ir::quant::QuantScheme,
729 },
730
731 DequantMatMulInt4 {
733 x: usize,
734 w_q: usize,
735 scale: usize,
736 zp: usize,
737 dst: usize,
738 m: u32,
739 k: u32,
740 n: u32,
741 block_size: u32,
742 is_asymmetric: bool,
743 },
744
745 DequantMatMulFp8 {
747 x: usize,
748 w_q: usize,
749 scale: usize,
750 dst: usize,
751 m: u32,
752 k: u32,
753 n: u32,
754 e5m2: bool,
755 },
756
757 DequantMatMulNvfp4 {
759 x: usize,
760 w_q: usize,
761 scale: usize,
762 global_scale: usize,
763 dst: usize,
764 m: u32,
765 k: u32,
766 n: u32,
767 },
768
769 LoraMatMul {
773 x: usize,
774 w: usize,
775 a: usize,
776 b: usize,
777 dst: usize,
778 m: u32,
779 k: u32,
780 n: u32,
781 r: u32,
782 scale: f32,
783 },
784 Sample {
788 logits: usize,
789 dst: usize,
790 batch: u32,
791 vocab: u32,
792 top_k: u32, top_p: f32, temperature: f32, seed: u64,
796 },
797 Attention {
808 q: usize,
809 k: usize,
810 v: usize,
811 mask: usize,
812 out: usize,
813 batch: u32,
814 seq: u32,
816 kv_seq: u32,
818 heads: u32,
819 head_dim: u32,
820 mask_kind: rlx_ir::op::MaskKind,
821 q_row_stride: u32,
822 k_row_stride: u32,
823 v_row_stride: u32,
824 bhsd: bool,
832 },
833 AttentionBackward {
835 q: usize,
836 k: usize,
837 v: usize,
838 dy: usize,
839 mask: usize,
840 out: usize,
841 batch: u32,
842 seq: u32,
843 kv_seq: u32,
844 heads: u32,
845 head_dim: u32,
846 mask_kind: rlx_ir::op::MaskKind,
847 wrt: rlx_ir::op::AttentionBwdWrt,
848 bhsd: bool,
849 },
850 Rope {
856 src: usize,
857 cos: usize,
858 sin: usize,
859 dst: usize,
860 batch: u32,
861 seq: u32,
862 hidden: u32,
863 head_dim: u32,
864 n_rot: u32,
865 cos_len: u32,
866 src_row_stride: u32,
867 },
868 FusedAttnBlock {
871 hidden: usize,
872 qkv_w: usize,
873 out_w: usize,
874 mask: usize,
875 out: usize,
876 qkv_b: usize,
877 out_b: usize, cos: usize,
879 sin: usize,
880 cos_len: u32, batch: u32,
882 seq: u32,
883 hs: u32,
884 nh: u32,
885 dh: u32,
886 has_bias: bool,
887 has_rope: bool,
888 },
889 FusedBertLayer {
892 hidden: usize,
894 qkv_w: usize,
895 qkv_b: usize,
896 out_w: usize,
897 out_b: usize,
898 mask: usize,
899 ln1_g: usize,
901 ln1_b: usize,
902 eps1: f32,
903 fc1_w: usize,
905 fc1_b: usize,
906 fc2_w: usize,
907 fc2_b: usize,
908 ln2_g: usize,
910 ln2_b: usize,
911 eps2: f32,
912 out: usize,
914 batch: u32,
916 seq: u32,
917 hs: u32,
918 nh: u32,
919 dh: u32,
920 int_dim: u32,
921 },
922 FusedNomicLayer {
924 hidden: usize,
925 qkv_w: usize,
926 out_w: usize,
927 mask: usize,
928 cos: usize,
929 sin: usize,
930 cos_len: u32,
931 ln1_g: usize,
932 ln1_b: usize,
933 eps1: f32,
934 fc11_w: usize,
935 fc12_w: usize,
936 fc2_w: usize,
937 ln2_g: usize,
938 ln2_b: usize,
939 eps2: f32,
940 out: usize,
941 batch: u32,
942 seq: u32,
943 hs: u32,
944 nh: u32,
945 dh: u32,
946 int_dim: u32,
947 },
948 FusedSwiGLU {
952 src: usize,
953 dst: usize,
954 n_half: u32,
955 total: u32,
956 gate_first: bool,
957 },
958 Concat {
963 dst: usize,
964 outer: u32,
965 inner: u32,
966 total_axis: u32,
967 inputs: Vec<(usize, u32)>,
968 },
969 Compare {
971 lhs: usize,
972 rhs: usize,
973 dst: usize,
974 len: u32,
975 op: CmpOp,
976 inputs_i64: u8,
978 inputs_elem_bytes: u8,
980 dst_elem_bytes: u8,
982 },
983 Reduce {
991 src: usize,
992 dst: usize,
993 outer: u32,
994 reduced: u32,
995 inner: u32,
996 op: ReduceOp,
997 },
998 TopK {
1002 src: usize,
1003 dst: usize,
1004 outer: u32,
1005 axis_dim: u32,
1006 k: u32,
1007 indices_i64: u8,
1008 },
1009 GroupedMatMul {
1013 input: usize,
1014 weight: usize,
1015 expert_idx: usize,
1016 dst: usize,
1017 m: u32,
1018 k_dim: u32,
1019 n: u32,
1020 num_experts: u32,
1021 },
1022 DequantGroupedMatMulGguf {
1024 input: usize,
1025 w_q: usize,
1026 expert_idx: usize,
1027 dst: usize,
1028 m: u32,
1029 k_dim: u32,
1030 n: u32,
1031 num_experts: u32,
1032 scheme: rlx_ir::quant::QuantScheme,
1033 },
1034 DequantMoEWeightsGguf {
1036 w_q: usize,
1037 dst: usize,
1038 k_dim: u32,
1039 n: u32,
1040 num_experts: u32,
1041 scheme: rlx_ir::quant::QuantScheme,
1042 },
1043 ScatterAdd {
1046 updates: usize,
1047 indices: usize,
1048 dst: usize,
1049 num_updates: u32,
1050 out_dim: u32,
1051 trailing: u32,
1052 },
1053 Where {
1055 cond: usize,
1056 on_true: usize,
1057 on_false: usize,
1058 dst: usize,
1059 len: u32,
1060 elem_bytes: u8,
1061 cond_elem_bytes: u8,
1063 },
1064 Transpose {
1070 src: usize,
1071 dst: usize,
1072 in_total: u32,
1073 out_dims: Vec<u32>,
1074 in_strides: Vec<u32>,
1075 elem_bytes: u8,
1076 },
1077 GatherAxis {
1082 table: usize,
1083 idx: usize,
1084 dst: usize,
1085 outer: u32,
1086 axis_dim: u32,
1087 num_idx: u32,
1088 trailing: u32,
1089 idx_i64: u8,
1090 table_bytes: u8,
1091 },
1092 Pool2D {
1096 src: usize,
1097 dst: usize,
1098 n: u32,
1099 c: u32,
1100 h: u32,
1101 w: u32,
1102 h_out: u32,
1103 w_out: u32,
1104 kh: u32,
1105 kw: u32,
1106 sh: u32,
1107 sw: u32,
1108 ph: u32,
1109 pw: u32,
1110 kind: ReduceOp,
1111 },
1112 Conv2D {
1117 src: usize,
1118 weight: usize,
1119 dst: usize,
1120 n: u32,
1121 c_in: u32,
1122 h: u32,
1123 w: u32,
1124 c_out: u32,
1125 h_out: u32,
1126 w_out: u32,
1127 kh: u32,
1128 kw: u32,
1129 sh: u32,
1130 sw: u32,
1131 ph: u32,
1132 pw: u32,
1133 dh: u32,
1134 dw: u32,
1135 groups: u32,
1136 },
1137
1138 QMatMul {
1146 x: usize,
1147 w: usize,
1148 bias: usize,
1149 out: usize,
1150 m: u32,
1151 k: u32,
1152 n: u32,
1153 x_zp: i32,
1154 w_zp: i32,
1155 out_zp: i32,
1156 mult: f32,
1157 },
1158
1159 QConv2d {
1163 x: usize,
1164 w: usize,
1165 bias: usize,
1166 out: usize,
1167 n: u32,
1168 c_in: u32,
1169 h: u32,
1170 w_in: u32,
1171 c_out: u32,
1172 h_out: u32,
1173 w_out: u32,
1174 kh: u32,
1175 kw: u32,
1176 sh: u32,
1177 sw: u32,
1178 ph: u32,
1179 pw: u32,
1180 dh: u32,
1181 dw: u32,
1182 groups: u32,
1183 x_zp: i32,
1184 w_zp: i32,
1185 out_zp: i32,
1186 mult: f32,
1187 },
1188
1189 Quantize {
1196 x: usize,
1197 q: usize,
1198 len: u32,
1199 chan_axis: u32,
1200 chan_dim: u32,
1201 inner: u32,
1202 scales: Vec<f32>,
1203 zero_points: Vec<i32>,
1204 },
1205
1206 Dequantize {
1208 q: usize,
1209 x: usize,
1210 len: u32,
1211 chan_axis: u32,
1212 chan_dim: u32,
1213 inner: u32,
1214 scales: Vec<f32>,
1215 zero_points: Vec<i32>,
1216 },
1217
1218 FakeQuantize {
1229 x: usize,
1230 out: usize,
1231 len: u32,
1232 chan_axis: u32,
1233 chan_dim: u32,
1234 inner: u32,
1235 bits: u8,
1236 ste: rlx_ir::op::SteKind,
1240 scale_mode: rlx_ir::op::ScaleMode,
1245 state_off: Option<usize>,
1249 },
1250
1251 FakeQuantizeBackward {
1256 x: usize,
1257 dy: usize,
1258 dx: usize,
1259 len: u32,
1260 chan_axis: u32,
1261 chan_dim: u32,
1262 inner: u32,
1263 bits: u8,
1264 ste: rlx_ir::op::SteKind,
1265 },
1266
1267 FakeQuantizeLSQ {
1270 x: usize,
1271 scale_off: usize,
1272 out: usize,
1273 len: u32,
1274 chan_axis: u32,
1275 chan_dim: u32,
1276 inner: u32,
1277 bits: u8,
1278 },
1279
1280 FakeQuantizeLSQBackwardX {
1283 x: usize,
1284 scale_off: usize,
1285 dy: usize,
1286 dx: usize,
1287 len: u32,
1288 chan_axis: u32,
1289 chan_dim: u32,
1290 inner: u32,
1291 bits: u8,
1292 },
1293
1294 FakeQuantizeLSQBackwardScale {
1299 x: usize,
1300 scale_off: usize,
1301 dy: usize,
1302 dscale: usize,
1303 len: u32,
1304 chan_axis: u32,
1305 chan_dim: u32,
1306 inner: u32,
1307 bits: u8,
1308 },
1309
1310 ReluBackward {
1312 x: usize,
1313 dy: usize,
1314 dx: usize,
1315 len: u32,
1316 },
1317 ReluBackwardF64 {
1323 x: usize,
1324 dy: usize,
1325 dx: usize,
1326 len: u32,
1327 },
1328
1329 ActivationBackward {
1334 x: usize,
1335 dy: usize,
1336 dx: usize,
1337 len: u32,
1338 kind: Activation,
1339 },
1340 ActivationBackwardF64 {
1346 x: usize,
1347 dy: usize,
1348 dx: usize,
1349 len: u32,
1350 kind: Activation,
1351 },
1352
1353 LayerNormBackwardInput {
1356 x: usize,
1357 gamma: usize,
1358 dy: usize,
1359 dx: usize,
1360 rows: u32,
1361 h: u32,
1362 eps: f32,
1363 },
1364
1365 LayerNormBackwardGamma {
1367 x: usize,
1368 dy: usize,
1369 dgamma: usize,
1370 rows: u32,
1371 h: u32,
1372 eps: f32,
1373 },
1374
1375 RmsNormBackwardInput {
1376 x: usize,
1377 gamma: usize,
1378 beta: usize,
1379 dy: usize,
1380 dx: usize,
1381 rows: u32,
1382 h: u32,
1383 eps: f32,
1384 },
1385 RmsNormBackwardGamma {
1386 x: usize,
1387 gamma: usize,
1388 beta: usize,
1389 dy: usize,
1390 dgamma: usize,
1391 rows: u32,
1392 h: u32,
1393 eps: f32,
1394 },
1395 RmsNormBackwardBeta {
1396 x: usize,
1397 gamma: usize,
1398 beta: usize,
1399 dy: usize,
1400 dbeta: usize,
1401 rows: u32,
1402 h: u32,
1403 eps: f32,
1404 },
1405 RopeBackward {
1406 dy: usize,
1407 cos: usize,
1408 sin: usize,
1409 dx: usize,
1410 batch: u32,
1411 seq: u32,
1412 hidden: u32,
1413 head_dim: u32,
1414 n_rot: u32,
1415 cos_len: u32,
1416 },
1417 CumsumBackward {
1418 dy: usize,
1419 dx: usize,
1420 rows: u32,
1421 cols: u32,
1422 exclusive: bool,
1423 },
1424 GatherBackward {
1425 dy: usize,
1426 indices: usize,
1427 dst: usize,
1428 outer: u32,
1429 axis_dim: u32,
1430 num_idx: u32,
1431 trailing: u32,
1432 },
1433
1434 GroupNormBackwardInput {
1435 x: usize,
1436 gamma: usize,
1437 beta: usize,
1438 dy: usize,
1439 dx: usize,
1440 n: u32,
1441 c: u32,
1442 h: u32,
1443 w: u32,
1444 num_groups: u32,
1445 eps: f32,
1446 },
1447 GroupNormBackwardGamma {
1448 x: usize,
1449 dy: usize,
1450 dgamma: usize,
1451 n: u32,
1452 c: u32,
1453 h: u32,
1454 w: u32,
1455 num_groups: u32,
1456 eps: f32,
1457 },
1458 GroupNormBackwardBeta {
1459 dy: usize,
1460 dbeta: usize,
1461 n: u32,
1462 c: u32,
1463 h: u32,
1464 w: u32,
1465 },
1466
1467 MaxPool2dBackward {
1473 x: usize,
1474 dy: usize,
1475 dx: usize,
1476 n: u32,
1477 c: u32,
1478 h: u32,
1479 w: u32,
1480 h_out: u32,
1481 w_out: u32,
1482 kh: u32,
1483 kw: u32,
1484 sh: u32,
1485 sw: u32,
1486 ph: u32,
1487 pw: u32,
1488 },
1489
1490 Conv2dBackwardInput {
1494 dy: usize,
1495 w: usize,
1496 dx: usize,
1497 n: u32,
1498 c_in: u32,
1499 h: u32,
1500 w_in: u32,
1501 c_out: u32,
1502 h_out: u32,
1503 w_out: u32,
1504 kh: u32,
1505 kw: u32,
1506 sh: u32,
1507 sw: u32,
1508 ph: u32,
1509 pw: u32,
1510 dh: u32,
1511 dw: u32,
1512 groups: u32,
1513 },
1514
1515 Conv2dBackwardWeight {
1519 x: usize,
1520 dy: usize,
1521 dw: usize,
1522 n: u32,
1523 c_in: u32,
1524 h: u32,
1525 w: u32,
1526 c_out: u32,
1527 h_out: u32,
1528 w_out: u32,
1529 kh: u32,
1530 kw: u32,
1531 sh: u32,
1532 sw: u32,
1533 ph: u32,
1534 pw: u32,
1535 dh: u32,
1536 dw_dil: u32,
1537 groups: u32,
1538 },
1539
1540 Im2Col {
1543 x: usize,
1544 col: usize,
1545 n: u32,
1546 c_in: u32,
1547 h: u32,
1548 w: u32,
1549 h_out: u32,
1550 w_out: u32,
1551 kh: u32,
1552 kw: u32,
1553 sh: u32,
1554 sw: u32,
1555 ph: u32,
1556 pw: u32,
1557 dh: u32,
1558 dw_dil: u32,
1559 },
1560
1561 SoftmaxCrossEntropy {
1565 logits: usize,
1566 labels: usize,
1567 dst: usize,
1568 n: u32,
1569 c: u32,
1570 },
1571
1572 SoftmaxCrossEntropyBackward {
1575 logits: usize,
1576 labels: usize,
1577 d_loss: usize,
1578 dlogits: usize,
1579 n: u32,
1580 c: u32,
1581 },
1582
1583 CustomOp {
1589 kernel: Arc<dyn CpuKernel>,
1590 inputs: Vec<(usize, u32, Shape)>, output: (usize, u32, Shape), attrs: Vec<u8>,
1593 },
1594
1595 GaussianSplatRender {
1605 positions_off: usize,
1606 positions_len: usize,
1607 scales_off: usize,
1608 scales_len: usize,
1609 rotations_off: usize,
1610 rotations_len: usize,
1611 opacities_off: usize,
1612 opacities_len: usize,
1613 colors_off: usize,
1614 colors_len: usize,
1615 sh_coeffs_off: usize,
1616 sh_coeffs_len: usize,
1617 meta_off: usize,
1618 dst_off: usize,
1619 dst_len: usize,
1620 width: u32,
1621 height: u32,
1622 tile_size: u32,
1623 radius_scale: f32,
1624 alpha_cutoff: f32,
1625 max_splat_steps: u32,
1626 transmittance_threshold: f32,
1627 max_list_entries: u32,
1628 },
1629 GaussianSplatRenderBackward {
1630 positions_off: usize,
1631 positions_len: usize,
1632 scales_off: usize,
1633 scales_len: usize,
1634 rotations_off: usize,
1635 rotations_len: usize,
1636 opacities_off: usize,
1637 opacities_len: usize,
1638 colors_off: usize,
1639 colors_len: usize,
1640 sh_coeffs_off: usize,
1641 sh_coeffs_len: usize,
1642 meta_off: usize,
1643 d_loss_off: usize,
1644 d_loss_len: usize,
1645 packed_off: usize,
1646 packed_len: usize,
1647 width: u32,
1648 height: u32,
1649 tile_size: u32,
1650 radius_scale: f32,
1651 alpha_cutoff: f32,
1652 max_splat_steps: u32,
1653 transmittance_threshold: f32,
1654 max_list_entries: u32,
1655 loss_grad_clip: f32,
1656 sh_band: u32,
1657 max_anisotropy: f32,
1658 },
1659 GaussianSplatPrepare {
1661 positions_off: usize,
1662 positions_len: usize,
1663 scales_off: usize,
1664 scales_len: usize,
1665 rotations_off: usize,
1666 rotations_len: usize,
1667 opacities_off: usize,
1668 opacities_len: usize,
1669 colors_off: usize,
1670 colors_len: usize,
1671 sh_coeffs_off: usize,
1672 sh_coeffs_len: usize,
1673 meta_off: usize,
1674 meta_len: usize,
1675 prep_off: usize,
1676 prep_len: usize,
1677 width: u32,
1678 height: u32,
1679 tile_size: u32,
1680 radius_scale: f32,
1681 alpha_cutoff: f32,
1682 max_splat_steps: u32,
1683 transmittance_threshold: f32,
1684 max_list_entries: u32,
1685 },
1686 GaussianSplatRasterize {
1688 prep_off: usize,
1689 prep_len: usize,
1690 meta_off: usize,
1691 meta_len: usize,
1692 dst_off: usize,
1693 dst_len: usize,
1694 count: usize,
1695 width: u32,
1696 height: u32,
1697 tile_size: u32,
1698 alpha_cutoff: f32,
1699 max_splat_steps: u32,
1700 transmittance_threshold: f32,
1701 max_list_entries: u32,
1702 },
1703 Fft1d {
1704 src: usize,
1705 dst: usize,
1706 outer: u32,
1707 n_complex: u32,
1708 inverse: bool,
1709 norm_tag: u32,
1710 dtype: rlx_ir::DType,
1711 },
1712 FftButterflyStage {
1713 state_src: usize,
1714 state_dst: usize,
1715 gate_src: usize,
1716 rev_src: usize,
1717 tw_re_src: usize,
1718 tw_im_src: usize,
1719 batch: u32,
1720 n_fft: u32,
1721 stage: u32,
1722 },
1723 LogMel {
1724 spec: usize,
1725 filters: usize,
1726 dst: usize,
1727 outer: u32,
1728 n_fft: u32,
1729 n_bins: u32,
1730 n_mels: u32,
1731 },
1732 LogMelBackward {
1733 spec: usize,
1734 filters: usize,
1735 dy: usize,
1736 dst: usize,
1737 outer: u32,
1738 n_fft: u32,
1739 n_bins: u32,
1740 n_mels: u32,
1741 },
1742 WelchPeaks {
1743 spec: usize,
1744 dst: usize,
1745 welch_batch: u32,
1746 n_fft: u32,
1747 n_segments: u32,
1748 k: u32,
1749 },
1750}
1751
1752#[derive(Clone)]
1755pub struct ThunkSchedule {
1756 pub thunks: Vec<Thunk>,
1757 pub moe_resident: Option<std::sync::Arc<[bool]>>,
1759 pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1761 pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1763 pub mask_threshold: f32,
1765 pub mask_neg_inf: f32,
1766 pub score_skip: f32,
1767 pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1773}
1774
1775impl ThunkSchedule {
1776 pub fn strip_nops(&mut self) {
1777 self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1778 self.compiled_fns.clear();
1781 }
1782}
1783
1784fn node_offset(arena: &Arena, id: NodeId) -> usize {
1786 if arena.has_buffer(id) {
1787 arena.byte_offset(id)
1788 } else {
1789 usize::MAX
1790 }
1791}
1792
1793fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1799 match t {
1800 Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1801 Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1802 Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1803 Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1804 Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1805 Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1806 Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1807 Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1808 Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1809 Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1810 Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1811 Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1812 Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1813 Thunk::ConjugateC64 { src, .. } => vec![*src],
1814 Thunk::Scan {
1815 outer_init_off,
1816 xs_inputs,
1817 ..
1818 } => {
1819 let mut v = vec![*outer_init_off];
1820 for (_, outer_xs_off, _) in xs_inputs.iter() {
1821 v.push(*outer_xs_off);
1822 }
1823 v
1824 }
1825 Thunk::ScanBackward {
1826 outer_init_off,
1827 outer_traj_off,
1828 outer_upstream_off,
1829 outer_xs_offs,
1830 ..
1831 } => {
1832 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1833 for (off, _) in outer_xs_offs.iter() {
1834 v.push(*off);
1835 }
1836 v
1837 }
1838 Thunk::ScanBackwardXs {
1839 outer_init_off,
1840 outer_traj_off,
1841 outer_upstream_off,
1842 outer_xs_offs,
1843 ..
1844 } => {
1845 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1846 for (off, _) in outer_xs_offs.iter() {
1847 v.push(*off);
1848 }
1849 v
1850 }
1851 Thunk::CustomFn { inputs, .. } => {
1852 inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1853 }
1854 Thunk::ActivationInPlace { data, .. } => vec![*data],
1855 Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1856 vec![*src, *g, *b]
1857 }
1858 Thunk::BatchNormInference {
1859 src,
1860 g,
1861 b,
1862 mean,
1863 var,
1864 ..
1865 } => vec![*src, *g, *b, *mean, *var],
1866 Thunk::ResizeNearest2x { src, .. } => vec![*src],
1867 Thunk::AxialRope2d { src, .. } => vec![*src],
1868 Thunk::FusedResidualLN {
1869 x, res, bias, g, b, ..
1870 } => vec![*x, *res, *bias, *g, *b],
1871 Thunk::FusedResidualRmsNorm {
1872 x, res, bias, g, b, ..
1873 } => vec![*x, *res, *bias, *g, *b],
1874 Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1875 Thunk::Softmax { data, .. } => vec![*data],
1876 Thunk::Cumsum { src, .. } => vec![*src],
1877 Thunk::Sample { logits, .. } => vec![*logits],
1878 Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1879 Thunk::DequantMatMul {
1880 x, w_q, scale, zp, ..
1881 } => vec![*x, *w_q, *scale, *zp],
1882 Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1883 Thunk::DequantMatMulInt4 {
1884 x, w_q, scale, zp, ..
1885 } => vec![*x, *w_q, *scale, *zp],
1886 Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1887 Thunk::DequantMatMulNvfp4 {
1888 x,
1889 w_q,
1890 scale,
1891 global_scale,
1892 ..
1893 } => vec![*x, *w_q, *scale, *global_scale],
1894 Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1895 Thunk::SelectiveScan {
1896 x, delta, a, b, c, ..
1897 } => vec![*x, *delta, *a, *b, *c],
1898 Thunk::GatedDeltaNet {
1899 q,
1900 k,
1901 v,
1902 g,
1903 beta,
1904 state,
1905 ..
1906 } => {
1907 let mut v = vec![*q, *k, *v, *g, *beta];
1908 if *state != 0 {
1909 v.push(*state);
1910 }
1911 v
1912 }
1913 Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1914 Thunk::AttentionBackward {
1915 q, k, v, dy, mask, ..
1916 } => {
1917 let mut v = vec![*q, *k, *v, *dy];
1918 if *mask != 0 {
1919 v.push(*mask);
1920 }
1921 v
1922 }
1923 Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1924 Thunk::FusedAttnBlock {
1925 hidden,
1926 qkv_w,
1927 out_w,
1928 mask,
1929 qkv_b,
1930 out_b,
1931 cos,
1932 sin,
1933 ..
1934 } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1935 Thunk::FusedSwiGLU { src, .. } => vec![*src],
1936 Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1937 Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1938 Thunk::Narrow { src, .. } => vec![*src],
1939 Thunk::Copy { src, .. } => vec![*src],
1940 Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1941 _ => vec![],
1945 }
1946}
1947
1948#[allow(clippy::too_many_arguments)]
1962fn dequant_matmul_int8(
1963 x: &[f32], w_bytes: &[i8], scales: &[f32], zps: &[f32], out: &mut [f32], m: usize,
1969 k: usize,
1970 n: usize,
1971 block_size: usize,
1972 asym: bool,
1973) {
1974 let blocks_per_col = k.div_ceil(block_size);
1975 for i in 0..m {
1976 for j in 0..n {
1977 let mut acc = 0f32;
1978 for p in 0..k {
1979 let block = p / block_size;
1980 let s = scales[block * n + j];
1981 let z = if asym { zps[block * n + j] } else { 0.0 };
1982 let q = w_bytes[p * n + j] as f32;
1983 let dequantized = (q - z) * s;
1984 acc += x[i * k + p] * dequantized;
1985 }
1986 out[i * n + j] = acc;
1987 }
1988 }
1989 let _ = blocks_per_col;
1990}
1991
1992#[allow(clippy::too_many_arguments)]
1993fn dequant_matmul_int4(
1994 x: &[f32],
1995 w_bytes: &[u8],
1996 scales: &[f32],
1997 zps: &[f32],
1998 out: &mut [f32],
1999 m: usize,
2000 k: usize,
2001 n: usize,
2002 block_size: usize,
2003 asym: bool,
2004) {
2005 for i in 0..m {
2006 for j in 0..n {
2007 let mut acc = 0f32;
2008 for p in 0..k {
2009 let block = p / block_size;
2010 let s = scales[block * n + j];
2011 let z = if asym { zps[block * n + j] } else { 0.0 };
2012 let byte_idx = (p * n + j) / 2;
2013 let nibble = if (p * n + j) & 1 == 0 {
2014 w_bytes[byte_idx] & 0x0F
2015 } else {
2016 w_bytes[byte_idx] >> 4
2017 };
2018 let dequantized = (nibble as f32 - z) * s;
2019 acc += x[i * k + p] * dequantized;
2020 }
2021 out[i * n + j] = acc;
2022 }
2023 }
2024}
2025
2026fn fp8_e4m3_to_f32(b: u8) -> f32 {
2027 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2028 let exp = (b >> 3) & 0x0F;
2029 let mant = b & 0x07;
2030 if exp == 0 {
2031 if mant == 0 {
2032 return 0.0;
2033 }
2034 return sign * (mant as f32) * 2f32.powi(-9);
2035 }
2036 if exp == 0x0F {
2037 return if mant == 0 {
2038 sign * f32::INFINITY
2039 } else {
2040 f32::NAN
2041 };
2042 }
2043 sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
2044}
2045
2046fn fp8_e5m2_to_f32(b: u8) -> f32 {
2047 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2048 let exp = (b >> 2) & 0x1F;
2049 let mant = b & 0x03;
2050 if exp == 0 {
2051 if mant == 0 {
2052 return 0.0;
2053 }
2054 return sign * (mant as f32) * 2f32.powi(-16);
2055 }
2056 if exp == 0x1F {
2057 return if mant == 0 {
2058 sign * f32::INFINITY
2059 } else {
2060 f32::NAN
2061 };
2062 }
2063 sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
2064}
2065
2066#[allow(clippy::too_many_arguments)]
2067fn dequant_matmul_fp8(
2068 x: &[f32],
2069 w_bytes: &[u8],
2070 scales: &[f32],
2071 out: &mut [f32],
2072 m: usize,
2073 k: usize,
2074 n: usize,
2075 e5m2: bool,
2076) {
2077 let dequant = if e5m2 {
2078 fp8_e5m2_to_f32
2079 } else {
2080 fp8_e4m3_to_f32
2081 };
2082 for i in 0..m {
2083 for j in 0..n {
2084 let mut acc = 0f32;
2085 for p in 0..k {
2086 let w = dequant(w_bytes[p * n + j]);
2087 let s = scales.get(j).copied().unwrap_or(1.0);
2088 acc += x[i * k + p] * w * s;
2089 }
2090 out[i * n + j] = acc;
2091 }
2092 }
2093}
2094
2095#[allow(clippy::too_many_arguments)]
2096pub fn dequant_matmul_nvfp4(
2097 x: &[f32],
2098 w_bytes: &[u8],
2099 scale_bytes: &[u8],
2100 global_scale: f32,
2101 out: &mut [f32],
2102 m: usize,
2103 k: usize,
2104 n: usize,
2105) {
2106 use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
2107 let gs = NVFP4_GROUP_SIZE;
2108 for i in 0..m {
2109 for j in 0..n {
2110 let mut acc = 0f32;
2111 for p in 0..k {
2112 let byte_idx = (p * n + j) / 2;
2113 let nibble = if (p * n + j) & 1 == 0 {
2114 w_bytes[byte_idx] & 0x0F
2115 } else {
2116 w_bytes[byte_idx] >> 4
2117 };
2118 let block = p / gs;
2119 let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
2120 let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
2121 acc += x[i * k + p] * w;
2122 }
2123 out[i * n + j] = acc;
2124 }
2125 }
2126}
2127
2128fn sample_row(
2137 logits: &[f32],
2138 top_k: usize,
2139 top_p: f32,
2140 temperature: f32,
2141 rng: &mut rlx_ir::Philox4x32,
2142) -> usize {
2143 let v = logits.len();
2144 if v == 0 {
2145 return 0;
2146 }
2147 let temp = temperature.max(1e-6);
2148 let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2150
2151 if top_k > 0 && top_k < v {
2153 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2155 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2158 let cutoff = indexed[top_k - 1].1;
2159 for x in scaled.iter_mut() {
2160 if *x < cutoff {
2161 *x = f32::NEG_INFINITY;
2162 }
2163 }
2164 }
2165
2166 let mut max_l = f32::NEG_INFINITY;
2168 for &x in &scaled {
2169 if x > max_l {
2170 max_l = x;
2171 }
2172 }
2173 let mut sum = 0.0f32;
2174 for x in scaled.iter_mut() {
2175 *x = (*x - max_l).exp();
2176 sum += *x;
2177 }
2178 let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2179 for x in scaled.iter_mut() {
2180 *x *= inv;
2181 }
2182
2183 if top_p < 1.0 {
2186 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2187 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2188 let mut cum = 0.0f32;
2189 let mut keep = vec![false; v];
2190 for (idx, p) in indexed.iter() {
2191 keep[*idx] = true;
2192 cum += *p;
2193 if cum >= top_p {
2194 break;
2195 }
2196 }
2197 let mut new_sum = 0.0f32;
2198 for (i, x) in scaled.iter_mut().enumerate() {
2199 if !keep[i] {
2200 *x = 0.0;
2201 }
2202 new_sum += *x;
2203 }
2204 let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2205 for x in scaled.iter_mut() {
2206 *x *= inv;
2207 }
2208 }
2209
2210 let r = rng.next_f32();
2212 let mut acc = 0.0f32;
2213 for (i, &p) in scaled.iter().enumerate() {
2214 acc += p;
2215 if r <= acc {
2216 return i;
2217 }
2218 }
2219 v - 1 }
2221
2222#[inline]
2226fn apply_synthetic_mask(
2227 scores: &mut [f32],
2228 q_seq: usize,
2229 k_seq: usize,
2230 kind: rlx_ir::op::MaskKind,
2231) {
2232 let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2233 let q_offset = k_seq.saturating_sub(q_seq);
2234 match kind {
2235 rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2236 rlx_ir::op::MaskKind::Causal => {
2237 for qi in 0..q_seq {
2238 let abs_q = q_offset + qi;
2239 for ki in (abs_q + 1)..k_seq {
2240 scores[qi * k_seq + ki] = neg;
2241 }
2242 }
2243 }
2244 rlx_ir::op::MaskKind::SlidingWindow(w) => {
2245 for qi in 0..q_seq {
2246 let abs_q = q_offset + qi;
2247 let lo = abs_q.saturating_sub(w);
2248 for ki in 0..k_seq {
2249 if ki < lo || ki > abs_q {
2250 scores[qi * k_seq + ki] = neg;
2251 }
2252 }
2253 }
2254 }
2255 }
2256}
2257
2258fn conv_nchw_dims(shape: &Shape) -> (u32, u32, u32, u32) {
2260 match shape.rank() {
2261 3 => (
2262 shape.dim(0).unwrap_static() as u32,
2263 shape.dim(1).unwrap_static() as u32,
2264 1,
2265 shape.dim(2).unwrap_static() as u32,
2266 ),
2267 4 => (
2268 shape.dim(0).unwrap_static() as u32,
2269 shape.dim(1).unwrap_static() as u32,
2270 shape.dim(2).unwrap_static() as u32,
2271 shape.dim(3).unwrap_static() as u32,
2272 ),
2273 r => panic!("conv_nchw_dims: expected rank 3 or 4, got {r}"),
2274 }
2275}
2276
2277pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2279 let mut thunks = Vec::with_capacity(graph.len());
2280
2281 for node in graph.nodes() {
2282 if rlx_opt::is_pure_view(graph, node) {
2286 thunks.push(Thunk::Nop);
2287 continue;
2288 }
2289 let t = match &node.op {
2290 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2291
2292 Op::FusedMatMulBiasAct { activation } => {
2293 let shape = &node.shape;
2294 let n = shape.dim(shape.rank() - 1).unwrap_static();
2295 let total = shape.num_elements().unwrap();
2296 let m = total / n;
2297 let a_len = get_len(graph, node.inputs[0]);
2298 let k = a_len / m;
2299 Thunk::FusedMmBiasAct {
2300 a: node_offset(arena, node.inputs[0]),
2301 w: node_offset(arena, node.inputs[1]),
2302 bias: node_offset(arena, node.inputs[2]),
2303 c: node_offset(arena, node.id),
2304 m: m as u32,
2305 k: k as u32,
2306 n: n as u32,
2307 act: *activation,
2308 }
2309 }
2310
2311 Op::FusedResidualLN { has_bias, eps } => {
2312 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2313 let total = node.shape.num_elements().unwrap();
2314 let rows = total / h;
2315 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2316 Thunk::FusedResidualLN {
2317 x: node_offset(arena, node.inputs[0]),
2318 res: node_offset(arena, node.inputs[1]),
2319 bias: if *has_bias {
2320 node_offset(arena, node.inputs[2])
2321 } else {
2322 0
2323 },
2324 g: node_offset(arena, node.inputs[g_idx]),
2325 b: node_offset(arena, node.inputs[b_idx]),
2326 out: node_offset(arena, node.id),
2327 rows: rows as u32,
2328 h: h as u32,
2329 eps: *eps,
2330 has_bias: *has_bias,
2331 }
2332 }
2333
2334 Op::FusedResidualRmsNorm { has_bias, eps } => {
2335 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2336 let total = node.shape.num_elements().unwrap();
2337 let rows = total / h;
2338 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2339 Thunk::FusedResidualRmsNorm {
2340 x: node_offset(arena, node.inputs[0]),
2341 res: node_offset(arena, node.inputs[1]),
2342 bias: if *has_bias {
2343 node_offset(arena, node.inputs[2])
2344 } else {
2345 0
2346 },
2347 g: node_offset(arena, node.inputs[g_idx]),
2348 b: node_offset(arena, node.inputs[b_idx]),
2349 out: node_offset(arena, node.id),
2350 rows: rows as u32,
2351 h: h as u32,
2352 eps: *eps,
2353 has_bias: *has_bias,
2354 }
2355 }
2356
2357 Op::MatMul => {
2358 let shape = &node.shape;
2359 let a_shape = &graph.node(node.inputs[0]).shape;
2360 let b_shape = &graph.node(node.inputs[1]).shape;
2361 let eff =
2364 rlx_ir::shape::matmul_shape(a_shape, b_shape).unwrap_or_else(|_| shape.clone());
2365 let rank = eff.rank().max(2);
2366 let n = eff.dim(rank - 1).unwrap_static();
2367 let k_dim = a_shape.dim(a_shape.rank().max(2) - 1).unwrap_static();
2368 let both_batched = a_shape.rank() >= 3 && b_shape.rank() >= 3;
2371 let batched_3d = rank >= 3 && both_batched && a_shape.rank() + b_shape.rank() > 4;
2372 if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2373 let mut batch_prod = 1usize;
2374 for d in 0..rank - 2 {
2375 batch_prod *= eff.dim(d).unwrap_static();
2376 }
2377 let m_dim = eff.dim(rank - 2).unwrap_static();
2378 Thunk::BatchedDgemmF64 {
2379 a: node_offset(arena, node.inputs[0]),
2380 b: node_offset(arena, node.inputs[1]),
2381 c: node_offset(arena, node.id),
2382 batch: batch_prod as u32,
2383 m: m_dim as u32,
2384 k: k_dim as u32,
2385 n: n as u32,
2386 }
2387 } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2388 let mut batch_prod = 1usize;
2389 for d in 0..rank - 2 {
2390 batch_prod *= eff.dim(d).unwrap_static();
2391 }
2392 let m_dim = eff.dim(rank - 2).unwrap_static();
2393 Thunk::BatchedSgemm {
2394 a: node_offset(arena, node.inputs[0]),
2395 b: node_offset(arena, node.inputs[1]),
2396 c: node_offset(arena, node.id),
2397 batch: batch_prod as u32,
2398 m: m_dim as u32,
2399 k: k_dim as u32,
2400 n: n as u32,
2401 }
2402 } else {
2403 let m = if a_shape.rank() >= 3 && b_shape.rank() <= 2 {
2404 let mut m_prod = 1usize;
2405 for d in 0..a_shape.rank() - 1 {
2406 m_prod *= a_shape.dim(d).unwrap_static();
2407 }
2408 m_prod
2409 } else if a_shape.rank() >= 2 {
2410 a_shape.dim(a_shape.rank() - 2).unwrap_static()
2411 } else {
2412 eff.num_elements().unwrap_or(1) / n.max(1)
2413 };
2414 match shape.dtype() {
2415 rlx_ir::DType::F64 => Thunk::Dgemm {
2416 a: node_offset(arena, node.inputs[0]),
2417 b: node_offset(arena, node.inputs[1]),
2418 c: node_offset(arena, node.id),
2419 m: m as u32,
2420 k: k_dim as u32,
2421 n: n as u32,
2422 },
2423 _ => Thunk::Sgemm {
2424 a: node_offset(arena, node.inputs[0]),
2425 b: node_offset(arena, node.inputs[1]),
2426 c: node_offset(arena, node.id),
2427 m: m as u32,
2428 k: k_dim as u32,
2429 n: n as u32,
2430 },
2431 }
2432 }
2433 }
2434
2435 Op::Binary(op) => {
2436 let lhs_len = get_len(graph, node.inputs[0]);
2437 let rhs_len = get_len(graph, node.inputs[1]);
2438 let out_len = node.shape.num_elements().unwrap();
2439 if node.shape.dtype() == rlx_ir::DType::C64 {
2440 match op {
2444 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2445 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2446 "Op::Binary({op:?}) on DType::C64: complex \
2447 max/min/pow have no single natural definition \
2448 — caller should drop to 2N-real-block (see \
2449 spike-ac) and pick a convention there"
2450 ),
2451 }
2452 }
2453 let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2457 if lhs_len == out_len && rhs_len == out_len {
2458 (Vec::new(), Vec::new(), Vec::new())
2459 } else {
2460 let lhs_dims = get_static_dims(graph, node.inputs[0]);
2461 let rhs_dims = get_static_dims(graph, node.inputs[1]);
2462 let out_dims_v = get_static_dims(graph, node.id);
2463 if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2464 (Vec::new(), Vec::new(), Vec::new())
2469 } else {
2470 let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2471 let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2472 let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2473 (od, ls, rs)
2474 }
2475 };
2476 if node.shape.dtype() == rlx_ir::DType::C64 {
2477 Thunk::BinaryFullC64 {
2478 lhs: node_offset(arena, node.inputs[0]),
2479 rhs: node_offset(arena, node.inputs[1]),
2480 dst: node_offset(arena, node.id),
2481 len: out_len as u32,
2482 lhs_len: lhs_len as u32,
2483 rhs_len: rhs_len as u32,
2484 op: *op,
2485 out_dims_bcast,
2486 bcast_lhs_strides,
2487 bcast_rhs_strides,
2488 }
2489 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2490 Thunk::BinaryFullF64 {
2493 lhs: node_offset(arena, node.inputs[0]),
2494 rhs: node_offset(arena, node.inputs[1]),
2495 dst: node_offset(arena, node.id),
2496 len: out_len as u32,
2497 lhs_len: lhs_len as u32,
2498 rhs_len: rhs_len as u32,
2499 op: *op,
2500 out_dims_bcast,
2501 bcast_lhs_strides,
2502 bcast_rhs_strides,
2503 }
2504 } else if matches!(op, BinaryOp::Add)
2505 && rhs_len < out_len
2506 && out_len % rhs_len == 0
2507 && is_trailing_bias_broadcast(
2508 graph.node(node.inputs[1]).shape.dims(),
2509 graph.node(node.id).shape.dims(),
2510 )
2511 {
2512 Thunk::BiasAdd {
2522 src: node_offset(arena, node.inputs[0]),
2523 bias: node_offset(arena, node.inputs[1]),
2524 dst: node_offset(arena, node.id),
2525 m: (out_len / rhs_len) as u32,
2526 n: rhs_len as u32,
2527 }
2528 } else {
2529 let lhs_len = get_len(graph, node.inputs[0]);
2530 Thunk::BinaryFull {
2531 lhs: node_offset(arena, node.inputs[0]),
2532 rhs: node_offset(arena, node.inputs[1]),
2533 dst: node_offset(arena, node.id),
2534 len: out_len as u32,
2535 lhs_len: lhs_len as u32,
2536 rhs_len: rhs_len as u32,
2537 op: *op,
2538 out_dims_bcast,
2539 bcast_lhs_strides,
2540 bcast_rhs_strides,
2541 elem_bytes: node.shape.dtype().size_bytes() as u8,
2542 }
2543 }
2544 }
2545
2546 Op::Activation(act) => {
2547 let len = node.shape.num_elements().unwrap();
2548 let in_off = node_offset(arena, node.inputs[0]);
2549 let out_off = node_offset(arena, node.id);
2550 if node.shape.dtype() == rlx_ir::DType::C64 {
2551 match act {
2556 Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2557 other => panic!(
2558 "Op::Activation({other:?}) on DType::C64: no \
2559 natural complex extension — supported on C64: \
2560 Neg, Exp, Log, Sqrt"
2561 ),
2562 }
2563 Thunk::ActivationC64 {
2564 src: in_off,
2565 dst: out_off,
2566 len: len as u32,
2567 kind: *act,
2568 }
2569 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2570 Thunk::ActivationF64 {
2571 src: in_off,
2572 dst: out_off,
2573 len: len as u32,
2574 kind: *act,
2575 }
2576 } else if in_off == out_off {
2577 Thunk::ActivationInPlace {
2581 data: out_off,
2582 len: len as u32,
2583 act: *act,
2584 }
2585 } else {
2586 thunks.push(Thunk::Copy {
2590 src: in_off,
2591 dst: out_off,
2592 len: len as u32,
2593 });
2594 Thunk::ActivationInPlace {
2595 data: out_off,
2596 len: len as u32,
2597 act: *act,
2598 }
2599 }
2600 }
2601
2602 Op::Gather { axis } if *axis == 0 => {
2603 let table_shape = &graph.node(node.inputs[0]).shape;
2604 let table_total = table_shape.num_elements().unwrap();
2605 let trailing: usize = (1..table_shape.rank())
2606 .map(|i| table_shape.dim(i).unwrap_static())
2607 .product();
2608 let idx_len = get_len(graph, node.inputs[1]);
2609 let idx_i64 =
2610 u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2611 let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2612 Thunk::Gather {
2613 table: node_offset(arena, node.inputs[0]),
2614 table_len: table_total as u32,
2615 idx: node_offset(arena, node.inputs[1]),
2616 dst: node_offset(arena, node.id),
2617 num_idx: idx_len as u32,
2618 trailing: trailing as u32,
2619 idx_i64,
2620 table_bytes,
2621 }
2622 }
2623
2624 Op::Gather { axis } => {
2625 let table_shape = &graph.node(node.inputs[0]).shape;
2627 let rank = table_shape.rank();
2628 let outer: usize = (0..*axis)
2629 .map(|i| table_shape.dim(i).unwrap_static())
2630 .product::<usize>()
2631 .max(1);
2632 let trailing: usize = (*axis + 1..rank)
2633 .map(|i| table_shape.dim(i).unwrap_static())
2634 .product::<usize>()
2635 .max(1);
2636 let axis_dim = table_shape.dim(*axis).unwrap_static();
2637 let idx_len = get_len(graph, node.inputs[1]);
2638 let idx_i64 =
2639 u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2640 let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2641 Thunk::GatherAxis {
2642 table: node_offset(arena, node.inputs[0]),
2643 idx: node_offset(arena, node.inputs[1]),
2644 dst: node_offset(arena, node.id),
2645 outer: outer as u32,
2646 axis_dim: axis_dim as u32,
2647 num_idx: idx_len as u32,
2648 trailing: trailing as u32,
2649 idx_i64,
2650 table_bytes,
2651 }
2652 }
2653
2654 Op::Narrow { axis, start, len } => {
2655 let in_shape = &graph.node(node.inputs[0]).shape;
2656 let elem_bytes = in_shape.dtype().size_bytes() as u8;
2657 let rank = in_shape.rank();
2658 let outer: usize = (0..*axis)
2659 .map(|i| in_shape.dim(i).unwrap_static())
2660 .product::<usize>()
2661 .max(1);
2662 let inner: usize = (*axis + 1..rank)
2663 .map(|i| in_shape.dim(i).unwrap_static())
2664 .product::<usize>()
2665 .max(1);
2666 let in_axis = in_shape.dim(*axis).unwrap_static();
2667 let src_byte_offset =
2668 node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2669 Thunk::Narrow {
2670 src: src_byte_offset,
2671 dst: node_offset(arena, node.id),
2672 outer: outer as u32,
2673 src_stride: (in_axis * inner) as u32, dst_stride: (*len * inner) as u32, inner: (*len * inner) as u32, elem_bytes,
2677 }
2678 }
2679
2680 Op::Reshape { .. } | Op::StopGradient => {
2681 let len = node.shape.num_elements().unwrap();
2683 let src = node_offset(arena, node.inputs[0]);
2684 let dst = node_offset(arena, node.id);
2685 match node.shape.dtype() {
2686 rlx_ir::DType::F64 => Thunk::CopyF64 {
2687 src,
2688 dst,
2689 len: len as u32,
2690 },
2691 rlx_ir::DType::I64 => Thunk::CopyI64 {
2692 src,
2693 dst,
2694 len: len as u32,
2695 },
2696 _ => Thunk::Copy {
2697 src,
2698 dst,
2699 len: len as u32,
2700 },
2701 }
2702 }
2703
2704 Op::Cast { to } => {
2705 let in_node = graph.node(node.inputs[0]);
2706 let in_dtype = in_node.shape.dtype();
2707 let out_dtype = *to;
2708 let len = node.shape.num_elements().unwrap();
2709 let src = node_offset(arena, node.inputs[0]);
2710 let dst = node_offset(arena, node.id);
2711 if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::I64 {
2712 Thunk::CastF32ToI64 {
2713 src,
2714 dst,
2715 len: len as u32,
2716 }
2717 } else if in_dtype == rlx_ir::DType::I64 && out_dtype == rlx_ir::DType::F32 {
2718 Thunk::CastI64ToF32 {
2719 src,
2720 dst,
2721 len: len as u32,
2722 }
2723 } else if in_dtype == rlx_ir::DType::Bool && out_dtype == rlx_ir::DType::I32 {
2724 Thunk::CastBoolToI32 {
2725 src,
2726 dst,
2727 len: len as u32,
2728 }
2729 } else if in_dtype == rlx_ir::DType::I32 && out_dtype == rlx_ir::DType::F32 {
2730 Thunk::CastI32ToF32 {
2731 src,
2732 dst,
2733 len: len as u32,
2734 }
2735 } else if in_dtype == out_dtype {
2736 match out_dtype {
2737 rlx_ir::DType::F64 => Thunk::CopyF64 {
2738 src,
2739 dst,
2740 len: len as u32,
2741 },
2742 rlx_ir::DType::I64 => Thunk::CopyI64 {
2743 src,
2744 dst,
2745 len: len as u32,
2746 },
2747 _ => Thunk::Copy {
2748 src,
2749 dst,
2750 len: len as u32,
2751 },
2752 }
2753 } else {
2754 Thunk::Copy {
2755 src,
2756 dst,
2757 len: len as u32,
2758 }
2759 }
2760 }
2761
2762 Op::Quantize {
2763 axis,
2764 scales,
2765 zero_points,
2766 } => {
2767 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2768 Thunk::Quantize {
2769 x: node_offset(arena, node.inputs[0]),
2770 q: node_offset(arena, node.id),
2771 len: node.shape.num_elements().unwrap() as u32,
2772 chan_axis: chan_axis as u32,
2773 chan_dim: chan_dim as u32,
2774 inner: inner as u32,
2775 scales: scales.clone(),
2776 zero_points: zero_points.clone(),
2777 }
2778 }
2779
2780 Op::FakeQuantize {
2781 bits,
2782 axis,
2783 ste,
2784 scale_mode,
2785 } => {
2786 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2787 let state_off = match scale_mode {
2788 rlx_ir::op::ScaleMode::PerBatch => None,
2789 rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2790 debug_assert_eq!(
2792 node.inputs.len(),
2793 2,
2794 "EMA/Fixed FakeQuantize needs a state input"
2795 );
2796 Some(node_offset(arena, node.inputs[1]))
2797 }
2798 };
2799 Thunk::FakeQuantize {
2800 x: node_offset(arena, node.inputs[0]),
2801 out: node_offset(arena, node.id),
2802 len: node.shape.num_elements().unwrap() as u32,
2803 chan_axis: chan_axis as u32,
2804 chan_dim: chan_dim as u32,
2805 inner: inner as u32,
2806 bits: *bits,
2807 ste: *ste,
2808 scale_mode: *scale_mode,
2809 state_off,
2810 }
2811 }
2812
2813 Op::FakeQuantizeLSQ { bits, axis } => {
2814 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2815 Thunk::FakeQuantizeLSQ {
2816 x: node_offset(arena, node.inputs[0]),
2817 scale_off: node_offset(arena, node.inputs[1]),
2818 out: node_offset(arena, node.id),
2819 len: node.shape.num_elements().unwrap() as u32,
2820 chan_axis: chan_axis as u32,
2821 chan_dim: chan_dim as u32,
2822 inner: inner as u32,
2823 bits: *bits,
2824 }
2825 }
2826
2827 Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2828 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2829 Thunk::FakeQuantizeLSQBackwardX {
2830 x: node_offset(arena, node.inputs[0]),
2831 scale_off: node_offset(arena, node.inputs[1]),
2832 dy: node_offset(arena, node.inputs[2]),
2833 dx: node_offset(arena, node.id),
2834 len: node.shape.num_elements().unwrap() as u32,
2835 chan_axis: chan_axis as u32,
2836 chan_dim: chan_dim as u32,
2837 inner: inner as u32,
2838 bits: *bits,
2839 }
2840 }
2841
2842 Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2843 let in_shape = &graph.node(node.inputs[0]).shape;
2846 let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2847 Thunk::FakeQuantizeLSQBackwardScale {
2848 x: node_offset(arena, node.inputs[0]),
2849 scale_off: node_offset(arena, node.inputs[1]),
2850 dy: node_offset(arena, node.inputs[2]),
2851 dscale: node_offset(arena, node.id),
2852 len: in_shape.num_elements().unwrap() as u32,
2853 chan_axis: chan_axis as u32,
2854 chan_dim: chan_dim as u32,
2855 inner: inner as u32,
2856 bits: *bits,
2857 }
2858 }
2859
2860 Op::FakeQuantizeBackward { bits, axis, ste } => {
2861 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2862 Thunk::FakeQuantizeBackward {
2863 x: node_offset(arena, node.inputs[0]),
2864 dy: node_offset(arena, node.inputs[1]),
2865 dx: node_offset(arena, node.id),
2866 len: node.shape.num_elements().unwrap() as u32,
2867 chan_axis: chan_axis as u32,
2868 chan_dim: chan_dim as u32,
2869 inner: inner as u32,
2870 bits: *bits,
2871 ste: *ste,
2872 }
2873 }
2874
2875 Op::Dequantize {
2876 axis,
2877 scales,
2878 zero_points,
2879 } => {
2880 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2881 Thunk::Dequantize {
2882 q: node_offset(arena, node.inputs[0]),
2883 x: node_offset(arena, node.id),
2884 len: node.shape.num_elements().unwrap() as u32,
2885 chan_axis: chan_axis as u32,
2886 chan_dim: chan_dim as u32,
2887 inner: inner as u32,
2888 scales: scales.clone(),
2889 zero_points: zero_points.clone(),
2890 }
2891 }
2892
2893 Op::Expand { .. } => {
2894 let in_shape = &graph.node(node.inputs[0]).shape;
2899 let out_shape = &node.shape;
2900 let in_rank = in_shape.rank();
2901 let out_rank = out_shape.rank();
2902 let pad = out_rank.saturating_sub(in_rank);
2904 let in_dims: Vec<usize> = (0..out_rank)
2905 .map(|i| {
2906 if i < pad {
2907 1
2908 } else {
2909 in_shape.dim(i - pad).unwrap_static()
2910 }
2911 })
2912 .collect();
2913 let mut in_strides_full = vec![1usize; out_rank];
2915 for d in (0..out_rank.saturating_sub(1)).rev() {
2916 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2917 }
2918 let out_dims: Vec<u32> = (0..out_rank)
2919 .map(|i| out_shape.dim(i).unwrap_static() as u32)
2920 .collect();
2921 let in_strides: Vec<u32> = (0..out_rank)
2923 .map(|i| {
2924 if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2925 0
2926 } else {
2927 in_strides_full[i] as u32
2928 }
2929 })
2930 .collect();
2931 let in_total = in_dims.iter().product::<usize>() as u32;
2932 let src = node_offset(arena, node.inputs[0]);
2933 let dst = node_offset(arena, node.id);
2934 let elem_bytes = node.shape.dtype().size_bytes() as u8;
2935 match node.shape.dtype() {
2936 rlx_ir::DType::F64 => Thunk::TransposeF64 {
2937 src,
2938 dst,
2939 in_total,
2940 out_dims,
2941 in_strides,
2942 },
2943 _ => Thunk::Transpose {
2944 src,
2945 dst,
2946 in_total,
2947 out_dims,
2948 in_strides,
2949 elem_bytes,
2950 },
2951 }
2952 }
2953
2954 Op::RmsNorm { eps, .. } => {
2955 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2956 let total = node.shape.num_elements().unwrap();
2957 Thunk::RmsNorm {
2958 src: node_offset(arena, node.inputs[0]),
2959 g: node_offset(arena, node.inputs[1]),
2960 b: node_offset(arena, node.inputs[2]),
2961 dst: node_offset(arena, node.id),
2962 rows: (total / h) as u32,
2963 h: h as u32,
2964 eps: *eps,
2965 }
2966 }
2967
2968 Op::LayerNorm { eps, .. } => {
2969 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2970 let total = node.shape.num_elements().unwrap();
2971 Thunk::LayerNorm {
2972 src: node_offset(arena, node.inputs[0]),
2973 g: node_offset(arena, node.inputs[1]),
2974 b: node_offset(arena, node.inputs[2]),
2975 dst: node_offset(arena, node.id),
2976 rows: (total / h) as u32,
2977 h: h as u32,
2978 eps: *eps,
2979 }
2980 }
2981
2982 Op::GroupNorm { num_groups, eps } => {
2983 let in_shape = &graph.node(node.inputs[0]).shape;
2984 let (n, c, h, w) = conv_nchw_dims(in_shape);
2985 Thunk::GroupNorm {
2986 src: node_offset(arena, node.inputs[0]),
2987 g: node_offset(arena, node.inputs[1]),
2988 b: node_offset(arena, node.inputs[2]),
2989 dst: node_offset(arena, node.id),
2990 n,
2991 c,
2992 h,
2993 w,
2994 num_groups: *num_groups as u32,
2995 eps: *eps,
2996 }
2997 }
2998
2999 Op::BatchNormInference { eps } => {
3000 let in_shape = &graph.node(node.inputs[0]).shape;
3001 let rank = in_shape.rank();
3002 let channels = in_shape.dim(rank - 1).unwrap_static();
3003 let total = in_shape.num_elements().unwrap_or(0);
3004 let count = (total / channels.max(1)) as u32;
3005 Thunk::BatchNormInference {
3006 src: node_offset(arena, node.inputs[0]),
3007 g: node_offset(arena, node.inputs[1]),
3008 b: node_offset(arena, node.inputs[2]),
3009 mean: node_offset(arena, node.inputs[3]),
3010 var: node_offset(arena, node.inputs[4]),
3011 dst: node_offset(arena, node.id),
3012 count,
3013 channels: channels as u32,
3014 eps: *eps,
3015 }
3016 }
3017
3018 Op::BatchNormInferenceBackwardInput { eps } => {
3019 let x_shape = &graph.node(node.inputs[0]).shape;
3020 let rank = x_shape.rank();
3021 let channels = x_shape.dim(rank - 1).unwrap_static();
3022 let total = x_shape.num_elements().unwrap_or(0);
3023 Thunk::BatchNormInferenceBackwardInput {
3024 x: node_offset(arena, node.inputs[0]),
3025 gamma: node_offset(arena, node.inputs[1]),
3026 mean: node_offset(arena, node.inputs[2]),
3027 var: node_offset(arena, node.inputs[3]),
3028 dy: node_offset(arena, node.inputs[4]),
3029 dx: node_offset(arena, node.id),
3030 count: (total / channels.max(1)) as u32,
3031 channels: channels as u32,
3032 eps: *eps,
3033 }
3034 }
3035
3036 Op::BatchNormInferenceBackwardGamma { eps } => {
3037 let x_shape = &graph.node(node.inputs[0]).shape;
3038 let rank = x_shape.rank();
3039 let channels = x_shape.dim(rank - 1).unwrap_static();
3040 let total = x_shape.num_elements().unwrap_or(0);
3041 let _gamma_shape = &graph.node(node.id).shape;
3042 Thunk::BatchNormInferenceBackwardGamma {
3043 x: node_offset(arena, node.inputs[0]),
3044 mean: node_offset(arena, node.inputs[1]),
3045 var: node_offset(arena, node.inputs[2]),
3046 dy: node_offset(arena, node.inputs[3]),
3047 dgamma: node_offset(arena, node.id),
3048 count: (total / channels.max(1)) as u32,
3049 channels: channels as u32,
3050 eps: *eps,
3051 }
3052 }
3053
3054 Op::BatchNormInferenceBackwardBeta => {
3055 let dy_shape = &graph.node(node.inputs[0]).shape;
3056 let rank = dy_shape.rank();
3057 let channels = dy_shape.dim(rank - 1).unwrap_static();
3058 let total = dy_shape.num_elements().unwrap_or(0);
3059 Thunk::BatchNormInferenceBackwardBeta {
3060 dy: node_offset(arena, node.inputs[0]),
3061 dbeta: node_offset(arena, node.id),
3062 count: (total / channels.max(1)) as u32,
3063 channels: channels as u32,
3064 }
3065 }
3066
3067 Op::LayerNorm2d { eps } => {
3068 let in_shape = &graph.node(node.inputs[0]).shape;
3069 let (n, c, h, w) = conv_nchw_dims(in_shape);
3070 Thunk::LayerNorm2d {
3071 src: node_offset(arena, node.inputs[0]),
3072 g: node_offset(arena, node.inputs[1]),
3073 b: node_offset(arena, node.inputs[2]),
3074 dst: node_offset(arena, node.id),
3075 n,
3076 c,
3077 h,
3078 w,
3079 eps: *eps,
3080 }
3081 }
3082
3083 Op::ConvTranspose2d {
3084 kernel_size,
3085 stride,
3086 padding,
3087 dilation,
3088 output_padding: _,
3089 groups,
3090 } => {
3091 let in_shape = &graph.node(node.inputs[0]).shape;
3092 let out_shape = &node.shape;
3093 let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3094 let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3095 Thunk::ConvTranspose2d {
3096 src: node_offset(arena, node.inputs[0]),
3097 weight: node_offset(arena, node.inputs[1]),
3098 dst: node_offset(arena, node.id),
3099 n,
3100 c_in,
3101 h,
3102 w_in,
3103 c_out,
3104 h_out,
3105 w_out,
3106 kh: kernel_size[0] as u32,
3107 kw: kernel_size[1] as u32,
3108 sh: stride.first().copied().unwrap_or(1) as u32,
3109 sw: stride.get(1).copied().unwrap_or(1) as u32,
3110 ph: padding.first().copied().unwrap_or(0) as u32,
3111 pw: padding.get(1).copied().unwrap_or(0) as u32,
3112 dh: dilation.first().copied().unwrap_or(1) as u32,
3113 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3114 groups: *groups as u32,
3115 }
3116 }
3117
3118 Op::ResizeNearest2x => {
3119 let in_shape = &graph.node(node.inputs[0]).shape;
3120 let (n, c, h, w) = conv_nchw_dims(in_shape);
3121 Thunk::ResizeNearest2x {
3122 src: node_offset(arena, node.inputs[0]),
3123 dst: node_offset(arena, node.id),
3124 n,
3125 c,
3126 h,
3127 w,
3128 }
3129 }
3130
3131 Op::AxialRope2d {
3132 end_x,
3133 end_y,
3134 head_dim,
3135 num_heads,
3136 theta,
3137 repeat_factor,
3138 } => {
3139 let in_shape = &graph.node(node.inputs[0]).shape;
3140 let batch = in_shape.dim(0).unwrap_static() as u32;
3141 let seq = in_shape.dim(1).unwrap_static() as u32;
3142 let hidden = in_shape.dim(2).unwrap_static() as u32;
3143 Thunk::AxialRope2d {
3144 src: node_offset(arena, node.inputs[0]),
3145 dst: node_offset(arena, node.id),
3146 batch,
3147 seq,
3148 hidden,
3149 end_x: *end_x as u32,
3150 end_y: *end_y as u32,
3151 head_dim: *head_dim as u32,
3152 num_heads: *num_heads as u32,
3153 theta: *theta,
3154 repeat_factor: *repeat_factor as u32,
3155 }
3156 }
3157
3158 Op::Softmax { axis } => {
3159 let rank = node.shape.rank();
3160 let ax = if *axis < 0 {
3161 (rank as i32 + axis) as usize
3162 } else {
3163 *axis as usize
3164 };
3165 let cols = node.shape.dim(ax).unwrap_static();
3166 let total = node.shape.num_elements().unwrap();
3167 let in_off = node_offset(arena, node.inputs[0]);
3168 let out_off = node_offset(arena, node.id);
3169 if in_off != out_off {
3175 thunks.push(Thunk::Copy {
3176 src: in_off,
3177 dst: out_off,
3178 len: total as u32,
3179 });
3180 }
3181 Thunk::Softmax {
3182 data: out_off,
3183 rows: (total / cols) as u32,
3184 cols: cols as u32,
3185 }
3186 }
3187
3188 Op::SelectiveScan { state_size } => {
3189 let in_shape = &graph.node(node.inputs[0]).shape;
3190 let (batch, seq, hidden) = (
3191 in_shape.dim(0).unwrap_static(),
3192 in_shape.dim(1).unwrap_static(),
3193 in_shape.dim(2).unwrap_static(),
3194 );
3195 Thunk::SelectiveScan {
3196 x: node_offset(arena, node.inputs[0]),
3197 delta: node_offset(arena, node.inputs[1]),
3198 a: node_offset(arena, node.inputs[2]),
3199 b: node_offset(arena, node.inputs[3]),
3200 c: node_offset(arena, node.inputs[4]),
3201 dst: node_offset(arena, node.id),
3202 batch: batch as u32,
3203 seq: seq as u32,
3204 hidden: hidden as u32,
3205 state_size: *state_size as u32,
3206 }
3207 }
3208
3209 Op::GatedDeltaNet {
3210 state_size,
3211 carry_state,
3212 } => {
3213 let q_shape = &graph.node(node.inputs[0]).shape;
3214 let (batch, seq, heads) = (
3215 q_shape.dim(0).unwrap_static(),
3216 q_shape.dim(1).unwrap_static(),
3217 q_shape.dim(2).unwrap_static(),
3218 );
3219 let state_off = if *carry_state {
3220 node_offset(arena, node.inputs[5])
3221 } else {
3222 0
3223 };
3224 Thunk::GatedDeltaNet {
3225 q: node_offset(arena, node.inputs[0]),
3226 k: node_offset(arena, node.inputs[1]),
3227 v: node_offset(arena, node.inputs[2]),
3228 g: node_offset(arena, node.inputs[3]),
3229 beta: node_offset(arena, node.inputs[4]),
3230 state: state_off,
3231 dst: node_offset(arena, node.id),
3232 batch: batch as u32,
3233 seq: seq as u32,
3234 heads: heads as u32,
3235 state_size: *state_size as u32,
3236 }
3237 }
3238
3239 Op::QMatMul {
3240 x_zp,
3241 w_zp,
3242 out_zp,
3243 mult,
3244 } => {
3245 let x_shape = &graph.node(node.inputs[0]).shape;
3246 let w_shape = &graph.node(node.inputs[1]).shape;
3247 let m = x_shape.dim(0).unwrap_static();
3248 let k = x_shape.dim(1).unwrap_static();
3249 let n = w_shape.dim(1).unwrap_static();
3250 Thunk::QMatMul {
3251 x: node_offset(arena, node.inputs[0]),
3252 w: node_offset(arena, node.inputs[1]),
3253 bias: node_offset(arena, node.inputs[2]),
3254 out: node_offset(arena, node.id),
3255 m: m as u32,
3256 k: k as u32,
3257 n: n as u32,
3258 x_zp: *x_zp,
3259 w_zp: *w_zp,
3260 out_zp: *out_zp,
3261 mult: *mult,
3262 }
3263 }
3264
3265 Op::QConv2d {
3266 kernel_size,
3267 stride,
3268 padding,
3269 dilation,
3270 groups,
3271 x_zp,
3272 w_zp,
3273 out_zp,
3274 mult,
3275 } => {
3276 let in_shape = &graph.node(node.inputs[0]).shape;
3277 let w_shape = &graph.node(node.inputs[1]).shape;
3278 let out_shape = &node.shape;
3279 if kernel_size.len() == 2
3280 && in_shape.rank() == 4
3281 && w_shape.rank() == 4
3282 && out_shape.rank() == 4
3283 {
3284 Thunk::QConv2d {
3285 x: node_offset(arena, node.inputs[0]),
3286 w: node_offset(arena, node.inputs[1]),
3287 bias: node_offset(arena, node.inputs[2]),
3288 out: node_offset(arena, node.id),
3289 n: in_shape.dim(0).unwrap_static() as u32,
3290 c_in: in_shape.dim(1).unwrap_static() as u32,
3291 h: in_shape.dim(2).unwrap_static() as u32,
3292 w_in: in_shape.dim(3).unwrap_static() as u32,
3293 c_out: out_shape.dim(1).unwrap_static() as u32,
3294 h_out: out_shape.dim(2).unwrap_static() as u32,
3295 w_out: out_shape.dim(3).unwrap_static() as u32,
3296 kh: kernel_size[0] as u32,
3297 kw: kernel_size[1] as u32,
3298 sh: stride.first().copied().unwrap_or(1) as u32,
3299 sw: stride.get(1).copied().unwrap_or(1) as u32,
3300 ph: padding.first().copied().unwrap_or(0) as u32,
3301 pw: padding.get(1).copied().unwrap_or(0) as u32,
3302 dh: dilation.first().copied().unwrap_or(1) as u32,
3303 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3304 groups: *groups as u32,
3305 x_zp: *x_zp,
3306 w_zp: *w_zp,
3307 out_zp: *out_zp,
3308 mult: *mult,
3309 }
3310 } else {
3311 Thunk::Nop
3312 }
3313 }
3314
3315 Op::DequantMatMul { scheme } => {
3316 use rlx_ir::quant::QuantScheme;
3317 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3318 let total = node.shape.num_elements().unwrap();
3319 let m = total / n.max(1);
3320 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3321 let k = x_total / m.max(1);
3322 if scheme.is_gguf() {
3323 Thunk::DequantMatMulGguf {
3324 x: node_offset(arena, node.inputs[0]),
3325 w_q: node_offset(arena, node.inputs[1]),
3326 dst: node_offset(arena, node.id),
3327 m: m as u32,
3328 k: k as u32,
3329 n: n as u32,
3330 scheme: *scheme,
3331 }
3332 } else {
3333 match scheme {
3334 QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3335 x: node_offset(arena, node.inputs[0]),
3336 w_q: node_offset(arena, node.inputs[1]),
3337 scale: node_offset(arena, node.inputs[2]),
3338 global_scale: node_offset(arena, node.inputs[3]),
3339 dst: node_offset(arena, node.id),
3340 m: m as u32,
3341 k: k as u32,
3342 n: n as u32,
3343 },
3344 QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3345 x: node_offset(arena, node.inputs[0]),
3346 w_q: node_offset(arena, node.inputs[1]),
3347 scale: node_offset(arena, node.inputs[2]),
3348 zp: node_offset(arena, node.inputs[3]),
3349 dst: node_offset(arena, node.id),
3350 m: m as u32,
3351 k: k as u32,
3352 n: n as u32,
3353 block_size: *block_size,
3354 is_asymmetric: false,
3355 },
3356 QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3357 x: node_offset(arena, node.inputs[0]),
3358 w_q: node_offset(arena, node.inputs[1]),
3359 scale: node_offset(arena, node.inputs[2]),
3360 dst: node_offset(arena, node.id),
3361 m: m as u32,
3362 k: k as u32,
3363 n: n as u32,
3364 e5m2: false,
3365 },
3366 QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3367 x: node_offset(arena, node.inputs[0]),
3368 w_q: node_offset(arena, node.inputs[1]),
3369 scale: node_offset(arena, node.inputs[2]),
3370 dst: node_offset(arena, node.id),
3371 m: m as u32,
3372 k: k as u32,
3373 n: n as u32,
3374 e5m2: true,
3375 },
3376 QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3377 x: node_offset(arena, node.inputs[0]),
3378 w_q: node_offset(arena, node.inputs[1]),
3379 scale: node_offset(arena, node.inputs[2]),
3380 zp: node_offset(arena, node.inputs[3]),
3381 dst: node_offset(arena, node.id),
3382 m: m as u32,
3383 k: k as u32,
3384 n: n as u32,
3385 block_size: *block_size,
3386 is_asymmetric: false,
3387 },
3388 QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3389 x: node_offset(arena, node.inputs[0]),
3390 w_q: node_offset(arena, node.inputs[1]),
3391 scale: node_offset(arena, node.inputs[2]),
3392 zp: node_offset(arena, node.inputs[3]),
3393 dst: node_offset(arena, node.id),
3394 m: m as u32,
3395 k: k as u32,
3396 n: n as u32,
3397 block_size: *block_size,
3398 is_asymmetric: true,
3399 },
3400 other => panic!(
3401 "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3402 ),
3403 }
3404 }
3405 }
3406
3407 Op::LoraMatMul { scale } => {
3408 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3410 let total = node.shape.num_elements().unwrap();
3411 let m = total / n.max(1);
3412 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3413 let k = x_total / m.max(1);
3414 let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3415 let r = a_total / k.max(1);
3416 Thunk::LoraMatMul {
3417 x: node_offset(arena, node.inputs[0]),
3418 w: node_offset(arena, node.inputs[1]),
3419 a: node_offset(arena, node.inputs[2]),
3420 b: node_offset(arena, node.inputs[3]),
3421 dst: node_offset(arena, node.id),
3422 m: m as u32,
3423 k: k as u32,
3424 n: n as u32,
3425 r: r as u32,
3426 scale: *scale,
3427 }
3428 }
3429
3430 Op::Sample {
3431 top_k,
3432 top_p,
3433 temperature,
3434 seed,
3435 } => {
3436 let in_shape = &graph.node(node.inputs[0]).shape;
3437 let (batch, vocab) = if in_shape.rank() >= 2 {
3439 (
3440 in_shape.dim(0).unwrap_static(),
3441 in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3442 )
3443 } else {
3444 (1, in_shape.num_elements().unwrap_or(0))
3445 };
3446 Thunk::Sample {
3447 logits: node_offset(arena, node.inputs[0]),
3448 dst: node_offset(arena, node.id),
3449 batch: batch as u32,
3450 vocab: vocab as u32,
3451 top_k: *top_k as u32,
3452 top_p: *top_p,
3453 temperature: *temperature,
3454 seed: *seed,
3455 }
3456 }
3457
3458 Op::Cumsum { axis, exclusive } => {
3459 let rank = node.shape.rank();
3464 let ax = if *axis < 0 {
3465 (rank as i32 + axis) as usize
3466 } else {
3467 *axis as usize
3468 };
3469 assert_eq!(
3470 ax,
3471 rank - 1,
3472 "Cumsum only supports the last axis on CPU today"
3473 );
3474 let cols = node.shape.dim(ax).unwrap_static();
3475 let total = node.shape.num_elements().unwrap();
3476 Thunk::Cumsum {
3477 src: node_offset(arena, node.inputs[0]),
3478 dst: node_offset(arena, node.id),
3479 rows: (total / cols) as u32,
3480 cols: cols as u32,
3481 exclusive: *exclusive,
3482 }
3483 }
3484
3485 Op::Attention {
3486 num_heads,
3487 head_dim,
3488 mask_kind,
3489 score_scale: _,
3490 attn_logit_softcap: _,
3491 } => {
3492 let q_shape = &graph.node(node.inputs[0]).shape;
3498 let k_shape = &graph.node(node.inputs[1]).shape;
3499 let rank = q_shape.rank();
3500 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3501 let d1 = q_shape.dim(1).unwrap_static();
3502 let d2 = q_shape.dim(2).unwrap_static();
3503 if d1 == *num_heads {
3504 (
3506 q_shape.dim(0).unwrap_static(),
3507 d2,
3508 k_shape.dim(2).unwrap_static(),
3509 true,
3510 )
3511 } else {
3512 (
3514 q_shape.dim(0).unwrap_static(),
3515 d1,
3516 k_shape.dim(1).unwrap_static(),
3517 false,
3518 )
3519 }
3520 } else if rank >= 3 {
3521 (
3522 q_shape.dim(0).unwrap_static(),
3523 q_shape.dim(1).unwrap_static(),
3524 k_shape.dim(1).unwrap_static(),
3525 false,
3526 )
3527 } else {
3528 (
3529 1,
3530 q_shape.dim(0).unwrap_static(),
3531 k_shape.dim(0).unwrap_static(),
3532 false,
3533 )
3534 };
3535 let mask_off = if matches!(
3536 mask_kind,
3537 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3538 ) {
3539 node_offset(arena, node.inputs[3])
3540 } else {
3541 0
3542 };
3543 let hs = (*num_heads * *head_dim) as u32;
3544 Thunk::Attention {
3545 q: node_offset(arena, node.inputs[0]),
3546 k: node_offset(arena, node.inputs[1]),
3547 v: node_offset(arena, node.inputs[2]),
3548 mask: mask_off,
3549 out: node_offset(arena, node.id),
3550 batch: batch as u32,
3551 seq: seq as u32,
3552 kv_seq: kv_seq as u32,
3553 heads: *num_heads as u32,
3554 head_dim: *head_dim as u32,
3555 mask_kind: *mask_kind,
3556 q_row_stride: hs,
3560 k_row_stride: hs,
3561 v_row_stride: hs,
3562 bhsd,
3563 }
3564 }
3565
3566 Op::AttentionBackward {
3567 num_heads,
3568 head_dim,
3569 mask_kind,
3570 wrt,
3571 } => {
3572 let q_shape = &graph.node(node.inputs[0]).shape;
3573 let k_shape = &graph.node(node.inputs[1]).shape;
3574 let rank = q_shape.rank();
3575 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3576 let d1 = q_shape.dim(1).unwrap_static();
3577 let d2 = q_shape.dim(2).unwrap_static();
3578 if d1 == *num_heads {
3579 (
3580 q_shape.dim(0).unwrap_static(),
3581 d2,
3582 k_shape.dim(2).unwrap_static(),
3583 true,
3584 )
3585 } else {
3586 (
3587 q_shape.dim(0).unwrap_static(),
3588 d1,
3589 k_shape.dim(1).unwrap_static(),
3590 false,
3591 )
3592 }
3593 } else if rank >= 3 {
3594 (
3595 q_shape.dim(0).unwrap_static(),
3596 q_shape.dim(1).unwrap_static(),
3597 k_shape.dim(1).unwrap_static(),
3598 false,
3599 )
3600 } else {
3601 (
3602 1,
3603 q_shape.dim(0).unwrap_static(),
3604 k_shape.dim(0).unwrap_static(),
3605 false,
3606 )
3607 };
3608 let mask_off = if matches!(
3609 mask_kind,
3610 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3611 ) {
3612 node_offset(arena, node.inputs[4])
3613 } else {
3614 0
3615 };
3616 Thunk::AttentionBackward {
3617 q: node_offset(arena, node.inputs[0]),
3618 k: node_offset(arena, node.inputs[1]),
3619 v: node_offset(arena, node.inputs[2]),
3620 dy: node_offset(arena, node.inputs[3]),
3621 mask: mask_off,
3622 out: node_offset(arena, node.id),
3623 batch: batch as u32,
3624 seq: seq as u32,
3625 kv_seq: kv_seq as u32,
3626 heads: *num_heads as u32,
3627 head_dim: *head_dim as u32,
3628 mask_kind: *mask_kind,
3629 wrt: *wrt,
3630 bhsd,
3631 }
3632 }
3633
3634 Op::FusedAttentionBlock {
3635 num_heads,
3636 head_dim,
3637 has_bias,
3638 has_rope,
3639 } => {
3640 let x_shape = &graph.node(node.inputs[0]).shape;
3641 let (batch, seq) = if x_shape.rank() >= 3 {
3642 (
3643 x_shape.dim(0).unwrap_static(),
3644 x_shape.dim(1).unwrap_static(),
3645 )
3646 } else {
3647 let total = x_shape.num_elements().unwrap();
3648 let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3649 (total / (s * num_heads * head_dim), s)
3650 };
3651 let hs = (*num_heads * *head_dim) as u32;
3652 let mut idx = 4;
3654 let (qkv_b_off, out_b_off) = if *has_bias {
3655 let qb = node_offset(arena, node.inputs[idx]);
3656 let ob = node_offset(arena, node.inputs[idx + 1]);
3657 idx += 2;
3658 (qb, ob)
3659 } else {
3660 (0, 0)
3661 };
3662 let (cos_off, sin_off, cl) = if *has_rope {
3663 let c = node_offset(arena, node.inputs[idx]);
3664 let s = node_offset(arena, node.inputs[idx + 1]);
3665 let clen = get_len(graph, node.inputs[idx]);
3666 (c, s, clen as u32)
3667 } else {
3668 (0, 0, 0)
3669 };
3670
3671 Thunk::FusedAttnBlock {
3672 hidden: node_offset(arena, node.inputs[0]),
3673 qkv_w: node_offset(arena, node.inputs[1]),
3674 out_w: node_offset(arena, node.inputs[2]),
3675 mask: node_offset(arena, node.inputs[3]),
3676 out: node_offset(arena, node.id),
3677 qkv_b: qkv_b_off,
3678 out_b: out_b_off,
3679 cos: cos_off,
3680 sin: sin_off,
3681 cos_len: cl,
3682 batch: batch as u32,
3683 seq: seq as u32,
3684 hs,
3685 nh: *num_heads as u32,
3686 dh: *head_dim as u32,
3687 has_bias: *has_bias,
3688 has_rope: *has_rope,
3689 }
3690 }
3691
3692 Op::Rope { head_dim, n_rot } => {
3693 let x_shape = &graph.node(node.inputs[0]).shape;
3694 let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3695 (
3696 x_shape.dim(0).unwrap_static(),
3697 x_shape.dim(1).unwrap_static(),
3698 x_shape.dim(2).unwrap_static(),
3699 )
3700 } else {
3701 let total = x_shape.num_elements().unwrap();
3702 (
3703 1,
3704 x_shape.dim(0).unwrap_static(),
3705 total / x_shape.dim(0).unwrap_static(),
3706 )
3707 };
3708 let cos_len = get_len(graph, node.inputs[1]);
3709 Thunk::Rope {
3710 src: node_offset(arena, node.inputs[0]),
3711 cos: node_offset(arena, node.inputs[1]),
3712 sin: node_offset(arena, node.inputs[2]),
3713 dst: node_offset(arena, node.id),
3714 batch: batch as u32,
3715 seq: seq as u32,
3716 hidden: hidden as u32,
3717 head_dim: *head_dim as u32,
3718 n_rot: *n_rot as u32,
3719 cos_len: cos_len as u32,
3720 src_row_stride: hidden as u32,
3724 }
3725 }
3726
3727 Op::FusedSwiGLU {
3728 cast_to: _,
3729 gate_first,
3730 } => {
3731 let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3732 let total = node.shape.num_elements().unwrap();
3733 Thunk::FusedSwiGLU {
3734 src: node_offset(arena, node.inputs[0]),
3735 dst: node_offset(arena, node.id),
3736 n_half: n_half as u32,
3737 total: total as u32,
3738 gate_first: *gate_first,
3739 }
3740 }
3741
3742 Op::Conv {
3743 kernel_size,
3744 stride,
3745 padding,
3746 dilation,
3747 groups,
3748 } => {
3749 let in_shape = &graph.node(node.inputs[0]).shape;
3750 let w_shape = &graph.node(node.inputs[1]).shape;
3751 let out_shape = &node.shape;
3752 let is_1x1_simple = kernel_size.len() == 2
3756 && kernel_size[0] == 1
3757 && kernel_size[1] == 1
3758 && stride.iter().all(|&s| s == 1)
3759 && padding.iter().all(|&p| p == 0)
3760 && dilation.iter().all(|&d| d == 1)
3761 && *groups == 1;
3762 if is_1x1_simple
3763 && in_shape.rank() >= 3
3764 && out_shape.rank() >= 3
3765 && w_shape.rank() >= 2
3766 {
3767 let (n, c_in, h, w) = conv_nchw_dims(in_shape);
3768 let (_, c_out, _, _) = conv_nchw_dims(out_shape);
3769 Thunk::Conv2D1x1 {
3770 src: node_offset(arena, node.inputs[0]),
3771 weight: node_offset(arena, node.inputs[1]),
3772 dst: node_offset(arena, node.id),
3773 n,
3774 c_in,
3775 c_out,
3776 hw: h.saturating_mul(w),
3777 }
3778 } else if kernel_size.len() == 2
3779 && in_shape.rank() >= 3
3780 && w_shape.rank() >= 2
3781 && out_shape.rank() >= 3
3782 {
3783 let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3784 let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3785 Thunk::Conv2D {
3786 src: node_offset(arena, node.inputs[0]),
3787 weight: node_offset(arena, node.inputs[1]),
3788 dst: node_offset(arena, node.id),
3789 n,
3790 c_in,
3791 h,
3792 w: w_in,
3793 c_out,
3794 h_out,
3795 w_out,
3796 kh: kernel_size[0] as u32,
3797 kw: kernel_size[1] as u32,
3798 sh: stride.first().copied().unwrap_or(1) as u32,
3799 sw: stride.get(1).copied().unwrap_or(1) as u32,
3800 ph: padding.first().copied().unwrap_or(0) as u32,
3801 pw: padding.get(1).copied().unwrap_or(0) as u32,
3802 dh: dilation.first().copied().unwrap_or(1) as u32,
3803 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3804 groups: *groups as u32,
3805 }
3806 } else {
3807 Thunk::Nop
3808 }
3809 }
3810
3811 Op::Pool {
3812 kind,
3813 kernel_size,
3814 stride,
3815 padding,
3816 } => {
3817 let in_shape = &graph.node(node.inputs[0]).shape;
3819 let out_shape = &node.shape;
3820 if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3821 Thunk::Pool2D {
3822 src: node_offset(arena, node.inputs[0]),
3823 dst: node_offset(arena, node.id),
3824 n: in_shape.dim(0).unwrap_static() as u32,
3825 c: in_shape.dim(1).unwrap_static() as u32,
3826 h: in_shape.dim(2).unwrap_static() as u32,
3827 w: in_shape.dim(3).unwrap_static() as u32,
3828 h_out: out_shape.dim(2).unwrap_static() as u32,
3829 w_out: out_shape.dim(3).unwrap_static() as u32,
3830 kh: kernel_size[0] as u32,
3831 kw: kernel_size[1] as u32,
3832 sh: stride.first().copied().unwrap_or(1) as u32,
3833 sw: stride.get(1).copied().unwrap_or(1) as u32,
3834 ph: padding.first().copied().unwrap_or(0) as u32,
3835 pw: padding.get(1).copied().unwrap_or(0) as u32,
3836 kind: *kind,
3837 }
3838 } else {
3839 Thunk::Nop
3840 }
3841 }
3842
3843 Op::Transpose { perm } => {
3844 let in_shape = &graph.node(node.inputs[0]).shape;
3847 let in_rank = in_shape.rank();
3848 if perm.iter().any(|&p| p >= in_rank) {
3849 Thunk::Nop
3850 } else {
3851 let in_dims: Vec<usize> = (0..in_rank)
3852 .map(|i| in_shape.dim(i).unwrap_static())
3853 .collect();
3854 let mut in_strides_full = vec![1usize; in_rank];
3856 for d in (0..in_rank.saturating_sub(1)).rev() {
3857 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3858 }
3859 let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3860 let in_strides: Vec<u32> =
3861 perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3862 let in_total = in_dims.iter().product::<usize>() as u32;
3863 let src = node_offset(arena, node.inputs[0]);
3864 let dst = node_offset(arena, node.id);
3865 let elem_bytes = node.shape.dtype().size_bytes() as u8;
3866 match node.shape.dtype() {
3867 rlx_ir::DType::F64 => Thunk::TransposeF64 {
3868 src,
3869 dst,
3870 in_total,
3871 out_dims,
3872 in_strides,
3873 },
3874 _ => Thunk::Transpose {
3875 src,
3876 dst,
3877 in_total,
3878 out_dims,
3879 in_strides,
3880 elem_bytes,
3881 },
3882 }
3883 }
3884 }
3885
3886 Op::ScatterAdd => {
3887 let upd_shape = &graph.node(node.inputs[0]).shape;
3890 let out_shape = &node.shape;
3891 let num_updates = upd_shape.dim(0).unwrap_static();
3892 let out_dim = out_shape.dim(0).unwrap_static();
3893 let trailing: usize = (1..out_shape.rank())
3894 .map(|i| out_shape.dim(i).unwrap_static())
3895 .product::<usize>()
3896 .max(1);
3897 Thunk::ScatterAdd {
3898 updates: node_offset(arena, node.inputs[0]),
3899 indices: node_offset(arena, node.inputs[1]),
3900 dst: node_offset(arena, node.id),
3901 num_updates: num_updates as u32,
3902 out_dim: out_dim as u32,
3903 trailing: trailing as u32,
3904 }
3905 }
3906
3907 Op::GroupedMatMul => {
3908 let in_shape = &graph.node(node.inputs[0]).shape;
3910 let w_shape = &graph.node(node.inputs[1]).shape;
3911 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3912 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3913 let num_experts = w_shape.dim(0).unwrap_static();
3914 let n = w_shape.dim(2).unwrap_static();
3915 Thunk::GroupedMatMul {
3916 input: node_offset(arena, node.inputs[0]),
3917 weight: node_offset(arena, node.inputs[1]),
3918 expert_idx: node_offset(arena, node.inputs[2]),
3919 dst: node_offset(arena, node.id),
3920 m: m as u32,
3921 k_dim: k_dim as u32,
3922 n: n as u32,
3923 num_experts: num_experts as u32,
3924 }
3925 }
3926
3927 Op::DequantGroupedMatMul { scheme } => {
3928 let in_shape = &graph.node(node.inputs[0]).shape;
3929 let w_shape = &graph.node(node.inputs[1]).shape;
3930 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3931 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3932 let out_shape = &node.shape;
3933 let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3934 let block_elems = scheme.gguf_block_size() as usize;
3935 let block_bytes = scheme.gguf_block_bytes() as usize;
3936 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3937 let total_bytes = w_shape.num_elements().unwrap();
3938 let num_experts = total_bytes / slab_bytes.max(1);
3939 Thunk::DequantGroupedMatMulGguf {
3940 input: node_offset(arena, node.inputs[0]),
3941 w_q: node_offset(arena, node.inputs[1]),
3942 expert_idx: node_offset(arena, node.inputs[2]),
3943 dst: node_offset(arena, node.id),
3944 m: m as u32,
3945 k_dim: k_dim as u32,
3946 n: n as u32,
3947 num_experts: num_experts as u32,
3948 scheme: *scheme,
3949 }
3950 }
3951
3952 Op::DequantMoEWeights { scheme } => {
3953 let w_shape = &graph.node(node.inputs[0]).shape;
3954 let out_shape = &node.shape;
3955 let num_experts = out_shape.dim(0).unwrap_static();
3956 let k_dim = out_shape.dim(1).unwrap_static();
3957 let n = out_shape.dim(2).unwrap_static();
3958 let block_elems = scheme.gguf_block_size() as usize;
3959 let block_bytes = scheme.gguf_block_bytes() as usize;
3960 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3961 let total_bytes = w_shape.num_elements().unwrap();
3962 assert_eq!(
3963 total_bytes,
3964 num_experts * slab_bytes,
3965 "DequantMoEWeights packed bytes mismatch"
3966 );
3967 Thunk::DequantMoEWeightsGguf {
3968 w_q: node_offset(arena, node.inputs[0]),
3969 dst: node_offset(arena, node.id),
3970 k_dim: k_dim as u32,
3971 n: n as u32,
3972 num_experts: num_experts as u32,
3973 scheme: *scheme,
3974 }
3975 }
3976
3977 Op::TopK { k } => {
3978 let in_shape = &graph.node(node.inputs[0]).shape;
3979 let rank = in_shape.rank();
3980 let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3981 let outer = in_shape.num_elements().unwrap() / axis_dim;
3982 let indices_i64 = u8::from(graph.node(node.id).shape.dtype() == rlx_ir::DType::I64);
3983 Thunk::TopK {
3984 src: node_offset(arena, node.inputs[0]),
3985 dst: node_offset(arena, node.id),
3986 outer: outer as u32,
3987 axis_dim: axis_dim as u32,
3988 k: *k as u32,
3989 indices_i64,
3990 }
3991 }
3992
3993 Op::Reduce {
3994 op,
3995 axes,
3996 keep_dim: _,
3997 } => {
3998 let in_shape = &graph.node(node.inputs[0]).shape;
4004 let rank = in_shape.rank();
4005 let mut sorted = axes.clone();
4006 sorted.sort();
4007 sorted.dedup();
4008 let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
4009 && !sorted.is_empty()
4010 && *sorted.last().unwrap() < rank;
4011 if !contiguous {
4012 Thunk::Nop
4013 } else {
4014 let first = sorted[0];
4015 let last = *sorted.last().unwrap();
4016 let outer: usize = (0..first)
4017 .map(|i| in_shape.dim(i).unwrap_static())
4018 .product::<usize>()
4019 .max(1);
4020 let reduced: usize = (first..=last)
4021 .map(|i| in_shape.dim(i).unwrap_static())
4022 .product();
4023 let inner: usize = (last + 1..rank)
4024 .map(|i| in_shape.dim(i).unwrap_static())
4025 .product::<usize>()
4026 .max(1);
4027 let src = node_offset(arena, node.inputs[0]);
4028 let dst = node_offset(arena, node.id);
4029 if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
4030 Thunk::ReduceSumF64 {
4031 src,
4032 dst,
4033 outer: outer as u32,
4034 reduced: reduced as u32,
4035 inner: inner as u32,
4036 }
4037 } else {
4038 Thunk::Reduce {
4039 src,
4040 dst,
4041 outer: outer as u32,
4042 reduced: reduced as u32,
4043 inner: inner as u32,
4044 op: *op,
4045 }
4046 }
4047 }
4048 }
4049
4050 Op::Compare(cmp) => {
4051 let len = node.shape.num_elements().unwrap();
4052 let in_dtype = graph.node(node.inputs[0]).shape.dtype();
4053 let inputs_i64 = u8::from(in_dtype == rlx_ir::DType::I64);
4054 Thunk::Compare {
4055 lhs: node_offset(arena, node.inputs[0]),
4056 rhs: node_offset(arena, node.inputs[1]),
4057 dst: node_offset(arena, node.id),
4058 len: len as u32,
4059 op: *cmp,
4060 inputs_i64,
4061 inputs_elem_bytes: in_dtype.size_bytes() as u8,
4062 dst_elem_bytes: node.shape.dtype().size_bytes() as u8,
4063 }
4064 }
4065
4066 Op::Where => {
4067 let len = node.shape.num_elements().unwrap();
4068 let elem_bytes = node.shape.dtype().size_bytes() as u8;
4069 let cond_elem_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
4070 Thunk::Where {
4071 cond: node_offset(arena, node.inputs[0]),
4072 on_true: node_offset(arena, node.inputs[1]),
4073 on_false: node_offset(arena, node.inputs[2]),
4074 dst: node_offset(arena, node.id),
4075 len: len as u32,
4076 elem_bytes,
4077 cond_elem_bytes,
4078 }
4079 }
4080
4081 Op::ReluBackward => {
4082 let len: usize = (0..node.shape.rank())
4083 .map(|i| node.shape.dim(i).unwrap_static())
4084 .product();
4085 let x = node_offset(arena, node.inputs[0]);
4086 let dy = node_offset(arena, node.inputs[1]);
4087 let dx = node_offset(arena, node.id);
4088 match node.shape.dtype() {
4089 rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
4090 x,
4091 dy,
4092 dx,
4093 len: len as u32,
4094 },
4095 _ => Thunk::ReluBackward {
4096 x,
4097 dy,
4098 dx,
4099 len: len as u32,
4100 },
4101 }
4102 }
4103
4104 Op::ComplexNormSq => {
4105 let len: usize = (0..node.shape.rank())
4106 .map(|i| node.shape.dim(i).unwrap_static())
4107 .product();
4108 let src = node_offset(arena, node.inputs[0]);
4109 let dst = node_offset(arena, node.id);
4110 Thunk::ComplexNormSqF32 {
4111 src,
4112 dst,
4113 len: len as u32,
4114 }
4115 }
4116
4117 Op::ComplexNormSqBackward => {
4118 let len: usize = (0..node.shape.rank())
4119 .map(|i| node.shape.dim(i).unwrap_static())
4120 .product();
4121 let z = node_offset(arena, node.inputs[0]);
4122 let g = node_offset(arena, node.inputs[1]);
4123 let dz = node_offset(arena, node.id);
4124 Thunk::ComplexNormSqBackwardF32 {
4125 z,
4126 g,
4127 dz,
4128 len: len as u32,
4129 }
4130 }
4131
4132 Op::Conjugate => {
4133 let len: usize = (0..node.shape.rank())
4134 .map(|i| node.shape.dim(i).unwrap_static())
4135 .product();
4136 Thunk::ConjugateC64 {
4137 src: node_offset(arena, node.inputs[0]),
4138 dst: node_offset(arena, node.id),
4139 len: len as u32,
4140 }
4141 }
4142
4143 Op::ActivationBackward { kind } => {
4144 let len: usize = (0..node.shape.rank())
4145 .map(|i| node.shape.dim(i).unwrap_static())
4146 .product();
4147 let x = node_offset(arena, node.inputs[0]);
4148 let dy = node_offset(arena, node.inputs[1]);
4149 let dx = node_offset(arena, node.id);
4150 match node.shape.dtype() {
4151 rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
4152 x,
4153 dy,
4154 dx,
4155 len: len as u32,
4156 kind: *kind,
4157 },
4158 _ => Thunk::ActivationBackward {
4159 x,
4160 dy,
4161 dx,
4162 len: len as u32,
4163 kind: *kind,
4164 },
4165 }
4166 }
4167
4168 Op::LayerNormBackwardInput { eps, .. } => {
4169 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
4171 let total = node.shape.num_elements().unwrap();
4172 Thunk::LayerNormBackwardInput {
4173 x: node_offset(arena, node.inputs[0]),
4174 gamma: node_offset(arena, node.inputs[1]),
4175 dy: node_offset(arena, node.inputs[2]),
4176 dx: node_offset(arena, node.id),
4177 rows: (total / h) as u32,
4178 h: h as u32,
4179 eps: *eps,
4180 }
4181 }
4182
4183 Op::LayerNormBackwardGamma { eps, .. } => {
4184 let x_shape = &graph.node(node.inputs[0]).shape;
4185 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4186 let x_total = x_shape.num_elements().unwrap();
4187 Thunk::LayerNormBackwardGamma {
4188 x: node_offset(arena, node.inputs[0]),
4189 dy: node_offset(arena, node.inputs[1]),
4190 dgamma: node_offset(arena, node.id),
4191 rows: (x_total / h) as u32,
4192 h: h as u32,
4193 eps: *eps,
4194 }
4195 }
4196
4197 Op::RmsNormBackwardInput { eps, .. }
4198 | Op::RmsNormBackwardGamma { eps, .. }
4199 | Op::RmsNormBackwardBeta { eps, .. } => {
4200 let x_shape = &graph.node(node.inputs[0]).shape;
4201 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4202 let rows = (x_shape.num_elements().unwrap() / h) as u32;
4203 let off = |i: usize| node_offset(arena, node.inputs[i]);
4204 let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
4205 match &node.op {
4206 Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
4207 x: common.0,
4208 gamma: common.1,
4209 beta: common.2,
4210 dy: common.3,
4211 dx: node_offset(arena, node.id),
4212 rows: common.4,
4213 h: common.5,
4214 eps: common.6,
4215 },
4216 Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
4217 x: common.0,
4218 gamma: common.1,
4219 beta: common.2,
4220 dy: common.3,
4221 dgamma: node_offset(arena, node.id),
4222 rows: common.4,
4223 h: common.5,
4224 eps: common.6,
4225 },
4226 Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
4227 x: common.0,
4228 gamma: common.1,
4229 beta: common.2,
4230 dy: common.3,
4231 dbeta: node_offset(arena, node.id),
4232 rows: common.4,
4233 h: common.5,
4234 eps: common.6,
4235 },
4236 _ => unreachable!(),
4237 }
4238 }
4239
4240 Op::RopeBackward { head_dim, n_rot } => {
4241 let dy_shape = &graph.node(node.inputs[0]).shape;
4242 let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
4243 (
4244 dy_shape.dim(0).unwrap_static(),
4245 dy_shape.dim(1).unwrap_static(),
4246 dy_shape.dim(2).unwrap_static(),
4247 )
4248 } else {
4249 (
4250 1,
4251 dy_shape.dim(0).unwrap_static(),
4252 dy_shape.dim(1).unwrap_static(),
4253 )
4254 };
4255 let cos_shape = &graph.node(node.inputs[1]).shape;
4256 let cos_len = cos_shape.num_elements().unwrap();
4257 Thunk::RopeBackward {
4258 dy: node_offset(arena, node.inputs[0]),
4259 cos: node_offset(arena, node.inputs[1]),
4260 sin: node_offset(arena, node.inputs[2]),
4261 dx: node_offset(arena, node.id),
4262 batch: batch as u32,
4263 seq: seq as u32,
4264 hidden: hidden as u32,
4265 head_dim: *head_dim as u32,
4266 n_rot: *n_rot as u32,
4267 cos_len: cos_len as u32,
4268 }
4269 }
4270
4271 Op::CumsumBackward { exclusive, .. } => {
4272 let dy_shape = &graph.node(node.inputs[0]).shape;
4273 let rank = dy_shape.rank();
4274 let cols = dy_shape.dim(rank - 1).unwrap_static();
4275 let rows = dy_shape.num_elements().unwrap() / cols;
4276 Thunk::CumsumBackward {
4277 dy: node_offset(arena, node.inputs[0]),
4278 dx: node_offset(arena, node.id),
4279 rows: rows as u32,
4280 cols: cols as u32,
4281 exclusive: *exclusive,
4282 }
4283 }
4284
4285 Op::GatherBackward { .. } => {
4286 let dy_shape = &graph.node(node.inputs[0]).shape;
4287 let idx_shape = &graph.node(node.inputs[1]).shape;
4288 let out_shape = &node.shape;
4289 let rank = out_shape.rank();
4290 let axis = match &node.op {
4291 Op::GatherBackward { axis } => *axis,
4292 _ => 0,
4293 };
4294 let axis_u = if axis < 0 {
4295 (rank as i32 + axis) as usize
4296 } else {
4297 axis as usize
4298 };
4299 let outer: usize = (0..axis_u)
4300 .map(|i| dy_shape.dim(i).unwrap_static())
4301 .product::<usize>()
4302 .max(1);
4303 let num_idx = idx_shape.dim(axis_u).unwrap_static();
4304 let trailing: usize = (axis_u + 1..dy_shape.rank())
4305 .map(|i| dy_shape.dim(i).unwrap_static())
4306 .product::<usize>()
4307 .max(1);
4308 let axis_dim = out_shape.dim(axis_u).unwrap_static();
4309 Thunk::GatherBackward {
4310 dy: node_offset(arena, node.inputs[0]),
4311 indices: node_offset(arena, node.inputs[1]),
4312 dst: node_offset(arena, node.id),
4313 outer: outer as u32,
4314 axis_dim: axis_dim as u32,
4315 num_idx: num_idx as u32,
4316 trailing: trailing as u32,
4317 }
4318 }
4319
4320 Op::GroupNormBackwardInput { num_groups, eps }
4321 | Op::GroupNormBackwardGamma { num_groups, eps }
4322 | Op::GroupNormBackwardBeta { num_groups, eps } => {
4323 let x_shape = &graph.node(node.inputs[0]).shape;
4324 let n = x_shape.dim(0).unwrap_static() as u32;
4325 let c = x_shape.dim(1).unwrap_static() as u32;
4326 let h = x_shape.dim(2).unwrap_static() as u32;
4327 let w = x_shape.dim(3).unwrap_static() as u32;
4328 match &node.op {
4329 Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4330 x: node_offset(arena, node.inputs[0]),
4331 gamma: node_offset(arena, node.inputs[1]),
4332 beta: node_offset(arena, node.inputs[2]),
4333 dy: node_offset(arena, node.inputs[3]),
4334 dx: node_offset(arena, node.id),
4335 n,
4336 c,
4337 h,
4338 w,
4339 num_groups: *num_groups as u32,
4340 eps: *eps,
4341 },
4342 Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4343 x: node_offset(arena, node.inputs[0]),
4344 dy: node_offset(arena, node.inputs[1]),
4345 dgamma: node_offset(arena, node.id),
4346 n,
4347 c,
4348 h,
4349 w,
4350 num_groups: *num_groups as u32,
4351 eps: *eps,
4352 },
4353 Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4354 dy: node_offset(arena, node.inputs[1]),
4355 dbeta: node_offset(arena, node.id),
4356 n,
4357 c,
4358 h,
4359 w,
4360 },
4361 _ => unreachable!(),
4362 }
4363 }
4364
4365 Op::MaxPool2dBackward {
4366 kernel_size,
4367 stride,
4368 padding,
4369 } => {
4370 let x_shape = &graph.node(node.inputs[0]).shape;
4371 let dy_shape = &graph.node(node.inputs[1]).shape;
4372 if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4373 Thunk::MaxPool2dBackward {
4374 x: node_offset(arena, node.inputs[0]),
4375 dy: node_offset(arena, node.inputs[1]),
4376 dx: node_offset(arena, node.id),
4377 n: x_shape.dim(0).unwrap_static() as u32,
4378 c: x_shape.dim(1).unwrap_static() as u32,
4379 h: x_shape.dim(2).unwrap_static() as u32,
4380 w: x_shape.dim(3).unwrap_static() as u32,
4381 h_out: dy_shape.dim(2).unwrap_static() as u32,
4382 w_out: dy_shape.dim(3).unwrap_static() as u32,
4383 kh: kernel_size[0] as u32,
4384 kw: kernel_size[1] as u32,
4385 sh: stride.first().copied().unwrap_or(1) as u32,
4386 sw: stride.get(1).copied().unwrap_or(1) as u32,
4387 ph: padding.first().copied().unwrap_or(0) as u32,
4388 pw: padding.get(1).copied().unwrap_or(0) as u32,
4389 }
4390 } else {
4391 Thunk::Nop
4392 }
4393 }
4394
4395 Op::Conv2dBackwardInput {
4396 kernel_size,
4397 stride,
4398 padding,
4399 dilation,
4400 groups,
4401 } => {
4402 let dy_shape = &graph.node(node.inputs[0]).shape;
4403 let w_shape = &graph.node(node.inputs[1]).shape;
4404 let out_shape = &node.shape;
4405 if kernel_size.len() == 2
4406 && dy_shape.rank() == 4
4407 && w_shape.rank() == 4
4408 && out_shape.rank() == 4
4409 {
4410 Thunk::Conv2dBackwardInput {
4411 dy: node_offset(arena, node.inputs[0]),
4412 w: node_offset(arena, node.inputs[1]),
4413 dx: node_offset(arena, node.id),
4414 n: out_shape.dim(0).unwrap_static() as u32,
4415 c_in: out_shape.dim(1).unwrap_static() as u32,
4416 h: out_shape.dim(2).unwrap_static() as u32,
4417 w_in: out_shape.dim(3).unwrap_static() as u32,
4418 c_out: dy_shape.dim(1).unwrap_static() as u32,
4419 h_out: dy_shape.dim(2).unwrap_static() as u32,
4420 w_out: dy_shape.dim(3).unwrap_static() as u32,
4421 kh: kernel_size[0] as u32,
4422 kw: kernel_size[1] as u32,
4423 sh: stride.first().copied().unwrap_or(1) as u32,
4424 sw: stride.get(1).copied().unwrap_or(1) as u32,
4425 ph: padding.first().copied().unwrap_or(0) as u32,
4426 pw: padding.get(1).copied().unwrap_or(0) as u32,
4427 dh: dilation.first().copied().unwrap_or(1) as u32,
4428 dw: dilation.get(1).copied().unwrap_or(1) as u32,
4429 groups: *groups as u32,
4430 }
4431 } else {
4432 Thunk::Nop
4433 }
4434 }
4435
4436 Op::Conv2dBackwardWeight {
4437 kernel_size,
4438 stride,
4439 padding,
4440 dilation,
4441 groups,
4442 } => {
4443 let x_shape = &graph.node(node.inputs[0]).shape;
4444 let dy_shape = &graph.node(node.inputs[1]).shape;
4445 let dw_shape = &node.shape;
4446 if kernel_size.len() == 2
4447 && x_shape.rank() == 4
4448 && dy_shape.rank() == 4
4449 && dw_shape.rank() == 4
4450 {
4451 Thunk::Conv2dBackwardWeight {
4452 x: node_offset(arena, node.inputs[0]),
4453 dy: node_offset(arena, node.inputs[1]),
4454 dw: node_offset(arena, node.id),
4455 n: x_shape.dim(0).unwrap_static() as u32,
4456 c_in: x_shape.dim(1).unwrap_static() as u32,
4457 h: x_shape.dim(2).unwrap_static() as u32,
4458 w: x_shape.dim(3).unwrap_static() as u32,
4459 c_out: dy_shape.dim(1).unwrap_static() as u32,
4460 h_out: dy_shape.dim(2).unwrap_static() as u32,
4461 w_out: dy_shape.dim(3).unwrap_static() as u32,
4462 kh: kernel_size[0] as u32,
4463 kw: kernel_size[1] as u32,
4464 sh: stride.first().copied().unwrap_or(1) as u32,
4465 sw: stride.get(1).copied().unwrap_or(1) as u32,
4466 ph: padding.first().copied().unwrap_or(0) as u32,
4467 pw: padding.get(1).copied().unwrap_or(0) as u32,
4468 dh: dilation.first().copied().unwrap_or(1) as u32,
4469 dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4470 groups: *groups as u32,
4471 }
4472 } else {
4473 Thunk::Nop
4474 }
4475 }
4476
4477 Op::Im2Col {
4478 kernel_size,
4479 stride,
4480 padding,
4481 dilation,
4482 } => {
4483 let x_shape = &graph.node(node.inputs[0]).shape;
4484 let out_shape = &node.shape;
4485 if kernel_size.len() == 2 && x_shape.rank() == 4 && out_shape.rank() == 2 {
4486 let n = match x_shape.dim(0) {
4487 rlx_ir::shape::Dim::Static(v) => v as u32,
4488 _ => 0,
4489 };
4490 let c_in = x_shape.dim(1).unwrap_static() as u32;
4491 let h = x_shape.dim(2).unwrap_static() as u32;
4492 let w = x_shape.dim(3).unwrap_static() as u32;
4493 let kh = kernel_size[0] as u32;
4494 let kw = kernel_size[1] as u32;
4495 let sh = stride.first().copied().unwrap_or(1) as u32;
4496 let sw = stride.get(1).copied().unwrap_or(1) as u32;
4497 let ph = padding.first().copied().unwrap_or(0) as u32;
4498 let pw = padding.get(1).copied().unwrap_or(0) as u32;
4499 let dh = dilation.first().copied().unwrap_or(1) as u32;
4500 let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
4501 let h_out = rlx_ir::shape::conv2d_spatial_output(
4502 h as usize,
4503 kh as usize,
4504 sh as usize,
4505 ph as usize,
4506 dh as usize,
4507 ) as u32;
4508 let w_out = rlx_ir::shape::conv2d_spatial_output(
4509 w as usize,
4510 kw as usize,
4511 sw as usize,
4512 pw as usize,
4513 dw_dil as usize,
4514 ) as u32;
4515 Thunk::Im2Col {
4516 x: node_offset(arena, node.inputs[0]),
4517 col: node_offset(arena, node.id),
4518 n,
4519 c_in,
4520 h,
4521 w,
4522 h_out,
4523 w_out,
4524 kh,
4525 kw,
4526 sh,
4527 sw,
4528 ph,
4529 pw,
4530 dh,
4531 dw_dil,
4532 }
4533 } else {
4534 Thunk::Nop
4535 }
4536 }
4537
4538 Op::SoftmaxCrossEntropyWithLogits => {
4539 let logits_shape = &graph.node(node.inputs[0]).shape;
4540 if logits_shape.rank() == 2 {
4541 Thunk::SoftmaxCrossEntropy {
4542 logits: node_offset(arena, node.inputs[0]),
4543 labels: node_offset(arena, node.inputs[1]),
4544 dst: node_offset(arena, node.id),
4545 n: logits_shape.dim(0).unwrap_static() as u32,
4546 c: logits_shape.dim(1).unwrap_static() as u32,
4547 }
4548 } else {
4549 Thunk::Nop
4550 }
4551 }
4552
4553 Op::SoftmaxCrossEntropyBackward => {
4554 let logits_shape = &graph.node(node.inputs[0]).shape;
4555 if logits_shape.rank() == 2 {
4556 Thunk::SoftmaxCrossEntropyBackward {
4557 logits: node_offset(arena, node.inputs[0]),
4558 labels: node_offset(arena, node.inputs[1]),
4559 d_loss: node_offset(arena, node.inputs[2]),
4560 dlogits: node_offset(arena, node.id),
4561 n: logits_shape.dim(0).unwrap_static() as u32,
4562 c: logits_shape.dim(1).unwrap_static() as u32,
4563 }
4564 } else {
4565 Thunk::Nop
4566 }
4567 }
4568
4569 Op::DenseSolve => {
4570 let a_shape = &graph.node(node.inputs[0]).shape;
4572 let n = a_shape.dim(0).unwrap_static();
4573 debug_assert_eq!(
4574 n,
4575 a_shape.dim(1).unwrap_static(),
4576 "DenseSolve: A must be square"
4577 );
4578 let b_elems = node.shape.num_elements().unwrap();
4579 let nrhs = b_elems / n;
4580 match node.shape.dtype() {
4581 rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4582 a: node_offset(arena, node.inputs[0]),
4583 b: node_offset(arena, node.inputs[1]),
4584 x: node_offset(arena, node.id),
4585 n: n as u32,
4586 nrhs: nrhs as u32,
4587 },
4588 rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4589 a: node_offset(arena, node.inputs[0]),
4590 b: node_offset(arena, node.inputs[1]),
4591 x: node_offset(arena, node.id),
4592 n: n as u32,
4593 nrhs: nrhs as u32,
4594 },
4595 other => panic!(
4596 "DenseSolve: F32 + F64 lowered; got {other:?}. \
4597 Add another variant when needed."
4598 ),
4599 }
4600 }
4601
4602 Op::BatchedDenseSolve => {
4603 let a_shape = &graph.node(node.inputs[0]).shape;
4605 assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4606 let batch = a_shape.dim(0).unwrap_static();
4607 let n = a_shape.dim(1).unwrap_static();
4608 debug_assert_eq!(
4609 n,
4610 a_shape.dim(2).unwrap_static(),
4611 "BatchedDenseSolve: A's last two dims must match"
4612 );
4613 let total = node.shape.num_elements().unwrap();
4614 let nrhs = total / (batch * n);
4615 match node.shape.dtype() {
4616 rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4617 a: node_offset(arena, node.inputs[0]),
4618 b: node_offset(arena, node.inputs[1]),
4619 x: node_offset(arena, node.id),
4620 batch: batch as u32,
4621 n: n as u32,
4622 nrhs: nrhs as u32,
4623 },
4624 rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4625 a: node_offset(arena, node.inputs[0]),
4626 b: node_offset(arena, node.inputs[1]),
4627 x: node_offset(arena, node.id),
4628 batch: batch as u32,
4629 n: n as u32,
4630 nrhs: nrhs as u32,
4631 },
4632 other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4633 }
4634 }
4635
4636 Op::Scan {
4637 body,
4638 length,
4639 save_trajectory,
4640 num_bcast,
4641 num_xs,
4642 num_checkpoints,
4643 } => {
4644 assert!(
4645 *num_checkpoints == 0 || *num_checkpoints <= *length,
4646 "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4647 *num_checkpoints,
4648 *length
4649 );
4650 if *num_checkpoints != 0 && *num_checkpoints != *length {
4651 assert!(
4652 *save_trajectory,
4653 "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4654 );
4655 }
4656 let body_plan = rlx_opt::memory::plan_memory(body);
4667 let _body_arena_size = body_plan.arena_size;
4668 let body_offsets: HashMap<NodeId, usize> = body_plan
4671 .assignments
4672 .iter()
4673 .map(|(id, slot)| (*id, slot.offset))
4674 .collect();
4675
4676 let mut body_inputs: Vec<NodeId> = body
4679 .nodes()
4680 .iter()
4681 .filter(|n| matches!(n.op, Op::Input { .. }))
4682 .map(|n| n.id)
4683 .collect();
4684 body_inputs.sort();
4685 let n_body_inputs = body_inputs.len();
4686 let expected = 1 + *num_bcast as usize + *num_xs as usize;
4687 if n_body_inputs != expected {
4688 let names: Vec<String> = body
4689 .nodes()
4690 .iter()
4691 .filter_map(|n| match &n.op {
4692 Op::Input { name } => Some(format!("{}={}", n.id, name)),
4693 _ => None,
4694 })
4695 .collect();
4696 panic!(
4697 "Op::Scan body has {} Op::Input nodes; expected {} \
4698 (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4699 n_body_inputs,
4700 expected,
4701 *num_bcast,
4702 *num_xs,
4703 names.join(", ")
4704 );
4705 }
4706
4707 let body_input_id = body_inputs[0];
4708 let body_input_off = body_offsets[&body_input_id];
4709 let body_output_id = body
4710 .outputs
4711 .first()
4712 .copied()
4713 .expect("Op::Scan body must declare one output");
4714 let body_output_off = body_offsets[&body_output_id];
4715
4716 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4717 for n in body.nodes() {
4720 if let Op::Constant { data } = &n.op
4721 && body_arena.has_buffer(n.id)
4722 && !data.is_empty()
4723 {
4724 match n.shape.dtype() {
4725 rlx_ir::DType::F64 => {
4726 let off = body_arena.byte_offset(n.id);
4727 let buf = body_arena.raw_buf_mut();
4728 let nbytes = (buf.len() - off).min(data.len());
4729 buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4730 }
4731 _ => {
4732 let buf = body_arena.slice_mut(n.id);
4733 let n_floats = data.len() / 4;
4734 let n_lim = buf.len().min(n_floats);
4735 for i in 0..n_lim {
4736 let bytes = [
4737 data[i * 4],
4738 data[i * 4 + 1],
4739 data[i * 4 + 2],
4740 data[i * 4 + 3],
4741 ];
4742 buf[i] = f32::from_le_bytes(bytes);
4743 }
4744 }
4745 }
4746 }
4747 }
4748 let body_init = body_arena.raw_buf().to_vec();
4749 let body_schedule = compile_thunks(body, &body_arena);
4750
4751 let carry_bytes = if *save_trajectory {
4756 let total = node
4757 .shape
4758 .size_bytes()
4759 .expect("Op::Scan trajectory output must have static shape");
4760 total / *length as usize
4761 } else {
4762 node.shape
4763 .size_bytes()
4764 .expect("Op::Scan carry must have static shape")
4765 };
4766
4767 let mut bcast_inputs: Vec<(usize, usize, u32)> =
4772 Vec::with_capacity(*num_bcast as usize);
4773 for i in 0..*num_bcast as usize {
4774 let body_b_id = body_inputs[1 + i];
4775 let body_b_off = body_offsets[&body_b_id];
4776 let outer_b_id = node.inputs[1 + i];
4777 let outer_b_off = node_offset(arena, outer_b_id);
4778 let outer_b_shape = &graph.node(outer_b_id).shape;
4779 let total = outer_b_shape
4780 .size_bytes()
4781 .expect("Op::Scan bcast must have static shape");
4782 bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4783 }
4784
4785 let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4789 let xs_base = 1 + *num_bcast as usize;
4790 for i in 0..*num_xs as usize {
4791 let body_x_id = body_inputs[xs_base + i];
4792 let body_x_off = body_offsets[&body_x_id];
4793 let outer_xs_id = node.inputs[xs_base + i];
4794 let outer_xs_off = node_offset(arena, outer_xs_id);
4795 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4796 let total = outer_xs_shape
4797 .size_bytes()
4798 .expect("Op::Scan xs must have static shape");
4799 let per_step = total / *length as usize;
4800 xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4801 }
4802
4803 Thunk::Scan {
4804 body: Arc::new(body_schedule),
4805 body_init: Arc::new(body_init),
4806 body_input_off,
4807 body_output_off,
4808 outer_init_off: node_offset(arena, node.inputs[0]),
4809 outer_final_off: node_offset(arena, node.id),
4810 length: *length,
4811 carry_bytes: carry_bytes as u32,
4812 save_trajectory: *save_trajectory,
4813 xs_inputs: Arc::new(xs_inputs),
4814 bcast_inputs: Arc::new(bcast_inputs),
4815 num_checkpoints: *num_checkpoints,
4816 }
4817 }
4818
4819 Op::ScanBackward {
4820 body_vjp,
4821 length,
4822 save_trajectory,
4823 num_xs,
4824 num_checkpoints,
4825 forward_body,
4826 } => {
4827 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4828 if is_recursive {
4829 assert!(
4830 forward_body.is_some(),
4831 "Op::ScanBackward with num_checkpoints<length requires forward_body"
4832 );
4833 }
4834 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4842 let body_offsets: HashMap<NodeId, usize> = body_plan
4843 .assignments
4844 .iter()
4845 .map(|(id, slot)| (*id, slot.offset))
4846 .collect();
4847 let mut body_d_output_off: Option<usize> = None;
4848 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4849 for n in body_vjp.nodes() {
4850 if let Op::Input { name } = &n.op {
4851 let off = body_offsets[&n.id];
4852 if name == "d_output" {
4853 body_d_output_off = Some(off);
4854 } else {
4855 body_other_inputs.push((n.id, off));
4856 }
4857 }
4858 }
4859 body_other_inputs.sort_by_key(|(id, _)| *id);
4860 let body_d_output_off =
4861 body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4862 let expected_others = 1 + *num_xs as usize;
4863 assert_eq!(
4864 body_other_inputs.len(),
4865 expected_others,
4866 "ScanBackward body_vjp has {} non-d_output Inputs; \
4867 expected {} (1 carry + {} xs)",
4868 body_other_inputs.len(),
4869 expected_others,
4870 num_xs
4871 );
4872 let body_carry_in_off = body_other_inputs[0].1;
4873 let body_x_offs: Vec<usize> = body_other_inputs
4874 .iter()
4875 .skip(1)
4876 .map(|(_, off)| *off)
4877 .collect();
4878 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4879
4880 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4881 for n in body_vjp.nodes() {
4883 if let Op::Constant { data } = &n.op
4884 && body_arena.has_buffer(n.id)
4885 && !data.is_empty()
4886 {
4887 match n.shape.dtype() {
4888 rlx_ir::DType::F64 => {
4889 let off = body_arena.byte_offset(n.id);
4890 let buf = body_arena.raw_buf_mut();
4891 let nb = (buf.len() - off).min(data.len());
4892 buf[off..off + nb].copy_from_slice(&data[..nb]);
4893 }
4894 _ => {
4895 let buf = body_arena.slice_mut(n.id);
4896 let nf = data.len() / 4;
4897 let nl = buf.len().min(nf);
4898 for i in 0..nl {
4899 let bytes = [
4900 data[i * 4],
4901 data[i * 4 + 1],
4902 data[i * 4 + 2],
4903 data[i * 4 + 3],
4904 ];
4905 buf[i] = f32::from_le_bytes(bytes);
4906 }
4907 }
4908 }
4909 }
4910 }
4911 let body_init = body_arena.raw_buf().to_vec();
4912 let body_schedule = compile_thunks(body_vjp, &body_arena);
4913
4914 let carry_bytes = body_vjp
4916 .node(body_vjp.outputs[0])
4917 .shape
4918 .size_bytes()
4919 .expect("ScanBackward dcarry must be statically shaped");
4920 let carry_elem_size = body_vjp
4921 .node(body_vjp.outputs[0])
4922 .shape
4923 .dtype()
4924 .size_bytes() as u32;
4925
4926 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4929 for i in 0..*num_xs as usize {
4930 let outer_xs_id = node.inputs[3 + i];
4931 let outer_xs_off = node_offset(arena, outer_xs_id);
4932 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4933 let total = outer_xs_shape
4934 .size_bytes()
4935 .expect("ScanBackward xs must have static shape");
4936 let per_step = total / *length as usize;
4937 outer_xs_offs.push((outer_xs_off, per_step as u32));
4938 }
4939
4940 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4945 if is_recursive {
4946 let fb = forward_body.as_ref().unwrap();
4947 let fb_plan = rlx_opt::memory::plan_memory(fb);
4948 let fb_offsets: HashMap<NodeId, usize> = fb_plan
4949 .assignments
4950 .iter()
4951 .map(|(id, slot)| (*id, slot.offset))
4952 .collect();
4953 let mut fb_inputs: Vec<NodeId> = fb
4954 .nodes()
4955 .iter()
4956 .filter(|n| matches!(n.op, Op::Input { .. }))
4957 .map(|n| n.id)
4958 .collect();
4959 fb_inputs.sort();
4960 let fb_carry = fb_offsets[&fb_inputs[0]];
4961 let fb_xs: Vec<usize> = (1..fb_inputs.len())
4962 .map(|i| fb_offsets[&fb_inputs[i]])
4963 .collect();
4964 let fb_out = fb_offsets[&fb.outputs[0]];
4965 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4966 for n in fb.nodes() {
4967 if let Op::Constant { data } = &n.op
4968 && fb_arena.has_buffer(n.id)
4969 && !data.is_empty()
4970 {
4971 let off = fb_arena.byte_offset(n.id);
4978 let buf = fb_arena.raw_buf_mut();
4979 let nb = (buf.len() - off).min(data.len());
4980 buf[off..off + nb].copy_from_slice(&data[..nb]);
4981 }
4982 }
4983 let fb_init_bytes = fb_arena.raw_buf().to_vec();
4984 let fb_sched = compile_thunks(fb, &fb_arena);
4985 (
4986 Some(Arc::new(fb_sched)),
4987 Some(Arc::new(fb_init_bytes)),
4988 fb_carry,
4989 fb_out,
4990 fb_xs,
4991 )
4992 } else {
4993 (None, None, 0, 0, Vec::new())
4994 };
4995
4996 Thunk::ScanBackward {
4997 body_vjp: Arc::new(body_schedule),
4998 body_init: Arc::new(body_init),
4999 body_carry_in_off,
5000 body_x_offs: Arc::new(body_x_offs),
5001 body_d_output_off,
5002 body_dcarry_out_off,
5003 outer_init_off: node_offset(arena, node.inputs[0]),
5004 outer_traj_off: node_offset(arena, node.inputs[1]),
5005 outer_upstream_off: node_offset(arena, node.inputs[2]),
5006 outer_xs_offs: Arc::new(outer_xs_offs),
5007 outer_dinit_off: node_offset(arena, node.id),
5008 length: *length,
5009 carry_bytes: carry_bytes as u32,
5010 carry_elem_size,
5011 save_trajectory: *save_trajectory,
5012 num_checkpoints: *num_checkpoints,
5013 forward_body: fb_schedule,
5014 forward_body_init: fb_init,
5015 forward_body_carry_in_off: fb_carry_in_off,
5016 forward_body_output_off: fb_output_off,
5017 forward_body_x_offs: Arc::new(fb_x_offs),
5018 }
5019 }
5020
5021 Op::ScanBackwardXs {
5022 body_vjp,
5023 length,
5024 save_trajectory,
5025 num_xs,
5026 xs_idx,
5027 num_checkpoints,
5028 forward_body,
5029 } => {
5030 assert!(
5031 *num_checkpoints == 0 || *num_checkpoints <= *length,
5032 "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
5033 *num_checkpoints,
5034 *length
5035 );
5036 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
5037 if is_recursive {
5038 assert!(
5039 forward_body.is_some(),
5040 "Op::ScanBackwardXs with num_checkpoints<length \
5041 requires forward_body"
5042 );
5043 }
5044 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
5052 let body_offsets: HashMap<NodeId, usize> = body_plan
5053 .assignments
5054 .iter()
5055 .map(|(id, slot)| (*id, slot.offset))
5056 .collect();
5057 let mut body_d_output_off: Option<usize> = None;
5058 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
5059 for n in body_vjp.nodes() {
5060 if let Op::Input { name } = &n.op {
5061 let off = body_offsets[&n.id];
5062 if name == "d_output" {
5063 body_d_output_off = Some(off);
5064 } else {
5065 body_other_inputs.push((n.id, off));
5066 }
5067 }
5068 }
5069 body_other_inputs.sort_by_key(|(id, _)| *id);
5070 let body_d_output_off =
5071 body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
5072 let expected_others = 1 + *num_xs as usize;
5073 assert_eq!(
5074 body_other_inputs.len(),
5075 expected_others,
5076 "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
5077 body_other_inputs.len(),
5078 expected_others
5079 );
5080 let body_carry_in_off = body_other_inputs[0].1;
5081 let body_x_offs: Vec<usize> = body_other_inputs
5082 .iter()
5083 .skip(1)
5084 .map(|(_, off)| *off)
5085 .collect();
5086 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
5087 let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
5088 let body_dxs_out_off = body_offsets[&dxs_out_node];
5089
5090 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5091 for n in body_vjp.nodes() {
5092 if let Op::Constant { data } = &n.op
5093 && body_arena.has_buffer(n.id)
5094 && !data.is_empty()
5095 {
5096 match n.shape.dtype() {
5097 rlx_ir::DType::F64 => {
5098 let off = body_arena.byte_offset(n.id);
5099 let buf = body_arena.raw_buf_mut();
5100 let nb = (buf.len() - off).min(data.len());
5101 buf[off..off + nb].copy_from_slice(&data[..nb]);
5102 }
5103 _ => {
5104 let buf = body_arena.slice_mut(n.id);
5105 let nf = data.len() / 4;
5106 let nl = buf.len().min(nf);
5107 for i in 0..nl {
5108 let bytes = [
5109 data[i * 4],
5110 data[i * 4 + 1],
5111 data[i * 4 + 2],
5112 data[i * 4 + 3],
5113 ];
5114 buf[i] = f32::from_le_bytes(bytes);
5115 }
5116 }
5117 }
5118 }
5119 }
5120 let body_init = body_arena.raw_buf().to_vec();
5121 let body_schedule = compile_thunks(body_vjp, &body_arena);
5122
5123 let carry_bytes = body_vjp
5124 .node(body_vjp.outputs[0])
5125 .shape
5126 .size_bytes()
5127 .expect("ScanBackwardXs dcarry must be statically shaped");
5128 let carry_elem_size = body_vjp
5129 .node(body_vjp.outputs[0])
5130 .shape
5131 .dtype()
5132 .size_bytes() as u32;
5133 let per_step_bytes = body_vjp
5134 .node(dxs_out_node)
5135 .shape
5136 .size_bytes()
5137 .expect("ScanBackwardXs dxs body output must be statically shaped");
5138
5139 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
5140 for i in 0..*num_xs as usize {
5141 let outer_xs_id = node.inputs[3 + i];
5142 let outer_xs_off = node_offset(arena, outer_xs_id);
5143 let outer_xs_shape = &graph.node(outer_xs_id).shape;
5144 let total = outer_xs_shape
5145 .size_bytes()
5146 .expect("ScanBackwardXs xs must have static shape");
5147 let per_step = total / *length as usize;
5148 outer_xs_offs.push((outer_xs_off, per_step as u32));
5149 }
5150
5151 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
5154 if is_recursive {
5155 let fb = forward_body.as_ref().unwrap();
5156 let fb_plan = rlx_opt::memory::plan_memory(fb);
5157 let fb_offsets: HashMap<NodeId, usize> = fb_plan
5158 .assignments
5159 .iter()
5160 .map(|(id, slot)| (*id, slot.offset))
5161 .collect();
5162 let mut fb_inputs: Vec<NodeId> = fb
5163 .nodes()
5164 .iter()
5165 .filter(|n| matches!(n.op, Op::Input { .. }))
5166 .map(|n| n.id)
5167 .collect();
5168 fb_inputs.sort();
5169 let fb_carry = fb_offsets[&fb_inputs[0]];
5170 let fb_xs: Vec<usize> = (1..fb_inputs.len())
5171 .map(|i| fb_offsets[&fb_inputs[i]])
5172 .collect();
5173 let fb_out = fb_offsets[&fb.outputs[0]];
5174 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
5175 for n in fb.nodes() {
5176 if let Op::Constant { data } = &n.op
5177 && fb_arena.has_buffer(n.id)
5178 && !data.is_empty()
5179 {
5180 let off = fb_arena.byte_offset(n.id);
5187 let buf = fb_arena.raw_buf_mut();
5188 let nb = (buf.len() - off).min(data.len());
5189 buf[off..off + nb].copy_from_slice(&data[..nb]);
5190 }
5191 }
5192 let fb_init_bytes = fb_arena.raw_buf().to_vec();
5193 let fb_sched = compile_thunks(fb, &fb_arena);
5194 (
5195 Some(Arc::new(fb_sched)),
5196 Some(Arc::new(fb_init_bytes)),
5197 fb_carry,
5198 fb_out,
5199 fb_xs,
5200 )
5201 } else {
5202 (None, None, 0, 0, Vec::new())
5203 };
5204
5205 Thunk::ScanBackwardXs {
5206 body_vjp: Arc::new(body_schedule),
5207 body_init: Arc::new(body_init),
5208 body_carry_in_off,
5209 body_x_offs: Arc::new(body_x_offs),
5210 body_d_output_off,
5211 body_dcarry_out_off,
5212 body_dxs_out_off,
5213 outer_init_off: node_offset(arena, node.inputs[0]),
5214 outer_traj_off: node_offset(arena, node.inputs[1]),
5215 outer_upstream_off: node_offset(arena, node.inputs[2]),
5216 outer_xs_offs: Arc::new(outer_xs_offs),
5217 outer_dxs_off: node_offset(arena, node.id),
5218 length: *length,
5219 carry_bytes: carry_bytes as u32,
5220 carry_elem_size,
5221 per_step_bytes: per_step_bytes as u32,
5222 save_trajectory: *save_trajectory,
5223 num_checkpoints: *num_checkpoints,
5224 forward_body: fb_schedule,
5225 forward_body_init: fb_init,
5226 forward_body_carry_in_off: fb_carry_in_off,
5227 forward_body_output_off: fb_output_off,
5228 forward_body_x_offs: Arc::new(fb_x_offs),
5229 }
5230 }
5231
5232 Op::Concat { axis } => {
5233 let out_shape = &node.shape;
5237 let rank = out_shape.rank();
5238 let outer: usize = (0..*axis)
5239 .map(|i| out_shape.dim(i).unwrap_static())
5240 .product::<usize>()
5241 .max(1);
5242 let inner: usize = (*axis + 1..rank)
5243 .map(|i| out_shape.dim(i).unwrap_static())
5244 .product::<usize>()
5245 .max(1);
5246 let total_axis = out_shape.dim(*axis).unwrap_static();
5247 let inputs: Vec<(usize, u32)> = node
5248 .inputs
5249 .iter()
5250 .map(|&in_id| {
5251 let in_shape = &graph.node(in_id).shape;
5252 let in_axis = in_shape.dim(*axis).unwrap_static();
5253 (node_offset(arena, in_id), in_axis as u32)
5254 })
5255 .collect();
5256 let dst = node_offset(arena, node.id);
5257 match out_shape.dtype() {
5258 rlx_ir::DType::F64 => Thunk::ConcatF64 {
5259 dst,
5260 outer: outer as u32,
5261 inner: inner as u32,
5262 total_axis: total_axis as u32,
5263 inputs,
5264 },
5265 _ => Thunk::Concat {
5266 dst,
5267 outer: outer as u32,
5268 inner: inner as u32,
5269 total_axis: total_axis as u32,
5270 inputs,
5271 },
5272 }
5273 }
5274
5275 Op::GaussianSplatRender {
5276 width,
5277 height,
5278 tile_size,
5279 radius_scale,
5280 alpha_cutoff,
5281 max_splat_steps,
5282 transmittance_threshold,
5283 max_list_entries,
5284 } => {
5285 let elem_len =
5286 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5287 Thunk::GaussianSplatRender {
5288 positions_off: node_offset(arena, node.inputs[0]),
5289 positions_len: elem_len(node.inputs[0]),
5290 scales_off: node_offset(arena, node.inputs[1]),
5291 scales_len: elem_len(node.inputs[1]),
5292 rotations_off: node_offset(arena, node.inputs[2]),
5293 rotations_len: elem_len(node.inputs[2]),
5294 opacities_off: node_offset(arena, node.inputs[3]),
5295 opacities_len: elem_len(node.inputs[3]),
5296 colors_off: node_offset(arena, node.inputs[4]),
5297 colors_len: elem_len(node.inputs[4]),
5298 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5299 sh_coeffs_len: elem_len(node.inputs[5]),
5300 meta_off: node_offset(arena, node.inputs[6]),
5301 dst_off: node_offset(arena, node.id),
5302 dst_len: node.shape.num_elements().unwrap_or(0),
5303 width: *width,
5304 height: *height,
5305 tile_size: *tile_size,
5306 radius_scale: *radius_scale,
5307 alpha_cutoff: *alpha_cutoff,
5308 max_splat_steps: *max_splat_steps,
5309 transmittance_threshold: *transmittance_threshold,
5310 max_list_entries: *max_list_entries,
5311 }
5312 }
5313
5314 Op::GaussianSplatRenderBackward {
5315 width,
5316 height,
5317 tile_size,
5318 radius_scale,
5319 alpha_cutoff,
5320 max_splat_steps,
5321 transmittance_threshold,
5322 max_list_entries,
5323 loss_grad_clip,
5324 sh_band,
5325 max_anisotropy,
5326 } => {
5327 let elem_len =
5328 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5329 Thunk::GaussianSplatRenderBackward {
5330 positions_off: node_offset(arena, node.inputs[0]),
5331 positions_len: elem_len(node.inputs[0]),
5332 scales_off: node_offset(arena, node.inputs[1]),
5333 scales_len: elem_len(node.inputs[1]),
5334 rotations_off: node_offset(arena, node.inputs[2]),
5335 rotations_len: elem_len(node.inputs[2]),
5336 opacities_off: node_offset(arena, node.inputs[3]),
5337 opacities_len: elem_len(node.inputs[3]),
5338 colors_off: node_offset(arena, node.inputs[4]),
5339 colors_len: elem_len(node.inputs[4]),
5340 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5341 sh_coeffs_len: elem_len(node.inputs[5]),
5342 meta_off: node_offset(arena, node.inputs[6]),
5343 d_loss_off: node_offset(arena, node.inputs[7]),
5344 d_loss_len: elem_len(node.inputs[7]),
5345 packed_off: node_offset(arena, node.id),
5346 packed_len: node.shape.num_elements().unwrap_or(0),
5347 width: *width,
5348 height: *height,
5349 tile_size: *tile_size,
5350 radius_scale: *radius_scale,
5351 alpha_cutoff: *alpha_cutoff,
5352 max_splat_steps: *max_splat_steps,
5353 transmittance_threshold: *transmittance_threshold,
5354 max_list_entries: *max_list_entries,
5355 loss_grad_clip: *loss_grad_clip,
5356 sh_band: *sh_band,
5357 max_anisotropy: *max_anisotropy,
5358 }
5359 }
5360
5361 Op::GaussianSplatPrepare {
5362 width,
5363 height,
5364 tile_size,
5365 radius_scale,
5366 alpha_cutoff,
5367 max_splat_steps,
5368 transmittance_threshold,
5369 max_list_entries,
5370 } => {
5371 let elem_len =
5372 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5373 Thunk::GaussianSplatPrepare {
5374 positions_off: node_offset(arena, node.inputs[0]),
5375 positions_len: elem_len(node.inputs[0]),
5376 scales_off: node_offset(arena, node.inputs[1]),
5377 scales_len: elem_len(node.inputs[1]),
5378 rotations_off: node_offset(arena, node.inputs[2]),
5379 rotations_len: elem_len(node.inputs[2]),
5380 opacities_off: node_offset(arena, node.inputs[3]),
5381 opacities_len: elem_len(node.inputs[3]),
5382 colors_off: node_offset(arena, node.inputs[4]),
5383 colors_len: elem_len(node.inputs[4]),
5384 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5385 sh_coeffs_len: elem_len(node.inputs[5]),
5386 meta_off: node_offset(arena, node.inputs[6]),
5387 meta_len: elem_len(node.inputs[6]),
5388 prep_off: node_offset(arena, node.id),
5389 prep_len: node.shape.num_elements().unwrap_or(0),
5390 width: *width,
5391 height: *height,
5392 tile_size: *tile_size,
5393 radius_scale: *radius_scale,
5394 alpha_cutoff: *alpha_cutoff,
5395 max_splat_steps: *max_splat_steps,
5396 transmittance_threshold: *transmittance_threshold,
5397 max_list_entries: *max_list_entries,
5398 }
5399 }
5400
5401 Op::GaussianSplatRasterize {
5402 width,
5403 height,
5404 tile_size,
5405 alpha_cutoff,
5406 max_splat_steps,
5407 transmittance_threshold,
5408 max_list_entries,
5409 } => {
5410 let elem_len =
5411 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5412 let prep_id = node.inputs[0];
5413 let count = match &graph.node(prep_id).op {
5414 rlx_ir::Op::GaussianSplatPrepare { .. } => {
5415 elem_len(graph.node(prep_id).inputs[0]) / 3
5416 }
5417 _ => 1,
5418 };
5419 Thunk::GaussianSplatRasterize {
5420 prep_off: node_offset(arena, prep_id),
5421 prep_len: elem_len(prep_id),
5422 meta_off: node_offset(arena, node.inputs[1]),
5423 meta_len: elem_len(node.inputs[1]),
5424 dst_off: node_offset(arena, node.id),
5425 dst_len: node.shape.num_elements().unwrap_or(0),
5426 count,
5427 width: *width,
5428 height: *height,
5429 tile_size: *tile_size,
5430 alpha_cutoff: *alpha_cutoff,
5431 max_splat_steps: *max_splat_steps,
5432 transmittance_threshold: *transmittance_threshold,
5433 max_list_entries: *max_list_entries,
5434 }
5435 }
5436
5437 Op::Custom { name, attrs, .. } => {
5438 let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5439 panic!(
5440 "compile_thunks: no CPU kernel registered for \
5441 Op::Custom('{name}'). Register one via \
5442 rlx_cpu::op_registry::register_cpu_kernel \
5443 before compiling on the CPU backend."
5444 )
5445 });
5446 let inputs_v: Vec<(usize, u32, Shape)> = node
5447 .inputs
5448 .iter()
5449 .map(|&in_id| {
5450 let s = graph.node(in_id).shape.clone();
5451 let len = s.num_elements().unwrap_or(0) as u32;
5452 (node_offset(arena, in_id), len, s)
5453 })
5454 .collect();
5455 let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5456 Thunk::CustomOp {
5457 kernel,
5458 inputs: inputs_v,
5459 output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5460 attrs: attrs.clone(),
5461 }
5462 }
5463
5464 Op::Fft { inverse, norm } => {
5465 let shape = &node.shape;
5466 let meta = rlx_ir::fft::fft_meta(shape);
5467 let dtype = shape.dtype();
5468 assert!(
5469 matches!(
5470 dtype,
5471 rlx_ir::DType::F32 | rlx_ir::DType::F64 | rlx_ir::DType::C64
5472 ),
5473 "Op::Fft on CPU requires F32, F64, or C64, got {dtype:?}"
5474 );
5475 Thunk::Fft1d {
5476 src: node_offset(arena, node.inputs[0]),
5477 dst: node_offset(arena, node.id),
5478 outer: meta.outer as u32,
5479 n_complex: meta.n_complex as u32,
5480 inverse: *inverse,
5481 norm_tag: norm.tag(),
5482 dtype,
5483 }
5484 }
5485
5486 Op::FftButterflyStage { stage, n_fft } => {
5487 let state_shape = graph.node(node.inputs[0]).shape.clone();
5488 assert_eq!(
5489 state_shape.dtype(),
5490 rlx_ir::DType::F32,
5491 "Op::FftButterflyStage requires F32 state"
5492 );
5493 let batch = state_shape.dim(0).unwrap_static() as u32;
5494 Thunk::FftButterflyStage {
5495 state_src: node_offset(arena, node.inputs[0]),
5496 state_dst: node_offset(arena, node.id),
5497 gate_src: node_offset(arena, node.inputs[1]),
5498 rev_src: node_offset(arena, node.inputs[2]),
5499 tw_re_src: node_offset(arena, node.inputs[3]),
5500 tw_im_src: node_offset(arena, node.inputs[4]),
5501 batch,
5502 n_fft: *n_fft,
5503 stage: *stage,
5504 }
5505 }
5506
5507 Op::LogMel => {
5508 let spec_shape = graph.node(node.inputs[0]).shape.clone();
5509 let filt_shape = graph.node(node.inputs[1]).shape.clone();
5510 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5511 .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
5512 Thunk::LogMel {
5513 spec: node_offset(arena, node.inputs[0]),
5514 filters: node_offset(arena, node.inputs[1]),
5515 dst: node_offset(arena, node.id),
5516 outer: meta.outer as u32,
5517 n_fft: meta.n_fft as u32,
5518 n_bins: meta.n_bins as u32,
5519 n_mels: meta.n_mels as u32,
5520 }
5521 }
5522
5523 Op::LogMelBackward => {
5524 let spec_shape = graph.node(node.inputs[0]).shape.clone();
5525 let filt_shape = graph.node(node.inputs[1]).shape.clone();
5526 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5527 .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
5528 Thunk::LogMelBackward {
5529 spec: node_offset(arena, node.inputs[0]),
5530 filters: node_offset(arena, node.inputs[1]),
5531 dy: node_offset(arena, node.inputs[2]),
5532 dst: node_offset(arena, node.id),
5533 outer: meta.outer as u32,
5534 n_fft: meta.n_fft as u32,
5535 n_bins: meta.n_bins as u32,
5536 n_mels: meta.n_mels as u32,
5537 }
5538 }
5539
5540 Op::WelchPeaks { k, n_segments } => {
5541 let spec_shape = graph.node(node.inputs[0]).shape.clone();
5542 let meta = rlx_ir::audio::welch_peaks_meta(&spec_shape, *k, *n_segments)
5543 .unwrap_or_else(|e| panic!("Op::WelchPeaks: {e}"));
5544 Thunk::WelchPeaks {
5545 spec: node_offset(arena, node.inputs[0]),
5546 dst: node_offset(arena, node.id),
5547 welch_batch: meta.welch_batch as u32,
5548 n_fft: meta.n_fft as u32,
5549 n_segments: meta.n_segments as u32,
5550 k: meta.k as u32,
5551 }
5552 }
5553
5554 Op::CustomFn {
5555 fwd_body,
5556 num_inputs,
5557 ..
5558 } => {
5559 let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5565 let body_offsets: HashMap<NodeId, usize> = body_plan
5566 .assignments
5567 .iter()
5568 .map(|(id, slot)| (*id, slot.offset))
5569 .collect();
5570
5571 let mut body_input_ids: Vec<NodeId> = fwd_body
5572 .nodes()
5573 .iter()
5574 .filter(|n| matches!(n.op, Op::Input { .. }))
5575 .map(|n| n.id)
5576 .collect();
5577 body_input_ids.sort();
5578 assert_eq!(
5579 body_input_ids.len(),
5580 *num_inputs as usize,
5581 "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5582 body_input_ids.len(),
5583 *num_inputs,
5584 );
5585
5586 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5587 for n in fwd_body.nodes() {
5588 if let Op::Constant { data } = &n.op
5589 && body_arena.has_buffer(n.id)
5590 && !data.is_empty()
5591 {
5592 match n.shape.dtype() {
5593 rlx_ir::DType::F64 => {
5594 let off = body_arena.byte_offset(n.id);
5595 let buf = body_arena.raw_buf_mut();
5596 let nb = (buf.len() - off).min(data.len());
5597 buf[off..off + nb].copy_from_slice(&data[..nb]);
5598 }
5599 _ => {
5600 let buf = body_arena.slice_mut(n.id);
5601 let nf = data.len() / 4;
5602 let nl = buf.len().min(nf);
5603 for i in 0..nl {
5604 let bytes = [
5605 data[i * 4],
5606 data[i * 4 + 1],
5607 data[i * 4 + 2],
5608 data[i * 4 + 3],
5609 ];
5610 buf[i] = f32::from_le_bytes(bytes);
5611 }
5612 }
5613 }
5614 }
5615 }
5616 let body_init = body_arena.raw_buf().to_vec();
5617 let body_schedule = compile_thunks(fwd_body, &body_arena);
5618
5619 let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5621 .map(|i| {
5622 let body_in = body_input_ids[i];
5623 let body_off = body_offsets[&body_in];
5624 let outer_in = node.inputs[i];
5625 let outer_off = node_offset(arena, outer_in);
5626 let bytes = graph
5627 .node(outer_in)
5628 .shape
5629 .size_bytes()
5630 .expect("Op::CustomFn primal input must have static shape");
5631 (body_off, outer_off, bytes as u32)
5632 })
5633 .collect();
5634
5635 let body_output_id = fwd_body
5636 .outputs
5637 .first()
5638 .copied()
5639 .expect("Op::CustomFn fwd_body must declare exactly one output");
5640 let body_output_off = body_offsets[&body_output_id];
5641 let out_bytes = node
5642 .shape
5643 .size_bytes()
5644 .expect("Op::CustomFn output must have static shape");
5645
5646 Thunk::CustomFn {
5647 body: Arc::new(body_schedule),
5648 body_init: Arc::new(body_init),
5649 inputs: Arc::new(inputs_v),
5650 body_output_off,
5651 outer_output_off: node_offset(arena, node.id),
5652 out_bytes: out_bytes as u32,
5653 }
5654 }
5655
5656 _ => Thunk::Nop,
5657 };
5658 thunks.push(t);
5659 }
5660
5661 let cfg = crate::config::RuntimeConfig::global();
5662 let mask_thr = cfg.mask_binary_threshold;
5663 let mask_neg = cfg.attn_mask_neg_inf;
5664 let score_skip = cfg.score_skip_threshold;
5665
5666 let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5668 .iter()
5669 .filter(|t| !matches!(t, Thunk::Nop))
5670 .map(|thunk| {
5671 match thunk.clone() {
5672 Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5673
5674 Thunk::Sgemm { a, b, c, m, k, n } => {
5675 let (m, k, n) = (m as usize, k as usize, n as usize);
5676 Arc::new(move |base: *mut u8| unsafe {
5677 crate::blas::sgemm(
5678 sl(a, base, m * k),
5679 sl(b, base, k * n),
5680 sl_mut(c, base, m * n),
5681 m,
5682 k,
5683 n,
5684 );
5685 })
5686 }
5687
5688 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5689 let (n_, nrhs_) = (n as usize, nrhs as usize);
5690 Arc::new(move |base: *mut u8| unsafe {
5691 let a_src = sl_f64(a, base, n_ * n_);
5692 let b_src = sl_f64(b, base, n_ * nrhs_);
5693 let mut a_scratch: Vec<f64> = a_src.to_vec();
5694 let mut x_buf: Vec<f64> = b_src.to_vec();
5695 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5696 if info != 0 {
5697 panic!("DenseSolveF64: singular (info={info})");
5698 }
5699 sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5700 })
5701 }
5702
5703 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5704 let (n_, nrhs_) = (n as usize, nrhs as usize);
5705 Arc::new(move |base: *mut u8| unsafe {
5706 let a_src = sl(a, base, n_ * n_);
5707 let b_src = sl(b, base, n_ * nrhs_);
5708 let mut a_scratch: Vec<f32> = a_src.to_vec();
5709 let mut x_buf: Vec<f32> = b_src.to_vec();
5710 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5711 if info != 0 {
5712 panic!("DenseSolveF32: singular (info={info})");
5713 }
5714 sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5715 })
5716 }
5717
5718 Thunk::FusedMmBiasAct {
5719 a,
5720 w,
5721 bias,
5722 c,
5723 m,
5724 k,
5725 n,
5726 act,
5727 } => {
5728 let (m, k, n) = (m as usize, k as usize, n as usize);
5729 Arc::new(move |base: *mut u8| unsafe {
5730 let out = sl_mut(c, base, m * n);
5731 crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5732 match act {
5740 Some(Activation::Gelu) => {
5741 crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5742 }
5743 Some(other) => {
5744 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5745 apply_activation_inplace(out, other);
5746 }
5747 None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5748 }
5749 })
5750 }
5751
5752 Thunk::FusedResidualLN {
5753 x,
5754 res,
5755 bias,
5756 g,
5757 b,
5758 out,
5759 rows,
5760 h,
5761 eps,
5762 has_bias,
5763 } => {
5764 let (rows, h) = (rows as usize, h as usize);
5765 Arc::new(move |base: *mut u8| unsafe {
5766 let zero = vec![0f32; h]; let bi = if has_bias { sl(bias, base, h) } else { &zero };
5768 let xp = sl(x, base, rows * h).as_ptr() as usize;
5769 let rp = sl(res, base, rows * h).as_ptr() as usize;
5770 let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5771 let bp = bi.as_ptr() as usize;
5772 let gp = sl(g, base, h).as_ptr() as usize;
5773 let bbp = sl(b, base, h).as_ptr() as usize;
5774 crate::pool::par_for(rows, 4, &|off, cnt| {
5775 let xs = std::slice::from_raw_parts(
5776 (xp as *const f32).add(off * h),
5777 cnt * h,
5778 );
5779 let rs = std::slice::from_raw_parts(
5780 (rp as *const f32).add(off * h),
5781 cnt * h,
5782 );
5783 let os = std::slice::from_raw_parts_mut(
5784 (op as *mut f32).add(off * h),
5785 cnt * h,
5786 );
5787 let bi = std::slice::from_raw_parts(bp as *const f32, h);
5788 let g = std::slice::from_raw_parts(gp as *const f32, h);
5789 let b = std::slice::from_raw_parts(bbp as *const f32, h);
5790 crate::kernels::residual_bias_layer_norm(
5791 xs, rs, bi, g, b, os, cnt, h, eps,
5792 );
5793 });
5794 })
5795 }
5796
5797 Thunk::BiasAdd {
5798 src,
5799 bias,
5800 dst,
5801 m,
5802 n,
5803 } => {
5804 let (m, n) = (m as usize, n as usize);
5805 let len = m * n;
5806 Arc::new(move |base: *mut u8| unsafe {
5807 let out = sl_mut(dst, base, len);
5808 if src != dst {
5809 let src_ptr = base.add(src) as *const f32;
5810 let dst_ptr = base.add(dst) as *mut f32;
5811 if src_ptr != dst_ptr {
5812 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
5813 }
5814 }
5815 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5816 })
5817 }
5818
5819 Thunk::Gather {
5820 table,
5821 table_len,
5822 idx,
5823 dst,
5824 num_idx,
5825 trailing,
5826 idx_i64,
5827 table_bytes,
5828 } => {
5829 let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5830 let rows = tl / tr.max(1);
5831 let (idx_i64, table_bytes) = (idx_i64, table_bytes);
5832 Arc::new(move |base: *mut u8| unsafe {
5833 if table_bytes == 8 {
5834 let tab = sl_i64(table, base, tl);
5835 let out = sl_mut_i64(dst, base, ni * tr);
5836 if idx_i64 != 0 {
5837 let ids = sl_i64(idx, base, ni);
5838 for i in 0..ni {
5839 let row = ids[i].max(0) as usize;
5840 if row < rows {
5841 out[i * tr..(i + 1) * tr]
5842 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5843 }
5844 }
5845 } else {
5846 let ids = sl(idx, base, ni);
5847 for i in 0..ni {
5848 let row = ids[i] as usize;
5849 if row < rows {
5850 out[i * tr..(i + 1) * tr]
5851 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5852 }
5853 }
5854 }
5855 } else {
5856 let tab = sl(table, base, tl);
5857 let out = sl_mut(dst, base, ni * tr);
5858 if idx_i64 != 0 {
5859 let ids = sl_i64(idx, base, ni);
5860 for i in 0..ni {
5861 let row = ids[i].max(0) as usize;
5862 if row < rows {
5863 out[i * tr..(i + 1) * tr]
5864 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5865 }
5866 }
5867 } else {
5868 let ids = sl(idx, base, ni);
5869 for i in 0..ni {
5870 let row = ids[i] as usize;
5871 if row < rows {
5872 out[i * tr..(i + 1) * tr]
5873 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5874 }
5875 }
5876 }
5877 }
5878 })
5879 }
5880
5881 Thunk::Narrow {
5882 src,
5883 dst,
5884 outer,
5885 src_stride,
5886 dst_stride,
5887 inner,
5888 elem_bytes,
5889 } => {
5890 narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5891 }
5892
5893 Thunk::Copy { src, dst, len } => {
5894 let len = len as usize;
5895 Arc::new(move |base: *mut u8| unsafe {
5896 if src == dst || len == 0 {
5897 return;
5898 }
5899 let src_ptr = base.add(src) as *const f32;
5900 let dst_ptr = base.add(dst) as *mut f32;
5901 if src_ptr == dst_ptr {
5902 return;
5903 }
5904 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
5905 })
5906 }
5907
5908 Thunk::Softmax { data, rows, cols } => {
5909 let (rows, cols) = (rows as usize, cols as usize);
5910 Arc::new(move |base: *mut u8| unsafe {
5911 crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5912 })
5913 }
5914
5915 Thunk::Cumsum {
5916 src,
5917 dst,
5918 rows,
5919 cols,
5920 exclusive,
5921 } => {
5922 let (rows, cols) = (rows as usize, cols as usize);
5923 Arc::new(move |base: *mut u8| unsafe {
5924 let s = sl(src, base, rows * cols);
5925 let d = sl_mut(dst, base, rows * cols);
5926 if exclusive {
5927 for r in 0..rows {
5928 let mut acc = 0.0f32;
5929 for c in 0..cols {
5930 d[r * cols + c] = acc;
5931 acc += s[r * cols + c];
5932 }
5933 }
5934 } else {
5935 for r in 0..rows {
5936 let mut acc = 0.0f32;
5937 for c in 0..cols {
5938 acc += s[r * cols + c];
5939 d[r * cols + c] = acc;
5940 }
5941 }
5942 }
5943 })
5944 }
5945
5946 Thunk::Sample {
5947 logits,
5948 dst,
5949 batch,
5950 vocab,
5951 top_k,
5952 top_p,
5953 temperature,
5954 seed,
5955 } => {
5956 let (b, v) = (batch as usize, vocab as usize);
5957 let k = (top_k as usize).min(v);
5958 Arc::new(move |base: *mut u8| unsafe {
5959 let lg = sl(logits, base, b * v);
5960 let out = sl_mut(dst, base, b);
5961 let mut rng =
5962 rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5963 for bi in 0..b {
5964 let row = &lg[bi * v..(bi + 1) * v];
5965 out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5966 }
5967 })
5968 }
5969
5970 Thunk::DequantMatMul {
5971 x,
5972 w_q,
5973 scale,
5974 zp,
5975 dst,
5976 m,
5977 k,
5978 n,
5979 block_size,
5980 is_asymmetric,
5981 } => {
5982 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5983 let n_blocks_per_col = k.div_ceil(bs);
5984 Arc::new(move |base: *mut u8| unsafe {
5985 let xs = sl(x, base, m * k);
5986 let raw = base.add(w_q);
5988 let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5989 let scales = sl(scale, base, n_blocks_per_col * n);
5990 let zps = if is_asymmetric {
5991 sl(zp, base, n_blocks_per_col * n)
5992 } else {
5993 &[][..]
5994 };
5995 let out = sl_mut(dst, base, m * n);
5996 dequant_matmul_int8(
5997 xs,
5998 w_bytes,
5999 scales,
6000 zps,
6001 out,
6002 m,
6003 k,
6004 n,
6005 bs,
6006 is_asymmetric,
6007 );
6008 })
6009 }
6010
6011 Thunk::DequantMatMulGguf {
6012 x,
6013 w_q,
6014 dst,
6015 m,
6016 k,
6017 n,
6018 scheme,
6019 } => {
6020 let (m, k, n) = (m as usize, k as usize, n as usize);
6021 let block_bytes = scheme.gguf_block_bytes() as usize;
6022 let block_elems = scheme.gguf_block_size() as usize;
6023 let total_bytes = (k * n) / block_elems * block_bytes;
6024 Arc::new(move |base: *mut u8| unsafe {
6025 let xs = sl(x, base, m * k);
6026 let w_bytes =
6027 std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
6028 let out = sl_mut(dst, base, m * n);
6029 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
6030 })
6031 }
6032
6033 Thunk::DequantMatMulInt4 {
6034 x,
6035 w_q,
6036 scale,
6037 zp,
6038 dst,
6039 m,
6040 k,
6041 n,
6042 block_size,
6043 is_asymmetric,
6044 } => {
6045 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
6046 let n_blocks = k.div_ceil(bs);
6047 Arc::new(move |base: *mut u8| unsafe {
6048 let xs = sl(x, base, m * k);
6049 let w_bytes = std::slice::from_raw_parts(
6050 base.add(w_q) as *const u8,
6051 (k * n).div_ceil(2),
6052 );
6053 let scales = sl(scale, base, n_blocks * n);
6054 let zps = if is_asymmetric {
6055 sl(zp, base, n_blocks * n)
6056 } else {
6057 &[][..]
6058 };
6059 let out = sl_mut(dst, base, m * n);
6060 dequant_matmul_int4(
6061 xs,
6062 w_bytes,
6063 scales,
6064 zps,
6065 out,
6066 m,
6067 k,
6068 n,
6069 bs,
6070 is_asymmetric,
6071 );
6072 })
6073 }
6074
6075 Thunk::DequantMatMulFp8 {
6076 x,
6077 w_q,
6078 scale,
6079 dst,
6080 m,
6081 k,
6082 n,
6083 e5m2,
6084 } => {
6085 let (m, k, n) = (m as usize, k as usize, n as usize);
6086 Arc::new(move |base: *mut u8| unsafe {
6087 let xs = sl(x, base, m * k);
6088 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
6089 let scales = sl(scale, base, n);
6090 let out = sl_mut(dst, base, m * n);
6091 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
6092 })
6093 }
6094
6095 Thunk::DequantMatMulNvfp4 {
6096 x,
6097 w_q,
6098 scale,
6099 global_scale,
6100 dst,
6101 m,
6102 k,
6103 n,
6104 } => {
6105 let (m, k, n) = (m as usize, k as usize, n as usize);
6106 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
6107 Arc::new(move |base: *mut u8| unsafe {
6108 let xs = sl(x, base, m * k);
6109 let w_bytes = std::slice::from_raw_parts(
6110 base.add(w_q) as *const u8,
6111 (k * n).div_ceil(2),
6112 );
6113 let scale_bytes =
6114 std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
6115 let gs = sl(global_scale, base, 1)[0];
6116 let out = sl_mut(dst, base, m * n);
6117 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
6118 })
6119 }
6120
6121 Thunk::LoraMatMul {
6122 x,
6123 w,
6124 a,
6125 b,
6126 dst,
6127 m,
6128 k,
6129 n,
6130 r,
6131 scale,
6132 } => {
6133 let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
6134 Arc::new(move |base: *mut u8| unsafe {
6135 let xs = sl(x, base, m * k);
6136 let ws = sl(w, base, k * n);
6137 let a_s = sl(a, base, k * r);
6138 let bs = sl(b, base, r * n);
6139 let out = sl_mut(dst, base, m * n);
6140 crate::blas::sgemm(xs, ws, out, m, k, n);
6142 let mut tmp = vec![0f32; m * r];
6144 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
6145 if scale != 1.0 {
6149 for v in tmp.iter_mut() {
6150 *v *= scale;
6151 }
6152 }
6153 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
6154 })
6155 }
6156
6157 Thunk::LayerNorm {
6158 src,
6159 g,
6160 b,
6161 dst,
6162 rows,
6163 h,
6164 eps,
6165 } => {
6166 let (rows, h) = (rows as usize, h as usize);
6167 Arc::new(move |base: *mut u8| unsafe {
6168 let inp = sl(src, base, rows * h);
6169 let gamma = sl(g, base, h);
6170 let beta = sl(b, base, h);
6171 let out = sl_mut(dst, base, rows * h);
6172 for row in 0..rows {
6173 crate::kernels::layer_norm_row(
6174 &inp[row * h..(row + 1) * h],
6175 gamma,
6176 beta,
6177 &mut out[row * h..(row + 1) * h],
6178 h,
6179 eps,
6180 );
6181 }
6182 })
6183 }
6184
6185 Thunk::BatchNormInference {
6186 src,
6187 g,
6188 b,
6189 mean,
6190 var,
6191 dst,
6192 count,
6193 channels,
6194 eps,
6195 } => {
6196 let count = count as usize;
6197 let c = channels as usize;
6198 let n = count * c;
6199 let (src, g, b, mean, var, dst) = (src, g, b, mean, var, dst);
6200 Arc::new(move |base: *mut u8| unsafe {
6201 crate::kernels::batch_norm_inference(
6202 sl(src, base, n),
6203 sl(g, base, c),
6204 sl(b, base, c),
6205 sl(mean, base, c),
6206 sl(var, base, c),
6207 sl_mut(dst, base, n),
6208 c,
6209 eps,
6210 );
6211 })
6212 }
6213
6214 Thunk::Attention {
6215 q,
6216 k,
6217 v,
6218 mask,
6219 out,
6220 batch,
6221 seq,
6222 kv_seq,
6223 heads,
6224 head_dim,
6225 mask_kind,
6226 q_row_stride,
6227 k_row_stride,
6228 v_row_stride,
6229 bhsd,
6230 } => {
6231 if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6232 eprintln!("[attn-compile] batch={batch} seq={seq} kv_seq={kv_seq} heads={heads} bhsd={bhsd}");
6233 }
6234 let (b, q_s, k_s, nh, dh) = (
6243 batch as usize,
6244 seq as usize,
6245 kv_seq as usize,
6246 heads as usize,
6247 head_dim as usize,
6248 );
6249 let hs = nh * dh;
6250 let qrs = q_row_stride as usize;
6251 let krs = k_row_stride as usize;
6252 let vrs = v_row_stride as usize;
6253 let scale = (dh as f32).powf(-0.5);
6254 Arc::new(move |base: *mut u8| unsafe {
6255 if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6256 eprintln!("[attn] b={b} q_s={q_s} k_s={k_s} nh={nh} dh={dh} bhsd={bhsd} mask_kind={:?}", mask_kind);
6257 }
6258 let (q_len, k_len, v_len, o_len) = if bhsd {
6263 let qn = b * nh * q_s * dh;
6264 let kn = b * nh * k_s * dh;
6265 (qn, kn, kn, qn)
6266 } else {
6267 (b * q_s * qrs, b * k_s * krs, b * k_s * vrs, b * q_s * hs)
6268 };
6269 let q_d = sl(q, base, q_len);
6270 let k_d = sl(k, base, k_len);
6271 let v_d = sl(v, base, v_len);
6272 let m_d: &[f32] = match mask_kind {
6273 rlx_ir::op::MaskKind::Custom => sl(mask, base, b * k_s),
6274 rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * q_s * k_s),
6275 _ => &[],
6276 };
6277 let o_d = sl_mut(out, base, o_len);
6278 let mut qh = vec![0f32; q_s * dh];
6279 let mut kh = vec![0f32; k_s * dh];
6280 let mut vh = vec![0f32; k_s * dh];
6281 let mut sc = vec![0f32; q_s * k_s];
6282 let mut oh = vec![0f32; q_s * dh];
6283 for bi in 0..b {
6284 for hi in 0..nh {
6285 for si in 0..q_s {
6287 let q_off = if bhsd {
6288 bi * nh * q_s * dh + hi * q_s * dh + si * dh
6289 } else {
6290 bi * q_s * qrs + si * qrs + hi * dh
6291 };
6292 qh[si * dh..(si + 1) * dh]
6293 .copy_from_slice(&q_d[q_off..q_off + dh]);
6294 }
6295 for si in 0..k_s {
6297 let (k_off, v_off) = if bhsd {
6298 (
6299 bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6300 bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6301 )
6302 } else {
6303 (
6304 bi * k_s * krs + si * krs + hi * dh,
6305 bi * k_s * vrs + si * vrs + hi * dh,
6306 )
6307 };
6308 kh[si * dh..(si + 1) * dh]
6309 .copy_from_slice(&k_d[k_off..k_off + dh]);
6310 vh[si * dh..(si + 1) * dh]
6311 .copy_from_slice(&v_d[v_off..v_off + dh]);
6312 }
6313 for qi in 0..q_s {
6314 for ki in 0..k_s {
6315 let mut dot = 0f32;
6316 for d in 0..dh {
6317 dot += qh[qi * dh + d] * kh[ki * dh + d];
6318 }
6319 sc[qi * k_s + ki] = dot * scale;
6320 }
6321 }
6322 let q_offset = k_s.saturating_sub(q_s);
6326 match mask_kind {
6327 rlx_ir::op::MaskKind::None => {}
6328 rlx_ir::op::MaskKind::Causal => {
6329 for qi in 0..q_s {
6330 let abs_q = q_offset + qi;
6331 for ki in (abs_q + 1)..k_s {
6332 sc[qi * k_s + ki] = mask_neg;
6333 }
6334 }
6335 }
6336 rlx_ir::op::MaskKind::SlidingWindow(w) => {
6337 for qi in 0..q_s {
6338 let abs_q = q_offset + qi;
6339 let lo = abs_q.saturating_sub(w);
6340 for ki in 0..k_s {
6341 if ki < lo || ki > abs_q {
6342 sc[qi * k_s + ki] = mask_neg;
6343 }
6344 }
6345 }
6346 }
6347 rlx_ir::op::MaskKind::Custom => {
6348 for qi in 0..q_s {
6349 for ki in 0..k_s {
6350 if m_d[bi * k_s + ki] < mask_thr {
6351 sc[qi * k_s + ki] = mask_neg;
6352 }
6353 }
6354 }
6355 }
6356 rlx_ir::op::MaskKind::Bias => {
6357 let per_bh = q_s * k_s;
6358 let off = (bi * nh + hi) * per_bh;
6359 for i in 0..per_bh {
6360 sc[i] += m_d[off + i];
6361 }
6362 }
6363 }
6364 crate::naive::softmax(&mut sc, q_s, k_s);
6365 oh.fill(0.0);
6366 for qi in 0..q_s {
6367 for ki in 0..k_s {
6368 let w = sc[qi * k_s + ki];
6369 if w > score_skip {
6370 for d in 0..dh {
6371 oh[qi * dh + d] += w * vh[ki * dh + d];
6372 }
6373 }
6374 }
6375 }
6376 for si in 0..q_s {
6377 let off = if bhsd {
6378 bi * nh * q_s * dh + hi * q_s * dh + si * dh
6379 } else {
6380 bi * q_s * hs + si * hs + hi * dh
6381 };
6382 o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
6383 }
6384 }
6385 }
6386 })
6387 }
6388
6389 Thunk::FusedSwiGLU {
6390 src,
6391 dst,
6392 n_half,
6393 total,
6394 gate_first,
6395 } => {
6396 let n = n_half as usize;
6397 let t = total as usize;
6398 let outer = t / n;
6399 let in_total = outer * 2 * n;
6400 Arc::new(move |base: *mut u8| unsafe {
6401 let inp = sl(src, base, in_total);
6402 let out = sl_mut(dst, base, t);
6403 for o in 0..outer {
6404 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
6405 let out_row = &mut out[o * n..(o + 1) * n];
6406 for i in 0..n {
6407 let (up, gate) = if gate_first {
6408 (in_row[n + i], in_row[i])
6409 } else {
6410 (in_row[i], in_row[n + i])
6411 };
6412 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
6413 }
6414 }
6415 })
6416 }
6417
6418 Thunk::Concat {
6419 dst,
6420 outer,
6421 inner,
6422 total_axis,
6423 inputs,
6424 } => {
6425 let outer = outer as usize;
6426 let inner = inner as usize;
6427 let total_axis = total_axis as usize;
6428 let out_total = outer * total_axis * inner;
6429 let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
6432 let mut cum: usize = 0;
6433 for (src_off, in_axis) in &inputs {
6434 let in_axis = *in_axis as usize;
6435 layout.push((*src_off, cum * inner, in_axis * inner));
6436 cum += in_axis;
6437 }
6438 Arc::new(move |base: *mut u8| unsafe {
6439 let out = sl_mut(dst, base, out_total);
6440 let row_stride = total_axis * inner;
6441 for (src_off, dst_col_off, copy_per_row) in &layout {
6442 let in_total = outer * *copy_per_row;
6443 let inp = sl(*src_off, base, in_total);
6444 for o in 0..outer {
6445 let dst_row_start = o * row_stride + *dst_col_off;
6446 let src_row_start = o * *copy_per_row;
6447 out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
6448 &inp[src_row_start..src_row_start + *copy_per_row],
6449 );
6450 }
6451 }
6452 })
6453 }
6454
6455 Thunk::CustomOp {
6456 kernel,
6457 inputs,
6458 output,
6459 attrs,
6460 } => {
6461 let kernel = kernel.clone();
6467 let attrs = attrs.clone();
6468 let inputs = inputs.clone();
6469 let (out_off, out_len, out_shape) = output.clone();
6470 Arc::new(move |base: *mut u8| unsafe {
6471 dispatch_custom_op(
6472 &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
6473 );
6474 })
6475 }
6476
6477 Thunk::GaussianSplatRender {
6478 positions_off,
6479 positions_len,
6480 scales_off,
6481 scales_len,
6482 rotations_off,
6483 rotations_len,
6484 opacities_off,
6485 opacities_len,
6486 colors_off,
6487 colors_len,
6488 sh_coeffs_off,
6489 sh_coeffs_len,
6490 meta_off,
6491 dst_off,
6492 dst_len,
6493 width,
6494 height,
6495 tile_size,
6496 radius_scale,
6497 alpha_cutoff,
6498 max_splat_steps,
6499 transmittance_threshold,
6500 max_list_entries,
6501 } => Arc::new(move |base: *mut u8| unsafe {
6502 crate::splat::execute_gaussian_splat_render(
6503 positions_off,
6504 positions_len,
6505 scales_off,
6506 scales_len,
6507 rotations_off,
6508 rotations_len,
6509 opacities_off,
6510 opacities_len,
6511 colors_off,
6512 colors_len,
6513 sh_coeffs_off,
6514 sh_coeffs_len,
6515 meta_off,
6516 dst_off,
6517 dst_len,
6518 width,
6519 height,
6520 tile_size,
6521 radius_scale,
6522 alpha_cutoff,
6523 max_splat_steps,
6524 transmittance_threshold,
6525 max_list_entries,
6526 base,
6527 );
6528 }),
6529
6530 Thunk::GaussianSplatRenderBackward {
6531 positions_off,
6532 positions_len,
6533 scales_off,
6534 scales_len,
6535 rotations_off,
6536 rotations_len,
6537 opacities_off,
6538 opacities_len,
6539 colors_off,
6540 colors_len,
6541 sh_coeffs_off,
6542 sh_coeffs_len,
6543 meta_off,
6544 d_loss_off,
6545 d_loss_len,
6546 packed_off,
6547 packed_len,
6548 width,
6549 height,
6550 tile_size,
6551 radius_scale,
6552 alpha_cutoff,
6553 max_splat_steps,
6554 transmittance_threshold,
6555 max_list_entries,
6556 loss_grad_clip,
6557 sh_band,
6558 max_anisotropy,
6559 } => Arc::new(move |base: *mut u8| unsafe {
6560 crate::splat::execute_gaussian_splat_render_backward(
6561 positions_off,
6562 positions_len,
6563 scales_off,
6564 scales_len,
6565 rotations_off,
6566 rotations_len,
6567 opacities_off,
6568 opacities_len,
6569 colors_off,
6570 colors_len,
6571 sh_coeffs_off,
6572 sh_coeffs_len,
6573 meta_off,
6574 d_loss_off,
6575 d_loss_len,
6576 packed_off,
6577 packed_len,
6578 width,
6579 height,
6580 tile_size,
6581 radius_scale,
6582 alpha_cutoff,
6583 max_splat_steps,
6584 transmittance_threshold,
6585 max_list_entries,
6586 loss_grad_clip,
6587 sh_band,
6588 max_anisotropy,
6589 base,
6590 );
6591 }),
6592
6593 Thunk::GaussianSplatPrepare {
6594 positions_off,
6595 positions_len,
6596 scales_off,
6597 scales_len,
6598 rotations_off,
6599 rotations_len,
6600 opacities_off,
6601 opacities_len,
6602 colors_off,
6603 colors_len,
6604 sh_coeffs_off,
6605 sh_coeffs_len,
6606 meta_off,
6607 meta_len,
6608 prep_off,
6609 prep_len,
6610 width,
6611 height,
6612 tile_size,
6613 radius_scale,
6614 alpha_cutoff,
6615 max_splat_steps,
6616 transmittance_threshold,
6617 max_list_entries,
6618 } => Arc::new(move |base: *mut u8| unsafe {
6619 crate::splat::execute_gaussian_splat_prepare(
6620 positions_off,
6621 positions_len,
6622 scales_off,
6623 scales_len,
6624 rotations_off,
6625 rotations_len,
6626 opacities_off,
6627 opacities_len,
6628 colors_off,
6629 colors_len,
6630 sh_coeffs_off,
6631 sh_coeffs_len,
6632 meta_off,
6633 meta_len,
6634 prep_off,
6635 prep_len,
6636 width,
6637 height,
6638 tile_size,
6639 radius_scale,
6640 alpha_cutoff,
6641 max_splat_steps,
6642 transmittance_threshold,
6643 max_list_entries,
6644 base,
6645 );
6646 }),
6647
6648 Thunk::GaussianSplatRasterize {
6649 prep_off,
6650 prep_len,
6651 meta_off,
6652 meta_len,
6653 dst_off,
6654 dst_len,
6655 count,
6656 width,
6657 height,
6658 tile_size,
6659 alpha_cutoff,
6660 max_splat_steps,
6661 transmittance_threshold,
6662 max_list_entries,
6663 } => Arc::new(move |base: *mut u8| unsafe {
6664 crate::splat::execute_gaussian_splat_rasterize(
6665 prep_off,
6666 prep_len,
6667 meta_off,
6668 meta_len,
6669 dst_off,
6670 dst_len,
6671 count,
6672 width,
6673 height,
6674 tile_size,
6675 alpha_cutoff,
6676 max_splat_steps,
6677 transmittance_threshold,
6678 max_list_entries,
6679 base,
6680 );
6681 }),
6682
6683 Thunk::Fft1d {
6684 src,
6685 dst,
6686 outer,
6687 n_complex,
6688 inverse,
6689 norm_tag,
6690 dtype,
6691 } => {
6692 let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6693 rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6694 execute_fft1d_f64(
6695 src,
6696 dst,
6697 outer as usize,
6698 n_complex as usize,
6699 inverse,
6700 norm_tag,
6701 base,
6702 );
6703 }),
6704 rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6705 execute_fft1d_f32(
6706 src,
6707 dst,
6708 outer as usize,
6709 n_complex as usize,
6710 inverse,
6711 norm_tag,
6712 base,
6713 );
6714 }),
6715 rlx_ir::DType::C64 => Arc::new(move |base: *mut u8| unsafe {
6716 execute_fft1d_c64(
6717 src,
6718 dst,
6719 outer as usize,
6720 n_complex as usize,
6721 inverse,
6722 norm_tag,
6723 base,
6724 );
6725 }),
6726 other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
6727 };
6728 f
6729 }
6730
6731 Thunk::FftButterflyStage {
6732 state_src,
6733 state_dst,
6734 gate_src,
6735 rev_src,
6736 tw_re_src,
6737 tw_im_src,
6738 batch,
6739 n_fft,
6740 stage,
6741 } => Arc::new(move |base: *mut u8| unsafe {
6742 execute_fft_butterfly_stage_f32(
6743 state_src,
6744 state_dst,
6745 gate_src,
6746 rev_src,
6747 tw_re_src,
6748 tw_im_src,
6749 batch as usize,
6750 n_fft as usize,
6751 stage as usize,
6752 base,
6753 );
6754 }),
6755
6756 Thunk::LogMel {
6757 spec,
6758 filters,
6759 dst,
6760 outer,
6761 n_fft,
6762 n_bins,
6763 n_mels,
6764 } => Arc::new(move |base: *mut u8| unsafe {
6765 execute_log_mel_f32(
6766 spec,
6767 filters,
6768 dst,
6769 outer as usize,
6770 n_fft as usize,
6771 n_bins as usize,
6772 n_mels as usize,
6773 base,
6774 );
6775 }),
6776
6777 Thunk::LogMelBackward {
6778 spec,
6779 filters,
6780 dy,
6781 dst,
6782 outer,
6783 n_fft,
6784 n_bins,
6785 n_mels,
6786 } => Arc::new(move |base: *mut u8| unsafe {
6787 execute_log_mel_backward_f32(
6788 spec,
6789 filters,
6790 dy,
6791 dst,
6792 outer as usize,
6793 n_fft as usize,
6794 n_bins as usize,
6795 n_mels as usize,
6796 base,
6797 );
6798 }),
6799
6800 Thunk::WelchPeaks {
6801 spec,
6802 dst,
6803 welch_batch,
6804 n_fft,
6805 n_segments,
6806 k,
6807 } => Arc::new(move |base: *mut u8| unsafe {
6808 execute_welch_peaks_f32(
6809 spec,
6810 dst,
6811 welch_batch as usize,
6812 n_fft as usize,
6813 n_segments as usize,
6814 k as usize,
6815 base,
6816 );
6817 }),
6818
6819 _ => Arc::new(|_: *mut u8| {}),
6820 }
6821 })
6822 .collect();
6823
6824 let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6828 .and_then(|v| v.parse().ok())
6829 .unwrap_or(64);
6830 let should_fuse = thunks.iter().any(|t| match t {
6831 Thunk::Attention { batch, seq, .. } => {
6832 (*batch as usize) * (*seq as usize) <= fuse_threshold
6833 }
6834 _ => false,
6835 });
6836
6837 if should_fuse {
6838 let active: Vec<usize> = thunks
6840 .iter()
6841 .enumerate()
6842 .filter(|(_, t)| !matches!(t, Thunk::Nop))
6843 .map(|(i, _)| i)
6844 .collect();
6845
6846 let mut kill = vec![false; thunks.len()]; let mut insertions: Vec<(usize, Thunk)> = Vec::new(); let mut ai = 0;
6850 while ai < active.len() {
6851 let a = |off: usize| -> Option<(usize, &Thunk)> {
6853 active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6854 };
6855
6856 let matched = (|| {
6858 let (_i0, t0) = a(0)?;
6859 let (_, t1) = a(1)?;
6860 let (_, t2) = a(2)?;
6861 let (_, t3) = a(3)?;
6862
6863 let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6865 Thunk::FusedMmBiasAct {
6866 a,
6867 w,
6868 bias,
6869 n: _,
6870 act: None,
6871 ..
6872 } => (*a, *w, *bias, true),
6873 Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6874 _ => return None,
6875 };
6876
6877 if !matches!(t1, Thunk::Narrow { .. }) {
6879 return None;
6880 }
6881 if !matches!(t2, Thunk::Narrow { .. }) {
6882 return None;
6883 }
6884 if !matches!(t3, Thunk::Narrow { .. }) {
6885 return None;
6886 }
6887
6888 let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6890 _,
6891 Thunk::Rope {
6892 cos, sin, cos_len, ..
6893 },
6894 )) = a(4)
6895 {
6896 if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6897 if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6898 (true, 6, *cos, *sin, *cos_len)
6899 } else {
6900 return None;
6901 }
6902 } else {
6903 return None;
6904 }
6905 } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6906 (false, 4, 0, 0, 0)
6907 } else {
6908 return None;
6909 };
6910
6911 let (_attn_real_idx, attn_t) = a(attn_ai)?;
6912 let (batch, seq, heads, head_dim, mask) = match attn_t {
6913 Thunk::Attention {
6914 batch,
6915 seq,
6916 heads,
6917 head_dim,
6918 mask,
6919 ..
6920 } => (*batch, *seq, *heads, *head_dim, *mask),
6921 _ => return None,
6922 };
6923
6924 let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6926 let (out_w, out_b, out_dst) = match out_t {
6927 Thunk::FusedMmBiasAct {
6928 w,
6929 bias,
6930 c,
6931 act: None,
6932 ..
6933 } => (*w, *bias, *c),
6934 Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6935 _ => return None,
6936 };
6937
6938 let hs = heads * head_dim;
6939 let total_active = attn_ai + 2; Some((
6942 total_active,
6943 Thunk::FusedAttnBlock {
6944 hidden,
6945 qkv_w,
6946 out_w,
6947 mask,
6948 out: out_dst,
6949 qkv_b: if has_b { qkv_b } else { 0 },
6950 out_b: if has_b { out_b } else { 0 },
6951 cos: cos_off,
6952 sin: sin_off,
6953 cos_len: cl,
6954 batch,
6955 seq,
6956 hs,
6957 nh: heads,
6958 dh: head_dim,
6959 has_bias: has_b,
6960 has_rope,
6961 },
6962 ))
6963 })();
6964
6965 if let Some((count, fused_thunk)) = matched {
6966 for off in 0..count {
6968 if let Some(&idx) = active.get(ai + off) {
6969 kill[idx] = true;
6970 }
6971 }
6972 insertions.push((active[ai], fused_thunk));
6974 ai += count;
6975 } else {
6976 ai += 1;
6977 }
6978 }
6979
6980 if !insertions.is_empty() {
6982 let mut new_thunks = Vec::with_capacity(thunks.len());
6983 let mut insert_idx = 0;
6984 for (i, t) in thunks.into_iter().enumerate() {
6985 if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6986 new_thunks.push(insertions[insert_idx].1.clone());
6987 insert_idx += 1;
6988 }
6989 if !kill[i] {
6990 new_thunks.push(t);
6991 }
6992 }
6993 if cfg.verbose >= 1 {
6994 eprintln!(
6995 "[rlx] fused_attention: {} attention blocks fused",
6996 insertions.len()
6997 );
6998 }
6999 thunks = new_thunks;
7000 }
7001 }
7002
7003 if should_fuse {
7008 let active: Vec<usize> = thunks
7009 .iter()
7010 .enumerate()
7011 .filter(|(_, t)| !matches!(t, Thunk::Nop))
7012 .map(|(i, _)| i)
7013 .collect();
7014
7015 let mut kill = vec![false; thunks.len()];
7016 let mut insertions: Vec<(usize, Thunk)> = Vec::new();
7017
7018 let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
7019
7020 let mut ai = 0;
7021 while ai < active.len() {
7022 let bert_match = (|| -> Option<usize> {
7024 let fab = a(ai)?;
7025 let rln1 = a(ai + 1)?;
7026 let ffn1 = a(ai + 2)?;
7027 let ffn2 = a(ai + 3)?;
7028 let rln2 = a(ai + 4)?;
7029
7030 let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
7031 Thunk::FusedAttnBlock {
7032 hidden,
7033 qkv_w,
7034 qkv_b,
7035 out_w,
7036 out_b,
7037 mask,
7038 batch,
7039 seq,
7040 hs,
7041 nh,
7042 dh,
7043 has_bias: true,
7044 has_rope: false,
7045 ..
7046 } => (
7047 *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
7048 ),
7049 _ => return None,
7050 };
7051 let (ln1_g, ln1_b, eps1) = match rln1 {
7052 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7053 _ => return None,
7054 };
7055 let (fc1_w, fc1_b, int_dim) = match ffn1 {
7056 Thunk::FusedMmBiasAct {
7057 w,
7058 bias,
7059 n,
7060 act: Some(Activation::Gelu),
7061 ..
7062 } => (*w, *bias, *n),
7063 _ => return None,
7064 };
7065 let (fc2_w, fc2_b) = match ffn2 {
7066 Thunk::FusedMmBiasAct {
7067 w, bias, act: None, ..
7068 } => (*w, *bias),
7069 _ => return None,
7070 };
7071 let (ln2_g, ln2_b, eps2, out) = match rln2 {
7072 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7073 _ => return None,
7074 };
7075
7076 for off in 0..5 {
7077 kill[active[ai + off]] = true;
7078 }
7079 insertions.push((
7080 active[ai],
7081 Thunk::FusedBertLayer {
7082 hidden,
7083 qkv_w,
7084 qkv_b,
7085 out_w,
7086 out_b,
7087 mask,
7088 ln1_g,
7089 ln1_b,
7090 eps1,
7091 fc1_w,
7092 fc1_b,
7093 fc2_w,
7094 fc2_b,
7095 ln2_g,
7096 ln2_b,
7097 eps2,
7098 out,
7099 batch,
7100 seq,
7101 hs,
7102 nh,
7103 dh,
7104 int_dim,
7105 },
7106 ));
7107 Some(5)
7108 })();
7109 if let Some(n) = bert_match {
7110 ai += n;
7111 continue;
7112 }
7113
7114 #[allow(unreachable_code)]
7118 let nomic_match = (|| -> Option<usize> {
7119 return None; let fab = a(ai)?;
7121 let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
7122 match fab {
7123 Thunk::FusedAttnBlock {
7124 hidden,
7125 qkv_w,
7126 out_w,
7127 mask,
7128 cos,
7129 sin,
7130 cos_len,
7131 batch,
7132 seq,
7133 hs,
7134 nh,
7135 dh,
7136 has_bias: false,
7137 has_rope: true,
7138 ..
7139 } => (
7140 *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
7141 *hs, *nh, *dh,
7142 ),
7143 _ => return None,
7144 };
7145 let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
7147 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7148 _ => return None,
7149 };
7150 let fused_fc_w = match a(ai + 2)? {
7152 Thunk::Sgemm { b: w, .. } => *w,
7153 _ => return None,
7154 };
7155 if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
7157 return None;
7158 }
7159 if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
7160 return None;
7161 }
7162 if !matches!(
7164 a(ai + 5)?,
7165 Thunk::ActivationInPlace {
7166 act: Activation::Silu,
7167 ..
7168 }
7169 ) {
7170 return None;
7171 }
7172 if !matches!(
7174 a(ai + 6)?,
7175 Thunk::BinaryFull {
7176 op: BinaryOp::Mul,
7177 ..
7178 }
7179 ) {
7180 return None;
7181 }
7182 let fc2_w = match a(ai + 7)? {
7184 Thunk::Sgemm { b: w, .. } => *w,
7185 _ => return None,
7186 };
7187 let int_dim = match a(ai + 3)? {
7189 Thunk::Narrow { inner, .. } => *inner,
7190 _ => return None,
7191 };
7192 let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
7194 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7195 _ => return None,
7196 };
7197
7198 for off in 0..9 {
7199 kill[active[ai + off]] = true;
7200 }
7201 insertions.push((
7202 active[ai],
7203 Thunk::FusedNomicLayer {
7204 hidden,
7205 qkv_w,
7206 out_w,
7207 mask,
7208 cos,
7209 sin,
7210 cos_len,
7211 ln1_g,
7212 ln1_b,
7213 eps1,
7214 fc11_w: fused_fc_w,
7215 fc12_w: 0,
7216 fc2_w,
7217 ln2_g,
7218 ln2_b,
7219 eps2,
7220 out,
7221 batch,
7222 seq,
7223 hs,
7224 nh,
7225 dh,
7226 int_dim,
7227 },
7228 ));
7229 Some(9)
7230 })();
7231 if let Some(n) = nomic_match {
7232 ai += n;
7233 continue;
7234 }
7235
7236 ai += 1;
7237 }
7238
7239 if !insertions.is_empty() {
7240 let mut new_thunks = Vec::with_capacity(thunks.len());
7241 let mut ins_idx = 0;
7242 for (i, t) in thunks.into_iter().enumerate() {
7243 if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
7244 new_thunks.push(insertions[ins_idx].1.clone());
7245 ins_idx += 1;
7246 }
7247 if !kill[i] {
7248 new_thunks.push(t);
7249 }
7250 }
7251 if cfg.verbose >= 1 {
7252 eprintln!(
7253 "[rlx] fused_layer: {} full transformer layers fused",
7254 insertions.len()
7255 );
7256 }
7257 thunks = new_thunks;
7258 }
7259 }
7260
7261 {
7273 let mut read_offsets: HashMap<usize, usize> = HashMap::new();
7276 for t in &thunks {
7277 for off in thunk_read_offsets(t) {
7278 *read_offsets.entry(off).or_insert(0) += 1;
7279 }
7280 }
7281
7282 let mut fused_count = 0usize;
7283 for i in 0..thunks.len().saturating_sub(1) {
7284 let narrow = match &thunks[i] {
7287 Thunk::Narrow { .. } => i,
7288 _ => continue,
7289 };
7290 let mut j = narrow + 1;
7292 while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
7293 j += 1;
7294 }
7295 if j >= thunks.len() {
7296 continue;
7297 }
7298 let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
7300 Thunk::Narrow {
7301 src,
7302 dst,
7303 src_stride,
7304 ..
7305 } => (*src, *dst, *src_stride),
7306 _ => continue,
7307 };
7308 let rope_reads_narrow = matches!(&thunks[j],
7309 Thunk::Rope { src, .. } if *src == n_dst);
7310 if !rope_reads_narrow {
7311 continue;
7312 }
7313 if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
7317 continue;
7318 }
7319
7320 if let Thunk::Rope {
7323 src,
7324 src_row_stride,
7325 ..
7326 } = &mut thunks[j]
7327 {
7328 *src = n_src;
7329 *src_row_stride = n_src_stride;
7330 }
7331 thunks[narrow] = Thunk::Nop;
7332 fused_count += 1;
7333 }
7334
7335 if fused_count > 0 && cfg.verbose >= 1 {
7336 eprintln!(
7337 "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
7338 fused_count
7339 );
7340 }
7341 }
7342
7343 {
7355 let mut read_counts: HashMap<usize, usize> = HashMap::new();
7356 for t in &thunks {
7357 for off in thunk_read_offsets(t) {
7358 *read_counts.entry(off).or_insert(0) += 1;
7359 }
7360 }
7361 let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
7363 for (i, t) in thunks.iter().enumerate() {
7364 if let Thunk::Narrow { dst, .. } = t {
7365 dst_to_idx.insert(*dst, i);
7366 }
7367 }
7368
7369 let mut fused_count = 0usize;
7370 for i in 0..thunks.len() {
7371 let (q_off, k_off, v_off) = match &thunks[i] {
7372 Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
7373 _ => continue,
7374 };
7375 let q_n = match dst_to_idx.get(&q_off).copied() {
7377 Some(x) => x,
7378 None => continue,
7379 };
7380 let k_n = match dst_to_idx.get(&k_off).copied() {
7381 Some(x) => x,
7382 None => continue,
7383 };
7384 let v_n = match dst_to_idx.get(&v_off).copied() {
7385 Some(x) => x,
7386 None => continue,
7387 };
7388 if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
7390 continue;
7391 }
7392 if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
7393 continue;
7394 }
7395 if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
7396 continue;
7397 }
7398
7399 let (q_src, q_stride) = match &thunks[q_n] {
7400 Thunk::Narrow {
7401 src, src_stride, ..
7402 } => (*src, *src_stride),
7403 _ => continue,
7404 };
7405 let (k_src, k_stride) = match &thunks[k_n] {
7406 Thunk::Narrow {
7407 src, src_stride, ..
7408 } => (*src, *src_stride),
7409 _ => continue,
7410 };
7411 let (v_src, v_stride) = match &thunks[v_n] {
7412 Thunk::Narrow {
7413 src, src_stride, ..
7414 } => (*src, *src_stride),
7415 _ => continue,
7416 };
7417
7418 if let Thunk::Attention {
7419 q,
7420 k,
7421 v,
7422 q_row_stride,
7423 k_row_stride,
7424 v_row_stride,
7425 ..
7426 } = &mut thunks[i]
7427 {
7428 *q = q_src;
7429 *k = k_src;
7430 *v = v_src;
7431 *q_row_stride = q_stride;
7432 *k_row_stride = k_stride;
7433 *v_row_stride = v_stride;
7434 }
7435 thunks[q_n] = Thunk::Nop;
7436 thunks[k_n] = Thunk::Nop;
7437 thunks[v_n] = Thunk::Nop;
7438 fused_count += 1;
7439 }
7440
7441 if fused_count > 0 && cfg.verbose >= 1 {
7442 eprintln!(
7443 "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
7444 fused_count
7445 );
7446 }
7447 }
7448
7449 ThunkSchedule {
7450 thunks,
7451 moe_resident: None,
7452 moe_resident_layers: None,
7453 moe_topk_capture: None,
7454 mask_threshold: cfg.mask_binary_threshold,
7455 mask_neg_inf: cfg.attn_mask_neg_inf,
7456 score_skip: cfg.score_skip_threshold,
7457 compiled_fns,
7458 }
7459}
7460
7461fn get_len(graph: &Graph, id: NodeId) -> usize {
7462 graph.node(id).shape.num_elements().unwrap_or(0)
7463}
7464
7465fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
7467 let dims = graph.node(id).shape.dims();
7468 let mut out = Vec::with_capacity(dims.len());
7469 for d in dims {
7470 if let Some(s) = match d {
7471 rlx_ir::Dim::Static(s) => Some(*s),
7472 _ => None,
7473 } {
7474 out.push(s);
7475 } else {
7476 return Vec::new();
7477 }
7478 }
7479 out
7480}
7481
7482fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
7500 if rhs_dims.len() > out_dims.len() {
7501 return false;
7502 }
7503 let off = out_dims.len() - rhs_dims.len();
7504 for i in 0..rhs_dims.len() {
7505 let r = match rhs_dims[i] {
7506 rlx_ir::Dim::Static(n) => n,
7507 _ => return false,
7508 };
7509 let o = match out_dims[off + i] {
7510 rlx_ir::Dim::Static(n) => n,
7511 _ => return false,
7512 };
7513 if r != o {
7514 return false;
7515 }
7516 }
7517 true
7518}
7519
7520fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
7521 let r_out = out_dims.len();
7522 let r_in = in_dims.len();
7523 assert!(
7524 r_in <= r_out,
7525 "broadcast: input rank {r_in} > output rank {r_out}"
7526 );
7527 let pad = r_out - r_in;
7528 let mut strides = vec![0u32; r_out];
7529 let mut acc: usize = 1;
7530 for d in (0..r_out).rev() {
7531 let in_size = if d < pad { 1 } else { in_dims[d - pad] };
7532 if in_size == 1 {
7533 strides[d] = 0;
7534 } else {
7535 assert_eq!(
7536 in_size, out_dims[d],
7537 "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
7538 out_dims[d]
7539 );
7540 strides[d] = acc as u32;
7541 acc *= in_size;
7542 }
7543 }
7544 strides
7545}
7546
7547pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7551 let base = arena_buf.as_mut_ptr();
7552 for f in &schedule.compiled_fns {
7553 f(base);
7554 }
7555}
7556
7557pub fn execute_thunks_active(
7562 schedule: &ThunkSchedule,
7563 _arena_buf: &mut [u8],
7564 _actual: usize,
7565 _upper: usize,
7566) -> bool {
7567 let _ = schedule;
7568 false
7569}
7570
7571struct MoeResidencyGuard;
7573impl Drop for MoeResidencyGuard {
7574 fn drop(&mut self) {
7575 if let Some(stats) = crate::moe_residency::take_stats() {
7576 crate::moe_residency::stash_last_forward_stats(stats);
7577 } else {
7578 crate::moe_residency::clear_mask();
7579 }
7580 }
7581}
7582
7583fn thunk_kind_name(t: &Thunk) -> &'static str {
7584 match t {
7585 Thunk::Nop => "Nop",
7586 Thunk::Gather { .. } => "Gather",
7587 Thunk::GatherAxis { .. } => "GatherAxis",
7588 Thunk::TopK { .. } => "TopK",
7589 Thunk::Copy { .. } => "Copy",
7590 Thunk::CopyF64 { .. } => "CopyF64",
7591 Thunk::CopyI64 { .. } => "CopyI64",
7592 Thunk::CastF32ToI64 { .. } => "CastF32ToI64",
7593 Thunk::CastI64ToF32 { .. } => "CastI64ToF32",
7594 Thunk::CastBoolToI32 { .. } => "CastBoolToI32",
7595 Thunk::CastI32ToF32 { .. } => "CastI32ToF32",
7596 Thunk::Transpose { .. } => "Transpose",
7597 Thunk::TransposeF64 { .. } => "TransposeF64",
7598 Thunk::Where { .. } => "Where",
7599 Thunk::Compare { .. } => "Compare",
7600 Thunk::BinaryFull { .. } => "BinaryFull",
7601 Thunk::BinaryFullF64 { .. } => "BinaryFullF64",
7602 Thunk::Sgemm { .. } => "Sgemm",
7603 Thunk::Dgemm { .. } => "Dgemm",
7604 Thunk::FusedMmBiasAct { .. } => "FusedMmBiasAct",
7605 Thunk::BiasAdd { .. } => "BiasAdd",
7606 Thunk::LayerNorm { .. } => "LayerNorm",
7607 Thunk::Softmax { .. } => "Softmax",
7608 Thunk::Conv2D { .. } => "Conv2D",
7609 Thunk::Conv2D1x1 { .. } => "Conv2D1x1",
7610 Thunk::CustomOp { .. } => "CustomOp",
7611 Thunk::ActivationInPlace { .. } => "ActivationInPlace",
7612 Thunk::Narrow { .. } => "Narrow",
7613 Thunk::Cumsum { .. } => "Cumsum",
7614 Thunk::Reduce { .. } => "Reduce",
7615 Thunk::BatchedSgemm { .. } => "BatchedSgemm",
7616 Thunk::DequantMatMul { .. } => "DequantMatMul",
7617 Thunk::Quantize { .. } => "Quantize",
7618 Thunk::Dequantize { .. } => "Dequantize",
7619 Thunk::ConvTranspose2d { .. } => "ConvTranspose2d",
7620 Thunk::ResizeNearest2x { .. } => "ResizeNearest2x",
7621 _ => "Other",
7622 }
7623}
7624
7625pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7626 crate::moe_residency::reset_gmm_counters();
7627 if let Some(layers) = schedule.moe_resident_layers.clone() {
7628 crate::moe_residency::set_per_layer_masks(Some(layers));
7629 } else {
7630 crate::moe_residency::set_mask(schedule.moe_resident.clone());
7631 }
7632 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
7633 cap.clear();
7634 }
7635 let _moe_guard = MoeResidencyGuard;
7636 let base = arena_buf.as_mut_ptr();
7637 let mask_thr = schedule.mask_threshold;
7638 let mask_neg = schedule.mask_neg_inf;
7639 let score_thr = schedule.score_skip;
7640 let thunks = &schedule.thunks;
7641 let len = thunks.len();
7642
7643 let max_h = thunks
7645 .iter()
7646 .filter_map(|t| match t {
7647 Thunk::FusedResidualLN { h, .. }
7648 | Thunk::FusedResidualRmsNorm { h, .. }
7649 | Thunk::LayerNorm { h, .. } => Some(*h as usize),
7650 _ => None,
7651 })
7652 .max()
7653 .unwrap_or(0);
7654 let zero_bias = vec![0f32; max_h];
7655
7656 let max_sdpa = thunks
7659 .iter()
7660 .filter_map(|t| match t {
7661 Thunk::Attention {
7662 batch,
7663 seq,
7664 kv_seq,
7665 heads,
7666 head_dim,
7667 ..
7668 } => Some((
7669 *batch as usize,
7670 (*seq as usize).max(*kv_seq as usize),
7671 *heads as usize,
7672 *head_dim as usize,
7673 )),
7674 _ => None,
7675 })
7676 .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
7677 (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7678 });
7679 let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7680 let max_units = max_batch * max_heads;
7681 let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7682
7683 let fl = thunks
7685 .iter()
7686 .filter_map(|t| match t {
7687 Thunk::FusedBertLayer {
7688 batch,
7689 seq,
7690 hs,
7691 int_dim,
7692 ..
7693 } => {
7694 let m = (*batch as usize) * (*seq as usize);
7695 let h = *hs as usize;
7696 let id = *int_dim as usize;
7697 Some((m, h, id, m * (*seq as usize)))
7698 }
7699 Thunk::FusedNomicLayer {
7700 batch,
7701 seq,
7702 hs,
7703 int_dim,
7704 ..
7705 } => {
7706 let m = (*batch as usize) * (*seq as usize);
7707 let h = *hs as usize;
7708 let id = *int_dim as usize;
7709 Some((m, h, id, m * (*seq as usize)))
7710 }
7711 _ => None,
7712 })
7713 .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7714 (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7715 });
7716 let (fl_m, fl_h, fl_int, fl_ss) = fl;
7717 let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7718 let mut fl_attn = vec![0f32; fl_m * fl_h];
7719 let mut fl_res = vec![0f32; fl_m * fl_h];
7720 let mut fl_normed = vec![0f32; fl_m * fl_h];
7721 let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; let mut fl_sc = vec![0f32; fl_ss.max(1)];
7723
7724 let trace_thunks = std::env::var_os("RLX_TRACE_THUNK").is_some();
7725 if trace_thunks {
7726 eprintln!(
7727 "[thunk] prealloc max_h={max_h} sdpa={} fl_m={fl_m} fl_h={fl_h} fl_int={fl_int}",
7728 max_units * max_seq * max_seq
7729 );
7730 }
7731 for i in 0..len {
7732 let thunk = unsafe { thunks.get_unchecked(i) };
7733 if trace_thunks && (i < 120 || i % 200 == 0 || i + 1 == len) {
7734 eprintln!("[thunk {i}/{len}] {}", thunk_kind_name(thunk));
7735 }
7736 let trace_done = trace_thunks && i < 120;
7737 match thunk {
7738 Thunk::Nop => {}
7739
7740 Thunk::GaussianSplatRender {
7741 positions_off,
7742 positions_len,
7743 scales_off,
7744 scales_len,
7745 rotations_off,
7746 rotations_len,
7747 opacities_off,
7748 opacities_len,
7749 colors_off,
7750 colors_len,
7751 sh_coeffs_off,
7752 sh_coeffs_len,
7753 meta_off,
7754 dst_off,
7755 dst_len,
7756 width,
7757 height,
7758 tile_size,
7759 radius_scale,
7760 alpha_cutoff,
7761 max_splat_steps,
7762 transmittance_threshold,
7763 max_list_entries,
7764 } => unsafe {
7765 crate::splat::execute_gaussian_splat_render(
7766 *positions_off,
7767 *positions_len,
7768 *scales_off,
7769 *scales_len,
7770 *rotations_off,
7771 *rotations_len,
7772 *opacities_off,
7773 *opacities_len,
7774 *colors_off,
7775 *colors_len,
7776 *sh_coeffs_off,
7777 *sh_coeffs_len,
7778 *meta_off,
7779 *dst_off,
7780 *dst_len,
7781 *width,
7782 *height,
7783 *tile_size,
7784 *radius_scale,
7785 *alpha_cutoff,
7786 *max_splat_steps,
7787 *transmittance_threshold,
7788 *max_list_entries,
7789 base,
7790 );
7791 },
7792
7793 Thunk::GaussianSplatRenderBackward {
7794 positions_off,
7795 positions_len,
7796 scales_off,
7797 scales_len,
7798 rotations_off,
7799 rotations_len,
7800 opacities_off,
7801 opacities_len,
7802 colors_off,
7803 colors_len,
7804 sh_coeffs_off,
7805 sh_coeffs_len,
7806 meta_off,
7807 d_loss_off,
7808 d_loss_len,
7809 packed_off,
7810 packed_len,
7811 width,
7812 height,
7813 tile_size,
7814 radius_scale,
7815 alpha_cutoff,
7816 max_splat_steps,
7817 transmittance_threshold,
7818 max_list_entries,
7819 loss_grad_clip,
7820 sh_band,
7821 max_anisotropy,
7822 } => unsafe {
7823 crate::splat::execute_gaussian_splat_render_backward(
7824 *positions_off,
7825 *positions_len,
7826 *scales_off,
7827 *scales_len,
7828 *rotations_off,
7829 *rotations_len,
7830 *opacities_off,
7831 *opacities_len,
7832 *colors_off,
7833 *colors_len,
7834 *sh_coeffs_off,
7835 *sh_coeffs_len,
7836 *meta_off,
7837 *d_loss_off,
7838 *d_loss_len,
7839 *packed_off,
7840 *packed_len,
7841 *width,
7842 *height,
7843 *tile_size,
7844 *radius_scale,
7845 *alpha_cutoff,
7846 *max_splat_steps,
7847 *transmittance_threshold,
7848 *max_list_entries,
7849 *loss_grad_clip,
7850 *sh_band,
7851 *max_anisotropy,
7852 base,
7853 );
7854 },
7855
7856 Thunk::GaussianSplatPrepare {
7857 positions_off,
7858 positions_len,
7859 scales_off,
7860 scales_len,
7861 rotations_off,
7862 rotations_len,
7863 opacities_off,
7864 opacities_len,
7865 colors_off,
7866 colors_len,
7867 sh_coeffs_off,
7868 sh_coeffs_len,
7869 meta_off,
7870 meta_len,
7871 prep_off,
7872 prep_len,
7873 width,
7874 height,
7875 tile_size,
7876 radius_scale,
7877 alpha_cutoff,
7878 max_splat_steps,
7879 transmittance_threshold,
7880 max_list_entries,
7881 } => unsafe {
7882 crate::splat::execute_gaussian_splat_prepare(
7883 *positions_off,
7884 *positions_len,
7885 *scales_off,
7886 *scales_len,
7887 *rotations_off,
7888 *rotations_len,
7889 *opacities_off,
7890 *opacities_len,
7891 *colors_off,
7892 *colors_len,
7893 *sh_coeffs_off,
7894 *sh_coeffs_len,
7895 *meta_off,
7896 *meta_len,
7897 *prep_off,
7898 *prep_len,
7899 *width,
7900 *height,
7901 *tile_size,
7902 *radius_scale,
7903 *alpha_cutoff,
7904 *max_splat_steps,
7905 *transmittance_threshold,
7906 *max_list_entries,
7907 base,
7908 );
7909 },
7910
7911 Thunk::GaussianSplatRasterize {
7912 prep_off,
7913 prep_len,
7914 meta_off,
7915 meta_len,
7916 dst_off,
7917 dst_len,
7918 count,
7919 width,
7920 height,
7921 tile_size,
7922 alpha_cutoff,
7923 max_splat_steps,
7924 transmittance_threshold,
7925 max_list_entries,
7926 } => unsafe {
7927 crate::splat::execute_gaussian_splat_rasterize(
7928 *prep_off,
7929 *prep_len,
7930 *meta_off,
7931 *meta_len,
7932 *dst_off,
7933 *dst_len,
7934 *count,
7935 *width,
7936 *height,
7937 *tile_size,
7938 *alpha_cutoff,
7939 *max_splat_steps,
7940 *transmittance_threshold,
7941 *max_list_entries,
7942 base,
7943 );
7944 },
7945
7946 Thunk::Fft1d {
7947 src,
7948 dst,
7949 outer,
7950 n_complex,
7951 inverse,
7952 norm_tag,
7953 dtype,
7954 } => unsafe {
7955 match dtype {
7956 rlx_ir::DType::F64 => execute_fft1d_f64(
7957 *src,
7958 *dst,
7959 *outer as usize,
7960 *n_complex as usize,
7961 *inverse,
7962 *norm_tag,
7963 base,
7964 ),
7965 rlx_ir::DType::F32 => execute_fft1d_f32(
7966 *src,
7967 *dst,
7968 *outer as usize,
7969 *n_complex as usize,
7970 *inverse,
7971 *norm_tag,
7972 base,
7973 ),
7974 rlx_ir::DType::C64 => execute_fft1d_c64(
7975 *src,
7976 *dst,
7977 *outer as usize,
7978 *n_complex as usize,
7979 *inverse,
7980 *norm_tag,
7981 base,
7982 ),
7983 other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
7984 }
7985 },
7986
7987 Thunk::FftButterflyStage {
7988 state_src,
7989 state_dst,
7990 gate_src,
7991 rev_src,
7992 tw_re_src,
7993 tw_im_src,
7994 batch,
7995 n_fft,
7996 stage,
7997 } => unsafe {
7998 execute_fft_butterfly_stage_f32(
7999 *state_src,
8000 *state_dst,
8001 *gate_src,
8002 *rev_src,
8003 *tw_re_src,
8004 *tw_im_src,
8005 *batch as usize,
8006 *n_fft as usize,
8007 *stage as usize,
8008 base,
8009 );
8010 },
8011
8012 Thunk::LogMel {
8013 spec,
8014 filters,
8015 dst,
8016 outer,
8017 n_fft,
8018 n_bins,
8019 n_mels,
8020 } => unsafe {
8021 execute_log_mel_f32(
8022 *spec,
8023 *filters,
8024 *dst,
8025 *outer as usize,
8026 *n_fft as usize,
8027 *n_bins as usize,
8028 *n_mels as usize,
8029 base,
8030 );
8031 },
8032
8033 Thunk::LogMelBackward {
8034 spec,
8035 filters,
8036 dy,
8037 dst,
8038 outer,
8039 n_fft,
8040 n_bins,
8041 n_mels,
8042 } => unsafe {
8043 execute_log_mel_backward_f32(
8044 *spec,
8045 *filters,
8046 *dy,
8047 *dst,
8048 *outer as usize,
8049 *n_fft as usize,
8050 *n_bins as usize,
8051 *n_mels as usize,
8052 base,
8053 );
8054 },
8055
8056 Thunk::WelchPeaks {
8057 spec,
8058 dst,
8059 welch_batch,
8060 n_fft,
8061 n_segments,
8062 k,
8063 } => unsafe {
8064 execute_welch_peaks_f32(
8065 *spec,
8066 *dst,
8067 *welch_batch as usize,
8068 *n_fft as usize,
8069 *n_segments as usize,
8070 *k as usize,
8071 base,
8072 );
8073 },
8074
8075 Thunk::CustomFn {
8079 body,
8080 body_init,
8081 inputs,
8082 body_output_off,
8083 outer_output_off,
8084 out_bytes,
8085 } => {
8086 let mut body_buf: Vec<u8> = (**body_init).clone();
8087 unsafe {
8088 for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
8089 let src = (base as *const u8).add(*outer_in_off);
8090 let dst = body_buf.as_mut_ptr().add(*body_in_off);
8091 std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
8092 }
8093 }
8094 execute_thunks(body, &mut body_buf);
8095 unsafe {
8096 let src = body_buf.as_ptr().add(*body_output_off);
8097 let dst = base.add(*outer_output_off);
8098 std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
8099 }
8100 }
8101
8102 Thunk::Sgemm { a, b, c, m, k, n } => {
8103 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8104 if trace_thunks {
8105 eprintln!("[sgemm] m={m} k={k} n={n} a={} b={} c={}", *a, *b, *c);
8106 }
8107 let c_len = m.saturating_mul(n);
8108 let a_len = m.saturating_mul(k);
8109 let b_len = k.saturating_mul(n);
8110 let arena_len = arena_buf.len();
8111 let max_a = (arena_len.saturating_sub(*a)) / 4;
8112 let max_b = (arena_len.saturating_sub(*b)) / 4;
8113 let max_c = (arena_len.saturating_sub(*c)) / 4;
8114 let a_len = a_len.min(max_a);
8115 let b_len = b_len.min(max_b);
8116 let c_len = c_len.min(max_c);
8117 unsafe {
8118 let a_sl = sl(*a, base, a_len);
8119 let b_sl = sl(*b, base, b_len);
8120 let c_sl = sl_mut(*c, base, c_len);
8121 if std::ptr::eq(a_sl.as_ptr(), c_sl.as_ptr())
8122 || std::ptr::eq(b_sl.as_ptr(), c_sl.as_ptr())
8123 {
8124 let mut tmp = vec![0.0f32; c_len];
8125 crate::blas::sgemm_auto(a_sl, b_sl, &mut tmp, m, k, n);
8126 c_sl.copy_from_slice(&tmp);
8127 } else {
8128 crate::blas::sgemm_auto(a_sl, b_sl, c_sl, m, k, n);
8129 }
8130 }
8131 }
8132
8133 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
8134 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8135 unsafe {
8141 let a_src = sl_f64(*a, base, n_ * n_);
8142 let b_src = sl_f64(*b, base, n_ * nrhs_);
8143 let mut a_scratch: Vec<f64> = a_src.to_vec();
8144 let mut x_buf: Vec<f64> = b_src.to_vec();
8145 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8146 if info != 0 {
8147 panic!(
8148 "DenseSolveF64: dgesv reported singular matrix \
8149 (info={info}, n={n_}, nrhs={nrhs_})"
8150 );
8151 }
8152 let dst = sl_mut_f64(*x, base, n_ * nrhs_);
8153 dst.copy_from_slice(&x_buf);
8154 }
8155 }
8156
8157 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
8158 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8159 unsafe {
8160 let a_src = sl(*a, base, n_ * n_);
8161 let b_src = sl(*b, base, n_ * nrhs_);
8162 let mut a_scratch: Vec<f32> = a_src.to_vec();
8163 let mut x_buf: Vec<f32> = b_src.to_vec();
8164 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8165 if info != 0 {
8166 panic!(
8167 "DenseSolveF32: sgesv reported singular matrix \
8168 (info={info}, n={n_}, nrhs={nrhs_})"
8169 );
8170 }
8171 let dst = sl_mut(*x, base, n_ * nrhs_);
8172 dst.copy_from_slice(&x_buf);
8173 }
8174 }
8175
8176 Thunk::BatchedDenseSolveF64 {
8177 a,
8178 b,
8179 x,
8180 batch,
8181 n,
8182 nrhs,
8183 } => {
8184 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8191 let a_stride = n_ * n_;
8192 let b_stride = n_ * nrhs_;
8193 unsafe {
8194 let a_full = sl_f64(*a, base, b_ * a_stride);
8195 let b_full = sl_f64(*b, base, b_ * b_stride);
8196 let x_full = sl_mut_f64(*x, base, b_ * b_stride);
8197 for bi in 0..b_ {
8198 let mut a_scratch: Vec<f64> =
8199 a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8200 let mut x_buf: Vec<f64> =
8201 b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8202 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8203 if info != 0 {
8204 panic!(
8205 "BatchedDenseSolveF64: slice {bi} \
8206 singular (info={info}, n={n_}, nrhs={nrhs_})"
8207 );
8208 }
8209 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8210 }
8211 }
8212 }
8213
8214 Thunk::BatchedDenseSolveF32 {
8215 a,
8216 b,
8217 x,
8218 batch,
8219 n,
8220 nrhs,
8221 } => {
8222 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8223 let a_stride = n_ * n_;
8224 let b_stride = n_ * nrhs_;
8225 unsafe {
8226 let a_full = sl(*a, base, b_ * a_stride);
8227 let b_full = sl(*b, base, b_ * b_stride);
8228 let x_full = sl_mut(*x, base, b_ * b_stride);
8229 for bi in 0..b_ {
8230 let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8231 let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8232 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8233 if info != 0 {
8234 panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
8235 }
8236 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8237 }
8238 }
8239 }
8240
8241 Thunk::BatchedDgemmF64 {
8242 a,
8243 b,
8244 c,
8245 batch,
8246 m,
8247 k,
8248 n,
8249 } => {
8250 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8251 let a_stride = m_ * k_;
8252 let b_stride = k_ * n_;
8253 let c_stride = m_ * n_;
8254 unsafe {
8255 let a_full = sl_f64(*a, base, b_ * a_stride);
8256 let b_full = sl_f64(*b, base, b_ * b_stride);
8257 let c_full = sl_mut_f64(*c, base, b_ * c_stride);
8258 for bi in 0..b_ {
8259 let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
8260 let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
8261 let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
8262 crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
8263 }
8264 }
8265 }
8266
8267 Thunk::BatchedSgemm {
8268 a,
8269 b,
8270 c,
8271 batch,
8272 m,
8273 k,
8274 n,
8275 } => {
8276 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8277 if trace_thunks {
8278 eprintln!(
8279 "[batched-sgemm] batch={b_} m={m_} k={k_} n={n_} a={} b={} c={}",
8280 *a, *b, *c
8281 );
8282 }
8283 let a_stride = m_.saturating_mul(k_);
8284 let b_stride = k_.saturating_mul(n_);
8285 let c_stride = m_.saturating_mul(n_);
8286 let arena_len = arena_buf.len();
8287 let a_cap = (arena_len.saturating_sub(*a)) / 4;
8288 let b_cap = (arena_len.saturating_sub(*b)) / 4;
8289 let c_cap = (arena_len.saturating_sub(*c)) / 4;
8290 let a_elems = (b_ * a_stride).min(a_cap);
8291 let b_elems = (b_ * b_stride).min(b_cap);
8292 let c_elems = (b_ * c_stride).min(c_cap);
8293 let b_eff = b_
8294 .min(a_elems.checked_div(a_stride).unwrap_or(0))
8295 .min(b_elems.checked_div(b_stride).unwrap_or(0))
8296 .min(c_elems.checked_div(c_stride).unwrap_or(0));
8297 unsafe {
8298 let a_full = sl(*a, base, a_elems);
8299 let b_full = sl(*b, base, b_elems);
8300 let c_full = sl_mut(*c, base, c_elems);
8301 for bi in 0..b_eff {
8302 let a0 = bi * a_stride;
8303 let b0 = bi * b_stride;
8304 let c0 = bi * c_stride;
8305 if a0 + a_stride > a_full.len()
8306 || b0 + b_stride > b_full.len()
8307 || c0 + c_stride > c_full.len()
8308 {
8309 break;
8310 }
8311 let a_slice = &a_full[a0..a0 + a_stride];
8312 let b_slice = &b_full[b0..b0 + b_stride];
8313 let c_slice = &mut c_full[c0..c0 + c_stride];
8314 if std::ptr::eq(a_slice.as_ptr(), c_slice.as_mut_ptr())
8315 || std::ptr::eq(b_slice.as_ptr(), c_slice.as_mut_ptr())
8316 {
8317 let mut tmp = vec![0.0f32; c_stride];
8318 crate::blas::sgemm_auto(a_slice, b_slice, &mut tmp, m_, k_, n_);
8319 c_slice.copy_from_slice(&tmp);
8320 } else {
8321 crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
8322 }
8323 }
8324 }
8325 }
8326
8327 Thunk::Dgemm { a, b, c, m, k, n } => {
8328 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8329 unsafe {
8330 crate::blas::dgemm(
8331 sl_f64(*a, base, m * k),
8332 sl_f64(*b, base, k * n),
8333 sl_mut_f64(*c, base, m * n),
8334 m,
8335 k,
8336 n,
8337 );
8338 }
8339 }
8340
8341 Thunk::TransposeF64 {
8342 src,
8343 dst,
8344 in_total,
8345 out_dims,
8346 in_strides,
8347 } => unsafe {
8348 let inp = sl_f64(*src, base, *in_total as usize);
8349 let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
8350 let out = sl_mut_f64(*dst, base, out_total);
8351 transpose_walk_f64(inp, out, out_dims, in_strides);
8352 },
8353
8354 Thunk::ActivationF64 {
8355 src,
8356 dst,
8357 len,
8358 kind,
8359 } => {
8360 let len = *len as usize;
8361 unsafe {
8362 let inp = sl_f64(*src, base, len);
8363 let out = sl_mut_f64(*dst, base, len);
8364 apply_activation_f64(inp, out, *kind);
8365 }
8366 }
8367
8368 Thunk::ReduceSumF64 {
8369 src,
8370 dst,
8371 outer,
8372 reduced,
8373 inner,
8374 } => {
8375 let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
8376 unsafe {
8377 let inp = sl_f64(*src, base, o * r * n);
8378 let out = sl_mut_f64(*dst, base, o * n);
8379 reduce_sum_f64(inp, out, o, r, n);
8380 }
8381 }
8382
8383 Thunk::CopyF64 { src, dst, len } => {
8384 let mut len = *len as usize;
8385 if *src == *dst || len == 0 {
8386 continue;
8387 }
8388 let arena_len = arena_buf.len();
8389 let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8390 let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8391 len = len.min(max_from_src).min(max_from_dst);
8392 if len == 0 {
8393 continue;
8394 }
8395 let byte_len = len.saturating_mul(8);
8396 unsafe {
8397 std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8398 }
8399 }
8400
8401 Thunk::CopyI64 { src, dst, len } => {
8402 let mut len = *len as usize;
8403 if *src == *dst || len == 0 {
8404 continue;
8405 }
8406 let arena_len = arena_buf.len();
8407 let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8408 let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8409 len = len.min(max_from_src).min(max_from_dst);
8410 if len == 0 {
8411 continue;
8412 }
8413 let byte_len = len.saturating_mul(8);
8414 unsafe {
8415 std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8416 }
8417 }
8418
8419 Thunk::CastF32ToI64 { src, dst, len } => {
8420 let len = *len as usize;
8421 if len == 0 {
8422 continue;
8423 }
8424 unsafe {
8425 let inp = sl(*src, base, len);
8426 let out = sl_mut_i64(*dst, base, len);
8427 for i in 0..len {
8428 out[i] = inp[i].round() as i64;
8429 }
8430 }
8431 }
8432
8433 Thunk::CastI64ToF32 { src, dst, len } => {
8434 let len = *len as usize;
8435 if len == 0 {
8436 continue;
8437 }
8438 unsafe {
8439 let inp = sl_i64(*src, base, len);
8440 let out = sl_mut(*dst, base, len);
8441 for i in 0..len {
8442 out[i] = inp[i] as f32;
8443 }
8444 }
8445 }
8446
8447 Thunk::CastBoolToI32 { src, dst, len } => {
8448 let len = *len as usize;
8449 if len == 0 {
8450 continue;
8451 }
8452 unsafe {
8453 let inp = &arena_buf[*src..*src + len];
8454 let out = sl_mut_i32(*dst, base, len);
8455 for i in 0..len {
8456 out[i] = i32::from(inp[i] != 0);
8457 }
8458 }
8459 }
8460
8461 Thunk::CastI32ToF32 { src, dst, len } => {
8462 let len = *len as usize;
8463 if len == 0 {
8464 continue;
8465 }
8466 unsafe {
8467 let inp = sl_i32(*src, base, len);
8468 let out = sl_mut(*dst, base, len);
8469 for i in 0..len {
8470 out[i] = inp[i] as f32;
8471 }
8472 }
8473 }
8474
8475 Thunk::BinaryFullF64 {
8476 lhs,
8477 rhs,
8478 dst,
8479 len,
8480 lhs_len,
8481 rhs_len,
8482 op,
8483 out_dims_bcast,
8484 bcast_lhs_strides,
8485 bcast_rhs_strides,
8486 } => {
8487 let len = *len as usize;
8488 let lhs_len = *lhs_len as usize;
8489 let rhs_len = *rhs_len as usize;
8490 unsafe {
8491 let l = sl_f64(*lhs, base, lhs_len);
8492 let r = sl_f64(*rhs, base, rhs_len);
8493 let d = sl_mut_f64(*dst, base, len);
8494 if lhs_len == len && rhs_len == len {
8495 for i in 0..len {
8496 d[i] = binary_op_f64(*op, l[i], r[i]);
8497 }
8498 } else if !out_dims_bcast.is_empty() {
8499 let rank = out_dims_bcast.len();
8503 let mut coords = vec![0u32; rank];
8504 for i in 0..len {
8505 let mut rem = i;
8506 for ax in (0..rank).rev() {
8507 let sz = out_dims_bcast[ax] as usize;
8508 coords[ax] = (rem % sz) as u32;
8509 rem /= sz;
8510 }
8511 let mut li: usize = 0;
8512 let mut ri: usize = 0;
8513 for ax in 0..rank {
8514 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8515 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8516 }
8517 d[i] = binary_op_f64(*op, l[li], r[ri]);
8518 }
8519 } else {
8520 for i in 0..len {
8525 d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
8526 }
8527 }
8528 }
8529 }
8530
8531 Thunk::BinaryFullC64 {
8532 lhs,
8533 rhs,
8534 dst,
8535 len,
8536 lhs_len,
8537 rhs_len,
8538 op,
8539 out_dims_bcast,
8540 bcast_lhs_strides,
8541 bcast_rhs_strides,
8542 } => {
8543 let n_out = *len as usize;
8549 let n_l = *lhs_len as usize;
8550 let n_r = *rhs_len as usize;
8551 unsafe {
8552 let l = sl(*lhs, base, 2 * n_l);
8553 let r = sl(*rhs, base, 2 * n_r);
8554 let d = sl_mut(*dst, base, 2 * n_out);
8555 let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
8556 match op {
8557 BinaryOp::Add => (a_re + b_re, a_im + b_im),
8558 BinaryOp::Sub => (a_re - b_re, a_im - b_im),
8559 BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
8560 BinaryOp::Div => {
8561 let denom = b_re * b_re + b_im * b_im;
8562 (
8563 (a_re * b_re + a_im * b_im) / denom,
8564 (a_im * b_re - a_re * b_im) / denom,
8565 )
8566 }
8567 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
8568 unreachable!("C64 max/min/pow rejected at lowering")
8569 }
8570 }
8571 };
8572 if n_l == n_out && n_r == n_out {
8573 for i in 0..n_out {
8574 let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
8575 d[2 * i] = re;
8576 d[2 * i + 1] = im;
8577 }
8578 } else if !out_dims_bcast.is_empty() {
8579 let rank = out_dims_bcast.len();
8583 let mut coords = vec![0u32; rank];
8584 for i in 0..n_out {
8585 let mut rem = i;
8586 for ax in (0..rank).rev() {
8587 let sz = out_dims_bcast[ax] as usize;
8588 coords[ax] = (rem % sz) as u32;
8589 rem /= sz;
8590 }
8591 let mut li: usize = 0;
8592 let mut ri: usize = 0;
8593 for ax in 0..rank {
8594 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8595 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8596 }
8597 let (re, im) =
8598 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
8599 d[2 * i] = re;
8600 d[2 * i + 1] = im;
8601 }
8602 } else {
8603 for i in 0..n_out {
8605 let li = if n_l == 1 { 0 } else { i % n_l };
8606 let ri = if n_r == 1 { 0 } else { i % n_r };
8607 let (re, im) =
8608 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
8609 d[2 * i] = re;
8610 d[2 * i + 1] = im;
8611 }
8612 }
8613 }
8614 }
8615
8616 Thunk::ComplexNormSqF32 { src, dst, len } => {
8617 let n = *len as usize;
8618 unsafe {
8619 let s = sl(*src, base, 2 * n);
8620 let d = sl_mut(*dst, base, n);
8621 for i in 0..n {
8622 let re = s[2 * i];
8623 let im = s[2 * i + 1];
8624 d[i] = re * re + im * im;
8625 }
8626 }
8627 }
8628
8629 Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
8630 let n = *len as usize;
8633 unsafe {
8634 let zb = sl(*z, base, 2 * n);
8635 let gb = sl(*g, base, n);
8636 let db = sl_mut(*dz, base, 2 * n);
8637 for i in 0..n {
8638 let re = zb[2 * i];
8639 let im = zb[2 * i + 1];
8640 let gv = gb[i];
8641 db[2 * i] = gv * re;
8642 db[2 * i + 1] = gv * im;
8643 }
8644 }
8645 }
8646
8647 Thunk::ConjugateC64 { src, dst, len } => {
8648 let n = *len as usize;
8649 unsafe {
8650 let s = sl(*src, base, 2 * n);
8651 let d = sl_mut(*dst, base, 2 * n);
8652 for i in 0..n {
8653 d[2 * i] = s[2 * i];
8654 d[2 * i + 1] = -s[2 * i + 1];
8655 }
8656 }
8657 }
8658
8659 Thunk::ActivationC64 {
8660 src,
8661 dst,
8662 len,
8663 kind,
8664 } => {
8665 let n = *len as usize;
8666 unsafe {
8667 let s = sl(*src, base, 2 * n);
8668 let d = sl_mut(*dst, base, 2 * n);
8669 for i in 0..n {
8670 let a = s[2 * i];
8671 let b = s[2 * i + 1];
8672 let (re, im) = match kind {
8673 Activation::Neg => (-a, -b),
8674 Activation::Exp => {
8675 let ea = a.exp();
8677 (ea * b.cos(), ea * b.sin())
8678 }
8679 Activation::Log => {
8680 let r = (a * a + b * b).sqrt();
8682 (r.ln(), b.atan2(a))
8683 }
8684 Activation::Sqrt => {
8685 let r = (a * a + b * b).sqrt();
8688 let re = ((r + a) * 0.5).max(0.0).sqrt();
8689 let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
8690 let im = if b >= 0.0 { im_mag } else { -im_mag };
8691 (re, im)
8692 }
8693 _ => unreachable!("non-C64 activation kind survived lowering"),
8694 };
8695 d[2 * i] = re;
8696 d[2 * i + 1] = im;
8697 }
8698 }
8699 }
8700
8701 Thunk::Scan {
8702 body,
8703 body_init,
8704 body_input_off,
8705 body_output_off,
8706 outer_init_off,
8707 outer_final_off,
8708 length,
8709 carry_bytes,
8710 save_trajectory,
8711 xs_inputs,
8712 bcast_inputs,
8713 num_checkpoints,
8714 } => {
8715 let cb = *carry_bytes as usize;
8716 let n_steps = *length as usize;
8717 let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
8721 n_steps } else {
8723 *num_checkpoints as usize
8724 };
8725 let checkpoint_t_for_k = |k: usize| -> usize {
8726 if k_total == n_steps {
8727 k
8728 } else {
8729 ((k + 1) * n_steps)
8730 .div_ceil(k_total)
8731 .saturating_sub(1)
8732 .min(n_steps - 1)
8733 }
8734 };
8735 let mut next_k = 0usize;
8736
8737 let mut body_buf: Vec<u8> = (**body_init).clone();
8738 unsafe {
8739 std::ptr::copy_nonoverlapping(
8740 base.add(*outer_init_off),
8741 body_buf.as_mut_ptr().add(*body_input_off),
8742 cb,
8743 );
8744 for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
8748 std::ptr::copy_nonoverlapping(
8749 base.add(*outer_b_off),
8750 body_buf.as_mut_ptr().add(*body_b_off),
8751 *total_bytes as usize,
8752 );
8753 }
8754 }
8755 for t in 0..n_steps {
8756 for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
8757 let psb = *per_step_bytes as usize;
8758 unsafe {
8759 std::ptr::copy_nonoverlapping(
8760 base.add(*outer_xs_off + t * psb),
8761 body_buf.as_mut_ptr().add(*body_x_off),
8762 psb,
8763 );
8764 }
8765 }
8766
8767 execute_thunks(body, &mut body_buf);
8768
8769 if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
8770 unsafe {
8771 std::ptr::copy_nonoverlapping(
8772 body_buf.as_ptr().add(*body_output_off),
8773 base.add(*outer_final_off + next_k * cb),
8774 cb,
8775 );
8776 }
8777 next_k += 1;
8778 }
8779
8780 if *body_output_off != *body_input_off {
8781 body_buf
8782 .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
8783 }
8784 }
8785
8786 if !*save_trajectory {
8787 unsafe {
8789 std::ptr::copy_nonoverlapping(
8790 body_buf.as_ptr().add(*body_output_off),
8791 base.add(*outer_final_off),
8792 cb,
8793 );
8794 }
8795 }
8796 }
8797
8798 Thunk::ScanBackward {
8799 body_vjp,
8800 body_init,
8801 body_carry_in_off,
8802 body_x_offs,
8803 body_d_output_off,
8804 body_dcarry_out_off,
8805 outer_init_off,
8806 outer_traj_off,
8807 outer_upstream_off,
8808 outer_xs_offs,
8809 outer_dinit_off,
8810 length,
8811 carry_bytes,
8812 save_trajectory,
8813 num_checkpoints,
8814 forward_body,
8815 forward_body_init,
8816 forward_body_carry_in_off,
8817 forward_body_output_off,
8818 forward_body_x_offs,
8819 carry_elem_size,
8820 } => {
8821 let cb = *carry_bytes as usize;
8834 let n_steps = *length as usize;
8835 let k_total = *num_checkpoints as usize;
8836 let is_recursive = k_total != 0 && k_total != n_steps;
8837 let checkpoint_t_for_k = |k: usize| -> usize {
8838 ((k + 1) * n_steps)
8839 .div_ceil(k_total)
8840 .saturating_sub(1)
8841 .min(n_steps - 1)
8842 };
8843
8844 let mut fwd_buf: Vec<u8> = if is_recursive {
8845 (**forward_body_init.as_ref().unwrap()).clone()
8846 } else {
8847 Vec::new()
8848 };
8849
8850 let mut dcarry: Vec<u8> = vec![0u8; cb];
8851 if !*save_trajectory {
8852 unsafe {
8853 std::ptr::copy_nonoverlapping(
8854 base.add(*outer_upstream_off),
8855 dcarry.as_mut_ptr(),
8856 cb,
8857 );
8858 }
8859 }
8860
8861 let mut body_buf: Vec<u8> = (**body_init).clone();
8862
8863 let process_iter =
8868 |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
8869 if *save_trajectory {
8870 unsafe {
8871 let up_off = *outer_upstream_off + t * cb;
8872 match *carry_elem_size {
8873 4 => {
8874 let up_ptr = base.add(up_off) as *const f32;
8875 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8876 let n_elems = cb / 4;
8877 for i in 0..n_elems {
8878 *dc_ptr.add(i) += *up_ptr.add(i);
8879 }
8880 }
8881 8 => {
8882 let up_ptr = base.add(up_off) as *const f64;
8883 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8884 let n_elems = cb / 8;
8885 for i in 0..n_elems {
8886 *dc_ptr.add(i) += *up_ptr.add(i);
8887 }
8888 }
8889 other => panic!(
8890 "ScanBackward: unsupported carry elem size {other} \
8891 (only f32/f64 carries are supported today)"
8892 ),
8893 }
8894 }
8895 }
8896 body_buf[*body_carry_in_off..*body_carry_in_off + cb]
8897 .copy_from_slice(carry_in);
8898 unsafe {
8899 for (i, body_x_off) in body_x_offs.iter().enumerate() {
8900 let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
8901 let psb = per_step_bytes as usize;
8902 std::ptr::copy_nonoverlapping(
8903 base.add(outer_xs_off + t * psb),
8904 body_buf.as_mut_ptr().add(*body_x_off),
8905 psb,
8906 );
8907 }
8908 std::ptr::copy_nonoverlapping(
8909 dcarry.as_ptr(),
8910 body_buf.as_mut_ptr().add(*body_d_output_off),
8911 cb,
8912 );
8913 }
8914 execute_thunks(body_vjp, body_buf);
8915 unsafe {
8916 std::ptr::copy_nonoverlapping(
8917 body_buf.as_ptr().add(*body_dcarry_out_off),
8918 dcarry.as_mut_ptr(),
8919 cb,
8920 );
8921 }
8922 };
8923
8924 if is_recursive {
8925 let leaf_threshold = 4usize;
8933 let fb_sched = forward_body.as_ref().unwrap();
8934 let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8935 let mut segment_end = n_steps - 1;
8936 for seg_k in (0..k_total).rev() {
8937 let segment_start = if seg_k == 0 {
8938 0
8939 } else {
8940 checkpoint_t_for_k(seg_k - 1) + 1
8941 };
8942 let mut anchor: Vec<u8> = vec![0u8; cb];
8943 unsafe {
8944 let src = if seg_k == 0 {
8945 base.add(*outer_init_off)
8946 } else {
8947 base.add(*outer_traj_off + (seg_k - 1) * cb)
8948 };
8949 std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8950 }
8951 let mut leaf_action = |t: usize, carry_in: &[u8]| {
8954 process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8955 };
8956 unsafe {
8957 griewank_process_segment(
8958 segment_start,
8959 segment_end,
8960 &anchor,
8961 cb,
8962 fb_sched,
8963 fb_init,
8964 *forward_body_carry_in_off,
8965 *forward_body_output_off,
8966 forward_body_x_offs,
8967 base,
8968 outer_xs_offs,
8969 &mut fwd_buf,
8970 leaf_threshold,
8971 &mut leaf_action,
8972 );
8973 }
8974 if seg_k == 0 {
8975 break;
8976 }
8977 segment_end = segment_start - 1;
8978 }
8979 } else {
8980 let mut carry_buf: Vec<u8> = vec![0u8; cb];
8983 for t in (0..n_steps).rev() {
8984 unsafe {
8985 let src = if t == 0 {
8986 base.add(*outer_init_off)
8987 } else {
8988 base.add(*outer_traj_off + (t - 1) * cb)
8989 };
8990 std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8991 }
8992 process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8993 }
8994 }
8995
8996 unsafe {
8997 std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8998 }
8999 }
9000
9001 Thunk::ScanBackwardXs {
9002 body_vjp,
9003 body_init,
9004 body_carry_in_off,
9005 body_x_offs,
9006 body_d_output_off,
9007 body_dcarry_out_off,
9008 body_dxs_out_off,
9009 outer_init_off,
9010 outer_traj_off,
9011 outer_upstream_off,
9012 outer_xs_offs,
9013 outer_dxs_off,
9014 length,
9015 carry_bytes,
9016 carry_elem_size,
9017 per_step_bytes,
9018 save_trajectory,
9019 num_checkpoints,
9020 forward_body,
9021 forward_body_init,
9022 forward_body_carry_in_off,
9023 forward_body_output_off,
9024 forward_body_x_offs,
9025 } => {
9026 let cb = *carry_bytes as usize;
9027 let psb = *per_step_bytes as usize;
9028 let n_steps = *length as usize;
9029 let k_total = *num_checkpoints as usize;
9030 let is_recursive = k_total != 0 && k_total != n_steps;
9031 let checkpoint_t_for_k = |k: usize| -> usize {
9032 ((k + 1) * n_steps)
9033 .div_ceil(k_total)
9034 .saturating_sub(1)
9035 .min(n_steps - 1)
9036 };
9037
9038 let mut fwd_buf: Vec<u8> = if is_recursive {
9042 (**forward_body_init.as_ref().unwrap()).clone()
9043 } else {
9044 Vec::new()
9045 };
9046 let mut seg_cache: Vec<u8> = Vec::new();
9047 let mut seg_start_t: usize = usize::MAX;
9048 let mut seg_count: usize = 0;
9049 let recompute_carry_t =
9050 |t: usize,
9051 dst: &mut [u8],
9052 fwd_buf: &mut Vec<u8>,
9053 seg_cache: &mut Vec<u8>,
9054 seg_start_t: &mut usize,
9055 seg_count: &mut usize| {
9056 if !is_recursive {
9057 unsafe {
9058 let src = if t == 0 {
9059 base.add(*outer_init_off)
9060 } else {
9061 base.add(*outer_traj_off + (t - 1) * cb)
9062 };
9063 std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
9064 }
9065 return;
9066 }
9067 if *seg_start_t != usize::MAX
9068 && t >= *seg_start_t
9069 && t < *seg_start_t + *seg_count
9070 {
9071 let off = (t - *seg_start_t) * cb;
9072 dst.copy_from_slice(&seg_cache[off..off + cb]);
9073 return;
9074 }
9075 let seg_k = (0..k_total)
9076 .find(|&k| t <= checkpoint_t_for_k(k))
9077 .unwrap_or(k_total - 1);
9078 let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
9079 (0, unsafe { base.add(*outer_init_off) as *const u8 })
9080 } else {
9081 let prev_ck = checkpoint_t_for_k(seg_k - 1);
9082 (prev_ck + 1, unsafe {
9083 base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
9084 })
9085 };
9086 let seg_end_t = checkpoint_t_for_k(seg_k);
9087 let seg_size = seg_end_t - anchor_t + 1;
9088
9089 fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
9090 unsafe {
9091 std::ptr::copy_nonoverlapping(
9092 anchor_ptr,
9093 fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
9094 cb,
9095 );
9096 }
9097 seg_cache.resize(seg_size * cb, 0u8);
9098 seg_cache[0..cb].copy_from_slice(
9099 &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9100 );
9101 let fb_sched = forward_body.as_ref().unwrap();
9102 for i in 1..seg_size {
9103 let cur_iter = anchor_t + i - 1;
9104 for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
9105 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
9106 let xb = x_psb as usize;
9107 unsafe {
9108 std::ptr::copy_nonoverlapping(
9109 base.add(outer_xs_off + cur_iter * xb),
9110 fwd_buf.as_mut_ptr().add(*fb_x_off),
9111 xb,
9112 );
9113 }
9114 }
9115 execute_thunks(fb_sched, fwd_buf);
9116 if *forward_body_output_off != *forward_body_carry_in_off {
9117 fwd_buf.copy_within(
9118 *forward_body_output_off..*forward_body_output_off + cb,
9119 *forward_body_carry_in_off,
9120 );
9121 }
9122 let cache_off = i * cb;
9123 seg_cache[cache_off..cache_off + cb].copy_from_slice(
9124 &fwd_buf
9125 [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9126 );
9127 }
9128 *seg_start_t = anchor_t;
9129 *seg_count = seg_size;
9130
9131 let off = (t - anchor_t) * cb;
9132 dst.copy_from_slice(&seg_cache[off..off + cb]);
9133 };
9134
9135 let mut dcarry: Vec<u8> = vec![0u8; cb];
9136 if !*save_trajectory {
9137 unsafe {
9138 std::ptr::copy_nonoverlapping(
9139 base.add(*outer_upstream_off),
9140 dcarry.as_mut_ptr(),
9141 cb,
9142 );
9143 }
9144 }
9145
9146 let mut body_buf: Vec<u8> = (**body_init).clone();
9147
9148 for t in (0..n_steps).rev() {
9149 if *save_trajectory {
9150 unsafe {
9151 let up_off = *outer_upstream_off + t * cb;
9152 match *carry_elem_size {
9153 4 => {
9154 let up_ptr = base.add(up_off) as *const f32;
9155 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
9156 let n_elems = cb / 4;
9157 for i in 0..n_elems {
9158 *dc_ptr.add(i) += *up_ptr.add(i);
9159 }
9160 }
9161 8 => {
9162 let up_ptr = base.add(up_off) as *const f64;
9163 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
9164 let n_elems = cb / 8;
9165 for i in 0..n_elems {
9166 *dc_ptr.add(i) += *up_ptr.add(i);
9167 }
9168 }
9169 other => panic!(
9170 "ScanBackwardXs: unsupported carry elem size {other} \
9171 (only f32/f64 carries are supported today)"
9172 ),
9173 }
9174 }
9175 }
9176
9177 let carry_dst_start = *body_carry_in_off;
9181 {
9182 let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
9183 recompute_carry_t(
9184 t,
9185 carry_slice,
9186 &mut fwd_buf,
9187 &mut seg_cache,
9188 &mut seg_start_t,
9189 &mut seg_count,
9190 );
9191 }
9192 unsafe {
9193 for (i, body_x_off) in body_x_offs.iter().enumerate() {
9194 let (outer_xs_off, x_psb) = outer_xs_offs[i];
9195 let xb = x_psb as usize;
9196 std::ptr::copy_nonoverlapping(
9197 base.add(outer_xs_off + t * xb),
9198 body_buf.as_mut_ptr().add(*body_x_off),
9199 xb,
9200 );
9201 }
9202 std::ptr::copy_nonoverlapping(
9203 dcarry.as_ptr(),
9204 body_buf.as_mut_ptr().add(*body_d_output_off),
9205 cb,
9206 );
9207 }
9208
9209 execute_thunks(body_vjp, &mut body_buf);
9210
9211 unsafe {
9214 std::ptr::copy_nonoverlapping(
9215 body_buf.as_ptr().add(*body_dxs_out_off),
9216 base.add(*outer_dxs_off + t * psb),
9217 psb,
9218 );
9219 }
9220
9221 unsafe {
9223 std::ptr::copy_nonoverlapping(
9224 body_buf.as_ptr().add(*body_dcarry_out_off),
9225 dcarry.as_mut_ptr(),
9226 cb,
9227 );
9228 }
9229 }
9230 }
9231
9232 Thunk::FusedMmBiasAct {
9233 a,
9234 w,
9235 bias,
9236 c,
9237 m,
9238 k,
9239 n,
9240 act,
9241 } => {
9242 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9243 unsafe {
9244 let out = sl_mut(*c, base, m * n);
9245 crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
9246 match act {
9247 Some(Activation::Gelu) => {
9248 crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
9249 }
9250 Some(other) => {
9251 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9252 apply_activation_inplace(out, *other);
9253 }
9254 None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
9255 }
9256 }
9257 }
9258
9259 Thunk::FusedResidualLN {
9260 x,
9261 res,
9262 bias,
9263 g,
9264 b,
9265 out,
9266 rows,
9267 h,
9268 eps,
9269 has_bias,
9270 } => {
9271 let (rows, h) = (*rows as usize, *h as usize);
9272 unsafe {
9273 let zero = &zero_bias[..h];
9274 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9275 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9276 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9277 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9278 let bi_ptr = bi.as_ptr() as usize;
9279 let g_ptr = sl(*g, base, h).as_ptr() as usize;
9280 let b_ptr = sl(*b, base, h).as_ptr() as usize;
9281 let e = *eps;
9282 crate::pool::par_for(rows, 4, &|off, cnt| {
9283 let xs =
9284 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9285 let rs =
9286 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9287 let os = std::slice::from_raw_parts_mut(
9288 (o_ptr as *mut f32).add(off * h),
9289 cnt * h,
9290 );
9291 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9292 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9293 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9294 crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
9295 });
9296 }
9297 }
9298
9299 Thunk::FusedResidualRmsNorm {
9300 x,
9301 res,
9302 bias,
9303 g,
9304 b,
9305 out,
9306 rows,
9307 h,
9308 eps,
9309 has_bias,
9310 } => {
9311 let (rows, h) = (*rows as usize, *h as usize);
9312 unsafe {
9313 let zero = &zero_bias[..h];
9314 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9315 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9316 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9317 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9318 let bi_ptr = bi.as_ptr() as usize;
9319 let g_ptr = sl(*g, base, h).as_ptr() as usize;
9320 let b_ptr = sl(*b, base, h).as_ptr() as usize;
9321 let e = *eps;
9322 crate::pool::par_for(rows, 4, &|off, cnt| {
9323 let xs =
9324 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9325 let rs =
9326 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9327 let os = std::slice::from_raw_parts_mut(
9328 (o_ptr as *mut f32).add(off * h),
9329 cnt * h,
9330 );
9331 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9332 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9333 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9334 crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
9335 });
9336 }
9337 }
9338
9339 Thunk::BiasAdd {
9340 src,
9341 bias,
9342 dst,
9343 m,
9344 n,
9345 } => {
9346 let (m, n) = (*m as usize, *n as usize);
9347 let len = m * n;
9348 unsafe {
9349 let out = sl_mut(*dst, base, len);
9350 if *src != *dst {
9351 let src_ptr = base.add(*src) as *const f32;
9352 let dst_ptr = base.add(*dst) as *mut f32;
9353 if src_ptr != dst_ptr {
9354 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
9355 }
9356 }
9357 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9358 }
9359 }
9360
9361 Thunk::BinaryFull {
9362 lhs,
9363 rhs,
9364 dst,
9365 len,
9366 lhs_len,
9367 rhs_len,
9368 op,
9369 out_dims_bcast,
9370 bcast_lhs_strides,
9371 bcast_rhs_strides,
9372 elem_bytes,
9373 } => {
9374 let len = *len as usize;
9375 let ll = (*lhs_len as usize).max(1);
9376 let rl = (*rhs_len as usize).max(1);
9377 let eb = (*elem_bytes).max(1) as usize;
9378 let arena_len = arena_buf.len();
9379 let ll = ll.min((arena_len.saturating_sub(*lhs)) / eb);
9380 let rl = rl.min((arena_len.saturating_sub(*rhs)) / eb);
9381 let len = len.min((arena_len.saturating_sub(*dst)) / eb);
9382 unsafe {
9383 if eb == 8 {
9384 let l = sl_i64(*lhs, base, ll);
9385 let r = sl_i64(*rhs, base, rl);
9386 let o = sl_mut_i64(*dst, base, len);
9387 if !out_dims_bcast.is_empty() {
9388 let rank = out_dims_bcast.len();
9389 let mut coords = vec![0u32; rank];
9390 for i in 0..len {
9391 let mut rem = i;
9392 for ax in (0..rank).rev() {
9393 let sz = out_dims_bcast[ax] as usize;
9394 coords[ax] = (rem % sz) as u32;
9395 rem /= sz;
9396 }
9397 let mut li = 0usize;
9398 let mut ri = 0usize;
9399 for ax in 0..rank {
9400 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9401 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9402 }
9403 o[i] = match op {
9404 BinaryOp::Add => l[li].wrapping_add(r[ri]),
9405 BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9406 BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9407 BinaryOp::Div => {
9408 if r[ri] == 0 {
9409 0
9410 } else {
9411 l[li] / r[ri]
9412 }
9413 }
9414 BinaryOp::Max => l[li].max(r[ri]),
9415 BinaryOp::Min => l[li].min(r[ri]),
9416 BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9417 };
9418 }
9419 } else {
9420 for i in 0..len {
9421 let li = if ll == 1 { 0 } else { i % ll };
9422 let ri = if rl == 1 { 0 } else { i % rl };
9423 o[i] = match op {
9424 BinaryOp::Add => l[li].wrapping_add(r[ri]),
9425 BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9426 BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9427 BinaryOp::Div => {
9428 if r[ri] == 0 {
9429 0
9430 } else {
9431 l[li] / r[ri]
9432 }
9433 }
9434 BinaryOp::Max => l[li].max(r[ri]),
9435 BinaryOp::Min => l[li].min(r[ri]),
9436 BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9437 };
9438 }
9439 }
9440 } else {
9441 let l = sl(*lhs, base, ll);
9442 let r = sl(*rhs, base, rl);
9443 let o = sl_mut(*dst, base, len);
9444 if ll == len && rl == len {
9445 #[cfg(target_arch = "aarch64")]
9446 if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
9447 use std::arch::aarch64::*;
9448 let chunks = len / 4;
9449 for c in 0..chunks {
9450 let off = c * 4;
9451 let vl = vld1q_f32(l.as_ptr().add(off));
9452 let vr = vld1q_f32(r.as_ptr().add(off));
9453 let res = match op {
9454 BinaryOp::Add => vaddq_f32(vl, vr),
9455 BinaryOp::Mul => vmulq_f32(vl, vr),
9456 _ => unreachable!(),
9457 };
9458 vst1q_f32(o.as_mut_ptr().add(off), res);
9459 }
9460 for i in (chunks * 4)..len {
9461 o[i] = match op {
9462 BinaryOp::Add => l[i] + r[i],
9463 BinaryOp::Mul => l[i] * r[i],
9464 _ => unreachable!(),
9465 };
9466 }
9467 continue;
9468 }
9469 }
9470 if !out_dims_bcast.is_empty() {
9471 let rank = out_dims_bcast.len();
9472 let mut coords = vec![0u32; rank];
9473 for i in 0..len {
9474 let mut rem = i;
9475 for ax in (0..rank).rev() {
9476 let sz = out_dims_bcast[ax] as usize;
9477 coords[ax] = (rem % sz) as u32;
9478 rem /= sz;
9479 }
9480 let mut li = 0usize;
9481 let mut ri = 0usize;
9482 for ax in 0..rank {
9483 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9484 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9485 }
9486 o[i] = match op {
9487 BinaryOp::Add => l[li] + r[ri],
9488 BinaryOp::Sub => l[li] - r[ri],
9489 BinaryOp::Mul => l[li] * r[ri],
9490 BinaryOp::Div => l[li] / r[ri],
9491 BinaryOp::Max => l[li].max(r[ri]),
9492 BinaryOp::Min => l[li].min(r[ri]),
9493 BinaryOp::Pow => l[li].powf(r[ri]),
9494 };
9495 }
9496 } else {
9497 for i in 0..len {
9498 let li = if ll == 1 { 0 } else { i % ll };
9499 let ri = if rl == 1 { 0 } else { i % rl };
9500 o[i] = match op {
9501 BinaryOp::Add => l[li] + r[ri],
9502 BinaryOp::Sub => l[li] - r[ri],
9503 BinaryOp::Mul => l[li] * r[ri],
9504 BinaryOp::Div => l[li] / r[ri],
9505 BinaryOp::Max => l[li].max(r[ri]),
9506 BinaryOp::Min => l[li].min(r[ri]),
9507 BinaryOp::Pow => l[li].powf(r[ri]),
9508 };
9509 }
9510 }
9511 }
9512 }
9513 }
9514
9515 Thunk::Gather {
9516 table,
9517 table_len,
9518 idx,
9519 dst,
9520 num_idx,
9521 trailing,
9522 idx_i64,
9523 table_bytes,
9524 } => {
9525 let (ni, tr) = (*num_idx as usize, *trailing as usize);
9526 let rows = *table_len as usize / tr.max(1);
9527 unsafe {
9528 if *table_bytes == 8 {
9529 let tab = sl_i64(*table, base, *table_len as usize);
9530 let out = sl_mut_i64(*dst, base, ni * tr);
9531 if *idx_i64 != 0 {
9532 let ids = sl_i64(*idx, base, ni);
9533 for i in 0..ni {
9534 let row = ids[i].max(0) as usize;
9535 if row < rows {
9536 out[i * tr..(i + 1) * tr]
9537 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9538 }
9539 }
9540 } else {
9541 let ids = sl(*idx, base, ni);
9542 for i in 0..ni {
9543 let row = ids[i] as usize;
9544 if row < rows {
9545 out[i * tr..(i + 1) * tr]
9546 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9547 }
9548 }
9549 }
9550 } else {
9551 let tab = sl(*table, base, *table_len as usize);
9552 let out = sl_mut(*dst, base, ni * tr);
9553 if *idx_i64 != 0 {
9554 let ids = sl_i64(*idx, base, ni);
9555 for i in 0..ni {
9556 let row = ids[i].max(0) as usize;
9557 if row < rows {
9558 out[i * tr..(i + 1) * tr]
9559 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9560 }
9561 }
9562 } else {
9563 let ids = sl(*idx, base, ni);
9564 for i in 0..ni {
9565 let row = ids[i] as usize;
9566 if row < rows {
9567 out[i * tr..(i + 1) * tr]
9568 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9569 }
9570 }
9571 }
9572 }
9573 }
9574 }
9575
9576 Thunk::Narrow {
9577 src,
9578 dst,
9579 outer,
9580 src_stride,
9581 dst_stride,
9582 inner,
9583 elem_bytes,
9584 } => {
9585 let (outer, ss, ds, inner, eb) = (
9586 *outer as usize,
9587 *src_stride as usize,
9588 *dst_stride as usize,
9589 *inner as usize,
9590 *elem_bytes as usize,
9591 );
9592 let row_bytes = inner.saturating_mul(eb);
9593 let src_row_stride = ss.saturating_mul(eb);
9594 let dst_row_stride = ds.saturating_mul(eb);
9595 if trace_thunks {
9596 eprintln!(
9597 "[narrow] src={} dst={} outer={outer} ss={ss} ds={ds} inner={inner} eb={eb} row={row_bytes} arena={}",
9598 *src,
9599 *dst,
9600 arena_buf.len()
9601 );
9602 }
9603 if row_bytes > 0 && *src != *dst {
9604 let arena_len = arena_buf.len();
9605 for o in 0..outer {
9606 let s_off = *src + o * src_row_stride;
9607 let d_off = *dst + o * dst_row_stride;
9608 if s_off == d_off {
9609 continue;
9610 }
9611 if s_off.saturating_add(row_bytes) > arena_len
9612 || d_off.saturating_add(row_bytes) > arena_len
9613 {
9614 break;
9615 }
9616 unsafe {
9617 std::ptr::copy_nonoverlapping(
9618 base.add(s_off),
9619 base.add(d_off),
9620 row_bytes,
9621 );
9622 }
9623 }
9624 }
9625 }
9626
9627 Thunk::Copy { src, dst, len } => {
9628 let mut len = *len as usize;
9629 if *src == *dst || len == 0 {
9630 continue;
9631 }
9632 let arena_len = arena_buf.len();
9633 let max_from_src = (arena_len.saturating_sub(*src)) / 4;
9634 let max_from_dst = (arena_len.saturating_sub(*dst)) / 4;
9635 len = len.min(max_from_src).min(max_from_dst);
9636 if len == 0 {
9637 continue;
9638 }
9639 let byte_len = len.saturating_mul(4);
9640 unsafe {
9641 std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
9642 }
9643 }
9644
9645 Thunk::LayerNorm {
9646 src,
9647 g,
9648 b,
9649 dst,
9650 rows,
9651 h,
9652 eps,
9653 } => {
9654 let (rows, h) = (*rows as usize, *h as usize);
9655 unsafe {
9656 let input = sl(*src, base, rows * h);
9657 let gamma = sl(*g, base, h);
9658 let beta = sl(*b, base, h);
9659 let output = sl_mut(*dst, base, rows * h);
9660 if rows >= 4 && rows * h >= 30_000 {
9662 let i_ptr = input.as_ptr() as usize;
9663 let o_ptr = output.as_mut_ptr() as usize;
9664 let g_ptr = gamma.as_ptr() as usize;
9665 let b_ptr = beta.as_ptr() as usize;
9666 let e = *eps;
9667 crate::pool::par_for(rows, 4, &|off, cnt| {
9668 let inp = std::slice::from_raw_parts(
9669 (i_ptr as *const f32).add(off * h),
9670 cnt * h,
9671 );
9672 let out = std::slice::from_raw_parts_mut(
9673 (o_ptr as *mut f32).add(off * h),
9674 cnt * h,
9675 );
9676 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9677 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9678 for row in 0..cnt {
9679 crate::kernels::layer_norm_row(
9680 &inp[row * h..(row + 1) * h],
9681 g,
9682 b,
9683 &mut out[row * h..(row + 1) * h],
9684 h,
9685 e,
9686 );
9687 }
9688 });
9689 } else {
9690 for row in 0..rows {
9691 crate::kernels::layer_norm_row(
9692 &input[row * h..(row + 1) * h],
9693 gamma,
9694 beta,
9695 &mut output[row * h..(row + 1) * h],
9696 h,
9697 *eps,
9698 );
9699 }
9700 }
9701 }
9702 }
9703
9704 Thunk::GroupNorm {
9705 src,
9706 g,
9707 b,
9708 dst,
9709 n,
9710 c,
9711 h,
9712 w,
9713 num_groups,
9714 eps,
9715 } => {
9716 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9717 let plane = c * h * w;
9718 unsafe {
9719 for ni in 0..n {
9720 let input = sl(*src, base.add(ni * plane), plane);
9721 let gamma = sl(*g, base, c);
9722 let beta = sl(*b, base, c);
9723 let output = sl_mut(*dst, base.add(ni * plane), plane);
9724 crate::kernels::group_norm_nchw(
9725 input,
9726 gamma,
9727 beta,
9728 output,
9729 1,
9730 c,
9731 h,
9732 w,
9733 *num_groups as usize,
9734 *eps,
9735 );
9736 }
9737 }
9738 }
9739
9740 Thunk::BatchNormInference {
9741 src,
9742 g,
9743 b,
9744 mean,
9745 var,
9746 dst,
9747 count,
9748 channels,
9749 eps,
9750 } => {
9751 let count = *count as usize;
9752 let c = *channels as usize;
9753 let n = count * c;
9754 unsafe {
9755 crate::kernels::batch_norm_inference(
9756 sl(*src, base, n),
9757 sl(*g, base, c),
9758 sl(*b, base, c),
9759 sl(*mean, base, c),
9760 sl(*var, base, c),
9761 sl_mut(*dst, base, n),
9762 c,
9763 *eps,
9764 );
9765 }
9766 }
9767
9768 Thunk::LayerNorm2d {
9769 src,
9770 g,
9771 b,
9772 dst,
9773 n,
9774 c,
9775 h,
9776 w,
9777 eps,
9778 } => {
9779 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9780 let plane = c * h * w;
9781 unsafe {
9782 let input = sl(*src, base, n * plane);
9783 let gamma = sl(*g, base, c);
9784 let beta = sl(*b, base, c);
9785 let output = sl_mut(*dst, base, n * plane);
9786 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
9787 }
9788 }
9789
9790 Thunk::ConvTranspose2d {
9791 src,
9792 weight,
9793 dst,
9794 n,
9795 c_in,
9796 h,
9797 w_in,
9798 c_out,
9799 h_out,
9800 w_out,
9801 kh,
9802 kw,
9803 sh,
9804 sw,
9805 ph,
9806 pw,
9807 dh,
9808 dw,
9809 groups,
9810 } => {
9811 let n = *n as usize;
9812 let c_in = *c_in as usize;
9813 let h = *h as usize;
9814 let w_in = *w_in as usize;
9815 let c_out = *c_out as usize;
9816 let h_out = *h_out as usize;
9817 let w_out = *w_out as usize;
9818 unsafe {
9819 let inp = sl(*src, base, n * c_in * h * w_in);
9820 let wt = sl(
9821 *weight,
9822 base,
9823 c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
9824 );
9825 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
9826 crate::kernels::conv_transpose2d_nchw(
9827 inp,
9828 wt,
9829 out,
9830 n,
9831 c_in,
9832 h,
9833 w_in,
9834 c_out,
9835 h_out,
9836 w_out,
9837 *kh as usize,
9838 *kw as usize,
9839 *sh as usize,
9840 *sw as usize,
9841 *ph as usize,
9842 *pw as usize,
9843 *dh as usize,
9844 *dw as usize,
9845 *groups as usize,
9846 );
9847 }
9848 }
9849
9850 Thunk::ResizeNearest2x {
9851 src,
9852 dst,
9853 n,
9854 c,
9855 h,
9856 w,
9857 } => {
9858 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9859 let in_plane = c * h * w;
9860 let out_plane = c * h * 2 * w * 2;
9861 unsafe {
9862 for ni in 0..n {
9863 let input = sl(*src, base.add(ni * in_plane), in_plane);
9864 let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
9865 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
9866 }
9867 }
9868 }
9869
9870 Thunk::AxialRope2d {
9871 src,
9872 dst,
9873 batch,
9874 seq,
9875 hidden,
9876 end_x,
9877 end_y,
9878 head_dim,
9879 num_heads,
9880 theta,
9881 repeat_factor,
9882 } => {
9883 let b = *batch as usize;
9884 let s = *seq as usize;
9885 let hdim = *head_dim as usize;
9886 let nh = *num_heads as usize;
9887 let plane = s * (*hidden as usize);
9888 unsafe {
9889 for bi in 0..b {
9890 let input = sl(*src, base.add(bi * plane), plane);
9891 let output = sl_mut(*dst, base.add(bi * plane), plane);
9892 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
9893 input,
9894 nh,
9895 s,
9896 hdim,
9897 *end_x as usize,
9898 *end_y as usize,
9899 *theta,
9900 *repeat_factor as usize,
9901 );
9902 output.copy_from_slice(&rotated);
9903 }
9904 }
9905 }
9906
9907 Thunk::RmsNorm {
9908 src,
9909 g,
9910 b,
9911 dst,
9912 rows,
9913 h,
9914 eps,
9915 } => {
9916 let (rows, h) = (*rows as usize, *h as usize);
9917 unsafe {
9918 let input = sl(*src, base, rows * h);
9919 let gamma = sl(*g, base, h);
9920 let beta = sl(*b, base, h);
9921 let output = sl_mut(*dst, base, rows * h);
9922 let inv_h = 1.0 / h as f32;
9923 for row in 0..rows {
9924 let in_row = &input[row * h..(row + 1) * h];
9925 let out_row = &mut output[row * h..(row + 1) * h];
9926 let mut sumsq = 0f32;
9928 for &v in in_row {
9929 sumsq += v * v;
9930 }
9931 let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
9932 for i in 0..h {
9933 out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
9934 }
9935 }
9936 }
9937 }
9938
9939 Thunk::Softmax { data, rows, cols } => {
9940 let (rows, cols) = (*rows as usize, *cols as usize);
9941 unsafe {
9942 crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
9943 }
9944 }
9945
9946 Thunk::Cumsum {
9947 src,
9948 dst,
9949 rows,
9950 cols,
9951 exclusive,
9952 } => {
9953 let (rows, cols) = (*rows as usize, *cols as usize);
9954 unsafe {
9955 let s = sl(*src, base, rows * cols);
9956 let d = sl_mut(*dst, base, rows * cols);
9957 if *exclusive {
9958 for r in 0..rows {
9959 let mut acc = 0.0f32;
9960 for c in 0..cols {
9961 d[r * cols + c] = acc;
9962 acc += s[r * cols + c];
9963 }
9964 }
9965 } else {
9966 for r in 0..rows {
9967 let mut acc = 0.0f32;
9968 for c in 0..cols {
9969 acc += s[r * cols + c];
9970 d[r * cols + c] = acc;
9971 }
9972 }
9973 }
9974 }
9975 }
9976
9977 Thunk::Sample {
9978 logits,
9979 dst,
9980 batch,
9981 vocab,
9982 top_k,
9983 top_p,
9984 temperature,
9985 seed,
9986 } => {
9987 let (b, v) = (*batch as usize, *vocab as usize);
9988 let k = (*top_k as usize).min(v);
9989 unsafe {
9990 let lg = sl(*logits, base, b * v);
9991 let out = sl_mut(*dst, base, b);
9992 let mut rng =
9993 rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
9994 for bi in 0..b {
9995 let row = &lg[bi * v..(bi + 1) * v];
9996 out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
9997 }
9998 }
9999 }
10000
10001 Thunk::GatedDeltaNet {
10002 q,
10003 k,
10004 v,
10005 g,
10006 beta,
10007 state,
10008 dst,
10009 batch,
10010 seq,
10011 heads,
10012 state_size,
10013 } => unsafe {
10014 execute_gated_delta_net_f32(
10015 *q,
10016 *k,
10017 *v,
10018 *g,
10019 *beta,
10020 *state,
10021 *dst,
10022 *batch as usize,
10023 *seq as usize,
10024 *heads as usize,
10025 *state_size as usize,
10026 base,
10027 );
10028 },
10029
10030 Thunk::SelectiveScan {
10031 x,
10032 delta,
10033 a,
10034 b: bp,
10035 c: cp,
10036 dst,
10037 batch,
10038 seq,
10039 hidden,
10040 state_size,
10041 } => {
10042 let (b, s, h, n) = (
10043 *batch as usize,
10044 *seq as usize,
10045 *hidden as usize,
10046 *state_size as usize,
10047 );
10048 unsafe {
10049 let xs = sl(*x, base, b * s * h);
10050 let dt = sl(*delta, base, b * s * h);
10051 let am = sl(*a, base, h * n);
10052 let bm = sl(*bp, base, b * s * n);
10053 let cm = sl(*cp, base, b * s * n);
10054 let out = sl_mut(*dst, base, b * s * h);
10055
10056 let mut state = vec![0f32; h * n];
10060 for bi in 0..b {
10061 for v in state.iter_mut() {
10063 *v = 0.0;
10064 }
10065 for si in 0..s {
10066 let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10067 let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10068 let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10069 let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10070 let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10071
10072 for ci in 0..h {
10073 let d = dt_row[ci];
10074 let xv = x_row[ci];
10075 let mut acc = 0f32;
10076 for ni in 0..n {
10077 let da = (d * am[ci * n + ni]).exp();
10079 state[ci * n + ni] =
10080 da * state[ci * n + ni] + d * b_row[ni] * xv;
10081 acc += c_row[ni] * state[ci * n + ni];
10082 }
10083 out_row[ci] = acc;
10084 }
10085 }
10086 }
10087 }
10088 }
10089
10090 Thunk::DequantMatMul {
10091 x,
10092 w_q,
10093 scale,
10094 zp,
10095 dst,
10096 m,
10097 k,
10098 n,
10099 block_size,
10100 is_asymmetric,
10101 } => {
10102 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10103 let n_blocks = k.div_ceil(bs);
10104 unsafe {
10105 let xs = sl(*x, base, m * k);
10106 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
10107 let scales = sl(*scale, base, n_blocks * n);
10108 let zps = if *is_asymmetric {
10109 sl(*zp, base, n_blocks * n)
10110 } else {
10111 &[][..]
10112 };
10113 let out = sl_mut(*dst, base, m * n);
10114 dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10115 }
10116 }
10117
10118 Thunk::DequantMatMulGguf {
10119 x,
10120 w_q,
10121 dst,
10122 m,
10123 k,
10124 n,
10125 scheme,
10126 } => {
10127 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10128 let block_bytes = scheme.gguf_block_bytes() as usize;
10129 let block_elems = scheme.gguf_block_size() as usize;
10130 debug_assert!(
10131 block_bytes > 0 && block_elems > 0,
10132 "non-GGUF scheme in GGUF arm"
10133 );
10134 debug_assert!(
10135 (k * n).is_multiple_of(block_elems),
10136 "k*n={} not aligned to GGUF block size {}",
10137 k * n,
10138 block_elems
10139 );
10140 let total_bytes = (k * n) / block_elems * block_bytes;
10141 unsafe {
10142 let xs = sl(*x, base, m * k);
10143 let w_bytes_ptr = base.add(*w_q) as *const u8;
10144 let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
10145 let out = sl_mut(*dst, base, m * n);
10146 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
10147 }
10148 }
10149
10150 Thunk::DequantMatMulInt4 {
10151 x,
10152 w_q,
10153 scale,
10154 zp,
10155 dst,
10156 m,
10157 k,
10158 n,
10159 block_size,
10160 is_asymmetric,
10161 } => {
10162 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10163 let n_blocks = k.div_ceil(bs);
10164 unsafe {
10165 let xs = sl(*x, base, m * k);
10166 let w_bytes = std::slice::from_raw_parts(
10167 base.add(*w_q) as *const u8,
10168 (k * n).div_ceil(2),
10169 );
10170 let scales = sl(*scale, base, n_blocks * n);
10171 let zps = if *is_asymmetric {
10172 sl(*zp, base, n_blocks * n)
10173 } else {
10174 &[][..]
10175 };
10176 let out = sl_mut(*dst, base, m * n);
10177 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10178 }
10179 }
10180
10181 Thunk::DequantMatMulFp8 {
10182 x,
10183 w_q,
10184 scale,
10185 dst,
10186 m,
10187 k,
10188 n,
10189 e5m2,
10190 } => {
10191 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10192 unsafe {
10193 let xs = sl(*x, base, m * k);
10194 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
10195 let scales = sl(*scale, base, n);
10196 let out = sl_mut(*dst, base, m * n);
10197 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
10198 }
10199 }
10200
10201 Thunk::DequantMatMulNvfp4 {
10202 x,
10203 w_q,
10204 scale,
10205 global_scale,
10206 dst,
10207 m,
10208 k,
10209 n,
10210 } => {
10211 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10212 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
10213 unsafe {
10214 let xs = sl(*x, base, m * k);
10215 let w_bytes = std::slice::from_raw_parts(
10216 base.add(*w_q) as *const u8,
10217 (k * n).div_ceil(2),
10218 );
10219 let scale_bytes =
10220 std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
10221 let gs = sl(*global_scale, base, 1)[0];
10222 let out = sl_mut(*dst, base, m * n);
10223 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
10224 }
10225 }
10226
10227 Thunk::LoraMatMul {
10228 x,
10229 w,
10230 a,
10231 b,
10232 dst,
10233 m,
10234 k,
10235 n,
10236 r,
10237 scale,
10238 } => {
10239 let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
10240 unsafe {
10241 let xs = sl(*x, base, m * k);
10242 let ws = sl(*w, base, k * n);
10243 let a_s = sl(*a, base, k * r);
10244 let bs = sl(*b, base, r * n);
10245 let out = sl_mut(*dst, base, m * n);
10246 crate::blas::sgemm(xs, ws, out, m, k, n);
10247 let mut tmp = vec![0f32; m * r];
10248 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
10249 if *scale != 1.0 {
10250 for v in tmp.iter_mut() {
10251 *v *= *scale;
10252 }
10253 }
10254 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
10255 }
10256 }
10257
10258 Thunk::Attention {
10259 q,
10260 k,
10261 v,
10262 mask,
10263 out,
10264 batch,
10265 seq,
10266 kv_seq,
10267 heads,
10268 head_dim,
10269 mask_kind,
10270 q_row_stride,
10271 k_row_stride,
10272 v_row_stride,
10273 bhsd,
10274 } => {
10275 let (b, q_s, k_s, nh, dh) = (
10276 *batch as usize,
10277 *seq as usize,
10278 *kv_seq as usize,
10279 *heads as usize,
10280 *head_dim as usize,
10281 );
10282 let hs = nh * dh;
10283 let (qrs, krs, vrs) = if *bhsd {
10286 (dh, dh, dh)
10287 } else {
10288 (
10289 *q_row_stride as usize,
10290 *k_row_stride as usize,
10291 *v_row_stride as usize,
10292 )
10293 };
10294 let bhsd = *bhsd;
10295 let _ = (q_row_stride, k_row_stride, v_row_stride);
10296 let scale = (dh as f32).powf(-0.5);
10297 let ss = q_s * k_s;
10298 let cfg = crate::config::RuntimeConfig::global();
10299 unsafe {
10300 let q_len = if bhsd {
10307 b * nh * q_s * dh
10308 } else {
10309 b * q_s * qrs
10310 };
10311 let k_len = if bhsd {
10312 b * nh * k_s * dh
10313 } else {
10314 b * k_s * krs
10315 };
10316 let v_len = if bhsd {
10317 b * nh * k_s * dh
10318 } else {
10319 b * k_s * vrs
10320 };
10321 let q_data = sl(*q, base, q_len);
10322 let k_data = sl(*k, base, k_len);
10323 let v_data = sl(*v, base, v_len);
10324 let mask_data: &[f32] = match mask_kind {
10325 rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
10326 rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
10327 _ => &[],
10328 };
10329 let out_len = if bhsd {
10330 b * nh * q_s * dh
10331 } else {
10332 b * q_s * hs
10333 };
10334 let out_data = sl_mut(*out, base, out_len);
10335
10336 if bhsd {
10347 let scores = &mut sdpa_scores[..ss];
10348 for bi in 0..b {
10349 for hi in 0..nh {
10350 let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
10351 let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
10352 for qi in 0..q_s {
10354 let q_base = q_head_base + qi * dh;
10355 for ki in 0..k_s {
10356 let k_base = k_head_base + ki * dh;
10357 let mut dot = 0f32;
10358 for d in 0..dh {
10359 dot += q_data[q_base + d] * k_data[k_base + d];
10360 }
10361 scores[qi * k_s + ki] = dot * scale;
10362 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10363 && !mask_data.is_empty()
10364 && mask_data[bi * k_s + ki] < mask_thr
10365 {
10366 scores[qi * k_s + ki] = mask_neg;
10367 }
10368 }
10369 }
10370 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10371 let off = (bi * nh + hi) * q_s * k_s;
10372 for i in 0..q_s * k_s {
10373 scores[i] += mask_data[off + i];
10374 }
10375 }
10376 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10377 crate::kernels::neon_softmax(scores, q_s, k_s);
10378 for qi in 0..q_s {
10380 let o_base = q_head_base + qi * dh;
10381 for d in 0..dh {
10382 out_data[o_base + d] = 0.0;
10383 }
10384 for ki in 0..k_s {
10385 let sc = scores[qi * k_s + ki];
10386 if sc > score_thr {
10387 let v_base = k_head_base + ki * dh;
10388 for d in 0..dh {
10389 out_data[o_base + d] += sc * v_data[v_base + d];
10390 }
10391 }
10392 }
10393 }
10394 }
10395 }
10396 continue;
10397 }
10398
10399 if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
10406 let scores = &mut sdpa_scores[..ss];
10408 #[cfg(target_arch = "aarch64")]
10409 let neon_chunks = dh / 4;
10410
10411 for bi in 0..b {
10412 for hi in 0..nh {
10413 for qi in 0..q_s {
10415 let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
10416 for ki in 0..k_s {
10417 let k_off = bi * k_s * krs + ki * krs + hi * dh;
10418 #[cfg(target_arch = "aarch64")]
10419 let mut dot;
10420 #[cfg(not(target_arch = "aarch64"))]
10421 let mut dot = 0f32;
10422 #[cfg(target_arch = "aarch64")]
10423 {
10424 use std::arch::aarch64::*;
10425 let mut acc = vdupq_n_f32(0.0);
10426 for c in 0..neon_chunks {
10427 let vq =
10428 vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
10429 let vk =
10430 vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
10431 acc = vfmaq_f32(acc, vq, vk);
10432 }
10433 dot = vaddvq_f32(acc);
10434 for d in (neon_chunks * 4)..dh {
10435 dot += q_data[q_off + d] * k_data[k_off + d];
10436 }
10437 }
10438 #[cfg(not(target_arch = "aarch64"))]
10439 for d in 0..dh {
10440 dot += q_data[q_off + d] * k_data[k_off + d];
10441 }
10442 scores[qi * k_s + ki] = dot * scale;
10443 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10450 && !mask_data.is_empty()
10451 && mask_data[bi * k_s + ki] < mask_thr
10452 {
10453 scores[qi * k_s + ki] = mask_neg;
10454 }
10455 }
10456 }
10457
10458 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10459 let off = (bi * nh + hi) * q_s * k_s;
10460 for i in 0..q_s * k_s {
10461 scores[i] += mask_data[off + i];
10462 }
10463 }
10464 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10465 crate::kernels::neon_softmax(scores, q_s, k_s);
10466
10467 for qi in 0..q_s {
10469 let o_off = bi * q_s * hs + qi * hs + hi * dh;
10470 for d in 0..dh {
10472 out_data[o_off + d] = 0.0;
10473 }
10474 for ki in 0..k_s {
10475 let sc = scores[qi * k_s + ki];
10476 if sc > score_thr {
10477 let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
10478 #[cfg(target_arch = "aarch64")]
10479 {
10480 use std::arch::aarch64::*;
10481 let vsc = vdupq_n_f32(sc);
10482 for c in 0..neon_chunks {
10483 let off = c * 4;
10484 let vo = vld1q_f32(
10485 out_data.as_ptr().add(o_off + off),
10486 );
10487 let vv =
10488 vld1q_f32(v_data.as_ptr().add(v_off + off));
10489 vst1q_f32(
10490 out_data.as_mut_ptr().add(o_off + off),
10491 vfmaq_f32(vo, vsc, vv),
10492 );
10493 }
10494 }
10495 #[cfg(not(target_arch = "aarch64"))]
10496 for d in 0..dh {
10497 out_data[o_off + d] += sc * v_data[v_off + d];
10498 }
10499 }
10500 }
10501 }
10502 }
10503 }
10504 } else {
10505 let total_work = b * nh;
10507 let q_addr = q_data.as_ptr() as usize;
10508 let k_addr = k_data.as_ptr() as usize;
10509 let v_addr = v_data.as_ptr() as usize;
10510 let m_addr = mask_data.as_ptr() as usize;
10511 let o_addr = out_data.as_mut_ptr() as usize;
10512 let sc_addr = sdpa_scores.as_mut_ptr() as usize;
10513
10514 crate::pool::par_for(total_work, 1, &|off, cnt| {
10515 for idx in off..off + cnt {
10516 let bi = idx / nh;
10517 let hi = idx % nh;
10518
10519 let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
10520 let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
10521 let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
10522 let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
10523 let sc = std::slice::from_raw_parts_mut(
10524 (sc_addr as *mut f32).add(idx * ss),
10525 ss,
10526 );
10527
10528 crate::blas::sgemm_general(
10531 q_start,
10532 k_start,
10533 sc.as_mut_ptr(),
10534 q_s,
10535 k_s,
10536 dh,
10537 scale,
10538 0.0,
10539 qrs,
10540 krs,
10541 k_s,
10542 false,
10543 true,
10544 );
10545
10546 match mask_kind {
10547 rlx_ir::op::MaskKind::Custom => {
10548 let mask_bi = std::slice::from_raw_parts(
10549 (m_addr as *const f32).add(bi * k_s),
10550 k_s,
10551 );
10552 for ki in 0..k_s {
10553 if mask_bi[ki] < mask_thr {
10554 for qi in 0..q_s {
10555 sc[qi * k_s + ki] = mask_neg;
10556 }
10557 }
10558 }
10559 }
10560 rlx_ir::op::MaskKind::Bias => {
10561 let bias = std::slice::from_raw_parts(
10563 (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
10564 q_s * k_s,
10565 );
10566 for i in 0..q_s * k_s {
10567 sc[i] += bias[i];
10568 }
10569 }
10570 _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
10571 }
10572
10573 crate::kernels::neon_softmax(sc, q_s, k_s);
10574
10575 crate::blas::sgemm_general(
10579 sc.as_ptr(),
10580 v_start,
10581 o_start,
10582 q_s,
10583 dh,
10584 k_s,
10585 1.0,
10586 0.0,
10587 k_s,
10588 vrs,
10589 hs,
10590 false,
10591 false,
10592 );
10593 }
10594 });
10595 }
10596 }
10597 }
10598
10599 Thunk::AttentionBackward {
10600 q,
10601 k,
10602 v,
10603 dy,
10604 mask,
10605 out,
10606 batch,
10607 seq,
10608 kv_seq,
10609 heads,
10610 head_dim,
10611 mask_kind,
10612 wrt,
10613 bhsd,
10614 } => {
10615 let (b, q_s, k_s, nh, dh) = (
10616 *batch as usize,
10617 *seq as usize,
10618 *kv_seq as usize,
10619 *heads as usize,
10620 *head_dim as usize,
10621 );
10622 unsafe {
10623 let q_len = if *bhsd {
10624 b * nh * q_s * dh
10625 } else {
10626 b * q_s * nh * dh
10627 };
10628 let k_len = if *bhsd {
10629 b * nh * k_s * dh
10630 } else {
10631 b * k_s * nh * dh
10632 };
10633 let out_len = match wrt {
10634 rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
10635 k_len
10636 }
10637 rlx_ir::op::AttentionBwdWrt::Query => q_len,
10638 };
10639 let q_data = sl(*q, base, q_len);
10640 let k_data = sl(*k, base, k_len);
10641 let v_data = sl(*v, base, k_len);
10642 let dy_data = sl(*dy, base, q_len);
10643 let out_data = sl_mut(*out, base, out_len);
10644 let mask_data: &[f32] = if *mask != 0 {
10645 let ml = match mask_kind {
10646 rlx_ir::op::MaskKind::Custom => b * k_s,
10647 rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
10648 _ => 0,
10649 };
10650 sl(*mask, base, ml)
10651 } else {
10652 &[]
10653 };
10654 crate::attention_bwd::attention_backward(
10655 *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
10656 *mask_kind, mask_data, *bhsd,
10657 );
10658 }
10659 }
10660
10661 Thunk::ActivationInPlace { data, len, act } => {
10662 let len = *len as usize;
10663 unsafe {
10664 let d = sl_mut(*data, base, len);
10665 match act {
10666 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
10667 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
10668 Activation::Silu => crate::kernels::par_silu_inplace(d),
10669 Activation::Relu => {
10670 for v in d.iter_mut() {
10671 *v = v.max(0.0);
10672 }
10673 }
10674 Activation::Sigmoid => {
10675 for v in d.iter_mut() {
10676 *v = 1.0 / (1.0 + (-*v).exp());
10677 }
10678 }
10679 Activation::Tanh => {
10680 for v in d.iter_mut() {
10681 *v = v.tanh();
10682 }
10683 }
10684 Activation::Exp => {
10685 for v in d.iter_mut() {
10686 *v = v.exp();
10687 }
10688 }
10689 Activation::Log => {
10690 for v in d.iter_mut() {
10691 *v = v.ln();
10692 }
10693 }
10694 Activation::Sqrt => {
10695 for v in d.iter_mut() {
10696 *v = v.sqrt();
10697 }
10698 }
10699 Activation::Rsqrt => {
10700 for v in d.iter_mut() {
10701 *v = 1.0 / v.sqrt();
10702 }
10703 }
10704 Activation::Neg => {
10705 for v in d.iter_mut() {
10706 *v = -*v;
10707 }
10708 }
10709 Activation::Abs => {
10710 for v in d.iter_mut() {
10711 *v = v.abs();
10712 }
10713 }
10714 Activation::Round => {
10715 for v in d.iter_mut() {
10716 *v = v.round();
10717 }
10718 }
10719 Activation::Sin => {
10720 for v in d.iter_mut() {
10721 *v = v.sin();
10722 }
10723 }
10724 Activation::Cos => {
10725 for v in d.iter_mut() {
10726 *v = v.cos();
10727 }
10728 }
10729 Activation::Tan => {
10730 for v in d.iter_mut() {
10731 *v = v.tan();
10732 }
10733 }
10734 Activation::Atan => {
10735 for v in d.iter_mut() {
10736 *v = v.atan();
10737 }
10738 }
10739 }
10740 }
10741 }
10742
10743 Thunk::FusedAttnBlock {
10744 hidden,
10745 qkv_w,
10746 out_w,
10747 mask,
10748 out,
10749 qkv_b,
10750 out_b,
10751 cos,
10752 sin,
10753 cos_len,
10754 batch,
10755 seq,
10756 hs,
10757 nh,
10758 dh,
10759 has_bias,
10760 has_rope,
10761 } => {
10762 let (b, s) = (*batch as usize, *seq as usize);
10763 let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
10764 let m = b * s;
10765 let scale = (d_h as f32).powf(-0.5);
10766 let half = d_h / 2;
10767 unsafe {
10768 let inp = sl(*hidden, base, m * h);
10769 let wq = sl(*qkv_w, base, h * 3 * h);
10770 let wo = sl(*out_w, base, h * h);
10771 let mk = sl(*mask, base, b * s);
10772 let dst = sl_mut(*out, base, m * h);
10773
10774 let mut qkv = vec![0f32; m * 3 * h];
10776 let mut attn_out = vec![0f32; m * h];
10777 let mut scores_buf = vec![0f32; s * s]; crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
10781 if *has_bias {
10782 let bias = sl(*qkv_b, base, 3 * h);
10783 crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
10784 }
10785
10786 #[cfg(target_arch = "aarch64")]
10789 let neon_chunks = d_h / 4;
10790 #[cfg(target_arch = "aarch64")]
10791 let _rope_chunks = half / 4;
10792
10793 for bi in 0..b {
10794 for hi in 0..n_h {
10795 for qi in 0..s {
10797 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10798 for ki in 0..s {
10799 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10800 let mut dot = 0f32;
10801
10802 if *has_rope {
10803 let q_cos = qi * half;
10805 let k_cos = ki * half;
10806 let cos_tab = sl(*cos, base, *cos_len as usize);
10807 let sin_tab = sl(*sin, base, *cos_len as usize);
10808 for i in 0..half {
10811 let q1 = qkv[q_base + i];
10812 let q2 = qkv[q_base + half + i];
10813 let k1 = qkv[k_base + i];
10814 let k2 = qkv[k_base + half + i];
10815 let c_q = cos_tab[q_cos + i];
10816 let s_q = sin_tab[q_cos + i];
10817 let c_k = cos_tab[k_cos + i];
10818 let s_k = sin_tab[k_cos + i];
10819 let qr1 = q1 * c_q - q2 * s_q;
10820 let kr1 = k1 * c_k - k2 * s_k;
10821 let qr2 = q2 * c_q + q1 * s_q;
10822 let kr2 = k2 * c_k + k1 * s_k;
10823 dot += qr1 * kr1 + qr2 * kr2;
10824 }
10825 } else {
10826 #[cfg(target_arch = "aarch64")]
10828 {
10829 use std::arch::aarch64::*;
10830 let mut acc = vdupq_n_f32(0.0);
10831 for c in 0..neon_chunks {
10832 let vq =
10833 vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
10834 let vk =
10835 vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
10836 acc = vfmaq_f32(acc, vq, vk);
10837 }
10838 dot = vaddvq_f32(acc);
10839 for d in (neon_chunks * 4)..d_h {
10840 dot += qkv[q_base + d] * qkv[k_base + d];
10841 }
10842 }
10843 #[cfg(not(target_arch = "aarch64"))]
10844 for d in 0..d_h {
10845 dot += qkv[q_base + d] * qkv[k_base + d];
10846 }
10847 }
10848
10849 scores_buf[qi * s + ki] = dot * scale;
10850 if mk[bi * s + ki] < mask_thr {
10851 scores_buf[qi * s + ki] = mask_neg;
10852 }
10853 }
10854 }
10855
10856 crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
10858
10859 for qi in 0..s {
10861 let o_base = bi * s * h + qi * h + hi * d_h;
10862 for d in 0..d_h {
10863 attn_out[o_base + d] = 0.0;
10864 }
10865 for ki in 0..s {
10866 let sc = scores_buf[qi * s + ki];
10867 if sc > score_thr {
10868 let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10869 #[cfg(target_arch = "aarch64")]
10870 {
10871 use std::arch::aarch64::*;
10872 let vsc = vdupq_n_f32(sc);
10873 for c in 0..neon_chunks {
10874 let off = c * 4;
10875 let vo =
10876 vld1q_f32(attn_out.as_ptr().add(o_base + off));
10877 let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
10878 vst1q_f32(
10879 attn_out.as_mut_ptr().add(o_base + off),
10880 vfmaq_f32(vo, vsc, vv),
10881 );
10882 }
10883 }
10884 #[cfg(not(target_arch = "aarch64"))]
10885 for d in 0..d_h {
10886 attn_out[o_base + d] += sc * qkv[v_base + d];
10887 }
10888 }
10889 }
10890 }
10891 }
10892 }
10893
10894 crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
10896 if *has_bias {
10897 let bias = sl(*out_b, base, h);
10898 crate::blas::bias_add(dst, bias, m, h);
10899 }
10900 }
10901 }
10902
10903 Thunk::Rope {
10904 src,
10905 cos,
10906 sin,
10907 dst,
10908 batch,
10909 seq,
10910 hidden,
10911 head_dim,
10912 n_rot,
10913 cos_len,
10914 src_row_stride,
10915 } => {
10916 let (b, s, hs, dh, nr) = (
10917 *batch as usize,
10918 *seq as usize,
10919 *hidden as usize,
10920 *head_dim as usize,
10921 *n_rot as usize,
10922 );
10923 let tab_half = dh / 2;
10924 let rot_half = nr / 2;
10925 let nh = hs / dh;
10926 let cl = *cos_len as usize;
10927 let src_rs = *src_row_stride as usize;
10928 unsafe {
10929 let x = sl(*src, base, b * s * src_rs);
10930 let cos_tab = sl(*cos, base, cl);
10931 let sin_tab = sl(*sin, base, cl);
10932 let out = sl_mut(*dst, base, b * s * hs);
10933
10934 let total = b * s;
10935 let x_ptr = x.as_ptr() as usize;
10936 let o_ptr = out.as_mut_ptr() as usize;
10937 let c_ptr = cos_tab.as_ptr() as usize;
10938 let s_ptr = sin_tab.as_ptr() as usize;
10939
10940 crate::pool::par_for(total, 4, &|off, cnt| {
10941 for idx in off..off + cnt {
10942 let bi = idx / s;
10943 let si = idx % s;
10944 let tab_off = si * tab_half;
10945
10946 for hi in 0..nh {
10947 let src_base = bi * s * src_rs + si * src_rs + hi * dh;
10948 let dst_base = bi * s * hs + si * hs + hi * dh;
10949 let xp = (x_ptr as *const f32).add(src_base);
10950 let op = (o_ptr as *mut f32).add(dst_base);
10951 let cp = (c_ptr as *const f32).add(tab_off);
10952 let sp = (s_ptr as *const f32).add(tab_off);
10953
10954 for i in 0..rot_half {
10955 let x1 = *xp.add(i);
10956 let x2 = *xp.add(rot_half + i);
10957 let cv = *cp.add(i);
10958 let sv = *sp.add(i);
10959 *op.add(i) = x1 * cv - x2 * sv;
10960 *op.add(rot_half + i) = x2 * cv + x1 * sv;
10961 }
10962 for j in nr..dh {
10963 *op.add(j) = *xp.add(j);
10964 }
10965 }
10966 }
10967 });
10968 }
10969 }
10970 Thunk::FusedBertLayer {
10971 hidden,
10972 qkv_w,
10973 qkv_b,
10974 out_w,
10975 out_b,
10976 mask,
10977 ln1_g,
10978 ln1_b,
10979 eps1,
10980 fc1_w,
10981 fc1_b,
10982 fc2_w,
10983 fc2_b,
10984 ln2_g,
10985 ln2_b,
10986 eps2,
10987 out,
10988 batch,
10989 seq,
10990 hs,
10991 nh,
10992 dh,
10993 int_dim,
10994 } => {
10995 let (b, s, h, n_h, d_h) = (
10996 *batch as usize,
10997 *seq as usize,
10998 *hs as usize,
10999 *nh as usize,
11000 *dh as usize,
11001 );
11002 let m = b * s;
11003 let id = *int_dim as usize;
11004 let scale = (d_h as f32).powf(-0.5);
11005 let _half = d_h / 2;
11006 #[cfg(target_arch = "aarch64")]
11007 let neon_chunks = d_h / 4;
11008 unsafe {
11009 let inp = sl(*hidden, base, m * h);
11010 let dst = sl_mut(*out, base, m * h);
11011 let mk = sl(*mask, base, b * s);
11012
11013 let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
11015 let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
11016 let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
11017 let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
11018 let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
11019 let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
11020
11021 crate::blas::par_sgemm_bias(
11023 inp,
11024 sl(*qkv_w, base, h * 3 * h),
11025 sl(*qkv_b, base, 3 * h),
11026 qkv,
11027 m,
11028 h,
11029 3 * h,
11030 );
11031
11032 for bi in 0..b {
11034 for hi in 0..n_h {
11035 for qi in 0..s {
11036 for ki in 0..s {
11037 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11038 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11039 #[cfg(target_arch = "aarch64")]
11040 let dot;
11041 #[cfg(not(target_arch = "aarch64"))]
11042 let mut dot = 0f32;
11043 #[cfg(target_arch = "aarch64")]
11044 {
11045 use std::arch::aarch64::*;
11046 let mut acc = vdupq_n_f32(0.0);
11047 for c in 0..neon_chunks {
11048 acc = vfmaq_f32(
11049 acc,
11050 vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
11051 vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
11052 );
11053 }
11054 dot = vaddvq_f32(acc);
11055 }
11056 #[cfg(not(target_arch = "aarch64"))]
11057 for d in 0..d_h {
11058 dot += qkv[q_base + d] * qkv[k_base + d];
11059 }
11060 sc[qi * s + ki] = dot * scale;
11061 if mk[bi * s + ki] < mask_thr {
11062 sc[qi * s + ki] = mask_neg;
11063 }
11064 }
11065 }
11066 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11067 for qi in 0..s {
11068 let o = bi * s * h + qi * h + hi * d_h;
11069 for d in 0..d_h {
11070 attn[o + d] = 0.0;
11071 }
11072 for ki in 0..s {
11073 let w = sc[qi * s + ki];
11074 if w > score_thr {
11075 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11076 #[cfg(target_arch = "aarch64")]
11077 {
11078 use std::arch::aarch64::*;
11079 let vw = vdupq_n_f32(w);
11080 for c in 0..neon_chunks {
11081 let off = c * 4;
11082 vst1q_f32(
11083 attn.as_mut_ptr().add(o + off),
11084 vfmaq_f32(
11085 vld1q_f32(attn.as_ptr().add(o + off)),
11086 vw,
11087 vld1q_f32(qkv.as_ptr().add(v + off)),
11088 ),
11089 );
11090 }
11091 }
11092 #[cfg(not(target_arch = "aarch64"))]
11093 for d in 0..d_h {
11094 attn[o + d] += w * qkv[v + d];
11095 }
11096 }
11097 }
11098 }
11099 }
11100 }
11101
11102 crate::blas::sgemm_bias(
11104 attn,
11105 sl(*out_w, base, h * h),
11106 sl(*out_b, base, h),
11107 res,
11108 m,
11109 h,
11110 h,
11111 );
11112 #[cfg(target_arch = "aarch64")]
11113 {
11114 use std::arch::aarch64::*;
11115 let chunks_h = (m * h) / 4;
11116 for c in 0..chunks_h {
11117 let off = c * 4;
11118 vst1q_f32(
11119 res.as_mut_ptr().add(off),
11120 vaddq_f32(
11121 vld1q_f32(res.as_ptr().add(off)),
11122 vld1q_f32(inp.as_ptr().add(off)),
11123 ),
11124 );
11125 }
11126 for i in (chunks_h * 4)..(m * h) {
11127 res[i] += inp[i];
11128 }
11129 }
11130 #[cfg(not(target_arch = "aarch64"))]
11131 for i in 0..m * h {
11132 res[i] += inp[i];
11133 }
11134
11135 let g1 = sl(*ln1_g, base, h);
11137 let b1 = sl(*ln1_b, base, h);
11138 for r in 0..m {
11139 crate::kernels::layer_norm_row(
11140 &res[r * h..(r + 1) * h],
11141 g1,
11142 b1,
11143 &mut normed[r * h..(r + 1) * h],
11144 h,
11145 *eps1,
11146 );
11147 }
11148
11149 crate::blas::par_sgemm_bias(
11151 normed,
11152 sl(*fc1_w, base, h * id),
11153 sl(*fc1_b, base, id),
11154 ffn,
11155 m,
11156 h,
11157 id,
11158 );
11159 crate::kernels::par_gelu_inplace(ffn);
11160
11161 crate::blas::par_sgemm_bias(
11163 ffn,
11164 sl(*fc2_w, base, id * h),
11165 sl(*fc2_b, base, h),
11166 res,
11167 m,
11168 id,
11169 h,
11170 );
11171 #[cfg(target_arch = "aarch64")]
11172 {
11173 use std::arch::aarch64::*;
11174 let chunks_h = (m * h) / 4;
11175 for c in 0..chunks_h {
11176 let off = c * 4;
11177 vst1q_f32(
11178 res.as_mut_ptr().add(off),
11179 vaddq_f32(
11180 vld1q_f32(res.as_ptr().add(off)),
11181 vld1q_f32(normed.as_ptr().add(off)),
11182 ),
11183 );
11184 }
11185 for i in (chunks_h * 4)..(m * h) {
11186 res[i] += normed[i];
11187 }
11188 }
11189 #[cfg(not(target_arch = "aarch64"))]
11190 for i in 0..m * h {
11191 res[i] += normed[i];
11192 }
11193
11194 let g2 = sl(*ln2_g, base, h);
11196 let b2 = sl(*ln2_b, base, h);
11197 for r in 0..m {
11198 crate::kernels::layer_norm_row(
11199 &res[r * h..(r + 1) * h],
11200 g2,
11201 b2,
11202 &mut dst[r * h..(r + 1) * h],
11203 h,
11204 *eps2,
11205 );
11206 }
11207 }
11208 }
11209
11210 Thunk::FusedNomicLayer {
11211 hidden,
11212 qkv_w,
11213 out_w,
11214 mask,
11215 cos,
11216 sin,
11217 cos_len,
11218 ln1_g,
11219 ln1_b,
11220 eps1,
11221 fc11_w,
11222 fc12_w: _,
11223 fc2_w,
11224 ln2_g,
11225 ln2_b,
11226 eps2,
11227 out,
11228 batch,
11229 seq,
11230 hs,
11231 nh,
11232 dh,
11233 int_dim,
11234 } => {
11235 let (b, s, h, n_h, d_h) = (
11236 *batch as usize,
11237 *seq as usize,
11238 *hs as usize,
11239 *nh as usize,
11240 *dh as usize,
11241 );
11242 let m = b * s;
11243 let id = *int_dim as usize;
11244 let scale = (d_h as f32).powf(-0.5);
11245 let half_dh = d_h / 2;
11246 #[cfg(target_arch = "aarch64")]
11247 let neon_chunks = d_h / 4;
11248 unsafe {
11249 let inp = sl(*hidden, base, m * h);
11250 let dst = sl_mut(*out, base, m * h);
11251 let mk = sl(*mask, base, b * s);
11252 let cos_tab = sl(*cos, base, *cos_len as usize);
11253 let sin_tab = sl(*sin, base, *cos_len as usize);
11254 let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
11256
11257 let mut qkv = vec![0f32; m * 3 * h];
11258 let mut attn = vec![0f32; m * h];
11259 let mut res = vec![0f32; m * h];
11260 let mut normed = vec![0f32; m * h];
11261 let mut ffn_concat = vec![0f32; m * 2 * id]; let mut sc = vec![0f32; s * s];
11263
11264 crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
11266
11267 for bi in 0..b {
11269 for hi in 0..n_h {
11270 for qi in 0..s {
11271 for ki in 0..s {
11272 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11273 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11274 let mut dot = 0f32;
11275 for i in 0..half_dh {
11276 let q1 = qkv[q_base + i];
11277 let q2 = qkv[q_base + half_dh + i];
11278 let k1 = qkv[k_base + i];
11279 let k2 = qkv[k_base + half_dh + i];
11280 let cq = cos_tab[qi * half_dh + i];
11281 let sq = sin_tab[qi * half_dh + i];
11282 let ck = cos_tab[ki * half_dh + i];
11283 let sk = sin_tab[ki * half_dh + i];
11284 dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
11285 + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
11286 }
11287 sc[qi * s + ki] = dot * scale;
11288 if mk[bi * s + ki] < mask_thr {
11289 sc[qi * s + ki] = mask_neg;
11290 }
11291 }
11292 }
11293 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11294 for qi in 0..s {
11295 let o = bi * s * h + qi * h + hi * d_h;
11296 for d in 0..d_h {
11297 attn[o + d] = 0.0;
11298 }
11299 for ki in 0..s {
11300 let w = sc[qi * s + ki];
11301 if w > score_thr {
11302 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11303 #[cfg(target_arch = "aarch64")]
11304 {
11305 use std::arch::aarch64::*;
11306 let vw = vdupq_n_f32(w);
11307 for c in 0..neon_chunks {
11308 let off = c * 4;
11309 vst1q_f32(
11310 attn.as_mut_ptr().add(o + off),
11311 vfmaq_f32(
11312 vld1q_f32(attn.as_ptr().add(o + off)),
11313 vw,
11314 vld1q_f32(qkv.as_ptr().add(v + off)),
11315 ),
11316 );
11317 }
11318 }
11319 #[cfg(not(target_arch = "aarch64"))]
11320 for d in 0..d_h {
11321 attn[o + d] += w * qkv[v + d];
11322 }
11323 }
11324 }
11325 }
11326 }
11327 }
11328
11329 crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
11331 for i in 0..m * h {
11332 res[i] += inp[i];
11333 }
11334
11335 let g1 = sl(*ln1_g, base, h);
11337 let b1 = sl(*ln1_b, base, h);
11338 for r in 0..m {
11339 crate::kernels::layer_norm_row(
11340 &res[r * h..(r + 1) * h],
11341 g1,
11342 b1,
11343 &mut normed[r * h..(r + 1) * h],
11344 h,
11345 *eps1,
11346 );
11347 }
11348
11349 crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
11351 for row in 0..m {
11354 let bo = row * 2 * id;
11355 for j in 0..id {
11357 let x = ffn_concat[bo + id + j];
11358 ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
11359 }
11360 for j in 0..id {
11362 ffn_concat[bo + j] *= ffn_concat[bo + id + j];
11363 }
11364 }
11365
11366 crate::blas::sgemm_general(
11371 ffn_concat.as_ptr(),
11372 sl(*fc2_w, base, id * h).as_ptr(),
11373 res.as_mut_ptr(),
11374 m,
11375 h,
11376 id,
11377 1.0,
11378 0.0,
11379 2 * id,
11380 h,
11381 h,
11382 false,
11383 false,
11384 );
11385 for i in 0..m * h {
11386 res[i] += normed[i];
11387 }
11388
11389 let g2 = sl(*ln2_g, base, h);
11391 let b2 = sl(*ln2_b, base, h);
11392 for r in 0..m {
11393 crate::kernels::layer_norm_row(
11394 &res[r * h..(r + 1) * h],
11395 g2,
11396 b2,
11397 &mut dst[r * h..(r + 1) * h],
11398 h,
11399 *eps2,
11400 );
11401 }
11402 }
11403 }
11404
11405 Thunk::FusedSwiGLU {
11406 src,
11407 dst,
11408 n_half,
11409 total,
11410 gate_first,
11411 } => {
11412 let n = *n_half as usize;
11413 let t = *total as usize;
11414 let outer = t / n;
11415 let in_total = outer * 2 * n;
11416 let gate_first = *gate_first;
11417 unsafe {
11418 let inp = sl(*src, base, in_total);
11419 let out = sl_mut(*dst, base, t);
11420 for o in 0..outer {
11421 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
11422 let out_row = &mut out[o * n..(o + 1) * n];
11423 for i in 0..n {
11424 let (up, gate) = if gate_first {
11425 (in_row[n + i], in_row[i])
11426 } else {
11427 (in_row[i], in_row[n + i])
11428 };
11429 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
11430 }
11431 }
11432 }
11433 }
11434
11435 Thunk::Concat {
11436 dst,
11437 outer,
11438 inner,
11439 total_axis,
11440 inputs,
11441 } => {
11442 let outer = *outer as usize;
11443 let inner = *inner as usize;
11444 let total_axis = *total_axis as usize;
11445 let row_stride = total_axis * inner;
11446 let out_total = outer * row_stride;
11447 unsafe {
11448 let out = sl_mut(*dst, base, out_total);
11449 let mut cum: usize = 0;
11450 for (src_off, in_axis) in inputs {
11451 let in_axis = *in_axis as usize;
11452 let copy_per_row = in_axis * inner;
11453 let dst_col_off = cum * inner;
11454 let in_total = outer * copy_per_row;
11455 let inp = sl(*src_off, base, in_total);
11456 for o in 0..outer {
11457 let dst_row_start = o * row_stride + dst_col_off;
11458 let src_row_start = o * copy_per_row;
11459 out[dst_row_start..dst_row_start + copy_per_row]
11460 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
11461 }
11462 cum += in_axis;
11463 }
11464 }
11465 }
11466
11467 Thunk::ConcatF64 {
11468 dst,
11469 outer,
11470 inner,
11471 total_axis,
11472 inputs,
11473 } => {
11474 let outer = *outer as usize;
11475 let inner = *inner as usize;
11476 let total_axis = *total_axis as usize;
11477 let row_stride = total_axis * inner;
11478 let out_total = outer * row_stride;
11479 unsafe {
11480 let out = sl_mut_f64(*dst, base, out_total);
11481 let mut cum: usize = 0;
11482 for (src_off, in_axis) in inputs {
11483 let in_axis = *in_axis as usize;
11484 let copy_per_row = in_axis * inner;
11485 let dst_col_off = cum * inner;
11486 let in_total = outer * copy_per_row;
11487 let inp = sl_f64(*src_off, base, in_total);
11488 for o in 0..outer {
11489 let dst_row_start = o * row_stride + dst_col_off;
11490 let src_row_start = o * copy_per_row;
11491 out[dst_row_start..dst_row_start + copy_per_row]
11492 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
11493 }
11494 cum += in_axis;
11495 }
11496 }
11497 }
11498
11499 Thunk::Compare {
11500 lhs,
11501 rhs,
11502 dst,
11503 len,
11504 op,
11505 inputs_i64,
11506 inputs_elem_bytes,
11507 dst_elem_bytes,
11508 } => {
11509 let len = *len as usize;
11510 let arena_len = arena_buf.len();
11511 let elem = (*inputs_elem_bytes).max(1) as usize;
11512 let dst_eb = (*dst_elem_bytes).max(1) as usize;
11513 let max_l = (arena_len.saturating_sub(*lhs)) / elem;
11514 let max_r = (arena_len.saturating_sub(*rhs)) / elem;
11515 let max_d = (arena_len.saturating_sub(*dst)) / dst_eb;
11516 let len = len.min(max_l).min(max_r).min(max_d);
11517 if trace_thunks && len > 0 {
11518 eprintln!("[compare] len={len} lhs={} rhs={} dst={}", *lhs, *rhs, *dst);
11519 }
11520 if elem == 1 {
11521 let l = arena_buf[*lhs..*lhs + len].to_vec();
11522 let r = arena_buf[*rhs..*rhs + len].to_vec();
11523 for i in 0..len {
11524 let v = match op {
11525 CmpOp::Eq => l[i] == r[i],
11526 CmpOp::Ne => l[i] != r[i],
11527 CmpOp::Lt => l[i] < r[i],
11528 CmpOp::Le => l[i] <= r[i],
11529 CmpOp::Gt => l[i] > r[i],
11530 CmpOp::Ge => l[i] >= r[i],
11531 };
11532 if *dst_elem_bytes == 1 {
11533 arena_buf[*dst + i] = u8::from(v);
11534 } else {
11535 unsafe {
11536 let o = sl_mut(*dst, base, len);
11537 o[i] = if v { 1.0 } else { 0.0 };
11538 }
11539 }
11540 }
11541 } else if *inputs_i64 != 0 {
11542 unsafe {
11543 let l = sl_i64(*lhs, base, len);
11544 let r = sl_i64(*rhs, base, len);
11545 for i in 0..len {
11546 let v = match op {
11547 CmpOp::Eq => l[i] == r[i],
11548 CmpOp::Ne => l[i] != r[i],
11549 CmpOp::Lt => l[i] < r[i],
11550 CmpOp::Le => l[i] <= r[i],
11551 CmpOp::Gt => l[i] > r[i],
11552 CmpOp::Ge => l[i] >= r[i],
11553 };
11554 if *dst_elem_bytes == 1 {
11555 arena_buf[*dst + i] = u8::from(v);
11556 } else {
11557 let o = sl_mut(*dst, base, len);
11558 o[i] = if v { 1.0 } else { 0.0 };
11559 }
11560 }
11561 }
11562 } else {
11563 unsafe {
11564 let l = sl(*lhs, base, len);
11565 let r = sl(*rhs, base, len);
11566 for i in 0..len {
11567 let v = match op {
11568 CmpOp::Eq => l[i] == r[i],
11569 CmpOp::Ne => l[i] != r[i],
11570 CmpOp::Lt => l[i] < r[i],
11571 CmpOp::Le => l[i] <= r[i],
11572 CmpOp::Gt => l[i] > r[i],
11573 CmpOp::Ge => l[i] >= r[i],
11574 };
11575 if *dst_elem_bytes == 1 {
11576 arena_buf[*dst + i] = u8::from(v);
11577 } else {
11578 let o = sl_mut(*dst, base, len);
11579 o[i] = if v { 1.0 } else { 0.0 };
11580 }
11581 }
11582 }
11583 }
11584 }
11585
11586 Thunk::Where {
11587 cond,
11588 on_true,
11589 on_false,
11590 dst,
11591 len,
11592 elem_bytes,
11593 cond_elem_bytes,
11594 } => {
11595 let len = *len as usize;
11596 let eb = *elem_bytes as usize;
11597 let cond_eb = (*cond_elem_bytes).max(1) as usize;
11598 let arena_len = arena_buf.len();
11599 let len = len
11600 .min((arena_len.saturating_sub(*cond)) / cond_eb)
11601 .min((arena_len.saturating_sub(*on_true)) / eb)
11602 .min((arena_len.saturating_sub(*on_false)) / eb)
11603 .min((arena_len.saturating_sub(*dst)) / eb);
11604 unsafe {
11605 if *elem_bytes == 8 {
11606 let t = sl_i64(*on_true, base, len);
11607 let e = sl_i64(*on_false, base, len);
11608 let o = sl_mut_i64(*dst, base, len);
11609 if *cond_elem_bytes == 1 {
11610 let c = &arena_buf[*cond..*cond + len];
11611 for i in 0..len {
11612 o[i] = if c[i] != 0 { t[i] } else { e[i] };
11613 }
11614 } else {
11615 let c = sl_i64(*cond, base, len);
11616 for i in 0..len {
11617 o[i] = if c[i] != 0 { t[i] } else { e[i] };
11618 }
11619 }
11620 } else if *cond_elem_bytes == 1 {
11621 let c = &arena_buf[*cond..*cond + len];
11622 let t = sl(*on_true, base, len);
11623 let e = sl(*on_false, base, len);
11624 let o = sl_mut(*dst, base, len);
11625 for i in 0..len {
11626 o[i] = if c[i] != 0 { t[i] } else { e[i] };
11627 }
11628 } else {
11629 let c = sl(*cond, base, len);
11630 let t = sl(*on_true, base, len);
11631 let e = sl(*on_false, base, len);
11632 let o = sl_mut(*dst, base, len);
11633 for i in 0..len {
11634 o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
11635 }
11636 }
11637 }
11638 }
11639
11640 Thunk::ScatterAdd {
11641 updates,
11642 indices,
11643 dst,
11644 num_updates,
11645 out_dim,
11646 trailing,
11647 } => {
11648 let num_updates = *num_updates as usize;
11649 let out_dim = *out_dim as usize;
11650 let trailing = *trailing as usize;
11651 unsafe {
11652 let upd = sl(*updates, base, num_updates * trailing);
11653 let ids = sl(*indices, base, num_updates);
11654 let out = sl_mut(*dst, base, out_dim * trailing);
11655 for v in out.iter_mut() {
11657 *v = 0.0;
11658 }
11659 for i in 0..num_updates {
11660 let row = ids[i] as usize;
11661 debug_assert!(row < out_dim, "ScatterAdd index out of range");
11662 let src_off = i * trailing;
11663 let dst_off = row * trailing;
11664 for j in 0..trailing {
11665 out[dst_off + j] += upd[src_off + j];
11666 }
11667 }
11668 }
11669 }
11670
11671 Thunk::GroupedMatMul {
11672 input,
11673 weight,
11674 expert_idx,
11675 dst,
11676 m,
11677 k_dim,
11678 n,
11679 num_experts,
11680 } => {
11681 let m = *m as usize;
11682 let k_dim = *k_dim as usize;
11683 let n = *n as usize;
11684 let num_experts = *num_experts as usize;
11685 unsafe {
11686 let inp = sl(*input, base, m * k_dim);
11687 let wt = sl(*weight, base, num_experts * k_dim * n);
11688 let ids = sl(*expert_idx, base, m);
11689 let out = sl_mut(*dst, base, m * n);
11690
11691 let mut counts = vec![0usize; num_experts];
11694 for i in 0..m {
11695 let e = ids[i] as usize;
11696 debug_assert!(
11697 e < num_experts,
11698 "expert_idx out of range: {e} >= {num_experts}"
11699 );
11700 counts[e] += 1;
11701 }
11702 let mut offsets = vec![0usize; num_experts + 1];
11704 for e in 0..num_experts {
11705 offsets[e + 1] = offsets[e] + counts[e];
11706 }
11707 let mut packed_in = vec![0f32; m * k_dim];
11711 let mut original_pos = vec![0usize; m];
11712 let mut write_idx = vec![0usize; num_experts];
11713 for i in 0..m {
11714 let e = ids[i] as usize;
11715 let dst_row = offsets[e] + write_idx[e];
11716 packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
11717 .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
11718 original_pos[dst_row] = i;
11719 write_idx[e] += 1;
11720 }
11721
11722 let mut packed_out = vec![0f32; m * n];
11726 let expert_stride = k_dim * n;
11727 let gmm_ord = crate::moe_residency::next_gmm_ord();
11728 let moe_layer = gmm_ord / 3;
11729 for e in 0..num_experts {
11730 let count = counts[e];
11731 if count == 0 {
11732 continue;
11733 }
11734 crate::moe_residency::record_expert_tokens(moe_layer, e, count);
11735 let in_start = offsets[e];
11736 let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
11737 let w_slab: &[f32] =
11738 if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
11739 if let Some(ptr) =
11740 crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
11741 {
11742 std::slice::from_raw_parts(ptr, expert_stride)
11743 } else {
11744 &wt[e * expert_stride..(e + 1) * expert_stride]
11745 }
11746 } else {
11747 &wt[e * expert_stride..(e + 1) * expert_stride]
11748 };
11749 let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
11750 crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
11751 }
11752
11753 for packed_idx in 0..m {
11755 let i = original_pos[packed_idx];
11756 out[i * n..(i + 1) * n]
11757 .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
11758 }
11759 }
11760 }
11761
11762 Thunk::DequantGroupedMatMulGguf {
11763 input,
11764 w_q,
11765 expert_idx,
11766 dst,
11767 m,
11768 k_dim,
11769 n,
11770 num_experts,
11771 scheme,
11772 } => {
11773 let m = *m as usize;
11774 let k_dim = *k_dim as usize;
11775 let n = *n as usize;
11776 let num_experts = *num_experts as usize;
11777 let block_elems = scheme.gguf_block_size() as usize;
11778 let block_bytes = scheme.gguf_block_bytes() as usize;
11779 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
11780 unsafe {
11781 let inp = sl(*input, base, m * k_dim);
11782 let wt = std::slice::from_raw_parts(
11783 base.add(*w_q) as *const u8,
11784 num_experts * slab_bytes,
11785 );
11786 let ids = sl(*expert_idx, base, m);
11787 let out = sl_mut(*dst, base, m * n);
11788 crate::gguf_matmul::gguf_grouped_matmul_bt(
11789 inp,
11790 wt,
11791 ids,
11792 out,
11793 m,
11794 k_dim,
11795 n,
11796 num_experts,
11797 *scheme,
11798 );
11799 }
11800 }
11801
11802 Thunk::DequantMoEWeightsGguf {
11803 w_q,
11804 dst,
11805 k_dim,
11806 n,
11807 num_experts,
11808 scheme,
11809 } => {
11810 let k_dim = *k_dim as usize;
11811 let n = *n as usize;
11812 let num_experts = *num_experts as usize;
11813 let block_elems = scheme.gguf_block_size() as usize;
11814 let block_bytes = scheme.gguf_block_bytes() as usize;
11815 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
11816 unsafe {
11817 let wt = std::slice::from_raw_parts(
11818 base.add(*w_q) as *const u8,
11819 num_experts * slab_bytes,
11820 );
11821 let out = sl_mut(*dst, base, num_experts * k_dim * n);
11822 crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
11823 wt,
11824 out,
11825 num_experts,
11826 k_dim,
11827 n,
11828 *scheme,
11829 );
11830 }
11831 }
11832
11833 Thunk::TopK {
11834 src,
11835 dst,
11836 outer,
11837 axis_dim,
11838 k,
11839 indices_i64,
11840 } => {
11841 let outer = *outer as usize;
11842 let axis_dim = *axis_dim as usize;
11843 let k = *k as usize;
11844 unsafe {
11845 let inp = sl(*src, base, outer * axis_dim);
11846 let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
11850 if *indices_i64 != 0 {
11851 let out = sl_mut_i64(*dst, base, outer * k);
11852 for o in 0..outer {
11853 row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
11854 for ki in 0..k {
11855 let mut best_i = 0usize;
11856 let mut best_v = row_buf[0];
11857 for i in 1..axis_dim {
11858 let v = row_buf[i];
11859 if v > best_v {
11860 best_v = v;
11861 best_i = i;
11862 }
11863 }
11864 out[o * k + ki] = best_i as i64;
11865 row_buf[best_i] = f32::NEG_INFINITY;
11866 }
11867 }
11868 } else {
11869 let out = sl_mut(*dst, base, outer * k);
11870 for o in 0..outer {
11871 row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
11872 for ki in 0..k {
11873 let mut best_i = 0usize;
11874 let mut best_v = row_buf[0];
11875 for i in 1..axis_dim {
11876 let v = row_buf[i];
11877 if v > best_v {
11878 best_v = v;
11879 best_i = i;
11880 }
11881 }
11882 out[o * k + ki] = best_i as f32;
11883 row_buf[best_i] = f32::NEG_INFINITY;
11884 }
11885 }
11886 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
11887 cap.push_topk_f32(&out[..outer * k], axis_dim);
11888 }
11889 }
11890 }
11891 }
11892
11893 Thunk::Reduce {
11894 src,
11895 dst,
11896 outer,
11897 reduced,
11898 inner,
11899 op,
11900 } => {
11901 let outer = *outer as usize;
11902 let reduced = *reduced as usize;
11903 let inner = *inner as usize;
11904 let in_total = outer * reduced * inner;
11905 let out_total = outer * inner;
11906 unsafe {
11907 let inp = sl(*src, base, in_total);
11908 let out = sl_mut(*dst, base, out_total);
11909 for o in 0..outer {
11910 for i in 0..inner {
11911 let mut acc = match op {
11912 ReduceOp::Max => f32::NEG_INFINITY,
11913 ReduceOp::Min => f32::INFINITY,
11914 ReduceOp::Prod => 1.0f32,
11915 _ => 0.0f32, };
11917 for r in 0..reduced {
11919 let v = inp[o * reduced * inner + r * inner + i];
11920 acc = match op {
11921 ReduceOp::Sum | ReduceOp::Mean => acc + v,
11922 ReduceOp::Max => acc.max(v),
11923 ReduceOp::Min => acc.min(v),
11924 ReduceOp::Prod => acc * v,
11925 };
11926 }
11927 if matches!(op, ReduceOp::Mean) {
11928 acc /= reduced as f32;
11929 }
11930 out[o * inner + i] = acc;
11931 }
11932 }
11933 }
11934 }
11935
11936 Thunk::Conv2D1x1 {
11937 src,
11938 weight,
11939 dst,
11940 n,
11941 c_in,
11942 c_out,
11943 hw,
11944 } => {
11945 let n = *n as usize;
11946 let c_in = *c_in as usize;
11947 let c_out = *c_out as usize;
11948 let hw = *hw as usize;
11949 unsafe {
11950 let inp = sl(*src, base, n * c_in * hw);
11951 let wt = sl(*weight, base, c_out * c_in);
11952 let out = sl_mut(*dst, base, n * c_out * hw);
11953 for ni in 0..n {
11958 let in_off = ni * c_in * hw;
11959 let out_off = ni * c_out * hw;
11960 crate::blas::sgemm(
11961 wt,
11962 &inp[in_off..in_off + c_in * hw],
11963 &mut out[out_off..out_off + c_out * hw],
11964 c_out,
11965 c_in,
11966 hw,
11967 );
11968 }
11969 }
11970 }
11971
11972 Thunk::Conv2D {
11973 src,
11974 weight,
11975 dst,
11976 n,
11977 c_in,
11978 h,
11979 w,
11980 c_out,
11981 h_out,
11982 w_out,
11983 kh,
11984 kw,
11985 sh,
11986 sw,
11987 ph,
11988 pw,
11989 dh,
11990 dw,
11991 groups,
11992 } => {
11993 let n = *n as usize;
11994 let c_in = *c_in as usize;
11995 let h = *h as usize;
11996 let w = *w as usize;
11997 let c_out = *c_out as usize;
11998 let h_out = *h_out as usize;
11999 let w_out = *w_out as usize;
12000 let kh = *kh as usize;
12001 let kw = *kw as usize;
12002 let sh = *sh as usize;
12003 let sw = *sw as usize;
12004 let ph = *ph as usize;
12005 let pw = *pw as usize;
12006 let dh = *dh as usize;
12007 let dw = *dw as usize;
12008 let groups = *groups as usize;
12009 let c_in_per_g = c_in / groups;
12010 let c_out_per_g = c_out / groups;
12011 unsafe {
12012 let inp = sl(*src, base, n * c_in * h * w);
12013 let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
12014 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
12015 for ni in 0..n {
12016 for co in 0..c_out {
12017 let g = co / c_out_per_g;
12018 let ci_start = g * c_in_per_g;
12019 for ho in 0..h_out {
12020 for wo in 0..w_out {
12021 let mut acc = 0f32;
12022 for ci_off in 0..c_in_per_g {
12023 let ci = ci_start + ci_off;
12024 let in_chan = ((ni * c_in) + ci) * h * w;
12025 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12026 for ki in 0..kh {
12027 for kj in 0..kw {
12028 let hi = ho * sh + ki * dh;
12029 let wi = wo * sw + kj * dw;
12030 if hi < ph || wi < pw {
12031 continue;
12032 }
12033 let hi = hi - ph;
12034 let wi = wi - pw;
12035 if hi >= h || wi >= w {
12036 continue;
12037 }
12038 acc += inp[in_chan + hi * w + wi]
12039 * wt[wt_chan + ki * kw + kj];
12040 }
12041 }
12042 }
12043 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
12044 acc;
12045 }
12046 }
12047 }
12048 }
12049 }
12050 }
12051
12052 Thunk::Pool2D {
12053 src,
12054 dst,
12055 n,
12056 c,
12057 h,
12058 w,
12059 h_out,
12060 w_out,
12061 kh,
12062 kw,
12063 sh,
12064 sw,
12065 ph,
12066 pw,
12067 kind,
12068 } => {
12069 let n = *n as usize;
12070 let c = *c as usize;
12071 let h = *h as usize;
12072 let w = *w as usize;
12073 let h_out = *h_out as usize;
12074 let w_out = *w_out as usize;
12075 let kh = *kh as usize;
12076 let kw = *kw as usize;
12077 let sh = *sh as usize;
12078 let sw = *sw as usize;
12079 let ph = *ph as usize;
12080 let pw = *pw as usize;
12081 let kernel_area = (kh * kw) as f32;
12082 unsafe {
12083 let inp = sl(*src, base, n * c * h * w);
12084 let out = sl_mut(*dst, base, n * c * h_out * w_out);
12085 for ni in 0..n {
12086 for ci in 0..c {
12087 let in_chan = ni * c * h * w + ci * h * w;
12088 let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
12089 for ho in 0..h_out {
12090 for wo in 0..w_out {
12091 let mut acc = match kind {
12092 ReduceOp::Max => f32::NEG_INFINITY,
12093 _ => 0f32, };
12095 for ki in 0..kh {
12096 for kj in 0..kw {
12097 let hi = ho * sh + ki;
12098 let wi = wo * sw + kj;
12099 if hi < ph || wi < pw {
12101 continue;
12102 }
12103 let hi = hi - ph;
12104 let wi = wi - pw;
12105 if hi >= h || wi >= w {
12106 continue;
12107 }
12108 let v = inp[in_chan + hi * w + wi];
12109 match kind {
12110 ReduceOp::Max => acc = acc.max(v),
12111 _ => acc += v,
12112 }
12113 }
12114 }
12115 if matches!(kind, ReduceOp::Mean) {
12116 acc /= kernel_area;
12117 }
12118 out[out_chan + ho * w_out + wo] = acc;
12119 }
12120 }
12121 }
12122 }
12123 }
12124 }
12125
12126 Thunk::ReluBackward { x, dy, dx, len } => {
12127 let len = *len as usize;
12128 unsafe {
12129 let xs = sl(*x, base, len);
12130 let dys = sl(*dy, base, len);
12131 let out = sl_mut(*dx, base, len);
12132 for i in 0..len {
12133 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12134 }
12135 }
12136 }
12137
12138 Thunk::ReluBackwardF64 { x, dy, dx, len } => {
12139 let len = *len as usize;
12140 unsafe {
12141 let xs = sl_f64(*x, base, len);
12142 let dys = sl_f64(*dy, base, len);
12143 let out = sl_mut_f64(*dx, base, len);
12144 for i in 0..len {
12145 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12146 }
12147 }
12148 }
12149
12150 Thunk::QMatMul {
12151 x,
12152 w,
12153 bias,
12154 out,
12155 m,
12156 k,
12157 n,
12158 x_zp,
12159 w_zp,
12160 out_zp,
12161 mult,
12162 } => {
12163 let m = *m as usize;
12164 let k = *k as usize;
12165 let n = *n as usize;
12166 unsafe {
12167 let x_ptr = base.add(*x) as *const i8;
12168 let w_ptr = base.add(*w) as *const i8;
12169 let bias_ptr = base.add(*bias) as *const i32;
12170 let out_ptr = base.add(*out) as *mut i8;
12171 for mi in 0..m {
12172 for ni in 0..n {
12173 let mut acc: i32 = *bias_ptr.add(ni);
12174 for ki in 0..k {
12175 let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
12176 let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
12177 acc += xv * wv;
12178 }
12179 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12182 let r = r.clamp(-128, 127) as i8;
12183 *out_ptr.add(mi * n + ni) = r;
12184 }
12185 }
12186 }
12187 }
12188
12189 Thunk::QConv2d {
12190 x,
12191 w,
12192 bias,
12193 out,
12194 n,
12195 c_in,
12196 h,
12197 w_in,
12198 c_out,
12199 h_out,
12200 w_out,
12201 kh,
12202 kw,
12203 sh,
12204 sw,
12205 ph,
12206 pw,
12207 dh,
12208 dw,
12209 groups,
12210 x_zp,
12211 w_zp,
12212 out_zp,
12213 mult,
12214 } => {
12215 let n = *n as usize;
12216 let c_in = *c_in as usize;
12217 let h = *h as usize;
12218 let w_in = *w_in as usize;
12219 let c_out = *c_out as usize;
12220 let h_out = *h_out as usize;
12221 let w_out = *w_out as usize;
12222 let kh = *kh as usize;
12223 let kw = *kw as usize;
12224 let sh = *sh as usize;
12225 let sw = *sw as usize;
12226 let ph = *ph as usize;
12227 let pw = *pw as usize;
12228 let dh = *dh as usize;
12229 let dw = *dw as usize;
12230 let groups = *groups as usize;
12231 let c_in_per_g = c_in / groups;
12232 let c_out_per_g = c_out / groups;
12233 unsafe {
12234 let x_ptr = base.add(*x) as *const i8;
12235 let w_ptr = base.add(*w) as *const i8;
12236 let bias_ptr = base.add(*bias) as *const i32;
12237 let out_ptr = base.add(*out) as *mut i8;
12238 for ni in 0..n {
12239 for co in 0..c_out {
12240 let g = co / c_out_per_g;
12241 let ci_start = g * c_in_per_g;
12242 for ho in 0..h_out {
12243 for wo in 0..w_out {
12244 let mut acc: i32 = *bias_ptr.add(co);
12245 for ci_off in 0..c_in_per_g {
12246 let ci = ci_start + ci_off;
12247 let in_chan = ((ni * c_in) + ci) * h * w_in;
12248 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12249 for ki in 0..kh {
12250 for kj in 0..kw {
12251 let hi = ho * sh + ki * dh;
12252 let wi = wo * sw + kj * dw;
12253 if hi < ph || wi < pw {
12254 continue;
12255 }
12256 let hi = hi - ph;
12257 let wi = wi - pw;
12258 if hi >= h || wi >= w_in {
12259 continue;
12260 }
12261 let xv = *x_ptr.add(in_chan + hi * w_in + wi)
12262 as i32
12263 - *x_zp;
12264 let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
12265 - *w_zp;
12266 acc += xv * wv;
12267 }
12268 }
12269 }
12270 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12271 let r = r.clamp(-128, 127) as i8;
12272 let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
12273 *out_ptr.add(dst) = r;
12274 }
12275 }
12276 }
12277 }
12278 }
12279 }
12280
12281 Thunk::Quantize {
12282 x,
12283 q,
12284 len,
12285 chan_axis: _,
12286 chan_dim,
12287 inner,
12288 scales,
12289 zero_points,
12290 } => {
12291 let len = *len as usize;
12292 let chan_dim = *chan_dim as usize;
12293 let inner = *inner as usize;
12294 unsafe {
12295 let xs = sl(*x, base, len);
12296 let q_ptr = base.add(*q) as *mut i8;
12297 for i in 0..len {
12298 let c = if chan_dim == 1 {
12299 0
12300 } else {
12301 (i / inner) % chan_dim
12302 };
12303 let inv_scale = 1.0 / scales[c];
12304 let zp = zero_points[c];
12305 let v = (xs[i] * inv_scale).round() as i32 + zp;
12306 *q_ptr.add(i) = v.clamp(-128, 127) as i8;
12307 }
12308 }
12309 }
12310
12311 Thunk::Dequantize {
12312 q,
12313 x,
12314 len,
12315 chan_axis: _,
12316 chan_dim,
12317 inner,
12318 scales,
12319 zero_points,
12320 } => {
12321 let len = *len as usize;
12322 let chan_dim = *chan_dim as usize;
12323 let inner = *inner as usize;
12324 unsafe {
12325 let q_ptr = base.add(*q) as *const i8;
12326 let out = sl_mut(*x, base, len);
12327 for i in 0..len {
12328 let c = if chan_dim == 1 {
12329 0
12330 } else {
12331 (i / inner) % chan_dim
12332 };
12333 let scale = scales[c];
12334 let zp = zero_points[c];
12335 let qv = *q_ptr.add(i) as i32;
12336 out[i] = (qv - zp) as f32 * scale;
12337 }
12338 }
12339 }
12340
12341 Thunk::FakeQuantize {
12342 x,
12343 out,
12344 len,
12345 chan_axis: _,
12346 chan_dim,
12347 inner,
12348 bits,
12349 ste: _,
12350 scale_mode,
12351 state_off,
12352 } => {
12353 use rlx_ir::op::ScaleMode;
12354 let len = *len as usize;
12355 let chan_dim = *chan_dim as usize;
12356 let inner = *inner as usize;
12357 let q_max: f32 = match *bits {
12358 8 => 127.0,
12359 4 => 7.0,
12360 2 => 1.0,
12361 n => panic!("FakeQuantize: unsupported bits {n}"),
12362 };
12363 unsafe {
12364 let xs = sl(*x, base, len);
12365 let outs = sl_mut(*out, base, len);
12366
12367 let mut scale = vec![0f32; chan_dim];
12368 match scale_mode {
12369 ScaleMode::PerBatch => {
12370 let mut max_abs = vec![0f32; chan_dim];
12371 for i in 0..len {
12372 let c = if chan_dim == 1 {
12373 0
12374 } else {
12375 (i / inner) % chan_dim
12376 };
12377 let a = xs[i].abs();
12378 if a > max_abs[c] {
12379 max_abs[c] = a;
12380 }
12381 }
12382 for c in 0..chan_dim {
12383 scale[c] = (max_abs[c] / q_max).max(1e-12);
12384 }
12385 }
12386 ScaleMode::EMA { decay } => {
12387 let mut max_abs = vec![0f32; chan_dim];
12390 for i in 0..len {
12391 let c = if chan_dim == 1 {
12392 0
12393 } else {
12394 (i / inner) % chan_dim
12395 };
12396 let a = xs[i].abs();
12397 if a > max_abs[c] {
12398 max_abs[c] = a;
12399 }
12400 }
12401 let state =
12402 sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
12403 for c in 0..chan_dim {
12404 let cur = (max_abs[c] / q_max).max(1e-12);
12405 let blended = if state[c] <= 0.0 {
12407 cur
12408 } else {
12409 *decay * state[c] + (1.0 - *decay) * cur
12410 };
12411 state[c] = blended;
12412 scale[c] = blended;
12413 }
12414 }
12415 ScaleMode::Fixed => {
12416 let state =
12417 sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
12418 for c in 0..chan_dim {
12419 scale[c] = state[c].max(1e-12);
12420 }
12421 }
12422 }
12423
12424 for i in 0..len {
12425 let c = if chan_dim == 1 {
12426 0
12427 } else {
12428 (i / inner) % chan_dim
12429 };
12430 let s = scale[c];
12431 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
12432 outs[i] = qv * s;
12433 }
12434 }
12435 }
12436
12437 Thunk::ActivationBackward {
12438 x,
12439 dy,
12440 dx,
12441 len,
12442 kind,
12443 } => {
12444 let len = *len as usize;
12445 unsafe {
12446 let xs = sl(*x, base, len);
12447 let dys = sl(*dy, base, len);
12448 let out = sl_mut(*dx, base, len);
12449 activation_backward_kernel(*kind, xs, dys, out);
12450 }
12451 }
12452
12453 Thunk::ActivationBackwardF64 {
12454 x,
12455 dy,
12456 dx,
12457 len,
12458 kind,
12459 } => {
12460 let len = *len as usize;
12461 unsafe {
12462 let xs = sl_f64(*x, base, len);
12463 let dys = sl_f64(*dy, base, len);
12464 let out = sl_mut_f64(*dx, base, len);
12465 activation_backward_kernel_f64(*kind, xs, dys, out);
12466 }
12467 }
12468
12469 Thunk::FakeQuantizeLSQ {
12470 x,
12471 scale_off,
12472 out,
12473 len,
12474 chan_axis: _,
12475 chan_dim,
12476 inner,
12477 bits,
12478 } => {
12479 let len = *len as usize;
12480 let chan_dim = *chan_dim as usize;
12481 let inner = *inner as usize;
12482 let q_max: f32 = match *bits {
12483 8 => 127.0,
12484 4 => 7.0,
12485 2 => 1.0,
12486 n => panic!("FakeQuantizeLSQ: bad bits {n}"),
12487 };
12488 unsafe {
12489 let xs = sl(*x, base, len);
12490 let scale = sl(*scale_off, base, chan_dim);
12491 let outs = sl_mut(*out, base, len);
12492 for i in 0..len {
12493 let c = if chan_dim == 1 {
12494 0
12495 } else {
12496 (i / inner) % chan_dim
12497 };
12498 let s = scale[c].max(1e-12);
12499 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
12500 outs[i] = qv * s;
12501 }
12502 }
12503 }
12504
12505 Thunk::FakeQuantizeLSQBackwardX {
12506 x,
12507 scale_off,
12508 dy,
12509 dx,
12510 len,
12511 chan_axis: _,
12512 chan_dim,
12513 inner,
12514 bits,
12515 } => {
12516 let len = *len as usize;
12517 let chan_dim = *chan_dim as usize;
12518 let inner = *inner as usize;
12519 let q_max: f32 = match *bits {
12520 8 => 127.0,
12521 4 => 7.0,
12522 2 => 1.0,
12523 n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
12524 };
12525 unsafe {
12526 let xs = sl(*x, base, len);
12527 let scale = sl(*scale_off, base, chan_dim);
12528 let dys = sl(*dy, base, len);
12529 let outs = sl_mut(*dx, base, len);
12530 for i in 0..len {
12532 let c = if chan_dim == 1 {
12533 0
12534 } else {
12535 (i / inner) % chan_dim
12536 };
12537 let z = xs[i] / scale[c].max(1e-12);
12538 outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
12539 }
12540 }
12541 }
12542
12543 Thunk::FakeQuantizeLSQBackwardScale {
12544 x,
12545 scale_off,
12546 dy,
12547 dscale,
12548 len,
12549 chan_axis: _,
12550 chan_dim,
12551 inner,
12552 bits,
12553 } => {
12554 let len = *len as usize;
12555 let chan_dim = *chan_dim as usize;
12556 let inner = *inner as usize;
12557 let q_max: f32 = match *bits {
12558 8 => 127.0,
12559 4 => 7.0,
12560 2 => 1.0,
12561 n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
12562 };
12563 unsafe {
12564 let xs = sl(*x, base, len);
12565 let scale = sl(*scale_off, base, chan_dim);
12566 let dys = sl(*dy, base, len);
12567 let outs = sl_mut(*dscale, base, chan_dim);
12568 for v in outs.iter_mut() {
12569 *v = 0.0;
12570 }
12571 for i in 0..len {
12574 let c = if chan_dim == 1 {
12575 0
12576 } else {
12577 (i / inner) % chan_dim
12578 };
12579 let s = scale[c].max(1e-12);
12580 let z = xs[i] / s;
12581 let psi = if z.abs() <= q_max {
12582 -z + z.round()
12583 } else if z > 0.0 {
12584 q_max
12585 } else {
12586 -q_max
12587 };
12588 outs[c] += psi * dys[i];
12589 }
12590 }
12591 }
12592
12593 Thunk::FakeQuantizeBackward {
12594 x,
12595 dy,
12596 dx,
12597 len,
12598 chan_axis: _,
12599 chan_dim,
12600 inner,
12601 bits,
12602 ste,
12603 } => {
12604 use rlx_ir::op::SteKind;
12605 let len = *len as usize;
12606 let chan_dim = *chan_dim as usize;
12607 let inner = *inner as usize;
12608 let q_max: f32 = match *bits {
12609 8 => 127.0,
12610 4 => 7.0,
12611 2 => 1.0,
12612 n => panic!("FakeQuantizeBackward: bad bits {n}"),
12613 };
12614 unsafe {
12615 let xs = sl(*x, base, len);
12616 let dys = sl(*dy, base, len);
12617 let outs = sl_mut(*dx, base, len);
12618
12619 let mut max_abs = vec![0f32; chan_dim];
12621 for i in 0..len {
12622 let c = if chan_dim == 1 {
12623 0
12624 } else {
12625 (i / inner) % chan_dim
12626 };
12627 let a = xs[i].abs();
12628 if a > max_abs[c] {
12629 max_abs[c] = a;
12630 }
12631 }
12632 let mut scale = vec![0f32; chan_dim];
12633 for c in 0..chan_dim {
12634 scale[c] = (max_abs[c] / q_max).max(1e-12);
12635 }
12636
12637 match *ste {
12638 SteKind::Identity => {
12639 outs.copy_from_slice(dys);
12641 }
12642 SteKind::ClippedIdentity => {
12643 for i in 0..len {
12646 let c = if chan_dim == 1 {
12647 0
12648 } else {
12649 (i / inner) % chan_dim
12650 };
12651 let bound = q_max * scale[c];
12652 outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
12653 }
12654 }
12655 SteKind::Tanh => {
12656 for i in 0..len {
12658 let c = if chan_dim == 1 {
12659 0
12660 } else {
12661 (i / inner) % chan_dim
12662 };
12663 let t = (xs[i] / scale[c]).tanh();
12664 outs[i] = dys[i] * (1.0 - t * t);
12665 }
12666 }
12667 SteKind::HardTanh => {
12668 for i in 0..len {
12670 let c = if chan_dim == 1 {
12671 0
12672 } else {
12673 (i / inner) % chan_dim
12674 };
12675 let bound = q_max * scale[c];
12676 let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
12677 outs[i] = dys[i] * attenuation;
12678 }
12679 }
12680 }
12681 }
12682 }
12683
12684 Thunk::LayerNormBackwardInput {
12685 x,
12686 gamma,
12687 dy,
12688 dx,
12689 rows,
12690 h,
12691 eps,
12692 } => {
12693 let rows = *rows as usize;
12694 let h = *h as usize;
12695 let eps = *eps;
12696 unsafe {
12697 let xs = sl(*x, base, rows * h);
12698 let g = sl(*gamma, base, h);
12699 let dys = sl(*dy, base, rows * h);
12700 let out = sl_mut(*dx, base, rows * h);
12701 let n_inv = 1.0 / h as f32;
12702 for r in 0..rows {
12703 let xr = &xs[r * h..(r + 1) * h];
12704 let dyr = &dys[r * h..(r + 1) * h];
12705 let mut sum = 0f32;
12708 for &v in xr {
12709 sum += v;
12710 }
12711 let mean = sum * n_inv;
12712 let mut var = 0f32;
12713 for &v in xr {
12714 let d = v - mean;
12715 var += d * d;
12716 }
12717 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
12718
12719 let mut s_sy = 0f32;
12722 let mut s_sxh = 0f32;
12723 for d in 0..h {
12724 let xh = (xr[d] - mean) * inv_std;
12725 let sy = dyr[d] * g[d];
12726 s_sy += sy;
12727 s_sxh += sy * xh;
12728 }
12729 let m_sy = s_sy * n_inv;
12730 let m_sxh = s_sxh * n_inv;
12731
12732 for d in 0..h {
12733 let xh = (xr[d] - mean) * inv_std;
12734 let sy = dyr[d] * g[d];
12735 out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
12736 }
12737 }
12738 }
12739 }
12740
12741 Thunk::BatchNormInferenceBackwardInput {
12742 x,
12743 gamma,
12744 mean,
12745 var,
12746 dy,
12747 dx,
12748 count,
12749 channels,
12750 eps,
12751 } => {
12752 let count = *count as usize;
12753 let c = *channels as usize;
12754 let n = count * c;
12755 let eps = *eps;
12756 unsafe {
12757 crate::kernels::batch_norm_inference_backward_input(
12758 sl(*x, base, n),
12759 sl(*gamma, base, c),
12760 sl(*mean, base, c),
12761 sl(*var, base, c),
12762 sl(*dy, base, n),
12763 sl_mut(*dx, base, n),
12764 c,
12765 eps,
12766 );
12767 }
12768 }
12769
12770 Thunk::BatchNormInferenceBackwardGamma {
12771 x,
12772 mean,
12773 var,
12774 dy,
12775 dgamma,
12776 count,
12777 channels,
12778 eps,
12779 } => {
12780 let count = *count as usize;
12781 let c = *channels as usize;
12782 let n = count * c;
12783 let eps = *eps;
12784 unsafe {
12785 crate::kernels::batch_norm_inference_backward_gamma(
12786 sl(*x, base, n),
12787 sl(*mean, base, c),
12788 sl(*var, base, c),
12789 sl(*dy, base, n),
12790 sl_mut(*dgamma, base, c),
12791 c,
12792 eps,
12793 );
12794 }
12795 }
12796
12797 Thunk::BatchNormInferenceBackwardBeta {
12798 dy,
12799 dbeta,
12800 count,
12801 channels,
12802 } => {
12803 let count = *count as usize;
12804 let c = *channels as usize;
12805 let n = count * c;
12806 unsafe {
12807 crate::kernels::batch_norm_inference_backward_beta(
12808 sl(*dy, base, n),
12809 sl_mut(*dbeta, base, c),
12810 c,
12811 );
12812 }
12813 }
12814
12815 Thunk::LayerNormBackwardGamma {
12816 x,
12817 dy,
12818 dgamma,
12819 rows,
12820 h,
12821 eps,
12822 } => {
12823 let rows = *rows as usize;
12824 let h = *h as usize;
12825 let eps = *eps;
12826 unsafe {
12827 let xs = sl(*x, base, rows * h);
12828 let dys = sl(*dy, base, rows * h);
12829 let out = sl_mut(*dgamma, base, h);
12830 for v in out.iter_mut() {
12831 *v = 0.0;
12832 }
12833 let n_inv = 1.0 / h as f32;
12834 for r in 0..rows {
12835 let xr = &xs[r * h..(r + 1) * h];
12836 let dyr = &dys[r * h..(r + 1) * h];
12837 let mut sum = 0f32;
12838 for &v in xr {
12839 sum += v;
12840 }
12841 let mean = sum * n_inv;
12842 let mut var = 0f32;
12843 for &v in xr {
12844 let d = v - mean;
12845 var += d * d;
12846 }
12847 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
12848 for d in 0..h {
12849 let xh = (xr[d] - mean) * inv_std;
12850 out[d] += dyr[d] * xh;
12851 }
12852 }
12853 }
12854 }
12855
12856 Thunk::RmsNormBackwardInput {
12857 x,
12858 gamma,
12859 beta,
12860 dy,
12861 dx,
12862 rows,
12863 h,
12864 eps,
12865 } => {
12866 let (rows, h) = (*rows as usize, *h as usize);
12867 unsafe {
12868 let xs = sl(*x, base, rows * h);
12869 let g = sl(*gamma, base, h);
12870 let b = sl(*beta, base, h);
12871 let dys = sl(*dy, base, rows * h);
12872 let out = sl_mut(*dx, base, rows * h);
12873 let mut dg = vec![0f32; h];
12874 let mut db = vec![0f32; h];
12875 for r in 0..rows {
12876 crate::training_bwd::rms_norm_backward_row(
12877 &xs[r * h..(r + 1) * h],
12878 g,
12879 b,
12880 &dys[r * h..(r + 1) * h],
12881 &mut out[r * h..(r + 1) * h],
12882 &mut dg,
12883 &mut db,
12884 *eps,
12885 );
12886 }
12887 }
12888 }
12889
12890 Thunk::RmsNormBackwardGamma {
12891 x,
12892 gamma,
12893 beta,
12894 dy,
12895 dgamma,
12896 rows,
12897 h,
12898 eps,
12899 } => {
12900 let (rows, h) = (*rows as usize, *h as usize);
12901 unsafe {
12902 let xs = sl(*x, base, rows * h);
12903 let g = sl(*gamma, base, h);
12904 let b = sl(*beta, base, h);
12905 let dys = sl(*dy, base, rows * h);
12906 let out = sl_mut(*dgamma, base, h);
12907 for v in out.iter_mut() {
12908 *v = 0.0;
12909 }
12910 let mut dx = vec![0f32; h];
12911 let mut db = vec![0f32; h];
12912 for r in 0..rows {
12913 crate::training_bwd::rms_norm_backward_row(
12914 &xs[r * h..(r + 1) * h],
12915 g,
12916 b,
12917 &dys[r * h..(r + 1) * h],
12918 &mut dx,
12919 &mut *out,
12920 &mut db,
12921 *eps,
12922 );
12923 }
12924 }
12925 }
12926
12927 Thunk::RmsNormBackwardBeta {
12928 x,
12929 gamma,
12930 beta,
12931 dy,
12932 dbeta,
12933 rows,
12934 h,
12935 eps,
12936 } => {
12937 let (rows, h) = (*rows as usize, *h as usize);
12938 unsafe {
12939 let xs = sl(*x, base, rows * h);
12940 let g = sl(*gamma, base, h);
12941 let b = sl(*beta, base, h);
12942 let dys = sl(*dy, base, rows * h);
12943 let out = sl_mut(*dbeta, base, h);
12944 for v in out.iter_mut() {
12945 *v = 0.0;
12946 }
12947 let mut dx = vec![0f32; h];
12948 let mut dg = vec![0f32; h];
12949 for r in 0..rows {
12950 crate::training_bwd::rms_norm_backward_row(
12951 &xs[r * h..(r + 1) * h],
12952 g,
12953 b,
12954 &dys[r * h..(r + 1) * h],
12955 &mut dx,
12956 &mut dg,
12957 &mut *out,
12958 *eps,
12959 );
12960 }
12961 }
12962 }
12963
12964 Thunk::RopeBackward {
12965 dy,
12966 cos,
12967 sin,
12968 dx,
12969 batch,
12970 seq,
12971 hidden,
12972 head_dim,
12973 n_rot,
12974 cos_len,
12975 } => {
12976 let (b, s, hs, dh, nr, cl) = (
12977 *batch as usize,
12978 *seq as usize,
12979 *hidden as usize,
12980 *head_dim as usize,
12981 *n_rot as usize,
12982 *cos_len as usize,
12983 );
12984 let nh = hs / dh;
12985 let tab_half = dh / 2;
12986 unsafe {
12987 let dys = sl(*dy, base, b * s * hs);
12988 let cos_tab = sl(*cos, base, cl);
12989 let sin_tab = sl(*sin, base, cl);
12990 let out = sl_mut(*dx, base, b * s * hs);
12991 for bi in 0..b {
12992 for si in 0..s {
12993 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12994 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12995 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12996 for hi in 0..nh {
12997 let base_idx = bi * s * hs + si * hs + hi * dh;
12998 crate::training_bwd::rope_backward_row(
12999 &dys[base_idx..base_idx + dh],
13000 cp,
13001 sp,
13002 &mut out[base_idx..base_idx + dh],
13003 dh,
13004 nr,
13005 );
13006 }
13007 }
13008 }
13009 }
13010 }
13011
13012 Thunk::CumsumBackward {
13013 dy,
13014 dx,
13015 rows,
13016 cols,
13017 exclusive,
13018 } => {
13019 let (rows, cols) = (*rows as usize, *cols as usize);
13020 unsafe {
13021 let dys = sl(*dy, base, rows * cols);
13022 let out = sl_mut(*dx, base, rows * cols);
13023 for r in 0..rows {
13024 crate::training_bwd::cumsum_backward_row(
13025 &dys[r * cols..(r + 1) * cols],
13026 &mut out[r * cols..(r + 1) * cols],
13027 *exclusive,
13028 );
13029 }
13030 }
13031 }
13032
13033 Thunk::GroupNormBackwardInput {
13034 x,
13035 gamma,
13036 beta: _beta,
13037 dy,
13038 dx,
13039 n,
13040 c,
13041 h,
13042 w,
13043 num_groups,
13044 eps,
13045 } => {
13046 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13047 let plane = c * h * w;
13048 unsafe {
13049 let xs = sl(*x, base, n * plane);
13050 let g = sl(*gamma, base, c);
13051 let dys = sl(*dy, base, n * plane);
13052 let out = sl_mut(*dx, base, n * plane);
13053 crate::training_bwd::group_norm_backward_input_nchw(
13054 xs,
13055 g,
13056 dys,
13057 out,
13058 n,
13059 c,
13060 h,
13061 w,
13062 *num_groups as usize,
13063 *eps,
13064 );
13065 }
13066 }
13067
13068 Thunk::GroupNormBackwardGamma {
13069 x,
13070 dy,
13071 dgamma,
13072 n,
13073 c,
13074 h,
13075 w,
13076 num_groups,
13077 eps,
13078 } => {
13079 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13080 let plane = c * h * w;
13081 unsafe {
13082 let xs = sl(*x, base, n * plane);
13083 let dys = sl(*dy, base, n * plane);
13084 let out = sl_mut(*dgamma, base, c);
13085 crate::training_bwd::group_norm_backward_gamma_nchw(
13086 xs,
13087 dys,
13088 out,
13089 n,
13090 c,
13091 h,
13092 w,
13093 *num_groups as usize,
13094 *eps,
13095 );
13096 }
13097 }
13098
13099 Thunk::GroupNormBackwardBeta {
13100 dy,
13101 dbeta,
13102 n,
13103 c,
13104 h,
13105 w,
13106 } => {
13107 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13108 let plane = c * h * w;
13109 unsafe {
13110 let dys = sl(*dy, base, n * plane);
13111 let out = sl_mut(*dbeta, base, c);
13112 crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
13113 }
13114 }
13115
13116 Thunk::GatherBackward {
13117 dy,
13118 indices,
13119 dst,
13120 outer,
13121 axis_dim,
13122 num_idx,
13123 trailing,
13124 } => {
13125 let (outer, axis_dim, num_idx, trailing) = (
13126 *outer as usize,
13127 *axis_dim as usize,
13128 *num_idx as usize,
13129 *trailing as usize,
13130 );
13131 unsafe {
13132 let dys = sl(*dy, base, outer * num_idx * trailing);
13133 let ids = sl(*indices, base, num_idx);
13134 let out = sl_mut(*dst, base, outer * axis_dim * trailing);
13135 for v in out.iter_mut() {
13136 *v = 0.0;
13137 }
13138 crate::training_bwd::gather_axis_backward(
13139 dys, ids, out, outer, axis_dim, num_idx, trailing,
13140 );
13141 }
13142 }
13143
13144 Thunk::MaxPool2dBackward {
13145 x,
13146 dy,
13147 dx,
13148 n,
13149 c,
13150 h,
13151 w,
13152 h_out,
13153 w_out,
13154 kh,
13155 kw,
13156 sh,
13157 sw,
13158 ph,
13159 pw,
13160 } => unsafe {
13161 execute_maxpool2d_backward_f32(
13162 *x, *dy, *dx, *n, *c, *h, *w, *h_out, *w_out, *kh, *kw, *sh, *sw, *ph, *pw,
13163 base,
13164 );
13165 },
13166
13167 Thunk::Conv2dBackwardInput {
13168 dy,
13169 w,
13170 dx,
13171 n,
13172 c_in,
13173 h,
13174 w_in,
13175 c_out,
13176 h_out,
13177 w_out,
13178 kh,
13179 kw,
13180 sh,
13181 sw,
13182 ph,
13183 pw,
13184 dh,
13185 dw,
13186 groups,
13187 } => {
13188 let n = *n as usize;
13200 let c_in = *c_in as usize;
13201 let h = *h as usize;
13202 let w_in = *w_in as usize;
13203 let c_out = *c_out as usize;
13204 let h_out = *h_out as usize;
13205 let w_out = *w_out as usize;
13206 let kh = *kh as usize;
13207 let kw = *kw as usize;
13208 let sh = *sh as usize;
13209 let sw = *sw as usize;
13210 let ph = *ph as usize;
13211 let pw = *pw as usize;
13212 let dh = *dh as usize;
13213 let dw = *dw as usize;
13214 let groups = *groups as usize;
13215 let c_in_per_g = c_in / groups;
13216 let c_out_per_g = c_out / groups;
13217
13218 let m_dim = c_in_per_g * kh * kw;
13219 let n_dim = h_out * w_out;
13220 let k_dim = c_out_per_g;
13221
13222 let dy_stride_n = c_out * h_out * w_out;
13223 let dy_stride_g = c_out_per_g * h_out * w_out;
13224 let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13225 let dx_stride_n = c_in * h * w_in;
13226 let dx_stride_g = c_in_per_g * h * w_in;
13227
13228 unsafe {
13229 let dys = sl(*dy, base, n * c_out * h_out * w_out);
13230 let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
13231 let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
13232 for v in dxs.iter_mut() {
13233 *v = 0.0;
13234 }
13235
13236 let mut dcol = vec![0f32; m_dim * n_dim];
13238
13239 for ni in 0..n {
13240 for g in 0..groups {
13241 let w_g_off = g * w_stride_g;
13242 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13243 let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
13244
13245 crate::blas::sgemm_general(
13250 ws.as_ptr().add(w_g_off),
13251 dys.as_ptr().add(dy_n_g_off),
13252 dcol.as_mut_ptr(),
13253 m_dim,
13254 n_dim,
13255 k_dim,
13256 1.0,
13257 0.0,
13258 m_dim,
13259 n_dim,
13260 n_dim,
13261 true,
13262 false,
13263 );
13264
13265 col2im(
13267 &dcol,
13268 &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
13269 c_in_per_g,
13270 h,
13271 w_in,
13272 h_out,
13273 w_out,
13274 kh,
13275 kw,
13276 sh,
13277 sw,
13278 ph,
13279 pw,
13280 dh,
13281 dw,
13282 );
13283 }
13284 }
13285 }
13286 }
13287
13288 Thunk::Conv2dBackwardWeight {
13289 x,
13290 dy,
13291 dw,
13292 n,
13293 c_in,
13294 h,
13295 w,
13296 c_out,
13297 h_out,
13298 w_out,
13299 kh,
13300 kw,
13301 sh,
13302 sw,
13303 ph,
13304 pw,
13305 dh,
13306 dw_dil,
13307 groups,
13308 } => {
13309 let n = *n as usize;
13310 let c_in = *c_in as usize;
13311 let h = *h as usize;
13312 let w = *w as usize;
13313 let c_out = *c_out as usize;
13324 let h_out = *h_out as usize;
13325 let w_out = *w_out as usize;
13326 let kh = *kh as usize;
13327 let kw = *kw as usize;
13328 let sh = *sh as usize;
13329 let sw = *sw as usize;
13330 let ph = *ph as usize;
13331 let pw = *pw as usize;
13332 let dh = *dh as usize;
13333 let dw_dil = *dw_dil as usize;
13334 let groups = *groups as usize;
13335 let c_in_per_g = c_in / groups;
13336 let c_out_per_g = c_out / groups;
13337
13338 let m_dim = c_out_per_g;
13339 let n_dim = c_in_per_g * kh * kw;
13340 let k_dim = h_out * w_out;
13341
13342 let x_stride_n = c_in * h * w;
13343 let x_stride_g = c_in_per_g * h * w;
13344 let dy_stride_n = c_out * h_out * w_out;
13345 let dy_stride_g = c_out_per_g * h_out * w_out;
13346 let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13347
13348 unsafe {
13349 let xs = sl(*x, base, n * c_in * h * w);
13350 let dys = sl(*dy, base, n * c_out * h_out * w_out);
13351 let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
13352 for v in dws.iter_mut() {
13353 *v = 0.0;
13354 }
13355
13356 let mut col = vec![0f32; n_dim * k_dim];
13357
13358 for ni in 0..n {
13359 for g in 0..groups {
13360 let x_n_g_off = ni * x_stride_n + g * x_stride_g;
13361 im2col(
13362 &xs[x_n_g_off..x_n_g_off + x_stride_g],
13363 &mut col,
13364 c_in_per_g,
13365 h,
13366 w,
13367 h_out,
13368 w_out,
13369 kh,
13370 kw,
13371 sh,
13372 sw,
13373 ph,
13374 pw,
13375 dh,
13376 dw_dil,
13377 );
13378
13379 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13380 let dw_g_off = g * dw_stride_g;
13381
13382 crate::blas::sgemm_general(
13390 dys.as_ptr().add(dy_n_g_off),
13391 col.as_ptr(),
13392 dws.as_mut_ptr().add(dw_g_off),
13393 m_dim,
13394 n_dim,
13395 k_dim,
13396 1.0,
13397 1.0,
13398 k_dim,
13399 k_dim,
13400 n_dim,
13401 false,
13402 true,
13403 );
13404 }
13405 }
13406 }
13407 }
13408
13409 Thunk::Im2Col {
13410 x,
13411 col,
13412 n,
13413 c_in,
13414 h,
13415 w,
13416 h_out,
13417 w_out,
13418 kh,
13419 kw,
13420 sh,
13421 sw,
13422 ph,
13423 pw,
13424 dh,
13425 dw_dil,
13426 } => {
13427 let c_in = *c_in as usize;
13428 let h = *h as usize;
13429 let w = *w as usize;
13430 let h_out = *h_out as usize;
13431 let w_out = *w_out as usize;
13432 let kh = *kh as usize;
13433 let kw = *kw as usize;
13434 let sh = *sh as usize;
13435 let sw = *sw as usize;
13436 let ph = *ph as usize;
13437 let pw = *pw as usize;
13438 let dh = *dh as usize;
13439 let dw_dil = *dw_dil as usize;
13440 let per_batch = c_in * h * w;
13441 unsafe {
13442 let n_eff = if *n == 0 { 0usize } else { *n as usize };
13443 let x_floats = if n_eff == 0 {
13444 per_batch.max(1)
13445 } else {
13446 n_eff * per_batch
13447 };
13448 let xs = sl(*x, base, x_floats);
13449 let n = if *n == 0 {
13450 xs.len() / per_batch.max(1)
13451 } else {
13452 n_eff
13453 };
13454 let m = n * h_out * w_out;
13455 let k = c_in * kh * kw;
13456 let cols = sl_mut(*col, base, m * k);
13457 crate::im2col::im2col_rows_layout(
13458 xs, cols, n, c_in, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw_dil,
13459 );
13460 }
13461 }
13462
13463 Thunk::SoftmaxCrossEntropy {
13464 logits,
13465 labels,
13466 dst,
13467 n,
13468 c,
13469 } => {
13470 let n = *n as usize;
13471 let c = *c as usize;
13472 unsafe {
13473 let lg = sl(*logits, base, n * c);
13474 let lb = sl(*labels, base, n);
13475 let out = sl_mut(*dst, base, n);
13476 for ni in 0..n {
13477 let row = &lg[ni * c..(ni + 1) * c];
13478 let mut m = f32::NEG_INFINITY;
13480 for &v in row {
13481 if v > m {
13482 m = v;
13483 }
13484 }
13485 let mut sum = 0f32;
13486 for &v in row {
13487 sum += (v - m).exp();
13488 }
13489 let lse = m + sum.ln();
13490 let label_idx = lb[ni] as usize;
13491 out[ni] = lse - row[label_idx];
13493 }
13494 }
13495 }
13496
13497 Thunk::SoftmaxCrossEntropyBackward {
13498 logits,
13499 labels,
13500 d_loss,
13501 dlogits,
13502 n,
13503 c,
13504 } => {
13505 let n = *n as usize;
13506 let c = *c as usize;
13507 unsafe {
13508 let lg = sl(*logits, base, n * c);
13509 let lb = sl(*labels, base, n);
13510 let dl = sl(*d_loss, base, n);
13511 let out = sl_mut(*dlogits, base, n * c);
13512 for ni in 0..n {
13513 let row = &lg[ni * c..(ni + 1) * c];
13514 let label_idx = lb[ni] as usize;
13515 let scale = dl[ni];
13516 let mut m = f32::NEG_INFINITY;
13517 for &v in row {
13518 if v > m {
13519 m = v;
13520 }
13521 }
13522 let mut sum = 0f32;
13523 for &v in row {
13524 sum += (v - m).exp();
13525 }
13526 let inv_sum = 1.0 / sum;
13527 let dst_row = &mut out[ni * c..(ni + 1) * c];
13528 for k in 0..c {
13529 let p = (row[k] - m).exp() * inv_sum;
13530 let one_hot = if k == label_idx { 1.0 } else { 0.0 };
13531 dst_row[k] = (p - one_hot) * scale;
13532 }
13533 }
13534 }
13535 }
13536
13537 Thunk::GatherAxis {
13538 table,
13539 idx,
13540 dst,
13541 outer,
13542 axis_dim,
13543 num_idx,
13544 trailing,
13545 idx_i64,
13546 table_bytes,
13547 } => {
13548 let outer = *outer as usize;
13549 let axis_dim = *axis_dim as usize;
13550 let num_idx = *num_idx as usize;
13551 let trailing = *trailing as usize;
13552 unsafe {
13553 if *table_bytes == 8 {
13554 let tab = sl_i64(*table, base, outer * axis_dim * trailing);
13555 let out = sl_mut_i64(*dst, base, outer * num_idx * trailing);
13556 for o in 0..outer {
13557 let tab_outer = o * axis_dim * trailing;
13558 let out_outer = o * num_idx * trailing;
13559 if *idx_i64 != 0 {
13560 let ids = sl_i64(*idx, base, num_idx);
13561 for k in 0..num_idx {
13562 let row = ids[k].max(0) as usize;
13563 if row < axis_dim {
13564 let tab_row = tab_outer + row * trailing;
13565 let out_row = out_outer + k * trailing;
13566 out[out_row..out_row + trailing]
13567 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13568 }
13569 }
13570 } else {
13571 let ids = sl(*idx, base, num_idx);
13572 for k in 0..num_idx {
13573 let row = ids[k] as usize;
13574 if row < axis_dim {
13575 let tab_row = tab_outer + row * trailing;
13576 let out_row = out_outer + k * trailing;
13577 out[out_row..out_row + trailing]
13578 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13579 }
13580 }
13581 }
13582 }
13583 } else {
13584 let tab = sl(*table, base, outer * axis_dim * trailing);
13585 let out = sl_mut(*dst, base, outer * num_idx * trailing);
13586 for o in 0..outer {
13587 let tab_outer = o * axis_dim * trailing;
13588 let out_outer = o * num_idx * trailing;
13589 if *idx_i64 != 0 {
13590 let ids = sl_i64(*idx, base, num_idx);
13591 for k in 0..num_idx {
13592 let row = ids[k].max(0) as usize;
13593 if row < axis_dim {
13594 let tab_row = tab_outer + row * trailing;
13595 let out_row = out_outer + k * trailing;
13596 out[out_row..out_row + trailing]
13597 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13598 }
13599 }
13600 } else {
13601 let ids = sl(*idx, base, num_idx);
13602 for k in 0..num_idx {
13603 let row = ids[k] as usize;
13604 if row < axis_dim {
13605 let tab_row = tab_outer + row * trailing;
13606 let out_row = out_outer + k * trailing;
13607 out[out_row..out_row + trailing]
13608 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13609 }
13610 }
13611 }
13612 }
13613 }
13614 }
13615 }
13616
13617 Thunk::Transpose {
13618 src,
13619 dst,
13620 in_total,
13621 out_dims,
13622 in_strides,
13623 elem_bytes,
13624 } => {
13625 let rank = out_dims.len();
13630 let total: usize = out_dims.iter().map(|&d| d as usize).product();
13631 let in_total = *in_total as usize;
13632 unsafe {
13633 if *elem_bytes == 8 {
13634 let inp = sl_i64(*src, base, in_total);
13635 let out = sl_mut_i64(*dst, base, total);
13636 let mut idx = vec![0usize; rank];
13637 for o in 0..total {
13638 let mut src_idx = 0usize;
13639 for d in 0..rank {
13640 src_idx += idx[d] * in_strides[d] as usize;
13641 }
13642 out[o] = inp[src_idx];
13643 for d in (0..rank).rev() {
13644 idx[d] += 1;
13645 if idx[d] < out_dims[d] as usize {
13646 break;
13647 }
13648 idx[d] = 0;
13649 }
13650 }
13651 } else {
13652 let inp = sl(*src, base, in_total);
13653 let out = sl_mut(*dst, base, total);
13654 let mut idx = vec![0usize; rank];
13655 for o in 0..total {
13656 let mut src_idx = 0usize;
13657 for d in 0..rank {
13658 src_idx += idx[d] * in_strides[d] as usize;
13659 }
13660 out[o] = inp[src_idx];
13661 for d in (0..rank).rev() {
13662 idx[d] += 1;
13663 if idx[d] < out_dims[d] as usize {
13664 break;
13665 }
13666 idx[d] = 0;
13667 }
13668 }
13669 }
13670 }
13671 }
13672
13673 Thunk::CustomOp {
13679 kernel,
13680 inputs,
13681 output,
13682 attrs,
13683 } => {
13684 let (out_off, out_len, out_shape) = output;
13685 unsafe {
13686 dispatch_custom_op(
13687 &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
13688 );
13689 }
13690 }
13691 }
13692 if trace_done {
13693 eprintln!("[thunk {i} done]");
13694 }
13695 }
13696}
13697
13698#[allow(clippy::too_many_arguments)]
13713unsafe fn griewank_process_segment(
13714 t_lo: usize,
13715 t_hi: usize,
13716 anchor_carry: &[u8],
13717 cb: usize,
13718 fwd_sched: &ThunkSchedule,
13719 fwd_init: &[u8],
13720 fwd_carry_in_off: usize,
13721 fwd_output_off: usize,
13722 fwd_x_offs: &[usize],
13723 base: *mut u8,
13724 outer_xs_offs: &[(usize, u32)],
13725 fwd_buf: &mut Vec<u8>,
13726 leaf_threshold: usize,
13727 process_iter: &mut dyn FnMut(usize, &[u8]),
13728) {
13729 unsafe {
13730 let size = t_hi - t_lo + 1;
13731 if size == 1 {
13732 process_iter(t_lo, anchor_carry);
13733 return;
13734 }
13735 if size <= leaf_threshold {
13736 let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
13738 cache.extend_from_slice(anchor_carry);
13739 fwd_buf.copy_from_slice(fwd_init);
13740 std::ptr::copy_nonoverlapping(
13741 anchor_carry.as_ptr(),
13742 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
13743 cb,
13744 );
13745 for i in 1..size {
13746 let cur_iter = t_lo + i - 1;
13747 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
13748 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
13749 let xb = x_psb as usize;
13750 std::ptr::copy_nonoverlapping(
13751 base.add(outer_xs_off + cur_iter * xb),
13752 fwd_buf.as_mut_ptr().add(*fb_x_off),
13753 xb,
13754 );
13755 }
13756 execute_thunks(fwd_sched, fwd_buf);
13757 if fwd_output_off != fwd_carry_in_off {
13758 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
13759 }
13760 cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
13761 }
13762 for t in (t_lo..=t_hi).rev() {
13764 let idx = t - t_lo;
13765 let carry = &cache[idx * cb..(idx + 1) * cb];
13766 process_iter(t, carry);
13767 }
13768 return;
13769 }
13770
13771 let mid = t_lo + size / 2;
13775 fwd_buf.copy_from_slice(fwd_init);
13776 std::ptr::copy_nonoverlapping(
13777 anchor_carry.as_ptr(),
13778 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
13779 cb,
13780 );
13781 for cur_iter in t_lo..mid {
13782 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
13783 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
13784 let xb = x_psb as usize;
13785 std::ptr::copy_nonoverlapping(
13786 base.add(outer_xs_off + cur_iter * xb),
13787 fwd_buf.as_mut_ptr().add(*fb_x_off),
13788 xb,
13789 );
13790 }
13791 execute_thunks(fwd_sched, fwd_buf);
13792 if fwd_output_off != fwd_carry_in_off {
13793 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
13794 }
13795 }
13796 let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
13797
13798 griewank_process_segment(
13802 mid,
13803 t_hi,
13804 &mid_carry,
13805 cb,
13806 fwd_sched,
13807 fwd_init,
13808 fwd_carry_in_off,
13809 fwd_output_off,
13810 fwd_x_offs,
13811 base,
13812 outer_xs_offs,
13813 fwd_buf,
13814 leaf_threshold,
13815 process_iter,
13816 );
13817 griewank_process_segment(
13819 t_lo,
13820 mid - 1,
13821 anchor_carry,
13822 cb,
13823 fwd_sched,
13824 fwd_init,
13825 fwd_carry_in_off,
13826 fwd_output_off,
13827 fwd_x_offs,
13828 base,
13829 outer_xs_offs,
13830 fwd_buf,
13831 leaf_threshold,
13832 process_iter,
13833 );
13834 }
13835}
13836
13837pub unsafe fn execute_fft1d_f64(
13854 src: usize,
13855 dst: usize,
13856 outer: usize,
13857 n_complex: usize,
13858 inverse: bool,
13859 norm_tag: u32,
13860 base: *mut u8,
13861) {
13862 let row_elems = 2 * n_complex;
13863 let mut re = vec![0f64; n_complex];
13864 let mut im = vec![0f64; n_complex];
13865 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
13866 let scale = norm.output_scale(n_complex, inverse);
13867 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
13870 BluesteinScratchF64::empty()
13871 } else {
13872 BluesteinScratchF64::build(n_complex, inverse)
13873 };
13874 for o in 0..outer {
13875 let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
13876 let s = unsafe { sl_f64(row_offset, base, row_elems) };
13877 re.copy_from_slice(&s[..n_complex]);
13878 im.copy_from_slice(&s[n_complex..]);
13879 if n_complex.is_power_of_two() {
13880 fft_radix2_inplace_f64(&mut re, &mut im, inverse);
13881 } else if n_complex <= 16 {
13882 fft_naive_inplace_f64(&mut re, &mut im, inverse);
13883 } else {
13884 fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
13885 }
13886 if scale != 1.0 {
13887 re.iter_mut().for_each(|v| *v *= scale);
13888 im.iter_mut().for_each(|v| *v *= scale);
13889 }
13890 let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
13891 let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
13892 d[..n_complex].copy_from_slice(&re);
13893 d[n_complex..].copy_from_slice(&im);
13894 }
13895}
13896
13897pub unsafe fn execute_gated_delta_net_f32(
13906 q: usize,
13907 k: usize,
13908 v: usize,
13909 g: usize,
13910 beta: usize,
13911 state: usize,
13912 dst: usize,
13913 batch: usize,
13914 seq: usize,
13915 heads: usize,
13916 state_size: usize,
13917 base: *mut u8,
13918) {
13919 use rayon::prelude::*;
13920
13921 #[derive(Copy, Clone)]
13922 struct ArenaPtr(usize);
13923 unsafe impl Send for ArenaPtr {}
13924 unsafe impl Sync for ArenaPtr {}
13925 impl ArenaPtr {
13926 #[inline]
13927 fn get(self) -> *mut u8 {
13928 self.0 as *mut u8
13929 }
13930 }
13931
13932 unsafe {
13933 let arena = ArenaPtr(base as usize);
13934 let (b, s, h, n) = (batch, seq, heads, state_size);
13935 let scale = 1.0f32 / (n as f32).sqrt();
13936 let use_external = state != 0;
13937 let mut owned_state = vec![0f32; h * n * n];
13938
13939 crate::pool::num_threads();
13940
13941 assert!(
13942 n <= crate::gdn::GDN_MAX_STATE,
13943 "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
13944 crate::gdn::GDN_MAX_STATE
13945 );
13946
13947 let qs = sl(q, arena.get(), b * s * h * n);
13948 let ks = sl(k, arena.get(), b * s * h * n);
13949 let vs = sl(v, arena.get(), b * s * h * n);
13950 let gs = sl(g, arena.get(), b * s * h);
13951 let betas = sl(beta, arena.get(), b * s * h);
13952 let _out = sl_mut(dst, arena.get(), b * s * h * n);
13953 let hs_n = h * n;
13954
13955 let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
13956 for ti in 0..s {
13957 let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
13958 let gb_step = bi * s * h + ti * h + hi;
13959 let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
13960 crate::gdn::gdn_step_blas(
13961 s_mat,
13962 &qs[qkv_step..qkv_step + n],
13963 &ks[qkv_step..qkv_step + n],
13964 &vs[qkv_step..qkv_step + n],
13965 gs[gb_step],
13966 betas[gb_step],
13967 out_row,
13968 sk,
13969 n,
13970 scale,
13971 );
13972 }
13973 };
13974
13975 if !use_external && s > 1 {
13978 for bi in 0..b {
13979 (0..h).into_par_iter().for_each(|hi| {
13980 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
13981 let sk = &mut sk_buf[..n];
13982 let mut local_state =
13983 [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
13984 let s_mat = &mut local_state[..n * n];
13985 s_mat.fill(0.0);
13986 run_head(bi, hi, s_mat, sk);
13987 });
13988 }
13989 return;
13990 }
13991
13992 if use_external {
13993 let state_bytes = state;
13994 (0..b * h).into_par_iter().for_each(|bhi| {
13995 let bi = bhi / h;
13996 let hi = bhi % h;
13997 let elem_off = bi * h * n * n + hi * n * n;
13998 let s_mat = sl_mut(
13999 state_bytes + elem_off * std::mem::size_of::<f32>(),
14000 arena.get(),
14001 n * n,
14002 );
14003 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14004 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
14005 });
14006 } else {
14007 for bi in 0..b {
14008 owned_state.fill(0.0);
14009 owned_state
14010 .par_chunks_mut(n * n)
14011 .enumerate()
14012 .for_each(|(hi, s_mat)| {
14013 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14014 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
14015 });
14016 }
14017 }
14018 }
14019}
14020
14021pub unsafe fn execute_rms_norm_backward_input_f32(
14023 x: usize,
14024 gamma: usize,
14025 beta: usize,
14026 dy: usize,
14027 dx: usize,
14028 rows: u32,
14029 h: u32,
14030 eps: f32,
14031 base: *mut u8,
14032) {
14033 let (rows, h) = (rows as usize, h as usize);
14034 let mut dg = vec![0f32; h];
14035 let mut db = vec![0f32; h];
14036 let xs = sl(x, base, rows * h);
14037 let dys = sl(dy, base, rows * h);
14038 let g = sl(gamma, base, h);
14039 let b = sl(beta, base, h);
14040 let out = sl_mut(dx, base, rows * h);
14041 for r in 0..rows {
14042 crate::training_bwd::rms_norm_backward_row(
14043 &xs[r * h..(r + 1) * h],
14044 g,
14045 b,
14046 &dys[r * h..(r + 1) * h],
14047 &mut out[r * h..(r + 1) * h],
14048 &mut dg,
14049 &mut db,
14050 eps,
14051 );
14052 }
14053}
14054
14055pub unsafe fn execute_rms_norm_backward_gamma_f32(
14056 x: usize,
14057 gamma: usize,
14058 beta: usize,
14059 dy: usize,
14060 dgamma: usize,
14061 rows: u32,
14062 h: u32,
14063 eps: f32,
14064 base: *mut u8,
14065) {
14066 let (rows, h) = (rows as usize, h as usize);
14067 let out = sl_mut(dgamma, base, h);
14068 out.fill(0.0);
14069 let mut dx = vec![0f32; h];
14070 let mut db = vec![0f32; h];
14071 let xs = sl(x, base, rows * h);
14072 let dys = sl(dy, base, rows * h);
14073 let g = sl(gamma, base, h);
14074 let b = sl(beta, base, h);
14075 for r in 0..rows {
14076 crate::training_bwd::rms_norm_backward_row(
14077 &xs[r * h..(r + 1) * h],
14078 g,
14079 b,
14080 &dys[r * h..(r + 1) * h],
14081 &mut dx,
14082 out,
14083 &mut db,
14084 eps,
14085 );
14086 }
14087}
14088
14089pub unsafe fn execute_rms_norm_backward_beta_f32(
14090 x: usize,
14091 gamma: usize,
14092 beta: usize,
14093 dy: usize,
14094 dbeta: usize,
14095 rows: u32,
14096 h: u32,
14097 eps: f32,
14098 base: *mut u8,
14099) {
14100 let (rows, h) = (rows as usize, h as usize);
14101 let out = sl_mut(dbeta, base, h);
14102 out.fill(0.0);
14103 let mut dx = vec![0f32; h];
14104 let mut dg = vec![0f32; h];
14105 let xs = sl(x, base, rows * h);
14106 let dys = sl(dy, base, rows * h);
14107 let g = sl(gamma, base, h);
14108 let b = sl(beta, base, h);
14109 for r in 0..rows {
14110 crate::training_bwd::rms_norm_backward_row(
14111 &xs[r * h..(r + 1) * h],
14112 g,
14113 b,
14114 &dys[r * h..(r + 1) * h],
14115 &mut dx,
14116 &mut dg,
14117 out,
14118 eps,
14119 );
14120 }
14121}
14122
14123#[allow(clippy::too_many_arguments)]
14124pub unsafe fn execute_conv2d_forward_f32(
14125 src: usize,
14126 weight: usize,
14127 dst: usize,
14128 n: u32,
14129 c_in: u32,
14130 h: u32,
14131 w: u32,
14132 c_out: u32,
14133 h_out: u32,
14134 w_out: u32,
14135 kh: u32,
14136 kw: u32,
14137 sh: u32,
14138 sw: u32,
14139 ph: u32,
14140 pw: u32,
14141 dh: u32,
14142 dw: u32,
14143 groups: u32,
14144 base: *mut u8,
14145) {
14146 let n = n as usize;
14147 let c_in = c_in as usize;
14148 let h = h as usize;
14149 let w = w as usize;
14150 let c_out = c_out as usize;
14151 let h_out = h_out as usize;
14152 let w_out = w_out as usize;
14153 let kh = kh as usize;
14154 let kw = kw as usize;
14155 let sh = sh as usize;
14156 let sw = sw as usize;
14157 let ph = ph as usize;
14158 let pw = pw as usize;
14159 let dh = dh as usize;
14160 let dw = dw as usize;
14161 let groups = groups as usize;
14162 let c_in_per_g = c_in / groups;
14163 let inp = sl(src, base, n * c_in * h * w);
14164 let wt = sl(weight, base, c_out * c_in_per_g * kh * kw);
14165 let out = sl_mut(dst, base, n * c_out * h_out * w_out);
14166 crate::conv_fwd::conv2d_forward_nchw_f32(
14167 inp, wt, out, n, c_in, h, w, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw, groups,
14168 );
14169}
14170
14171pub unsafe fn execute_maxpool2d_backward_f32(
14172 x: usize,
14173 dy: usize,
14174 dx: usize,
14175 n: u32,
14176 c: u32,
14177 h: u32,
14178 w: u32,
14179 h_out: u32,
14180 w_out: u32,
14181 kh: u32,
14182 kw: u32,
14183 sh: u32,
14184 sw: u32,
14185 ph: u32,
14186 pw: u32,
14187 base: *mut u8,
14188) {
14189 let (n, c, h, w) = (n as usize, c as usize, h as usize, w as usize);
14190 let (h_out, w_out) = (h_out as usize, w_out as usize);
14191 let (kh, kw) = (kh as usize, kw as usize);
14192 let (sh, sw) = (sh as usize, sw as usize);
14193 let (ph, pw) = (ph as usize, pw as usize);
14194 let xs = sl(x, base, n * c * h * w);
14195 let dys = sl(dy, base, n * c * h_out * w_out);
14196 let dxs = sl_mut(dx, base, n * c * h * w);
14197 crate::training_bwd::maxpool2d_backward_nchw(
14198 xs, dys, dxs, n, c, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw,
14199 );
14200}
14201
14202pub unsafe fn execute_rope_backward_f32(
14203 dy: usize,
14204 cos: usize,
14205 sin: usize,
14206 dx: usize,
14207 batch: u32,
14208 seq: u32,
14209 hidden: u32,
14210 head_dim: u32,
14211 n_rot: u32,
14212 cos_len: u32,
14213 base: *mut u8,
14214) {
14215 let (b, s, hs, dh, nr, cl) = (
14216 batch as usize,
14217 seq as usize,
14218 hidden as usize,
14219 head_dim as usize,
14220 n_rot as usize,
14221 cos_len as usize,
14222 );
14223 let nh = hs / dh;
14224 let tab_half = dh / 2;
14225 let dys = sl(dy, base, b * s * hs);
14226 let cos_tab = sl(cos, base, cl);
14227 let sin_tab = sl(sin, base, cl);
14228 let out = sl_mut(dx, base, b * s * hs);
14229 for bi in 0..b {
14230 for si in 0..s {
14231 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
14232 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
14233 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
14234 for hi in 0..nh {
14235 let base_idx = bi * s * hs + si * hs + hi * dh;
14236 crate::training_bwd::rope_backward_row(
14237 &dys[base_idx..base_idx + dh],
14238 cp,
14239 sp,
14240 &mut out[base_idx..base_idx + dh],
14241 dh,
14242 nr,
14243 );
14244 }
14245 }
14246 }
14247}
14248
14249pub unsafe fn execute_cumsum_backward_f32(
14250 dy: usize,
14251 dx: usize,
14252 rows: u32,
14253 cols: u32,
14254 exclusive: bool,
14255 base: *mut u8,
14256) {
14257 let (rows, cols) = (rows as usize, cols as usize);
14258 let dys = sl(dy, base, rows * cols);
14259 let out = sl_mut(dx, base, rows * cols);
14260 for r in 0..rows {
14261 crate::training_bwd::cumsum_backward_row(
14262 &dys[r * cols..(r + 1) * cols],
14263 &mut out[r * cols..(r + 1) * cols],
14264 exclusive,
14265 );
14266 }
14267}
14268
14269pub unsafe fn execute_gather_backward_f32(
14270 dy: usize,
14271 indices: usize,
14272 dst: usize,
14273 outer: u32,
14274 axis_dim: u32,
14275 num_idx: u32,
14276 trailing: u32,
14277 base: *mut u8,
14278) {
14279 let (outer, axis_dim, num_idx, trailing) = (
14280 outer as usize,
14281 axis_dim as usize,
14282 num_idx as usize,
14283 trailing as usize,
14284 );
14285 let out = sl_mut(dst, base, outer * axis_dim * trailing);
14286 out.fill(0.0);
14287 crate::training_bwd::gather_axis_backward(
14288 sl(dy, base, outer * num_idx * trailing),
14289 sl(indices, base, num_idx),
14290 out,
14291 outer,
14292 axis_dim,
14293 num_idx,
14294 trailing,
14295 );
14296}
14297
14298pub unsafe fn execute_dequant_matmul_gguf_f32(
14300 x: usize,
14301 w_q: usize,
14302 dst: usize,
14303 m: usize,
14304 k: usize,
14305 n: usize,
14306 scheme: rlx_ir::quant::QuantScheme,
14307 base: *mut u8,
14308) {
14309 unsafe {
14310 let block_bytes = scheme.gguf_block_bytes() as usize;
14311 let block_elems = scheme.gguf_block_size() as usize;
14312 let total_bytes = (k * n) / block_elems * block_bytes;
14313 let xs = sl(x, base, m * k);
14314 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
14315 let out = sl_mut(dst, base, m * n);
14316 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
14317 }
14318}
14319
14320pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
14322 input: usize,
14323 w_q: usize,
14324 expert_idx: usize,
14325 dst: usize,
14326 m: usize,
14327 k: usize,
14328 n: usize,
14329 num_experts: usize,
14330 scheme: rlx_ir::quant::QuantScheme,
14331 base: *mut u8,
14332) {
14333 unsafe {
14334 let block_bytes = scheme.gguf_block_bytes() as usize;
14335 let block_elems = scheme.gguf_block_size() as usize;
14336 let slab_bytes = (k * n) / block_elems * block_bytes;
14337 let xs = sl(input, base, m * k);
14338 let w_bytes =
14339 std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
14340 let ids = sl(expert_idx, base, m);
14341 let out = sl_mut(dst, base, m * n);
14342 crate::gguf_matmul::gguf_grouped_matmul_bt(
14343 xs,
14344 w_bytes,
14345 ids,
14346 out,
14347 m,
14348 k,
14349 n,
14350 num_experts,
14351 scheme,
14352 );
14353 }
14354}
14355
14356pub unsafe fn execute_dequant_matmul_int4_f32(
14358 x: usize,
14359 w_q: usize,
14360 scale: usize,
14361 zp: usize,
14362 dst: usize,
14363 m: usize,
14364 k: usize,
14365 n: usize,
14366 block_size: u32,
14367 is_asymmetric: bool,
14368 base: *mut u8,
14369) {
14370 let bs = block_size as usize;
14371 let n_blocks = k.div_ceil(bs);
14372 unsafe {
14373 let xs = sl(x, base, m * k);
14374 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
14375 let scales = sl(scale, base, n_blocks * n);
14376 let zps = if is_asymmetric {
14377 sl(zp, base, n_blocks * n)
14378 } else {
14379 &[][..]
14380 };
14381 let out = sl_mut(dst, base, m * n);
14382 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
14383 }
14384}
14385
14386pub unsafe fn execute_dequant_matmul_fp8_f32(
14388 x: usize,
14389 w_q: usize,
14390 scale: usize,
14391 dst: usize,
14392 m: usize,
14393 k: usize,
14394 n: usize,
14395 e5m2: bool,
14396 base: *mut u8,
14397) {
14398 unsafe {
14399 let xs = sl(x, base, m * k);
14400 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
14401 let scales = sl(scale, base, n);
14402 let out = sl_mut(dst, base, m * n);
14403 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
14404 }
14405}
14406
14407pub unsafe fn execute_dequant_matmul_nvfp4_f32(
14409 x: usize,
14410 w_q: usize,
14411 scale: usize,
14412 global_scale: usize,
14413 dst: usize,
14414 m: usize,
14415 k: usize,
14416 n: usize,
14417 base: *mut u8,
14418) {
14419 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
14420 unsafe {
14421 let xs = sl(x, base, m * k);
14422 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
14423 let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
14424 let gs = sl(global_scale, base, 1)[0];
14425 let out = sl_mut(dst, base, m * n);
14426 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
14427 }
14428}
14429
14430pub unsafe fn execute_gated_delta_net_f16(
14432 q: usize,
14433 k: usize,
14434 v: usize,
14435 g: usize,
14436 beta: usize,
14437 state: usize,
14438 dst: usize,
14439 batch: usize,
14440 seq: usize,
14441 heads: usize,
14442 state_size: usize,
14443 base: *mut u8,
14444) {
14445 use half::f16;
14446 unsafe {
14447 let read_f16 = |off: usize, len: usize| -> Vec<f32> {
14448 let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
14449 raw.chunks_exact(2)
14450 .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
14451 .collect()
14452 };
14453 let write_f16 = |off: usize, data: &[f32]| {
14454 let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
14455 for (i, &v) in data.iter().enumerate() {
14456 let le = f16::from_f32(v).to_le_bytes();
14457 out[i * 2] = le[0];
14458 out[i * 2 + 1] = le[1];
14459 }
14460 };
14461
14462 let (b, s, h, n) = (batch, seq, heads, state_size);
14463 let q_f = read_f16(q, b * s * h * n);
14464 let k_f = read_f16(k, b * s * h * n);
14465 let v_f = read_f16(v, b * s * h * n);
14466 let g_f = read_f16(g, b * s * h);
14467 let b_f = read_f16(beta, b * s * h);
14468 let mut state_f = if state != 0 {
14469 read_f16(state, b * h * n * n)
14470 } else {
14471 vec![0f32; b * h * n * n]
14472 };
14473 let mut out_f = vec![0f32; b * s * h * n];
14474 let scale = 1.0f32 / (n as f32).sqrt();
14475 let mut sk_buf = vec![0f32; n];
14476 let mut owned_state = vec![0f32; h * n * n];
14477
14478 for bi in 0..b {
14479 let state_slice: &mut [f32] = if state != 0 {
14480 let start = bi * h * n * n;
14481 &mut state_f[start..start + h * n * n]
14482 } else {
14483 owned_state.fill(0.0);
14484 &mut owned_state
14485 };
14486
14487 for ti in 0..s {
14488 let qkv_step_base = bi * s * h * n + ti * h * n;
14489 let gb_step_base = bi * s * h + ti * h;
14490
14491 for hi in 0..h {
14492 let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14493 let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14494 let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14495 let g_t = g_f[gb_step_base + hi];
14496 let beta_t = b_f[gb_step_base + hi];
14497
14498 let s_base = hi * n * n;
14499 let s_mat = &mut state_slice[s_base..s_base + n * n];
14500
14501 let g_exp = g_t.exp();
14502 for st in s_mat.iter_mut() {
14503 *st *= g_exp;
14504 }
14505
14506 for j in 0..n {
14507 let mut acc = 0f32;
14508 for i in 0..n {
14509 acc += s_mat[i * n + j] * k_row[i];
14510 }
14511 sk_buf[j] = acc;
14512 }
14513
14514 for j in 0..n {
14515 sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
14516 }
14517
14518 for i in 0..n {
14519 let ki = k_row[i];
14520 for j in 0..n {
14521 s_mat[i * n + j] += ki * sk_buf[j];
14522 }
14523 }
14524
14525 let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14526 for j in 0..n {
14527 let mut acc = 0f32;
14528 for i in 0..n {
14529 acc += s_mat[i * n + j] * q_row[i];
14530 }
14531 out_row[j] = acc * scale;
14532 }
14533 }
14534 }
14535 }
14536
14537 write_f16(dst, &out_f);
14538 if state != 0 {
14539 write_f16(state, &state_f);
14540 }
14541 }
14542}
14543
14544pub unsafe fn execute_group_norm_nchw_f32(
14546 src: usize,
14547 g: usize,
14548 b: usize,
14549 dst: usize,
14550 n: usize,
14551 c: usize,
14552 h: usize,
14553 w: usize,
14554 num_groups: usize,
14555 eps: f32,
14556 base: *mut u8,
14557) {
14558 let plane = c * h * w;
14559 for ni in 0..n {
14560 let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
14561 let gamma = unsafe { sl(g, base, c) };
14562 let beta = unsafe { sl(b, base, c) };
14563 let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
14564 crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
14565 }
14566}
14567
14568pub unsafe fn execute_layer_norm2d_nchw_f32(
14570 src: usize,
14571 g: usize,
14572 b: usize,
14573 dst: usize,
14574 n: usize,
14575 c: usize,
14576 h: usize,
14577 w: usize,
14578 eps: f32,
14579 base: *mut u8,
14580) {
14581 let plane = c * h * w;
14582 unsafe {
14583 let input = sl(src, base, n * plane);
14584 let gamma = sl(g, base, c);
14585 let beta = sl(b, base, c);
14586 let output = sl_mut(dst, base, n * plane);
14587 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
14588 }
14589}
14590
14591pub unsafe fn execute_conv_transpose2d_nchw_f32(
14593 src: usize,
14594 weight: usize,
14595 dst: usize,
14596 n: usize,
14597 c_in: usize,
14598 h: usize,
14599 w_in: usize,
14600 c_out: usize,
14601 h_out: usize,
14602 w_out: usize,
14603 kh: usize,
14604 kw: usize,
14605 sh: usize,
14606 sw: usize,
14607 ph: usize,
14608 pw: usize,
14609 dh: usize,
14610 dw: usize,
14611 groups: usize,
14612 base: *mut u8,
14613) {
14614 let in_elems = n * c_in * h * w_in;
14615 let w_elems = c_in * (c_out / groups) * kh * kw;
14616 let out_elems = n * c_out * h_out * w_out;
14617 unsafe {
14618 let input = sl(src, base, in_elems);
14619 let wt = sl(weight, base, w_elems);
14620 let output = sl_mut(dst, base, out_elems);
14621 crate::kernels::conv_transpose2d_nchw(
14622 input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
14623 dw, groups,
14624 );
14625 }
14626}
14627
14628pub unsafe fn execute_resize_nearest_2x_f32(
14630 src: usize,
14631 dst: usize,
14632 n: usize,
14633 c: usize,
14634 h: usize,
14635 w: usize,
14636 base: *mut u8,
14637) {
14638 let in_plane = c * h * w;
14639 let out_plane = c * h * 2 * w * 2;
14640 for ni in 0..n {
14641 let input = unsafe {
14642 sl(
14643 src + ni * in_plane * std::mem::size_of::<f32>(),
14644 base,
14645 in_plane,
14646 )
14647 };
14648 let output = unsafe {
14649 sl_mut(
14650 dst + ni * out_plane * std::mem::size_of::<f32>(),
14651 base,
14652 out_plane,
14653 )
14654 };
14655 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
14656 }
14657}
14658
14659pub unsafe fn execute_axial_rope2d_f32(
14661 src: usize,
14662 dst: usize,
14663 batch: usize,
14664 seq: usize,
14665 hidden: usize,
14666 end_x: usize,
14667 end_y: usize,
14668 head_dim: usize,
14669 num_heads: usize,
14670 theta: f32,
14671 repeat_factor: usize,
14672 base: *mut u8,
14673) {
14674 let plane = seq * hidden;
14675 let plane_bytes = plane * std::mem::size_of::<f32>();
14676 for bi in 0..batch {
14677 let in_off = src + bi * plane_bytes;
14678 let input = unsafe { sl(in_off, base, plane) };
14679 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
14680 input,
14681 num_heads,
14682 seq,
14683 head_dim,
14684 end_x,
14685 end_y,
14686 theta,
14687 repeat_factor,
14688 );
14689 let out_off = dst + bi * plane_bytes;
14690 let output = unsafe { sl_mut(out_off, base, plane) };
14691 output.copy_from_slice(&rotated);
14692 }
14693}
14694
14695pub unsafe fn execute_fft_butterfly_stage_f32(
14697 state_src: usize,
14698 state_dst: usize,
14699 gate_src: usize,
14700 rev_src: usize,
14701 tw_re_src: usize,
14702 tw_im_src: usize,
14703 batch: usize,
14704 n_fft: usize,
14705 stage: usize,
14706 base: *mut u8,
14707) {
14708 let half = n_fft / 2;
14709 let stride = 1usize << stage;
14710 let gate = unsafe { sl(gate_src, base, half) };
14711 let rev = unsafe { sl(rev_src, base, half) };
14712 let tw_re = unsafe { sl(tw_re_src, base, half) };
14713 let tw_im = unsafe { sl(tw_im_src, base, half) };
14714 let row_elems = n_fft * 2;
14715 for b in 0..batch {
14716 let in_off = state_src + b * row_elems * std::mem::size_of::<f32>();
14717 let out_off = state_dst + b * row_elems * std::mem::size_of::<f32>();
14718 let inp = unsafe { sl(in_off, base, row_elems) };
14719 let out = unsafe { sl_mut(out_off, base, row_elems) };
14720 out.copy_from_slice(inp);
14721 for bf in 0..half {
14722 if gate[bf] == 0.0 {
14723 continue;
14724 }
14725 let group = bf / stride;
14726 let k = bf % stride;
14727 let i0 = group * 2 * stride + k;
14728 let i1 = i0 + stride;
14729 let w_re = tw_re[bf];
14730 let w_im = tw_im[bf];
14731 let in_a_re = inp[i0 * 2];
14732 let in_a_im = inp[i0 * 2 + 1];
14733 let in_b_re = inp[i1 * 2];
14734 let in_b_im = inp[i1 * 2 + 1];
14735 let (b_re, b_im) = (
14736 in_b_re * w_re - in_b_im * w_im,
14737 in_b_re * w_im + in_b_im * w_re,
14738 );
14739 let (top_re, top_im) = (in_a_re + b_re, in_a_im + b_im);
14740 let (bot_re, bot_im) = (in_a_re - b_re, in_a_im - b_im);
14741 let (oa_re, oa_im, ob_re, ob_im) = if rev[bf] >= 0.5 {
14742 (bot_re, bot_im, top_re, top_im)
14743 } else {
14744 (top_re, top_im, bot_re, bot_im)
14745 };
14746 out[i0 * 2] = oa_re;
14747 out[i0 * 2 + 1] = oa_im;
14748 out[i1 * 2] = ob_re;
14749 out[i1 * 2 + 1] = ob_im;
14750 }
14751 }
14752}
14753
14754pub unsafe fn execute_fft1d_f32(
14756 src: usize,
14757 dst: usize,
14758 outer: usize,
14759 n_complex: usize,
14760 inverse: bool,
14761 norm_tag: u32,
14762 base: *mut u8,
14763) {
14764 let row_elems = 2 * n_complex;
14765 let mut re = vec![0f32; n_complex];
14766 let mut im = vec![0f32; n_complex];
14767 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14768 let scale = norm.output_scale(n_complex, inverse) as f32;
14769 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14770 BluesteinScratchF32::empty()
14771 } else {
14772 BluesteinScratchF32::build(n_complex, inverse)
14773 };
14774 for o in 0..outer {
14775 let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
14776 let s = unsafe { sl(row_offset, base, row_elems) };
14777 re.copy_from_slice(&s[..n_complex]);
14778 im.copy_from_slice(&s[n_complex..]);
14779 if n_complex.is_power_of_two() {
14780 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
14781 } else if n_complex <= 16 {
14782 fft_naive_inplace_f32(&mut re, &mut im, inverse);
14783 } else {
14784 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
14785 }
14786 if scale != 1.0 {
14787 re.iter_mut().for_each(|v| *v *= scale);
14788 im.iter_mut().for_each(|v| *v *= scale);
14789 }
14790 let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
14791 let d = unsafe { sl_mut(dst_offset, base, row_elems) };
14792 d[..n_complex].copy_from_slice(&re);
14793 d[n_complex..].copy_from_slice(&im);
14794 }
14795}
14796
14797pub unsafe fn execute_fft1d_c64(
14799 src: usize,
14800 dst: usize,
14801 outer: usize,
14802 n_complex: usize,
14803 inverse: bool,
14804 norm_tag: u32,
14805 base: *mut u8,
14806) {
14807 let row_bytes = n_complex * 8;
14808 let mut re = vec![0f32; n_complex];
14809 let mut im = vec![0f32; n_complex];
14810 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14811 let scale = norm.output_scale(n_complex, inverse) as f32;
14812 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14813 BluesteinScratchF32::empty()
14814 } else {
14815 BluesteinScratchF32::build(n_complex, inverse)
14816 };
14817 for o in 0..outer {
14818 let row_offset = src + o * row_bytes;
14819 for i in 0..n_complex {
14820 let elem_off = row_offset + i * 8;
14821 re[i] = f32::from_le_bytes([
14822 *base.add(elem_off),
14823 *base.add(elem_off + 1),
14824 *base.add(elem_off + 2),
14825 *base.add(elem_off + 3),
14826 ]);
14827 im[i] = f32::from_le_bytes([
14828 *base.add(elem_off + 4),
14829 *base.add(elem_off + 5),
14830 *base.add(elem_off + 6),
14831 *base.add(elem_off + 7),
14832 ]);
14833 }
14834 if n_complex.is_power_of_two() {
14835 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
14836 } else if n_complex <= 16 {
14837 fft_naive_inplace_f32(&mut re, &mut im, inverse);
14838 } else {
14839 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
14840 }
14841 if scale != 1.0 {
14842 re.iter_mut().for_each(|v| *v *= scale);
14843 im.iter_mut().for_each(|v| *v *= scale);
14844 }
14845 let dst_row = dst + o * row_bytes;
14846 for i in 0..n_complex {
14847 let elem_off = dst_row + i * 8;
14848 let re_b = re[i].to_le_bytes();
14849 let im_b = im[i].to_le_bytes();
14850 for j in 0..4 {
14851 *base.add(elem_off + j) = re_b[j];
14852 *base.add(elem_off + 4 + j) = im_b[j];
14853 }
14854 }
14855 }
14856}
14857
14858pub unsafe fn execute_log_mel(
14860 spec: usize,
14861 filters: usize,
14862 dst: usize,
14863 outer: usize,
14864 n_fft: usize,
14865 n_bins: usize,
14866 n_mels: usize,
14867 base: *mut u8,
14868) {
14869 execute_log_mel_f32(spec, filters, dst, outer, n_fft, n_bins, n_mels, base);
14870}
14871
14872pub unsafe fn execute_log_mel_f32(
14873 spec: usize,
14874 filters: usize,
14875 dst: usize,
14876 outer: usize,
14877 n_fft: usize,
14878 n_bins: usize,
14879 n_mels: usize,
14880 base: *mut u8,
14881) {
14882 let spec_ptr = base.add(spec) as *const f32;
14883 let filt_ptr = base.add(filters) as *const f32;
14884 let dst_ptr = base.add(dst) as *mut f32;
14885 let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14886 let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
14887 let out = std::slice::from_raw_parts_mut(dst_ptr, outer * n_mels);
14888 rlx_ir::audio::log_mel_block_f32(spec, filters, outer, n_fft, n_bins, n_mels, out);
14889}
14890
14891pub unsafe fn execute_welch_peaks_f32(
14892 spec: usize,
14893 dst: usize,
14894 welch_batch: usize,
14895 n_fft: usize,
14896 n_segments: usize,
14897 k: usize,
14898 base: *mut u8,
14899) {
14900 let spec_ptr = base.add(spec) as *const f32;
14901 let dst_ptr = base.add(dst) as *mut f32;
14902 let outer = welch_batch * n_segments;
14903 let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14904 let out = std::slice::from_raw_parts_mut(dst_ptr, welch_batch * k * 2);
14905 rlx_ir::audio::welch_peaks_block_f32(spec, welch_batch, n_fft, n_segments, k, out);
14906}
14907
14908pub unsafe fn execute_log_mel_backward_f32(
14909 spec: usize,
14910 filters: usize,
14911 dy: usize,
14912 dst: usize,
14913 outer: usize,
14914 n_fft: usize,
14915 n_bins: usize,
14916 n_mels: usize,
14917 base: *mut u8,
14918) {
14919 let spec_ptr = base.add(spec) as *const f32;
14920 let filt_ptr = base.add(filters) as *const f32;
14921 let dy_ptr = base.add(dy) as *const f32;
14922 let dst_ptr = base.add(dst) as *mut f32;
14923 let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14924 let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
14925 let dy = std::slice::from_raw_parts(dy_ptr, outer * n_mels);
14926 let d_spec = std::slice::from_raw_parts_mut(dst_ptr, outer * n_fft * 2);
14927 d_spec.fill(0.0);
14928 rlx_ir::audio::log_mel_block_vjp(spec, filters, dy, outer, n_fft, n_bins, n_mels, d_spec);
14929}
14930
14931pub unsafe fn execute_fft1d(
14933 src: usize,
14934 dst: usize,
14935 outer: usize,
14936 n_complex: usize,
14937 inverse: bool,
14938 norm_tag: u32,
14939 dtype: rlx_ir::DType,
14940 base: *mut u8,
14941) {
14942 match dtype {
14943 rlx_ir::DType::F32 => {
14944 execute_fft1d_f32(src, dst, outer, n_complex, inverse, norm_tag, base)
14945 }
14946 rlx_ir::DType::F64 => {
14947 execute_fft1d_f64(src, dst, outer, n_complex, inverse, norm_tag, base)
14948 }
14949 rlx_ir::DType::C64 => {
14950 execute_fft1d_c64(src, dst, outer, n_complex, inverse, norm_tag, base)
14951 }
14952 other => panic!("execute_fft1d: unsupported dtype {other:?}"),
14953 }
14954}
14955
14956fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
14961 let n = re.len();
14962 debug_assert_eq!(im.len(), n);
14963 debug_assert!(
14964 n.is_power_of_two(),
14965 "fft_radix2_f32: n={n} must be a power of two"
14966 );
14967 if n <= 1 {
14968 return;
14969 }
14970
14971 let mut j = 0usize;
14972 for i in 1..n {
14973 let mut bit = n >> 1;
14974 while j & bit != 0 {
14975 j ^= bit;
14976 bit >>= 1;
14977 }
14978 j ^= bit;
14979 if i < j {
14980 re.swap(i, j);
14981 im.swap(i, j);
14982 }
14983 }
14984
14985 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
14986 let mut len = 2usize;
14987 while len <= n {
14988 let half = len / 2;
14989 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
14990 let w_re_step = theta.cos();
14991 let w_im_step = theta.sin();
14992 let mut i = 0usize;
14993 while i < n {
14994 let mut wre = 1.0_f64;
14995 let mut wim = 0.0_f64;
14996 for k in 0..half {
14997 let wre_f = wre as f32;
14998 let wim_f = wim as f32;
14999 let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
15000 let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
15001 let u_re = re[i + k];
15002 let u_im = im[i + k];
15003 re[i + k] = u_re + t_re;
15004 im[i + k] = u_im + t_im;
15005 re[i + k + half] = u_re - t_re;
15006 im[i + k + half] = u_im - t_im;
15007 let new_wre = wre * w_re_step - wim * w_im_step;
15008 let new_wim = wre * w_im_step + wim * w_re_step;
15009 wre = new_wre;
15010 wim = new_wim;
15011 }
15012 i += len;
15013 }
15014 len <<= 1;
15015 }
15016}
15017
15018fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15022 let n = re.len();
15023 debug_assert_eq!(im.len(), n);
15024 debug_assert!(
15025 n.is_power_of_two(),
15026 "fft_radix2: n={n} must be a power of two"
15027 );
15028 if n <= 1 {
15029 return;
15030 }
15031
15032 let mut j = 0usize;
15034 for i in 1..n {
15035 let mut bit = n >> 1;
15036 while j & bit != 0 {
15037 j ^= bit;
15038 bit >>= 1;
15039 }
15040 j ^= bit;
15041 if i < j {
15042 re.swap(i, j);
15043 im.swap(i, j);
15044 }
15045 }
15046
15047 let sign = if inverse { 1.0 } else { -1.0 };
15049 let mut len = 2usize;
15050 while len <= n {
15051 let half = len / 2;
15052 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
15053 let w_re_step = theta.cos();
15054 let w_im_step = theta.sin();
15055 let mut i = 0usize;
15056 while i < n {
15057 let mut wre = 1.0_f64;
15059 let mut wim = 0.0_f64;
15060 for k in 0..half {
15061 let t_re = wre * re[i + k + half] - wim * im[i + k + half];
15062 let t_im = wre * im[i + k + half] + wim * re[i + k + half];
15063 let u_re = re[i + k];
15064 let u_im = im[i + k];
15065 re[i + k] = u_re + t_re;
15066 im[i + k] = u_im + t_im;
15067 re[i + k + half] = u_re - t_re;
15068 im[i + k + half] = u_im - t_im;
15069 let new_wre = wre * w_re_step - wim * w_im_step;
15070 let new_wim = wre * w_im_step + wim * w_re_step;
15071 wre = new_wre;
15072 wim = new_wim;
15073 }
15074 i += len;
15075 }
15076 len <<= 1;
15077 }
15078}
15079
15080struct BluesteinScratchF64 {
15084 m: usize,
15086 w_re: Vec<f64>,
15090 w_im: Vec<f64>,
15091 bf_re: Vec<f64>,
15094 bf_im: Vec<f64>,
15095 ar: Vec<f64>,
15097 ai: Vec<f64>,
15098}
15099
15100impl BluesteinScratchF64 {
15101 fn empty() -> Self {
15102 Self {
15103 m: 0,
15104 w_re: Vec::new(),
15105 w_im: Vec::new(),
15106 bf_re: Vec::new(),
15107 bf_im: Vec::new(),
15108 ar: Vec::new(),
15109 ai: Vec::new(),
15110 }
15111 }
15112
15113 fn build(n: usize, inverse: bool) -> Self {
15114 let m = if n <= 1 {
15117 1
15118 } else {
15119 (2 * n - 1).next_power_of_two()
15120 };
15121
15122 let mod_2n = (2 * n) as u64;
15125 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15126 let mut w_re = vec![0.0_f64; n];
15127 let mut w_im = vec![0.0_f64; n];
15128 for k in 0..n {
15129 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15130 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15131 w_re[k] = theta.cos();
15132 w_im[k] = theta.sin();
15133 }
15134
15135 let mut bf_re = vec![0.0_f64; m];
15138 let mut bf_im = vec![0.0_f64; m];
15139 if n > 0 {
15140 bf_re[0] = w_re[0];
15141 bf_im[0] = -w_im[0];
15142 for k in 1..n {
15143 bf_re[k] = w_re[k];
15144 bf_im[k] = -w_im[k];
15145 bf_re[m - k] = w_re[k];
15146 bf_im[m - k] = -w_im[k];
15147 }
15148 }
15149 if m > 1 {
15150 fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
15151 }
15152
15153 Self {
15154 m,
15155 w_re,
15156 w_im,
15157 bf_re,
15158 bf_im,
15159 ar: vec![0.0_f64; m],
15160 ai: vec![0.0_f64; m],
15161 }
15162 }
15163}
15164
15165fn fft_naive_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15167 let n = re.len();
15168 if n <= 1 {
15169 return;
15170 }
15171 let sign = if inverse { 1.0 } else { -1.0 };
15172 let mut out_re = vec![0.0_f64; n];
15173 let mut out_im = vec![0.0_f64; n];
15174 for k in 0..n {
15175 for nn in 0..n {
15176 let theta = sign * 2.0 * std::f64::consts::PI * (nn as f64) * (k as f64) / (n as f64);
15177 let c = theta.cos();
15178 let s = theta.sin();
15179 out_re[k] += re[nn] * c - im[nn] * s;
15180 out_im[k] += re[nn] * s + im[nn] * c;
15181 }
15182 }
15183 re.copy_from_slice(&out_re);
15184 im.copy_from_slice(&out_im);
15185}
15186
15187fn fft_naive_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
15188 let n = re.len();
15189 if n <= 1 {
15190 return;
15191 }
15192 let sign = if inverse { 1.0f32 } else { -1.0f32 };
15193 let mut out_re = vec![0.0_f32; n];
15194 let mut out_im = vec![0.0_f32; n];
15195 for k in 0..n {
15196 for nn in 0..n {
15197 let theta = sign * 2.0 * std::f32::consts::PI * (nn as f32) * (k as f32) / (n as f32);
15198 let c = theta.cos();
15199 let s = theta.sin();
15200 out_re[k] += re[nn] * c - im[nn] * s;
15201 out_im[k] += re[nn] * s + im[nn] * c;
15202 }
15203 }
15204 re.copy_from_slice(&out_re);
15205 im.copy_from_slice(&out_im);
15206}
15207
15208fn fft_bluestein_inplace_f64(
15217 re: &mut [f64],
15218 im: &mut [f64],
15219 _inverse: bool,
15220 s: &mut BluesteinScratchF64,
15221) {
15222 let n = re.len();
15223 debug_assert_eq!(im.len(), n);
15224 debug_assert_eq!(s.w_re.len(), n);
15225 if n <= 1 {
15226 return;
15227 }
15228 let m = s.m;
15229
15230 for k in 0..m {
15232 s.ar[k] = 0.0;
15233 s.ai[k] = 0.0;
15234 }
15235 for k in 0..n {
15236 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
15237 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
15238 }
15239
15240 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
15242
15243 for k in 0..m {
15245 let ar = s.ar[k];
15246 let ai = s.ai[k];
15247 let br = s.bf_re[k];
15248 let bi = s.bf_im[k];
15249 s.ar[k] = ar * br - ai * bi;
15250 s.ai[k] = ar * bi + ai * br;
15251 }
15252
15253 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
15256 let inv_m = 1.0 / (m as f64);
15257
15258 for k in 0..n {
15260 let yr = s.ar[k] * inv_m;
15261 let yi = s.ai[k] * inv_m;
15262 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
15263 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
15264 }
15265}
15266
15267struct BluesteinScratchF32 {
15271 m: usize,
15272 w_re: Vec<f32>,
15273 w_im: Vec<f32>,
15274 bf_re: Vec<f32>,
15275 bf_im: Vec<f32>,
15276 ar: Vec<f32>,
15277 ai: Vec<f32>,
15278}
15279
15280impl BluesteinScratchF32 {
15281 fn empty() -> Self {
15282 Self {
15283 m: 0,
15284 w_re: Vec::new(),
15285 w_im: Vec::new(),
15286 bf_re: Vec::new(),
15287 bf_im: Vec::new(),
15288 ar: Vec::new(),
15289 ai: Vec::new(),
15290 }
15291 }
15292
15293 fn build(n: usize, inverse: bool) -> Self {
15294 let m = if n <= 1 {
15295 1
15296 } else {
15297 (2 * n - 1).next_power_of_two()
15298 };
15299
15300 let mod_2n = (2 * n) as u64;
15301 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15302 let mut w_re = vec![0.0_f32; n];
15303 let mut w_im = vec![0.0_f32; n];
15304 for k in 0..n {
15305 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15306 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15307 w_re[k] = theta.cos() as f32;
15308 w_im[k] = theta.sin() as f32;
15309 }
15310
15311 let mut bf_re = vec![0.0_f32; m];
15312 let mut bf_im = vec![0.0_f32; m];
15313 if n > 0 {
15314 bf_re[0] = w_re[0];
15315 bf_im[0] = -w_im[0];
15316 for k in 1..n {
15317 bf_re[k] = w_re[k];
15318 bf_im[k] = -w_im[k];
15319 bf_re[m - k] = w_re[k];
15320 bf_im[m - k] = -w_im[k];
15321 }
15322 }
15323 if m > 1 {
15324 fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
15325 }
15326
15327 Self {
15328 m,
15329 w_re,
15330 w_im,
15331 bf_re,
15332 bf_im,
15333 ar: vec![0.0_f32; m],
15334 ai: vec![0.0_f32; m],
15335 }
15336 }
15337}
15338
15339fn fft_bluestein_inplace_f32(
15340 re: &mut [f32],
15341 im: &mut [f32],
15342 _inverse: bool,
15343 s: &mut BluesteinScratchF32,
15344) {
15345 let n = re.len();
15346 debug_assert_eq!(im.len(), n);
15347 debug_assert_eq!(s.w_re.len(), n);
15348 if n <= 1 {
15349 return;
15350 }
15351 let m = s.m;
15352
15353 for k in 0..m {
15354 s.ar[k] = 0.0;
15355 s.ai[k] = 0.0;
15356 }
15357 for k in 0..n {
15358 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
15359 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
15360 }
15361
15362 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
15363
15364 for k in 0..m {
15365 let ar = s.ar[k];
15366 let ai = s.ai[k];
15367 let br = s.bf_re[k];
15368 let bi = s.bf_im[k];
15369 s.ar[k] = ar * br - ai * bi;
15370 s.ai[k] = ar * bi + ai * br;
15371 }
15372
15373 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
15374 let inv_m = 1.0_f32 / (m as f32);
15375
15376 for k in 0..n {
15377 let yr = s.ar[k] * inv_m;
15378 let yi = s.ai[k] * inv_m;
15379 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
15380 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
15381 }
15382}
15383
15384unsafe fn dispatch_custom_op(
15390 kernel: &dyn crate::op_registry::CpuKernel,
15391 inputs: &[(usize, u32, Shape)],
15392 out_off: usize,
15393 out_len: u32,
15394 out_shape: &Shape,
15395 attrs: &[u8],
15396 base: *mut u8,
15397) {
15398 use crate::op_registry::{CpuTensorMut, CpuTensorRef};
15399 use rlx_ir::DType;
15400
15401 macro_rules! build_in_view {
15406 ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
15407 CpuTensorRef::$variant {
15408 data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
15409 shape: $shape,
15410 }
15411 };
15412 }
15413 macro_rules! build_out_view {
15414 ($variant:ident, $rust_ty:ty) => {
15415 CpuTensorMut::$variant {
15416 data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
15417 shape: out_shape,
15418 }
15419 };
15420 }
15421
15422 let in_views: Vec<CpuTensorRef<'_>> = inputs
15423 .iter()
15424 .map(|(off, len, shape)| {
15425 let n = *len as usize;
15426 let off = *off;
15427 match shape.dtype() {
15428 DType::F32 => build_in_view!(shape, off, n, F32, f32),
15429 DType::F64 => build_in_view!(shape, off, n, F64, f64),
15430 DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
15431 DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
15432 DType::I8 => build_in_view!(shape, off, n, I8, i8),
15433 DType::I16 => build_in_view!(shape, off, n, I16, i16),
15434 DType::I32 => build_in_view!(shape, off, n, I32, i32),
15435 DType::I64 => build_in_view!(shape, off, n, I64, i64),
15436 DType::U8 => build_in_view!(shape, off, n, U8, u8),
15437 DType::U32 => build_in_view!(shape, off, n, U32, u32),
15438 DType::Bool => build_in_view!(shape, off, n, Bool, u8),
15439 DType::C64 => panic!(
15443 "Op::Custom kernel input has DType::C64 — built-in \
15444 complex ops handle their own kernels; user-registered \
15445 ops don't yet see complex tensors"
15446 ),
15447 }
15448 })
15449 .collect();
15450
15451 let result = match out_shape.dtype() {
15452 DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
15453 DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
15454 DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
15455 DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
15456 DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
15457 DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
15458 DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
15459 DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
15460 DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
15461 DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
15462 DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
15463 DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
15464 };
15465 if let Err(e) = result {
15466 panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
15467 }
15468}
15469
15470#[inline(always)]
15476unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
15477 if offset == usize::MAX {
15478 return &[];
15479 }
15480 unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
15481}
15482
15483#[inline(always)]
15484unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
15485 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
15486}
15487
15488#[inline(always)]
15490fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
15494 use rlx_ir::op::Activation;
15495 match act {
15496 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
15497 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
15498 Activation::Silu => crate::kernels::par_silu_inplace(d),
15499 Activation::Relu => {
15500 for v in d.iter_mut() {
15501 *v = v.max(0.0);
15502 }
15503 }
15504 Activation::Sigmoid => {
15505 for v in d.iter_mut() {
15506 *v = 1.0 / (1.0 + (-*v).exp());
15507 }
15508 }
15509 Activation::Tanh => {
15510 for v in d.iter_mut() {
15511 *v = v.tanh();
15512 }
15513 }
15514 Activation::Exp => {
15515 for v in d.iter_mut() {
15516 *v = v.exp();
15517 }
15518 }
15519 Activation::Log => {
15520 for v in d.iter_mut() {
15521 *v = v.ln();
15522 }
15523 }
15524 Activation::Sqrt => {
15525 for v in d.iter_mut() {
15526 *v = v.sqrt();
15527 }
15528 }
15529 Activation::Rsqrt => {
15530 for v in d.iter_mut() {
15531 *v = 1.0 / v.sqrt();
15532 }
15533 }
15534 Activation::Neg => {
15535 for v in d.iter_mut() {
15536 *v = -*v;
15537 }
15538 }
15539 Activation::Abs => {
15540 for v in d.iter_mut() {
15541 *v = v.abs();
15542 }
15543 }
15544 Activation::Round => {
15545 for v in d.iter_mut() {
15546 *v = v.round();
15547 }
15548 }
15549 Activation::Sin => {
15550 for v in d.iter_mut() {
15551 *v = v.sin();
15552 }
15553 }
15554 Activation::Cos => {
15555 for v in d.iter_mut() {
15556 *v = v.cos();
15557 }
15558 }
15559 Activation::Tan => {
15560 for v in d.iter_mut() {
15561 *v = v.tan();
15562 }
15563 }
15564 Activation::Atan => {
15565 for v in d.iter_mut() {
15566 *v = v.atan();
15567 }
15568 }
15569 }
15570}
15571
15572#[allow(clippy::too_many_arguments)]
15581fn im2col(
15582 x: &[f32],
15583 col: &mut [f32],
15584 c_in: usize,
15585 h: usize,
15586 w: usize,
15587 h_out: usize,
15588 w_out: usize,
15589 kh: usize,
15590 kw: usize,
15591 sh: usize,
15592 sw: usize,
15593 ph: usize,
15594 pw: usize,
15595 dh: usize,
15596 dw_dil: usize,
15597) {
15598 let n_dim = h_out * w_out;
15599 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
15600 debug_assert_eq!(x.len(), c_in * h * w);
15601 let h_isz = h as isize;
15602 let w_isz = w as isize;
15603 let ph_isz = ph as isize;
15604 let pw_isz = pw as isize;
15605 for ci in 0..c_in {
15606 for ki in 0..kh {
15607 for kj in 0..kw {
15608 let row = ((ci * kh) + ki) * kw + kj;
15609 let row_off = row * n_dim;
15610 for ho in 0..h_out {
15611 let hi = (ho * sh + ki * dh) as isize - ph_isz;
15612 if hi < 0 || hi >= h_isz {
15613 for wo in 0..w_out {
15614 col[row_off + ho * w_out + wo] = 0.0;
15615 }
15616 continue;
15617 }
15618 let hi = hi as usize;
15619 let in_row_off = (ci * h + hi) * w;
15620 for wo in 0..w_out {
15621 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
15622 col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
15623 0.0
15624 } else {
15625 x[in_row_off + wi as usize]
15626 };
15627 }
15628 }
15629 }
15630 }
15631 }
15632}
15633
15634#[allow(clippy::too_many_arguments)]
15641fn col2im(
15642 col: &[f32],
15643 x: &mut [f32],
15644 c_in: usize,
15645 h: usize,
15646 w: usize,
15647 h_out: usize,
15648 w_out: usize,
15649 kh: usize,
15650 kw: usize,
15651 sh: usize,
15652 sw: usize,
15653 ph: usize,
15654 pw: usize,
15655 dh: usize,
15656 dw_dil: usize,
15657) {
15658 let n_dim = h_out * w_out;
15659 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
15660 debug_assert_eq!(x.len(), c_in * h * w);
15661 let h_isz = h as isize;
15662 let w_isz = w as isize;
15663 let ph_isz = ph as isize;
15664 let pw_isz = pw as isize;
15665 for ci in 0..c_in {
15666 for ki in 0..kh {
15667 for kj in 0..kw {
15668 let row = ((ci * kh) + ki) * kw + kj;
15669 let row_off = row * n_dim;
15670 for ho in 0..h_out {
15671 let hi = (ho * sh + ki * dh) as isize - ph_isz;
15672 if hi < 0 || hi >= h_isz {
15673 continue;
15674 }
15675 let hi = hi as usize;
15676 let in_row_off = (ci * h + hi) * w;
15677 for wo in 0..w_out {
15678 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
15679 if wi < 0 || wi >= w_isz {
15680 continue;
15681 }
15682 x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
15683 }
15684 }
15685 }
15686 }
15687 }
15688}
15689
15690fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
15700 match axis {
15701 None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
15702 Some(d) => {
15703 let chan_dim = shape.dim(d).unwrap_static();
15704 let inner: usize = (d + 1..shape.rank())
15705 .map(|i| shape.dim(i).unwrap_static())
15706 .product::<usize>()
15707 .max(1);
15708 (d, chan_dim, inner)
15709 }
15710 }
15711}
15712
15713fn activation_backward_kernel(
15714 act: rlx_ir::op::Activation,
15715 xs: &[f32],
15716 dys: &[f32],
15717 out: &mut [f32],
15718) {
15719 use rlx_ir::op::Activation;
15720 let n = xs.len();
15721 debug_assert_eq!(dys.len(), n);
15722 debug_assert_eq!(out.len(), n);
15723 match act {
15724 Activation::Relu => {
15725 for i in 0..n {
15726 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
15727 }
15728 }
15729 Activation::Sigmoid => {
15730 for i in 0..n {
15731 let s = 1.0 / (1.0 + (-xs[i]).exp());
15732 out[i] = s * (1.0 - s) * dys[i];
15733 }
15734 }
15735 Activation::Tanh => {
15736 for i in 0..n {
15737 let t = xs[i].tanh();
15738 out[i] = (1.0 - t * t) * dys[i];
15739 }
15740 }
15741 Activation::Silu => {
15742 for i in 0..n {
15744 let s = 1.0 / (1.0 + (-xs[i]).exp());
15745 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
15746 }
15747 }
15748 Activation::Gelu => {
15749 const INV_SQRT2: f32 = 0.707_106_77;
15752 const INV_SQRT_2PI: f32 = 0.398_942_3;
15753 for i in 0..n {
15754 let x = xs[i];
15755 let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
15756 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
15757 out[i] = (phi + x * pdf) * dys[i];
15758 }
15759 }
15760 Activation::GeluApprox => {
15761 const C: f32 = 0.797_884_6; const A: f32 = 0.044_715;
15765 for i in 0..n {
15766 let x = xs[i];
15767 let inner = C * (x + A * x * x * x);
15768 let t = inner.tanh();
15769 let dinner = C * (1.0 + 3.0 * A * x * x);
15770 let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
15771 out[i] = d * dys[i];
15772 }
15773 }
15774 Activation::Exp => {
15775 for i in 0..n {
15776 out[i] = xs[i].exp() * dys[i];
15777 }
15778 }
15779 Activation::Log => {
15780 for i in 0..n {
15781 out[i] = dys[i] / xs[i];
15782 }
15783 }
15784 Activation::Sqrt => {
15785 for i in 0..n {
15787 let s = xs[i].sqrt();
15788 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
15789 }
15790 }
15791 Activation::Rsqrt => {
15792 for i in 0..n {
15794 let s = xs[i].sqrt();
15795 out[i] = if s > 0.0 {
15796 -0.5 * dys[i] / (xs[i] * s)
15797 } else {
15798 0.0
15799 };
15800 }
15801 }
15802 Activation::Neg => {
15803 for i in 0..n {
15804 out[i] = -dys[i];
15805 }
15806 }
15807 Activation::Abs => {
15808 for i in 0..n {
15810 let x = xs[i];
15811 let s = if x > 0.0 {
15812 1.0
15813 } else if x < 0.0 {
15814 -1.0
15815 } else {
15816 0.0
15817 };
15818 out[i] = s * dys[i];
15819 }
15820 }
15821 Activation::Round => {
15822 out.copy_from_slice(dys);
15827 }
15828 Activation::Sin => {
15829 for i in 0..n {
15831 out[i] = xs[i].cos() * dys[i];
15832 }
15833 }
15834 Activation::Cos => {
15835 for i in 0..n {
15836 out[i] = -xs[i].sin() * dys[i];
15837 }
15838 }
15839 Activation::Tan => {
15840 for i in 0..n {
15842 let t = xs[i].tan();
15843 out[i] = (1.0 + t * t) * dys[i];
15844 }
15845 }
15846 Activation::Atan => {
15847 for i in 0..n {
15849 let x = xs[i];
15850 out[i] = dys[i] / (1.0 + x * x);
15851 }
15852 }
15853 }
15854}
15855
15856fn activation_backward_kernel_f64(
15860 act: rlx_ir::op::Activation,
15861 xs: &[f64],
15862 dys: &[f64],
15863 out: &mut [f64],
15864) {
15865 use rlx_ir::op::Activation;
15866 let n = xs.len();
15867 debug_assert_eq!(dys.len(), n);
15868 debug_assert_eq!(out.len(), n);
15869 match act {
15870 Activation::Relu => {
15871 for i in 0..n {
15872 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
15873 }
15874 }
15875 Activation::Sigmoid => {
15876 for i in 0..n {
15877 let s = 1.0 / (1.0 + (-xs[i]).exp());
15878 out[i] = s * (1.0 - s) * dys[i];
15879 }
15880 }
15881 Activation::Tanh => {
15882 for i in 0..n {
15883 let t = xs[i].tanh();
15884 out[i] = (1.0 - t * t) * dys[i];
15885 }
15886 }
15887 Activation::Silu => {
15888 for i in 0..n {
15889 let s = 1.0 / (1.0 + (-xs[i]).exp());
15890 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
15891 }
15892 }
15893 Activation::Gelu | Activation::GeluApprox => {
15894 const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
15896 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
15897 for i in 0..n {
15898 let x = xs[i];
15899 let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
15900 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
15901 out[i] = (phi + x * pdf) * dys[i];
15902 }
15903 }
15904 Activation::Exp => {
15905 for i in 0..n {
15906 out[i] = xs[i].exp() * dys[i];
15907 }
15908 }
15909 Activation::Log => {
15910 for i in 0..n {
15911 out[i] = dys[i] / xs[i];
15912 }
15913 }
15914 Activation::Sqrt => {
15915 for i in 0..n {
15916 let s = xs[i].sqrt();
15917 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
15918 }
15919 }
15920 Activation::Rsqrt => {
15921 for i in 0..n {
15922 let s = xs[i].sqrt();
15923 out[i] = if s > 0.0 {
15924 -0.5 * dys[i] / (xs[i] * s)
15925 } else {
15926 0.0
15927 };
15928 }
15929 }
15930 Activation::Neg => {
15931 for i in 0..n {
15932 out[i] = -dys[i];
15933 }
15934 }
15935 Activation::Abs => {
15936 for i in 0..n {
15937 let x = xs[i];
15938 let s = if x > 0.0 {
15939 1.0
15940 } else if x < 0.0 {
15941 -1.0
15942 } else {
15943 0.0
15944 };
15945 out[i] = s * dys[i];
15946 }
15947 }
15948 Activation::Round => {
15949 out.copy_from_slice(dys);
15950 }
15951 Activation::Sin => {
15952 for i in 0..n {
15953 out[i] = xs[i].cos() * dys[i];
15954 }
15955 }
15956 Activation::Cos => {
15957 for i in 0..n {
15958 out[i] = -xs[i].sin() * dys[i];
15959 }
15960 }
15961 Activation::Tan => {
15962 for i in 0..n {
15963 let t = xs[i].tan();
15964 out[i] = (1.0 + t * t) * dys[i];
15965 }
15966 }
15967 Activation::Atan => {
15968 for i in 0..n {
15969 let x = xs[i];
15970 out[i] = dys[i] / (1.0 + x * x);
15971 }
15972 }
15973 }
15974}
15975
15976#[inline(always)]
15981fn erf_f64(x: f64) -> f64 {
15982 let s = x.signum();
15983 let x = x.abs();
15984 let t = 1.0 / (1.0 + 0.327_591_1 * x);
15985 let y = 1.0
15986 - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
15987 + 0.254_829_59)
15988 * t
15989 * (-x * x).exp();
15990 s * y
15991}
15992
15993#[inline(always)]
15996fn erf_f32(x: f32) -> f32 {
15997 let s = x.signum();
15998 let x = x.abs();
15999 let t = 1.0 / (1.0 + 0.327_591_1 * x);
16000 let y = 1.0
16001 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
16002 + 0.254_829_6)
16003 * t
16004 * (-x * x).exp();
16005 s * y
16006}
16007
16008fn narrow_thunk_closure(
16009 src: usize,
16010 dst: usize,
16011 outer: u32,
16012 src_stride: u32,
16013 dst_stride: u32,
16014 inner: u32,
16015 elem_bytes: u8,
16016) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
16017 let (outer, ss, ds, inner, eb) = (
16018 outer as usize,
16019 src_stride as usize,
16020 dst_stride as usize,
16021 inner as usize,
16022 elem_bytes as usize,
16023 );
16024 let row_bytes = inner.saturating_mul(eb);
16025 let src_row_stride = ss.saturating_mul(eb);
16026 let dst_row_stride = ds.saturating_mul(eb);
16027 Arc::new(move |base: *mut u8| unsafe {
16028 if row_bytes == 0 || src == dst {
16029 return;
16030 }
16031 let arena_len = usize::MAX;
16033 for o in 0..outer {
16034 let s_off = src + o * src_row_stride;
16035 let d_off = dst + o * dst_row_stride;
16036 if s_off == d_off {
16037 continue;
16038 }
16039 if s_off.saturating_add(row_bytes) > arena_len
16040 || d_off.saturating_add(row_bytes) > arena_len
16041 {
16042 break;
16043 }
16044 std::ptr::copy_nonoverlapping(base.add(s_off), base.add(d_off), row_bytes);
16045 }
16046 })
16047}
16048
16049unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
16050 if offset == usize::MAX {
16051 return &[];
16052 }
16053 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
16054}
16055
16056#[inline(always)]
16057unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
16058 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
16059}
16060
16061#[inline(always)]
16062unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
16063 if offset == usize::MAX {
16064 return &[];
16065 }
16066 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
16067}
16068
16069#[inline(always)]
16070unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
16071 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
16072}
16073
16074#[inline(always)]
16079#[allow(dead_code)]
16080unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
16081 if offset == usize::MAX {
16082 return &[];
16083 }
16084 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
16085}
16086
16087#[inline(always)]
16088#[allow(dead_code)]
16089unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
16090 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
16091}
16092
16093#[inline(always)]
16094unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
16095 if offset == usize::MAX {
16096 return &[];
16097 }
16098 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
16099}
16100
16101#[inline(always)]
16102unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
16103 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
16104}
16105
16106fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
16110 let rank = out_dims.len();
16111 let mut idx = vec![0u32; rank];
16112 for o in 0..out.len() {
16113 let mut src_off = 0usize;
16114 for d in 0..rank {
16115 src_off += idx[d] as usize * in_strides[d] as usize;
16116 }
16117 out[o] = inp[src_off];
16118 for d in (0..rank).rev() {
16120 idx[d] += 1;
16121 if idx[d] < out_dims[d] {
16122 break;
16123 }
16124 idx[d] = 0;
16125 }
16126 }
16127}
16128
16129fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
16135 match kind {
16136 Activation::Neg => {
16137 for (o, &v) in out.iter_mut().zip(inp) {
16138 *o = -v;
16139 }
16140 }
16141 Activation::Exp => {
16142 for (o, &v) in out.iter_mut().zip(inp) {
16143 *o = v.exp();
16144 }
16145 }
16146 Activation::Log => {
16147 for (o, &v) in out.iter_mut().zip(inp) {
16148 *o = v.ln();
16149 }
16150 }
16151 Activation::Sqrt => {
16152 for (o, &v) in out.iter_mut().zip(inp) {
16153 *o = v.sqrt();
16154 }
16155 }
16156 Activation::Rsqrt => {
16157 for (o, &v) in out.iter_mut().zip(inp) {
16158 *o = 1.0 / v.sqrt();
16159 }
16160 }
16161 Activation::Abs => {
16162 for (o, &v) in out.iter_mut().zip(inp) {
16163 *o = v.abs();
16164 }
16165 }
16166 Activation::Tanh => {
16167 for (o, &v) in out.iter_mut().zip(inp) {
16168 *o = v.tanh();
16169 }
16170 }
16171 Activation::Sigmoid => {
16172 for (o, &v) in out.iter_mut().zip(inp) {
16173 *o = 1.0 / (1.0 + (-v).exp());
16174 }
16175 }
16176 Activation::Relu => {
16177 for (o, &v) in out.iter_mut().zip(inp) {
16178 *o = v.max(0.0);
16179 }
16180 }
16181 Activation::Round => {
16182 for (o, &v) in out.iter_mut().zip(inp) {
16183 *o = v.round_ties_even();
16184 }
16185 }
16186 Activation::Sin => {
16187 for (o, &v) in out.iter_mut().zip(inp) {
16188 *o = v.sin();
16189 }
16190 }
16191 Activation::Cos => {
16192 for (o, &v) in out.iter_mut().zip(inp) {
16193 *o = v.cos();
16194 }
16195 }
16196 Activation::Tan => {
16197 for (o, &v) in out.iter_mut().zip(inp) {
16198 *o = v.tan();
16199 }
16200 }
16201 Activation::Atan => {
16202 for (o, &v) in out.iter_mut().zip(inp) {
16203 *o = v.atan();
16204 }
16205 }
16206 Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
16207 panic!(
16208 "apply_activation_f64: {kind:?} not yet implemented at f64. \
16209 Add when a workload needs it."
16210 );
16211 }
16212 }
16213}
16214
16215#[inline]
16216fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
16217 match op {
16218 BinaryOp::Add => a + b,
16219 BinaryOp::Sub => a - b,
16220 BinaryOp::Mul => a * b,
16221 BinaryOp::Div => a / b,
16222 BinaryOp::Max => a.max(b),
16223 BinaryOp::Min => a.min(b),
16224 BinaryOp::Pow => a.powf(b),
16225 }
16226}
16227
16228fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
16231 for o in 0..outer {
16232 for n in 0..inner {
16233 let mut acc = 0.0_f64;
16234 for r in 0..reduced {
16235 acc += inp[o * reduced * inner + r * inner + n];
16236 }
16237 out[o * inner + n] = acc;
16238 }
16239 }
16240}
16241
16242#[cfg(test)]
16243mod tests {
16244 use super::*;
16245 use rlx_ir::*;
16246
16247 #[test]
16253 fn narrow_rope_fuses_in_unfused_path() {
16254 let f = DType::F32;
16255 let mut g = Graph::new("nr_fuse");
16256 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); let cos = g.input("cos", Shape::new(&[16], f));
16259 let sin = g.input("sin", Shape::new(&[16], f));
16260 let q = g.narrow_(qkv, 2, 0, 64);
16262 let q_rope = g.rope(q, cos, sin, 16);
16263 g.set_outputs(vec![q_rope]);
16264
16265 let plan = rlx_opt::memory::plan_memory(&g);
16266 let arena = crate::arena::Arena::from_plan(plan);
16267 let sched = compile_thunks(&g, &arena);
16268
16269 let mut narrow_count = 0;
16270 let mut rope_with_stride: Option<u32> = None;
16271 for t in &sched.thunks {
16272 match t {
16273 Thunk::Narrow { .. } => narrow_count += 1,
16274 Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
16275 _ => {}
16276 }
16277 }
16278 assert_eq!(
16281 narrow_count, 0,
16282 "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
16283 );
16284 assert_eq!(
16285 rope_with_stride,
16286 Some(192),
16287 "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
16288 );
16289 }
16290
16291 #[test]
16294 fn ssm_selective_scan_matches_reference() {
16295 use rlx_ir::Philox4x32;
16296 let bch = 1usize;
16297 let s = 4usize;
16298 let h = 3usize;
16299 let n = 2usize;
16300
16301 let mut rng = Philox4x32::new(13);
16302 let mut x = vec![0f32; bch * s * h];
16303 rng.fill_normal(&mut x);
16304 let mut delta = vec![0f32; bch * s * h];
16305 for v in delta.iter_mut() {
16307 *v = (rng.next_f32() - 0.5) * 0.1;
16308 }
16309 let mut a = vec![0f32; h * n];
16310 for v in a.iter_mut() {
16311 *v = -(rng.next_f32() * 0.5 + 0.1);
16312 } let mut b = vec![0f32; bch * s * n];
16314 rng.fill_normal(&mut b);
16315 let mut c = vec![0f32; bch * s * n];
16316 rng.fill_normal(&mut c);
16317
16318 let mut expected = vec![0f32; bch * s * h];
16320 for bi in 0..bch {
16321 let mut state = vec![0f32; h * n];
16322 for si in 0..s {
16323 for ci in 0..h {
16324 let d = delta[bi * s * h + si * h + ci];
16325 let xv = x[bi * s * h + si * h + ci];
16326 let mut acc = 0f32;
16327 for ni in 0..n {
16328 let da = (d * a[ci * n + ni]).exp();
16329 state[ci * n + ni] =
16330 da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
16331 acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
16332 }
16333 expected[bi * s * h + si * h + ci] = acc;
16334 }
16335 }
16336 }
16337
16338 let f = DType::F32;
16340 let mut g = Graph::new("ssm");
16341 let xn = g.input("x", Shape::new(&[bch, s, h], f));
16342 let dn = g.input("delta", Shape::new(&[bch, s, h], f));
16343 let an = g.param("a", Shape::new(&[h, n], f));
16344 let bn = g.param("b", Shape::new(&[bch, s, n], f));
16345 let cn = g.param("c", Shape::new(&[bch, s, n], f));
16346 let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
16347 g.set_outputs(vec![yn]);
16348
16349 let plan = rlx_opt::memory::plan_memory(&g);
16350 let mut arena = crate::arena::Arena::from_plan(plan);
16351 let sched = compile_thunks(&g, &arena);
16352
16353 let xn_off = arena.byte_offset(xn);
16354 let dn_off = arena.byte_offset(dn);
16355 let an_off = arena.byte_offset(an);
16356 let bn_off = arena.byte_offset(bn);
16357 let cn_off = arena.byte_offset(cn);
16358 let yn_off = arena.byte_offset(yn);
16359 let buf = arena.raw_buf_mut();
16360 unsafe {
16361 let copy = |dst: *mut f32, data: &[f32]| {
16362 for (i, &v) in data.iter().enumerate() {
16363 *dst.add(i) = v;
16364 }
16365 };
16366 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
16367 copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
16368 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
16369 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
16370 copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
16371 }
16372 execute_thunks(&sched, arena.raw_buf_mut());
16373
16374 let actual: Vec<f32> = unsafe {
16375 let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
16376 (0..bch * s * h).map(|i| *p.add(i)).collect()
16377 };
16378
16379 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16380 assert!(
16381 (e - a).abs() < 1e-3,
16382 "mismatch at {i}: expected {e}, got {a}"
16383 );
16384 }
16385 }
16386
16387 #[test]
16390 fn conv_1x1_fast_path_matches_scalar() {
16391 use rlx_ir::Philox4x32;
16392 let n = 2usize;
16394 let c_in = 4usize;
16395 let h = 3usize;
16396 let w = 3usize;
16397 let c_out = 5usize;
16398 let mut rng = Philox4x32::new(31);
16399 let mut x = vec![0f32; n * c_in * h * w];
16400 rng.fill_normal(&mut x);
16401 let mut weight = vec![0f32; c_out * c_in];
16402 rng.fill_normal(&mut weight);
16403
16404 let mut expected = vec![0f32; n * c_out * h * w];
16407 for ni in 0..n {
16408 for co in 0..c_out {
16409 for hi in 0..h {
16410 for wi in 0..w {
16411 let mut acc = 0f32;
16412 for ci in 0..c_in {
16413 acc += weight[co * c_in + ci]
16414 * x[((ni * c_in) + ci) * h * w + hi * w + wi];
16415 }
16416 expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
16417 }
16418 }
16419 }
16420 }
16421
16422 let f = DType::F32;
16424 let mut g = Graph::new("conv1x1");
16425 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
16426 let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
16427 let cn = g.add_node(
16429 rlx_ir::Op::Conv {
16430 kernel_size: vec![1, 1],
16431 stride: vec![1, 1],
16432 padding: vec![0, 0],
16433 dilation: vec![1, 1],
16434 groups: 1,
16435 },
16436 vec![xn, wn],
16437 Shape::new(&[n, c_out, h, w], f),
16438 );
16439 g.set_outputs(vec![cn]);
16440
16441 let plan = rlx_opt::memory::plan_memory(&g);
16442 let mut arena = crate::arena::Arena::from_plan(plan);
16443 let sched = compile_thunks(&g, &arena);
16444
16445 let saw_fast = sched
16447 .thunks
16448 .iter()
16449 .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
16450 let saw_slow = sched
16451 .thunks
16452 .iter()
16453 .any(|t| matches!(t, Thunk::Conv2D { .. }));
16454 assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
16455 assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
16456
16457 let xn_off = arena.byte_offset(xn);
16458 let wn_off = arena.byte_offset(wn);
16459 let cn_off = arena.byte_offset(cn);
16460 let buf = arena.raw_buf_mut();
16461 unsafe {
16462 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
16463 for (i, &v) in x.iter().enumerate() {
16464 *xp.add(i) = v;
16465 }
16466 let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
16467 for (i, &v) in weight.iter().enumerate() {
16468 *wp.add(i) = v;
16469 }
16470 }
16471 execute_thunks(&sched, arena.raw_buf_mut());
16472
16473 let actual: Vec<f32> = unsafe {
16474 let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
16475 (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
16476 };
16477
16478 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16479 assert!(
16480 (e - a).abs() < 1e-3,
16481 "mismatch at {i}: expected {e}, got {a}"
16482 );
16483 }
16484 }
16485
16486 #[test]
16489 fn dequant_matmul_int8_sym_matches_reference() {
16490 use rlx_ir::Philox4x32;
16491 use rlx_ir::quant::QuantScheme;
16492
16493 let m = 3usize;
16494 let k = 8usize;
16495 let n = 4usize;
16496 let block_size = 4usize; let blocks_per_col = k / block_size;
16498
16499 let mut rng = Philox4x32::new(99);
16501 let mut x = vec![0f32; m * k];
16502 rng.fill_normal(&mut x);
16503 let w_q: Vec<i8> = (0..(k * n))
16504 .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
16505 .collect();
16506 let scales: Vec<f32> = (0..(blocks_per_col * n))
16507 .map(|i| 0.01 + 0.001 * i as f32)
16508 .collect();
16509
16510 let mut w_f32 = vec![0f32; k * n];
16512 for p in 0..k {
16513 let block = p / block_size;
16514 for j in 0..n {
16515 let s = scales[block * n + j];
16516 w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
16517 }
16518 }
16519 let mut expected = vec![0f32; m * n];
16520 for i in 0..m {
16521 for j in 0..n {
16522 let mut acc = 0f32;
16523 for p in 0..k {
16524 acc += x[i * k + p] * w_f32[p * n + j];
16525 }
16526 expected[i * n + j] = acc;
16527 }
16528 }
16529
16530 let f = DType::F32;
16532 let mut g = Graph::new("dq");
16533 let xn = g.input("x", Shape::new(&[m, k], f));
16534 let wn = g.param("w", Shape::new(&[k, n], DType::I8));
16535 let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
16536 let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); let dq = g.dequant_matmul(
16538 xn,
16539 wn,
16540 sn,
16541 zn,
16542 QuantScheme::Int8Block {
16543 block_size: block_size as u32,
16544 },
16545 Shape::new(&[m, n], f),
16546 );
16547 g.set_outputs(vec![dq]);
16548
16549 let plan = rlx_opt::memory::plan_memory(&g);
16550 let mut arena = crate::arena::Arena::from_plan(plan);
16551 let sched = compile_thunks(&g, &arena);
16552
16553 let xn_off = arena.byte_offset(xn);
16554 let wn_off = arena.byte_offset(wn);
16555 let sn_off = arena.byte_offset(sn);
16556 let zn_off = arena.byte_offset(zn);
16557 let dq_off = arena.byte_offset(dq);
16558 let buf = arena.raw_buf_mut();
16559 unsafe {
16560 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
16562 for (i, &v) in x.iter().enumerate() {
16563 *xp.add(i) = v;
16564 }
16565 let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
16566 for (i, &v) in scales.iter().enumerate() {
16567 *sp.add(i) = v;
16568 }
16569 let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
16570 for i in 0..(blocks_per_col * n) {
16571 *zp.add(i) = 0.0;
16572 }
16573 let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
16575 for (i, &v) in w_q.iter().enumerate() {
16576 *wp.add(i) = v;
16577 }
16578 }
16579 execute_thunks(&sched, arena.raw_buf_mut());
16580
16581 let actual: Vec<f32> = unsafe {
16582 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
16583 (0..m * n).map(|i| *p.add(i)).collect()
16584 };
16585
16586 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16587 assert!(
16588 (e - a).abs() < 1e-3,
16589 "mismatch at {i}: expected {e}, got {a}"
16590 );
16591 }
16592 }
16593
16594 #[test]
16596 fn lora_matmul_matches_unfused_reference() {
16597 use rlx_ir::Philox4x32;
16598
16599 let m = 4usize;
16600 let k = 8usize;
16601 let n = 6usize;
16602 let r = 2usize;
16603 let scale = 0.5f32;
16604
16605 let mut rng = Philox4x32::new(42);
16607 let mut x = vec![0f32; m * k];
16608 rng.fill_normal(&mut x);
16609 let mut w = vec![0f32; k * n];
16610 rng.fill_normal(&mut w);
16611 let mut a = vec![0f32; k * r];
16612 rng.fill_normal(&mut a);
16613 let mut b = vec![0f32; r * n];
16614 rng.fill_normal(&mut b);
16615
16616 let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
16618 let mut o = vec![0f32; rows * cols];
16619 for i in 0..rows {
16620 for j in 0..cols {
16621 let mut acc = 0f32;
16622 for p in 0..inner {
16623 acc += a_buf[i * inner + p] * b_buf[p * cols + j];
16624 }
16625 o[i * cols + j] = acc;
16626 }
16627 }
16628 o
16629 };
16630 let xw = naive(&x, &w, m, k, n);
16631 let xa = naive(&x, &a, m, k, r);
16632 let xab = naive(&xa, &b, m, r, n);
16633 let mut expected = xw;
16634 for i in 0..(m * n) {
16635 expected[i] += scale * xab[i];
16636 }
16637
16638 let f = DType::F32;
16640 let mut g = Graph::new("lora");
16641 let xn = g.input("x", Shape::new(&[m, k], f));
16642 let wn = g.param("w", Shape::new(&[k, n], f));
16643 let an = g.param("a", Shape::new(&[k, r], f));
16644 let bn = g.param("b", Shape::new(&[r, n], f));
16645 let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
16646 g.set_outputs(vec![lm]);
16647
16648 let plan = rlx_opt::memory::plan_memory(&g);
16649 let mut arena = crate::arena::Arena::from_plan(plan);
16650 let sched = compile_thunks(&g, &arena);
16651
16652 let xn_off = arena.byte_offset(xn);
16653 let wn_off = arena.byte_offset(wn);
16654 let an_off = arena.byte_offset(an);
16655 let bn_off = arena.byte_offset(bn);
16656 let lm_off = arena.byte_offset(lm);
16657 let buf = arena.raw_buf_mut();
16658 unsafe {
16659 let copy = |dst: *mut f32, data: &[f32]| {
16660 for (i, &v) in data.iter().enumerate() {
16661 *dst.add(i) = v;
16662 }
16663 };
16664 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
16665 copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
16666 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
16667 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
16668 }
16669 execute_thunks(&sched, arena.raw_buf_mut());
16670
16671 let actual: Vec<f32> = unsafe {
16672 let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
16673 (0..m * n).map(|i| *p.add(i)).collect()
16674 };
16675
16676 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16677 assert!(
16678 (e - a).abs() < 1e-3,
16679 "mismatch at {i}: expected {e}, got {a}"
16680 );
16681 }
16682 }
16683
16684 #[test]
16686 fn sample_temperature_zero_is_argmax() {
16687 let f = DType::F32;
16690 let mut g = Graph::new("samp");
16691 let logits = g.input("logits", Shape::new(&[1, 8], f));
16692 let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
16693 g.set_outputs(vec![s]);
16694 let plan = rlx_opt::memory::plan_memory(&g);
16695 let mut arena = crate::arena::Arena::from_plan(plan);
16696 let sched = compile_thunks(&g, &arena);
16697
16698 let logits_off = arena.byte_offset(logits);
16699 let s_off = arena.byte_offset(s);
16700 let buf = arena.raw_buf_mut();
16701 unsafe {
16702 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
16703 let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
16705 for (i, &v) in inputs.iter().enumerate() {
16706 *p.add(i) = v;
16707 }
16708 }
16709 execute_thunks(&sched, arena.raw_buf_mut());
16710
16711 let token = unsafe {
16712 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
16713 *p as usize
16714 };
16715 assert_eq!(token, 5, "low-temp sampling should pick the argmax");
16716 }
16717
16718 #[test]
16719 fn sample_top_k_one_is_deterministic() {
16720 let f = DType::F32;
16722 let mut g = Graph::new("samp_k1");
16723 let logits = g.input("logits", Shape::new(&[1, 4], f));
16724 let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
16725 g.set_outputs(vec![s]);
16726 let plan = rlx_opt::memory::plan_memory(&g);
16727 let mut arena = crate::arena::Arena::from_plan(plan);
16728 let sched = compile_thunks(&g, &arena);
16729
16730 let logits_off = arena.byte_offset(logits);
16731 let s_off = arena.byte_offset(s);
16732 let buf = arena.raw_buf_mut();
16733 unsafe {
16734 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
16735 let inputs = [0.1f32, 5.0, 0.3, 0.4]; for (i, &v) in inputs.iter().enumerate() {
16737 *p.add(i) = v;
16738 }
16739 }
16740 execute_thunks(&sched, arena.raw_buf_mut());
16741 let token = unsafe {
16742 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
16743 *p as usize
16744 };
16745 assert_eq!(token, 1);
16746 }
16747
16748 #[test]
16750 fn cumsum_inclusive_matches_naive() {
16751 let f = DType::F32;
16752 let mut g = Graph::new("cumsum");
16753 let x = g.input("x", Shape::new(&[2, 4], f));
16754 let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
16755 g.set_outputs(vec![cs]);
16756 let plan = rlx_opt::memory::plan_memory(&g);
16757 let mut arena = crate::arena::Arena::from_plan(plan);
16758 let sched = compile_thunks(&g, &arena);
16759
16760 let x_off = arena.byte_offset(x);
16762 let out_off = arena.byte_offset(cs);
16763 let buf = arena.raw_buf_mut();
16764 unsafe {
16765 let p = buf.as_mut_ptr().add(x_off) as *mut f32;
16766 let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
16767 for (i, &v) in inputs.iter().enumerate() {
16768 *p.add(i) = v;
16769 }
16770 }
16771 execute_thunks(&sched, arena.raw_buf_mut());
16772
16773 let out: Vec<f32> = unsafe {
16774 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
16775 (0..8).map(|i| *p.add(i)).collect()
16776 };
16777 assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
16778 }
16779
16780 #[test]
16784 fn narrow_attention_fuses_in_unfused_path() {
16785 let f = DType::F32;
16786 let mut g = Graph::new("nattn_fuse");
16787 let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); let mask = g.input("mask", Shape::new(&[8, 16], f));
16790 let q = g.narrow_(qkv, 2, 0, 64);
16791 let k = g.narrow_(qkv, 2, 64, 64);
16792 let v = g.narrow_(qkv, 2, 128, 64);
16793 let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
16794 g.set_outputs(vec![attn]);
16795
16796 let plan = rlx_opt::memory::plan_memory(&g);
16797 let arena = crate::arena::Arena::from_plan(plan);
16798 let sched = compile_thunks(&g, &arena);
16799
16800 let mut narrow_count = 0;
16801 let mut attn_strides: Option<(u32, u32, u32)> = None;
16802 for t in &sched.thunks {
16803 match t {
16804 Thunk::Narrow { .. } => narrow_count += 1,
16805 Thunk::Attention {
16806 q_row_stride,
16807 k_row_stride,
16808 v_row_stride,
16809 ..
16810 } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
16811 _ => {}
16812 }
16813 }
16814 assert_eq!(
16817 narrow_count, 0,
16818 "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
16819 );
16820 assert_eq!(
16821 attn_strides,
16822 Some((192, 192, 192)),
16823 "Attention should walk Q/K/V with parent row stride 192"
16824 );
16825 }
16826
16827 fn run_graph(
16838 g: &Graph,
16839 inputs: &[(NodeId, &[f32])],
16840 out_id: NodeId,
16841 out_len: usize,
16842 ) -> Vec<f32> {
16843 let plan = rlx_opt::memory::plan_memory(g);
16844 let mut arena = crate::arena::Arena::from_plan(plan);
16845 let sched = compile_thunks(g, &arena);
16846 for &(id, data) in inputs {
16847 let off = arena.byte_offset(id);
16848 let buf = arena.raw_buf_mut();
16849 unsafe {
16850 let p = buf.as_mut_ptr().add(off) as *mut f32;
16851 for (i, &v) in data.iter().enumerate() {
16852 *p.add(i) = v;
16853 }
16854 }
16855 }
16856 execute_thunks(&sched, arena.raw_buf_mut());
16857 let off = arena.byte_offset(out_id);
16858 unsafe {
16859 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
16860 (0..out_len).map(|i| *p.add(i)).collect()
16861 }
16862 }
16863
16864 #[test]
16865 fn relu_backward_matches_mask() {
16866 let f = DType::F32;
16867 let len = 7usize;
16868 let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
16869 let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
16870
16871 let mut g = Graph::new("relu_bw");
16872 let xn = g.input("x", Shape::new(&[len], f));
16873 let dyn_ = g.input("dy", Shape::new(&[len], f));
16874 let dx = g.relu_backward(xn, dyn_);
16875 g.set_outputs(vec![dx]);
16876
16877 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
16878 let expected: Vec<f32> = x
16882 .iter()
16883 .zip(&dy)
16884 .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
16885 .collect();
16886 for (a, e) in actual.iter().zip(&expected) {
16887 assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
16888 }
16889 }
16890
16891 #[test]
16892 fn maxpool2d_backward_routes_to_argmax() {
16893 let f = DType::F32;
16894 let x: Vec<f32> = vec![
16896 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
16897 ];
16898 let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
16902
16903 let mut g = Graph::new("maxpool_bw");
16904 let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
16905 let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
16906 let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
16907 g.set_outputs(vec![dx]);
16908
16909 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
16910 let mut expected = vec![0f32; 16];
16911 expected[5] = 0.5;
16912 expected[7] = 1.0;
16913 expected[13] = 2.0;
16914 expected[15] = 4.0;
16915 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
16916 assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
16917 }
16918 }
16919
16920 #[test]
16921 fn conv2d_backward_input_matches_numerical_gradient() {
16922 use rlx_ir::Philox4x32;
16923 let n = 1usize;
16926 let c_in = 2usize;
16927 let h = 4usize;
16928 let w = 4usize;
16929 let c_out = 3usize;
16930 let kh = 3usize;
16931 let kw = 3usize;
16932 let ph = 1usize;
16933 let pw = 1usize;
16934 let sh = 1usize;
16935 let sw = 1usize;
16936 let h_out = (h + 2 * ph - kh) / sh + 1;
16938 let w_out = (w + 2 * pw - kw) / sw + 1;
16939 assert_eq!(h_out, 4);
16940 assert_eq!(w_out, 4);
16941
16942 let mut rng = Philox4x32::new(7);
16943 let mut x = vec![0f32; n * c_in * h * w];
16944 rng.fill_normal(&mut x);
16945 let mut wt = vec![0f32; c_out * c_in * kh * kw];
16946 rng.fill_normal(&mut wt);
16947 let mut dy = vec![0f32; n * c_out * h_out * w_out];
16948 rng.fill_normal(&mut dy);
16949
16950 let f = DType::F32;
16952 let mut g = Graph::new("conv_bwi");
16953 let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
16954 let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
16955 let dx = g.conv2d_backward_input(
16956 dy_in,
16957 w_in,
16958 Shape::new(&[n, c_in, h, w], f),
16959 vec![kh, kw],
16960 vec![sh, sw],
16961 vec![ph, pw],
16962 vec![1, 1],
16963 1,
16964 );
16965 g.set_outputs(vec![dx]);
16966 let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
16967
16968 let forward = |x: &[f32]| -> Vec<f32> {
16972 let mut out = vec![0f32; n * c_out * h_out * w_out];
16973 for ni in 0..n {
16974 for co in 0..c_out {
16975 for ho in 0..h_out {
16976 for wo in 0..w_out {
16977 let mut acc = 0f32;
16978 for ci in 0..c_in {
16979 for ki in 0..kh {
16980 for kj in 0..kw {
16981 let hi = ho * sh + ki;
16982 let wi = wo * sw + kj;
16983 if hi < ph || wi < pw {
16984 continue;
16985 }
16986 let hi = hi - ph;
16987 let wi = wi - pw;
16988 if hi >= h || wi >= w {
16989 continue;
16990 }
16991 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
16992 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
16993 acc += xv * wv;
16994 }
16995 }
16996 }
16997 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
16998 }
16999 }
17000 }
17001 }
17002 out
17003 };
17004 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17005 let eps = 1e-3f32;
17006 let mut numerical = vec![0f32; x.len()];
17007 for i in 0..x.len() {
17008 let saved = x[i];
17009 x[i] = saved + eps;
17010 let plus = dot(&forward(&x), &dy);
17011 x[i] = saved - eps;
17012 let minus = dot(&forward(&x), &dy);
17013 x[i] = saved;
17014 numerical[i] = (plus - minus) / (2.0 * eps);
17015 }
17016 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17017 assert!(
17019 (a - n).abs() < 5e-3,
17020 "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
17021 );
17022 }
17023 }
17024
17025 #[test]
17026 fn conv2d_backward_weight_matches_numerical_gradient() {
17027 use rlx_ir::Philox4x32;
17028 let n = 2usize;
17029 let c_in = 2usize;
17030 let h = 4usize;
17031 let w = 4usize;
17032 let c_out = 2usize;
17033 let kh = 3usize;
17034 let kw = 3usize;
17035 let ph = 0usize;
17036 let pw = 0usize;
17037 let sh = 1usize;
17038 let sw = 1usize;
17039 let h_out = (h + 2 * ph - kh) / sh + 1;
17040 let w_out = (w + 2 * pw - kw) / sw + 1;
17041
17042 let mut rng = Philox4x32::new(11);
17043 let mut x = vec![0f32; n * c_in * h * w];
17044 rng.fill_normal(&mut x);
17045 let mut wt = vec![0f32; c_out * c_in * kh * kw];
17046 rng.fill_normal(&mut wt);
17047 let mut dy = vec![0f32; n * c_out * h_out * w_out];
17048 rng.fill_normal(&mut dy);
17049
17050 let f = DType::F32;
17051 let mut g = Graph::new("conv_bww");
17052 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
17053 let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
17054 let dwn = g.conv2d_backward_weight(
17055 xn,
17056 dyn_,
17057 Shape::new(&[c_out, c_in, kh, kw], f),
17058 vec![kh, kw],
17059 vec![sh, sw],
17060 vec![ph, pw],
17061 vec![1, 1],
17062 1,
17063 );
17064 g.set_outputs(vec![dwn]);
17065 let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
17066
17067 let forward = |wt: &[f32]| -> Vec<f32> {
17068 let mut out = vec![0f32; n * c_out * h_out * w_out];
17069 for ni in 0..n {
17070 for co in 0..c_out {
17071 for ho in 0..h_out {
17072 for wo in 0..w_out {
17073 let mut acc = 0f32;
17074 for ci in 0..c_in {
17075 for ki in 0..kh {
17076 for kj in 0..kw {
17077 let hi = ho + ki;
17078 let wi = wo + kj;
17079 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
17080 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
17081 acc += xv * wv;
17082 }
17083 }
17084 }
17085 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
17086 }
17087 }
17088 }
17089 }
17090 out
17091 };
17092 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17093 let eps = 1e-3f32;
17094 let mut numerical = vec![0f32; wt.len()];
17095 for i in 0..wt.len() {
17096 let saved = wt[i];
17097 wt[i] = saved + eps;
17098 let plus = dot(&forward(&wt), &dy);
17099 wt[i] = saved - eps;
17100 let minus = dot(&forward(&wt), &dy);
17101 wt[i] = saved;
17102 numerical[i] = (plus - minus) / (2.0 * eps);
17103 }
17104 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17105 assert!(
17106 (a - n).abs() < 5e-3,
17107 "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
17108 );
17109 }
17110 }
17111
17112 #[test]
17113 fn softmax_cross_entropy_matches_reference() {
17114 let f = DType::F32;
17115 let logits: Vec<f32> = vec![
17116 1.0, 2.0, 3.0, -1.0, 0.0, 4.0, 5.0, 5.0, 5.0, ];
17120 let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
17121
17122 let mut g = Graph::new("sce");
17123 let lg = g.input("logits", Shape::new(&[3, 3], f));
17124 let lb = g.input("labels", Shape::new(&[3], f));
17125 let loss = g.softmax_cross_entropy_with_logits(lg, lb);
17126 g.set_outputs(vec![loss]);
17127 let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
17128
17129 let mut expected = vec![0f32; 3];
17131 for ni in 0..3 {
17132 let row = &logits[ni * 3..(ni + 1) * 3];
17133 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17134 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17135 let lse = m + sum.ln();
17136 let label_idx = labels[ni] as usize;
17137 expected[ni] = lse - row[label_idx];
17138 }
17139 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
17140 assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
17141 }
17142 }
17143
17144 #[test]
17145 fn softmax_cross_entropy_backward_matches_numerical_gradient() {
17146 use rlx_ir::Philox4x32;
17147 let n = 4usize;
17148 let c = 5usize;
17149 let mut rng = Philox4x32::new(23);
17150 let mut logits = vec![0f32; n * c];
17151 rng.fill_normal(&mut logits);
17152 let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
17153 let mut d_loss = vec![0f32; n];
17154 rng.fill_normal(&mut d_loss);
17155
17156 let f = DType::F32;
17157 let mut g = Graph::new("sce_bw");
17158 let lg = g.input("logits", Shape::new(&[n, c], f));
17159 let lb = g.input("labels", Shape::new(&[n], f));
17160 let dl = g.input("d_loss", Shape::new(&[n], f));
17161 let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
17162 g.set_outputs(vec![dlogits]);
17163 let analytical = run_graph(
17164 &g,
17165 &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
17166 dlogits,
17167 n * c,
17168 );
17169
17170 let sce_loss = |logits: &[f32]| -> Vec<f32> {
17172 let mut out = vec![0f32; n];
17173 for ni in 0..n {
17174 let row = &logits[ni * c..(ni + 1) * c];
17175 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17176 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17177 out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
17178 }
17179 out
17180 };
17181 let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
17182 let eps = 1e-3f32;
17183 let mut numerical = vec![0f32; logits.len()];
17184 for i in 0..logits.len() {
17185 let saved = logits[i];
17186 logits[i] = saved + eps;
17187 let plus = dot(&sce_loss(&logits), &d_loss);
17188 logits[i] = saved - eps;
17189 let minus = dot(&sce_loss(&logits), &d_loss);
17190 logits[i] = saved;
17191 numerical[i] = (plus - minus) / (2.0 * eps);
17192 }
17193 for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
17194 assert!(
17195 (a - num).abs() < 5e-3,
17196 "sce_bw[{i}]: analytical {a} vs numerical {num}"
17197 );
17198 }
17199 }
17200
17201 fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
17214 for node in graph.nodes() {
17215 if let Op::Constant { data } = &node.op
17216 && arena.has_buffer(node.id)
17217 && !data.is_empty()
17218 {
17219 let buf = arena.slice_mut(node.id);
17220 let n_floats = data.len() / 4;
17221 let n = buf.len().min(n_floats);
17222 for i in 0..n {
17223 let bytes = [
17224 data[i * 4],
17225 data[i * 4 + 1],
17226 data[i * 4 + 2],
17227 data[i * 4 + 3],
17228 ];
17229 buf[i] = f32::from_le_bytes(bytes);
17230 }
17231 }
17232 }
17233 }
17234
17235 fn prepare(
17239 graph: &Graph,
17240 seed_inputs: &[(NodeId, &[f32])],
17241 ) -> (ThunkSchedule, crate::arena::Arena) {
17242 let plan = rlx_opt::memory::plan_memory(graph);
17243 let mut arena = crate::arena::Arena::from_plan(plan);
17244 let sched = compile_thunks(graph, &arena);
17245 fill_constants_into_arena(graph, &mut arena);
17246 for &(id, data) in seed_inputs {
17247 let off = arena.byte_offset(id);
17248 let buf = arena.raw_buf_mut();
17249 unsafe {
17250 let p = buf.as_mut_ptr().add(off) as *mut f32;
17251 for (i, &v) in data.iter().enumerate() {
17252 *p.add(i) = v;
17253 }
17254 }
17255 }
17256 (sched, arena)
17257 }
17258
17259 fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
17260 let off = arena.byte_offset(id);
17261 unsafe {
17262 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
17263 (0..len).map(|i| *p.add(i)).collect()
17264 }
17265 }
17266
17267 fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
17268 let off = arena.byte_offset(id);
17269 let buf = arena.raw_buf_mut();
17270 unsafe {
17271 let p = buf.as_mut_ptr().add(off) as *mut f32;
17272 for (i, &v) in data.iter().enumerate() {
17273 *p.add(i) = v;
17274 }
17275 }
17276 }
17277
17278 fn prepare_f64(
17280 graph: &Graph,
17281 seed_inputs: &[(NodeId, &[f64])],
17282 ) -> (ThunkSchedule, crate::arena::Arena) {
17283 let plan = rlx_opt::memory::plan_memory(graph);
17284 let mut arena = crate::arena::Arena::from_plan(plan);
17285 let sched = compile_thunks(graph, &arena);
17286 fill_constants_into_arena(graph, &mut arena);
17287 for &(id, data) in seed_inputs {
17288 let off = arena.byte_offset(id);
17289 let buf = arena.raw_buf_mut();
17290 unsafe {
17291 let p = buf.as_mut_ptr().add(off) as *mut f64;
17292 for (i, &v) in data.iter().enumerate() {
17293 *p.add(i) = v;
17294 }
17295 }
17296 }
17297 (sched, arena)
17298 }
17299
17300 fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
17301 let off = arena.byte_offset(id);
17302 unsafe {
17303 let p = arena.raw_buf().as_ptr().add(off) as *const f64;
17304 (0..len).map(|i| *p.add(i)).collect()
17305 }
17306 }
17307
17308 #[test]
17318 fn dense_solve_f64_end_to_end() {
17319 let mut g = Graph::new("solve_e2e");
17320 let a = g.input("A", Shape::new(&[2, 2], DType::F64));
17321 let b = g.input("b", Shape::new(&[2], DType::F64));
17322 let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
17323 g.set_outputs(vec![x]);
17324
17325 let a_data = [2.0, 1.0, 1.0, 3.0_f64];
17326 let b_data = [5.0, 10.0_f64];
17327 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17328 execute_thunks(&sched, arena.raw_buf_mut());
17329
17330 let got = read_arena_f64(&arena, x, 2);
17331 let want = [1.0, 3.0_f64];
17332 for i in 0..2 {
17333 assert!(
17334 (got[i] - want[i]).abs() < 1e-12,
17335 "x[{i}] = {} (expected {})",
17336 got[i],
17337 want[i]
17338 );
17339 }
17340 }
17341
17342 #[test]
17348 fn dense_solve_f64_5x5_laplacian() {
17349 let n = 5usize;
17350 let mut g = Graph::new("solve_5x5");
17351 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17352 let b = g.input("b", Shape::new(&[n], DType::F64));
17353 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17354 g.set_outputs(vec![x]);
17355
17356 let mut a_data = vec![0.0_f64; n * n];
17358 for i in 0..n {
17359 a_data[i * n + i] = 2.0;
17360 if i > 0 {
17361 a_data[i * n + (i - 1)] = -1.0;
17362 }
17363 if i + 1 < n {
17364 a_data[i * n + (i + 1)] = -1.0;
17365 }
17366 }
17367 let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
17368 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17369 execute_thunks(&sched, arena.raw_buf_mut());
17370
17371 let got = read_arena_f64(&arena, x, n);
17372 let mut residual = vec![0.0_f64; n];
17374 for i in 0..n {
17375 for j in 0..n {
17376 residual[i] += a_data[i * n + j] * got[j];
17377 }
17378 }
17379 for i in 0..n {
17380 assert!(
17381 (residual[i] - b_data[i]).abs() < 1e-10,
17382 "row {i}: residual {} vs b {}",
17383 residual[i],
17384 b_data[i]
17385 );
17386 }
17387 }
17388
17389 #[test]
17408 fn hello_resistor_gradient_end_to_end() {
17409 use rlx_opt::autodiff::grad_with_loss;
17410 let n = 3usize;
17411
17412 let mut g = Graph::new("hello_resistor");
17414 let a = g.param("A", Shape::new(&[n, n], DType::F64));
17415 let b = g.input("b", Shape::new(&[n], DType::F64));
17416 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17417 let loss = g.reduce(
17418 x,
17419 ReduceOp::Sum,
17420 vec![0],
17421 false,
17422 Shape::new(&[1], DType::F64),
17423 );
17424 g.set_outputs(vec![loss]);
17425
17426 let bwd = grad_with_loss(&g, &[a, b]);
17428 assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
17429
17430 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17434 for node in graph.nodes() {
17435 let name = match &node.op {
17436 rlx_ir::Op::Input { name } => Some(name.as_str()),
17437 rlx_ir::Op::Param { name } => Some(name.as_str()),
17438 _ => None,
17439 };
17440 if name == Some(want) {
17441 return node.id;
17442 }
17443 }
17444 panic!("no node named {want:?} in bwd graph");
17445 };
17446 let a_bwd = find_by_name(&bwd, "A");
17447 let b_bwd = find_by_name(&bwd, "b");
17448 let d_out_bwd = find_by_name(&bwd, "d_output");
17449
17450 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17454 let b_data = [1.0, 2.0, 3.0_f64];
17455 let d_output = [1.0_f64]; let (sched, mut arena) = prepare_f64(
17459 &bwd,
17460 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
17461 );
17462 execute_thunks(&sched, arena.raw_buf_mut());
17463
17464 let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
17465 let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
17466 let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
17467
17468 let x_ref = {
17471 let mut a = a_data;
17472 let mut b = b_data;
17473 let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
17474 assert_eq!(info, 0);
17475 b
17476 };
17477 let loss_ref: f64 = x_ref.iter().sum();
17478 let db_ref = {
17480 let mut at = [0.0_f64; 9];
17481 for i in 0..n {
17482 for j in 0..n {
17483 at[i * n + j] = a_data[j * n + i];
17484 }
17485 }
17486 let mut ones = [1.0_f64; 3];
17487 let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
17488 assert_eq!(info, 0);
17489 ones
17490 };
17491 let mut da_ref = [0.0_f64; 9];
17493 for i in 0..n {
17494 for j in 0..n {
17495 da_ref[i * n + j] = -db_ref[i] * x_ref[j];
17496 }
17497 }
17498
17499 assert!(
17501 (loss_out[0] - loss_ref).abs() < 1e-10,
17502 "loss: got {}, want {}",
17503 loss_out[0],
17504 loss_ref
17505 );
17506 for i in 0..n {
17507 assert!(
17508 (db_out[i] - db_ref[i]).abs() < 1e-10,
17509 "db[{i}]: got {}, want {}",
17510 db_out[i],
17511 db_ref[i]
17512 );
17513 }
17514 for i in 0..n * n {
17515 assert!(
17516 (da_out[i] - da_ref[i]).abs() < 1e-10,
17517 "dA[{i}]: got {}, want {}",
17518 da_out[i],
17519 da_ref[i]
17520 );
17521 }
17522
17523 let h = 1e-6_f64;
17526 for k in 0..n {
17527 let mut bp = b_data;
17528 bp[k] += h;
17529 let mut bm = b_data;
17530 bm[k] -= h;
17531 let lp = {
17532 let mut ac = a_data;
17533 let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
17534 assert_eq!(info, 0);
17535 bp.iter().sum::<f64>()
17536 };
17537 let lm = {
17538 let mut ac = a_data;
17539 let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
17540 assert_eq!(info, 0);
17541 bm.iter().sum::<f64>()
17542 };
17543 let fd = (lp - lm) / (2.0 * h);
17544 assert!(
17545 (db_out[k] - fd).abs() < 1e-7,
17546 "FD mismatch on db[{k}]: AD={} FD={}",
17547 db_out[k],
17548 fd
17549 );
17550 }
17551 }
17552
17553 #[test]
17558 fn scan_geometric_growth_f64() {
17559 let n = 3usize;
17560 let length = 10u32;
17561
17562 let mut body = Graph::new("scan_body");
17564 let x = body.input("carry", Shape::new(&[n], DType::F64));
17565 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
17566 let scale = body.add_node(
17567 Op::Constant { data: scale_bytes },
17568 vec![],
17569 Shape::new(&[n], DType::F64),
17570 );
17571 let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
17572 let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
17573 body.set_outputs(vec![next]);
17574
17575 let mut g = Graph::new("scan_outer");
17577 let init = g.input("init", Shape::new(&[n], DType::F64));
17578 let final_carry = g.scan(init, body, length);
17579 g.set_outputs(vec![final_carry]);
17580
17581 let init_data = vec![1.0_f64; n];
17582 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17583 execute_thunks(&sched, arena.raw_buf_mut());
17584 let got = read_arena_f64(&arena, final_carry, n);
17585 let want: f64 = 1.1_f64.powi(length as i32);
17586 for i in 0..n {
17587 assert!(
17588 (got[i] - want).abs() < 1e-12,
17589 "got[{i}] = {} want {}",
17590 got[i],
17591 want
17592 );
17593 }
17594 }
17595
17596 #[test]
17603 fn scan_with_xs_cumulative_sum() {
17604 let n = 3usize;
17605 let length = 4u32;
17606
17607 let mut body = Graph::new("cumsum_body");
17608 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17610 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
17611 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
17612 body.set_outputs(vec![next]);
17613
17614 let mut g = Graph::new("cumsum_outer");
17615 let init = g.input("init", Shape::new(&[n], DType::F64));
17616 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17617 let final_carry = g.scan_with_xs(init, &[xs], body, length);
17618 g.set_outputs(vec![final_carry]);
17619
17620 let init_data = vec![0.0_f64; n];
17621 let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
17623 execute_thunks(&sched, arena.raw_buf_mut());
17624 let got = read_arena_f64(&arena, final_carry, n);
17625
17626 let mut want = init_data.clone();
17630 for t in 0..length as usize {
17631 for j in 0..n {
17632 want[j] += xs_data[t * n + j];
17633 }
17634 }
17635 for i in 0..n {
17636 assert!(
17637 (got[i] - want[i]).abs() < 1e-12,
17638 "got[{i}] = {} want {}",
17639 got[i],
17640 want[i]
17641 );
17642 }
17643 }
17644
17645 #[test]
17649 fn scan_with_xs_be_with_drive() {
17650 let n = 3usize;
17651 let length = 4u32;
17652 let dt = 0.1_f64;
17653
17654 let mut m_data = vec![0.0_f64; n * n];
17655 for i in 0..n {
17656 m_data[i * n + i] = 1.0 + dt * 2.0;
17657 if i > 0 {
17658 m_data[i * n + (i - 1)] = -dt;
17659 }
17660 if i + 1 < n {
17661 m_data[i * n + (i + 1)] = -dt;
17662 }
17663 }
17664 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17665
17666 let mut body = Graph::new("be_drive_body");
17667 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17668 let drive = body.input("drive", Shape::new(&[n], DType::F64));
17669 let m = body.add_node(
17670 Op::Constant { data: m_bytes },
17671 vec![],
17672 Shape::new(&[n, n], DType::F64),
17673 );
17674 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
17675 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
17676 body.set_outputs(vec![next]);
17677
17678 let mut g = Graph::new("be_drive_outer");
17679 let init = g.input("init", Shape::new(&[n], DType::F64));
17680 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17681 let final_carry = g.scan_with_xs(init, &[xs], body, length);
17682 g.set_outputs(vec![final_carry]);
17683
17684 let init_data = vec![0.0_f64; n];
17685 let mut xs_data = vec![0.0_f64; length as usize * n];
17688 xs_data[0] = 1.0;
17689
17690 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
17691 execute_thunks(&sched, arena.raw_buf_mut());
17692 let got = read_arena_f64(&arena, final_carry, n);
17693
17694 let mut x = init_data.clone();
17696 for t in 0..length as usize {
17697 for j in 0..n {
17698 x[j] += xs_data[t * n + j];
17699 }
17700 let mut a_copy = m_data.clone();
17701 crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
17702 }
17703 for i in 0..n {
17704 assert!(
17705 (got[i] - x[i]).abs() < 1e-12,
17706 "got[{i}] = {} ref {}",
17707 got[i],
17708 x[i]
17709 );
17710 }
17711 }
17712
17713 #[test]
17719 fn batched_dense_solve_gradient_matches_per_batch_analytic() {
17720 use rlx_opt::autodiff::grad_with_loss;
17721 let n = 3usize;
17722 let batch = 4usize;
17723
17724 let mut g = Graph::new("bds_grad");
17725 let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
17726 let b = g.input("b", Shape::new(&[batch, n], DType::F64));
17727 let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
17728 let loss = g.reduce(
17729 x,
17730 ReduceOp::Sum,
17731 vec![0, 1],
17732 false,
17733 Shape::new(&[1], DType::F64),
17734 );
17735 g.set_outputs(vec![loss]);
17736
17737 let bwd = grad_with_loss(&g, &[a, b]);
17738
17739 let find = |graph: &Graph, want: &str| -> NodeId {
17740 for node in graph.nodes() {
17741 let name = match &node.op {
17742 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17743 _ => None,
17744 };
17745 if name == Some(want) {
17746 return node.id;
17747 }
17748 }
17749 panic!("no node named {want}");
17750 };
17751 let a_id = find(&bwd, "A");
17752 let b_id = find(&bwd, "b");
17753 let d_out_id = find(&bwd, "d_output");
17754
17755 let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
17756 let mut a_data = vec![0.0_f64; batch * n * n];
17757 let mut b_data = vec![0.0_f64; batch * n];
17758 for bi in 0..batch {
17759 for i in 0..n {
17760 for j in 0..n {
17761 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
17762 }
17763 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
17764 }
17765 for i in 0..n {
17766 b_data[bi * n + i] = rng.next_f32() as f64;
17767 }
17768 }
17769 let d_seed = [1.0_f64];
17770
17771 let (sched, mut arena) = prepare_f64(
17772 &bwd,
17773 &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
17774 );
17775 execute_thunks(&sched, arena.raw_buf_mut());
17776 let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
17777 let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
17778
17779 for bi in 0..batch {
17782 let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
17783 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
17784 let mut a_copy = a_slice.clone();
17785 crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
17786 let x_ref = b_slice.clone();
17787 let mut at = vec![0.0_f64; n * n];
17789 for i in 0..n {
17790 for j in 0..n {
17791 at[i * n + j] = a_slice[j * n + i];
17792 }
17793 }
17794 let mut ones = vec![1.0_f64; n];
17795 crate::blas::dgesv(&mut at, &mut ones, n, 1);
17796 let db_ref = ones;
17797 for i in 0..n {
17798 let got = db_out[bi * n + i];
17799 assert!(
17800 (got - db_ref[i]).abs() < 1e-10,
17801 "batch {bi}, db[{i}]: got {got} ref {}",
17802 db_ref[i]
17803 );
17804 }
17805 for i in 0..n {
17807 for j in 0..n {
17808 let got = da_out[bi * n * n + i * n + j];
17809 let want = -db_ref[i] * x_ref[j];
17810 assert!(
17811 (got - want).abs() < 1e-10,
17812 "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
17813 );
17814 }
17815 }
17816 }
17817 }
17818
17819 #[test]
17824 fn scan_checkpointed_grad_matches_plain_scan_grad() {
17825 use rlx_opt::autodiff::grad_with_loss;
17826 let n = 2usize;
17827 let length = 6u32;
17828
17829 let make_body = || {
17830 let mut body = Graph::new("ck_body");
17831 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17832 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
17833 let scale = body.add_node(
17834 Op::Constant { data: scale_bytes },
17835 vec![],
17836 Shape::new(&[n], DType::F64),
17837 );
17838 let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
17839 body.set_outputs(vec![next]);
17840 body
17841 };
17842
17843 let mut g_plain = Graph::new("ck_plain");
17845 let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
17846 let final_p = g_plain.scan(init_p, make_body(), length);
17847 let loss_p = g_plain.reduce(
17848 final_p,
17849 ReduceOp::Sum,
17850 vec![0],
17851 false,
17852 Shape::new(&[1], DType::F64),
17853 );
17854 g_plain.set_outputs(vec![loss_p]);
17855 let bwd_p = grad_with_loss(&g_plain, &[init_p]);
17856
17857 let mut g_ck = Graph::new("ck_ckpt");
17859 let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
17860 let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
17861 let loss_c = g_ck.reduce(
17862 final_c,
17863 ReduceOp::Sum,
17864 vec![0],
17865 false,
17866 Shape::new(&[1], DType::F64),
17867 );
17868 g_ck.set_outputs(vec![loss_c]);
17869 let bwd_c = grad_with_loss(&g_ck, &[init_c]);
17870
17871 let find = |graph: &Graph, want: &str| -> NodeId {
17872 for node in graph.nodes() {
17873 let name = match &node.op {
17874 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17875 _ => None,
17876 };
17877 if name == Some(want) {
17878 return node.id;
17879 }
17880 }
17881 panic!("no {want}");
17882 };
17883
17884 let init_data = vec![0.5_f64, -0.5];
17885 let d_seed = [1.0_f64];
17886
17887 let (s_p, mut a_p) = prepare_f64(
17888 &bwd_p,
17889 &[
17890 (find(&bwd_p, "init"), &init_data),
17891 (find(&bwd_p, "d_output"), &d_seed),
17892 ],
17893 );
17894 execute_thunks(&s_p, a_p.raw_buf_mut());
17895 let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
17896
17897 let (s_c, mut a_c) = prepare_f64(
17898 &bwd_c,
17899 &[
17900 (find(&bwd_c, "init"), &init_data),
17901 (find(&bwd_c, "d_output"), &d_seed),
17902 ],
17903 );
17904 execute_thunks(&s_c, a_c.raw_buf_mut());
17905 let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
17906
17907 for i in 0..n {
17908 assert!(
17909 (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
17910 "dinit[{i}]: plain={} checkpointed={}",
17911 dinit_p[i],
17912 dinit_c[i]
17913 );
17914 }
17915 }
17916
17917 #[test]
17923 fn recursive_checkpointing_matches_full_trajectory() {
17924 let n = 2usize;
17925 let length = 4u32;
17926
17927 let build_body = || -> Graph {
17929 let mut body = Graph::new("rc_body");
17930 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17931 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
17932 let ones = body.add_node(
17933 Op::Constant { data: ones_bytes },
17934 vec![],
17935 Shape::new(&[n], DType::F64),
17936 );
17937 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
17938 body.set_outputs(vec![next]);
17939 body
17940 };
17941
17942 let body_vjp_for = || -> Graph {
17945 use rlx_opt::autodiff::grad;
17946 let body = build_body();
17947 let carry_id = body
17949 .nodes()
17950 .iter()
17951 .find(|n| matches!(n.op, Op::Input { .. }))
17952 .map(|n| n.id)
17953 .unwrap();
17954 grad(&body, &[carry_id])
17955 };
17956
17957 let mut g_full = Graph::new("rc_outer_full");
17959 let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
17960 let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
17961 let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
17963 let dinit_full_id = g_full.scan_backward(
17964 init_full,
17965 traj_full_id,
17966 upstream_full,
17967 &[],
17968 body_vjp_for(),
17969 length,
17970 true,
17971 Shape::new(&[n], DType::F64),
17972 );
17973 g_full.set_outputs(vec![dinit_full_id]);
17974
17975 let k = 2u32;
17978 let mut g_rec = Graph::new("rc_outer_rec");
17979 let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
17980 let traj_rec_id = g_rec.add_node(
17981 Op::Scan {
17982 body: Box::new(build_body()),
17983 length,
17984 save_trajectory: true,
17985 num_bcast: 0,
17986 num_xs: 0,
17987 num_checkpoints: k,
17988 },
17989 vec![init_rec],
17990 Shape::new(&[k as usize, n], DType::F64),
17991 );
17992 let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
17995 let dinit_rec_id = g_rec.add_node(
17996 Op::ScanBackward {
17997 body_vjp: Box::new(body_vjp_for()),
17998 length,
17999 save_trajectory: true,
18000 num_xs: 0,
18001 num_checkpoints: k,
18002 forward_body: Some(Box::new(build_body())),
18003 },
18004 vec![init_rec, traj_rec_id, upstream_rec],
18005 Shape::new(&[n], DType::F64),
18006 );
18007 g_rec.set_outputs(vec![dinit_rec_id]);
18008
18009 let init_data = vec![0.5_f64, -0.5];
18011 let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
18012
18013 let find = |graph: &Graph, want: &str| -> NodeId {
18014 for node in graph.nodes() {
18015 if let Op::Input { name } = &node.op
18016 && name == want
18017 {
18018 return node.id;
18019 }
18020 }
18021 panic!("no input {want}");
18022 };
18023
18024 let (s_full, mut a_full) = prepare_f64(
18025 &g_full,
18026 &[
18027 (find(&g_full, "init"), &init_data),
18028 (find(&g_full, "upstream"), &upstream_data),
18029 ],
18030 );
18031 execute_thunks(&s_full, a_full.raw_buf_mut());
18032 let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
18033
18034 let (s_rec, mut a_rec) = prepare_f64(
18035 &g_rec,
18036 &[
18037 (find(&g_rec, "init"), &init_data),
18038 (find(&g_rec, "upstream"), &upstream_data),
18039 ],
18040 );
18041 execute_thunks(&s_rec, a_rec.raw_buf_mut());
18042 let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
18043
18044 for i in 0..n {
18045 assert!(
18046 (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
18047 "i={i}: full={} rec={}",
18048 dinit_full[i],
18049 dinit_rec[i]
18050 );
18051 }
18052 }
18053
18054 #[test]
18063 fn vmap_of_grad_scan_matches_per_row_runs() {
18064 use rlx_opt::autodiff::grad_with_loss;
18065 use rlx_opt::vmap::vmap;
18066 let n = 2usize;
18067 let length = 3u32;
18068 let batch = 3usize;
18069
18070 let mut body = Graph::new("scan_grad_body");
18071 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18072 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18073 let ones = body.add_node(
18074 Op::Constant { data: ones_bytes },
18075 vec![],
18076 Shape::new(&[n], DType::F64),
18077 );
18078 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
18079 body.set_outputs(vec![next]);
18080
18081 let mut g = Graph::new("scan_grad_outer");
18082 let init = g.input("init", Shape::new(&[n], DType::F64));
18083 let final_x = g.scan(init, body, length);
18084 let loss = g.reduce(
18085 final_x,
18086 ReduceOp::Sum,
18087 vec![0],
18088 false,
18089 Shape::new(&[1], DType::F64),
18090 );
18091 g.set_outputs(vec![loss]);
18092
18093 let bwd = grad_with_loss(&g, &[init]);
18094 let bg = vmap(&bwd, &["init"], batch);
18095
18096 let find = |graph: &Graph, want: &str| -> NodeId {
18097 for node in graph.nodes() {
18098 let name = match &node.op {
18099 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18100 _ => None,
18101 };
18102 if name == Some(want) {
18103 return node.id;
18104 }
18105 }
18106 panic!("no node named {want}");
18107 };
18108 let init_b = find(&bg, "init");
18109 let d_out_b = find(&bg, "d_output");
18110
18111 let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
18112 let d_seed = [1.0_f64];
18113
18114 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
18115 execute_thunks(&sched, arena.raw_buf_mut());
18116 let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
18117
18118 for i in 0..batch * n {
18119 assert!(
18120 (dinit_b[i] - 1.0).abs() < 1e-12,
18121 "dinit[{i}] = {} (expected 1.0)",
18122 dinit_b[i]
18123 );
18124 }
18125
18126 for bi in 0..batch {
18128 let row = &init_data[bi * n..(bi + 1) * n];
18129 let mut g2 = Graph::new("per_row_grad");
18130 let init2 = g2.input("init", Shape::new(&[n], DType::F64));
18131 let mut body2 = Graph::new("per_row_body");
18132 let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
18133 let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18134 let ones2 = body2.add_node(
18135 Op::Constant { data: ones2_bytes },
18136 vec![],
18137 Shape::new(&[n], DType::F64),
18138 );
18139 let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
18140 body2.set_outputs(vec![next2]);
18141 let final2 = g2.scan(init2, body2, length);
18142 let loss2 = g2.reduce(
18143 final2,
18144 ReduceOp::Sum,
18145 vec![0],
18146 false,
18147 Shape::new(&[1], DType::F64),
18148 );
18149 g2.set_outputs(vec![loss2]);
18150 let bwd2 = grad_with_loss(&g2, &[init2]);
18151 let init2_id = find(&bwd2, "init");
18152 let d_out2_id = find(&bwd2, "d_output");
18153 let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
18154 execute_thunks(&s2, a2.raw_buf_mut());
18155 let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
18156 for j in 0..n {
18157 let got = dinit_b[bi * n + j];
18158 let want = row_dinit[j];
18159 assert!(
18160 (got - want).abs() < 1e-12,
18161 "row {bi}, j {j}: vmap'd={got} per-row={want}"
18162 );
18163 }
18164 }
18165 }
18166
18167 #[test]
18173 fn vmap_scan_cumulative_sum_matches_scalar_runs() {
18174 use rlx_opt::vmap::vmap;
18175 let n = 2usize;
18176 let length = 4u32;
18177 let batch = 3usize;
18178
18179 let mut body = Graph::new("scan_body_cumsum");
18181 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18182 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
18183 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
18184 body.set_outputs(vec![next]);
18185
18186 let mut g = Graph::new("scan_outer_cumsum");
18187 let init = g.input("init", Shape::new(&[n], DType::F64));
18188 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18189 let final_carry = g.scan_with_xs(init, &[xs], body, length);
18190 g.set_outputs(vec![final_carry]);
18191
18192 let bg = vmap(&g, &["init", "xs"], batch);
18194
18195 let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
18197 let xs_data: Vec<f64> = (0..batch * length as usize * n)
18200 .map(|i| 0.1 * (i as f64))
18201 .collect();
18202
18203 let find = |graph: &Graph, want: &str| -> NodeId {
18204 for node in graph.nodes() {
18205 if let Op::Input { name } = &node.op
18206 && name == want
18207 {
18208 return node.id;
18209 }
18210 }
18211 panic!("no input {want}");
18212 };
18213 let init_b = find(&bg, "init");
18214 let xs_b = find(&bg, "xs");
18215 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
18216 execute_thunks(&sched, arena.raw_buf_mut());
18217 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
18218
18219 for bi in 0..batch {
18221 let init_slice = &init_data[bi * n..(bi + 1) * n];
18222 let mut x = init_slice.to_vec();
18223 for t in 0..length as usize {
18224 for j in 0..n {
18225 x[j] += xs_data[bi * length as usize * n + t * n + j];
18226 }
18227 }
18228
18229 for i in 0..n {
18230 let got = batched_out[bi * n + i];
18231 assert!(
18232 (got - x[i]).abs() < 1e-12,
18233 "row {bi}, i {i}: got {got} ref {}",
18234 x[i]
18235 );
18236 }
18237 }
18238 }
18239
18240 #[test]
18245 fn vmap_dense_solve_matches_scalar_runs() {
18246 use rlx_opt::vmap::vmap;
18247 let n = 3usize;
18248 let batch = 4usize;
18249
18250 let mut g = Graph::new("solve_forward");
18251 let a = g.input("A", Shape::new(&[n, n], DType::F64));
18252 let b = g.input("b", Shape::new(&[n], DType::F64));
18253 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
18254 g.set_outputs(vec![x]);
18255
18256 let bg = vmap(&g, &["A", "b"], batch);
18258
18259 let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
18261 let mut a_data = vec![0.0_f64; batch * n * n];
18262 let mut b_data = vec![0.0_f64; batch * n];
18263 for bi in 0..batch {
18264 for i in 0..n {
18266 for j in 0..n {
18267 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
18268 }
18269 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
18270 }
18271 for i in 0..n {
18272 b_data[bi * n + i] = rng.next_f32() as f64;
18273 }
18274 }
18275
18276 let find = |graph: &Graph, want: &str| -> NodeId {
18277 for node in graph.nodes() {
18278 if let Op::Input { name } = &node.op
18279 && name == want
18280 {
18281 return node.id;
18282 }
18283 }
18284 panic!("no input named {want}");
18285 };
18286 let ba = find(&bg, "A");
18287 let bb = find(&bg, "b");
18288 let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
18289 execute_thunks(&sched, arena.raw_buf_mut());
18290 let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
18291
18292 for bi in 0..batch {
18294 let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
18295 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
18296 crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
18297 for i in 0..n {
18298 let got = batched_x[bi * n + i];
18299 let want = b_slice[i];
18300 assert!(
18301 (got - want).abs() < 1e-12,
18302 "row {bi}, i {i}: got {got} want {want}"
18303 );
18304 }
18305 }
18306 }
18307
18308 #[test]
18315 fn vmap_matmul_add_reduce_matches_scalar_runs() {
18316 use rlx_opt::vmap::vmap;
18317 let n = 3usize;
18318 let batch = 4usize;
18319
18320 let mut g = Graph::new("vmap_e2e_forward");
18322 let x = g.input("x", Shape::new(&[n], DType::F64));
18323 let w = g.input("w", Shape::new(&[n, n], DType::F64));
18324 let b = g.input("b", Shape::new(&[n], DType::F64));
18325 let x_row = g.add_node(
18326 Op::Reshape {
18327 new_shape: vec![1, n as i64],
18328 },
18329 vec![x],
18330 Shape::new(&[1, n], DType::F64),
18331 );
18332 let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
18333 let mm_flat = g.add_node(
18334 Op::Reshape {
18335 new_shape: vec![n as i64],
18336 },
18337 vec![mm],
18338 Shape::new(&[n], DType::F64),
18339 );
18340 let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
18341 let loss = g.reduce(
18342 yv,
18343 ReduceOp::Sum,
18344 vec![0],
18345 false,
18346 Shape::new(&[1], DType::F64),
18347 );
18348 g.set_outputs(vec![loss]);
18349
18350 let bg = vmap(&g, &["x"], batch);
18352
18353 let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
18355 let n_w = n * n;
18356 let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
18357 let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
18358 let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
18359 for _ in 0..batch * n {
18360 x_data_batched.push(rng.next_f32() as f64);
18361 }
18362
18363 let find = |graph: &Graph, want: &str| -> NodeId {
18365 for node in graph.nodes() {
18366 if let Op::Input { name } = &node.op
18367 && name == want
18368 {
18369 return node.id;
18370 }
18371 }
18372 panic!("no input named {want}");
18373 };
18374 let bx = find(&bg, "x");
18375 let bw = find(&bg, "w");
18376 let bb = find(&bg, "b");
18377 let (sched, mut arena) =
18378 prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
18379 execute_thunks(&sched, arena.raw_buf_mut());
18380 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
18386
18387 for bi in 0..batch {
18389 let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
18390 let mut g2 = Graph::new("scalar_run");
18391 let x2 = g2.input("x", Shape::new(&[n], DType::F64));
18392 let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
18393 let b2 = g2.input("b", Shape::new(&[n], DType::F64));
18394 let xr = g2.add_node(
18395 Op::Reshape {
18396 new_shape: vec![1, n as i64],
18397 },
18398 vec![x2],
18399 Shape::new(&[1, n], DType::F64),
18400 );
18401 let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
18402 let mf = g2.add_node(
18403 Op::Reshape {
18404 new_shape: vec![n as i64],
18405 },
18406 vec![m],
18407 Shape::new(&[n], DType::F64),
18408 );
18409 let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
18410 let l2 = g2.reduce(
18411 yv2,
18412 ReduceOp::Sum,
18413 vec![0],
18414 false,
18415 Shape::new(&[1], DType::F64),
18416 );
18417 g2.set_outputs(vec![l2]);
18418 let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
18419 execute_thunks(&s2, a2.raw_buf_mut());
18420 let scalar_out = read_arena_f64(&a2, l2, 1);
18421 assert!(
18422 (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
18423 "row {bi}: batched={} scalar={}",
18424 batched_out[bi],
18425 scalar_out[0]
18426 );
18427 }
18428 }
18429
18430 #[test]
18437 fn scan_with_xs_dxs_matches_fd() {
18438 use rlx_opt::autodiff::grad_with_loss;
18439 let n = 3usize;
18440 let length = 3u32;
18441 let dt = 0.1_f64;
18442
18443 let mut m_data = vec![0.0_f64; n * n];
18444 for i in 0..n {
18445 m_data[i * n + i] = 1.0 + dt * 2.0;
18446 if i > 0 {
18447 m_data[i * n + (i - 1)] = -dt;
18448 }
18449 if i + 1 < n {
18450 m_data[i * n + (i + 1)] = -dt;
18451 }
18452 }
18453 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18454
18455 let mut body = Graph::new("be_dxs_body");
18456 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18457 let drive = body.input("drive", Shape::new(&[n], DType::F64));
18458 let m = body.add_node(
18459 Op::Constant { data: m_bytes },
18460 vec![],
18461 Shape::new(&[n, n], DType::F64),
18462 );
18463 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18464 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18465 body.set_outputs(vec![next]);
18466
18467 let mut g = Graph::new("be_dxs_outer");
18468 let init = g.input("init", Shape::new(&[n], DType::F64));
18469 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18470 let final_carry = g.scan_with_xs(init, &[xs], body, length);
18471 let loss = g.reduce(
18472 final_carry,
18473 ReduceOp::Sum,
18474 vec![0],
18475 false,
18476 Shape::new(&[1], DType::F64),
18477 );
18478 g.set_outputs(vec![loss]);
18479
18480 let bwd = grad_with_loss(&g, &[init, xs]);
18482 assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
18483
18484 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18485 for node in graph.nodes() {
18486 let name = match &node.op {
18487 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18488 _ => None,
18489 };
18490 if name == Some(want) {
18491 return node.id;
18492 }
18493 }
18494 panic!("no node named {want:?}");
18495 };
18496 let init_bwd = find_by_name(&bwd, "init");
18497 let xs_bwd = find_by_name(&bwd, "xs");
18498 let d_out_bwd = find_by_name(&bwd, "d_output");
18499
18500 let init_data = vec![0.5_f64, 0.0, -0.5];
18501 let xs_data: Vec<f64> = (0..length as usize * n)
18502 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
18503 .collect();
18504 let d_seed = [1.0_f64];
18505
18506 let (sched, mut arena) = prepare_f64(
18507 &bwd,
18508 &[
18509 (init_bwd, &init_data),
18510 (xs_bwd, &xs_data),
18511 (d_out_bwd, &d_seed),
18512 ],
18513 );
18514 execute_thunks(&sched, arena.raw_buf_mut());
18515 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18516 let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
18517
18518 let h = 1e-6;
18519 let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
18520 let mut acc = x0.to_vec();
18521 for t in 0..length as usize {
18522 for j in 0..n {
18523 acc[j] += xs_in[t * n + j];
18524 }
18525 let mut a_copy = m_data.clone();
18526 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18527 }
18528 acc.iter().sum()
18529 };
18530
18531 for i in 0..n {
18533 let mut ip = init_data.to_vec();
18534 ip[i] += h;
18535 let mut im = init_data.to_vec();
18536 im[i] -= h;
18537 let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
18538 assert!(
18539 (dinit[i] - fd).abs() < 1e-7,
18540 "FD dinit[{i}]: AD={} FD={}",
18541 dinit[i],
18542 fd
18543 );
18544 }
18545
18546 for t in 0..length as usize {
18548 for j in 0..n {
18549 let idx = t * n + j;
18550 let mut xp = xs_data.clone();
18551 xp[idx] += h;
18552 let mut xm = xs_data.clone();
18553 xm[idx] -= h;
18554 let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
18555 assert!(
18556 (dxs[idx] - fd).abs() < 1e-7,
18557 "FD dxs[t={t},j={j}]: AD={} FD={}",
18558 dxs[idx],
18559 fd
18560 );
18561 }
18562 }
18563 }
18564
18565 #[test]
18573 fn scan_with_xs_gradient_dinit_matches_fd() {
18574 use rlx_opt::autodiff::grad_with_loss;
18575 let n = 3usize;
18576 let length = 3u32;
18577 let dt = 0.1_f64;
18578
18579 let mut m_data = vec![0.0_f64; n * n];
18580 for i in 0..n {
18581 m_data[i * n + i] = 1.0 + dt * 2.0;
18582 if i > 0 {
18583 m_data[i * n + (i - 1)] = -dt;
18584 }
18585 if i + 1 < n {
18586 m_data[i * n + (i + 1)] = -dt;
18587 }
18588 }
18589 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18590
18591 let mut body = Graph::new("be_xs_grad_body");
18592 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18593 let drive = body.input("drive", Shape::new(&[n], DType::F64));
18594 let m = body.add_node(
18595 Op::Constant { data: m_bytes },
18596 vec![],
18597 Shape::new(&[n, n], DType::F64),
18598 );
18599 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18600 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18601 body.set_outputs(vec![next]);
18602
18603 let mut g = Graph::new("be_xs_grad_outer");
18604 let init = g.input("init", Shape::new(&[n], DType::F64));
18605 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18606 let final_carry = g.scan_with_xs(init, &[xs], body, length);
18607 let loss = g.reduce(
18608 final_carry,
18609 ReduceOp::Sum,
18610 vec![0],
18611 false,
18612 Shape::new(&[1], DType::F64),
18613 );
18614 g.set_outputs(vec![loss]);
18615
18616 let bwd = grad_with_loss(&g, &[init]);
18617
18618 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18619 for node in graph.nodes() {
18620 let name = match &node.op {
18621 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18622 _ => None,
18623 };
18624 if name == Some(want) {
18625 return node.id;
18626 }
18627 }
18628 panic!("no node named {want:?}");
18629 };
18630 let init_bwd = find_by_name(&bwd, "init");
18631 let xs_bwd = find_by_name(&bwd, "xs");
18632 let d_out_bwd = find_by_name(&bwd, "d_output");
18633
18634 let init_data = vec![0.5_f64, 0.0, -0.5];
18635 let xs_data: Vec<f64> = (0..length as usize * n)
18637 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
18638 .collect();
18639 let d_seed = [1.0_f64];
18640
18641 let (sched, mut arena) = prepare_f64(
18642 &bwd,
18643 &[
18644 (init_bwd, &init_data),
18645 (xs_bwd, &xs_data),
18646 (d_out_bwd, &d_seed),
18647 ],
18648 );
18649 execute_thunks(&sched, arena.raw_buf_mut());
18650 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18651
18652 let h = 1e-6;
18653 let loss_at = |x0: &[f64]| -> f64 {
18654 let mut acc = x0.to_vec();
18655 for t in 0..length as usize {
18656 for j in 0..n {
18657 acc[j] += xs_data[t * n + j];
18658 }
18659 let mut a_copy = m_data.clone();
18660 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18661 }
18662 acc.iter().sum()
18663 };
18664 for i in 0..n {
18665 let mut ip = init_data.to_vec();
18666 ip[i] += h;
18667 let mut im = init_data.to_vec();
18668 im[i] -= h;
18669 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18670 assert!(
18671 (dinit[i] - fd).abs() < 1e-7,
18672 "FD dinit[{i}]: AD={} FD={}",
18673 dinit[i],
18674 fd
18675 );
18676 }
18677 }
18678
18679 #[test]
18687 fn scan_gradient_geometric_matches_closed_form() {
18688 use rlx_opt::autodiff::grad_with_loss;
18689 let n = 3usize;
18690 let length = 5u32;
18691
18692 let mut body = Graph::new("scan_grad_body");
18693 let x = body.input("carry", Shape::new(&[n], DType::F64));
18694 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
18695 let scale = body.add_node(
18696 Op::Constant { data: scale_bytes },
18697 vec![],
18698 Shape::new(&[n], DType::F64),
18699 );
18700 let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
18701 body.set_outputs(vec![next]);
18702
18703 let mut g = Graph::new("scan_grad_outer");
18704 let init = g.input("init", Shape::new(&[n], DType::F64));
18705 let final_x = g.scan(init, body, length);
18706 let loss = g.reduce(
18707 final_x,
18708 ReduceOp::Sum,
18709 vec![0],
18710 false,
18711 Shape::new(&[1], DType::F64),
18712 );
18713 g.set_outputs(vec![loss]);
18714
18715 let bwd = grad_with_loss(&g, &[init]);
18716 assert_eq!(bwd.outputs.len(), 2);
18717
18718 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18719 for node in graph.nodes() {
18720 let name = match &node.op {
18721 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18722 _ => None,
18723 };
18724 if name == Some(want) {
18725 return node.id;
18726 }
18727 }
18728 panic!("no node named {want:?}");
18729 };
18730 let init_bwd = find_by_name(&bwd, "init");
18731 let d_out_bwd = find_by_name(&bwd, "d_output");
18732
18733 let init_data = vec![1.0_f64; n];
18734 let d_seed = [1.0_f64];
18735 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
18736 execute_thunks(&sched, arena.raw_buf_mut());
18737 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18738
18739 let want = 1.1_f64.powi(length as i32);
18740 for i in 0..n {
18741 assert!(
18742 (dinit[i] - want).abs() < 1e-12,
18743 "dinit[{i}] = {} want {}",
18744 dinit[i],
18745 want
18746 );
18747 }
18748
18749 let h = 1e-6;
18751 let loss_at = |x: &[f64]| -> f64 {
18752 let mut acc = x.to_vec();
18753 for _ in 0..length {
18754 for v in acc.iter_mut() {
18755 *v *= 1.1;
18756 }
18757 }
18758 acc.iter().sum()
18759 };
18760 let mut ip = init_data.clone();
18761 ip[0] += h;
18762 let mut im = init_data.clone();
18763 im[0] -= h;
18764 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18765 assert!(
18766 (dinit[0] - fd).abs() < 1e-7,
18767 "FD dinit[0]: AD={} FD={}",
18768 dinit[0],
18769 fd
18770 );
18771 }
18772
18773 #[test]
18776 fn scan_gradient_backward_euler_matches_fd() {
18777 use rlx_opt::autodiff::grad_with_loss;
18778 let n = 4usize;
18779 let length = 3u32;
18780 let dt = 0.05_f64;
18781
18782 let mut m_data = vec![0.0_f64; n * n];
18783 for i in 0..n {
18784 m_data[i * n + i] = 1.0 + dt * 2.0;
18785 if i > 0 {
18786 m_data[i * n + (i - 1)] = -dt;
18787 }
18788 if i + 1 < n {
18789 m_data[i * n + (i + 1)] = -dt;
18790 }
18791 }
18792 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18793
18794 let mut body = Graph::new("be_grad_body");
18795 let x = body.input("x", Shape::new(&[n], DType::F64));
18796 let m = body.add_node(
18797 Op::Constant { data: m_bytes },
18798 vec![],
18799 Shape::new(&[n, n], DType::F64),
18800 );
18801 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18802 body.set_outputs(vec![next]);
18803
18804 let mut g = Graph::new("be_grad_outer");
18805 let init = g.input("x0", Shape::new(&[n], DType::F64));
18806 let final_x = g.scan(init, body, length);
18807 let loss = g.reduce(
18808 final_x,
18809 ReduceOp::Sum,
18810 vec![0],
18811 false,
18812 Shape::new(&[1], DType::F64),
18813 );
18814 g.set_outputs(vec![loss]);
18815
18816 let bwd = grad_with_loss(&g, &[init]);
18817
18818 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18819 for node in graph.nodes() {
18820 let name = match &node.op {
18821 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18822 _ => None,
18823 };
18824 if name == Some(want) {
18825 return node.id;
18826 }
18827 }
18828 panic!("no node named {want:?}");
18829 };
18830 let init_bwd = find_by_name(&bwd, "x0");
18831 let d_out_bwd = find_by_name(&bwd, "d_output");
18832
18833 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18834 let d_seed = [1.0_f64];
18835 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
18836 execute_thunks(&sched, arena.raw_buf_mut());
18837 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18838
18839 let h = 1e-6;
18840 let loss_at = |x0: &[f64]| -> f64 {
18841 let mut acc = x0.to_vec();
18842 for _ in 0..length {
18843 let mut a_copy = m_data.clone();
18844 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18845 }
18846 acc.iter().sum()
18847 };
18848 for i in 0..n {
18849 let mut ip = init_data.to_vec();
18850 ip[i] += h;
18851 let mut im = init_data.to_vec();
18852 im[i] -= h;
18853 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18854 assert!(
18855 (dinit[i] - fd).abs() < 1e-7,
18856 "FD dinit[{i}]: AD={} FD={}",
18857 dinit[i],
18858 fd
18859 );
18860 }
18861 }
18862
18863 #[test]
18869 fn scan_trajectory_backward_euler_records_waveform() {
18870 let n = 4usize;
18871 let length = 5u32;
18872 let dt = 0.05_f64;
18873
18874 let mut m_data = vec![0.0_f64; n * n];
18875 for i in 0..n {
18876 m_data[i * n + i] = 1.0 + dt * 2.0;
18877 if i > 0 {
18878 m_data[i * n + (i - 1)] = -dt;
18879 }
18880 if i + 1 < n {
18881 m_data[i * n + (i + 1)] = -dt;
18882 }
18883 }
18884 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18885
18886 let mut body = Graph::new("be_traj_body");
18887 let x = body.input("x", Shape::new(&[n], DType::F64));
18888 let m = body.add_node(
18889 Op::Constant { data: m_bytes },
18890 vec![],
18891 Shape::new(&[n, n], DType::F64),
18892 );
18893 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18894 body.set_outputs(vec![next]);
18895
18896 let mut g = Graph::new("be_traj_outer");
18897 let init = g.input("x0", Shape::new(&[n], DType::F64));
18898 let traj = g.scan_trajectory(init, body, length);
18899 g.set_outputs(vec![traj]);
18900
18901 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18902 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
18903 execute_thunks(&sched, arena.raw_buf_mut());
18904 let got = read_arena_f64(&arena, traj, length as usize * n);
18905
18906 let mut want = Vec::<f64>::with_capacity(length as usize * n);
18908 let mut x_ref = init_data.to_vec();
18909 for _ in 0..length {
18910 let mut a_copy = m_data.clone();
18911 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
18912 want.extend_from_slice(&x_ref);
18913 }
18914 for i in 0..length as usize * n {
18915 assert!(
18916 (got[i] - want[i]).abs() < 1e-12,
18917 "got[{i}] = {} ref {}",
18918 got[i],
18919 want[i]
18920 );
18921 }
18922
18923 for t in 1..length as usize {
18926 let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
18927 let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
18928 assert!(
18929 curr <= prev + 1e-15,
18930 "mass should decay: row {} sum {prev}, row {t} sum {curr}",
18931 t - 1
18932 );
18933 }
18934
18935 let mut body2 = Graph::new("be_final_body");
18939 let x2 = body2.input("x", Shape::new(&[n], DType::F64));
18940 let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18941 let m2 = body2.add_node(
18942 Op::Constant { data: m_bytes2 },
18943 vec![],
18944 Shape::new(&[n, n], DType::F64),
18945 );
18946 let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
18947 body2.set_outputs(vec![next2]);
18948
18949 let mut g2 = Graph::new("be_final_outer");
18950 let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
18951 let final_x = g2.scan(init2, body2, length);
18952 g2.set_outputs(vec![final_x]);
18953 let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
18954 execute_thunks(&sched2, arena2.raw_buf_mut());
18955 let final_got = read_arena_f64(&arena2, final_x, n);
18956
18957 let last_row = &got[(length as usize - 1) * n..length as usize * n];
18958 for i in 0..n {
18959 assert!(
18960 (last_row[i] - final_got[i]).abs() < 1e-15,
18961 "last trajectory row[{i}] = {} vs final-scan = {}",
18962 last_row[i],
18963 final_got[i]
18964 );
18965 }
18966 }
18967
18968 #[test]
18974 fn scan_backward_euler_heat_f64() {
18975 let n = 4usize;
18976 let length = 5u32;
18977 let dt = 0.05_f64;
18978
18979 let mut m_data = vec![0.0_f64; n * n];
18982 for i in 0..n {
18983 m_data[i * n + i] = 1.0 + dt * 2.0;
18984 if i > 0 {
18985 m_data[i * n + (i - 1)] = -dt;
18986 }
18987 if i + 1 < n {
18988 m_data[i * n + (i + 1)] = -dt;
18989 }
18990 }
18991 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18992
18993 let mut body = Graph::new("be_body");
18994 let x = body.input("x", Shape::new(&[n], DType::F64));
18995 let m = body.add_node(
18996 Op::Constant { data: m_bytes },
18997 vec![],
18998 Shape::new(&[n, n], DType::F64),
18999 );
19000 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
19001 body.set_outputs(vec![next]);
19002
19003 let mut g = Graph::new("be_outer");
19004 let init = g.input("x0", Shape::new(&[n], DType::F64));
19005 let final_x = g.scan(init, body, length);
19006 g.set_outputs(vec![final_x]);
19007
19008 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
19010 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
19011 execute_thunks(&sched, arena.raw_buf_mut());
19012 let got = read_arena_f64(&arena, final_x, n);
19013
19014 let mut ref_x = init_data.to_vec();
19016 for _ in 0..length {
19017 let mut a_copy = m_data.clone();
19018 crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
19019 }
19020 for i in 0..n {
19021 assert!(
19022 (got[i] - ref_x[i]).abs() < 1e-12,
19023 "got[{i}] = {} ref {}",
19024 got[i],
19025 ref_x[i]
19026 );
19027 }
19028 let mass: f64 = got.iter().sum();
19033 assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
19034 }
19035
19036 #[test]
19040 fn dense_solve_f64_multi_rhs_forward() {
19041 let n = 3usize;
19042 let k = 2usize;
19043 let mut g = Graph::new("solve_multi_rhs");
19044 let a = g.input("A", Shape::new(&[n, n], DType::F64));
19045 let b = g.input("B", Shape::new(&[n, k], DType::F64));
19046 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19047 g.set_outputs(vec![x]);
19048
19049 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19050 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19051 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
19052 execute_thunks(&sched, arena.raw_buf_mut());
19053 let x_got = read_arena_f64(&arena, x, n * k);
19054 for c in 0..k {
19055 for i in 0..n {
19056 let mut acc = 0.0_f64;
19057 for j in 0..n {
19058 acc += a_data[i * n + j] * x_got[j * k + c];
19059 }
19060 let want = b_data[i * k + c];
19061 assert!(
19062 (acc - want).abs() < 1e-10,
19063 "col {c} row {i}: got {acc} want {want}"
19064 );
19065 }
19066 }
19067 }
19068
19069 #[test]
19072 fn dense_solve_f64_multi_rhs_gradient() {
19073 use rlx_opt::autodiff::grad_with_loss;
19074 let n = 3usize;
19075 let k = 2usize;
19076 let mut g = Graph::new("solve_mrhs_grad");
19077 let a = g.param("A", Shape::new(&[n, n], DType::F64));
19078 let b = g.input("B", Shape::new(&[n, k], DType::F64));
19079 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19080 let loss = g.reduce(
19081 x,
19082 ReduceOp::Sum,
19083 vec![0, 1],
19084 false,
19085 Shape::new(&[1], DType::F64),
19086 );
19087 g.set_outputs(vec![loss]);
19088
19089 let bwd = grad_with_loss(&g, &[a, b]);
19090 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19091 for node in graph.nodes() {
19092 let name = match &node.op {
19093 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19094 _ => None,
19095 };
19096 if name == Some(want) {
19097 return node.id;
19098 }
19099 }
19100 panic!("no node named {want:?}");
19101 };
19102 let a_bwd = find_by_name(&bwd, "A");
19103 let b_bwd = find_by_name(&bwd, "B");
19104 let d_out = find_by_name(&bwd, "d_output");
19105
19106 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19107 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19108 let d_seed = [1.0_f64];
19109
19110 let (sched, mut arena) = prepare_f64(
19111 &bwd,
19112 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
19113 );
19114 execute_thunks(&sched, arena.raw_buf_mut());
19115 let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
19116 let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
19117
19118 let mut x_ref = b_data;
19120 {
19121 let mut a_copy = a_data;
19122 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
19123 }
19124 let mut at = [0.0_f64; 9];
19125 for i in 0..n {
19126 for j in 0..n {
19127 at[i * n + j] = a_data[j * n + i];
19128 }
19129 }
19130 let mut ones_nk = vec![1.0_f64; n * k];
19131 crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
19132 let db_ref = ones_nk;
19133 let mut da_ref = [0.0_f64; 9];
19134 for i in 0..n {
19135 for j in 0..n {
19136 let mut acc = 0.0_f64;
19137 for c in 0..k {
19138 acc += db_ref[i * k + c] * x_ref[j * k + c];
19139 }
19140 da_ref[i * n + j] = -acc;
19141 }
19142 }
19143 for i in 0..n * k {
19144 assert!(
19145 (db_got[i] - db_ref[i]).abs() < 1e-10,
19146 "dB[{i}]: got {} want {}",
19147 db_got[i],
19148 db_ref[i]
19149 );
19150 }
19151 for i in 0..n * n {
19152 assert!(
19153 (da_got[i] - da_ref[i]).abs() < 1e-10,
19154 "dA[{i}]: got {} want {}",
19155 da_got[i],
19156 da_ref[i]
19157 );
19158 }
19159
19160 let h = 1e-6;
19162 let mut bp = b_data;
19163 bp[0] += h;
19164 let mut bm = b_data;
19165 bm[0] -= h;
19166 let xp = {
19167 let mut a_copy = a_data;
19168 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
19169 bp
19170 };
19171 let xm = {
19172 let mut a_copy = a_data;
19173 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
19174 bm
19175 };
19176 let lp: f64 = xp.iter().sum();
19177 let lm: f64 = xm.iter().sum();
19178 let fd = (lp - lm) / (2.0 * h);
19179 assert!(
19180 (db_got[0] - fd).abs() < 1e-7,
19181 "FD dB[0,0]: AD={} FD={}",
19182 db_got[0],
19183 fd
19184 );
19185 }
19186
19187 #[test]
19189 fn dense_solve_f64_multi_rhs_jvp() {
19190 use rlx_opt::autodiff_fwd::jvp;
19191 let n = 3usize;
19192 let k = 2usize;
19193 let mut g = Graph::new("solve_mrhs_jvp");
19194 let a = g.input("A", Shape::new(&[n, n], DType::F64));
19195 let b = g.input("B", Shape::new(&[n, k], DType::F64));
19196 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19197 g.set_outputs(vec![x]);
19198
19199 let jg = jvp(&g, &[b]);
19200 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19201 for node in graph.nodes() {
19202 let name = match &node.op {
19203 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19204 _ => None,
19205 };
19206 if name == Some(want) {
19207 return node.id;
19208 }
19209 }
19210 panic!("no node named {want:?}");
19211 };
19212 let a_id = find_by_name(&jg, "A");
19213 let b_id = find_by_name(&jg, "B");
19214 let tb_id = find_by_name(&jg, "tangent_B");
19215
19216 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19217 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19218 let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
19219
19220 let (sched, mut arena) =
19221 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
19222 execute_thunks(&sched, arena.raw_buf_mut());
19223 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
19224
19225 let mut a_copy = a_data;
19226 let mut tb_copy = tb_data;
19227 crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
19228 for i in 0..n * k {
19229 assert!(
19230 (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
19231 "t_X[{i}]: AD={} ref={}",
19232 tangent_x[i],
19233 tb_copy[i]
19234 );
19235 }
19236
19237 let h = 1e-6;
19238 let mut bp = b_data;
19239 let mut bm = b_data;
19240 for i in 0..n * k {
19241 bp[i] += h * tb_data[i];
19242 bm[i] -= h * tb_data[i];
19243 }
19244 let xp = {
19245 let mut a_copy = a_data;
19246 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
19247 bp
19248 };
19249 let xm = {
19250 let mut a_copy = a_data;
19251 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
19252 bm
19253 };
19254 for i in 0..n * k {
19255 let fd = (xp[i] - xm[i]) / (2.0 * h);
19256 assert!(
19257 (tangent_x[i] - fd).abs() < 1e-7,
19258 "FD t_X[{i}]: AD={} FD={}",
19259 tangent_x[i],
19260 fd
19261 );
19262 }
19263 }
19264
19265 #[test]
19272 fn jvp_dense_solve_b_runs_and_matches_fd() {
19273 use rlx_opt::autodiff_fwd::jvp;
19274 let n = 3usize;
19275
19276 let mut g = Graph::new("jvp_b_e2e");
19278 let a = g.input("A", Shape::new(&[n, n], DType::F64));
19279 let b = g.input("b", Shape::new(&[n], DType::F64));
19280 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19281 g.set_outputs(vec![x]);
19282
19283 let jg = jvp(&g, &[b]);
19285 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19287 for node in graph.nodes() {
19288 let name = match &node.op {
19289 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19290 _ => None,
19291 };
19292 if name == Some(want) {
19293 return node.id;
19294 }
19295 }
19296 panic!("no node named {want:?}");
19297 };
19298 let a_id = find_by_name(&jg, "A");
19299 let b_id = find_by_name(&jg, "b");
19300 let tb_id = find_by_name(&jg, "tangent_b");
19301
19302 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
19303 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
19304 let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
19306
19307 let (sched, mut arena) =
19308 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
19309 execute_thunks(&sched, arena.raw_buf_mut());
19310
19311 let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
19313 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
19314
19315 let t_x_ref = {
19317 let mut a = a_data;
19318 let mut tb = tb_data;
19319 let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
19320 assert_eq!(info, 0);
19321 tb
19322 };
19323 for i in 0..n {
19324 assert!(
19325 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
19326 "t_x[{i}]: got {} want {}",
19327 tangent_x[i],
19328 t_x_ref[i]
19329 );
19330 }
19331
19332 let h = 1e-6;
19334 let mut bp = b_data;
19335 let mut bm = b_data;
19336 for i in 0..n {
19337 bp[i] += h * tb_data[i];
19338 bm[i] -= h * tb_data[i];
19339 }
19340 let xp = {
19341 let mut a = a_data;
19342 let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
19343 assert_eq!(info, 0);
19344 bp
19345 };
19346 let xm = {
19347 let mut a = a_data;
19348 let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
19349 assert_eq!(info, 0);
19350 bm
19351 };
19352 let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
19353 for i in 0..n {
19354 assert!(
19355 (tangent_x[i] - fd[i]).abs() < 1e-7,
19356 "FD mismatch t_x[{i}]: AD={} FD={}",
19357 tangent_x[i],
19358 fd[i]
19359 );
19360 }
19361 let primal_ref = {
19363 let mut a = a_data;
19364 let mut b = b_data;
19365 crate::blas::dgesv(&mut a, &mut b, n, 1);
19366 b
19367 };
19368 for i in 0..n {
19369 assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
19370 }
19371 }
19372
19373 #[test]
19379 fn jvp_dense_solve_a_runs_and_matches_fd() {
19380 use rlx_opt::autodiff_fwd::jvp;
19381 let n = 3usize;
19382
19383 let mut g = Graph::new("jvp_a_e2e");
19384 let a = g.input("A", Shape::new(&[n, n], DType::F64));
19385 let b = g.input("b", Shape::new(&[n], DType::F64));
19386 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19387 g.set_outputs(vec![x]);
19388
19389 let jg = jvp(&g, &[a]);
19390 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19391 for node in graph.nodes() {
19392 let name = match &node.op {
19393 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19394 _ => None,
19395 };
19396 if name == Some(want) {
19397 return node.id;
19398 }
19399 }
19400 panic!("no node named {want:?}");
19401 };
19402 let a_id = find_by_name(&jg, "A");
19403 let b_id = find_by_name(&jg, "b");
19404 let ta_id = find_by_name(&jg, "tangent_A");
19405
19406 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
19407 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
19408 let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
19410
19411 let (sched, mut arena) =
19412 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
19413 execute_thunks(&sched, arena.raw_buf_mut());
19414
19415 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
19416
19417 let x_ref = {
19419 let mut a = a_data;
19420 let mut b = b_data;
19421 crate::blas::dgesv(&mut a, &mut b, n, 1);
19422 b
19423 };
19424 let mut prod = [0.0_f64; 3];
19425 for i in 0..n {
19426 for j in 0..n {
19427 prod[i] += ta_data[i * n + j] * x_ref[j];
19428 }
19429 }
19430 let t_x_ref = {
19431 let mut a = a_data;
19432 let mut p = prod;
19433 crate::blas::dgesv(&mut a, &mut p, n, 1);
19434 [-p[0], -p[1], -p[2]]
19435 };
19436 for i in 0..n {
19437 assert!(
19438 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
19439 "closed-form t_x[{i}]: AD={} ref={}",
19440 tangent_x[i],
19441 t_x_ref[i]
19442 );
19443 }
19444
19445 let h = 1e-6;
19447 let mut ap = a_data;
19448 let mut am = a_data;
19449 for i in 0..n * n {
19450 ap[i] += h * ta_data[i];
19451 am[i] -= h * ta_data[i];
19452 }
19453 let xp = {
19454 let mut a = ap;
19455 let mut b = b_data;
19456 crate::blas::dgesv(&mut a, &mut b, n, 1);
19457 b
19458 };
19459 let xm = {
19460 let mut a = am;
19461 let mut b = b_data;
19462 crate::blas::dgesv(&mut a, &mut b, n, 1);
19463 b
19464 };
19465 for i in 0..n {
19466 let fd = (xp[i] - xm[i]) / (2.0 * h);
19467 assert!(
19468 (tangent_x[i] - fd).abs() < 1e-7,
19469 "FD t_x[{i}]: AD={} FD={}",
19470 tangent_x[i],
19471 fd
19472 );
19473 }
19474 }
19475
19476 #[test]
19482 fn q_conv2d_matches_reference() {
19483 use rlx_ir::Philox4x32;
19484 let n = 1usize;
19486 let c_in = 2usize;
19487 let h = 5usize;
19488 let w_in = 5usize;
19489 let c_out = 3usize;
19490 let kh = 3usize;
19491 let kw = 3usize;
19492 let ph = 1usize;
19493 let pw = 1usize;
19494 let sh = 1usize;
19495 let sw = 1usize;
19496 let h_out = (h + 2 * ph - kh) / sh + 1;
19497 let w_out = (w_in + 2 * pw - kw) / sw + 1;
19498
19499 let x_scale = 0.04f32;
19500 let w_scale = 0.02f32;
19501 let out_scale = 0.5f32;
19502 let mult = x_scale * w_scale / out_scale;
19503
19504 let mut rng = Philox4x32::new(2099);
19505 let mut xf = vec![0f32; n * c_in * h * w_in];
19506 rng.fill_normal(&mut xf);
19507 let mut wf = vec![0f32; c_out * c_in * kh * kw];
19508 rng.fill_normal(&mut wf);
19509 let xq: Vec<i8> = xf
19510 .iter()
19511 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
19512 .collect();
19513 let wq: Vec<i8> = wf
19514 .iter()
19515 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
19516 .collect();
19517 let bias: Vec<i32> = vec![0i32; c_out];
19518
19519 let mut g = Graph::new("qconv");
19520 let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
19521 let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
19522 let bn = g.input("b", Shape::new(&[c_out], DType::I32));
19523 let out = g.q_conv2d(
19524 xn,
19525 wn,
19526 bn,
19527 vec![kh, kw],
19528 vec![sh, sw],
19529 vec![ph, pw],
19530 vec![1, 1],
19531 1,
19532 0,
19533 0,
19534 0,
19535 mult,
19536 Shape::new(&[n, c_out, h_out, w_out], DType::I8),
19537 );
19538 g.set_outputs(vec![out]);
19539
19540 let plan = rlx_opt::memory::plan_memory(&g);
19541 let mut arena = crate::arena::Arena::from_plan(plan);
19542 let sched = compile_thunks(&g, &arena);
19543 let xn_off = arena.byte_offset(xn);
19546 let wn_off = arena.byte_offset(wn);
19547 let bn_off = arena.byte_offset(bn);
19548 let out_off = arena.byte_offset(out);
19549 let buf = arena.raw_buf_mut();
19550 unsafe {
19551 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
19552 for (i, &v) in xq.iter().enumerate() {
19553 *p.add(i) = v;
19554 }
19555 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
19556 for (i, &v) in wq.iter().enumerate() {
19557 *p.add(i) = v;
19558 }
19559 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
19560 for (i, &v) in bias.iter().enumerate() {
19561 *p.add(i) = v;
19562 }
19563 }
19564 execute_thunks(&sched, arena.raw_buf_mut());
19565 let out_q: Vec<i8> = unsafe {
19566 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
19567 (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
19568 };
19569
19570 let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
19572 for ni in 0..n {
19573 for co in 0..c_out {
19574 for ho in 0..h_out {
19575 for wo in 0..w_out {
19576 let mut acc: i32 = 0;
19577 for ci in 0..c_in {
19578 for ki in 0..kh {
19579 for kj in 0..kw {
19580 let hi = ho * sh + ki;
19581 let wi = wo * sw + kj;
19582 if hi < ph || wi < pw {
19583 continue;
19584 }
19585 let hi = hi - ph;
19586 let wi = wi - pw;
19587 if hi >= h || wi >= w_in {
19588 continue;
19589 }
19590 let xv =
19591 xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
19592 let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
19593 acc += xv * wv;
19594 }
19595 }
19596 }
19597 let r = (acc as f32 * mult).round() as i32;
19598 let r = r.clamp(-128, 127) as i8;
19599 out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
19600 }
19601 }
19602 }
19603 }
19604
19605 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
19606 assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
19607 }
19608 }
19609
19610 #[test]
19618 fn q_matmul_matches_fake_quant_reference() {
19619 use rlx_ir::Philox4x32;
19620 let m = 3usize;
19621 let k = 8usize;
19622 let n = 5usize;
19623 let mut rng = Philox4x32::new(2031);
19624
19625 let x_scale = 0.05f32;
19627 let w_scale = 0.03f32;
19628 let out_scale = 0.4f32;
19629 let mult = x_scale * w_scale / out_scale;
19630 let mut xf = vec![0f32; m * k];
19631 rng.fill_normal(&mut xf);
19632 let mut wf = vec![0f32; k * n];
19633 rng.fill_normal(&mut wf);
19634 let xq: Vec<i8> = xf
19635 .iter()
19636 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
19637 .collect();
19638 let wq: Vec<i8> = wf
19639 .iter()
19640 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
19641 .collect();
19642 let bias: Vec<i32> = vec![0i32; n];
19643
19644 let _f = DType::F32;
19646 let mut g_q = Graph::new("qmm_direct");
19647 let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
19648 let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
19649 let bn = g_q.input("b", Shape::new(&[n], DType::I32));
19650 let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
19651 g_q.set_outputs(vec![out]);
19652 let plan = rlx_opt::memory::plan_memory(&g_q);
19653 let mut arena = crate::arena::Arena::from_plan(plan);
19654 let sched = compile_thunks(&g_q, &arena);
19655
19656 let xn_off = arena.byte_offset(xn);
19658 let wn_off = arena.byte_offset(wn);
19659 let bn_off = arena.byte_offset(bn);
19660 let out_off = arena.byte_offset(out);
19661 let buf = arena.raw_buf_mut();
19662 unsafe {
19663 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
19664 for (i, &v) in xq.iter().enumerate() {
19665 *p.add(i) = v;
19666 }
19667 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
19668 for (i, &v) in wq.iter().enumerate() {
19669 *p.add(i) = v;
19670 }
19671 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
19672 for (i, &v) in bias.iter().enumerate() {
19673 *p.add(i) = v;
19674 }
19675 }
19676 execute_thunks(&sched, arena.raw_buf_mut());
19677 let out_q: Vec<i8> = unsafe {
19678 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
19679 (0..m * n).map(|i| *p.add(i)).collect()
19680 };
19681
19682 let mut out_ref = vec![0i8; m * n];
19687 for mi in 0..m {
19688 for ni in 0..n {
19689 let mut acc: i32 = 0;
19690 for ki in 0..k {
19691 acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
19692 }
19693 let r = (acc as f32 * mult).round() as i32;
19694 out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
19695 }
19696 }
19697
19698 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
19699 assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
19700 }
19701 }
19702
19703 #[test]
19708 fn quantize_dequantize_round_trip() {
19709 use rlx_ir::Philox4x32;
19710 let len = 64;
19711 let mut rng = Philox4x32::new(2027);
19712 let mut x = vec![0f32; len];
19713 rng.fill_normal(&mut x);
19714 x[0] = 999.0;
19717 x[1] = -999.0;
19718
19719 let scale = 0.05f32;
19720 let zp = 3i32;
19721
19722 let f = DType::F32;
19723 let mut g = Graph::new("qdq");
19724 let xn = g.input("x", Shape::new(&[len], f));
19725 let q = g.quantize(xn, scale, zp);
19726 let dq = g.dequantize(q, scale, zp);
19727 g.set_outputs(vec![dq]);
19728
19729 let plan = rlx_opt::memory::plan_memory(&g);
19730 let mut arena = crate::arena::Arena::from_plan(plan);
19731 let sched = compile_thunks(&g, &arena);
19732 let xn_off = arena.byte_offset(xn);
19733 let dq_off = arena.byte_offset(dq);
19734 let buf = arena.raw_buf_mut();
19735 unsafe {
19736 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19737 for (i, &v) in x.iter().enumerate() {
19738 *p.add(i) = v;
19739 }
19740 }
19741 execute_thunks(&sched, arena.raw_buf_mut());
19742 let out: Vec<f32> = unsafe {
19743 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
19744 (0..len).map(|i| *p.add(i)).collect()
19745 };
19746
19747 let sat_pos = (127 - zp) as f32 * scale;
19750 let sat_neg = (-128 - zp) as f32 * scale;
19751 assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
19752 assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
19753
19754 for i in 2..len {
19757 assert!(
19758 (out[i] - x[i]).abs() <= scale + 1e-5,
19759 "qdq[{i}]: {} → {}, scale={scale}",
19760 x[i],
19761 out[i]
19762 );
19763 }
19764 }
19765
19766 #[test]
19772 fn quantize_per_channel_round_trip() {
19773 let c = 4usize;
19774 let inner = 5usize;
19775 let mags = [0.01f32, 0.5, 5.0, 50.0];
19778 let mut x = vec![0f32; c * inner];
19779 for ci in 0..c {
19780 for ii in 0..inner {
19781 x[ci * inner + ii] = match ii {
19785 0 => -mags[ci],
19786 1 => 0.0,
19787 2 => mags[ci],
19788 3 => mags[ci] * 1000.0, _ => -mags[ci] * 1000.0, };
19791 }
19792 }
19793 let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
19794 let zps: Vec<i32> = vec![0, 0, 0, 0];
19795
19796 let f = DType::F32;
19797 let mut g = Graph::new("qdq_pc");
19798 let xn = g.input("x", Shape::new(&[c, inner], f));
19799 let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
19800 let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
19801 g.set_outputs(vec![dq]);
19802
19803 let plan = rlx_opt::memory::plan_memory(&g);
19804 let mut arena = crate::arena::Arena::from_plan(plan);
19805 let sched = compile_thunks(&g, &arena);
19806 let xn_off = arena.byte_offset(xn);
19807 let dq_off = arena.byte_offset(dq);
19808 let buf = arena.raw_buf_mut();
19809 unsafe {
19810 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19811 for (i, &v) in x.iter().enumerate() {
19812 *p.add(i) = v;
19813 }
19814 }
19815 execute_thunks(&sched, arena.raw_buf_mut());
19816 let out: Vec<f32> = unsafe {
19817 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
19818 (0..c * inner).map(|i| *p.add(i)).collect()
19819 };
19820
19821 for ci in 0..c {
19822 for ii in 0..3 {
19825 let idx = ci * inner + ii;
19826 assert!(
19827 (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
19828 "ch {ci} idx {ii}: {} vs {}",
19829 x[idx],
19830 out[idx]
19831 );
19832 }
19833 let sat_pos = 127.0 * scales[ci];
19835 let sat_neg = -128.0 * scales[ci];
19836 assert!(
19837 (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
19838 "ch {ci} +sat: {}",
19839 out[ci * inner + 3]
19840 );
19841 assert!(
19842 (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
19843 "ch {ci} -sat: {}",
19844 out[ci * inner + 4]
19845 );
19846 }
19847 }
19848
19849 #[test]
19855 fn activation_backward_matches_numerical_per_kind() {
19856 use rlx_ir::Philox4x32;
19857 use rlx_ir::op::Activation;
19858 let mut rng = Philox4x32::new(91);
19859 let len = 32;
19860 let mut x_pos = vec![0f32; len];
19865 rng.fill_normal(&mut x_pos);
19866 for v in x_pos.iter_mut() {
19867 *v = v.abs() + 0.5;
19868 }
19869 let mut x_any = vec![0f32; len];
19870 rng.fill_normal(&mut x_any);
19871 let mut dy = vec![0f32; len];
19872 rng.fill_normal(&mut dy);
19873
19874 for &(kind, x_data, eps, tol) in &[
19875 (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
19876 (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
19877 (Activation::Silu, &x_any[..], 1e-3, 5e-3),
19878 (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
19879 (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
19880 (Activation::Exp, &x_any[..], 1e-4, 5e-3),
19881 (Activation::Log, &x_pos[..], 1e-4, 5e-3),
19882 (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
19883 (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
19884 (Activation::Neg, &x_any[..], 1e-3, 5e-4),
19885 ] {
19886 let f = DType::F32;
19887 let mut g = Graph::new("act_bw");
19888 let xn = g.input("x", Shape::new(&[len], f));
19889 let dyn_ = g.input("dy", Shape::new(&[len], f));
19890 let dx = g.activation_backward(kind, xn, dyn_);
19891 g.set_outputs(vec![dx]);
19892
19893 let plan = rlx_opt::memory::plan_memory(&g);
19894 let mut arena = crate::arena::Arena::from_plan(plan);
19895 let sched = compile_thunks(&g, &arena);
19896
19897 let xn_off = arena.byte_offset(xn);
19898 let dyn_off = arena.byte_offset(dyn_);
19899 let dx_off = arena.byte_offset(dx);
19900 let buf = arena.raw_buf_mut();
19901 unsafe {
19902 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19903 for (i, &v) in x_data.iter().enumerate() {
19904 *p.add(i) = v;
19905 }
19906 let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
19907 for (i, &v) in dy.iter().enumerate() {
19908 *p.add(i) = v;
19909 }
19910 }
19911 execute_thunks(&sched, arena.raw_buf_mut());
19912 let analytical: Vec<f32> = unsafe {
19913 let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
19914 (0..len).map(|i| *p.add(i)).collect()
19915 };
19916
19917 let act_apply = |kind: Activation, x: f32| -> f32 {
19920 match kind {
19921 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
19922 Activation::Tanh => x.tanh(),
19923 Activation::Silu => x / (1.0 + (-x).exp()),
19924 Activation::Gelu => {
19925 const INV_SQRT2: f32 = 0.707_106_77;
19927 0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
19928 }
19929 Activation::GeluApprox => {
19930 const C: f32 = 0.797_884_6;
19931 const A: f32 = 0.044_715;
19932 let inner = C * (x + A * x * x * x);
19933 0.5 * x * (1.0 + inner.tanh())
19934 }
19935 Activation::Exp => x.exp(),
19936 Activation::Log => x.ln(),
19937 Activation::Sqrt => x.sqrt(),
19938 Activation::Rsqrt => 1.0 / x.sqrt(),
19939 Activation::Neg => -x,
19940 Activation::Relu => x.max(0.0),
19941 Activation::Abs => x.abs(),
19942 Activation::Round => x.round(),
19943 Activation::Sin => x.sin(),
19944 Activation::Cos => x.cos(),
19945 Activation::Tan => x.tan(),
19946 Activation::Atan => x.atan(),
19947 }
19948 };
19949 for i in 0..len {
19950 let xv = x_data[i];
19951 let plus = act_apply(kind, xv + eps);
19952 let minus = act_apply(kind, xv - eps);
19953 let num = (plus - minus) / (2.0 * eps) * dy[i];
19954 assert!(
19955 (analytical[i] - num).abs() < tol,
19956 "{kind:?}[{i}]: analytical {} vs numerical {num}",
19957 analytical[i]
19958 );
19959 }
19960 }
19961 }
19962
19963 #[test]
19967 fn matmul_3d_gradient_matches_numerical() {
19968 use rlx_ir::Philox4x32;
19969 let batch = 2usize;
19970 let m = 3usize;
19971 let k = 4usize;
19972 let n = 5usize;
19973 let mut rng = Philox4x32::new(101);
19974 let mut a_data = vec![0f32; batch * m * k];
19975 rng.fill_normal(&mut a_data);
19976 let mut b_data = vec![0f32; batch * k * n];
19977 rng.fill_normal(&mut b_data);
19978
19979 let f = DType::F32;
19980 let mut fwd = Graph::new("matmul_3d");
19981 let an = fwd.input("a", Shape::new(&[batch, m, k], f));
19982 let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
19983 let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
19984 let loss = fwd.add_node(
19985 Op::Reduce {
19986 op: ReduceOp::Sum,
19987 axes: vec![0, 1, 2],
19988 keep_dim: false,
19989 },
19990 vec![mm],
19991 Shape::from_dims(&[], f),
19992 );
19993 fwd.set_outputs(vec![loss]);
19994
19995 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
19996 let d_out = bwd_graph
19997 .nodes()
19998 .iter()
19999 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20000 .map(|n| n.id)
20001 .unwrap();
20002
20003 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20004 let mut arena = crate::arena::Arena::from_plan(plan);
20005 let sched = compile_thunks(&bwd_graph, &arena);
20006 for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
20007 let off = arena.byte_offset(id);
20008 let buf = arena.raw_buf_mut();
20009 unsafe {
20010 let p = buf.as_mut_ptr().add(off) as *mut f32;
20011 for (i, &v) in data.iter().enumerate() {
20012 *p.add(i) = v;
20013 }
20014 }
20015 }
20016 execute_thunks(&sched, arena.raw_buf_mut());
20017 let gb_id = bwd_graph.outputs[1];
20018 let g_b: Vec<f32> = unsafe {
20019 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
20020 (0..batch * k * n).map(|i| *p.add(i)).collect()
20021 };
20022
20023 let forward_loss = |b_vals: &[f32]| -> f32 {
20025 let mut out = vec![0f32; batch * m * n];
20026 for bi in 0..batch {
20027 for mi in 0..m {
20028 for ni in 0..n {
20029 let mut acc = 0f32;
20030 for ki in 0..k {
20031 acc +=
20032 a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
20033 }
20034 out[bi * m * n + mi * n + ni] = acc;
20035 }
20036 }
20037 }
20038 out.iter().sum()
20039 };
20040 let eps = 1e-3f32;
20041 let mut bp_p = b_data.clone();
20042 let mut g_b_num = vec![0f32; b_data.len()];
20043 for i in 0..b_data.len() {
20044 let s = bp_p[i];
20045 bp_p[i] = s + eps;
20046 let lp = forward_loss(&bp_p);
20047 bp_p[i] = s - eps;
20048 let lm = forward_loss(&bp_p);
20049 bp_p[i] = s;
20050 g_b_num[i] = (lp - lm) / (2.0 * eps);
20051 }
20052 for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
20053 assert!(
20054 (a - n).abs() < 5e-3,
20055 "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
20056 );
20057 }
20058 }
20059
20060 #[test]
20066 fn softmax_gradient_matches_numerical() {
20067 use rlx_ir::Philox4x32;
20068 let n = 3usize;
20069 let c = 5usize;
20070 let mut rng = Philox4x32::new(57);
20071 let mut x_data = vec![0f32; n * c];
20072 rng.fill_normal(&mut x_data);
20073
20074 let f = DType::F32;
20075 let mut fwd = Graph::new("softmax_only");
20076 let xn = fwd.input("x", Shape::new(&[n, c], f));
20077 let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
20078 let loss = fwd.add_node(
20082 Op::Reduce {
20083 op: ReduceOp::Sum,
20084 axes: vec![0, 1],
20085 keep_dim: false,
20086 },
20087 vec![sm],
20088 Shape::from_dims(&[], f),
20089 );
20090 fwd.set_outputs(vec![loss]);
20091
20092 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
20096 let d_out = bwd_graph
20097 .nodes()
20098 .iter()
20099 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20100 .map(|n| n.id)
20101 .unwrap();
20102
20103 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20104 let mut arena = crate::arena::Arena::from_plan(plan);
20105 let sched = compile_thunks(&bwd_graph, &arena);
20106 for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
20107 let off = arena.byte_offset(id);
20108 let buf = arena.raw_buf_mut();
20109 unsafe {
20110 let p = buf.as_mut_ptr().add(off) as *mut f32;
20111 for (i, &v) in data.iter().enumerate() {
20112 *p.add(i) = v;
20113 }
20114 }
20115 }
20116 execute_thunks(&sched, arena.raw_buf_mut());
20117 let g_x_id = bwd_graph.outputs[1];
20118 let g_x: Vec<f32> = unsafe {
20119 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
20120 (0..n * c).map(|i| *p.add(i)).collect()
20121 };
20122
20123 let forward_loss = |x: &[f32]| -> f32 {
20127 let mut total = 0f32;
20128 for ni in 0..n {
20129 let row = &x[ni * c..(ni + 1) * c];
20130 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
20131 let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
20132 for &v in row {
20133 total += (v - m).exp() / denom;
20134 }
20135 }
20136 total
20137 };
20138 let eps = 1e-3f32;
20139 let mut p = x_data.clone();
20140 for i in 0..x_data.len() {
20141 let s = p[i];
20142 p[i] = s + eps;
20143 let lp = forward_loss(&p);
20144 p[i] = s - eps;
20145 let lm = forward_loss(&p);
20146 p[i] = s;
20147 let num = (lp - lm) / (2.0 * eps);
20148 assert!(
20149 (g_x[i] - num).abs() < 5e-3,
20150 "softmax g_x[{i}]: analytical {} vs numerical {num}",
20151 g_x[i]
20152 );
20153 }
20154 }
20155
20156 #[test]
20161 fn layer_norm_gradient_matches_numerical() {
20162 use rlx_ir::Philox4x32;
20163 let rows = 3usize;
20164 let h = 6usize;
20165 let mut rng = Philox4x32::new(1009);
20166 let mut x_data = vec![0f32; rows * h];
20167 rng.fill_normal(&mut x_data);
20168 let mut g_data = vec![0f32; h];
20169 rng.fill_normal(&mut g_data);
20170 for v in g_data.iter_mut() {
20171 *v = v.abs() + 0.5;
20172 }
20173 let mut b_data = vec![0f32; h];
20174 rng.fill_normal(&mut b_data);
20175 let eps = 1e-5f32;
20176
20177 let f = DType::F32;
20178 let mut fwd = Graph::new("ln_only");
20179 let xn = fwd.input("x", Shape::new(&[rows, h], f));
20180 let gp = fwd.param("gamma", Shape::new(&[h], f));
20181 let bp = fwd.param("beta", Shape::new(&[h], f));
20182 let ln = fwd.add_node(
20183 Op::LayerNorm { axis: -1, eps },
20184 vec![xn, gp, bp],
20185 Shape::new(&[rows, h], f),
20186 );
20187 let loss = fwd.add_node(
20188 Op::Reduce {
20189 op: ReduceOp::Sum,
20190 axes: vec![0, 1],
20191 keep_dim: false,
20192 },
20193 vec![ln],
20194 Shape::from_dims(&[], f),
20195 );
20196 fwd.set_outputs(vec![loss]);
20197
20198 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
20199 let d_out = bwd_graph
20200 .nodes()
20201 .iter()
20202 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20203 .map(|n| n.id)
20204 .unwrap();
20205
20206 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20207 let mut arena = crate::arena::Arena::from_plan(plan);
20208 let sched = compile_thunks(&bwd_graph, &arena);
20209 for &(id, data) in &[
20210 (xn, &x_data),
20211 (gp, &g_data),
20212 (bp, &b_data),
20213 (d_out, &vec![1.0f32]),
20214 ] {
20215 let off = arena.byte_offset(id);
20216 let buf = arena.raw_buf_mut();
20217 unsafe {
20218 let p = buf.as_mut_ptr().add(off) as *mut f32;
20219 for (i, &v) in data.iter().enumerate() {
20220 *p.add(i) = v;
20221 }
20222 }
20223 }
20224 execute_thunks(&sched, arena.raw_buf_mut());
20225 let read = |id: NodeId, n: usize| -> Vec<f32> {
20226 let off = arena.byte_offset(id);
20227 unsafe {
20228 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
20229 (0..n).map(|i| *p.add(i)).collect()
20230 }
20231 };
20232 let dx_a = read(bwd_graph.outputs[1], rows * h);
20233 let dg_a = read(bwd_graph.outputs[2], h);
20234 let db_a = read(bwd_graph.outputs[3], h);
20235
20236 let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
20237 let mut total = 0f32;
20238 for r in 0..rows {
20239 let row = &x[r * h..(r + 1) * h];
20240 let mean = row.iter().sum::<f32>() / h as f32;
20241 let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
20242 let inv_std = 1.0 / (var + eps).sqrt();
20243 for d in 0..h {
20244 total += ((row[d] - mean) * inv_std) * g[d] + b[d];
20245 }
20246 }
20247 total
20248 };
20249 let h_eps = 1e-3f32;
20250
20251 let mut x_p = x_data.clone();
20252 for i in 0..x_p.len() {
20253 let s = x_p[i];
20254 x_p[i] = s + h_eps;
20255 let lp = forward_loss(&x_p, &g_data, &b_data);
20256 x_p[i] = s - h_eps;
20257 let lm = forward_loss(&x_p, &g_data, &b_data);
20258 x_p[i] = s;
20259 let num = (lp - lm) / (2.0 * h_eps);
20260 assert!(
20261 (dx_a[i] - num).abs() < 5e-3,
20262 "ln dx[{i}]: analytical {} vs numerical {num}",
20263 dx_a[i]
20264 );
20265 }
20266 let mut g_p = g_data.clone();
20267 for i in 0..g_p.len() {
20268 let s = g_p[i];
20269 g_p[i] = s + h_eps;
20270 let lp = forward_loss(&x_data, &g_p, &b_data);
20271 g_p[i] = s - h_eps;
20272 let lm = forward_loss(&x_data, &g_p, &b_data);
20273 g_p[i] = s;
20274 let num = (lp - lm) / (2.0 * h_eps);
20275 assert!(
20276 (dg_a[i] - num).abs() < 5e-3,
20277 "ln dg[{i}]: analytical {} vs numerical {num}",
20278 dg_a[i]
20279 );
20280 }
20281 let mut b_p = b_data.clone();
20282 for i in 0..b_p.len() {
20283 let s = b_p[i];
20284 b_p[i] = s + h_eps;
20285 let lp = forward_loss(&x_data, &g_data, &b_p);
20286 b_p[i] = s - h_eps;
20287 let lm = forward_loss(&x_data, &g_data, &b_p);
20288 b_p[i] = s;
20289 let num = (lp - lm) / (2.0 * h_eps);
20290 assert!(
20291 (db_a[i] - num).abs() < 5e-3,
20292 "ln db[{i}]: analytical {} vs numerical {num}",
20293 db_a[i]
20294 );
20295 }
20296 }
20297
20298 #[test]
20303 fn dense_sce_mean_gradient_matches_numerical() {
20304 use rlx_ir::Philox4x32;
20305 let bs = 4usize;
20306 let k_in = 3usize;
20307 let c = 5usize;
20308 let mut rng = Philox4x32::new(7);
20309 let mut x = vec![0f32; bs * k_in];
20310 rng.fill_normal(&mut x);
20311 let mut w_init = vec![0f32; k_in * c];
20312 rng.fill_normal(&mut w_init);
20313 let mut b_init = vec![0f32; c];
20314 rng.fill_normal(&mut b_init);
20315 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
20316
20317 let f = DType::F32;
20319 let mut fwd = Graph::new("dense_sce");
20320 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
20321 let lb = fwd.input("labels", Shape::new(&[bs], f));
20322 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
20323 let bp = fwd.param("b", Shape::new(&[c], f));
20324 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
20325 let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
20326 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
20327 let loss = fwd.add_node(
20328 Op::Reduce {
20329 op: ReduceOp::Sum,
20330 axes: vec![0],
20331 keep_dim: false,
20332 },
20333 vec![loss_per],
20334 Shape::from_dims(&[], f),
20336 );
20337 fwd.set_outputs(vec![loss]);
20345
20346 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
20348 let d_out = bwd_graph
20351 .nodes()
20352 .iter()
20353 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20354 .map(|n| n.id)
20355 .expect("d_output input");
20356
20357 let (sched, mut arena) = prepare(
20358 &bwd_graph,
20359 &[
20360 (xn, &x),
20361 (lb, &labels),
20362 (wp, &w_init),
20363 (bp, &b_init),
20364 (d_out, &[1.0]),
20365 ],
20366 );
20367 execute_thunks(&sched, arena.raw_buf_mut());
20368
20369 let outs = &bwd_graph.outputs;
20370 let loss_id = outs[0];
20371 let gw_id = outs[1];
20372 let gb_id = outs[2];
20373 let loss_actual = read_arena(&arena, loss_id, 1)[0];
20374 let gw_actual = read_arena(&arena, gw_id, k_in * c);
20375 let gb_actual = read_arena(&arena, gb_id, c);
20376
20377 let plan = rlx_opt::memory::plan_memory(&fwd);
20381 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20382 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20383 write_arena(&mut fwd_arena, xn, &x);
20384 write_arena(&mut fwd_arena, lb, &labels);
20385
20386 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
20387 write_arena(arena, wp, w);
20388 write_arena(arena, bp, b);
20389 execute_thunks(&fwd_sched, arena.raw_buf_mut());
20390 read_arena(arena, loss, 1)[0]
20391 };
20392
20393 let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
20396 assert!(
20397 (loss_actual - loss_check).abs() < 1e-4,
20398 "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
20399 );
20400
20401 let eps = 1e-3f32;
20402 let mut w_perturbed = w_init.clone();
20403 let mut gw_numerical = vec![0f32; w_init.len()];
20404 for i in 0..w_init.len() {
20405 let saved = w_perturbed[i];
20406 w_perturbed[i] = saved + eps;
20407 let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
20408 w_perturbed[i] = saved - eps;
20409 let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
20410 w_perturbed[i] = saved;
20411 gw_numerical[i] = (lp - lm) / (2.0 * eps);
20412 }
20413 for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
20414 assert!(
20415 (a - n).abs() < 5e-3,
20416 "grad_w[{i}]: analytical {a} vs numerical {n}"
20417 );
20418 }
20419
20420 let mut b_perturbed = b_init.clone();
20421 let mut gb_numerical = vec![0f32; b_init.len()];
20422 for i in 0..b_init.len() {
20423 let saved = b_perturbed[i];
20424 b_perturbed[i] = saved + eps;
20425 let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
20426 b_perturbed[i] = saved - eps;
20427 let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
20428 b_perturbed[i] = saved;
20429 gb_numerical[i] = (lp - lm) / (2.0 * eps);
20430 }
20431 for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
20432 assert!(
20433 (a - n).abs() < 5e-3,
20434 "grad_b[{i}]: analytical {a} vs numerical {n}"
20435 );
20436 }
20437 }
20438
20439 #[test]
20442 fn dense_sce_mean_reduce_gradient_matches_numerical() {
20443 use rlx_ir::Philox4x32;
20444 let bs = 3usize;
20445 let k_in = 2usize;
20446 let c = 4usize;
20447 let mut rng = Philox4x32::new(13);
20448 let mut x = vec![0f32; bs * k_in];
20449 rng.fill_normal(&mut x);
20450 let mut w_init = vec![0f32; k_in * c];
20451 rng.fill_normal(&mut w_init);
20452 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
20453
20454 let f = DType::F32;
20455 let mut fwd = Graph::new("dense_sce_mean");
20456 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
20457 let lb = fwd.input("labels", Shape::new(&[bs], f));
20458 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
20459 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
20460 let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
20461 let loss = fwd.add_node(
20462 Op::Reduce {
20463 op: ReduceOp::Mean,
20464 axes: vec![0],
20465 keep_dim: false,
20466 },
20467 vec![loss_per],
20468 Shape::from_dims(&[], f),
20469 );
20470 fwd.set_outputs(vec![loss]);
20471
20472 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
20473 let d_out = bwd_graph
20474 .nodes()
20475 .iter()
20476 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20477 .map(|n| n.id)
20478 .unwrap();
20479
20480 let (sched, mut arena) = prepare(
20481 &bwd_graph,
20482 &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
20483 );
20484 execute_thunks(&sched, arena.raw_buf_mut());
20485
20486 let outs = &bwd_graph.outputs;
20487 let loss_id = outs[0];
20488 let gw_id = outs[1];
20489 let _ = read_arena(&arena, loss_id, 1)[0];
20490 let gw_actual = read_arena(&arena, gw_id, k_in * c);
20491
20492 let plan = rlx_opt::memory::plan_memory(&fwd);
20493 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20494 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20495 write_arena(&mut fwd_arena, xn, &x);
20496 write_arena(&mut fwd_arena, lb, &labels);
20497
20498 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
20499 write_arena(arena, wp, w);
20500 execute_thunks(&fwd_sched, arena.raw_buf_mut());
20501 read_arena(arena, loss, 1)[0]
20502 };
20503
20504 let eps = 1e-3f32;
20505 let mut wp_p = w_init.clone();
20506 let mut gw_num = vec![0f32; w_init.len()];
20507 for i in 0..w_init.len() {
20508 let s = wp_p[i];
20509 wp_p[i] = s + eps;
20510 let lp = run_loss(&mut fwd_arena, &wp_p);
20511 wp_p[i] = s - eps;
20512 let lm = run_loss(&mut fwd_arena, &wp_p);
20513 wp_p[i] = s;
20514 gw_num[i] = (lp - lm) / (2.0 * eps);
20515 }
20516 for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
20517 assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
20518 }
20519 }
20520 #[test]
20525 fn tinyconv_full_gradient_matches_numerical() {
20526 use rlx_ir::Philox4x32;
20527 let n = 1usize;
20529 let c_in = 1usize;
20530 let h = 6usize;
20531 let w_in = 6usize;
20532 let c_mid = 2usize; let kh = 3;
20534 let kw = 3;
20535 let h1 = h - kh + 1; let w1 = w_in - kw + 1; let h2 = h1 / 2;
20538 let w2 = w1 / 2; let flat = c_mid * h2 * w2; let num_classes = 3usize;
20541
20542 let mut rng = Philox4x32::new(31);
20543 let mut x = vec![0f32; n * c_in * h * w_in];
20544 rng.fill_normal(&mut x);
20545 let mut wc = vec![0f32; c_mid * c_in * kh * kw];
20546 rng.fill_normal(&mut wc);
20547 for v in wc.iter_mut() {
20548 *v *= 0.2;
20549 }
20550 let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
20559 let mut wfc = vec![0f32; flat * num_classes];
20560 rng.fill_normal(&mut wfc);
20561 for v in wfc.iter_mut() {
20562 *v *= 0.5;
20563 }
20564 let mut bfc = vec![0f32; num_classes];
20565 rng.fill_normal(&mut bfc);
20566 let labels: Vec<f32> = vec![1.0]; let f = DType::F32;
20569 let mut fwd = Graph::new("tinyconv");
20570 let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
20571 let lb = fwd.input("labels", Shape::new(&[n], f));
20572 let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
20573 let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
20574 let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
20575 let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
20576
20577 let conv = fwd.add_node(
20579 Op::Conv {
20580 kernel_size: vec![kh, kw],
20581 stride: vec![1, 1],
20582 padding: vec![0, 0],
20583 dilation: vec![1, 1],
20584 groups: 1,
20585 },
20586 vec![xn, wcp],
20587 Shape::new(&[n, c_mid, h1, w1], f),
20588 );
20589 let bc_4d = fwd.add_node(
20601 Op::Reshape {
20602 new_shape: vec![1, c_mid as i64, 1, 1],
20603 },
20604 vec![bcp],
20605 Shape::new(&[1, c_mid, 1, 1], f),
20606 );
20607 let bc_expanded = fwd.add_node(
20608 Op::Expand {
20609 target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
20610 },
20611 vec![bc_4d],
20612 Shape::new(&[n, c_mid, h1, w1], f),
20613 );
20614 let conv_b = fwd.binary(
20615 BinaryOp::Add,
20616 conv,
20617 bc_expanded,
20618 Shape::new(&[n, c_mid, h1, w1], f),
20619 );
20620 let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
20621 let pool = fwd.add_node(
20622 Op::Pool {
20623 kind: ReduceOp::Max,
20624 kernel_size: vec![2, 2],
20625 stride: vec![2, 2],
20626 padding: vec![0, 0],
20627 },
20628 vec![relu],
20629 Shape::new(&[n, c_mid, h2, w2], f),
20630 );
20631 let flatn = fwd.add_node(
20632 Op::Reshape {
20633 new_shape: vec![n as i64, flat as i64],
20634 },
20635 vec![pool],
20636 Shape::new(&[n, flat], f),
20637 );
20638 let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
20639 let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
20640 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
20641 let loss = fwd.add_node(
20642 Op::Reduce {
20643 op: ReduceOp::Mean,
20644 axes: vec![0],
20645 keep_dim: false,
20646 },
20647 vec![loss_per],
20648 Shape::from_dims(&[], f),
20649 );
20650 fwd.set_outputs(vec![loss]);
20651
20652 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
20653 let d_out = bwd_graph
20654 .nodes()
20655 .iter()
20656 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20657 .map(|n| n.id)
20658 .unwrap();
20659
20660 let (sched, mut arena) = prepare(
20661 &bwd_graph,
20662 &[
20663 (xn, &x),
20664 (lb, &labels),
20665 (wcp, &wc),
20666 (bcp, &bc),
20667 (wfp, &wfc),
20668 (bfp, &bfc),
20669 (d_out, &[1.0]),
20670 ],
20671 );
20672 execute_thunks(&sched, arena.raw_buf_mut());
20673
20674 let outs = bwd_graph.outputs.clone();
20675 let loss_id = outs[0];
20676 let g_wc_id = outs[1];
20677 let g_bc_id = outs[2];
20678 let g_wfc_id = outs[3];
20679 let g_bfc_id = outs[4];
20680 let loss_actual = read_arena(&arena, loss_id, 1)[0];
20681 let g_wc = read_arena(&arena, g_wc_id, wc.len());
20682 let g_bc = read_arena(&arena, g_bc_id, bc.len());
20683 let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
20684 let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
20685
20686 let plan = rlx_opt::memory::plan_memory(&fwd);
20688 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20689 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20690 write_arena(&mut fwd_arena, xn, &x);
20691 write_arena(&mut fwd_arena, lb, &labels);
20692
20693 let run_loss = |arena: &mut crate::arena::Arena,
20696 wc: &[f32],
20697 bc: &[f32],
20698 wfc: &[f32],
20699 bfc: &[f32]|
20700 -> f32 {
20701 write_arena(arena, wcp, wc);
20702 write_arena(arena, bcp, bc);
20703 write_arena(arena, wfp, wfc);
20704 write_arena(arena, bfp, bfc);
20705 execute_thunks(&fwd_sched, arena.raw_buf_mut());
20706 read_arena(arena, loss, 1)[0]
20707 };
20708
20709 let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
20710 assert!(
20711 (loss_actual - loss_check).abs() < 1e-4,
20712 "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
20713 );
20714
20715 let eps = 1e-3f32;
20716 let check_grad = |arena: &mut crate::arena::Arena,
20717 name: &str,
20718 analytical: &[f32],
20719 mut perturb: Box<
20720 dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
20721 >,
20722 n: usize| {
20723 for i in 0..n {
20724 let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
20725 let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
20726 let num = (lp - lm) / (2.0 * eps);
20727 assert!(
20728 (analytical[i] - num).abs() < 5e-3,
20729 "{name}[{i}]: analytical {} vs numerical {num}",
20730 analytical[i]
20731 );
20732 }
20733 };
20734
20735 #[allow(unused_macros)]
20738 macro_rules! sweep {
20739 ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
20740 let n = $base.len();
20741 for i in 0..n {
20742 let mut p = $base.clone();
20743 let s = p[i];
20744 p[i] = s + eps;
20745 let lp = {
20746 let $set_param = &p;
20747 run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
20748 let _ = $set_param;
20751 0.0_f32
20753 };
20754 let _ = lp;
20755 }
20756 }};
20757 }
20758 let _ = check_grad; for i in 0..wc.len() {
20762 let mut p = wc.clone();
20763 let s = p[i];
20764 p[i] = s + eps;
20765 let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
20766 p[i] = s - eps;
20767 let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
20768 let num = (lp - lm) / (2.0 * eps);
20769 assert!(
20770 (g_wc[i] - num).abs() < 5e-3,
20771 "g_wc[{i}]: {} vs {num}",
20772 g_wc[i]
20773 );
20774 }
20775 for i in 0..bc.len() {
20776 let mut p = bc.clone();
20777 let s = p[i];
20778 p[i] = s + eps;
20779 let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
20780 p[i] = s - eps;
20781 let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
20782 let num = (lp - lm) / (2.0 * eps);
20783 assert!(
20784 (g_bc[i] - num).abs() < 5e-3,
20785 "g_bc[{i}]: {} vs {num}",
20786 g_bc[i]
20787 );
20788 }
20789 for i in 0..wfc.len() {
20790 let mut p = wfc.clone();
20791 let s = p[i];
20792 p[i] = s + eps;
20793 let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
20794 p[i] = s - eps;
20795 let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
20796 let num = (lp - lm) / (2.0 * eps);
20797 assert!(
20798 (g_wfc[i] - num).abs() < 5e-3,
20799 "g_wfc[{i}]: {} vs {num}",
20800 g_wfc[i]
20801 );
20802 }
20803 for i in 0..bfc.len() {
20804 let mut p = bfc.clone();
20805 let s = p[i];
20806 p[i] = s + eps;
20807 let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
20808 p[i] = s - eps;
20809 let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
20810 let num = (lp - lm) / (2.0 * eps);
20811 assert!(
20812 (g_bfc[i] - num).abs() < 5e-3,
20813 "g_bfc[{i}]: {} vs {num}",
20814 g_bfc[i]
20815 );
20816 }
20817 }
20818
20819 #[test]
20823 fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
20824 let f = DType::F32;
20825 let mut g = Graph::new("nr_skip");
20826 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
20827 let cos = g.input("cos", Shape::new(&[16], f));
20828 let sin = g.input("sin", Shape::new(&[16], f));
20829 let q = g.narrow_(qkv, 2, 0, 64);
20830 let q_rope = g.rope(q, cos, sin, 16);
20831 let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
20833 g.set_outputs(vec![q_rope, q_dup]);
20834
20835 let plan = rlx_opt::memory::plan_memory(&g);
20836 let arena = crate::arena::Arena::from_plan(plan);
20837 let sched = compile_thunks(&g, &arena);
20838
20839 let narrow_count = sched
20840 .thunks
20841 .iter()
20842 .filter(|t| matches!(t, Thunk::Narrow { .. }))
20843 .count();
20844 assert!(
20845 narrow_count >= 1,
20846 "Narrow with multiple consumers must NOT be fused away"
20847 );
20848 }
20849
20850 #[test]
20863 fn custom_fn_forward_inlines_body() {
20864 let s = Shape::new(&[3], DType::F32);
20865
20866 let mut body = Graph::new("addone_body");
20868 let x = body.input("x", s.clone());
20869 let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
20870 let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
20871 let y = body.binary(BinaryOp::Add, x, one, s.clone());
20872 body.set_outputs(vec![y]);
20873
20874 let mut g = Graph::new("custom_fn_outer");
20875 let xin = g.input("x_in", s.clone());
20876 let cf = g.custom_fn(vec![xin], body, None, None);
20877 g.set_outputs(vec![cf]);
20878
20879 let xs = vec![10.0_f32, 20.0, 30.0];
20880 let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
20881 execute_thunks(&sched, arena.raw_buf_mut());
20882 let got = read_arena(&arena, cf, 3);
20883 assert_eq!(got, vec![11.0, 21.0, 31.0]);
20884 }
20885
20886 fn find_named(graph: &Graph, want: &str) -> NodeId {
20888 for n in graph.nodes() {
20889 let name = match &n.op {
20890 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20891 _ => None,
20892 };
20893 if name == Some(want) {
20894 return n.id;
20895 }
20896 }
20897 panic!("no node named {want:?} in graph");
20898 }
20899
20900 #[test]
20904 fn custom_fn_vjp_overrides_natural_gradient() {
20905 use rlx_opt::autodiff::grad_with_loss;
20906 let s = Shape::new(&[1], DType::F32);
20907
20908 let mut fwd = Graph::new("id_fwd");
20909 let x = fwd.input("x", s.clone());
20910 fwd.set_outputs(vec![x]);
20911
20912 let mut vjp_g = Graph::new("id_vjp");
20913 let _x_p = vjp_g.input("x", s.clone());
20914 let _y_p = vjp_g.input("primal_output", s.clone());
20915 let dy = vjp_g.input("d_output", s.clone());
20916 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
20917 let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
20918 let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
20919 vjp_g.set_outputs(vec![dx]);
20920
20921 let mut g = Graph::new("outer");
20922 let xp = g.param("x", s.clone());
20923 let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
20924 g.set_outputs(vec![cf]);
20925
20926 let bwd = grad_with_loss(&g, &[xp]);
20927 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
20928
20929 let xb = find_named(&bwd, "x");
20930 let dout = find_named(&bwd, "d_output");
20931 let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
20932 execute_thunks(&sched, arena.raw_buf_mut());
20933 let loss = read_arena(&arena, bwd.outputs[0], 1);
20934 let dx_v = read_arena(&arena, bwd.outputs[1], 1);
20935 assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
20936 assert!(
20937 (dx_v[0] - 2.0).abs() < 1e-6,
20938 "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
20939 dx_v[0]
20940 );
20941 }
20942
20943 #[test]
20948 fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
20949 use rlx_opt::autodiff::grad_with_loss;
20950 let s = Shape::new(&[1], DType::F32);
20951
20952 let mut fwd = Graph::new("mul_fwd");
20953 let a_f = fwd.input("a", s.clone());
20954 let b_f = fwd.input("b", s.clone());
20955 let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
20956 fwd.set_outputs(vec![y_f]);
20957
20958 let mut vjp_g = Graph::new("mul_vjp");
20959 let a_v = vjp_g.input("a", s.clone());
20960 let b_v = vjp_g.input("b", s.clone());
20961 let _y_v = vjp_g.input("primal_output", s.clone());
20962 let dy_v = vjp_g.input("d_output", s.clone());
20963 let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
20964 let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
20965 vjp_g.set_outputs(vec![da, db]);
20966
20967 let mut g = Graph::new("outer");
20968 let ap = g.param("a", s.clone());
20969 let bp = g.param("b", s.clone());
20970 let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
20971 g.set_outputs(vec![cf]);
20972
20973 let bwd = grad_with_loss(&g, &[ap, bp]);
20974 assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
20975
20976 let ab = find_named(&bwd, "a");
20977 let bb = find_named(&bwd, "b");
20978 let dout = find_named(&bwd, "d_output");
20979 let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
20980 execute_thunks(&sched, arena.raw_buf_mut());
20981 let loss = read_arena(&arena, bwd.outputs[0], 1);
20982 let da_v = read_arena(&arena, bwd.outputs[1], 1);
20983 let db_v = read_arena(&arena, bwd.outputs[2], 1);
20984 assert!((loss[0] - 15.0).abs() < 1e-5);
20985 assert!(
20986 (da_v[0] - 5.0).abs() < 1e-5,
20987 "da should be b=5.0, got {}",
20988 da_v[0]
20989 );
20990 assert!(
20991 (db_v[0] - 3.0).abs() < 1e-5,
20992 "db should be a=3.0, got {}",
20993 db_v[0]
20994 );
20995 }
20996
20997 #[test]
21000 fn custom_fn_jvp_overrides_natural_tangent() {
21001 use rlx_opt::autodiff_fwd::jvp;
21002 let s = Shape::new(&[1], DType::F32);
21003
21004 let mut fwd = Graph::new("id_fwd");
21005 let x = fwd.input("x", s.clone());
21006 fwd.set_outputs(vec![x]);
21007
21008 let mut jvp_g = Graph::new("id_jvp");
21009 let _x_p = jvp_g.input("x", s.clone());
21010 let tx = jvp_g.input("tangent_0", s.clone());
21011 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
21012 let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
21013 let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
21014 jvp_g.set_outputs(vec![ty]);
21015
21016 let mut g = Graph::new("outer");
21017 let xin = g.input("x_in", s.clone());
21018 let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
21019 g.set_outputs(vec![cf]);
21020
21021 let fwd_g = jvp(&g, &[xin]);
21022 assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
21023
21024 let xb = find_named(&fwd_g, "x_in");
21025 let tan = find_named(&fwd_g, "tangent_x_in");
21026 let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
21027 execute_thunks(&sched, arena.raw_buf_mut());
21028 let y = read_arena(&arena, fwd_g.outputs[0], 1);
21029 let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
21030 assert!((y[0] - 7.0).abs() < 1e-6);
21031 assert!(
21032 (ty_v[0] - 2.0).abs() < 1e-6,
21033 "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
21034 ty_v[0]
21035 );
21036 }
21037
21038 #[test]
21043 fn c64_dtype_storage_layout() {
21044 assert_eq!(
21045 DType::C64.size_bytes(),
21046 8,
21047 "C64 should be 8 bytes (f32 real + f32 imag)"
21048 );
21049 assert!(DType::C64.is_complex());
21050 assert!(!DType::C64.is_float());
21051
21052 let s = Shape::new(&[2], DType::C64);
21054 assert_eq!(s.size_bytes().unwrap(), 16);
21055 }
21056
21057 fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
21064 let n = a.len();
21065 let s = Shape::new(&[n], DType::C64);
21066 let mut g = Graph::new("c64_bin");
21067 let in_a = g.input("a", s.clone());
21068 let in_b = g.input("b", s.clone());
21069 let out = g.binary(op, in_a, in_b, s.clone());
21070 g.set_outputs(vec![out]);
21071
21072 let plan = rlx_opt::memory::plan_memory(&g);
21073 let mut arena = crate::arena::Arena::from_plan(plan);
21074 let sched = compile_thunks(&g, &arena);
21075
21076 let a_off = arena.byte_offset(in_a);
21077 let b_off = arena.byte_offset(in_b);
21078 let out_off = arena.byte_offset(out);
21079 let buf = arena.raw_buf_mut();
21081 unsafe {
21082 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21083 let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
21084 for (i, &(re, im)) in a.iter().enumerate() {
21085 *pa.add(2 * i) = re;
21086 *pa.add(2 * i + 1) = im;
21087 }
21088 for (i, &(re, im)) in b.iter().enumerate() {
21089 *pb.add(2 * i) = re;
21090 *pb.add(2 * i + 1) = im;
21091 }
21092 }
21093 execute_thunks(&sched, arena.raw_buf_mut());
21094 let raw_out: Vec<f32> = unsafe {
21095 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21096 (0..(2 * n)).map(|i| *p.add(i)).collect()
21097 };
21098 (0..n)
21099 .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
21100 .collect()
21101 }
21102
21103 #[track_caller]
21104 fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
21105 let dr = (got.0 - expected.0).abs();
21106 let di = (got.1 - expected.1).abs();
21107 assert!(
21108 dr < tol && di < tol,
21109 "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
21110 got.0,
21111 got.1,
21112 expected.0,
21113 expected.1
21114 );
21115 }
21116
21117 #[test]
21118 fn c64_binary_add_matches_complex_arithmetic() {
21119 let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
21120 let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
21121 let out = run_c64_binary(BinaryOp::Add, &a, &b);
21122 assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
21123 assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
21124 }
21125
21126 #[test]
21127 fn c64_binary_sub_matches_complex_arithmetic() {
21128 let a = [(5.0_f32, 1.0_f32)];
21129 let b = [(2.0_f32, 3.0_f32)];
21130 let out = run_c64_binary(BinaryOp::Sub, &a, &b);
21131 assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
21132 }
21133
21134 #[test]
21135 fn c64_binary_mul_matches_complex_arithmetic() {
21136 let a = [(1.0_f32, 2.0_f32)];
21138 let b = [(3.0_f32, 4.0_f32)];
21139 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21140 assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
21141 }
21142
21143 #[test]
21144 fn c64_binary_div_matches_complex_arithmetic() {
21145 let a = [(1.0_f32, 2.0_f32)];
21149 let b = [(3.0_f32, 4.0_f32)];
21150 let out = run_c64_binary(BinaryOp::Div, &a, &b);
21151 assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
21152 }
21153
21154 #[test]
21155 fn c64_binary_mul_identity_one_is_no_op() {
21156 let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
21158 let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
21159 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21160 assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
21161 assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
21162 }
21163
21164 #[test]
21165 fn c64_binary_mul_by_i_rotates_90_degrees() {
21166 let a = [(1.0_f32, 0.0_f32)];
21168 let b = [(0.0_f32, 1.0_f32)];
21169 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21170 assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
21171 }
21172
21173 #[test]
21174 fn c64_binary_div_by_self_gives_unity() {
21175 let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
21176 let out = run_c64_binary(BinaryOp::Div, &a, &a);
21177 assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
21178 assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
21179 }
21180
21181 #[test]
21182 #[should_panic(expected = "C64: complex max/min/pow")]
21183 fn c64_binary_max_is_rejected_at_lowering() {
21184 run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
21185 }
21186
21187 fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
21188 let n = a.len();
21189 let s = Shape::new(&[n], DType::C64);
21190 let mut g = Graph::new("c64_act");
21191 let in_a = g.input("a", s.clone());
21192 let out = g.activation(act, in_a, s.clone());
21193 g.set_outputs(vec![out]);
21194 let plan = rlx_opt::memory::plan_memory(&g);
21195 let mut arena = crate::arena::Arena::from_plan(plan);
21196 let sched = compile_thunks(&g, &arena);
21197 let a_off = arena.byte_offset(in_a);
21198 let out_off = arena.byte_offset(out);
21199 let buf = arena.raw_buf_mut();
21200 unsafe {
21201 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21202 for (i, &(re, im)) in a.iter().enumerate() {
21203 *pa.add(2 * i) = re;
21204 *pa.add(2 * i + 1) = im;
21205 }
21206 }
21207 execute_thunks(&sched, arena.raw_buf_mut());
21208 let raw: Vec<f32> = unsafe {
21209 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21210 (0..(2 * n)).map(|i| *p.add(i)).collect()
21211 };
21212 (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
21213 }
21214
21215 #[test]
21216 fn c64_activation_neg_negates_both_components() {
21217 let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
21218 let out = run_c64_activation(Activation::Neg, &inp);
21219 assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
21220 assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
21221 }
21222
21223 #[test]
21224 fn c64_activation_exp_matches_euler() {
21225 let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
21228 let out = run_c64_activation(Activation::Exp, &inp);
21229 assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
21230 assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
21231 }
21232
21233 #[test]
21234 fn c64_activation_log_matches_principal_branch() {
21235 let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
21239 let out = run_c64_activation(Activation::Log, &inp);
21240 assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
21241 assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
21242 assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
21243 }
21244
21245 #[test]
21246 fn c64_activation_sqrt_squared_recovers_input() {
21247 let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
21250 let roots = run_c64_activation(Activation::Sqrt, &inp);
21251 assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
21253 assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
21254 }
21255
21256 #[test]
21257 #[should_panic(expected = "no natural complex extension")]
21258 fn c64_activation_relu_is_rejected_at_lowering() {
21259 run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
21260 }
21261
21262 fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
21266 let n = z.len();
21267 let mut g = Graph::new("cns_fwd");
21268 let in_z = g.input("z", Shape::new(&[n], DType::C64));
21269 let out = g.complex_norm_sq(in_z);
21270 g.set_outputs(vec![out]);
21271 let plan = rlx_opt::memory::plan_memory(&g);
21272 let mut arena = crate::arena::Arena::from_plan(plan);
21273 let sched = compile_thunks(&g, &arena);
21274 let z_off = arena.byte_offset(in_z);
21275 let out_off = arena.byte_offset(out);
21276 let buf = arena.raw_buf_mut();
21277 unsafe {
21278 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
21279 for (i, &(re, im)) in z.iter().enumerate() {
21280 *pz.add(2 * i) = re;
21281 *pz.add(2 * i + 1) = im;
21282 }
21283 }
21284 execute_thunks(&sched, arena.raw_buf_mut());
21285 unsafe {
21286 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21287 (0..n).map(|i| *p.add(i)).collect()
21288 }
21289 }
21290
21291 fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
21293 let n = z.len();
21294 let mut gr = Graph::new("cns_bwd");
21295 let in_z = gr.input("z", Shape::new(&[n], DType::C64));
21296 let in_g = gr.input("g", Shape::new(&[n], DType::F32));
21297 let out = gr.complex_norm_sq_backward(in_z, in_g);
21298 gr.set_outputs(vec![out]);
21299 let plan = rlx_opt::memory::plan_memory(&gr);
21300 let mut arena = crate::arena::Arena::from_plan(plan);
21301 let sched = compile_thunks(&gr, &arena);
21302 let z_off = arena.byte_offset(in_z);
21303 let g_off = arena.byte_offset(in_g);
21304 let out_off = arena.byte_offset(out);
21305 let buf = arena.raw_buf_mut();
21306 unsafe {
21307 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
21308 let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
21309 for (i, &(re, im)) in z.iter().enumerate() {
21310 *pz.add(2 * i) = re;
21311 *pz.add(2 * i + 1) = im;
21312 }
21313 for (i, &v) in g.iter().enumerate() {
21314 *pg.add(i) = v;
21315 }
21316 }
21317 execute_thunks(&sched, arena.raw_buf_mut());
21318 unsafe {
21319 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21320 (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
21321 }
21322 }
21323
21324 #[test]
21325 fn complex_norm_sq_matches_textbook() {
21326 let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
21330 let out = run_complex_norm_sq(&z);
21331 assert!((out[0] - 25.0).abs() < 1e-5);
21332 assert!((out[1] - 1.0).abs() < 1e-6);
21333 assert!(out[2].abs() < 1e-6);
21334 }
21335
21336 #[test]
21337 fn complex_norm_sq_backward_matches_wirtinger_formula() {
21338 let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
21340 let g = [1.0_f32, 1.0_f32];
21341 let dz = run_complex_norm_sq_bwd(&z, &g);
21342 assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
21343 assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
21344 }
21345
21346 #[test]
21347 fn complex_norm_sq_backward_scales_with_upstream() {
21348 let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
21350 let g = [0.5_f32, -2.0_f32];
21351 let dz = run_complex_norm_sq_bwd(&z, &g);
21352 assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
21353 assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
21354 }
21355
21356 #[test]
21361 fn custom_fn_multi_extracts_each_subgraph_output() {
21362 use rlx_ir::ops::special::MultiOutputHandle;
21363
21364 let _ = MultiOutputHandle {
21365 source: NodeId(0),
21366 sub_shapes: vec![],
21367 offsets: vec![],
21368 }; let mut body = Graph::new("multi_body");
21372 let s3 = Shape::new(&[3], DType::F32);
21373 let x = body.input("x", s3.clone());
21374 let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
21375 let two = body.add_node(
21376 Op::Constant {
21377 data: vec![
21378 2.0_f32.to_le_bytes(),
21379 2.0_f32.to_le_bytes(),
21380 2.0_f32.to_le_bytes(),
21381 ]
21382 .into_iter()
21383 .flatten()
21384 .collect(),
21385 },
21386 vec![],
21387 s3.clone(),
21388 );
21389 let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
21390 body.set_outputs(vec![x_sq, two_x]);
21391
21392 let mut outer = Graph::new("multi_outer");
21394 let in_x = outer.input("xin", s3.clone());
21395 let handle = outer.custom_fn_multi(vec![in_x], body);
21396 assert_eq!(handle.n_outputs(), 2);
21397 let out0 = handle.output(&mut outer, 0); let out1 = handle.output(&mut outer, 1); outer.set_outputs(vec![out0, out1]);
21400
21401 let plan = rlx_opt::memory::plan_memory(&outer);
21402 let mut arena = crate::arena::Arena::from_plan(plan);
21403 let sched = compile_thunks(&outer, &arena);
21404 let xin_off = arena.byte_offset(in_x);
21405 let out0_off = arena.byte_offset(out0);
21406 let out1_off = arena.byte_offset(out1);
21407 let xs = [1.0_f32, 2.0, 3.0];
21408 unsafe {
21409 let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
21410 for (i, &v) in xs.iter().enumerate() {
21411 *p.add(i) = v;
21412 }
21413 }
21414 execute_thunks(&sched, arena.raw_buf_mut());
21415 let out0_v: Vec<f32> = unsafe {
21416 let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
21417 (0..3).map(|i| *p.add(i)).collect()
21418 };
21419 let out1_v: Vec<f32> = unsafe {
21420 let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
21421 (0..3).map(|i| *p.add(i)).collect()
21422 };
21423 for i in 0..3 {
21425 assert!(
21426 (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
21427 "out0[{i}] = {} != x² = {}",
21428 out0_v[i],
21429 xs[i] * xs[i]
21430 );
21431 assert!(
21432 (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
21433 "out1[{i}] = {} != 2x = {}",
21434 out1_v[i],
21435 2.0 * xs[i]
21436 );
21437 }
21438 }
21439
21440 #[test]
21441 fn complex_norm_sq_gradient_matches_finite_difference() {
21442 let z = [(3.0_f32, 4.0_f32)];
21444 let eps = 1e-3_f32;
21445 let v0 = run_complex_norm_sq(&z)[0];
21446 let z_pert = [(3.0_f32 + eps, 4.0_f32)];
21447 let v1 = run_complex_norm_sq(&z_pert)[0];
21448 let fd_re = (v1 - v0) / eps;
21449 let analytic_re = 2.0 * z[0].0;
21450 assert!((fd_re - analytic_re).abs() < 1e-2);
21451
21452 let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
21454 let v2 = run_complex_norm_sq(&z_pert_im)[0];
21455 let fd_im = (v2 - v0) / eps;
21456 let analytic_im = 2.0 * z[0].1;
21457 assert!((fd_im - analytic_im).abs() < 1e-2);
21458
21459 let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
21465 assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
21466 assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
21467 }
21468
21469 #[test]
21474 fn binary_full_5d_mid_singleton_broadcast() {
21475 let bh = 2usize;
21476 let h = 3;
21477 let w = 4;
21478 let f = DType::F32;
21479
21480 let mut g = Graph::new("bcast_5d");
21481 let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
21482 let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
21484 let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
21485 g.set_outputs(vec![out]);
21486
21487 let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
21489 let rhs_data: Vec<f32> = (0..bh * h * w * w)
21490 .map(|i| (i as f32 + 100.0) * 0.01)
21491 .collect();
21492
21493 let mut expected = vec![0f32; bh * h * w * h * w];
21495 for b_ in 0..bh {
21496 for hq in 0..h {
21497 for wq in 0..w {
21498 for hk in 0..h {
21499 for wk in 0..w {
21500 let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
21501 let ri = ((b_ * h + hq) * w + wq) * w + wk;
21503 expected[li] = lhs_data[li] + rhs_data[ri];
21504 }
21505 }
21506 }
21507 }
21508 }
21509
21510 let plan = rlx_opt::memory::plan_memory(&g);
21511 let mut arena = crate::arena::Arena::from_plan(plan);
21512 let sched = compile_thunks(&g, &arena);
21513 let lhs_off = arena.byte_offset(lhs);
21514 let rhs_off = arena.byte_offset(rhs);
21515 let out_off = arena.byte_offset(out);
21516 let buf = arena.raw_buf_mut();
21517 unsafe {
21518 let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
21519 for (i, &v) in lhs_data.iter().enumerate() {
21520 *p.add(i) = v;
21521 }
21522 let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
21523 for (i, &v) in rhs_data.iter().enumerate() {
21524 *p.add(i) = v;
21525 }
21526 }
21527 execute_thunks(&sched, arena.raw_buf_mut());
21528 let actual: Vec<f32> = unsafe {
21529 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21530 (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
21531 };
21532
21533 let mut max_diff = 0f32;
21535 let mut max_idx = 0;
21536 for i in 0..actual.len() {
21537 let d = (actual[i] - expected[i]).abs();
21538 if d > max_diff {
21539 max_diff = d;
21540 max_idx = i;
21541 }
21542 }
21543 assert!(
21544 max_diff < 1e-6,
21545 "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
21546 (actual={}, expected={})",
21547 actual[max_idx],
21548 expected[max_idx]
21549 );
21550 }
21551
21552 #[test]
21553 fn layer_norm2d_and_conv_transpose2d_kernels() {
21554 let mut out = vec![0f32; 8];
21555 crate::kernels::layer_norm2d_nchw(
21556 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
21557 &[1.0, 1.0],
21558 &[0.0, 0.0],
21559 &mut out,
21560 1,
21561 2,
21562 2,
21563 2,
21564 1e-5,
21565 );
21566 let mean0: f32 = (1.0 + 3.0) / 2.0;
21567 assert!((out[0] - mean0).abs() > 0.1);
21568
21569 let mut up = vec![0f32; 4];
21570 crate::kernels::conv_transpose2d_nchw(
21571 &[2.0],
21572 &[1.0, 0.0, 0.0, 1.0],
21573 &mut up,
21574 1,
21575 1,
21576 1,
21577 1,
21578 1,
21579 2,
21580 2,
21581 2,
21582 2,
21583 2,
21584 2,
21585 0,
21586 0,
21587 1,
21588 1,
21589 1,
21590 );
21591 assert!((up[0] - 2.0).abs() < 1e-5);
21592 assert!((up[3] - 2.0).abs() < 1e-5);
21593 }
21594}