1#![allow(unsafe_op_in_unsafe_fn)]
24use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33#[derive(Clone)]
35pub enum Thunk {
36 Nop,
38 Sgemm {
40 a: usize,
41 b: usize,
42 c: usize,
43 m: u32,
44 k: u32,
45 n: u32,
46 },
47 DenseSolveF64 {
53 a: usize,
54 b: usize,
55 x: usize,
56 n: u32,
57 nrhs: u32,
58 },
59 DenseSolveF32 {
62 a: usize,
63 b: usize,
64 x: usize,
65 n: u32,
66 nrhs: u32,
67 },
68 BatchedDenseSolveF64 {
73 a: usize,
74 b: usize,
75 x: usize,
76 batch: u32,
77 n: u32,
78 nrhs: u32,
79 },
80 BatchedDenseSolveF32 {
82 a: usize,
83 b: usize,
84 x: usize,
85 batch: u32,
86 n: u32,
87 nrhs: u32,
88 },
89 BatchedDgemmF64 {
95 a: usize,
96 b: usize,
97 c: usize,
98 batch: u32,
99 m: u32,
100 k: u32,
101 n: u32,
102 },
103 BatchedSgemm {
110 a: usize,
111 b: usize,
112 c: usize,
113 batch: u32,
114 m: u32,
115 k: u32,
116 n: u32,
117 },
118 Dgemm {
120 a: usize,
121 b: usize,
122 c: usize,
123 m: u32,
124 k: u32,
125 n: u32,
126 },
127 TransposeF64 {
131 src: usize,
132 dst: usize,
133 in_total: u32,
134 out_dims: Vec<u32>,
135 in_strides: Vec<u32>,
136 },
137 ActivationF64 {
141 src: usize,
142 dst: usize,
143 len: u32,
144 kind: Activation,
145 },
146 ComplexNormSqF32 {
150 src: usize,
151 dst: usize,
152 len: u32,
154 },
155 ComplexNormSqBackwardF32 {
159 z: usize,
160 g: usize,
161 dz: usize,
162 len: u32,
163 },
164 ConjugateC64 { src: usize, dst: usize, len: u32 },
167 ActivationC64 {
174 src: usize,
175 dst: usize,
176 len: u32,
177 kind: Activation,
178 },
179 ReduceSumF64 {
183 src: usize,
184 dst: usize,
185 outer: u32,
186 reduced: u32,
187 inner: u32,
188 },
189 CopyF64 { src: usize, dst: usize, len: u32 },
192 CopyI64 { src: usize, dst: usize, len: u32 },
194 CastF32ToI64 { src: usize, dst: usize, len: u32 },
196 CastI64ToF32 { src: usize, dst: usize, len: u32 },
198 CastBoolToI32 { src: usize, dst: usize, len: u32 },
200 CastI32ToF32 { src: usize, dst: usize, len: u32 },
202 BinaryFullF64 {
206 lhs: usize,
207 rhs: usize,
208 dst: usize,
209 len: u32,
210 lhs_len: u32,
211 rhs_len: u32,
212 op: BinaryOp,
213 out_dims_bcast: Vec<u32>,
216 bcast_lhs_strides: Vec<u32>,
217 bcast_rhs_strides: Vec<u32>,
218 },
219 ConcatF64 {
223 dst: usize,
224 outer: u32,
225 inner: u32,
226 total_axis: u32,
227 inputs: Vec<(usize, u32)>,
228 },
229 BinaryFullC64 {
237 lhs: usize,
238 rhs: usize,
239 dst: usize,
240 len: u32,
243 lhs_len: u32,
244 rhs_len: u32,
245 op: BinaryOp,
246 out_dims_bcast: Vec<u32>,
247 bcast_lhs_strides: Vec<u32>,
248 bcast_rhs_strides: Vec<u32>,
249 },
250 Scan {
259 body: Arc<ThunkSchedule>,
260 body_init: Arc<Vec<u8>>, body_input_off: usize, body_output_off: usize, outer_init_off: usize, outer_final_off: usize, length: u32,
266 carry_bytes: u32, save_trajectory: bool,
272 xs_inputs: Arc<Vec<(usize, usize, u32)>>,
277 bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
283 num_checkpoints: u32,
289 },
290
291 ScanBackward {
299 body_vjp: Arc<ThunkSchedule>,
300 body_init: Arc<Vec<u8>>,
301 body_carry_in_off: usize, body_x_offs: Arc<Vec<usize>>, body_d_output_off: usize, body_dcarry_out_off: usize, outer_init_off: usize, outer_traj_off: usize, outer_upstream_off: usize, outer_xs_offs: Arc<Vec<(usize, u32)>>,
311 outer_dinit_off: usize, length: u32,
313 carry_bytes: u32,
314 carry_elem_size: u32,
320 save_trajectory: bool, num_checkpoints: u32,
327 forward_body: Option<Arc<ThunkSchedule>>,
331 forward_body_init: Option<Arc<Vec<u8>>>,
333 forward_body_carry_in_off: usize,
336 forward_body_output_off: usize,
337 forward_body_x_offs: Arc<Vec<usize>>,
340 },
341
342 ScanBackwardXs {
349 body_vjp: Arc<ThunkSchedule>,
350 body_init: Arc<Vec<u8>>,
351 body_carry_in_off: usize,
352 body_x_offs: Arc<Vec<usize>>,
353 body_d_output_off: usize,
354 body_dcarry_out_off: usize,
355 body_dxs_out_off: usize, outer_init_off: usize,
357 outer_traj_off: usize,
358 outer_upstream_off: usize,
359 outer_xs_offs: Arc<Vec<(usize, u32)>>,
360 outer_dxs_off: usize, length: u32,
362 carry_bytes: u32,
363 carry_elem_size: u32,
365 per_step_bytes: u32, save_trajectory: bool,
367 num_checkpoints: u32,
375 forward_body: Option<Arc<ThunkSchedule>>,
376 forward_body_init: Option<Arc<Vec<u8>>>,
377 forward_body_carry_in_off: usize,
378 forward_body_output_off: usize,
379 forward_body_x_offs: Arc<Vec<usize>>,
380 },
381 CustomFn {
386 body: Arc<ThunkSchedule>,
387 body_init: Arc<Vec<u8>>,
388 inputs: Arc<Vec<(usize, usize, u32)>>,
390 body_output_off: usize,
391 outer_output_off: usize,
392 out_bytes: u32,
393 },
394 FusedMmBiasAct {
396 a: usize,
397 w: usize,
398 bias: usize,
399 c: usize,
400 m: u32,
401 k: u32,
402 n: u32,
403 act: Option<Activation>,
404 },
405 FusedResidualLN {
407 x: usize,
408 res: usize,
409 bias: usize,
410 g: usize,
411 b: usize,
412 out: usize,
413 rows: u32,
414 h: u32,
415 eps: f32,
416 has_bias: bool,
417 },
418 FusedResidualRmsNorm {
420 x: usize,
421 res: usize,
422 bias: usize,
423 g: usize,
424 b: usize,
425 out: usize,
426 rows: u32,
427 h: u32,
428 eps: f32,
429 has_bias: bool,
430 },
431 BiasAdd {
433 src: usize,
434 bias: usize,
435 dst: usize,
436 m: u32,
437 n: u32,
438 },
439 BinaryFull {
454 lhs: usize,
455 rhs: usize,
456 dst: usize,
457 len: u32,
458 lhs_len: u32,
459 rhs_len: u32,
460 op: BinaryOp,
461 out_dims_bcast: Vec<u32>,
463 bcast_lhs_strides: Vec<u32>,
465 bcast_rhs_strides: Vec<u32>,
467 elem_bytes: u8,
469 },
470 ActivationInPlace {
472 data: usize,
473 len: u32,
474 act: Activation,
475 },
476 Gather {
478 table: usize,
479 table_len: u32,
480 idx: usize,
481 dst: usize,
482 num_idx: u32,
483 trailing: u32,
484 idx_i64: u8,
486 table_bytes: u8,
488 },
489 Narrow {
491 src: usize,
492 dst: usize,
493 outer: u32,
494 src_stride: u32,
495 dst_stride: u32,
496 inner: u32,
497 elem_bytes: u8,
498 },
499 Copy { src: usize, dst: usize, len: u32 },
501 LayerNorm {
503 src: usize,
504 g: usize,
505 b: usize,
506 dst: usize,
507 rows: u32,
508 h: u32,
509 eps: f32,
510 },
511 GroupNorm {
513 src: usize,
514 g: usize,
515 b: usize,
516 dst: usize,
517 n: u32,
518 c: u32,
519 h: u32,
520 w: u32,
521 num_groups: u32,
522 eps: f32,
523 },
524 BatchNormInference {
526 src: usize,
527 g: usize,
528 b: usize,
529 mean: usize,
530 var: usize,
531 dst: usize,
532 count: u32,
533 channels: u32,
534 eps: f32,
535 },
536 BatchNormInferenceBackwardInput {
537 x: usize,
538 gamma: usize,
539 mean: usize,
540 var: usize,
541 dy: usize,
542 dx: usize,
543 count: u32,
544 channels: u32,
545 eps: f32,
546 },
547 BatchNormInferenceBackwardGamma {
548 x: usize,
549 mean: usize,
550 var: usize,
551 dy: usize,
552 dgamma: usize,
553 count: u32,
554 channels: u32,
555 eps: f32,
556 },
557 BatchNormInferenceBackwardBeta {
558 dy: usize,
559 dbeta: usize,
560 count: u32,
561 channels: u32,
562 },
563 LayerNorm2d {
565 src: usize,
566 g: usize,
567 b: usize,
568 dst: usize,
569 n: u32,
570 c: u32,
571 h: u32,
572 w: u32,
573 eps: f32,
574 },
575 ConvTranspose2d {
577 src: usize,
578 weight: usize,
579 dst: usize,
580 n: u32,
581 c_in: u32,
582 h: u32,
583 w_in: u32,
584 c_out: u32,
585 h_out: u32,
586 w_out: u32,
587 kh: u32,
588 kw: u32,
589 sh: u32,
590 sw: u32,
591 ph: u32,
592 pw: u32,
593 dh: u32,
594 dw: u32,
595 groups: u32,
596 },
597 ResizeNearest2x {
599 src: usize,
600 dst: usize,
601 n: u32,
602 c: u32,
603 h: u32,
604 w: u32,
605 },
606 AxialRope2d {
608 src: usize,
609 dst: usize,
610 batch: u32,
611 seq: u32,
612 hidden: u32,
613 end_x: u32,
614 end_y: u32,
615 head_dim: u32,
616 num_heads: u32,
617 theta: f32,
618 repeat_factor: u32,
619 },
620 RmsNorm {
623 src: usize,
624 g: usize,
625 b: usize,
626 dst: usize,
627 rows: u32,
628 h: u32,
629 eps: f32,
630 },
631 Softmax { data: usize, rows: u32, cols: u32 },
633 Cumsum {
636 src: usize,
637 dst: usize,
638 rows: u32,
639 cols: u32,
640 exclusive: bool,
641 },
642 SelectiveScan {
646 x: usize,
647 delta: usize,
648 a: usize,
649 b: usize,
650 c: usize,
651 dst: usize,
652 batch: u32,
653 seq: u32,
654 hidden: u32,
655 state_size: u32,
656 },
657
658 GatedDeltaNet {
662 q: usize,
663 k: usize,
664 v: usize,
665 g: usize,
666 beta: usize,
667 state: usize,
670 dst: usize,
671 batch: u32,
672 seq: u32,
673 heads: u32,
674 state_size: u32,
675 },
676
677 Conv2D1x1 {
687 src: usize,
688 weight: usize,
689 dst: usize,
690 n: u32,
691 c_in: u32,
692 c_out: u32,
693 hw: u32,
694 },
695
696 DequantMatMul {
700 x: usize,
701 w_q: usize, scale: usize, zp: usize, dst: usize,
705 m: u32,
706 k: u32,
707 n: u32,
708 block_size: u32,
709 is_asymmetric: bool,
710 },
711
712 DequantMatMulGguf {
722 x: usize, w_q: usize, dst: usize, m: u32,
726 k: u32,
727 n: u32,
728 scheme: rlx_ir::quant::QuantScheme,
729 },
730
731 DequantMatMulInt4 {
733 x: usize,
734 w_q: usize,
735 scale: usize,
736 zp: usize,
737 dst: usize,
738 m: u32,
739 k: u32,
740 n: u32,
741 block_size: u32,
742 is_asymmetric: bool,
743 },
744
745 DequantMatMulFp8 {
747 x: usize,
748 w_q: usize,
749 scale: usize,
750 dst: usize,
751 m: u32,
752 k: u32,
753 n: u32,
754 e5m2: bool,
755 },
756
757 DequantMatMulNvfp4 {
759 x: usize,
760 w_q: usize,
761 scale: usize,
762 global_scale: usize,
763 dst: usize,
764 m: u32,
765 k: u32,
766 n: u32,
767 },
768
769 LoraMatMul {
773 x: usize,
774 w: usize,
775 a: usize,
776 b: usize,
777 dst: usize,
778 m: u32,
779 k: u32,
780 n: u32,
781 r: u32,
782 scale: f32,
783 },
784 Sample {
788 logits: usize,
789 dst: usize,
790 batch: u32,
791 vocab: u32,
792 top_k: u32, top_p: f32, temperature: f32, seed: u64,
796 },
797 Attention {
808 q: usize,
809 k: usize,
810 v: usize,
811 mask: usize,
812 out: usize,
813 batch: u32,
814 seq: u32,
816 kv_seq: u32,
818 heads: u32,
819 head_dim: u32,
820 mask_kind: rlx_ir::op::MaskKind,
821 q_row_stride: u32,
822 k_row_stride: u32,
823 v_row_stride: u32,
824 bhsd: bool,
832 },
833 AttentionBackward {
835 q: usize,
836 k: usize,
837 v: usize,
838 dy: usize,
839 mask: usize,
840 out: usize,
841 batch: u32,
842 seq: u32,
843 kv_seq: u32,
844 heads: u32,
845 head_dim: u32,
846 mask_kind: rlx_ir::op::MaskKind,
847 wrt: rlx_ir::op::AttentionBwdWrt,
848 bhsd: bool,
849 },
850 Rope {
856 src: usize,
857 cos: usize,
858 sin: usize,
859 dst: usize,
860 batch: u32,
861 seq: u32,
862 hidden: u32,
863 head_dim: u32,
864 n_rot: u32,
865 cos_len: u32,
866 src_row_stride: u32,
867 },
868 FusedAttnBlock {
871 hidden: usize,
872 qkv_w: usize,
873 out_w: usize,
874 mask: usize,
875 out: usize,
876 qkv_b: usize,
877 out_b: usize, cos: usize,
879 sin: usize,
880 cos_len: u32, batch: u32,
882 seq: u32,
883 hs: u32,
884 nh: u32,
885 dh: u32,
886 has_bias: bool,
887 has_rope: bool,
888 },
889 FusedBertLayer {
892 hidden: usize,
894 qkv_w: usize,
895 qkv_b: usize,
896 out_w: usize,
897 out_b: usize,
898 mask: usize,
899 ln1_g: usize,
901 ln1_b: usize,
902 eps1: f32,
903 fc1_w: usize,
905 fc1_b: usize,
906 fc2_w: usize,
907 fc2_b: usize,
908 ln2_g: usize,
910 ln2_b: usize,
911 eps2: f32,
912 out: usize,
914 batch: u32,
916 seq: u32,
917 hs: u32,
918 nh: u32,
919 dh: u32,
920 int_dim: u32,
921 },
922 FusedNomicLayer {
924 hidden: usize,
925 qkv_w: usize,
926 out_w: usize,
927 mask: usize,
928 cos: usize,
929 sin: usize,
930 cos_len: u32,
931 ln1_g: usize,
932 ln1_b: usize,
933 eps1: f32,
934 fc11_w: usize,
935 fc12_w: usize,
936 fc2_w: usize,
937 ln2_g: usize,
938 ln2_b: usize,
939 eps2: f32,
940 out: usize,
941 batch: u32,
942 seq: u32,
943 hs: u32,
944 nh: u32,
945 dh: u32,
946 int_dim: u32,
947 },
948 FusedSwiGLU {
952 src: usize,
953 dst: usize,
954 n_half: u32,
955 total: u32,
956 gate_first: bool,
957 },
958 Concat {
963 dst: usize,
964 outer: u32,
965 inner: u32,
966 total_axis: u32,
967 inputs: Vec<(usize, u32)>,
968 },
969 Compare {
971 lhs: usize,
972 rhs: usize,
973 dst: usize,
974 len: u32,
975 op: CmpOp,
976 inputs_i64: u8,
978 inputs_elem_bytes: u8,
980 dst_elem_bytes: u8,
982 },
983 Reduce {
991 src: usize,
992 dst: usize,
993 outer: u32,
994 reduced: u32,
995 inner: u32,
996 op: ReduceOp,
997 },
998 TopK {
1002 src: usize,
1003 dst: usize,
1004 outer: u32,
1005 axis_dim: u32,
1006 k: u32,
1007 indices_i64: u8,
1008 },
1009 GroupedMatMul {
1013 input: usize,
1014 weight: usize,
1015 expert_idx: usize,
1016 dst: usize,
1017 m: u32,
1018 k_dim: u32,
1019 n: u32,
1020 num_experts: u32,
1021 },
1022 DequantGroupedMatMulGguf {
1024 input: usize,
1025 w_q: usize,
1026 expert_idx: usize,
1027 dst: usize,
1028 m: u32,
1029 k_dim: u32,
1030 n: u32,
1031 num_experts: u32,
1032 scheme: rlx_ir::quant::QuantScheme,
1033 },
1034 DequantMoEWeightsGguf {
1036 w_q: usize,
1037 dst: usize,
1038 k_dim: u32,
1039 n: u32,
1040 num_experts: u32,
1041 scheme: rlx_ir::quant::QuantScheme,
1042 },
1043 ScatterAdd {
1046 updates: usize,
1047 indices: usize,
1048 dst: usize,
1049 num_updates: u32,
1050 out_dim: u32,
1051 trailing: u32,
1052 },
1053 Where {
1055 cond: usize,
1056 on_true: usize,
1057 on_false: usize,
1058 dst: usize,
1059 len: u32,
1060 elem_bytes: u8,
1061 cond_elem_bytes: u8,
1063 },
1064 Transpose {
1070 src: usize,
1071 dst: usize,
1072 in_total: u32,
1073 out_dims: Vec<u32>,
1074 in_strides: Vec<u32>,
1075 elem_bytes: u8,
1076 },
1077 GatherAxis {
1082 table: usize,
1083 idx: usize,
1084 dst: usize,
1085 outer: u32,
1086 axis_dim: u32,
1087 num_idx: u32,
1088 trailing: u32,
1089 idx_i64: u8,
1090 table_bytes: u8,
1091 },
1092 Pool2D {
1096 src: usize,
1097 dst: usize,
1098 n: u32,
1099 c: u32,
1100 h: u32,
1101 w: u32,
1102 h_out: u32,
1103 w_out: u32,
1104 kh: u32,
1105 kw: u32,
1106 sh: u32,
1107 sw: u32,
1108 ph: u32,
1109 pw: u32,
1110 kind: ReduceOp,
1111 },
1112 Conv2D {
1117 src: usize,
1118 weight: usize,
1119 dst: usize,
1120 n: u32,
1121 c_in: u32,
1122 h: u32,
1123 w: u32,
1124 c_out: u32,
1125 h_out: u32,
1126 w_out: u32,
1127 kh: u32,
1128 kw: u32,
1129 sh: u32,
1130 sw: u32,
1131 ph: u32,
1132 pw: u32,
1133 dh: u32,
1134 dw: u32,
1135 groups: u32,
1136 },
1137
1138 QMatMul {
1146 x: usize,
1147 w: usize,
1148 bias: usize,
1149 out: usize,
1150 m: u32,
1151 k: u32,
1152 n: u32,
1153 x_zp: i32,
1154 w_zp: i32,
1155 out_zp: i32,
1156 mult: f32,
1157 },
1158
1159 QConv2d {
1163 x: usize,
1164 w: usize,
1165 bias: usize,
1166 out: usize,
1167 n: u32,
1168 c_in: u32,
1169 h: u32,
1170 w_in: u32,
1171 c_out: u32,
1172 h_out: u32,
1173 w_out: u32,
1174 kh: u32,
1175 kw: u32,
1176 sh: u32,
1177 sw: u32,
1178 ph: u32,
1179 pw: u32,
1180 dh: u32,
1181 dw: u32,
1182 groups: u32,
1183 x_zp: i32,
1184 w_zp: i32,
1185 out_zp: i32,
1186 mult: f32,
1187 },
1188
1189 Quantize {
1196 x: usize,
1197 q: usize,
1198 len: u32,
1199 chan_axis: u32,
1200 chan_dim: u32,
1201 inner: u32,
1202 scales: Vec<f32>,
1203 zero_points: Vec<i32>,
1204 },
1205
1206 Dequantize {
1208 q: usize,
1209 x: usize,
1210 len: u32,
1211 chan_axis: u32,
1212 chan_dim: u32,
1213 inner: u32,
1214 scales: Vec<f32>,
1215 zero_points: Vec<i32>,
1216 },
1217
1218 FakeQuantize {
1229 x: usize,
1230 out: usize,
1231 len: u32,
1232 chan_axis: u32,
1233 chan_dim: u32,
1234 inner: u32,
1235 bits: u8,
1236 ste: rlx_ir::op::SteKind,
1240 scale_mode: rlx_ir::op::ScaleMode,
1245 state_off: Option<usize>,
1249 },
1250
1251 FakeQuantizeBackward {
1256 x: usize,
1257 dy: usize,
1258 dx: usize,
1259 len: u32,
1260 chan_axis: u32,
1261 chan_dim: u32,
1262 inner: u32,
1263 bits: u8,
1264 ste: rlx_ir::op::SteKind,
1265 },
1266
1267 FakeQuantizeLSQ {
1270 x: usize,
1271 scale_off: usize,
1272 out: usize,
1273 len: u32,
1274 chan_axis: u32,
1275 chan_dim: u32,
1276 inner: u32,
1277 bits: u8,
1278 },
1279
1280 FakeQuantizeLSQBackwardX {
1283 x: usize,
1284 scale_off: usize,
1285 dy: usize,
1286 dx: usize,
1287 len: u32,
1288 chan_axis: u32,
1289 chan_dim: u32,
1290 inner: u32,
1291 bits: u8,
1292 },
1293
1294 FakeQuantizeLSQBackwardScale {
1299 x: usize,
1300 scale_off: usize,
1301 dy: usize,
1302 dscale: usize,
1303 len: u32,
1304 chan_axis: u32,
1305 chan_dim: u32,
1306 inner: u32,
1307 bits: u8,
1308 },
1309
1310 ReluBackward {
1312 x: usize,
1313 dy: usize,
1314 dx: usize,
1315 len: u32,
1316 },
1317 ReluBackwardF64 {
1323 x: usize,
1324 dy: usize,
1325 dx: usize,
1326 len: u32,
1327 },
1328
1329 ActivationBackward {
1334 x: usize,
1335 dy: usize,
1336 dx: usize,
1337 len: u32,
1338 kind: Activation,
1339 },
1340 ActivationBackwardF64 {
1346 x: usize,
1347 dy: usize,
1348 dx: usize,
1349 len: u32,
1350 kind: Activation,
1351 },
1352
1353 LayerNormBackwardInput {
1356 x: usize,
1357 gamma: usize,
1358 dy: usize,
1359 dx: usize,
1360 rows: u32,
1361 h: u32,
1362 eps: f32,
1363 },
1364
1365 LayerNormBackwardGamma {
1367 x: usize,
1368 dy: usize,
1369 dgamma: usize,
1370 rows: u32,
1371 h: u32,
1372 eps: f32,
1373 },
1374
1375 RmsNormBackwardInput {
1376 x: usize,
1377 gamma: usize,
1378 beta: usize,
1379 dy: usize,
1380 dx: usize,
1381 rows: u32,
1382 h: u32,
1383 eps: f32,
1384 },
1385 RmsNormBackwardGamma {
1386 x: usize,
1387 gamma: usize,
1388 beta: usize,
1389 dy: usize,
1390 dgamma: usize,
1391 rows: u32,
1392 h: u32,
1393 eps: f32,
1394 },
1395 RmsNormBackwardBeta {
1396 x: usize,
1397 gamma: usize,
1398 beta: usize,
1399 dy: usize,
1400 dbeta: usize,
1401 rows: u32,
1402 h: u32,
1403 eps: f32,
1404 },
1405 RopeBackward {
1406 dy: usize,
1407 cos: usize,
1408 sin: usize,
1409 dx: usize,
1410 batch: u32,
1411 seq: u32,
1412 hidden: u32,
1413 head_dim: u32,
1414 n_rot: u32,
1415 cos_len: u32,
1416 },
1417 CumsumBackward {
1418 dy: usize,
1419 dx: usize,
1420 rows: u32,
1421 cols: u32,
1422 exclusive: bool,
1423 },
1424 GatherBackward {
1425 dy: usize,
1426 indices: usize,
1427 dst: usize,
1428 outer: u32,
1429 axis_dim: u32,
1430 num_idx: u32,
1431 trailing: u32,
1432 },
1433
1434 GroupNormBackwardInput {
1435 x: usize,
1436 gamma: usize,
1437 beta: usize,
1438 dy: usize,
1439 dx: usize,
1440 n: u32,
1441 c: u32,
1442 h: u32,
1443 w: u32,
1444 num_groups: u32,
1445 eps: f32,
1446 },
1447 GroupNormBackwardGamma {
1448 x: usize,
1449 dy: usize,
1450 dgamma: usize,
1451 n: u32,
1452 c: u32,
1453 h: u32,
1454 w: u32,
1455 num_groups: u32,
1456 eps: f32,
1457 },
1458 GroupNormBackwardBeta {
1459 dy: usize,
1460 dbeta: usize,
1461 n: u32,
1462 c: u32,
1463 h: u32,
1464 w: u32,
1465 },
1466
1467 MaxPool2dBackward {
1473 x: usize,
1474 dy: usize,
1475 dx: usize,
1476 n: u32,
1477 c: u32,
1478 h: u32,
1479 w: u32,
1480 h_out: u32,
1481 w_out: u32,
1482 kh: u32,
1483 kw: u32,
1484 sh: u32,
1485 sw: u32,
1486 ph: u32,
1487 pw: u32,
1488 },
1489
1490 Conv2dBackwardInput {
1494 dy: usize,
1495 w: usize,
1496 dx: usize,
1497 n: u32,
1498 c_in: u32,
1499 h: u32,
1500 w_in: u32,
1501 c_out: u32,
1502 h_out: u32,
1503 w_out: u32,
1504 kh: u32,
1505 kw: u32,
1506 sh: u32,
1507 sw: u32,
1508 ph: u32,
1509 pw: u32,
1510 dh: u32,
1511 dw: u32,
1512 groups: u32,
1513 },
1514
1515 Conv2dBackwardWeight {
1519 x: usize,
1520 dy: usize,
1521 dw: usize,
1522 n: u32,
1523 c_in: u32,
1524 h: u32,
1525 w: u32,
1526 c_out: u32,
1527 h_out: u32,
1528 w_out: u32,
1529 kh: u32,
1530 kw: u32,
1531 sh: u32,
1532 sw: u32,
1533 ph: u32,
1534 pw: u32,
1535 dh: u32,
1536 dw_dil: u32,
1537 groups: u32,
1538 },
1539
1540 Im2Col {
1543 x: usize,
1544 col: usize,
1545 n: u32,
1546 c_in: u32,
1547 h: u32,
1548 w: u32,
1549 h_out: u32,
1550 w_out: u32,
1551 kh: u32,
1552 kw: u32,
1553 sh: u32,
1554 sw: u32,
1555 ph: u32,
1556 pw: u32,
1557 dh: u32,
1558 dw_dil: u32,
1559 },
1560
1561 SoftmaxCrossEntropy {
1565 logits: usize,
1566 labels: usize,
1567 dst: usize,
1568 n: u32,
1569 c: u32,
1570 },
1571
1572 SoftmaxCrossEntropyBackward {
1575 logits: usize,
1576 labels: usize,
1577 d_loss: usize,
1578 dlogits: usize,
1579 n: u32,
1580 c: u32,
1581 },
1582
1583 CustomOp {
1589 kernel: Arc<dyn CpuKernel>,
1590 inputs: Vec<(usize, u32, Shape)>, output: (usize, u32, Shape), attrs: Vec<u8>,
1593 },
1594
1595 GaussianSplatRender {
1605 positions_off: usize,
1606 positions_len: usize,
1607 scales_off: usize,
1608 scales_len: usize,
1609 rotations_off: usize,
1610 rotations_len: usize,
1611 opacities_off: usize,
1612 opacities_len: usize,
1613 colors_off: usize,
1614 colors_len: usize,
1615 sh_coeffs_off: usize,
1616 sh_coeffs_len: usize,
1617 meta_off: usize,
1618 dst_off: usize,
1619 dst_len: usize,
1620 width: u32,
1621 height: u32,
1622 tile_size: u32,
1623 radius_scale: f32,
1624 alpha_cutoff: f32,
1625 max_splat_steps: u32,
1626 transmittance_threshold: f32,
1627 max_list_entries: u32,
1628 },
1629 GaussianSplatRenderBackward {
1630 positions_off: usize,
1631 positions_len: usize,
1632 scales_off: usize,
1633 scales_len: usize,
1634 rotations_off: usize,
1635 rotations_len: usize,
1636 opacities_off: usize,
1637 opacities_len: usize,
1638 colors_off: usize,
1639 colors_len: usize,
1640 sh_coeffs_off: usize,
1641 sh_coeffs_len: usize,
1642 meta_off: usize,
1643 d_loss_off: usize,
1644 d_loss_len: usize,
1645 packed_off: usize,
1646 packed_len: usize,
1647 width: u32,
1648 height: u32,
1649 tile_size: u32,
1650 radius_scale: f32,
1651 alpha_cutoff: f32,
1652 max_splat_steps: u32,
1653 transmittance_threshold: f32,
1654 max_list_entries: u32,
1655 loss_grad_clip: f32,
1656 sh_band: u32,
1657 max_anisotropy: f32,
1658 },
1659 GaussianSplatPrepare {
1661 positions_off: usize,
1662 positions_len: usize,
1663 scales_off: usize,
1664 scales_len: usize,
1665 rotations_off: usize,
1666 rotations_len: usize,
1667 opacities_off: usize,
1668 opacities_len: usize,
1669 colors_off: usize,
1670 colors_len: usize,
1671 sh_coeffs_off: usize,
1672 sh_coeffs_len: usize,
1673 meta_off: usize,
1674 meta_len: usize,
1675 prep_off: usize,
1676 prep_len: usize,
1677 width: u32,
1678 height: u32,
1679 tile_size: u32,
1680 radius_scale: f32,
1681 alpha_cutoff: f32,
1682 max_splat_steps: u32,
1683 transmittance_threshold: f32,
1684 max_list_entries: u32,
1685 },
1686 GaussianSplatRasterize {
1688 prep_off: usize,
1689 prep_len: usize,
1690 meta_off: usize,
1691 meta_len: usize,
1692 dst_off: usize,
1693 dst_len: usize,
1694 count: usize,
1695 width: u32,
1696 height: u32,
1697 tile_size: u32,
1698 alpha_cutoff: f32,
1699 max_splat_steps: u32,
1700 transmittance_threshold: f32,
1701 max_list_entries: u32,
1702 },
1703 Fft1d {
1704 src: usize,
1705 dst: usize,
1706 outer: u32,
1707 n_complex: u32,
1708 inverse: bool,
1709 norm_tag: u32,
1710 dtype: rlx_ir::DType,
1711 },
1712 FftButterflyStage {
1713 state_src: usize,
1714 state_dst: usize,
1715 gate_src: usize,
1716 rev_src: usize,
1717 tw_re_src: usize,
1718 tw_im_src: usize,
1719 batch: u32,
1720 n_fft: u32,
1721 stage: u32,
1722 },
1723 LogMel {
1724 spec: usize,
1725 filters: usize,
1726 dst: usize,
1727 outer: u32,
1728 n_fft: u32,
1729 n_bins: u32,
1730 n_mels: u32,
1731 },
1732 LogMelBackward {
1733 spec: usize,
1734 filters: usize,
1735 dy: usize,
1736 dst: usize,
1737 outer: u32,
1738 n_fft: u32,
1739 n_bins: u32,
1740 n_mels: u32,
1741 },
1742}
1743
1744#[derive(Clone)]
1747pub struct ThunkSchedule {
1748 pub thunks: Vec<Thunk>,
1749 pub moe_resident: Option<std::sync::Arc<[bool]>>,
1751 pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1753 pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1755 pub mask_threshold: f32,
1757 pub mask_neg_inf: f32,
1758 pub score_skip: f32,
1759 pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1765}
1766
1767impl ThunkSchedule {
1768 pub fn strip_nops(&mut self) {
1769 self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1770 self.compiled_fns.clear();
1773 }
1774}
1775
1776fn node_offset(arena: &Arena, id: NodeId) -> usize {
1778 if arena.has_buffer(id) {
1779 arena.byte_offset(id)
1780 } else {
1781 usize::MAX
1782 }
1783}
1784
1785fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1791 match t {
1792 Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1793 Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1794 Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1795 Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1796 Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1797 Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1798 Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1799 Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1800 Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1801 Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1802 Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1803 Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1804 Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1805 Thunk::ConjugateC64 { src, .. } => vec![*src],
1806 Thunk::Scan {
1807 outer_init_off,
1808 xs_inputs,
1809 ..
1810 } => {
1811 let mut v = vec![*outer_init_off];
1812 for (_, outer_xs_off, _) in xs_inputs.iter() {
1813 v.push(*outer_xs_off);
1814 }
1815 v
1816 }
1817 Thunk::ScanBackward {
1818 outer_init_off,
1819 outer_traj_off,
1820 outer_upstream_off,
1821 outer_xs_offs,
1822 ..
1823 } => {
1824 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1825 for (off, _) in outer_xs_offs.iter() {
1826 v.push(*off);
1827 }
1828 v
1829 }
1830 Thunk::ScanBackwardXs {
1831 outer_init_off,
1832 outer_traj_off,
1833 outer_upstream_off,
1834 outer_xs_offs,
1835 ..
1836 } => {
1837 let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1838 for (off, _) in outer_xs_offs.iter() {
1839 v.push(*off);
1840 }
1841 v
1842 }
1843 Thunk::CustomFn { inputs, .. } => {
1844 inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1845 }
1846 Thunk::ActivationInPlace { data, .. } => vec![*data],
1847 Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1848 vec![*src, *g, *b]
1849 }
1850 Thunk::BatchNormInference {
1851 src,
1852 g,
1853 b,
1854 mean,
1855 var,
1856 ..
1857 } => vec![*src, *g, *b, *mean, *var],
1858 Thunk::ResizeNearest2x { src, .. } => vec![*src],
1859 Thunk::AxialRope2d { src, .. } => vec![*src],
1860 Thunk::FusedResidualLN {
1861 x, res, bias, g, b, ..
1862 } => vec![*x, *res, *bias, *g, *b],
1863 Thunk::FusedResidualRmsNorm {
1864 x, res, bias, g, b, ..
1865 } => vec![*x, *res, *bias, *g, *b],
1866 Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1867 Thunk::Softmax { data, .. } => vec![*data],
1868 Thunk::Cumsum { src, .. } => vec![*src],
1869 Thunk::Sample { logits, .. } => vec![*logits],
1870 Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1871 Thunk::DequantMatMul {
1872 x, w_q, scale, zp, ..
1873 } => vec![*x, *w_q, *scale, *zp],
1874 Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1875 Thunk::DequantMatMulInt4 {
1876 x, w_q, scale, zp, ..
1877 } => vec![*x, *w_q, *scale, *zp],
1878 Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1879 Thunk::DequantMatMulNvfp4 {
1880 x,
1881 w_q,
1882 scale,
1883 global_scale,
1884 ..
1885 } => vec![*x, *w_q, *scale, *global_scale],
1886 Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1887 Thunk::SelectiveScan {
1888 x, delta, a, b, c, ..
1889 } => vec![*x, *delta, *a, *b, *c],
1890 Thunk::GatedDeltaNet {
1891 q,
1892 k,
1893 v,
1894 g,
1895 beta,
1896 state,
1897 ..
1898 } => {
1899 let mut v = vec![*q, *k, *v, *g, *beta];
1900 if *state != 0 {
1901 v.push(*state);
1902 }
1903 v
1904 }
1905 Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1906 Thunk::AttentionBackward {
1907 q, k, v, dy, mask, ..
1908 } => {
1909 let mut v = vec![*q, *k, *v, *dy];
1910 if *mask != 0 {
1911 v.push(*mask);
1912 }
1913 v
1914 }
1915 Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1916 Thunk::FusedAttnBlock {
1917 hidden,
1918 qkv_w,
1919 out_w,
1920 mask,
1921 qkv_b,
1922 out_b,
1923 cos,
1924 sin,
1925 ..
1926 } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1927 Thunk::FusedSwiGLU { src, .. } => vec![*src],
1928 Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1929 Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1930 Thunk::Narrow { src, .. } => vec![*src],
1931 Thunk::Copy { src, .. } => vec![*src],
1932 Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1933 _ => vec![],
1937 }
1938}
1939
1940#[allow(clippy::too_many_arguments)]
1954fn dequant_matmul_int8(
1955 x: &[f32], w_bytes: &[i8], scales: &[f32], zps: &[f32], out: &mut [f32], m: usize,
1961 k: usize,
1962 n: usize,
1963 block_size: usize,
1964 asym: bool,
1965) {
1966 let blocks_per_col = k.div_ceil(block_size);
1967 for i in 0..m {
1968 for j in 0..n {
1969 let mut acc = 0f32;
1970 for p in 0..k {
1971 let block = p / block_size;
1972 let s = scales[block * n + j];
1973 let z = if asym { zps[block * n + j] } else { 0.0 };
1974 let q = w_bytes[p * n + j] as f32;
1975 let dequantized = (q - z) * s;
1976 acc += x[i * k + p] * dequantized;
1977 }
1978 out[i * n + j] = acc;
1979 }
1980 }
1981 let _ = blocks_per_col;
1982}
1983
1984#[allow(clippy::too_many_arguments)]
1985fn dequant_matmul_int4(
1986 x: &[f32],
1987 w_bytes: &[u8],
1988 scales: &[f32],
1989 zps: &[f32],
1990 out: &mut [f32],
1991 m: usize,
1992 k: usize,
1993 n: usize,
1994 block_size: usize,
1995 asym: bool,
1996) {
1997 for i in 0..m {
1998 for j in 0..n {
1999 let mut acc = 0f32;
2000 for p in 0..k {
2001 let block = p / block_size;
2002 let s = scales[block * n + j];
2003 let z = if asym { zps[block * n + j] } else { 0.0 };
2004 let byte_idx = (p * n + j) / 2;
2005 let nibble = if (p * n + j) & 1 == 0 {
2006 w_bytes[byte_idx] & 0x0F
2007 } else {
2008 w_bytes[byte_idx] >> 4
2009 };
2010 let dequantized = (nibble as f32 - z) * s;
2011 acc += x[i * k + p] * dequantized;
2012 }
2013 out[i * n + j] = acc;
2014 }
2015 }
2016}
2017
2018fn fp8_e4m3_to_f32(b: u8) -> f32 {
2019 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2020 let exp = (b >> 3) & 0x0F;
2021 let mant = b & 0x07;
2022 if exp == 0 {
2023 if mant == 0 {
2024 return 0.0;
2025 }
2026 return sign * (mant as f32) * 2f32.powi(-9);
2027 }
2028 if exp == 0x0F {
2029 return if mant == 0 {
2030 sign * f32::INFINITY
2031 } else {
2032 f32::NAN
2033 };
2034 }
2035 sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
2036}
2037
2038fn fp8_e5m2_to_f32(b: u8) -> f32 {
2039 let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2040 let exp = (b >> 2) & 0x1F;
2041 let mant = b & 0x03;
2042 if exp == 0 {
2043 if mant == 0 {
2044 return 0.0;
2045 }
2046 return sign * (mant as f32) * 2f32.powi(-16);
2047 }
2048 if exp == 0x1F {
2049 return if mant == 0 {
2050 sign * f32::INFINITY
2051 } else {
2052 f32::NAN
2053 };
2054 }
2055 sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
2056}
2057
2058#[allow(clippy::too_many_arguments)]
2059fn dequant_matmul_fp8(
2060 x: &[f32],
2061 w_bytes: &[u8],
2062 scales: &[f32],
2063 out: &mut [f32],
2064 m: usize,
2065 k: usize,
2066 n: usize,
2067 e5m2: bool,
2068) {
2069 let dequant = if e5m2 {
2070 fp8_e5m2_to_f32
2071 } else {
2072 fp8_e4m3_to_f32
2073 };
2074 for i in 0..m {
2075 for j in 0..n {
2076 let mut acc = 0f32;
2077 for p in 0..k {
2078 let w = dequant(w_bytes[p * n + j]);
2079 let s = scales.get(j).copied().unwrap_or(1.0);
2080 acc += x[i * k + p] * w * s;
2081 }
2082 out[i * n + j] = acc;
2083 }
2084 }
2085}
2086
2087#[allow(clippy::too_many_arguments)]
2088pub fn dequant_matmul_nvfp4(
2089 x: &[f32],
2090 w_bytes: &[u8],
2091 scale_bytes: &[u8],
2092 global_scale: f32,
2093 out: &mut [f32],
2094 m: usize,
2095 k: usize,
2096 n: usize,
2097) {
2098 use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
2099 let gs = NVFP4_GROUP_SIZE;
2100 for i in 0..m {
2101 for j in 0..n {
2102 let mut acc = 0f32;
2103 for p in 0..k {
2104 let byte_idx = (p * n + j) / 2;
2105 let nibble = if (p * n + j) & 1 == 0 {
2106 w_bytes[byte_idx] & 0x0F
2107 } else {
2108 w_bytes[byte_idx] >> 4
2109 };
2110 let block = p / gs;
2111 let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
2112 let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
2113 acc += x[i * k + p] * w;
2114 }
2115 out[i * n + j] = acc;
2116 }
2117 }
2118}
2119
2120fn sample_row(
2129 logits: &[f32],
2130 top_k: usize,
2131 top_p: f32,
2132 temperature: f32,
2133 rng: &mut rlx_ir::Philox4x32,
2134) -> usize {
2135 let v = logits.len();
2136 if v == 0 {
2137 return 0;
2138 }
2139 let temp = temperature.max(1e-6);
2140 let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2142
2143 if top_k > 0 && top_k < v {
2145 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2147 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2150 let cutoff = indexed[top_k - 1].1;
2151 for x in scaled.iter_mut() {
2152 if *x < cutoff {
2153 *x = f32::NEG_INFINITY;
2154 }
2155 }
2156 }
2157
2158 let mut max_l = f32::NEG_INFINITY;
2160 for &x in &scaled {
2161 if x > max_l {
2162 max_l = x;
2163 }
2164 }
2165 let mut sum = 0.0f32;
2166 for x in scaled.iter_mut() {
2167 *x = (*x - max_l).exp();
2168 sum += *x;
2169 }
2170 let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2171 for x in scaled.iter_mut() {
2172 *x *= inv;
2173 }
2174
2175 if top_p < 1.0 {
2178 let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2179 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2180 let mut cum = 0.0f32;
2181 let mut keep = vec![false; v];
2182 for (idx, p) in indexed.iter() {
2183 keep[*idx] = true;
2184 cum += *p;
2185 if cum >= top_p {
2186 break;
2187 }
2188 }
2189 let mut new_sum = 0.0f32;
2190 for (i, x) in scaled.iter_mut().enumerate() {
2191 if !keep[i] {
2192 *x = 0.0;
2193 }
2194 new_sum += *x;
2195 }
2196 let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2197 for x in scaled.iter_mut() {
2198 *x *= inv;
2199 }
2200 }
2201
2202 let r = rng.next_f32();
2204 let mut acc = 0.0f32;
2205 for (i, &p) in scaled.iter().enumerate() {
2206 acc += p;
2207 if r <= acc {
2208 return i;
2209 }
2210 }
2211 v - 1 }
2213
2214#[inline]
2218fn apply_synthetic_mask(
2219 scores: &mut [f32],
2220 q_seq: usize,
2221 k_seq: usize,
2222 kind: rlx_ir::op::MaskKind,
2223) {
2224 let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2225 let q_offset = k_seq.saturating_sub(q_seq);
2226 match kind {
2227 rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2228 rlx_ir::op::MaskKind::Causal => {
2229 for qi in 0..q_seq {
2230 let abs_q = q_offset + qi;
2231 for ki in (abs_q + 1)..k_seq {
2232 scores[qi * k_seq + ki] = neg;
2233 }
2234 }
2235 }
2236 rlx_ir::op::MaskKind::SlidingWindow(w) => {
2237 for qi in 0..q_seq {
2238 let abs_q = q_offset + qi;
2239 let lo = abs_q.saturating_sub(w);
2240 for ki in 0..k_seq {
2241 if ki < lo || ki > abs_q {
2242 scores[qi * k_seq + ki] = neg;
2243 }
2244 }
2245 }
2246 }
2247 }
2248}
2249
2250fn conv_nchw_dims(shape: &Shape) -> (u32, u32, u32, u32) {
2252 match shape.rank() {
2253 3 => (
2254 shape.dim(0).unwrap_static() as u32,
2255 shape.dim(1).unwrap_static() as u32,
2256 1,
2257 shape.dim(2).unwrap_static() as u32,
2258 ),
2259 4 => (
2260 shape.dim(0).unwrap_static() as u32,
2261 shape.dim(1).unwrap_static() as u32,
2262 shape.dim(2).unwrap_static() as u32,
2263 shape.dim(3).unwrap_static() as u32,
2264 ),
2265 r => panic!("conv_nchw_dims: expected rank 3 or 4, got {r}"),
2266 }
2267}
2268
2269pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2271 let mut thunks = Vec::with_capacity(graph.len());
2272
2273 for node in graph.nodes() {
2274 if rlx_opt::is_pure_view(graph, node) {
2278 thunks.push(Thunk::Nop);
2279 continue;
2280 }
2281 let t = match &node.op {
2282 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2283
2284 Op::FusedMatMulBiasAct { activation } => {
2285 let shape = &node.shape;
2286 let n = shape.dim(shape.rank() - 1).unwrap_static();
2287 let total = shape.num_elements().unwrap();
2288 let m = total / n;
2289 let a_len = get_len(graph, node.inputs[0]);
2290 let k = a_len / m;
2291 Thunk::FusedMmBiasAct {
2292 a: node_offset(arena, node.inputs[0]),
2293 w: node_offset(arena, node.inputs[1]),
2294 bias: node_offset(arena, node.inputs[2]),
2295 c: node_offset(arena, node.id),
2296 m: m as u32,
2297 k: k as u32,
2298 n: n as u32,
2299 act: *activation,
2300 }
2301 }
2302
2303 Op::FusedResidualLN { has_bias, eps } => {
2304 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2305 let total = node.shape.num_elements().unwrap();
2306 let rows = total / h;
2307 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2308 Thunk::FusedResidualLN {
2309 x: node_offset(arena, node.inputs[0]),
2310 res: node_offset(arena, node.inputs[1]),
2311 bias: if *has_bias {
2312 node_offset(arena, node.inputs[2])
2313 } else {
2314 0
2315 },
2316 g: node_offset(arena, node.inputs[g_idx]),
2317 b: node_offset(arena, node.inputs[b_idx]),
2318 out: node_offset(arena, node.id),
2319 rows: rows as u32,
2320 h: h as u32,
2321 eps: *eps,
2322 has_bias: *has_bias,
2323 }
2324 }
2325
2326 Op::FusedResidualRmsNorm { has_bias, eps } => {
2327 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2328 let total = node.shape.num_elements().unwrap();
2329 let rows = total / h;
2330 let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2331 Thunk::FusedResidualRmsNorm {
2332 x: node_offset(arena, node.inputs[0]),
2333 res: node_offset(arena, node.inputs[1]),
2334 bias: if *has_bias {
2335 node_offset(arena, node.inputs[2])
2336 } else {
2337 0
2338 },
2339 g: node_offset(arena, node.inputs[g_idx]),
2340 b: node_offset(arena, node.inputs[b_idx]),
2341 out: node_offset(arena, node.id),
2342 rows: rows as u32,
2343 h: h as u32,
2344 eps: *eps,
2345 has_bias: *has_bias,
2346 }
2347 }
2348
2349 Op::MatMul => {
2350 let shape = &node.shape;
2351 let a_shape = &graph.node(node.inputs[0]).shape;
2352 let b_shape = &graph.node(node.inputs[1]).shape;
2353 let eff =
2356 rlx_ir::shape::matmul_shape(a_shape, b_shape).unwrap_or_else(|_| shape.clone());
2357 let rank = eff.rank().max(2);
2358 let n = eff.dim(rank - 1).unwrap_static();
2359 let k_dim = a_shape.dim(a_shape.rank().max(2) - 1).unwrap_static();
2360 let both_batched = a_shape.rank() >= 3 && b_shape.rank() >= 3;
2363 let batched_3d = rank >= 3 && both_batched && a_shape.rank() + b_shape.rank() > 4;
2364 if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2365 let mut batch_prod = 1usize;
2366 for d in 0..rank - 2 {
2367 batch_prod *= eff.dim(d).unwrap_static();
2368 }
2369 let m_dim = eff.dim(rank - 2).unwrap_static();
2370 Thunk::BatchedDgemmF64 {
2371 a: node_offset(arena, node.inputs[0]),
2372 b: node_offset(arena, node.inputs[1]),
2373 c: node_offset(arena, node.id),
2374 batch: batch_prod as u32,
2375 m: m_dim as u32,
2376 k: k_dim as u32,
2377 n: n as u32,
2378 }
2379 } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2380 let mut batch_prod = 1usize;
2381 for d in 0..rank - 2 {
2382 batch_prod *= eff.dim(d).unwrap_static();
2383 }
2384 let m_dim = eff.dim(rank - 2).unwrap_static();
2385 Thunk::BatchedSgemm {
2386 a: node_offset(arena, node.inputs[0]),
2387 b: node_offset(arena, node.inputs[1]),
2388 c: node_offset(arena, node.id),
2389 batch: batch_prod as u32,
2390 m: m_dim as u32,
2391 k: k_dim as u32,
2392 n: n as u32,
2393 }
2394 } else {
2395 let m = if a_shape.rank() >= 3 && b_shape.rank() <= 2 {
2396 let mut m_prod = 1usize;
2397 for d in 0..a_shape.rank() - 1 {
2398 m_prod *= a_shape.dim(d).unwrap_static();
2399 }
2400 m_prod
2401 } else if a_shape.rank() >= 2 {
2402 a_shape.dim(a_shape.rank() - 2).unwrap_static()
2403 } else {
2404 eff.num_elements().unwrap_or(1) / n.max(1)
2405 };
2406 match shape.dtype() {
2407 rlx_ir::DType::F64 => Thunk::Dgemm {
2408 a: node_offset(arena, node.inputs[0]),
2409 b: node_offset(arena, node.inputs[1]),
2410 c: node_offset(arena, node.id),
2411 m: m as u32,
2412 k: k_dim as u32,
2413 n: n as u32,
2414 },
2415 _ => Thunk::Sgemm {
2416 a: node_offset(arena, node.inputs[0]),
2417 b: node_offset(arena, node.inputs[1]),
2418 c: node_offset(arena, node.id),
2419 m: m as u32,
2420 k: k_dim as u32,
2421 n: n as u32,
2422 },
2423 }
2424 }
2425 }
2426
2427 Op::Binary(op) => {
2428 let lhs_len = get_len(graph, node.inputs[0]);
2429 let rhs_len = get_len(graph, node.inputs[1]);
2430 let out_len = node.shape.num_elements().unwrap();
2431 if node.shape.dtype() == rlx_ir::DType::C64 {
2432 match op {
2436 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2437 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2438 "Op::Binary({op:?}) on DType::C64: complex \
2439 max/min/pow have no single natural definition \
2440 — caller should drop to 2N-real-block (see \
2441 spike-ac) and pick a convention there"
2442 ),
2443 }
2444 }
2445 let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2449 if lhs_len == out_len && rhs_len == out_len {
2450 (Vec::new(), Vec::new(), Vec::new())
2451 } else {
2452 let lhs_dims = get_static_dims(graph, node.inputs[0]);
2453 let rhs_dims = get_static_dims(graph, node.inputs[1]);
2454 let out_dims_v = get_static_dims(graph, node.id);
2455 if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2456 (Vec::new(), Vec::new(), Vec::new())
2461 } else {
2462 let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2463 let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2464 let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2465 (od, ls, rs)
2466 }
2467 };
2468 if node.shape.dtype() == rlx_ir::DType::C64 {
2469 Thunk::BinaryFullC64 {
2470 lhs: node_offset(arena, node.inputs[0]),
2471 rhs: node_offset(arena, node.inputs[1]),
2472 dst: node_offset(arena, node.id),
2473 len: out_len as u32,
2474 lhs_len: lhs_len as u32,
2475 rhs_len: rhs_len as u32,
2476 op: *op,
2477 out_dims_bcast,
2478 bcast_lhs_strides,
2479 bcast_rhs_strides,
2480 }
2481 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2482 Thunk::BinaryFullF64 {
2485 lhs: node_offset(arena, node.inputs[0]),
2486 rhs: node_offset(arena, node.inputs[1]),
2487 dst: node_offset(arena, node.id),
2488 len: out_len as u32,
2489 lhs_len: lhs_len as u32,
2490 rhs_len: rhs_len as u32,
2491 op: *op,
2492 out_dims_bcast,
2493 bcast_lhs_strides,
2494 bcast_rhs_strides,
2495 }
2496 } else if matches!(op, BinaryOp::Add)
2497 && rhs_len < out_len
2498 && out_len % rhs_len == 0
2499 && is_trailing_bias_broadcast(
2500 graph.node(node.inputs[1]).shape.dims(),
2501 graph.node(node.id).shape.dims(),
2502 )
2503 {
2504 Thunk::BiasAdd {
2514 src: node_offset(arena, node.inputs[0]),
2515 bias: node_offset(arena, node.inputs[1]),
2516 dst: node_offset(arena, node.id),
2517 m: (out_len / rhs_len) as u32,
2518 n: rhs_len as u32,
2519 }
2520 } else {
2521 let lhs_len = get_len(graph, node.inputs[0]);
2522 Thunk::BinaryFull {
2523 lhs: node_offset(arena, node.inputs[0]),
2524 rhs: node_offset(arena, node.inputs[1]),
2525 dst: node_offset(arena, node.id),
2526 len: out_len as u32,
2527 lhs_len: lhs_len as u32,
2528 rhs_len: rhs_len as u32,
2529 op: *op,
2530 out_dims_bcast,
2531 bcast_lhs_strides,
2532 bcast_rhs_strides,
2533 elem_bytes: node.shape.dtype().size_bytes() as u8,
2534 }
2535 }
2536 }
2537
2538 Op::Activation(act) => {
2539 let len = node.shape.num_elements().unwrap();
2540 let in_off = node_offset(arena, node.inputs[0]);
2541 let out_off = node_offset(arena, node.id);
2542 if node.shape.dtype() == rlx_ir::DType::C64 {
2543 match act {
2548 Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2549 other => panic!(
2550 "Op::Activation({other:?}) on DType::C64: no \
2551 natural complex extension — supported on C64: \
2552 Neg, Exp, Log, Sqrt"
2553 ),
2554 }
2555 Thunk::ActivationC64 {
2556 src: in_off,
2557 dst: out_off,
2558 len: len as u32,
2559 kind: *act,
2560 }
2561 } else if node.shape.dtype() == rlx_ir::DType::F64 {
2562 Thunk::ActivationF64 {
2563 src: in_off,
2564 dst: out_off,
2565 len: len as u32,
2566 kind: *act,
2567 }
2568 } else if in_off == out_off {
2569 Thunk::ActivationInPlace {
2573 data: out_off,
2574 len: len as u32,
2575 act: *act,
2576 }
2577 } else {
2578 thunks.push(Thunk::Copy {
2582 src: in_off,
2583 dst: out_off,
2584 len: len as u32,
2585 });
2586 Thunk::ActivationInPlace {
2587 data: out_off,
2588 len: len as u32,
2589 act: *act,
2590 }
2591 }
2592 }
2593
2594 Op::Gather { axis } if *axis == 0 => {
2595 let table_shape = &graph.node(node.inputs[0]).shape;
2596 let table_total = table_shape.num_elements().unwrap();
2597 let trailing: usize = (1..table_shape.rank())
2598 .map(|i| table_shape.dim(i).unwrap_static())
2599 .product();
2600 let idx_len = get_len(graph, node.inputs[1]);
2601 let idx_i64 =
2602 u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2603 let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2604 Thunk::Gather {
2605 table: node_offset(arena, node.inputs[0]),
2606 table_len: table_total as u32,
2607 idx: node_offset(arena, node.inputs[1]),
2608 dst: node_offset(arena, node.id),
2609 num_idx: idx_len as u32,
2610 trailing: trailing as u32,
2611 idx_i64,
2612 table_bytes,
2613 }
2614 }
2615
2616 Op::Gather { axis } => {
2617 let table_shape = &graph.node(node.inputs[0]).shape;
2619 let rank = table_shape.rank();
2620 let outer: usize = (0..*axis)
2621 .map(|i| table_shape.dim(i).unwrap_static())
2622 .product::<usize>()
2623 .max(1);
2624 let trailing: usize = (*axis + 1..rank)
2625 .map(|i| table_shape.dim(i).unwrap_static())
2626 .product::<usize>()
2627 .max(1);
2628 let axis_dim = table_shape.dim(*axis).unwrap_static();
2629 let idx_len = get_len(graph, node.inputs[1]);
2630 let idx_i64 =
2631 u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2632 let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2633 Thunk::GatherAxis {
2634 table: node_offset(arena, node.inputs[0]),
2635 idx: node_offset(arena, node.inputs[1]),
2636 dst: node_offset(arena, node.id),
2637 outer: outer as u32,
2638 axis_dim: axis_dim as u32,
2639 num_idx: idx_len as u32,
2640 trailing: trailing as u32,
2641 idx_i64,
2642 table_bytes,
2643 }
2644 }
2645
2646 Op::Narrow { axis, start, len } => {
2647 let in_shape = &graph.node(node.inputs[0]).shape;
2648 let elem_bytes = in_shape.dtype().size_bytes() as u8;
2649 let rank = in_shape.rank();
2650 let outer: usize = (0..*axis)
2651 .map(|i| in_shape.dim(i).unwrap_static())
2652 .product::<usize>()
2653 .max(1);
2654 let inner: usize = (*axis + 1..rank)
2655 .map(|i| in_shape.dim(i).unwrap_static())
2656 .product::<usize>()
2657 .max(1);
2658 let in_axis = in_shape.dim(*axis).unwrap_static();
2659 let src_byte_offset =
2660 node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2661 Thunk::Narrow {
2662 src: src_byte_offset,
2663 dst: node_offset(arena, node.id),
2664 outer: outer as u32,
2665 src_stride: (in_axis * inner) as u32, dst_stride: (*len * inner) as u32, inner: (*len * inner) as u32, elem_bytes,
2669 }
2670 }
2671
2672 Op::Reshape { .. } | Op::StopGradient => {
2673 let len = node.shape.num_elements().unwrap();
2675 let src = node_offset(arena, node.inputs[0]);
2676 let dst = node_offset(arena, node.id);
2677 match node.shape.dtype() {
2678 rlx_ir::DType::F64 => Thunk::CopyF64 {
2679 src,
2680 dst,
2681 len: len as u32,
2682 },
2683 rlx_ir::DType::I64 => Thunk::CopyI64 {
2684 src,
2685 dst,
2686 len: len as u32,
2687 },
2688 _ => Thunk::Copy {
2689 src,
2690 dst,
2691 len: len as u32,
2692 },
2693 }
2694 }
2695
2696 Op::Cast { to } => {
2697 let in_node = graph.node(node.inputs[0]);
2698 let in_dtype = in_node.shape.dtype();
2699 let out_dtype = *to;
2700 let len = node.shape.num_elements().unwrap();
2701 let src = node_offset(arena, node.inputs[0]);
2702 let dst = node_offset(arena, node.id);
2703 if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::I64 {
2704 Thunk::CastF32ToI64 {
2705 src,
2706 dst,
2707 len: len as u32,
2708 }
2709 } else if in_dtype == rlx_ir::DType::I64 && out_dtype == rlx_ir::DType::F32 {
2710 Thunk::CastI64ToF32 {
2711 src,
2712 dst,
2713 len: len as u32,
2714 }
2715 } else if in_dtype == rlx_ir::DType::Bool && out_dtype == rlx_ir::DType::I32 {
2716 Thunk::CastBoolToI32 {
2717 src,
2718 dst,
2719 len: len as u32,
2720 }
2721 } else if in_dtype == rlx_ir::DType::I32 && out_dtype == rlx_ir::DType::F32 {
2722 Thunk::CastI32ToF32 {
2723 src,
2724 dst,
2725 len: len as u32,
2726 }
2727 } else if in_dtype == out_dtype {
2728 match out_dtype {
2729 rlx_ir::DType::F64 => Thunk::CopyF64 {
2730 src,
2731 dst,
2732 len: len as u32,
2733 },
2734 rlx_ir::DType::I64 => Thunk::CopyI64 {
2735 src,
2736 dst,
2737 len: len as u32,
2738 },
2739 _ => Thunk::Copy {
2740 src,
2741 dst,
2742 len: len as u32,
2743 },
2744 }
2745 } else {
2746 Thunk::Copy {
2747 src,
2748 dst,
2749 len: len as u32,
2750 }
2751 }
2752 }
2753
2754 Op::Quantize {
2755 axis,
2756 scales,
2757 zero_points,
2758 } => {
2759 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2760 Thunk::Quantize {
2761 x: node_offset(arena, node.inputs[0]),
2762 q: node_offset(arena, node.id),
2763 len: node.shape.num_elements().unwrap() as u32,
2764 chan_axis: chan_axis as u32,
2765 chan_dim: chan_dim as u32,
2766 inner: inner as u32,
2767 scales: scales.clone(),
2768 zero_points: zero_points.clone(),
2769 }
2770 }
2771
2772 Op::FakeQuantize {
2773 bits,
2774 axis,
2775 ste,
2776 scale_mode,
2777 } => {
2778 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2779 let state_off = match scale_mode {
2780 rlx_ir::op::ScaleMode::PerBatch => None,
2781 rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2782 debug_assert_eq!(
2784 node.inputs.len(),
2785 2,
2786 "EMA/Fixed FakeQuantize needs a state input"
2787 );
2788 Some(node_offset(arena, node.inputs[1]))
2789 }
2790 };
2791 Thunk::FakeQuantize {
2792 x: node_offset(arena, node.inputs[0]),
2793 out: node_offset(arena, node.id),
2794 len: node.shape.num_elements().unwrap() as u32,
2795 chan_axis: chan_axis as u32,
2796 chan_dim: chan_dim as u32,
2797 inner: inner as u32,
2798 bits: *bits,
2799 ste: *ste,
2800 scale_mode: *scale_mode,
2801 state_off,
2802 }
2803 }
2804
2805 Op::FakeQuantizeLSQ { bits, axis } => {
2806 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2807 Thunk::FakeQuantizeLSQ {
2808 x: node_offset(arena, node.inputs[0]),
2809 scale_off: node_offset(arena, node.inputs[1]),
2810 out: node_offset(arena, node.id),
2811 len: node.shape.num_elements().unwrap() as u32,
2812 chan_axis: chan_axis as u32,
2813 chan_dim: chan_dim as u32,
2814 inner: inner as u32,
2815 bits: *bits,
2816 }
2817 }
2818
2819 Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2820 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2821 Thunk::FakeQuantizeLSQBackwardX {
2822 x: node_offset(arena, node.inputs[0]),
2823 scale_off: node_offset(arena, node.inputs[1]),
2824 dy: node_offset(arena, node.inputs[2]),
2825 dx: node_offset(arena, node.id),
2826 len: node.shape.num_elements().unwrap() as u32,
2827 chan_axis: chan_axis as u32,
2828 chan_dim: chan_dim as u32,
2829 inner: inner as u32,
2830 bits: *bits,
2831 }
2832 }
2833
2834 Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2835 let in_shape = &graph.node(node.inputs[0]).shape;
2838 let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2839 Thunk::FakeQuantizeLSQBackwardScale {
2840 x: node_offset(arena, node.inputs[0]),
2841 scale_off: node_offset(arena, node.inputs[1]),
2842 dy: node_offset(arena, node.inputs[2]),
2843 dscale: node_offset(arena, node.id),
2844 len: in_shape.num_elements().unwrap() as u32,
2845 chan_axis: chan_axis as u32,
2846 chan_dim: chan_dim as u32,
2847 inner: inner as u32,
2848 bits: *bits,
2849 }
2850 }
2851
2852 Op::FakeQuantizeBackward { bits, axis, ste } => {
2853 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2854 Thunk::FakeQuantizeBackward {
2855 x: node_offset(arena, node.inputs[0]),
2856 dy: node_offset(arena, node.inputs[1]),
2857 dx: node_offset(arena, node.id),
2858 len: node.shape.num_elements().unwrap() as u32,
2859 chan_axis: chan_axis as u32,
2860 chan_dim: chan_dim as u32,
2861 inner: inner as u32,
2862 bits: *bits,
2863 ste: *ste,
2864 }
2865 }
2866
2867 Op::Dequantize {
2868 axis,
2869 scales,
2870 zero_points,
2871 } => {
2872 let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2873 Thunk::Dequantize {
2874 q: node_offset(arena, node.inputs[0]),
2875 x: node_offset(arena, node.id),
2876 len: node.shape.num_elements().unwrap() as u32,
2877 chan_axis: chan_axis as u32,
2878 chan_dim: chan_dim as u32,
2879 inner: inner as u32,
2880 scales: scales.clone(),
2881 zero_points: zero_points.clone(),
2882 }
2883 }
2884
2885 Op::Expand { .. } => {
2886 let in_shape = &graph.node(node.inputs[0]).shape;
2891 let out_shape = &node.shape;
2892 let in_rank = in_shape.rank();
2893 let out_rank = out_shape.rank();
2894 let pad = out_rank.saturating_sub(in_rank);
2896 let in_dims: Vec<usize> = (0..out_rank)
2897 .map(|i| {
2898 if i < pad {
2899 1
2900 } else {
2901 in_shape.dim(i - pad).unwrap_static()
2902 }
2903 })
2904 .collect();
2905 let mut in_strides_full = vec![1usize; out_rank];
2907 for d in (0..out_rank.saturating_sub(1)).rev() {
2908 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2909 }
2910 let out_dims: Vec<u32> = (0..out_rank)
2911 .map(|i| out_shape.dim(i).unwrap_static() as u32)
2912 .collect();
2913 let in_strides: Vec<u32> = (0..out_rank)
2915 .map(|i| {
2916 if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2917 0
2918 } else {
2919 in_strides_full[i] as u32
2920 }
2921 })
2922 .collect();
2923 let in_total = in_dims.iter().product::<usize>() as u32;
2924 let src = node_offset(arena, node.inputs[0]);
2925 let dst = node_offset(arena, node.id);
2926 let elem_bytes = node.shape.dtype().size_bytes() as u8;
2927 match node.shape.dtype() {
2928 rlx_ir::DType::F64 => Thunk::TransposeF64 {
2929 src,
2930 dst,
2931 in_total,
2932 out_dims,
2933 in_strides,
2934 },
2935 _ => Thunk::Transpose {
2936 src,
2937 dst,
2938 in_total,
2939 out_dims,
2940 in_strides,
2941 elem_bytes,
2942 },
2943 }
2944 }
2945
2946 Op::RmsNorm { eps, .. } => {
2947 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2948 let total = node.shape.num_elements().unwrap();
2949 Thunk::RmsNorm {
2950 src: node_offset(arena, node.inputs[0]),
2951 g: node_offset(arena, node.inputs[1]),
2952 b: node_offset(arena, node.inputs[2]),
2953 dst: node_offset(arena, node.id),
2954 rows: (total / h) as u32,
2955 h: h as u32,
2956 eps: *eps,
2957 }
2958 }
2959
2960 Op::LayerNorm { eps, .. } => {
2961 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2962 let total = node.shape.num_elements().unwrap();
2963 Thunk::LayerNorm {
2964 src: node_offset(arena, node.inputs[0]),
2965 g: node_offset(arena, node.inputs[1]),
2966 b: node_offset(arena, node.inputs[2]),
2967 dst: node_offset(arena, node.id),
2968 rows: (total / h) as u32,
2969 h: h as u32,
2970 eps: *eps,
2971 }
2972 }
2973
2974 Op::GroupNorm { num_groups, eps } => {
2975 let in_shape = &graph.node(node.inputs[0]).shape;
2976 let (n, c, h, w) = conv_nchw_dims(in_shape);
2977 Thunk::GroupNorm {
2978 src: node_offset(arena, node.inputs[0]),
2979 g: node_offset(arena, node.inputs[1]),
2980 b: node_offset(arena, node.inputs[2]),
2981 dst: node_offset(arena, node.id),
2982 n,
2983 c,
2984 h,
2985 w,
2986 num_groups: *num_groups as u32,
2987 eps: *eps,
2988 }
2989 }
2990
2991 Op::BatchNormInference { eps } => {
2992 let in_shape = &graph.node(node.inputs[0]).shape;
2993 let rank = in_shape.rank();
2994 let channels = in_shape.dim(rank - 1).unwrap_static();
2995 let total = in_shape.num_elements().unwrap_or(0);
2996 let count = (total / channels.max(1)) as u32;
2997 Thunk::BatchNormInference {
2998 src: node_offset(arena, node.inputs[0]),
2999 g: node_offset(arena, node.inputs[1]),
3000 b: node_offset(arena, node.inputs[2]),
3001 mean: node_offset(arena, node.inputs[3]),
3002 var: node_offset(arena, node.inputs[4]),
3003 dst: node_offset(arena, node.id),
3004 count,
3005 channels: channels as u32,
3006 eps: *eps,
3007 }
3008 }
3009
3010 Op::BatchNormInferenceBackwardInput { eps } => {
3011 let x_shape = &graph.node(node.inputs[0]).shape;
3012 let rank = x_shape.rank();
3013 let channels = x_shape.dim(rank - 1).unwrap_static();
3014 let total = x_shape.num_elements().unwrap_or(0);
3015 Thunk::BatchNormInferenceBackwardInput {
3016 x: node_offset(arena, node.inputs[0]),
3017 gamma: node_offset(arena, node.inputs[1]),
3018 mean: node_offset(arena, node.inputs[2]),
3019 var: node_offset(arena, node.inputs[3]),
3020 dy: node_offset(arena, node.inputs[4]),
3021 dx: node_offset(arena, node.id),
3022 count: (total / channels.max(1)) as u32,
3023 channels: channels as u32,
3024 eps: *eps,
3025 }
3026 }
3027
3028 Op::BatchNormInferenceBackwardGamma { eps } => {
3029 let x_shape = &graph.node(node.inputs[0]).shape;
3030 let rank = x_shape.rank();
3031 let channels = x_shape.dim(rank - 1).unwrap_static();
3032 let total = x_shape.num_elements().unwrap_or(0);
3033 let _gamma_shape = &graph.node(node.id).shape;
3034 Thunk::BatchNormInferenceBackwardGamma {
3035 x: node_offset(arena, node.inputs[0]),
3036 mean: node_offset(arena, node.inputs[1]),
3037 var: node_offset(arena, node.inputs[2]),
3038 dy: node_offset(arena, node.inputs[3]),
3039 dgamma: node_offset(arena, node.id),
3040 count: (total / channels.max(1)) as u32,
3041 channels: channels as u32,
3042 eps: *eps,
3043 }
3044 }
3045
3046 Op::BatchNormInferenceBackwardBeta => {
3047 let dy_shape = &graph.node(node.inputs[0]).shape;
3048 let rank = dy_shape.rank();
3049 let channels = dy_shape.dim(rank - 1).unwrap_static();
3050 let total = dy_shape.num_elements().unwrap_or(0);
3051 Thunk::BatchNormInferenceBackwardBeta {
3052 dy: node_offset(arena, node.inputs[0]),
3053 dbeta: node_offset(arena, node.id),
3054 count: (total / channels.max(1)) as u32,
3055 channels: channels as u32,
3056 }
3057 }
3058
3059 Op::LayerNorm2d { eps } => {
3060 let in_shape = &graph.node(node.inputs[0]).shape;
3061 let (n, c, h, w) = conv_nchw_dims(in_shape);
3062 Thunk::LayerNorm2d {
3063 src: node_offset(arena, node.inputs[0]),
3064 g: node_offset(arena, node.inputs[1]),
3065 b: node_offset(arena, node.inputs[2]),
3066 dst: node_offset(arena, node.id),
3067 n,
3068 c,
3069 h,
3070 w,
3071 eps: *eps,
3072 }
3073 }
3074
3075 Op::ConvTranspose2d {
3076 kernel_size,
3077 stride,
3078 padding,
3079 dilation,
3080 output_padding: _,
3081 groups,
3082 } => {
3083 let in_shape = &graph.node(node.inputs[0]).shape;
3084 let out_shape = &node.shape;
3085 let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3086 let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3087 Thunk::ConvTranspose2d {
3088 src: node_offset(arena, node.inputs[0]),
3089 weight: node_offset(arena, node.inputs[1]),
3090 dst: node_offset(arena, node.id),
3091 n,
3092 c_in,
3093 h,
3094 w_in,
3095 c_out,
3096 h_out,
3097 w_out,
3098 kh: kernel_size[0] as u32,
3099 kw: kernel_size[1] as u32,
3100 sh: stride.first().copied().unwrap_or(1) as u32,
3101 sw: stride.get(1).copied().unwrap_or(1) as u32,
3102 ph: padding.first().copied().unwrap_or(0) as u32,
3103 pw: padding.get(1).copied().unwrap_or(0) as u32,
3104 dh: dilation.first().copied().unwrap_or(1) as u32,
3105 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3106 groups: *groups as u32,
3107 }
3108 }
3109
3110 Op::ResizeNearest2x => {
3111 let in_shape = &graph.node(node.inputs[0]).shape;
3112 let (n, c, h, w) = conv_nchw_dims(in_shape);
3113 Thunk::ResizeNearest2x {
3114 src: node_offset(arena, node.inputs[0]),
3115 dst: node_offset(arena, node.id),
3116 n,
3117 c,
3118 h,
3119 w,
3120 }
3121 }
3122
3123 Op::AxialRope2d {
3124 end_x,
3125 end_y,
3126 head_dim,
3127 num_heads,
3128 theta,
3129 repeat_factor,
3130 } => {
3131 let in_shape = &graph.node(node.inputs[0]).shape;
3132 let batch = in_shape.dim(0).unwrap_static() as u32;
3133 let seq = in_shape.dim(1).unwrap_static() as u32;
3134 let hidden = in_shape.dim(2).unwrap_static() as u32;
3135 Thunk::AxialRope2d {
3136 src: node_offset(arena, node.inputs[0]),
3137 dst: node_offset(arena, node.id),
3138 batch,
3139 seq,
3140 hidden,
3141 end_x: *end_x as u32,
3142 end_y: *end_y as u32,
3143 head_dim: *head_dim as u32,
3144 num_heads: *num_heads as u32,
3145 theta: *theta,
3146 repeat_factor: *repeat_factor as u32,
3147 }
3148 }
3149
3150 Op::Softmax { axis } => {
3151 let rank = node.shape.rank();
3152 let ax = if *axis < 0 {
3153 (rank as i32 + axis) as usize
3154 } else {
3155 *axis as usize
3156 };
3157 let cols = node.shape.dim(ax).unwrap_static();
3158 let total = node.shape.num_elements().unwrap();
3159 let in_off = node_offset(arena, node.inputs[0]);
3160 let out_off = node_offset(arena, node.id);
3161 if in_off != out_off {
3167 thunks.push(Thunk::Copy {
3168 src: in_off,
3169 dst: out_off,
3170 len: total as u32,
3171 });
3172 }
3173 Thunk::Softmax {
3174 data: out_off,
3175 rows: (total / cols) as u32,
3176 cols: cols as u32,
3177 }
3178 }
3179
3180 Op::SelectiveScan { state_size } => {
3181 let in_shape = &graph.node(node.inputs[0]).shape;
3182 let (batch, seq, hidden) = (
3183 in_shape.dim(0).unwrap_static(),
3184 in_shape.dim(1).unwrap_static(),
3185 in_shape.dim(2).unwrap_static(),
3186 );
3187 Thunk::SelectiveScan {
3188 x: node_offset(arena, node.inputs[0]),
3189 delta: node_offset(arena, node.inputs[1]),
3190 a: node_offset(arena, node.inputs[2]),
3191 b: node_offset(arena, node.inputs[3]),
3192 c: node_offset(arena, node.inputs[4]),
3193 dst: node_offset(arena, node.id),
3194 batch: batch as u32,
3195 seq: seq as u32,
3196 hidden: hidden as u32,
3197 state_size: *state_size as u32,
3198 }
3199 }
3200
3201 Op::GatedDeltaNet {
3202 state_size,
3203 carry_state,
3204 } => {
3205 let q_shape = &graph.node(node.inputs[0]).shape;
3206 let (batch, seq, heads) = (
3207 q_shape.dim(0).unwrap_static(),
3208 q_shape.dim(1).unwrap_static(),
3209 q_shape.dim(2).unwrap_static(),
3210 );
3211 let state_off = if *carry_state {
3212 node_offset(arena, node.inputs[5])
3213 } else {
3214 0
3215 };
3216 Thunk::GatedDeltaNet {
3217 q: node_offset(arena, node.inputs[0]),
3218 k: node_offset(arena, node.inputs[1]),
3219 v: node_offset(arena, node.inputs[2]),
3220 g: node_offset(arena, node.inputs[3]),
3221 beta: node_offset(arena, node.inputs[4]),
3222 state: state_off,
3223 dst: node_offset(arena, node.id),
3224 batch: batch as u32,
3225 seq: seq as u32,
3226 heads: heads as u32,
3227 state_size: *state_size as u32,
3228 }
3229 }
3230
3231 Op::QMatMul {
3232 x_zp,
3233 w_zp,
3234 out_zp,
3235 mult,
3236 } => {
3237 let x_shape = &graph.node(node.inputs[0]).shape;
3238 let w_shape = &graph.node(node.inputs[1]).shape;
3239 let m = x_shape.dim(0).unwrap_static();
3240 let k = x_shape.dim(1).unwrap_static();
3241 let n = w_shape.dim(1).unwrap_static();
3242 Thunk::QMatMul {
3243 x: node_offset(arena, node.inputs[0]),
3244 w: node_offset(arena, node.inputs[1]),
3245 bias: node_offset(arena, node.inputs[2]),
3246 out: node_offset(arena, node.id),
3247 m: m as u32,
3248 k: k as u32,
3249 n: n as u32,
3250 x_zp: *x_zp,
3251 w_zp: *w_zp,
3252 out_zp: *out_zp,
3253 mult: *mult,
3254 }
3255 }
3256
3257 Op::QConv2d {
3258 kernel_size,
3259 stride,
3260 padding,
3261 dilation,
3262 groups,
3263 x_zp,
3264 w_zp,
3265 out_zp,
3266 mult,
3267 } => {
3268 let in_shape = &graph.node(node.inputs[0]).shape;
3269 let w_shape = &graph.node(node.inputs[1]).shape;
3270 let out_shape = &node.shape;
3271 if kernel_size.len() == 2
3272 && in_shape.rank() == 4
3273 && w_shape.rank() == 4
3274 && out_shape.rank() == 4
3275 {
3276 Thunk::QConv2d {
3277 x: node_offset(arena, node.inputs[0]),
3278 w: node_offset(arena, node.inputs[1]),
3279 bias: node_offset(arena, node.inputs[2]),
3280 out: node_offset(arena, node.id),
3281 n: in_shape.dim(0).unwrap_static() as u32,
3282 c_in: in_shape.dim(1).unwrap_static() as u32,
3283 h: in_shape.dim(2).unwrap_static() as u32,
3284 w_in: in_shape.dim(3).unwrap_static() as u32,
3285 c_out: out_shape.dim(1).unwrap_static() as u32,
3286 h_out: out_shape.dim(2).unwrap_static() as u32,
3287 w_out: out_shape.dim(3).unwrap_static() as u32,
3288 kh: kernel_size[0] as u32,
3289 kw: kernel_size[1] as u32,
3290 sh: stride.first().copied().unwrap_or(1) as u32,
3291 sw: stride.get(1).copied().unwrap_or(1) as u32,
3292 ph: padding.first().copied().unwrap_or(0) as u32,
3293 pw: padding.get(1).copied().unwrap_or(0) as u32,
3294 dh: dilation.first().copied().unwrap_or(1) as u32,
3295 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3296 groups: *groups as u32,
3297 x_zp: *x_zp,
3298 w_zp: *w_zp,
3299 out_zp: *out_zp,
3300 mult: *mult,
3301 }
3302 } else {
3303 Thunk::Nop
3304 }
3305 }
3306
3307 Op::DequantMatMul { scheme } => {
3308 use rlx_ir::quant::QuantScheme;
3309 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3310 let total = node.shape.num_elements().unwrap();
3311 let m = total / n.max(1);
3312 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3313 let k = x_total / m.max(1);
3314 if scheme.is_gguf() {
3315 Thunk::DequantMatMulGguf {
3316 x: node_offset(arena, node.inputs[0]),
3317 w_q: node_offset(arena, node.inputs[1]),
3318 dst: node_offset(arena, node.id),
3319 m: m as u32,
3320 k: k as u32,
3321 n: n as u32,
3322 scheme: *scheme,
3323 }
3324 } else {
3325 match scheme {
3326 QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3327 x: node_offset(arena, node.inputs[0]),
3328 w_q: node_offset(arena, node.inputs[1]),
3329 scale: node_offset(arena, node.inputs[2]),
3330 global_scale: node_offset(arena, node.inputs[3]),
3331 dst: node_offset(arena, node.id),
3332 m: m as u32,
3333 k: k as u32,
3334 n: n as u32,
3335 },
3336 QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3337 x: node_offset(arena, node.inputs[0]),
3338 w_q: node_offset(arena, node.inputs[1]),
3339 scale: node_offset(arena, node.inputs[2]),
3340 zp: node_offset(arena, node.inputs[3]),
3341 dst: node_offset(arena, node.id),
3342 m: m as u32,
3343 k: k as u32,
3344 n: n as u32,
3345 block_size: *block_size,
3346 is_asymmetric: false,
3347 },
3348 QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3349 x: node_offset(arena, node.inputs[0]),
3350 w_q: node_offset(arena, node.inputs[1]),
3351 scale: node_offset(arena, node.inputs[2]),
3352 dst: node_offset(arena, node.id),
3353 m: m as u32,
3354 k: k as u32,
3355 n: n as u32,
3356 e5m2: false,
3357 },
3358 QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3359 x: node_offset(arena, node.inputs[0]),
3360 w_q: node_offset(arena, node.inputs[1]),
3361 scale: node_offset(arena, node.inputs[2]),
3362 dst: node_offset(arena, node.id),
3363 m: m as u32,
3364 k: k as u32,
3365 n: n as u32,
3366 e5m2: true,
3367 },
3368 QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3369 x: node_offset(arena, node.inputs[0]),
3370 w_q: node_offset(arena, node.inputs[1]),
3371 scale: node_offset(arena, node.inputs[2]),
3372 zp: node_offset(arena, node.inputs[3]),
3373 dst: node_offset(arena, node.id),
3374 m: m as u32,
3375 k: k as u32,
3376 n: n as u32,
3377 block_size: *block_size,
3378 is_asymmetric: false,
3379 },
3380 QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3381 x: node_offset(arena, node.inputs[0]),
3382 w_q: node_offset(arena, node.inputs[1]),
3383 scale: node_offset(arena, node.inputs[2]),
3384 zp: node_offset(arena, node.inputs[3]),
3385 dst: node_offset(arena, node.id),
3386 m: m as u32,
3387 k: k as u32,
3388 n: n as u32,
3389 block_size: *block_size,
3390 is_asymmetric: true,
3391 },
3392 other => panic!(
3393 "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3394 ),
3395 }
3396 }
3397 }
3398
3399 Op::LoraMatMul { scale } => {
3400 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3402 let total = node.shape.num_elements().unwrap();
3403 let m = total / n.max(1);
3404 let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3405 let k = x_total / m.max(1);
3406 let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3407 let r = a_total / k.max(1);
3408 Thunk::LoraMatMul {
3409 x: node_offset(arena, node.inputs[0]),
3410 w: node_offset(arena, node.inputs[1]),
3411 a: node_offset(arena, node.inputs[2]),
3412 b: node_offset(arena, node.inputs[3]),
3413 dst: node_offset(arena, node.id),
3414 m: m as u32,
3415 k: k as u32,
3416 n: n as u32,
3417 r: r as u32,
3418 scale: *scale,
3419 }
3420 }
3421
3422 Op::Sample {
3423 top_k,
3424 top_p,
3425 temperature,
3426 seed,
3427 } => {
3428 let in_shape = &graph.node(node.inputs[0]).shape;
3429 let (batch, vocab) = if in_shape.rank() >= 2 {
3431 (
3432 in_shape.dim(0).unwrap_static(),
3433 in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3434 )
3435 } else {
3436 (1, in_shape.num_elements().unwrap_or(0))
3437 };
3438 Thunk::Sample {
3439 logits: node_offset(arena, node.inputs[0]),
3440 dst: node_offset(arena, node.id),
3441 batch: batch as u32,
3442 vocab: vocab as u32,
3443 top_k: *top_k as u32,
3444 top_p: *top_p,
3445 temperature: *temperature,
3446 seed: *seed,
3447 }
3448 }
3449
3450 Op::Cumsum { axis, exclusive } => {
3451 let rank = node.shape.rank();
3456 let ax = if *axis < 0 {
3457 (rank as i32 + axis) as usize
3458 } else {
3459 *axis as usize
3460 };
3461 assert_eq!(
3462 ax,
3463 rank - 1,
3464 "Cumsum only supports the last axis on CPU today"
3465 );
3466 let cols = node.shape.dim(ax).unwrap_static();
3467 let total = node.shape.num_elements().unwrap();
3468 Thunk::Cumsum {
3469 src: node_offset(arena, node.inputs[0]),
3470 dst: node_offset(arena, node.id),
3471 rows: (total / cols) as u32,
3472 cols: cols as u32,
3473 exclusive: *exclusive,
3474 }
3475 }
3476
3477 Op::Attention {
3478 num_heads,
3479 head_dim,
3480 mask_kind,
3481 score_scale: _,
3482 attn_logit_softcap: _,
3483 } => {
3484 let q_shape = &graph.node(node.inputs[0]).shape;
3490 let k_shape = &graph.node(node.inputs[1]).shape;
3491 let rank = q_shape.rank();
3492 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3493 let d1 = q_shape.dim(1).unwrap_static();
3494 let d2 = q_shape.dim(2).unwrap_static();
3495 if d1 == *num_heads {
3496 (
3498 q_shape.dim(0).unwrap_static(),
3499 d2,
3500 k_shape.dim(2).unwrap_static(),
3501 true,
3502 )
3503 } else {
3504 (
3506 q_shape.dim(0).unwrap_static(),
3507 d1,
3508 k_shape.dim(1).unwrap_static(),
3509 false,
3510 )
3511 }
3512 } else if rank >= 3 {
3513 (
3514 q_shape.dim(0).unwrap_static(),
3515 q_shape.dim(1).unwrap_static(),
3516 k_shape.dim(1).unwrap_static(),
3517 false,
3518 )
3519 } else {
3520 (
3521 1,
3522 q_shape.dim(0).unwrap_static(),
3523 k_shape.dim(0).unwrap_static(),
3524 false,
3525 )
3526 };
3527 let mask_off = if matches!(
3528 mask_kind,
3529 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3530 ) {
3531 node_offset(arena, node.inputs[3])
3532 } else {
3533 0
3534 };
3535 let hs = (*num_heads * *head_dim) as u32;
3536 Thunk::Attention {
3537 q: node_offset(arena, node.inputs[0]),
3538 k: node_offset(arena, node.inputs[1]),
3539 v: node_offset(arena, node.inputs[2]),
3540 mask: mask_off,
3541 out: node_offset(arena, node.id),
3542 batch: batch as u32,
3543 seq: seq as u32,
3544 kv_seq: kv_seq as u32,
3545 heads: *num_heads as u32,
3546 head_dim: *head_dim as u32,
3547 mask_kind: *mask_kind,
3548 q_row_stride: hs,
3552 k_row_stride: hs,
3553 v_row_stride: hs,
3554 bhsd,
3555 }
3556 }
3557
3558 Op::AttentionBackward {
3559 num_heads,
3560 head_dim,
3561 mask_kind,
3562 wrt,
3563 } => {
3564 let q_shape = &graph.node(node.inputs[0]).shape;
3565 let k_shape = &graph.node(node.inputs[1]).shape;
3566 let rank = q_shape.rank();
3567 let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3568 let d1 = q_shape.dim(1).unwrap_static();
3569 let d2 = q_shape.dim(2).unwrap_static();
3570 if d1 == *num_heads {
3571 (
3572 q_shape.dim(0).unwrap_static(),
3573 d2,
3574 k_shape.dim(2).unwrap_static(),
3575 true,
3576 )
3577 } else {
3578 (
3579 q_shape.dim(0).unwrap_static(),
3580 d1,
3581 k_shape.dim(1).unwrap_static(),
3582 false,
3583 )
3584 }
3585 } else if rank >= 3 {
3586 (
3587 q_shape.dim(0).unwrap_static(),
3588 q_shape.dim(1).unwrap_static(),
3589 k_shape.dim(1).unwrap_static(),
3590 false,
3591 )
3592 } else {
3593 (
3594 1,
3595 q_shape.dim(0).unwrap_static(),
3596 k_shape.dim(0).unwrap_static(),
3597 false,
3598 )
3599 };
3600 let mask_off = if matches!(
3601 mask_kind,
3602 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3603 ) {
3604 node_offset(arena, node.inputs[4])
3605 } else {
3606 0
3607 };
3608 Thunk::AttentionBackward {
3609 q: node_offset(arena, node.inputs[0]),
3610 k: node_offset(arena, node.inputs[1]),
3611 v: node_offset(arena, node.inputs[2]),
3612 dy: node_offset(arena, node.inputs[3]),
3613 mask: mask_off,
3614 out: node_offset(arena, node.id),
3615 batch: batch as u32,
3616 seq: seq as u32,
3617 kv_seq: kv_seq as u32,
3618 heads: *num_heads as u32,
3619 head_dim: *head_dim as u32,
3620 mask_kind: *mask_kind,
3621 wrt: *wrt,
3622 bhsd,
3623 }
3624 }
3625
3626 Op::FusedAttentionBlock {
3627 num_heads,
3628 head_dim,
3629 has_bias,
3630 has_rope,
3631 } => {
3632 let x_shape = &graph.node(node.inputs[0]).shape;
3633 let (batch, seq) = if x_shape.rank() >= 3 {
3634 (
3635 x_shape.dim(0).unwrap_static(),
3636 x_shape.dim(1).unwrap_static(),
3637 )
3638 } else {
3639 let total = x_shape.num_elements().unwrap();
3640 let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3641 (total / (s * num_heads * head_dim), s)
3642 };
3643 let hs = (*num_heads * *head_dim) as u32;
3644 let mut idx = 4;
3646 let (qkv_b_off, out_b_off) = if *has_bias {
3647 let qb = node_offset(arena, node.inputs[idx]);
3648 let ob = node_offset(arena, node.inputs[idx + 1]);
3649 idx += 2;
3650 (qb, ob)
3651 } else {
3652 (0, 0)
3653 };
3654 let (cos_off, sin_off, cl) = if *has_rope {
3655 let c = node_offset(arena, node.inputs[idx]);
3656 let s = node_offset(arena, node.inputs[idx + 1]);
3657 let clen = get_len(graph, node.inputs[idx]);
3658 (c, s, clen as u32)
3659 } else {
3660 (0, 0, 0)
3661 };
3662
3663 Thunk::FusedAttnBlock {
3664 hidden: node_offset(arena, node.inputs[0]),
3665 qkv_w: node_offset(arena, node.inputs[1]),
3666 out_w: node_offset(arena, node.inputs[2]),
3667 mask: node_offset(arena, node.inputs[3]),
3668 out: node_offset(arena, node.id),
3669 qkv_b: qkv_b_off,
3670 out_b: out_b_off,
3671 cos: cos_off,
3672 sin: sin_off,
3673 cos_len: cl,
3674 batch: batch as u32,
3675 seq: seq as u32,
3676 hs,
3677 nh: *num_heads as u32,
3678 dh: *head_dim as u32,
3679 has_bias: *has_bias,
3680 has_rope: *has_rope,
3681 }
3682 }
3683
3684 Op::Rope { head_dim, n_rot } => {
3685 let x_shape = &graph.node(node.inputs[0]).shape;
3686 let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3687 (
3688 x_shape.dim(0).unwrap_static(),
3689 x_shape.dim(1).unwrap_static(),
3690 x_shape.dim(2).unwrap_static(),
3691 )
3692 } else {
3693 let total = x_shape.num_elements().unwrap();
3694 (
3695 1,
3696 x_shape.dim(0).unwrap_static(),
3697 total / x_shape.dim(0).unwrap_static(),
3698 )
3699 };
3700 let cos_len = get_len(graph, node.inputs[1]);
3701 Thunk::Rope {
3702 src: node_offset(arena, node.inputs[0]),
3703 cos: node_offset(arena, node.inputs[1]),
3704 sin: node_offset(arena, node.inputs[2]),
3705 dst: node_offset(arena, node.id),
3706 batch: batch as u32,
3707 seq: seq as u32,
3708 hidden: hidden as u32,
3709 head_dim: *head_dim as u32,
3710 n_rot: *n_rot as u32,
3711 cos_len: cos_len as u32,
3712 src_row_stride: hidden as u32,
3716 }
3717 }
3718
3719 Op::FusedSwiGLU {
3720 cast_to: _,
3721 gate_first,
3722 } => {
3723 let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3724 let total = node.shape.num_elements().unwrap();
3725 Thunk::FusedSwiGLU {
3726 src: node_offset(arena, node.inputs[0]),
3727 dst: node_offset(arena, node.id),
3728 n_half: n_half as u32,
3729 total: total as u32,
3730 gate_first: *gate_first,
3731 }
3732 }
3733
3734 Op::Conv {
3735 kernel_size,
3736 stride,
3737 padding,
3738 dilation,
3739 groups,
3740 } => {
3741 let in_shape = &graph.node(node.inputs[0]).shape;
3742 let w_shape = &graph.node(node.inputs[1]).shape;
3743 let out_shape = &node.shape;
3744 let is_1x1_simple = kernel_size.len() == 2
3748 && kernel_size[0] == 1
3749 && kernel_size[1] == 1
3750 && stride.iter().all(|&s| s == 1)
3751 && padding.iter().all(|&p| p == 0)
3752 && dilation.iter().all(|&d| d == 1)
3753 && *groups == 1;
3754 if is_1x1_simple
3755 && in_shape.rank() >= 3
3756 && out_shape.rank() >= 3
3757 && w_shape.rank() >= 2
3758 {
3759 let (n, c_in, h, w) = conv_nchw_dims(in_shape);
3760 let (_, c_out, _, _) = conv_nchw_dims(out_shape);
3761 Thunk::Conv2D1x1 {
3762 src: node_offset(arena, node.inputs[0]),
3763 weight: node_offset(arena, node.inputs[1]),
3764 dst: node_offset(arena, node.id),
3765 n,
3766 c_in,
3767 c_out,
3768 hw: h.saturating_mul(w),
3769 }
3770 } else if kernel_size.len() == 2
3771 && in_shape.rank() >= 3
3772 && w_shape.rank() >= 2
3773 && out_shape.rank() >= 3
3774 {
3775 let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3776 let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3777 Thunk::Conv2D {
3778 src: node_offset(arena, node.inputs[0]),
3779 weight: node_offset(arena, node.inputs[1]),
3780 dst: node_offset(arena, node.id),
3781 n,
3782 c_in,
3783 h,
3784 w: w_in,
3785 c_out,
3786 h_out,
3787 w_out,
3788 kh: kernel_size[0] as u32,
3789 kw: kernel_size[1] as u32,
3790 sh: stride.first().copied().unwrap_or(1) as u32,
3791 sw: stride.get(1).copied().unwrap_or(1) as u32,
3792 ph: padding.first().copied().unwrap_or(0) as u32,
3793 pw: padding.get(1).copied().unwrap_or(0) as u32,
3794 dh: dilation.first().copied().unwrap_or(1) as u32,
3795 dw: dilation.get(1).copied().unwrap_or(1) as u32,
3796 groups: *groups as u32,
3797 }
3798 } else {
3799 Thunk::Nop
3800 }
3801 }
3802
3803 Op::Pool {
3804 kind,
3805 kernel_size,
3806 stride,
3807 padding,
3808 } => {
3809 let in_shape = &graph.node(node.inputs[0]).shape;
3811 let out_shape = &node.shape;
3812 if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3813 Thunk::Pool2D {
3814 src: node_offset(arena, node.inputs[0]),
3815 dst: node_offset(arena, node.id),
3816 n: in_shape.dim(0).unwrap_static() as u32,
3817 c: in_shape.dim(1).unwrap_static() as u32,
3818 h: in_shape.dim(2).unwrap_static() as u32,
3819 w: in_shape.dim(3).unwrap_static() as u32,
3820 h_out: out_shape.dim(2).unwrap_static() as u32,
3821 w_out: out_shape.dim(3).unwrap_static() as u32,
3822 kh: kernel_size[0] as u32,
3823 kw: kernel_size[1] as u32,
3824 sh: stride.first().copied().unwrap_or(1) as u32,
3825 sw: stride.get(1).copied().unwrap_or(1) as u32,
3826 ph: padding.first().copied().unwrap_or(0) as u32,
3827 pw: padding.get(1).copied().unwrap_or(0) as u32,
3828 kind: *kind,
3829 }
3830 } else {
3831 Thunk::Nop
3832 }
3833 }
3834
3835 Op::Transpose { perm } => {
3836 let in_shape = &graph.node(node.inputs[0]).shape;
3839 let in_rank = in_shape.rank();
3840 if perm.iter().any(|&p| p >= in_rank) {
3841 Thunk::Nop
3842 } else {
3843 let in_dims: Vec<usize> = (0..in_rank)
3844 .map(|i| in_shape.dim(i).unwrap_static())
3845 .collect();
3846 let mut in_strides_full = vec![1usize; in_rank];
3848 for d in (0..in_rank.saturating_sub(1)).rev() {
3849 in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3850 }
3851 let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3852 let in_strides: Vec<u32> =
3853 perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3854 let in_total = in_dims.iter().product::<usize>() as u32;
3855 let src = node_offset(arena, node.inputs[0]);
3856 let dst = node_offset(arena, node.id);
3857 let elem_bytes = node.shape.dtype().size_bytes() as u8;
3858 match node.shape.dtype() {
3859 rlx_ir::DType::F64 => Thunk::TransposeF64 {
3860 src,
3861 dst,
3862 in_total,
3863 out_dims,
3864 in_strides,
3865 },
3866 _ => Thunk::Transpose {
3867 src,
3868 dst,
3869 in_total,
3870 out_dims,
3871 in_strides,
3872 elem_bytes,
3873 },
3874 }
3875 }
3876 }
3877
3878 Op::ScatterAdd => {
3879 let upd_shape = &graph.node(node.inputs[0]).shape;
3882 let out_shape = &node.shape;
3883 let num_updates = upd_shape.dim(0).unwrap_static();
3884 let out_dim = out_shape.dim(0).unwrap_static();
3885 let trailing: usize = (1..out_shape.rank())
3886 .map(|i| out_shape.dim(i).unwrap_static())
3887 .product::<usize>()
3888 .max(1);
3889 Thunk::ScatterAdd {
3890 updates: node_offset(arena, node.inputs[0]),
3891 indices: node_offset(arena, node.inputs[1]),
3892 dst: node_offset(arena, node.id),
3893 num_updates: num_updates as u32,
3894 out_dim: out_dim as u32,
3895 trailing: trailing as u32,
3896 }
3897 }
3898
3899 Op::GroupedMatMul => {
3900 let in_shape = &graph.node(node.inputs[0]).shape;
3902 let w_shape = &graph.node(node.inputs[1]).shape;
3903 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3904 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3905 let num_experts = w_shape.dim(0).unwrap_static();
3906 let n = w_shape.dim(2).unwrap_static();
3907 Thunk::GroupedMatMul {
3908 input: node_offset(arena, node.inputs[0]),
3909 weight: node_offset(arena, node.inputs[1]),
3910 expert_idx: node_offset(arena, node.inputs[2]),
3911 dst: node_offset(arena, node.id),
3912 m: m as u32,
3913 k_dim: k_dim as u32,
3914 n: n as u32,
3915 num_experts: num_experts as u32,
3916 }
3917 }
3918
3919 Op::DequantGroupedMatMul { scheme } => {
3920 let in_shape = &graph.node(node.inputs[0]).shape;
3921 let w_shape = &graph.node(node.inputs[1]).shape;
3922 let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3923 let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3924 let out_shape = &node.shape;
3925 let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3926 let block_elems = scheme.gguf_block_size() as usize;
3927 let block_bytes = scheme.gguf_block_bytes() as usize;
3928 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3929 let total_bytes = w_shape.num_elements().unwrap();
3930 let num_experts = total_bytes / slab_bytes.max(1);
3931 Thunk::DequantGroupedMatMulGguf {
3932 input: node_offset(arena, node.inputs[0]),
3933 w_q: node_offset(arena, node.inputs[1]),
3934 expert_idx: node_offset(arena, node.inputs[2]),
3935 dst: node_offset(arena, node.id),
3936 m: m as u32,
3937 k_dim: k_dim as u32,
3938 n: n as u32,
3939 num_experts: num_experts as u32,
3940 scheme: *scheme,
3941 }
3942 }
3943
3944 Op::DequantMoEWeights { scheme } => {
3945 let w_shape = &graph.node(node.inputs[0]).shape;
3946 let out_shape = &node.shape;
3947 let num_experts = out_shape.dim(0).unwrap_static();
3948 let k_dim = out_shape.dim(1).unwrap_static();
3949 let n = out_shape.dim(2).unwrap_static();
3950 let block_elems = scheme.gguf_block_size() as usize;
3951 let block_bytes = scheme.gguf_block_bytes() as usize;
3952 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3953 let total_bytes = w_shape.num_elements().unwrap();
3954 assert_eq!(
3955 total_bytes,
3956 num_experts * slab_bytes,
3957 "DequantMoEWeights packed bytes mismatch"
3958 );
3959 Thunk::DequantMoEWeightsGguf {
3960 w_q: node_offset(arena, node.inputs[0]),
3961 dst: node_offset(arena, node.id),
3962 k_dim: k_dim as u32,
3963 n: n as u32,
3964 num_experts: num_experts as u32,
3965 scheme: *scheme,
3966 }
3967 }
3968
3969 Op::TopK { k } => {
3970 let in_shape = &graph.node(node.inputs[0]).shape;
3971 let rank = in_shape.rank();
3972 let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3973 let outer = in_shape.num_elements().unwrap() / axis_dim;
3974 let indices_i64 = u8::from(graph.node(node.id).shape.dtype() == rlx_ir::DType::I64);
3975 Thunk::TopK {
3976 src: node_offset(arena, node.inputs[0]),
3977 dst: node_offset(arena, node.id),
3978 outer: outer as u32,
3979 axis_dim: axis_dim as u32,
3980 k: *k as u32,
3981 indices_i64,
3982 }
3983 }
3984
3985 Op::Reduce {
3986 op,
3987 axes,
3988 keep_dim: _,
3989 } => {
3990 let in_shape = &graph.node(node.inputs[0]).shape;
3996 let rank = in_shape.rank();
3997 let mut sorted = axes.clone();
3998 sorted.sort();
3999 sorted.dedup();
4000 let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
4001 && !sorted.is_empty()
4002 && *sorted.last().unwrap() < rank;
4003 if !contiguous {
4004 Thunk::Nop
4005 } else {
4006 let first = sorted[0];
4007 let last = *sorted.last().unwrap();
4008 let outer: usize = (0..first)
4009 .map(|i| in_shape.dim(i).unwrap_static())
4010 .product::<usize>()
4011 .max(1);
4012 let reduced: usize = (first..=last)
4013 .map(|i| in_shape.dim(i).unwrap_static())
4014 .product();
4015 let inner: usize = (last + 1..rank)
4016 .map(|i| in_shape.dim(i).unwrap_static())
4017 .product::<usize>()
4018 .max(1);
4019 let src = node_offset(arena, node.inputs[0]);
4020 let dst = node_offset(arena, node.id);
4021 if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
4022 Thunk::ReduceSumF64 {
4023 src,
4024 dst,
4025 outer: outer as u32,
4026 reduced: reduced as u32,
4027 inner: inner as u32,
4028 }
4029 } else {
4030 Thunk::Reduce {
4031 src,
4032 dst,
4033 outer: outer as u32,
4034 reduced: reduced as u32,
4035 inner: inner as u32,
4036 op: *op,
4037 }
4038 }
4039 }
4040 }
4041
4042 Op::Compare(cmp) => {
4043 let len = node.shape.num_elements().unwrap();
4044 let in_dtype = graph.node(node.inputs[0]).shape.dtype();
4045 let inputs_i64 = u8::from(in_dtype == rlx_ir::DType::I64);
4046 Thunk::Compare {
4047 lhs: node_offset(arena, node.inputs[0]),
4048 rhs: node_offset(arena, node.inputs[1]),
4049 dst: node_offset(arena, node.id),
4050 len: len as u32,
4051 op: *cmp,
4052 inputs_i64,
4053 inputs_elem_bytes: in_dtype.size_bytes() as u8,
4054 dst_elem_bytes: node.shape.dtype().size_bytes() as u8,
4055 }
4056 }
4057
4058 Op::Where => {
4059 let len = node.shape.num_elements().unwrap();
4060 let elem_bytes = node.shape.dtype().size_bytes() as u8;
4061 let cond_elem_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
4062 Thunk::Where {
4063 cond: node_offset(arena, node.inputs[0]),
4064 on_true: node_offset(arena, node.inputs[1]),
4065 on_false: node_offset(arena, node.inputs[2]),
4066 dst: node_offset(arena, node.id),
4067 len: len as u32,
4068 elem_bytes,
4069 cond_elem_bytes,
4070 }
4071 }
4072
4073 Op::ReluBackward => {
4074 let len: usize = (0..node.shape.rank())
4075 .map(|i| node.shape.dim(i).unwrap_static())
4076 .product();
4077 let x = node_offset(arena, node.inputs[0]);
4078 let dy = node_offset(arena, node.inputs[1]);
4079 let dx = node_offset(arena, node.id);
4080 match node.shape.dtype() {
4081 rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
4082 x,
4083 dy,
4084 dx,
4085 len: len as u32,
4086 },
4087 _ => Thunk::ReluBackward {
4088 x,
4089 dy,
4090 dx,
4091 len: len as u32,
4092 },
4093 }
4094 }
4095
4096 Op::ComplexNormSq => {
4097 let len: usize = (0..node.shape.rank())
4098 .map(|i| node.shape.dim(i).unwrap_static())
4099 .product();
4100 let src = node_offset(arena, node.inputs[0]);
4101 let dst = node_offset(arena, node.id);
4102 Thunk::ComplexNormSqF32 {
4103 src,
4104 dst,
4105 len: len as u32,
4106 }
4107 }
4108
4109 Op::ComplexNormSqBackward => {
4110 let len: usize = (0..node.shape.rank())
4111 .map(|i| node.shape.dim(i).unwrap_static())
4112 .product();
4113 let z = node_offset(arena, node.inputs[0]);
4114 let g = node_offset(arena, node.inputs[1]);
4115 let dz = node_offset(arena, node.id);
4116 Thunk::ComplexNormSqBackwardF32 {
4117 z,
4118 g,
4119 dz,
4120 len: len as u32,
4121 }
4122 }
4123
4124 Op::Conjugate => {
4125 let len: usize = (0..node.shape.rank())
4126 .map(|i| node.shape.dim(i).unwrap_static())
4127 .product();
4128 Thunk::ConjugateC64 {
4129 src: node_offset(arena, node.inputs[0]),
4130 dst: node_offset(arena, node.id),
4131 len: len as u32,
4132 }
4133 }
4134
4135 Op::ActivationBackward { kind } => {
4136 let len: usize = (0..node.shape.rank())
4137 .map(|i| node.shape.dim(i).unwrap_static())
4138 .product();
4139 let x = node_offset(arena, node.inputs[0]);
4140 let dy = node_offset(arena, node.inputs[1]);
4141 let dx = node_offset(arena, node.id);
4142 match node.shape.dtype() {
4143 rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
4144 x,
4145 dy,
4146 dx,
4147 len: len as u32,
4148 kind: *kind,
4149 },
4150 _ => Thunk::ActivationBackward {
4151 x,
4152 dy,
4153 dx,
4154 len: len as u32,
4155 kind: *kind,
4156 },
4157 }
4158 }
4159
4160 Op::LayerNormBackwardInput { eps, .. } => {
4161 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
4163 let total = node.shape.num_elements().unwrap();
4164 Thunk::LayerNormBackwardInput {
4165 x: node_offset(arena, node.inputs[0]),
4166 gamma: node_offset(arena, node.inputs[1]),
4167 dy: node_offset(arena, node.inputs[2]),
4168 dx: node_offset(arena, node.id),
4169 rows: (total / h) as u32,
4170 h: h as u32,
4171 eps: *eps,
4172 }
4173 }
4174
4175 Op::LayerNormBackwardGamma { eps, .. } => {
4176 let x_shape = &graph.node(node.inputs[0]).shape;
4177 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4178 let x_total = x_shape.num_elements().unwrap();
4179 Thunk::LayerNormBackwardGamma {
4180 x: node_offset(arena, node.inputs[0]),
4181 dy: node_offset(arena, node.inputs[1]),
4182 dgamma: node_offset(arena, node.id),
4183 rows: (x_total / h) as u32,
4184 h: h as u32,
4185 eps: *eps,
4186 }
4187 }
4188
4189 Op::RmsNormBackwardInput { eps, .. }
4190 | Op::RmsNormBackwardGamma { eps, .. }
4191 | Op::RmsNormBackwardBeta { eps, .. } => {
4192 let x_shape = &graph.node(node.inputs[0]).shape;
4193 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4194 let rows = (x_shape.num_elements().unwrap() / h) as u32;
4195 let off = |i: usize| node_offset(arena, node.inputs[i]);
4196 let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
4197 match &node.op {
4198 Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
4199 x: common.0,
4200 gamma: common.1,
4201 beta: common.2,
4202 dy: common.3,
4203 dx: node_offset(arena, node.id),
4204 rows: common.4,
4205 h: common.5,
4206 eps: common.6,
4207 },
4208 Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
4209 x: common.0,
4210 gamma: common.1,
4211 beta: common.2,
4212 dy: common.3,
4213 dgamma: node_offset(arena, node.id),
4214 rows: common.4,
4215 h: common.5,
4216 eps: common.6,
4217 },
4218 Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
4219 x: common.0,
4220 gamma: common.1,
4221 beta: common.2,
4222 dy: common.3,
4223 dbeta: node_offset(arena, node.id),
4224 rows: common.4,
4225 h: common.5,
4226 eps: common.6,
4227 },
4228 _ => unreachable!(),
4229 }
4230 }
4231
4232 Op::RopeBackward { head_dim, n_rot } => {
4233 let dy_shape = &graph.node(node.inputs[0]).shape;
4234 let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
4235 (
4236 dy_shape.dim(0).unwrap_static(),
4237 dy_shape.dim(1).unwrap_static(),
4238 dy_shape.dim(2).unwrap_static(),
4239 )
4240 } else {
4241 (
4242 1,
4243 dy_shape.dim(0).unwrap_static(),
4244 dy_shape.dim(1).unwrap_static(),
4245 )
4246 };
4247 let cos_shape = &graph.node(node.inputs[1]).shape;
4248 let cos_len = cos_shape.num_elements().unwrap();
4249 Thunk::RopeBackward {
4250 dy: node_offset(arena, node.inputs[0]),
4251 cos: node_offset(arena, node.inputs[1]),
4252 sin: node_offset(arena, node.inputs[2]),
4253 dx: node_offset(arena, node.id),
4254 batch: batch as u32,
4255 seq: seq as u32,
4256 hidden: hidden as u32,
4257 head_dim: *head_dim as u32,
4258 n_rot: *n_rot as u32,
4259 cos_len: cos_len as u32,
4260 }
4261 }
4262
4263 Op::CumsumBackward { exclusive, .. } => {
4264 let dy_shape = &graph.node(node.inputs[0]).shape;
4265 let rank = dy_shape.rank();
4266 let cols = dy_shape.dim(rank - 1).unwrap_static();
4267 let rows = dy_shape.num_elements().unwrap() / cols;
4268 Thunk::CumsumBackward {
4269 dy: node_offset(arena, node.inputs[0]),
4270 dx: node_offset(arena, node.id),
4271 rows: rows as u32,
4272 cols: cols as u32,
4273 exclusive: *exclusive,
4274 }
4275 }
4276
4277 Op::GatherBackward { .. } => {
4278 let dy_shape = &graph.node(node.inputs[0]).shape;
4279 let idx_shape = &graph.node(node.inputs[1]).shape;
4280 let out_shape = &node.shape;
4281 let rank = out_shape.rank();
4282 let axis = match &node.op {
4283 Op::GatherBackward { axis } => *axis,
4284 _ => 0,
4285 };
4286 let axis_u = if axis < 0 {
4287 (rank as i32 + axis) as usize
4288 } else {
4289 axis as usize
4290 };
4291 let outer: usize = (0..axis_u)
4292 .map(|i| dy_shape.dim(i).unwrap_static())
4293 .product::<usize>()
4294 .max(1);
4295 let num_idx = idx_shape.dim(axis_u).unwrap_static();
4296 let trailing: usize = (axis_u + 1..dy_shape.rank())
4297 .map(|i| dy_shape.dim(i).unwrap_static())
4298 .product::<usize>()
4299 .max(1);
4300 let axis_dim = out_shape.dim(axis_u).unwrap_static();
4301 Thunk::GatherBackward {
4302 dy: node_offset(arena, node.inputs[0]),
4303 indices: node_offset(arena, node.inputs[1]),
4304 dst: node_offset(arena, node.id),
4305 outer: outer as u32,
4306 axis_dim: axis_dim as u32,
4307 num_idx: num_idx as u32,
4308 trailing: trailing as u32,
4309 }
4310 }
4311
4312 Op::GroupNormBackwardInput { num_groups, eps }
4313 | Op::GroupNormBackwardGamma { num_groups, eps }
4314 | Op::GroupNormBackwardBeta { num_groups, eps } => {
4315 let x_shape = &graph.node(node.inputs[0]).shape;
4316 let n = x_shape.dim(0).unwrap_static() as u32;
4317 let c = x_shape.dim(1).unwrap_static() as u32;
4318 let h = x_shape.dim(2).unwrap_static() as u32;
4319 let w = x_shape.dim(3).unwrap_static() as u32;
4320 match &node.op {
4321 Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4322 x: node_offset(arena, node.inputs[0]),
4323 gamma: node_offset(arena, node.inputs[1]),
4324 beta: node_offset(arena, node.inputs[2]),
4325 dy: node_offset(arena, node.inputs[3]),
4326 dx: node_offset(arena, node.id),
4327 n,
4328 c,
4329 h,
4330 w,
4331 num_groups: *num_groups as u32,
4332 eps: *eps,
4333 },
4334 Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4335 x: node_offset(arena, node.inputs[0]),
4336 dy: node_offset(arena, node.inputs[1]),
4337 dgamma: node_offset(arena, node.id),
4338 n,
4339 c,
4340 h,
4341 w,
4342 num_groups: *num_groups as u32,
4343 eps: *eps,
4344 },
4345 Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4346 dy: node_offset(arena, node.inputs[1]),
4347 dbeta: node_offset(arena, node.id),
4348 n,
4349 c,
4350 h,
4351 w,
4352 },
4353 _ => unreachable!(),
4354 }
4355 }
4356
4357 Op::MaxPool2dBackward {
4358 kernel_size,
4359 stride,
4360 padding,
4361 } => {
4362 let x_shape = &graph.node(node.inputs[0]).shape;
4363 let dy_shape = &graph.node(node.inputs[1]).shape;
4364 if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4365 Thunk::MaxPool2dBackward {
4366 x: node_offset(arena, node.inputs[0]),
4367 dy: node_offset(arena, node.inputs[1]),
4368 dx: node_offset(arena, node.id),
4369 n: x_shape.dim(0).unwrap_static() as u32,
4370 c: x_shape.dim(1).unwrap_static() as u32,
4371 h: x_shape.dim(2).unwrap_static() as u32,
4372 w: x_shape.dim(3).unwrap_static() as u32,
4373 h_out: dy_shape.dim(2).unwrap_static() as u32,
4374 w_out: dy_shape.dim(3).unwrap_static() as u32,
4375 kh: kernel_size[0] as u32,
4376 kw: kernel_size[1] as u32,
4377 sh: stride.first().copied().unwrap_or(1) as u32,
4378 sw: stride.get(1).copied().unwrap_or(1) as u32,
4379 ph: padding.first().copied().unwrap_or(0) as u32,
4380 pw: padding.get(1).copied().unwrap_or(0) as u32,
4381 }
4382 } else {
4383 Thunk::Nop
4384 }
4385 }
4386
4387 Op::Conv2dBackwardInput {
4388 kernel_size,
4389 stride,
4390 padding,
4391 dilation,
4392 groups,
4393 } => {
4394 let dy_shape = &graph.node(node.inputs[0]).shape;
4395 let w_shape = &graph.node(node.inputs[1]).shape;
4396 let out_shape = &node.shape;
4397 if kernel_size.len() == 2
4398 && dy_shape.rank() == 4
4399 && w_shape.rank() == 4
4400 && out_shape.rank() == 4
4401 {
4402 Thunk::Conv2dBackwardInput {
4403 dy: node_offset(arena, node.inputs[0]),
4404 w: node_offset(arena, node.inputs[1]),
4405 dx: node_offset(arena, node.id),
4406 n: out_shape.dim(0).unwrap_static() as u32,
4407 c_in: out_shape.dim(1).unwrap_static() as u32,
4408 h: out_shape.dim(2).unwrap_static() as u32,
4409 w_in: out_shape.dim(3).unwrap_static() as u32,
4410 c_out: dy_shape.dim(1).unwrap_static() as u32,
4411 h_out: dy_shape.dim(2).unwrap_static() as u32,
4412 w_out: dy_shape.dim(3).unwrap_static() as u32,
4413 kh: kernel_size[0] as u32,
4414 kw: kernel_size[1] as u32,
4415 sh: stride.first().copied().unwrap_or(1) as u32,
4416 sw: stride.get(1).copied().unwrap_or(1) as u32,
4417 ph: padding.first().copied().unwrap_or(0) as u32,
4418 pw: padding.get(1).copied().unwrap_or(0) as u32,
4419 dh: dilation.first().copied().unwrap_or(1) as u32,
4420 dw: dilation.get(1).copied().unwrap_or(1) as u32,
4421 groups: *groups as u32,
4422 }
4423 } else {
4424 Thunk::Nop
4425 }
4426 }
4427
4428 Op::Conv2dBackwardWeight {
4429 kernel_size,
4430 stride,
4431 padding,
4432 dilation,
4433 groups,
4434 } => {
4435 let x_shape = &graph.node(node.inputs[0]).shape;
4436 let dy_shape = &graph.node(node.inputs[1]).shape;
4437 let dw_shape = &node.shape;
4438 if kernel_size.len() == 2
4439 && x_shape.rank() == 4
4440 && dy_shape.rank() == 4
4441 && dw_shape.rank() == 4
4442 {
4443 Thunk::Conv2dBackwardWeight {
4444 x: node_offset(arena, node.inputs[0]),
4445 dy: node_offset(arena, node.inputs[1]),
4446 dw: node_offset(arena, node.id),
4447 n: x_shape.dim(0).unwrap_static() as u32,
4448 c_in: x_shape.dim(1).unwrap_static() as u32,
4449 h: x_shape.dim(2).unwrap_static() as u32,
4450 w: x_shape.dim(3).unwrap_static() as u32,
4451 c_out: dy_shape.dim(1).unwrap_static() as u32,
4452 h_out: dy_shape.dim(2).unwrap_static() as u32,
4453 w_out: dy_shape.dim(3).unwrap_static() as u32,
4454 kh: kernel_size[0] as u32,
4455 kw: kernel_size[1] as u32,
4456 sh: stride.first().copied().unwrap_or(1) as u32,
4457 sw: stride.get(1).copied().unwrap_or(1) as u32,
4458 ph: padding.first().copied().unwrap_or(0) as u32,
4459 pw: padding.get(1).copied().unwrap_or(0) as u32,
4460 dh: dilation.first().copied().unwrap_or(1) as u32,
4461 dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4462 groups: *groups as u32,
4463 }
4464 } else {
4465 Thunk::Nop
4466 }
4467 }
4468
4469 Op::Im2Col {
4470 kernel_size,
4471 stride,
4472 padding,
4473 dilation,
4474 } => {
4475 let x_shape = &graph.node(node.inputs[0]).shape;
4476 let out_shape = &node.shape;
4477 if kernel_size.len() == 2 && x_shape.rank() == 4 && out_shape.rank() == 2 {
4478 let n = match x_shape.dim(0) {
4479 rlx_ir::shape::Dim::Static(v) => v as u32,
4480 _ => 0,
4481 };
4482 let c_in = x_shape.dim(1).unwrap_static() as u32;
4483 let h = x_shape.dim(2).unwrap_static() as u32;
4484 let w = x_shape.dim(3).unwrap_static() as u32;
4485 let kh = kernel_size[0] as u32;
4486 let kw = kernel_size[1] as u32;
4487 let sh = stride.first().copied().unwrap_or(1) as u32;
4488 let sw = stride.get(1).copied().unwrap_or(1) as u32;
4489 let ph = padding.first().copied().unwrap_or(0) as u32;
4490 let pw = padding.get(1).copied().unwrap_or(0) as u32;
4491 let dh = dilation.first().copied().unwrap_or(1) as u32;
4492 let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
4493 let h_out = rlx_ir::shape::conv2d_spatial_output(
4494 h as usize,
4495 kh as usize,
4496 sh as usize,
4497 ph as usize,
4498 dh as usize,
4499 ) as u32;
4500 let w_out = rlx_ir::shape::conv2d_spatial_output(
4501 w as usize,
4502 kw as usize,
4503 sw as usize,
4504 pw as usize,
4505 dw_dil as usize,
4506 ) as u32;
4507 Thunk::Im2Col {
4508 x: node_offset(arena, node.inputs[0]),
4509 col: node_offset(arena, node.id),
4510 n,
4511 c_in,
4512 h,
4513 w,
4514 h_out,
4515 w_out,
4516 kh,
4517 kw,
4518 sh,
4519 sw,
4520 ph,
4521 pw,
4522 dh,
4523 dw_dil,
4524 }
4525 } else {
4526 Thunk::Nop
4527 }
4528 }
4529
4530 Op::SoftmaxCrossEntropyWithLogits => {
4531 let logits_shape = &graph.node(node.inputs[0]).shape;
4532 if logits_shape.rank() == 2 {
4533 Thunk::SoftmaxCrossEntropy {
4534 logits: node_offset(arena, node.inputs[0]),
4535 labels: node_offset(arena, node.inputs[1]),
4536 dst: node_offset(arena, node.id),
4537 n: logits_shape.dim(0).unwrap_static() as u32,
4538 c: logits_shape.dim(1).unwrap_static() as u32,
4539 }
4540 } else {
4541 Thunk::Nop
4542 }
4543 }
4544
4545 Op::SoftmaxCrossEntropyBackward => {
4546 let logits_shape = &graph.node(node.inputs[0]).shape;
4547 if logits_shape.rank() == 2 {
4548 Thunk::SoftmaxCrossEntropyBackward {
4549 logits: node_offset(arena, node.inputs[0]),
4550 labels: node_offset(arena, node.inputs[1]),
4551 d_loss: node_offset(arena, node.inputs[2]),
4552 dlogits: node_offset(arena, node.id),
4553 n: logits_shape.dim(0).unwrap_static() as u32,
4554 c: logits_shape.dim(1).unwrap_static() as u32,
4555 }
4556 } else {
4557 Thunk::Nop
4558 }
4559 }
4560
4561 Op::DenseSolve => {
4562 let a_shape = &graph.node(node.inputs[0]).shape;
4564 let n = a_shape.dim(0).unwrap_static();
4565 debug_assert_eq!(
4566 n,
4567 a_shape.dim(1).unwrap_static(),
4568 "DenseSolve: A must be square"
4569 );
4570 let b_elems = node.shape.num_elements().unwrap();
4571 let nrhs = b_elems / n;
4572 match node.shape.dtype() {
4573 rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4574 a: node_offset(arena, node.inputs[0]),
4575 b: node_offset(arena, node.inputs[1]),
4576 x: node_offset(arena, node.id),
4577 n: n as u32,
4578 nrhs: nrhs as u32,
4579 },
4580 rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4581 a: node_offset(arena, node.inputs[0]),
4582 b: node_offset(arena, node.inputs[1]),
4583 x: node_offset(arena, node.id),
4584 n: n as u32,
4585 nrhs: nrhs as u32,
4586 },
4587 other => panic!(
4588 "DenseSolve: F32 + F64 lowered; got {other:?}. \
4589 Add another variant when needed."
4590 ),
4591 }
4592 }
4593
4594 Op::BatchedDenseSolve => {
4595 let a_shape = &graph.node(node.inputs[0]).shape;
4597 assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4598 let batch = a_shape.dim(0).unwrap_static();
4599 let n = a_shape.dim(1).unwrap_static();
4600 debug_assert_eq!(
4601 n,
4602 a_shape.dim(2).unwrap_static(),
4603 "BatchedDenseSolve: A's last two dims must match"
4604 );
4605 let total = node.shape.num_elements().unwrap();
4606 let nrhs = total / (batch * n);
4607 match node.shape.dtype() {
4608 rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4609 a: node_offset(arena, node.inputs[0]),
4610 b: node_offset(arena, node.inputs[1]),
4611 x: node_offset(arena, node.id),
4612 batch: batch as u32,
4613 n: n as u32,
4614 nrhs: nrhs as u32,
4615 },
4616 rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4617 a: node_offset(arena, node.inputs[0]),
4618 b: node_offset(arena, node.inputs[1]),
4619 x: node_offset(arena, node.id),
4620 batch: batch as u32,
4621 n: n as u32,
4622 nrhs: nrhs as u32,
4623 },
4624 other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4625 }
4626 }
4627
4628 Op::Scan {
4629 body,
4630 length,
4631 save_trajectory,
4632 num_bcast,
4633 num_xs,
4634 num_checkpoints,
4635 } => {
4636 assert!(
4637 *num_checkpoints == 0 || *num_checkpoints <= *length,
4638 "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4639 *num_checkpoints,
4640 *length
4641 );
4642 if *num_checkpoints != 0 && *num_checkpoints != *length {
4643 assert!(
4644 *save_trajectory,
4645 "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4646 );
4647 }
4648 let body_plan = rlx_opt::memory::plan_memory(body);
4659 let _body_arena_size = body_plan.arena_size;
4660 let body_offsets: HashMap<NodeId, usize> = body_plan
4663 .assignments
4664 .iter()
4665 .map(|(id, slot)| (*id, slot.offset))
4666 .collect();
4667
4668 let mut body_inputs: Vec<NodeId> = body
4671 .nodes()
4672 .iter()
4673 .filter(|n| matches!(n.op, Op::Input { .. }))
4674 .map(|n| n.id)
4675 .collect();
4676 body_inputs.sort();
4677 let n_body_inputs = body_inputs.len();
4678 let expected = 1 + *num_bcast as usize + *num_xs as usize;
4679 if n_body_inputs != expected {
4680 let names: Vec<String> = body
4681 .nodes()
4682 .iter()
4683 .filter_map(|n| match &n.op {
4684 Op::Input { name } => Some(format!("{}={}", n.id, name)),
4685 _ => None,
4686 })
4687 .collect();
4688 panic!(
4689 "Op::Scan body has {} Op::Input nodes; expected {} \
4690 (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4691 n_body_inputs,
4692 expected,
4693 *num_bcast,
4694 *num_xs,
4695 names.join(", ")
4696 );
4697 }
4698
4699 let body_input_id = body_inputs[0];
4700 let body_input_off = body_offsets[&body_input_id];
4701 let body_output_id = body
4702 .outputs
4703 .first()
4704 .copied()
4705 .expect("Op::Scan body must declare one output");
4706 let body_output_off = body_offsets[&body_output_id];
4707
4708 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4709 for n in body.nodes() {
4712 if let Op::Constant { data } = &n.op
4713 && body_arena.has_buffer(n.id)
4714 && !data.is_empty()
4715 {
4716 match n.shape.dtype() {
4717 rlx_ir::DType::F64 => {
4718 let off = body_arena.byte_offset(n.id);
4719 let buf = body_arena.raw_buf_mut();
4720 let nbytes = (buf.len() - off).min(data.len());
4721 buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4722 }
4723 _ => {
4724 let buf = body_arena.slice_mut(n.id);
4725 let n_floats = data.len() / 4;
4726 let n_lim = buf.len().min(n_floats);
4727 for i in 0..n_lim {
4728 let bytes = [
4729 data[i * 4],
4730 data[i * 4 + 1],
4731 data[i * 4 + 2],
4732 data[i * 4 + 3],
4733 ];
4734 buf[i] = f32::from_le_bytes(bytes);
4735 }
4736 }
4737 }
4738 }
4739 }
4740 let body_init = body_arena.raw_buf().to_vec();
4741 let body_schedule = compile_thunks(body, &body_arena);
4742
4743 let carry_bytes = if *save_trajectory {
4748 let total = node
4749 .shape
4750 .size_bytes()
4751 .expect("Op::Scan trajectory output must have static shape");
4752 total / *length as usize
4753 } else {
4754 node.shape
4755 .size_bytes()
4756 .expect("Op::Scan carry must have static shape")
4757 };
4758
4759 let mut bcast_inputs: Vec<(usize, usize, u32)> =
4764 Vec::with_capacity(*num_bcast as usize);
4765 for i in 0..*num_bcast as usize {
4766 let body_b_id = body_inputs[1 + i];
4767 let body_b_off = body_offsets[&body_b_id];
4768 let outer_b_id = node.inputs[1 + i];
4769 let outer_b_off = node_offset(arena, outer_b_id);
4770 let outer_b_shape = &graph.node(outer_b_id).shape;
4771 let total = outer_b_shape
4772 .size_bytes()
4773 .expect("Op::Scan bcast must have static shape");
4774 bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4775 }
4776
4777 let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4781 let xs_base = 1 + *num_bcast as usize;
4782 for i in 0..*num_xs as usize {
4783 let body_x_id = body_inputs[xs_base + i];
4784 let body_x_off = body_offsets[&body_x_id];
4785 let outer_xs_id = node.inputs[xs_base + i];
4786 let outer_xs_off = node_offset(arena, outer_xs_id);
4787 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4788 let total = outer_xs_shape
4789 .size_bytes()
4790 .expect("Op::Scan xs must have static shape");
4791 let per_step = total / *length as usize;
4792 xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4793 }
4794
4795 Thunk::Scan {
4796 body: Arc::new(body_schedule),
4797 body_init: Arc::new(body_init),
4798 body_input_off,
4799 body_output_off,
4800 outer_init_off: node_offset(arena, node.inputs[0]),
4801 outer_final_off: node_offset(arena, node.id),
4802 length: *length,
4803 carry_bytes: carry_bytes as u32,
4804 save_trajectory: *save_trajectory,
4805 xs_inputs: Arc::new(xs_inputs),
4806 bcast_inputs: Arc::new(bcast_inputs),
4807 num_checkpoints: *num_checkpoints,
4808 }
4809 }
4810
4811 Op::ScanBackward {
4812 body_vjp,
4813 length,
4814 save_trajectory,
4815 num_xs,
4816 num_checkpoints,
4817 forward_body,
4818 } => {
4819 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4820 if is_recursive {
4821 assert!(
4822 forward_body.is_some(),
4823 "Op::ScanBackward with num_checkpoints<length requires forward_body"
4824 );
4825 }
4826 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4834 let body_offsets: HashMap<NodeId, usize> = body_plan
4835 .assignments
4836 .iter()
4837 .map(|(id, slot)| (*id, slot.offset))
4838 .collect();
4839 let mut body_d_output_off: Option<usize> = None;
4840 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4841 for n in body_vjp.nodes() {
4842 if let Op::Input { name } = &n.op {
4843 let off = body_offsets[&n.id];
4844 if name == "d_output" {
4845 body_d_output_off = Some(off);
4846 } else {
4847 body_other_inputs.push((n.id, off));
4848 }
4849 }
4850 }
4851 body_other_inputs.sort_by_key(|(id, _)| *id);
4852 let body_d_output_off =
4853 body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4854 let expected_others = 1 + *num_xs as usize;
4855 assert_eq!(
4856 body_other_inputs.len(),
4857 expected_others,
4858 "ScanBackward body_vjp has {} non-d_output Inputs; \
4859 expected {} (1 carry + {} xs)",
4860 body_other_inputs.len(),
4861 expected_others,
4862 num_xs
4863 );
4864 let body_carry_in_off = body_other_inputs[0].1;
4865 let body_x_offs: Vec<usize> = body_other_inputs
4866 .iter()
4867 .skip(1)
4868 .map(|(_, off)| *off)
4869 .collect();
4870 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4871
4872 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4873 for n in body_vjp.nodes() {
4875 if let Op::Constant { data } = &n.op
4876 && body_arena.has_buffer(n.id)
4877 && !data.is_empty()
4878 {
4879 match n.shape.dtype() {
4880 rlx_ir::DType::F64 => {
4881 let off = body_arena.byte_offset(n.id);
4882 let buf = body_arena.raw_buf_mut();
4883 let nb = (buf.len() - off).min(data.len());
4884 buf[off..off + nb].copy_from_slice(&data[..nb]);
4885 }
4886 _ => {
4887 let buf = body_arena.slice_mut(n.id);
4888 let nf = data.len() / 4;
4889 let nl = buf.len().min(nf);
4890 for i in 0..nl {
4891 let bytes = [
4892 data[i * 4],
4893 data[i * 4 + 1],
4894 data[i * 4 + 2],
4895 data[i * 4 + 3],
4896 ];
4897 buf[i] = f32::from_le_bytes(bytes);
4898 }
4899 }
4900 }
4901 }
4902 }
4903 let body_init = body_arena.raw_buf().to_vec();
4904 let body_schedule = compile_thunks(body_vjp, &body_arena);
4905
4906 let carry_bytes = body_vjp
4908 .node(body_vjp.outputs[0])
4909 .shape
4910 .size_bytes()
4911 .expect("ScanBackward dcarry must be statically shaped");
4912 let carry_elem_size = body_vjp
4913 .node(body_vjp.outputs[0])
4914 .shape
4915 .dtype()
4916 .size_bytes() as u32;
4917
4918 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4921 for i in 0..*num_xs as usize {
4922 let outer_xs_id = node.inputs[3 + i];
4923 let outer_xs_off = node_offset(arena, outer_xs_id);
4924 let outer_xs_shape = &graph.node(outer_xs_id).shape;
4925 let total = outer_xs_shape
4926 .size_bytes()
4927 .expect("ScanBackward xs must have static shape");
4928 let per_step = total / *length as usize;
4929 outer_xs_offs.push((outer_xs_off, per_step as u32));
4930 }
4931
4932 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4937 if is_recursive {
4938 let fb = forward_body.as_ref().unwrap();
4939 let fb_plan = rlx_opt::memory::plan_memory(fb);
4940 let fb_offsets: HashMap<NodeId, usize> = fb_plan
4941 .assignments
4942 .iter()
4943 .map(|(id, slot)| (*id, slot.offset))
4944 .collect();
4945 let mut fb_inputs: Vec<NodeId> = fb
4946 .nodes()
4947 .iter()
4948 .filter(|n| matches!(n.op, Op::Input { .. }))
4949 .map(|n| n.id)
4950 .collect();
4951 fb_inputs.sort();
4952 let fb_carry = fb_offsets[&fb_inputs[0]];
4953 let fb_xs: Vec<usize> = (1..fb_inputs.len())
4954 .map(|i| fb_offsets[&fb_inputs[i]])
4955 .collect();
4956 let fb_out = fb_offsets[&fb.outputs[0]];
4957 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4958 for n in fb.nodes() {
4959 if let Op::Constant { data } = &n.op
4960 && fb_arena.has_buffer(n.id)
4961 && !data.is_empty()
4962 {
4963 let off = fb_arena.byte_offset(n.id);
4970 let buf = fb_arena.raw_buf_mut();
4971 let nb = (buf.len() - off).min(data.len());
4972 buf[off..off + nb].copy_from_slice(&data[..nb]);
4973 }
4974 }
4975 let fb_init_bytes = fb_arena.raw_buf().to_vec();
4976 let fb_sched = compile_thunks(fb, &fb_arena);
4977 (
4978 Some(Arc::new(fb_sched)),
4979 Some(Arc::new(fb_init_bytes)),
4980 fb_carry,
4981 fb_out,
4982 fb_xs,
4983 )
4984 } else {
4985 (None, None, 0, 0, Vec::new())
4986 };
4987
4988 Thunk::ScanBackward {
4989 body_vjp: Arc::new(body_schedule),
4990 body_init: Arc::new(body_init),
4991 body_carry_in_off,
4992 body_x_offs: Arc::new(body_x_offs),
4993 body_d_output_off,
4994 body_dcarry_out_off,
4995 outer_init_off: node_offset(arena, node.inputs[0]),
4996 outer_traj_off: node_offset(arena, node.inputs[1]),
4997 outer_upstream_off: node_offset(arena, node.inputs[2]),
4998 outer_xs_offs: Arc::new(outer_xs_offs),
4999 outer_dinit_off: node_offset(arena, node.id),
5000 length: *length,
5001 carry_bytes: carry_bytes as u32,
5002 carry_elem_size,
5003 save_trajectory: *save_trajectory,
5004 num_checkpoints: *num_checkpoints,
5005 forward_body: fb_schedule,
5006 forward_body_init: fb_init,
5007 forward_body_carry_in_off: fb_carry_in_off,
5008 forward_body_output_off: fb_output_off,
5009 forward_body_x_offs: Arc::new(fb_x_offs),
5010 }
5011 }
5012
5013 Op::ScanBackwardXs {
5014 body_vjp,
5015 length,
5016 save_trajectory,
5017 num_xs,
5018 xs_idx,
5019 num_checkpoints,
5020 forward_body,
5021 } => {
5022 assert!(
5023 *num_checkpoints == 0 || *num_checkpoints <= *length,
5024 "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
5025 *num_checkpoints,
5026 *length
5027 );
5028 let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
5029 if is_recursive {
5030 assert!(
5031 forward_body.is_some(),
5032 "Op::ScanBackwardXs with num_checkpoints<length \
5033 requires forward_body"
5034 );
5035 }
5036 let body_plan = rlx_opt::memory::plan_memory(body_vjp);
5044 let body_offsets: HashMap<NodeId, usize> = body_plan
5045 .assignments
5046 .iter()
5047 .map(|(id, slot)| (*id, slot.offset))
5048 .collect();
5049 let mut body_d_output_off: Option<usize> = None;
5050 let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
5051 for n in body_vjp.nodes() {
5052 if let Op::Input { name } = &n.op {
5053 let off = body_offsets[&n.id];
5054 if name == "d_output" {
5055 body_d_output_off = Some(off);
5056 } else {
5057 body_other_inputs.push((n.id, off));
5058 }
5059 }
5060 }
5061 body_other_inputs.sort_by_key(|(id, _)| *id);
5062 let body_d_output_off =
5063 body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
5064 let expected_others = 1 + *num_xs as usize;
5065 assert_eq!(
5066 body_other_inputs.len(),
5067 expected_others,
5068 "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
5069 body_other_inputs.len(),
5070 expected_others
5071 );
5072 let body_carry_in_off = body_other_inputs[0].1;
5073 let body_x_offs: Vec<usize> = body_other_inputs
5074 .iter()
5075 .skip(1)
5076 .map(|(_, off)| *off)
5077 .collect();
5078 let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
5079 let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
5080 let body_dxs_out_off = body_offsets[&dxs_out_node];
5081
5082 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5083 for n in body_vjp.nodes() {
5084 if let Op::Constant { data } = &n.op
5085 && body_arena.has_buffer(n.id)
5086 && !data.is_empty()
5087 {
5088 match n.shape.dtype() {
5089 rlx_ir::DType::F64 => {
5090 let off = body_arena.byte_offset(n.id);
5091 let buf = body_arena.raw_buf_mut();
5092 let nb = (buf.len() - off).min(data.len());
5093 buf[off..off + nb].copy_from_slice(&data[..nb]);
5094 }
5095 _ => {
5096 let buf = body_arena.slice_mut(n.id);
5097 let nf = data.len() / 4;
5098 let nl = buf.len().min(nf);
5099 for i in 0..nl {
5100 let bytes = [
5101 data[i * 4],
5102 data[i * 4 + 1],
5103 data[i * 4 + 2],
5104 data[i * 4 + 3],
5105 ];
5106 buf[i] = f32::from_le_bytes(bytes);
5107 }
5108 }
5109 }
5110 }
5111 }
5112 let body_init = body_arena.raw_buf().to_vec();
5113 let body_schedule = compile_thunks(body_vjp, &body_arena);
5114
5115 let carry_bytes = body_vjp
5116 .node(body_vjp.outputs[0])
5117 .shape
5118 .size_bytes()
5119 .expect("ScanBackwardXs dcarry must be statically shaped");
5120 let carry_elem_size = body_vjp
5121 .node(body_vjp.outputs[0])
5122 .shape
5123 .dtype()
5124 .size_bytes() as u32;
5125 let per_step_bytes = body_vjp
5126 .node(dxs_out_node)
5127 .shape
5128 .size_bytes()
5129 .expect("ScanBackwardXs dxs body output must be statically shaped");
5130
5131 let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
5132 for i in 0..*num_xs as usize {
5133 let outer_xs_id = node.inputs[3 + i];
5134 let outer_xs_off = node_offset(arena, outer_xs_id);
5135 let outer_xs_shape = &graph.node(outer_xs_id).shape;
5136 let total = outer_xs_shape
5137 .size_bytes()
5138 .expect("ScanBackwardXs xs must have static shape");
5139 let per_step = total / *length as usize;
5140 outer_xs_offs.push((outer_xs_off, per_step as u32));
5141 }
5142
5143 let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
5146 if is_recursive {
5147 let fb = forward_body.as_ref().unwrap();
5148 let fb_plan = rlx_opt::memory::plan_memory(fb);
5149 let fb_offsets: HashMap<NodeId, usize> = fb_plan
5150 .assignments
5151 .iter()
5152 .map(|(id, slot)| (*id, slot.offset))
5153 .collect();
5154 let mut fb_inputs: Vec<NodeId> = fb
5155 .nodes()
5156 .iter()
5157 .filter(|n| matches!(n.op, Op::Input { .. }))
5158 .map(|n| n.id)
5159 .collect();
5160 fb_inputs.sort();
5161 let fb_carry = fb_offsets[&fb_inputs[0]];
5162 let fb_xs: Vec<usize> = (1..fb_inputs.len())
5163 .map(|i| fb_offsets[&fb_inputs[i]])
5164 .collect();
5165 let fb_out = fb_offsets[&fb.outputs[0]];
5166 let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
5167 for n in fb.nodes() {
5168 if let Op::Constant { data } = &n.op
5169 && fb_arena.has_buffer(n.id)
5170 && !data.is_empty()
5171 {
5172 let off = fb_arena.byte_offset(n.id);
5179 let buf = fb_arena.raw_buf_mut();
5180 let nb = (buf.len() - off).min(data.len());
5181 buf[off..off + nb].copy_from_slice(&data[..nb]);
5182 }
5183 }
5184 let fb_init_bytes = fb_arena.raw_buf().to_vec();
5185 let fb_sched = compile_thunks(fb, &fb_arena);
5186 (
5187 Some(Arc::new(fb_sched)),
5188 Some(Arc::new(fb_init_bytes)),
5189 fb_carry,
5190 fb_out,
5191 fb_xs,
5192 )
5193 } else {
5194 (None, None, 0, 0, Vec::new())
5195 };
5196
5197 Thunk::ScanBackwardXs {
5198 body_vjp: Arc::new(body_schedule),
5199 body_init: Arc::new(body_init),
5200 body_carry_in_off,
5201 body_x_offs: Arc::new(body_x_offs),
5202 body_d_output_off,
5203 body_dcarry_out_off,
5204 body_dxs_out_off,
5205 outer_init_off: node_offset(arena, node.inputs[0]),
5206 outer_traj_off: node_offset(arena, node.inputs[1]),
5207 outer_upstream_off: node_offset(arena, node.inputs[2]),
5208 outer_xs_offs: Arc::new(outer_xs_offs),
5209 outer_dxs_off: node_offset(arena, node.id),
5210 length: *length,
5211 carry_bytes: carry_bytes as u32,
5212 carry_elem_size,
5213 per_step_bytes: per_step_bytes as u32,
5214 save_trajectory: *save_trajectory,
5215 num_checkpoints: *num_checkpoints,
5216 forward_body: fb_schedule,
5217 forward_body_init: fb_init,
5218 forward_body_carry_in_off: fb_carry_in_off,
5219 forward_body_output_off: fb_output_off,
5220 forward_body_x_offs: Arc::new(fb_x_offs),
5221 }
5222 }
5223
5224 Op::Concat { axis } => {
5225 let out_shape = &node.shape;
5229 let rank = out_shape.rank();
5230 let outer: usize = (0..*axis)
5231 .map(|i| out_shape.dim(i).unwrap_static())
5232 .product::<usize>()
5233 .max(1);
5234 let inner: usize = (*axis + 1..rank)
5235 .map(|i| out_shape.dim(i).unwrap_static())
5236 .product::<usize>()
5237 .max(1);
5238 let total_axis = out_shape.dim(*axis).unwrap_static();
5239 let inputs: Vec<(usize, u32)> = node
5240 .inputs
5241 .iter()
5242 .map(|&in_id| {
5243 let in_shape = &graph.node(in_id).shape;
5244 let in_axis = in_shape.dim(*axis).unwrap_static();
5245 (node_offset(arena, in_id), in_axis as u32)
5246 })
5247 .collect();
5248 let dst = node_offset(arena, node.id);
5249 match out_shape.dtype() {
5250 rlx_ir::DType::F64 => Thunk::ConcatF64 {
5251 dst,
5252 outer: outer as u32,
5253 inner: inner as u32,
5254 total_axis: total_axis as u32,
5255 inputs,
5256 },
5257 _ => Thunk::Concat {
5258 dst,
5259 outer: outer as u32,
5260 inner: inner as u32,
5261 total_axis: total_axis as u32,
5262 inputs,
5263 },
5264 }
5265 }
5266
5267 Op::GaussianSplatRender {
5268 width,
5269 height,
5270 tile_size,
5271 radius_scale,
5272 alpha_cutoff,
5273 max_splat_steps,
5274 transmittance_threshold,
5275 max_list_entries,
5276 } => {
5277 let elem_len =
5278 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5279 Thunk::GaussianSplatRender {
5280 positions_off: node_offset(arena, node.inputs[0]),
5281 positions_len: elem_len(node.inputs[0]),
5282 scales_off: node_offset(arena, node.inputs[1]),
5283 scales_len: elem_len(node.inputs[1]),
5284 rotations_off: node_offset(arena, node.inputs[2]),
5285 rotations_len: elem_len(node.inputs[2]),
5286 opacities_off: node_offset(arena, node.inputs[3]),
5287 opacities_len: elem_len(node.inputs[3]),
5288 colors_off: node_offset(arena, node.inputs[4]),
5289 colors_len: elem_len(node.inputs[4]),
5290 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5291 sh_coeffs_len: elem_len(node.inputs[5]),
5292 meta_off: node_offset(arena, node.inputs[6]),
5293 dst_off: node_offset(arena, node.id),
5294 dst_len: node.shape.num_elements().unwrap_or(0),
5295 width: *width,
5296 height: *height,
5297 tile_size: *tile_size,
5298 radius_scale: *radius_scale,
5299 alpha_cutoff: *alpha_cutoff,
5300 max_splat_steps: *max_splat_steps,
5301 transmittance_threshold: *transmittance_threshold,
5302 max_list_entries: *max_list_entries,
5303 }
5304 }
5305
5306 Op::GaussianSplatRenderBackward {
5307 width,
5308 height,
5309 tile_size,
5310 radius_scale,
5311 alpha_cutoff,
5312 max_splat_steps,
5313 transmittance_threshold,
5314 max_list_entries,
5315 loss_grad_clip,
5316 sh_band,
5317 max_anisotropy,
5318 } => {
5319 let elem_len =
5320 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5321 Thunk::GaussianSplatRenderBackward {
5322 positions_off: node_offset(arena, node.inputs[0]),
5323 positions_len: elem_len(node.inputs[0]),
5324 scales_off: node_offset(arena, node.inputs[1]),
5325 scales_len: elem_len(node.inputs[1]),
5326 rotations_off: node_offset(arena, node.inputs[2]),
5327 rotations_len: elem_len(node.inputs[2]),
5328 opacities_off: node_offset(arena, node.inputs[3]),
5329 opacities_len: elem_len(node.inputs[3]),
5330 colors_off: node_offset(arena, node.inputs[4]),
5331 colors_len: elem_len(node.inputs[4]),
5332 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5333 sh_coeffs_len: elem_len(node.inputs[5]),
5334 meta_off: node_offset(arena, node.inputs[6]),
5335 d_loss_off: node_offset(arena, node.inputs[7]),
5336 d_loss_len: elem_len(node.inputs[7]),
5337 packed_off: node_offset(arena, node.id),
5338 packed_len: node.shape.num_elements().unwrap_or(0),
5339 width: *width,
5340 height: *height,
5341 tile_size: *tile_size,
5342 radius_scale: *radius_scale,
5343 alpha_cutoff: *alpha_cutoff,
5344 max_splat_steps: *max_splat_steps,
5345 transmittance_threshold: *transmittance_threshold,
5346 max_list_entries: *max_list_entries,
5347 loss_grad_clip: *loss_grad_clip,
5348 sh_band: *sh_band,
5349 max_anisotropy: *max_anisotropy,
5350 }
5351 }
5352
5353 Op::GaussianSplatPrepare {
5354 width,
5355 height,
5356 tile_size,
5357 radius_scale,
5358 alpha_cutoff,
5359 max_splat_steps,
5360 transmittance_threshold,
5361 max_list_entries,
5362 } => {
5363 let elem_len =
5364 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5365 Thunk::GaussianSplatPrepare {
5366 positions_off: node_offset(arena, node.inputs[0]),
5367 positions_len: elem_len(node.inputs[0]),
5368 scales_off: node_offset(arena, node.inputs[1]),
5369 scales_len: elem_len(node.inputs[1]),
5370 rotations_off: node_offset(arena, node.inputs[2]),
5371 rotations_len: elem_len(node.inputs[2]),
5372 opacities_off: node_offset(arena, node.inputs[3]),
5373 opacities_len: elem_len(node.inputs[3]),
5374 colors_off: node_offset(arena, node.inputs[4]),
5375 colors_len: elem_len(node.inputs[4]),
5376 sh_coeffs_off: node_offset(arena, node.inputs[5]),
5377 sh_coeffs_len: elem_len(node.inputs[5]),
5378 meta_off: node_offset(arena, node.inputs[6]),
5379 meta_len: elem_len(node.inputs[6]),
5380 prep_off: node_offset(arena, node.id),
5381 prep_len: node.shape.num_elements().unwrap_or(0),
5382 width: *width,
5383 height: *height,
5384 tile_size: *tile_size,
5385 radius_scale: *radius_scale,
5386 alpha_cutoff: *alpha_cutoff,
5387 max_splat_steps: *max_splat_steps,
5388 transmittance_threshold: *transmittance_threshold,
5389 max_list_entries: *max_list_entries,
5390 }
5391 }
5392
5393 Op::GaussianSplatRasterize {
5394 width,
5395 height,
5396 tile_size,
5397 alpha_cutoff,
5398 max_splat_steps,
5399 transmittance_threshold,
5400 max_list_entries,
5401 } => {
5402 let elem_len =
5403 |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5404 let prep_id = node.inputs[0];
5405 let count = match &graph.node(prep_id).op {
5406 rlx_ir::Op::GaussianSplatPrepare { .. } => {
5407 elem_len(graph.node(prep_id).inputs[0]) / 3
5408 }
5409 _ => 1,
5410 };
5411 Thunk::GaussianSplatRasterize {
5412 prep_off: node_offset(arena, prep_id),
5413 prep_len: elem_len(prep_id),
5414 meta_off: node_offset(arena, node.inputs[1]),
5415 meta_len: elem_len(node.inputs[1]),
5416 dst_off: node_offset(arena, node.id),
5417 dst_len: node.shape.num_elements().unwrap_or(0),
5418 count,
5419 width: *width,
5420 height: *height,
5421 tile_size: *tile_size,
5422 alpha_cutoff: *alpha_cutoff,
5423 max_splat_steps: *max_splat_steps,
5424 transmittance_threshold: *transmittance_threshold,
5425 max_list_entries: *max_list_entries,
5426 }
5427 }
5428
5429 Op::Custom { name, attrs, .. } => {
5430 let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5431 panic!(
5432 "compile_thunks: no CPU kernel registered for \
5433 Op::Custom('{name}'). Register one via \
5434 rlx_cpu::op_registry::register_cpu_kernel \
5435 before compiling on the CPU backend."
5436 )
5437 });
5438 let inputs_v: Vec<(usize, u32, Shape)> = node
5439 .inputs
5440 .iter()
5441 .map(|&in_id| {
5442 let s = graph.node(in_id).shape.clone();
5443 let len = s.num_elements().unwrap_or(0) as u32;
5444 (node_offset(arena, in_id), len, s)
5445 })
5446 .collect();
5447 let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5448 Thunk::CustomOp {
5449 kernel,
5450 inputs: inputs_v,
5451 output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5452 attrs: attrs.clone(),
5453 }
5454 }
5455
5456 Op::Fft { inverse, norm } => {
5457 let shape = &node.shape;
5458 let meta = rlx_ir::fft::fft_meta(shape);
5459 let dtype = shape.dtype();
5460 assert!(
5461 matches!(
5462 dtype,
5463 rlx_ir::DType::F32 | rlx_ir::DType::F64 | rlx_ir::DType::C64
5464 ),
5465 "Op::Fft on CPU requires F32, F64, or C64, got {dtype:?}"
5466 );
5467 Thunk::Fft1d {
5468 src: node_offset(arena, node.inputs[0]),
5469 dst: node_offset(arena, node.id),
5470 outer: meta.outer as u32,
5471 n_complex: meta.n_complex as u32,
5472 inverse: *inverse,
5473 norm_tag: norm.tag(),
5474 dtype,
5475 }
5476 }
5477
5478 Op::FftButterflyStage { stage, n_fft } => {
5479 let state_shape = graph.node(node.inputs[0]).shape.clone();
5480 assert_eq!(
5481 state_shape.dtype(),
5482 rlx_ir::DType::F32,
5483 "Op::FftButterflyStage requires F32 state"
5484 );
5485 let batch = state_shape.dim(0).unwrap_static() as u32;
5486 Thunk::FftButterflyStage {
5487 state_src: node_offset(arena, node.inputs[0]),
5488 state_dst: node_offset(arena, node.id),
5489 gate_src: node_offset(arena, node.inputs[1]),
5490 rev_src: node_offset(arena, node.inputs[2]),
5491 tw_re_src: node_offset(arena, node.inputs[3]),
5492 tw_im_src: node_offset(arena, node.inputs[4]),
5493 batch,
5494 n_fft: *n_fft,
5495 stage: *stage,
5496 }
5497 }
5498
5499 Op::LogMel => {
5500 let spec_shape = graph.node(node.inputs[0]).shape.clone();
5501 let filt_shape = graph.node(node.inputs[1]).shape.clone();
5502 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5503 .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
5504 Thunk::LogMel {
5505 spec: node_offset(arena, node.inputs[0]),
5506 filters: node_offset(arena, node.inputs[1]),
5507 dst: node_offset(arena, node.id),
5508 outer: meta.outer as u32,
5509 n_fft: meta.n_fft as u32,
5510 n_bins: meta.n_bins as u32,
5511 n_mels: meta.n_mels as u32,
5512 }
5513 }
5514
5515 Op::LogMelBackward => {
5516 let spec_shape = graph.node(node.inputs[0]).shape.clone();
5517 let filt_shape = graph.node(node.inputs[1]).shape.clone();
5518 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5519 .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
5520 Thunk::LogMelBackward {
5521 spec: node_offset(arena, node.inputs[0]),
5522 filters: node_offset(arena, node.inputs[1]),
5523 dy: node_offset(arena, node.inputs[2]),
5524 dst: node_offset(arena, node.id),
5525 outer: meta.outer as u32,
5526 n_fft: meta.n_fft as u32,
5527 n_bins: meta.n_bins as u32,
5528 n_mels: meta.n_mels as u32,
5529 }
5530 }
5531
5532 Op::CustomFn {
5533 fwd_body,
5534 num_inputs,
5535 ..
5536 } => {
5537 let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5543 let body_offsets: HashMap<NodeId, usize> = body_plan
5544 .assignments
5545 .iter()
5546 .map(|(id, slot)| (*id, slot.offset))
5547 .collect();
5548
5549 let mut body_input_ids: Vec<NodeId> = fwd_body
5550 .nodes()
5551 .iter()
5552 .filter(|n| matches!(n.op, Op::Input { .. }))
5553 .map(|n| n.id)
5554 .collect();
5555 body_input_ids.sort();
5556 assert_eq!(
5557 body_input_ids.len(),
5558 *num_inputs as usize,
5559 "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5560 body_input_ids.len(),
5561 *num_inputs,
5562 );
5563
5564 let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5565 for n in fwd_body.nodes() {
5566 if let Op::Constant { data } = &n.op
5567 && body_arena.has_buffer(n.id)
5568 && !data.is_empty()
5569 {
5570 match n.shape.dtype() {
5571 rlx_ir::DType::F64 => {
5572 let off = body_arena.byte_offset(n.id);
5573 let buf = body_arena.raw_buf_mut();
5574 let nb = (buf.len() - off).min(data.len());
5575 buf[off..off + nb].copy_from_slice(&data[..nb]);
5576 }
5577 _ => {
5578 let buf = body_arena.slice_mut(n.id);
5579 let nf = data.len() / 4;
5580 let nl = buf.len().min(nf);
5581 for i in 0..nl {
5582 let bytes = [
5583 data[i * 4],
5584 data[i * 4 + 1],
5585 data[i * 4 + 2],
5586 data[i * 4 + 3],
5587 ];
5588 buf[i] = f32::from_le_bytes(bytes);
5589 }
5590 }
5591 }
5592 }
5593 }
5594 let body_init = body_arena.raw_buf().to_vec();
5595 let body_schedule = compile_thunks(fwd_body, &body_arena);
5596
5597 let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5599 .map(|i| {
5600 let body_in = body_input_ids[i];
5601 let body_off = body_offsets[&body_in];
5602 let outer_in = node.inputs[i];
5603 let outer_off = node_offset(arena, outer_in);
5604 let bytes = graph
5605 .node(outer_in)
5606 .shape
5607 .size_bytes()
5608 .expect("Op::CustomFn primal input must have static shape");
5609 (body_off, outer_off, bytes as u32)
5610 })
5611 .collect();
5612
5613 let body_output_id = fwd_body
5614 .outputs
5615 .first()
5616 .copied()
5617 .expect("Op::CustomFn fwd_body must declare exactly one output");
5618 let body_output_off = body_offsets[&body_output_id];
5619 let out_bytes = node
5620 .shape
5621 .size_bytes()
5622 .expect("Op::CustomFn output must have static shape");
5623
5624 Thunk::CustomFn {
5625 body: Arc::new(body_schedule),
5626 body_init: Arc::new(body_init),
5627 inputs: Arc::new(inputs_v),
5628 body_output_off,
5629 outer_output_off: node_offset(arena, node.id),
5630 out_bytes: out_bytes as u32,
5631 }
5632 }
5633
5634 _ => Thunk::Nop,
5635 };
5636 thunks.push(t);
5637 }
5638
5639 let cfg = crate::config::RuntimeConfig::global();
5640 let mask_thr = cfg.mask_binary_threshold;
5641 let mask_neg = cfg.attn_mask_neg_inf;
5642 let score_skip = cfg.score_skip_threshold;
5643
5644 let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5646 .iter()
5647 .filter(|t| !matches!(t, Thunk::Nop))
5648 .map(|thunk| {
5649 match thunk.clone() {
5650 Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5651
5652 Thunk::Sgemm { a, b, c, m, k, n } => {
5653 let (m, k, n) = (m as usize, k as usize, n as usize);
5654 Arc::new(move |base: *mut u8| unsafe {
5655 crate::blas::sgemm(
5656 sl(a, base, m * k),
5657 sl(b, base, k * n),
5658 sl_mut(c, base, m * n),
5659 m,
5660 k,
5661 n,
5662 );
5663 })
5664 }
5665
5666 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5667 let (n_, nrhs_) = (n as usize, nrhs as usize);
5668 Arc::new(move |base: *mut u8| unsafe {
5669 let a_src = sl_f64(a, base, n_ * n_);
5670 let b_src = sl_f64(b, base, n_ * nrhs_);
5671 let mut a_scratch: Vec<f64> = a_src.to_vec();
5672 let mut x_buf: Vec<f64> = b_src.to_vec();
5673 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5674 if info != 0 {
5675 panic!("DenseSolveF64: singular (info={info})");
5676 }
5677 sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5678 })
5679 }
5680
5681 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5682 let (n_, nrhs_) = (n as usize, nrhs as usize);
5683 Arc::new(move |base: *mut u8| unsafe {
5684 let a_src = sl(a, base, n_ * n_);
5685 let b_src = sl(b, base, n_ * nrhs_);
5686 let mut a_scratch: Vec<f32> = a_src.to_vec();
5687 let mut x_buf: Vec<f32> = b_src.to_vec();
5688 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5689 if info != 0 {
5690 panic!("DenseSolveF32: singular (info={info})");
5691 }
5692 sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5693 })
5694 }
5695
5696 Thunk::FusedMmBiasAct {
5697 a,
5698 w,
5699 bias,
5700 c,
5701 m,
5702 k,
5703 n,
5704 act,
5705 } => {
5706 let (m, k, n) = (m as usize, k as usize, n as usize);
5707 Arc::new(move |base: *mut u8| unsafe {
5708 let out = sl_mut(c, base, m * n);
5709 crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5710 match act {
5718 Some(Activation::Gelu) => {
5719 crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5720 }
5721 Some(other) => {
5722 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5723 apply_activation_inplace(out, other);
5724 }
5725 None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5726 }
5727 })
5728 }
5729
5730 Thunk::FusedResidualLN {
5731 x,
5732 res,
5733 bias,
5734 g,
5735 b,
5736 out,
5737 rows,
5738 h,
5739 eps,
5740 has_bias,
5741 } => {
5742 let (rows, h) = (rows as usize, h as usize);
5743 Arc::new(move |base: *mut u8| unsafe {
5744 let zero = vec![0f32; h]; let bi = if has_bias { sl(bias, base, h) } else { &zero };
5746 let xp = sl(x, base, rows * h).as_ptr() as usize;
5747 let rp = sl(res, base, rows * h).as_ptr() as usize;
5748 let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5749 let bp = bi.as_ptr() as usize;
5750 let gp = sl(g, base, h).as_ptr() as usize;
5751 let bbp = sl(b, base, h).as_ptr() as usize;
5752 crate::pool::par_for(rows, 4, &|off, cnt| {
5753 let xs = std::slice::from_raw_parts(
5754 (xp as *const f32).add(off * h),
5755 cnt * h,
5756 );
5757 let rs = std::slice::from_raw_parts(
5758 (rp as *const f32).add(off * h),
5759 cnt * h,
5760 );
5761 let os = std::slice::from_raw_parts_mut(
5762 (op as *mut f32).add(off * h),
5763 cnt * h,
5764 );
5765 let bi = std::slice::from_raw_parts(bp as *const f32, h);
5766 let g = std::slice::from_raw_parts(gp as *const f32, h);
5767 let b = std::slice::from_raw_parts(bbp as *const f32, h);
5768 crate::kernels::residual_bias_layer_norm(
5769 xs, rs, bi, g, b, os, cnt, h, eps,
5770 );
5771 });
5772 })
5773 }
5774
5775 Thunk::BiasAdd {
5776 src,
5777 bias,
5778 dst,
5779 m,
5780 n,
5781 } => {
5782 let (m, n) = (m as usize, n as usize);
5783 let len = m * n;
5784 Arc::new(move |base: *mut u8| unsafe {
5785 let out = sl_mut(dst, base, len);
5786 if src != dst {
5787 let src_ptr = base.add(src) as *const f32;
5788 let dst_ptr = base.add(dst) as *mut f32;
5789 if src_ptr != dst_ptr {
5790 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
5791 }
5792 }
5793 crate::blas::bias_add(out, sl(bias, base, n), m, n);
5794 })
5795 }
5796
5797 Thunk::Gather {
5798 table,
5799 table_len,
5800 idx,
5801 dst,
5802 num_idx,
5803 trailing,
5804 idx_i64,
5805 table_bytes,
5806 } => {
5807 let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5808 let rows = tl / tr.max(1);
5809 let (idx_i64, table_bytes) = (idx_i64, table_bytes);
5810 Arc::new(move |base: *mut u8| unsafe {
5811 if table_bytes == 8 {
5812 let tab = sl_i64(table, base, tl);
5813 let out = sl_mut_i64(dst, base, ni * tr);
5814 if idx_i64 != 0 {
5815 let ids = sl_i64(idx, base, ni);
5816 for i in 0..ni {
5817 let row = ids[i].max(0) as usize;
5818 if row < rows {
5819 out[i * tr..(i + 1) * tr]
5820 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5821 }
5822 }
5823 } else {
5824 let ids = sl(idx, base, ni);
5825 for i in 0..ni {
5826 let row = ids[i] as usize;
5827 if row < rows {
5828 out[i * tr..(i + 1) * tr]
5829 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5830 }
5831 }
5832 }
5833 } else {
5834 let tab = sl(table, base, tl);
5835 let out = sl_mut(dst, base, ni * tr);
5836 if idx_i64 != 0 {
5837 let ids = sl_i64(idx, base, ni);
5838 for i in 0..ni {
5839 let row = ids[i].max(0) as usize;
5840 if row < rows {
5841 out[i * tr..(i + 1) * tr]
5842 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5843 }
5844 }
5845 } else {
5846 let ids = sl(idx, base, ni);
5847 for i in 0..ni {
5848 let row = ids[i] as usize;
5849 if row < rows {
5850 out[i * tr..(i + 1) * tr]
5851 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5852 }
5853 }
5854 }
5855 }
5856 })
5857 }
5858
5859 Thunk::Narrow {
5860 src,
5861 dst,
5862 outer,
5863 src_stride,
5864 dst_stride,
5865 inner,
5866 elem_bytes,
5867 } => {
5868 narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5869 }
5870
5871 Thunk::Copy { src, dst, len } => {
5872 let len = len as usize;
5873 Arc::new(move |base: *mut u8| unsafe {
5874 if src == dst || len == 0 {
5875 return;
5876 }
5877 let src_ptr = base.add(src) as *const f32;
5878 let dst_ptr = base.add(dst) as *mut f32;
5879 if src_ptr == dst_ptr {
5880 return;
5881 }
5882 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
5883 })
5884 }
5885
5886 Thunk::Softmax { data, rows, cols } => {
5887 let (rows, cols) = (rows as usize, cols as usize);
5888 Arc::new(move |base: *mut u8| unsafe {
5889 crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5890 })
5891 }
5892
5893 Thunk::Cumsum {
5894 src,
5895 dst,
5896 rows,
5897 cols,
5898 exclusive,
5899 } => {
5900 let (rows, cols) = (rows as usize, cols as usize);
5901 Arc::new(move |base: *mut u8| unsafe {
5902 let s = sl(src, base, rows * cols);
5903 let d = sl_mut(dst, base, rows * cols);
5904 if exclusive {
5905 for r in 0..rows {
5906 let mut acc = 0.0f32;
5907 for c in 0..cols {
5908 d[r * cols + c] = acc;
5909 acc += s[r * cols + c];
5910 }
5911 }
5912 } else {
5913 for r in 0..rows {
5914 let mut acc = 0.0f32;
5915 for c in 0..cols {
5916 acc += s[r * cols + c];
5917 d[r * cols + c] = acc;
5918 }
5919 }
5920 }
5921 })
5922 }
5923
5924 Thunk::Sample {
5925 logits,
5926 dst,
5927 batch,
5928 vocab,
5929 top_k,
5930 top_p,
5931 temperature,
5932 seed,
5933 } => {
5934 let (b, v) = (batch as usize, vocab as usize);
5935 let k = (top_k as usize).min(v);
5936 Arc::new(move |base: *mut u8| unsafe {
5937 let lg = sl(logits, base, b * v);
5938 let out = sl_mut(dst, base, b);
5939 let mut rng =
5940 rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5941 for bi in 0..b {
5942 let row = &lg[bi * v..(bi + 1) * v];
5943 out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5944 }
5945 })
5946 }
5947
5948 Thunk::DequantMatMul {
5949 x,
5950 w_q,
5951 scale,
5952 zp,
5953 dst,
5954 m,
5955 k,
5956 n,
5957 block_size,
5958 is_asymmetric,
5959 } => {
5960 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5961 let n_blocks_per_col = k.div_ceil(bs);
5962 Arc::new(move |base: *mut u8| unsafe {
5963 let xs = sl(x, base, m * k);
5964 let raw = base.add(w_q);
5966 let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5967 let scales = sl(scale, base, n_blocks_per_col * n);
5968 let zps = if is_asymmetric {
5969 sl(zp, base, n_blocks_per_col * n)
5970 } else {
5971 &[][..]
5972 };
5973 let out = sl_mut(dst, base, m * n);
5974 dequant_matmul_int8(
5975 xs,
5976 w_bytes,
5977 scales,
5978 zps,
5979 out,
5980 m,
5981 k,
5982 n,
5983 bs,
5984 is_asymmetric,
5985 );
5986 })
5987 }
5988
5989 Thunk::DequantMatMulGguf {
5990 x,
5991 w_q,
5992 dst,
5993 m,
5994 k,
5995 n,
5996 scheme,
5997 } => {
5998 let (m, k, n) = (m as usize, k as usize, n as usize);
5999 let block_bytes = scheme.gguf_block_bytes() as usize;
6000 let block_elems = scheme.gguf_block_size() as usize;
6001 let total_bytes = (k * n) / block_elems * block_bytes;
6002 Arc::new(move |base: *mut u8| unsafe {
6003 let xs = sl(x, base, m * k);
6004 let w_bytes =
6005 std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
6006 let out = sl_mut(dst, base, m * n);
6007 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
6008 })
6009 }
6010
6011 Thunk::DequantMatMulInt4 {
6012 x,
6013 w_q,
6014 scale,
6015 zp,
6016 dst,
6017 m,
6018 k,
6019 n,
6020 block_size,
6021 is_asymmetric,
6022 } => {
6023 let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
6024 let n_blocks = k.div_ceil(bs);
6025 Arc::new(move |base: *mut u8| unsafe {
6026 let xs = sl(x, base, m * k);
6027 let w_bytes = std::slice::from_raw_parts(
6028 base.add(w_q) as *const u8,
6029 (k * n).div_ceil(2),
6030 );
6031 let scales = sl(scale, base, n_blocks * n);
6032 let zps = if is_asymmetric {
6033 sl(zp, base, n_blocks * n)
6034 } else {
6035 &[][..]
6036 };
6037 let out = sl_mut(dst, base, m * n);
6038 dequant_matmul_int4(
6039 xs,
6040 w_bytes,
6041 scales,
6042 zps,
6043 out,
6044 m,
6045 k,
6046 n,
6047 bs,
6048 is_asymmetric,
6049 );
6050 })
6051 }
6052
6053 Thunk::DequantMatMulFp8 {
6054 x,
6055 w_q,
6056 scale,
6057 dst,
6058 m,
6059 k,
6060 n,
6061 e5m2,
6062 } => {
6063 let (m, k, n) = (m as usize, k as usize, n as usize);
6064 Arc::new(move |base: *mut u8| unsafe {
6065 let xs = sl(x, base, m * k);
6066 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
6067 let scales = sl(scale, base, n);
6068 let out = sl_mut(dst, base, m * n);
6069 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
6070 })
6071 }
6072
6073 Thunk::DequantMatMulNvfp4 {
6074 x,
6075 w_q,
6076 scale,
6077 global_scale,
6078 dst,
6079 m,
6080 k,
6081 n,
6082 } => {
6083 let (m, k, n) = (m as usize, k as usize, n as usize);
6084 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
6085 Arc::new(move |base: *mut u8| unsafe {
6086 let xs = sl(x, base, m * k);
6087 let w_bytes = std::slice::from_raw_parts(
6088 base.add(w_q) as *const u8,
6089 (k * n).div_ceil(2),
6090 );
6091 let scale_bytes =
6092 std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
6093 let gs = sl(global_scale, base, 1)[0];
6094 let out = sl_mut(dst, base, m * n);
6095 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
6096 })
6097 }
6098
6099 Thunk::LoraMatMul {
6100 x,
6101 w,
6102 a,
6103 b,
6104 dst,
6105 m,
6106 k,
6107 n,
6108 r,
6109 scale,
6110 } => {
6111 let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
6112 Arc::new(move |base: *mut u8| unsafe {
6113 let xs = sl(x, base, m * k);
6114 let ws = sl(w, base, k * n);
6115 let a_s = sl(a, base, k * r);
6116 let bs = sl(b, base, r * n);
6117 let out = sl_mut(dst, base, m * n);
6118 crate::blas::sgemm(xs, ws, out, m, k, n);
6120 let mut tmp = vec![0f32; m * r];
6122 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
6123 if scale != 1.0 {
6127 for v in tmp.iter_mut() {
6128 *v *= scale;
6129 }
6130 }
6131 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
6132 })
6133 }
6134
6135 Thunk::LayerNorm {
6136 src,
6137 g,
6138 b,
6139 dst,
6140 rows,
6141 h,
6142 eps,
6143 } => {
6144 let (rows, h) = (rows as usize, h as usize);
6145 Arc::new(move |base: *mut u8| unsafe {
6146 let inp = sl(src, base, rows * h);
6147 let gamma = sl(g, base, h);
6148 let beta = sl(b, base, h);
6149 let out = sl_mut(dst, base, rows * h);
6150 for row in 0..rows {
6151 crate::kernels::layer_norm_row(
6152 &inp[row * h..(row + 1) * h],
6153 gamma,
6154 beta,
6155 &mut out[row * h..(row + 1) * h],
6156 h,
6157 eps,
6158 );
6159 }
6160 })
6161 }
6162
6163 Thunk::BatchNormInference {
6164 src,
6165 g,
6166 b,
6167 mean,
6168 var,
6169 dst,
6170 count,
6171 channels,
6172 eps,
6173 } => {
6174 let count = count as usize;
6175 let c = channels as usize;
6176 let n = count * c;
6177 let (src, g, b, mean, var, dst) = (src, g, b, mean, var, dst);
6178 Arc::new(move |base: *mut u8| unsafe {
6179 crate::kernels::batch_norm_inference(
6180 sl(src, base, n),
6181 sl(g, base, c),
6182 sl(b, base, c),
6183 sl(mean, base, c),
6184 sl(var, base, c),
6185 sl_mut(dst, base, n),
6186 c,
6187 eps,
6188 );
6189 })
6190 }
6191
6192 Thunk::Attention {
6193 q,
6194 k,
6195 v,
6196 mask,
6197 out,
6198 batch,
6199 seq,
6200 kv_seq,
6201 heads,
6202 head_dim,
6203 mask_kind,
6204 q_row_stride,
6205 k_row_stride,
6206 v_row_stride,
6207 bhsd,
6208 } => {
6209 if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6210 eprintln!("[attn-compile] batch={batch} seq={seq} kv_seq={kv_seq} heads={heads} bhsd={bhsd}");
6211 }
6212 let (b, q_s, k_s, nh, dh) = (
6221 batch as usize,
6222 seq as usize,
6223 kv_seq as usize,
6224 heads as usize,
6225 head_dim as usize,
6226 );
6227 let hs = nh * dh;
6228 let qrs = q_row_stride as usize;
6229 let krs = k_row_stride as usize;
6230 let vrs = v_row_stride as usize;
6231 let scale = (dh as f32).powf(-0.5);
6232 Arc::new(move |base: *mut u8| unsafe {
6233 if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6234 eprintln!("[attn] b={b} q_s={q_s} k_s={k_s} nh={nh} dh={dh} bhsd={bhsd} mask_kind={:?}", mask_kind);
6235 }
6236 let (q_len, k_len, v_len, o_len) = if bhsd {
6241 let qn = b * nh * q_s * dh;
6242 let kn = b * nh * k_s * dh;
6243 (qn, kn, kn, qn)
6244 } else {
6245 (b * q_s * qrs, b * k_s * krs, b * k_s * vrs, b * q_s * hs)
6246 };
6247 let q_d = sl(q, base, q_len);
6248 let k_d = sl(k, base, k_len);
6249 let v_d = sl(v, base, v_len);
6250 let m_d: &[f32] = match mask_kind {
6251 rlx_ir::op::MaskKind::Custom => sl(mask, base, b * k_s),
6252 rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * q_s * k_s),
6253 _ => &[],
6254 };
6255 let o_d = sl_mut(out, base, o_len);
6256 let mut qh = vec![0f32; q_s * dh];
6257 let mut kh = vec![0f32; k_s * dh];
6258 let mut vh = vec![0f32; k_s * dh];
6259 let mut sc = vec![0f32; q_s * k_s];
6260 let mut oh = vec![0f32; q_s * dh];
6261 for bi in 0..b {
6262 for hi in 0..nh {
6263 for si in 0..q_s {
6265 let q_off = if bhsd {
6266 bi * nh * q_s * dh + hi * q_s * dh + si * dh
6267 } else {
6268 bi * q_s * qrs + si * qrs + hi * dh
6269 };
6270 qh[si * dh..(si + 1) * dh]
6271 .copy_from_slice(&q_d[q_off..q_off + dh]);
6272 }
6273 for si in 0..k_s {
6275 let (k_off, v_off) = if bhsd {
6276 (
6277 bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6278 bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6279 )
6280 } else {
6281 (
6282 bi * k_s * krs + si * krs + hi * dh,
6283 bi * k_s * vrs + si * vrs + hi * dh,
6284 )
6285 };
6286 kh[si * dh..(si + 1) * dh]
6287 .copy_from_slice(&k_d[k_off..k_off + dh]);
6288 vh[si * dh..(si + 1) * dh]
6289 .copy_from_slice(&v_d[v_off..v_off + dh]);
6290 }
6291 for qi in 0..q_s {
6292 for ki in 0..k_s {
6293 let mut dot = 0f32;
6294 for d in 0..dh {
6295 dot += qh[qi * dh + d] * kh[ki * dh + d];
6296 }
6297 sc[qi * k_s + ki] = dot * scale;
6298 }
6299 }
6300 let q_offset = k_s.saturating_sub(q_s);
6304 match mask_kind {
6305 rlx_ir::op::MaskKind::None => {}
6306 rlx_ir::op::MaskKind::Causal => {
6307 for qi in 0..q_s {
6308 let abs_q = q_offset + qi;
6309 for ki in (abs_q + 1)..k_s {
6310 sc[qi * k_s + ki] = mask_neg;
6311 }
6312 }
6313 }
6314 rlx_ir::op::MaskKind::SlidingWindow(w) => {
6315 for qi in 0..q_s {
6316 let abs_q = q_offset + qi;
6317 let lo = abs_q.saturating_sub(w);
6318 for ki in 0..k_s {
6319 if ki < lo || ki > abs_q {
6320 sc[qi * k_s + ki] = mask_neg;
6321 }
6322 }
6323 }
6324 }
6325 rlx_ir::op::MaskKind::Custom => {
6326 for qi in 0..q_s {
6327 for ki in 0..k_s {
6328 if m_d[bi * k_s + ki] < mask_thr {
6329 sc[qi * k_s + ki] = mask_neg;
6330 }
6331 }
6332 }
6333 }
6334 rlx_ir::op::MaskKind::Bias => {
6335 let per_bh = q_s * k_s;
6336 let off = (bi * nh + hi) * per_bh;
6337 for i in 0..per_bh {
6338 sc[i] += m_d[off + i];
6339 }
6340 }
6341 }
6342 crate::naive::softmax(&mut sc, q_s, k_s);
6343 oh.fill(0.0);
6344 for qi in 0..q_s {
6345 for ki in 0..k_s {
6346 let w = sc[qi * k_s + ki];
6347 if w > score_skip {
6348 for d in 0..dh {
6349 oh[qi * dh + d] += w * vh[ki * dh + d];
6350 }
6351 }
6352 }
6353 }
6354 for si in 0..q_s {
6355 let off = if bhsd {
6356 bi * nh * q_s * dh + hi * q_s * dh + si * dh
6357 } else {
6358 bi * q_s * hs + si * hs + hi * dh
6359 };
6360 o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
6361 }
6362 }
6363 }
6364 })
6365 }
6366
6367 Thunk::FusedSwiGLU {
6368 src,
6369 dst,
6370 n_half,
6371 total,
6372 gate_first,
6373 } => {
6374 let n = n_half as usize;
6375 let t = total as usize;
6376 let outer = t / n;
6377 let in_total = outer * 2 * n;
6378 Arc::new(move |base: *mut u8| unsafe {
6379 let inp = sl(src, base, in_total);
6380 let out = sl_mut(dst, base, t);
6381 for o in 0..outer {
6382 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
6383 let out_row = &mut out[o * n..(o + 1) * n];
6384 for i in 0..n {
6385 let (up, gate) = if gate_first {
6386 (in_row[n + i], in_row[i])
6387 } else {
6388 (in_row[i], in_row[n + i])
6389 };
6390 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
6391 }
6392 }
6393 })
6394 }
6395
6396 Thunk::Concat {
6397 dst,
6398 outer,
6399 inner,
6400 total_axis,
6401 inputs,
6402 } => {
6403 let outer = outer as usize;
6404 let inner = inner as usize;
6405 let total_axis = total_axis as usize;
6406 let out_total = outer * total_axis * inner;
6407 let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
6410 let mut cum: usize = 0;
6411 for (src_off, in_axis) in &inputs {
6412 let in_axis = *in_axis as usize;
6413 layout.push((*src_off, cum * inner, in_axis * inner));
6414 cum += in_axis;
6415 }
6416 Arc::new(move |base: *mut u8| unsafe {
6417 let out = sl_mut(dst, base, out_total);
6418 let row_stride = total_axis * inner;
6419 for (src_off, dst_col_off, copy_per_row) in &layout {
6420 let in_total = outer * *copy_per_row;
6421 let inp = sl(*src_off, base, in_total);
6422 for o in 0..outer {
6423 let dst_row_start = o * row_stride + *dst_col_off;
6424 let src_row_start = o * *copy_per_row;
6425 out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
6426 &inp[src_row_start..src_row_start + *copy_per_row],
6427 );
6428 }
6429 }
6430 })
6431 }
6432
6433 Thunk::CustomOp {
6434 kernel,
6435 inputs,
6436 output,
6437 attrs,
6438 } => {
6439 let kernel = kernel.clone();
6445 let attrs = attrs.clone();
6446 let inputs = inputs.clone();
6447 let (out_off, out_len, out_shape) = output.clone();
6448 Arc::new(move |base: *mut u8| unsafe {
6449 dispatch_custom_op(
6450 &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
6451 );
6452 })
6453 }
6454
6455 Thunk::GaussianSplatRender {
6456 positions_off,
6457 positions_len,
6458 scales_off,
6459 scales_len,
6460 rotations_off,
6461 rotations_len,
6462 opacities_off,
6463 opacities_len,
6464 colors_off,
6465 colors_len,
6466 sh_coeffs_off,
6467 sh_coeffs_len,
6468 meta_off,
6469 dst_off,
6470 dst_len,
6471 width,
6472 height,
6473 tile_size,
6474 radius_scale,
6475 alpha_cutoff,
6476 max_splat_steps,
6477 transmittance_threshold,
6478 max_list_entries,
6479 } => Arc::new(move |base: *mut u8| unsafe {
6480 crate::splat::execute_gaussian_splat_render(
6481 positions_off,
6482 positions_len,
6483 scales_off,
6484 scales_len,
6485 rotations_off,
6486 rotations_len,
6487 opacities_off,
6488 opacities_len,
6489 colors_off,
6490 colors_len,
6491 sh_coeffs_off,
6492 sh_coeffs_len,
6493 meta_off,
6494 dst_off,
6495 dst_len,
6496 width,
6497 height,
6498 tile_size,
6499 radius_scale,
6500 alpha_cutoff,
6501 max_splat_steps,
6502 transmittance_threshold,
6503 max_list_entries,
6504 base,
6505 );
6506 }),
6507
6508 Thunk::GaussianSplatRenderBackward {
6509 positions_off,
6510 positions_len,
6511 scales_off,
6512 scales_len,
6513 rotations_off,
6514 rotations_len,
6515 opacities_off,
6516 opacities_len,
6517 colors_off,
6518 colors_len,
6519 sh_coeffs_off,
6520 sh_coeffs_len,
6521 meta_off,
6522 d_loss_off,
6523 d_loss_len,
6524 packed_off,
6525 packed_len,
6526 width,
6527 height,
6528 tile_size,
6529 radius_scale,
6530 alpha_cutoff,
6531 max_splat_steps,
6532 transmittance_threshold,
6533 max_list_entries,
6534 loss_grad_clip,
6535 sh_band,
6536 max_anisotropy,
6537 } => Arc::new(move |base: *mut u8| unsafe {
6538 crate::splat::execute_gaussian_splat_render_backward(
6539 positions_off,
6540 positions_len,
6541 scales_off,
6542 scales_len,
6543 rotations_off,
6544 rotations_len,
6545 opacities_off,
6546 opacities_len,
6547 colors_off,
6548 colors_len,
6549 sh_coeffs_off,
6550 sh_coeffs_len,
6551 meta_off,
6552 d_loss_off,
6553 d_loss_len,
6554 packed_off,
6555 packed_len,
6556 width,
6557 height,
6558 tile_size,
6559 radius_scale,
6560 alpha_cutoff,
6561 max_splat_steps,
6562 transmittance_threshold,
6563 max_list_entries,
6564 loss_grad_clip,
6565 sh_band,
6566 max_anisotropy,
6567 base,
6568 );
6569 }),
6570
6571 Thunk::GaussianSplatPrepare {
6572 positions_off,
6573 positions_len,
6574 scales_off,
6575 scales_len,
6576 rotations_off,
6577 rotations_len,
6578 opacities_off,
6579 opacities_len,
6580 colors_off,
6581 colors_len,
6582 sh_coeffs_off,
6583 sh_coeffs_len,
6584 meta_off,
6585 meta_len,
6586 prep_off,
6587 prep_len,
6588 width,
6589 height,
6590 tile_size,
6591 radius_scale,
6592 alpha_cutoff,
6593 max_splat_steps,
6594 transmittance_threshold,
6595 max_list_entries,
6596 } => Arc::new(move |base: *mut u8| unsafe {
6597 crate::splat::execute_gaussian_splat_prepare(
6598 positions_off,
6599 positions_len,
6600 scales_off,
6601 scales_len,
6602 rotations_off,
6603 rotations_len,
6604 opacities_off,
6605 opacities_len,
6606 colors_off,
6607 colors_len,
6608 sh_coeffs_off,
6609 sh_coeffs_len,
6610 meta_off,
6611 meta_len,
6612 prep_off,
6613 prep_len,
6614 width,
6615 height,
6616 tile_size,
6617 radius_scale,
6618 alpha_cutoff,
6619 max_splat_steps,
6620 transmittance_threshold,
6621 max_list_entries,
6622 base,
6623 );
6624 }),
6625
6626 Thunk::GaussianSplatRasterize {
6627 prep_off,
6628 prep_len,
6629 meta_off,
6630 meta_len,
6631 dst_off,
6632 dst_len,
6633 count,
6634 width,
6635 height,
6636 tile_size,
6637 alpha_cutoff,
6638 max_splat_steps,
6639 transmittance_threshold,
6640 max_list_entries,
6641 } => Arc::new(move |base: *mut u8| unsafe {
6642 crate::splat::execute_gaussian_splat_rasterize(
6643 prep_off,
6644 prep_len,
6645 meta_off,
6646 meta_len,
6647 dst_off,
6648 dst_len,
6649 count,
6650 width,
6651 height,
6652 tile_size,
6653 alpha_cutoff,
6654 max_splat_steps,
6655 transmittance_threshold,
6656 max_list_entries,
6657 base,
6658 );
6659 }),
6660
6661 Thunk::Fft1d {
6662 src,
6663 dst,
6664 outer,
6665 n_complex,
6666 inverse,
6667 norm_tag,
6668 dtype,
6669 } => {
6670 let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6671 rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6672 execute_fft1d_f64(
6673 src,
6674 dst,
6675 outer as usize,
6676 n_complex as usize,
6677 inverse,
6678 norm_tag,
6679 base,
6680 );
6681 }),
6682 rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6683 execute_fft1d_f32(
6684 src,
6685 dst,
6686 outer as usize,
6687 n_complex as usize,
6688 inverse,
6689 norm_tag,
6690 base,
6691 );
6692 }),
6693 rlx_ir::DType::C64 => Arc::new(move |base: *mut u8| unsafe {
6694 execute_fft1d_c64(
6695 src,
6696 dst,
6697 outer as usize,
6698 n_complex as usize,
6699 inverse,
6700 norm_tag,
6701 base,
6702 );
6703 }),
6704 other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
6705 };
6706 f
6707 }
6708
6709 Thunk::FftButterflyStage {
6710 state_src,
6711 state_dst,
6712 gate_src,
6713 rev_src,
6714 tw_re_src,
6715 tw_im_src,
6716 batch,
6717 n_fft,
6718 stage,
6719 } => Arc::new(move |base: *mut u8| unsafe {
6720 execute_fft_butterfly_stage_f32(
6721 state_src,
6722 state_dst,
6723 gate_src,
6724 rev_src,
6725 tw_re_src,
6726 tw_im_src,
6727 batch as usize,
6728 n_fft as usize,
6729 stage as usize,
6730 base,
6731 );
6732 }),
6733
6734 Thunk::LogMel {
6735 spec,
6736 filters,
6737 dst,
6738 outer,
6739 n_fft,
6740 n_bins,
6741 n_mels,
6742 } => Arc::new(move |base: *mut u8| unsafe {
6743 execute_log_mel_f32(
6744 spec,
6745 filters,
6746 dst,
6747 outer as usize,
6748 n_fft as usize,
6749 n_bins as usize,
6750 n_mels as usize,
6751 base,
6752 );
6753 }),
6754
6755 Thunk::LogMelBackward {
6756 spec,
6757 filters,
6758 dy,
6759 dst,
6760 outer,
6761 n_fft,
6762 n_bins,
6763 n_mels,
6764 } => Arc::new(move |base: *mut u8| unsafe {
6765 execute_log_mel_backward_f32(
6766 spec,
6767 filters,
6768 dy,
6769 dst,
6770 outer as usize,
6771 n_fft as usize,
6772 n_bins as usize,
6773 n_mels as usize,
6774 base,
6775 );
6776 }),
6777
6778 _ => Arc::new(|_: *mut u8| {}),
6779 }
6780 })
6781 .collect();
6782
6783 let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6787 .and_then(|v| v.parse().ok())
6788 .unwrap_or(64);
6789 let should_fuse = thunks.iter().any(|t| match t {
6790 Thunk::Attention { batch, seq, .. } => {
6791 (*batch as usize) * (*seq as usize) <= fuse_threshold
6792 }
6793 _ => false,
6794 });
6795
6796 if should_fuse {
6797 let active: Vec<usize> = thunks
6799 .iter()
6800 .enumerate()
6801 .filter(|(_, t)| !matches!(t, Thunk::Nop))
6802 .map(|(i, _)| i)
6803 .collect();
6804
6805 let mut kill = vec![false; thunks.len()]; let mut insertions: Vec<(usize, Thunk)> = Vec::new(); let mut ai = 0;
6809 while ai < active.len() {
6810 let a = |off: usize| -> Option<(usize, &Thunk)> {
6812 active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6813 };
6814
6815 let matched = (|| {
6817 let (_i0, t0) = a(0)?;
6818 let (_, t1) = a(1)?;
6819 let (_, t2) = a(2)?;
6820 let (_, t3) = a(3)?;
6821
6822 let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6824 Thunk::FusedMmBiasAct {
6825 a,
6826 w,
6827 bias,
6828 n: _,
6829 act: None,
6830 ..
6831 } => (*a, *w, *bias, true),
6832 Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6833 _ => return None,
6834 };
6835
6836 if !matches!(t1, Thunk::Narrow { .. }) {
6838 return None;
6839 }
6840 if !matches!(t2, Thunk::Narrow { .. }) {
6841 return None;
6842 }
6843 if !matches!(t3, Thunk::Narrow { .. }) {
6844 return None;
6845 }
6846
6847 let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6849 _,
6850 Thunk::Rope {
6851 cos, sin, cos_len, ..
6852 },
6853 )) = a(4)
6854 {
6855 if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6856 if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6857 (true, 6, *cos, *sin, *cos_len)
6858 } else {
6859 return None;
6860 }
6861 } else {
6862 return None;
6863 }
6864 } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6865 (false, 4, 0, 0, 0)
6866 } else {
6867 return None;
6868 };
6869
6870 let (_attn_real_idx, attn_t) = a(attn_ai)?;
6871 let (batch, seq, heads, head_dim, mask) = match attn_t {
6872 Thunk::Attention {
6873 batch,
6874 seq,
6875 heads,
6876 head_dim,
6877 mask,
6878 ..
6879 } => (*batch, *seq, *heads, *head_dim, *mask),
6880 _ => return None,
6881 };
6882
6883 let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6885 let (out_w, out_b, out_dst) = match out_t {
6886 Thunk::FusedMmBiasAct {
6887 w,
6888 bias,
6889 c,
6890 act: None,
6891 ..
6892 } => (*w, *bias, *c),
6893 Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6894 _ => return None,
6895 };
6896
6897 let hs = heads * head_dim;
6898 let total_active = attn_ai + 2; Some((
6901 total_active,
6902 Thunk::FusedAttnBlock {
6903 hidden,
6904 qkv_w,
6905 out_w,
6906 mask,
6907 out: out_dst,
6908 qkv_b: if has_b { qkv_b } else { 0 },
6909 out_b: if has_b { out_b } else { 0 },
6910 cos: cos_off,
6911 sin: sin_off,
6912 cos_len: cl,
6913 batch,
6914 seq,
6915 hs,
6916 nh: heads,
6917 dh: head_dim,
6918 has_bias: has_b,
6919 has_rope,
6920 },
6921 ))
6922 })();
6923
6924 if let Some((count, fused_thunk)) = matched {
6925 for off in 0..count {
6927 if let Some(&idx) = active.get(ai + off) {
6928 kill[idx] = true;
6929 }
6930 }
6931 insertions.push((active[ai], fused_thunk));
6933 ai += count;
6934 } else {
6935 ai += 1;
6936 }
6937 }
6938
6939 if !insertions.is_empty() {
6941 let mut new_thunks = Vec::with_capacity(thunks.len());
6942 let mut insert_idx = 0;
6943 for (i, t) in thunks.into_iter().enumerate() {
6944 if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6945 new_thunks.push(insertions[insert_idx].1.clone());
6946 insert_idx += 1;
6947 }
6948 if !kill[i] {
6949 new_thunks.push(t);
6950 }
6951 }
6952 if cfg.verbose >= 1 {
6953 eprintln!(
6954 "[rlx] fused_attention: {} attention blocks fused",
6955 insertions.len()
6956 );
6957 }
6958 thunks = new_thunks;
6959 }
6960 }
6961
6962 if should_fuse {
6967 let active: Vec<usize> = thunks
6968 .iter()
6969 .enumerate()
6970 .filter(|(_, t)| !matches!(t, Thunk::Nop))
6971 .map(|(i, _)| i)
6972 .collect();
6973
6974 let mut kill = vec![false; thunks.len()];
6975 let mut insertions: Vec<(usize, Thunk)> = Vec::new();
6976
6977 let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
6978
6979 let mut ai = 0;
6980 while ai < active.len() {
6981 let bert_match = (|| -> Option<usize> {
6983 let fab = a(ai)?;
6984 let rln1 = a(ai + 1)?;
6985 let ffn1 = a(ai + 2)?;
6986 let ffn2 = a(ai + 3)?;
6987 let rln2 = a(ai + 4)?;
6988
6989 let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
6990 Thunk::FusedAttnBlock {
6991 hidden,
6992 qkv_w,
6993 qkv_b,
6994 out_w,
6995 out_b,
6996 mask,
6997 batch,
6998 seq,
6999 hs,
7000 nh,
7001 dh,
7002 has_bias: true,
7003 has_rope: false,
7004 ..
7005 } => (
7006 *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
7007 ),
7008 _ => return None,
7009 };
7010 let (ln1_g, ln1_b, eps1) = match rln1 {
7011 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7012 _ => return None,
7013 };
7014 let (fc1_w, fc1_b, int_dim) = match ffn1 {
7015 Thunk::FusedMmBiasAct {
7016 w,
7017 bias,
7018 n,
7019 act: Some(Activation::Gelu),
7020 ..
7021 } => (*w, *bias, *n),
7022 _ => return None,
7023 };
7024 let (fc2_w, fc2_b) = match ffn2 {
7025 Thunk::FusedMmBiasAct {
7026 w, bias, act: None, ..
7027 } => (*w, *bias),
7028 _ => return None,
7029 };
7030 let (ln2_g, ln2_b, eps2, out) = match rln2 {
7031 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7032 _ => return None,
7033 };
7034
7035 for off in 0..5 {
7036 kill[active[ai + off]] = true;
7037 }
7038 insertions.push((
7039 active[ai],
7040 Thunk::FusedBertLayer {
7041 hidden,
7042 qkv_w,
7043 qkv_b,
7044 out_w,
7045 out_b,
7046 mask,
7047 ln1_g,
7048 ln1_b,
7049 eps1,
7050 fc1_w,
7051 fc1_b,
7052 fc2_w,
7053 fc2_b,
7054 ln2_g,
7055 ln2_b,
7056 eps2,
7057 out,
7058 batch,
7059 seq,
7060 hs,
7061 nh,
7062 dh,
7063 int_dim,
7064 },
7065 ));
7066 Some(5)
7067 })();
7068 if let Some(n) = bert_match {
7069 ai += n;
7070 continue;
7071 }
7072
7073 #[allow(unreachable_code)]
7077 let nomic_match = (|| -> Option<usize> {
7078 return None; let fab = a(ai)?;
7080 let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
7081 match fab {
7082 Thunk::FusedAttnBlock {
7083 hidden,
7084 qkv_w,
7085 out_w,
7086 mask,
7087 cos,
7088 sin,
7089 cos_len,
7090 batch,
7091 seq,
7092 hs,
7093 nh,
7094 dh,
7095 has_bias: false,
7096 has_rope: true,
7097 ..
7098 } => (
7099 *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
7100 *hs, *nh, *dh,
7101 ),
7102 _ => return None,
7103 };
7104 let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
7106 Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7107 _ => return None,
7108 };
7109 let fused_fc_w = match a(ai + 2)? {
7111 Thunk::Sgemm { b: w, .. } => *w,
7112 _ => return None,
7113 };
7114 if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
7116 return None;
7117 }
7118 if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
7119 return None;
7120 }
7121 if !matches!(
7123 a(ai + 5)?,
7124 Thunk::ActivationInPlace {
7125 act: Activation::Silu,
7126 ..
7127 }
7128 ) {
7129 return None;
7130 }
7131 if !matches!(
7133 a(ai + 6)?,
7134 Thunk::BinaryFull {
7135 op: BinaryOp::Mul,
7136 ..
7137 }
7138 ) {
7139 return None;
7140 }
7141 let fc2_w = match a(ai + 7)? {
7143 Thunk::Sgemm { b: w, .. } => *w,
7144 _ => return None,
7145 };
7146 let int_dim = match a(ai + 3)? {
7148 Thunk::Narrow { inner, .. } => *inner,
7149 _ => return None,
7150 };
7151 let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
7153 Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7154 _ => return None,
7155 };
7156
7157 for off in 0..9 {
7158 kill[active[ai + off]] = true;
7159 }
7160 insertions.push((
7161 active[ai],
7162 Thunk::FusedNomicLayer {
7163 hidden,
7164 qkv_w,
7165 out_w,
7166 mask,
7167 cos,
7168 sin,
7169 cos_len,
7170 ln1_g,
7171 ln1_b,
7172 eps1,
7173 fc11_w: fused_fc_w,
7174 fc12_w: 0,
7175 fc2_w,
7176 ln2_g,
7177 ln2_b,
7178 eps2,
7179 out,
7180 batch,
7181 seq,
7182 hs,
7183 nh,
7184 dh,
7185 int_dim,
7186 },
7187 ));
7188 Some(9)
7189 })();
7190 if let Some(n) = nomic_match {
7191 ai += n;
7192 continue;
7193 }
7194
7195 ai += 1;
7196 }
7197
7198 if !insertions.is_empty() {
7199 let mut new_thunks = Vec::with_capacity(thunks.len());
7200 let mut ins_idx = 0;
7201 for (i, t) in thunks.into_iter().enumerate() {
7202 if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
7203 new_thunks.push(insertions[ins_idx].1.clone());
7204 ins_idx += 1;
7205 }
7206 if !kill[i] {
7207 new_thunks.push(t);
7208 }
7209 }
7210 if cfg.verbose >= 1 {
7211 eprintln!(
7212 "[rlx] fused_layer: {} full transformer layers fused",
7213 insertions.len()
7214 );
7215 }
7216 thunks = new_thunks;
7217 }
7218 }
7219
7220 {
7232 let mut read_offsets: HashMap<usize, usize> = HashMap::new();
7235 for t in &thunks {
7236 for off in thunk_read_offsets(t) {
7237 *read_offsets.entry(off).or_insert(0) += 1;
7238 }
7239 }
7240
7241 let mut fused_count = 0usize;
7242 for i in 0..thunks.len().saturating_sub(1) {
7243 let narrow = match &thunks[i] {
7246 Thunk::Narrow { .. } => i,
7247 _ => continue,
7248 };
7249 let mut j = narrow + 1;
7251 while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
7252 j += 1;
7253 }
7254 if j >= thunks.len() {
7255 continue;
7256 }
7257 let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
7259 Thunk::Narrow {
7260 src,
7261 dst,
7262 src_stride,
7263 ..
7264 } => (*src, *dst, *src_stride),
7265 _ => continue,
7266 };
7267 let rope_reads_narrow = matches!(&thunks[j],
7268 Thunk::Rope { src, .. } if *src == n_dst);
7269 if !rope_reads_narrow {
7270 continue;
7271 }
7272 if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
7276 continue;
7277 }
7278
7279 if let Thunk::Rope {
7282 src,
7283 src_row_stride,
7284 ..
7285 } = &mut thunks[j]
7286 {
7287 *src = n_src;
7288 *src_row_stride = n_src_stride;
7289 }
7290 thunks[narrow] = Thunk::Nop;
7291 fused_count += 1;
7292 }
7293
7294 if fused_count > 0 && cfg.verbose >= 1 {
7295 eprintln!(
7296 "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
7297 fused_count
7298 );
7299 }
7300 }
7301
7302 {
7314 let mut read_counts: HashMap<usize, usize> = HashMap::new();
7315 for t in &thunks {
7316 for off in thunk_read_offsets(t) {
7317 *read_counts.entry(off).or_insert(0) += 1;
7318 }
7319 }
7320 let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
7322 for (i, t) in thunks.iter().enumerate() {
7323 if let Thunk::Narrow { dst, .. } = t {
7324 dst_to_idx.insert(*dst, i);
7325 }
7326 }
7327
7328 let mut fused_count = 0usize;
7329 for i in 0..thunks.len() {
7330 let (q_off, k_off, v_off) = match &thunks[i] {
7331 Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
7332 _ => continue,
7333 };
7334 let q_n = match dst_to_idx.get(&q_off).copied() {
7336 Some(x) => x,
7337 None => continue,
7338 };
7339 let k_n = match dst_to_idx.get(&k_off).copied() {
7340 Some(x) => x,
7341 None => continue,
7342 };
7343 let v_n = match dst_to_idx.get(&v_off).copied() {
7344 Some(x) => x,
7345 None => continue,
7346 };
7347 if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
7349 continue;
7350 }
7351 if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
7352 continue;
7353 }
7354 if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
7355 continue;
7356 }
7357
7358 let (q_src, q_stride) = match &thunks[q_n] {
7359 Thunk::Narrow {
7360 src, src_stride, ..
7361 } => (*src, *src_stride),
7362 _ => continue,
7363 };
7364 let (k_src, k_stride) = match &thunks[k_n] {
7365 Thunk::Narrow {
7366 src, src_stride, ..
7367 } => (*src, *src_stride),
7368 _ => continue,
7369 };
7370 let (v_src, v_stride) = match &thunks[v_n] {
7371 Thunk::Narrow {
7372 src, src_stride, ..
7373 } => (*src, *src_stride),
7374 _ => continue,
7375 };
7376
7377 if let Thunk::Attention {
7378 q,
7379 k,
7380 v,
7381 q_row_stride,
7382 k_row_stride,
7383 v_row_stride,
7384 ..
7385 } = &mut thunks[i]
7386 {
7387 *q = q_src;
7388 *k = k_src;
7389 *v = v_src;
7390 *q_row_stride = q_stride;
7391 *k_row_stride = k_stride;
7392 *v_row_stride = v_stride;
7393 }
7394 thunks[q_n] = Thunk::Nop;
7395 thunks[k_n] = Thunk::Nop;
7396 thunks[v_n] = Thunk::Nop;
7397 fused_count += 1;
7398 }
7399
7400 if fused_count > 0 && cfg.verbose >= 1 {
7401 eprintln!(
7402 "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
7403 fused_count
7404 );
7405 }
7406 }
7407
7408 ThunkSchedule {
7409 thunks,
7410 moe_resident: None,
7411 moe_resident_layers: None,
7412 moe_topk_capture: None,
7413 mask_threshold: cfg.mask_binary_threshold,
7414 mask_neg_inf: cfg.attn_mask_neg_inf,
7415 score_skip: cfg.score_skip_threshold,
7416 compiled_fns,
7417 }
7418}
7419
7420fn get_len(graph: &Graph, id: NodeId) -> usize {
7421 graph.node(id).shape.num_elements().unwrap_or(0)
7422}
7423
7424fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
7426 let dims = graph.node(id).shape.dims();
7427 let mut out = Vec::with_capacity(dims.len());
7428 for d in dims {
7429 if let Some(s) = match d {
7430 rlx_ir::Dim::Static(s) => Some(*s),
7431 _ => None,
7432 } {
7433 out.push(s);
7434 } else {
7435 return Vec::new();
7436 }
7437 }
7438 out
7439}
7440
7441fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
7459 if rhs_dims.len() > out_dims.len() {
7460 return false;
7461 }
7462 let off = out_dims.len() - rhs_dims.len();
7463 for i in 0..rhs_dims.len() {
7464 let r = match rhs_dims[i] {
7465 rlx_ir::Dim::Static(n) => n,
7466 _ => return false,
7467 };
7468 let o = match out_dims[off + i] {
7469 rlx_ir::Dim::Static(n) => n,
7470 _ => return false,
7471 };
7472 if r != o {
7473 return false;
7474 }
7475 }
7476 true
7477}
7478
7479fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
7480 let r_out = out_dims.len();
7481 let r_in = in_dims.len();
7482 assert!(
7483 r_in <= r_out,
7484 "broadcast: input rank {r_in} > output rank {r_out}"
7485 );
7486 let pad = r_out - r_in;
7487 let mut strides = vec![0u32; r_out];
7488 let mut acc: usize = 1;
7489 for d in (0..r_out).rev() {
7490 let in_size = if d < pad { 1 } else { in_dims[d - pad] };
7491 if in_size == 1 {
7492 strides[d] = 0;
7493 } else {
7494 assert_eq!(
7495 in_size, out_dims[d],
7496 "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
7497 out_dims[d]
7498 );
7499 strides[d] = acc as u32;
7500 acc *= in_size;
7501 }
7502 }
7503 strides
7504}
7505
7506pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7510 let base = arena_buf.as_mut_ptr();
7511 for f in &schedule.compiled_fns {
7512 f(base);
7513 }
7514}
7515
7516pub fn execute_thunks_active(
7521 schedule: &ThunkSchedule,
7522 _arena_buf: &mut [u8],
7523 _actual: usize,
7524 _upper: usize,
7525) -> bool {
7526 let _ = schedule;
7527 false
7528}
7529
7530struct MoeResidencyGuard;
7532impl Drop for MoeResidencyGuard {
7533 fn drop(&mut self) {
7534 if let Some(stats) = crate::moe_residency::take_stats() {
7535 crate::moe_residency::stash_last_forward_stats(stats);
7536 } else {
7537 crate::moe_residency::clear_mask();
7538 }
7539 }
7540}
7541
7542fn thunk_kind_name(t: &Thunk) -> &'static str {
7543 match t {
7544 Thunk::Nop => "Nop",
7545 Thunk::Gather { .. } => "Gather",
7546 Thunk::GatherAxis { .. } => "GatherAxis",
7547 Thunk::TopK { .. } => "TopK",
7548 Thunk::Copy { .. } => "Copy",
7549 Thunk::CopyF64 { .. } => "CopyF64",
7550 Thunk::CopyI64 { .. } => "CopyI64",
7551 Thunk::CastF32ToI64 { .. } => "CastF32ToI64",
7552 Thunk::CastI64ToF32 { .. } => "CastI64ToF32",
7553 Thunk::CastBoolToI32 { .. } => "CastBoolToI32",
7554 Thunk::CastI32ToF32 { .. } => "CastI32ToF32",
7555 Thunk::Transpose { .. } => "Transpose",
7556 Thunk::TransposeF64 { .. } => "TransposeF64",
7557 Thunk::Where { .. } => "Where",
7558 Thunk::Compare { .. } => "Compare",
7559 Thunk::BinaryFull { .. } => "BinaryFull",
7560 Thunk::BinaryFullF64 { .. } => "BinaryFullF64",
7561 Thunk::Sgemm { .. } => "Sgemm",
7562 Thunk::Dgemm { .. } => "Dgemm",
7563 Thunk::FusedMmBiasAct { .. } => "FusedMmBiasAct",
7564 Thunk::BiasAdd { .. } => "BiasAdd",
7565 Thunk::LayerNorm { .. } => "LayerNorm",
7566 Thunk::Softmax { .. } => "Softmax",
7567 Thunk::Conv2D { .. } => "Conv2D",
7568 Thunk::Conv2D1x1 { .. } => "Conv2D1x1",
7569 Thunk::CustomOp { .. } => "CustomOp",
7570 Thunk::ActivationInPlace { .. } => "ActivationInPlace",
7571 Thunk::Narrow { .. } => "Narrow",
7572 Thunk::Cumsum { .. } => "Cumsum",
7573 Thunk::Reduce { .. } => "Reduce",
7574 Thunk::BatchedSgemm { .. } => "BatchedSgemm",
7575 Thunk::DequantMatMul { .. } => "DequantMatMul",
7576 Thunk::Quantize { .. } => "Quantize",
7577 Thunk::Dequantize { .. } => "Dequantize",
7578 Thunk::ConvTranspose2d { .. } => "ConvTranspose2d",
7579 Thunk::ResizeNearest2x { .. } => "ResizeNearest2x",
7580 _ => "Other",
7581 }
7582}
7583
7584pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7585 crate::moe_residency::reset_gmm_counters();
7586 if let Some(layers) = schedule.moe_resident_layers.clone() {
7587 crate::moe_residency::set_per_layer_masks(Some(layers));
7588 } else {
7589 crate::moe_residency::set_mask(schedule.moe_resident.clone());
7590 }
7591 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
7592 cap.clear();
7593 }
7594 let _moe_guard = MoeResidencyGuard;
7595 let base = arena_buf.as_mut_ptr();
7596 let mask_thr = schedule.mask_threshold;
7597 let mask_neg = schedule.mask_neg_inf;
7598 let score_thr = schedule.score_skip;
7599 let thunks = &schedule.thunks;
7600 let len = thunks.len();
7601
7602 let max_h = thunks
7604 .iter()
7605 .filter_map(|t| match t {
7606 Thunk::FusedResidualLN { h, .. }
7607 | Thunk::FusedResidualRmsNorm { h, .. }
7608 | Thunk::LayerNorm { h, .. } => Some(*h as usize),
7609 _ => None,
7610 })
7611 .max()
7612 .unwrap_or(0);
7613 let zero_bias = vec![0f32; max_h];
7614
7615 let max_sdpa = thunks
7618 .iter()
7619 .filter_map(|t| match t {
7620 Thunk::Attention {
7621 batch,
7622 seq,
7623 kv_seq,
7624 heads,
7625 head_dim,
7626 ..
7627 } => Some((
7628 *batch as usize,
7629 (*seq as usize).max(*kv_seq as usize),
7630 *heads as usize,
7631 *head_dim as usize,
7632 )),
7633 _ => None,
7634 })
7635 .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
7636 (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7637 });
7638 let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7639 let max_units = max_batch * max_heads;
7640 let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7641
7642 let fl = thunks
7644 .iter()
7645 .filter_map(|t| match t {
7646 Thunk::FusedBertLayer {
7647 batch,
7648 seq,
7649 hs,
7650 int_dim,
7651 ..
7652 } => {
7653 let m = (*batch as usize) * (*seq as usize);
7654 let h = *hs as usize;
7655 let id = *int_dim as usize;
7656 Some((m, h, id, m * (*seq as usize)))
7657 }
7658 Thunk::FusedNomicLayer {
7659 batch,
7660 seq,
7661 hs,
7662 int_dim,
7663 ..
7664 } => {
7665 let m = (*batch as usize) * (*seq as usize);
7666 let h = *hs as usize;
7667 let id = *int_dim as usize;
7668 Some((m, h, id, m * (*seq as usize)))
7669 }
7670 _ => None,
7671 })
7672 .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7673 (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7674 });
7675 let (fl_m, fl_h, fl_int, fl_ss) = fl;
7676 let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7677 let mut fl_attn = vec![0f32; fl_m * fl_h];
7678 let mut fl_res = vec![0f32; fl_m * fl_h];
7679 let mut fl_normed = vec![0f32; fl_m * fl_h];
7680 let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; let mut fl_sc = vec![0f32; fl_ss.max(1)];
7682
7683 let trace_thunks = std::env::var_os("RLX_TRACE_THUNK").is_some();
7684 if trace_thunks {
7685 eprintln!(
7686 "[thunk] prealloc max_h={max_h} sdpa={} fl_m={fl_m} fl_h={fl_h} fl_int={fl_int}",
7687 max_units * max_seq * max_seq
7688 );
7689 }
7690 for i in 0..len {
7691 let thunk = unsafe { thunks.get_unchecked(i) };
7692 if trace_thunks && (i < 120 || i % 200 == 0 || i + 1 == len) {
7693 eprintln!("[thunk {i}/{len}] {}", thunk_kind_name(thunk));
7694 }
7695 let trace_done = trace_thunks && i < 120;
7696 match thunk {
7697 Thunk::Nop => {}
7698
7699 Thunk::GaussianSplatRender {
7700 positions_off,
7701 positions_len,
7702 scales_off,
7703 scales_len,
7704 rotations_off,
7705 rotations_len,
7706 opacities_off,
7707 opacities_len,
7708 colors_off,
7709 colors_len,
7710 sh_coeffs_off,
7711 sh_coeffs_len,
7712 meta_off,
7713 dst_off,
7714 dst_len,
7715 width,
7716 height,
7717 tile_size,
7718 radius_scale,
7719 alpha_cutoff,
7720 max_splat_steps,
7721 transmittance_threshold,
7722 max_list_entries,
7723 } => unsafe {
7724 crate::splat::execute_gaussian_splat_render(
7725 *positions_off,
7726 *positions_len,
7727 *scales_off,
7728 *scales_len,
7729 *rotations_off,
7730 *rotations_len,
7731 *opacities_off,
7732 *opacities_len,
7733 *colors_off,
7734 *colors_len,
7735 *sh_coeffs_off,
7736 *sh_coeffs_len,
7737 *meta_off,
7738 *dst_off,
7739 *dst_len,
7740 *width,
7741 *height,
7742 *tile_size,
7743 *radius_scale,
7744 *alpha_cutoff,
7745 *max_splat_steps,
7746 *transmittance_threshold,
7747 *max_list_entries,
7748 base,
7749 );
7750 },
7751
7752 Thunk::GaussianSplatRenderBackward {
7753 positions_off,
7754 positions_len,
7755 scales_off,
7756 scales_len,
7757 rotations_off,
7758 rotations_len,
7759 opacities_off,
7760 opacities_len,
7761 colors_off,
7762 colors_len,
7763 sh_coeffs_off,
7764 sh_coeffs_len,
7765 meta_off,
7766 d_loss_off,
7767 d_loss_len,
7768 packed_off,
7769 packed_len,
7770 width,
7771 height,
7772 tile_size,
7773 radius_scale,
7774 alpha_cutoff,
7775 max_splat_steps,
7776 transmittance_threshold,
7777 max_list_entries,
7778 loss_grad_clip,
7779 sh_band,
7780 max_anisotropy,
7781 } => unsafe {
7782 crate::splat::execute_gaussian_splat_render_backward(
7783 *positions_off,
7784 *positions_len,
7785 *scales_off,
7786 *scales_len,
7787 *rotations_off,
7788 *rotations_len,
7789 *opacities_off,
7790 *opacities_len,
7791 *colors_off,
7792 *colors_len,
7793 *sh_coeffs_off,
7794 *sh_coeffs_len,
7795 *meta_off,
7796 *d_loss_off,
7797 *d_loss_len,
7798 *packed_off,
7799 *packed_len,
7800 *width,
7801 *height,
7802 *tile_size,
7803 *radius_scale,
7804 *alpha_cutoff,
7805 *max_splat_steps,
7806 *transmittance_threshold,
7807 *max_list_entries,
7808 *loss_grad_clip,
7809 *sh_band,
7810 *max_anisotropy,
7811 base,
7812 );
7813 },
7814
7815 Thunk::GaussianSplatPrepare {
7816 positions_off,
7817 positions_len,
7818 scales_off,
7819 scales_len,
7820 rotations_off,
7821 rotations_len,
7822 opacities_off,
7823 opacities_len,
7824 colors_off,
7825 colors_len,
7826 sh_coeffs_off,
7827 sh_coeffs_len,
7828 meta_off,
7829 meta_len,
7830 prep_off,
7831 prep_len,
7832 width,
7833 height,
7834 tile_size,
7835 radius_scale,
7836 alpha_cutoff,
7837 max_splat_steps,
7838 transmittance_threshold,
7839 max_list_entries,
7840 } => unsafe {
7841 crate::splat::execute_gaussian_splat_prepare(
7842 *positions_off,
7843 *positions_len,
7844 *scales_off,
7845 *scales_len,
7846 *rotations_off,
7847 *rotations_len,
7848 *opacities_off,
7849 *opacities_len,
7850 *colors_off,
7851 *colors_len,
7852 *sh_coeffs_off,
7853 *sh_coeffs_len,
7854 *meta_off,
7855 *meta_len,
7856 *prep_off,
7857 *prep_len,
7858 *width,
7859 *height,
7860 *tile_size,
7861 *radius_scale,
7862 *alpha_cutoff,
7863 *max_splat_steps,
7864 *transmittance_threshold,
7865 *max_list_entries,
7866 base,
7867 );
7868 },
7869
7870 Thunk::GaussianSplatRasterize {
7871 prep_off,
7872 prep_len,
7873 meta_off,
7874 meta_len,
7875 dst_off,
7876 dst_len,
7877 count,
7878 width,
7879 height,
7880 tile_size,
7881 alpha_cutoff,
7882 max_splat_steps,
7883 transmittance_threshold,
7884 max_list_entries,
7885 } => unsafe {
7886 crate::splat::execute_gaussian_splat_rasterize(
7887 *prep_off,
7888 *prep_len,
7889 *meta_off,
7890 *meta_len,
7891 *dst_off,
7892 *dst_len,
7893 *count,
7894 *width,
7895 *height,
7896 *tile_size,
7897 *alpha_cutoff,
7898 *max_splat_steps,
7899 *transmittance_threshold,
7900 *max_list_entries,
7901 base,
7902 );
7903 },
7904
7905 Thunk::Fft1d {
7906 src,
7907 dst,
7908 outer,
7909 n_complex,
7910 inverse,
7911 norm_tag,
7912 dtype,
7913 } => unsafe {
7914 match dtype {
7915 rlx_ir::DType::F64 => execute_fft1d_f64(
7916 *src,
7917 *dst,
7918 *outer as usize,
7919 *n_complex as usize,
7920 *inverse,
7921 *norm_tag,
7922 base,
7923 ),
7924 rlx_ir::DType::F32 => execute_fft1d_f32(
7925 *src,
7926 *dst,
7927 *outer as usize,
7928 *n_complex as usize,
7929 *inverse,
7930 *norm_tag,
7931 base,
7932 ),
7933 rlx_ir::DType::C64 => execute_fft1d_c64(
7934 *src,
7935 *dst,
7936 *outer as usize,
7937 *n_complex as usize,
7938 *inverse,
7939 *norm_tag,
7940 base,
7941 ),
7942 other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
7943 }
7944 },
7945
7946 Thunk::FftButterflyStage {
7947 state_src,
7948 state_dst,
7949 gate_src,
7950 rev_src,
7951 tw_re_src,
7952 tw_im_src,
7953 batch,
7954 n_fft,
7955 stage,
7956 } => unsafe {
7957 execute_fft_butterfly_stage_f32(
7958 *state_src,
7959 *state_dst,
7960 *gate_src,
7961 *rev_src,
7962 *tw_re_src,
7963 *tw_im_src,
7964 *batch as usize,
7965 *n_fft as usize,
7966 *stage as usize,
7967 base,
7968 );
7969 },
7970
7971 Thunk::LogMel {
7972 spec,
7973 filters,
7974 dst,
7975 outer,
7976 n_fft,
7977 n_bins,
7978 n_mels,
7979 } => unsafe {
7980 execute_log_mel_f32(
7981 *spec,
7982 *filters,
7983 *dst,
7984 *outer as usize,
7985 *n_fft as usize,
7986 *n_bins as usize,
7987 *n_mels as usize,
7988 base,
7989 );
7990 },
7991
7992 Thunk::LogMelBackward {
7993 spec,
7994 filters,
7995 dy,
7996 dst,
7997 outer,
7998 n_fft,
7999 n_bins,
8000 n_mels,
8001 } => unsafe {
8002 execute_log_mel_backward_f32(
8003 *spec,
8004 *filters,
8005 *dy,
8006 *dst,
8007 *outer as usize,
8008 *n_fft as usize,
8009 *n_bins as usize,
8010 *n_mels as usize,
8011 base,
8012 );
8013 },
8014
8015 Thunk::CustomFn {
8019 body,
8020 body_init,
8021 inputs,
8022 body_output_off,
8023 outer_output_off,
8024 out_bytes,
8025 } => {
8026 let mut body_buf: Vec<u8> = (**body_init).clone();
8027 unsafe {
8028 for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
8029 let src = (base as *const u8).add(*outer_in_off);
8030 let dst = body_buf.as_mut_ptr().add(*body_in_off);
8031 std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
8032 }
8033 }
8034 execute_thunks(body, &mut body_buf);
8035 unsafe {
8036 let src = body_buf.as_ptr().add(*body_output_off);
8037 let dst = base.add(*outer_output_off);
8038 std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
8039 }
8040 }
8041
8042 Thunk::Sgemm { a, b, c, m, k, n } => {
8043 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8044 if trace_thunks {
8045 eprintln!("[sgemm] m={m} k={k} n={n} a={} b={} c={}", *a, *b, *c);
8046 }
8047 let c_len = m.saturating_mul(n);
8048 let a_len = m.saturating_mul(k);
8049 let b_len = k.saturating_mul(n);
8050 let arena_len = arena_buf.len();
8051 let max_a = (arena_len.saturating_sub(*a)) / 4;
8052 let max_b = (arena_len.saturating_sub(*b)) / 4;
8053 let max_c = (arena_len.saturating_sub(*c)) / 4;
8054 let a_len = a_len.min(max_a);
8055 let b_len = b_len.min(max_b);
8056 let c_len = c_len.min(max_c);
8057 unsafe {
8058 let a_sl = sl(*a, base, a_len);
8059 let b_sl = sl(*b, base, b_len);
8060 let c_sl = sl_mut(*c, base, c_len);
8061 if std::ptr::eq(a_sl.as_ptr(), c_sl.as_ptr())
8062 || std::ptr::eq(b_sl.as_ptr(), c_sl.as_ptr())
8063 {
8064 let mut tmp = vec![0.0f32; c_len];
8065 crate::blas::sgemm_auto(a_sl, b_sl, &mut tmp, m, k, n);
8066 c_sl.copy_from_slice(&tmp);
8067 } else {
8068 crate::blas::sgemm_auto(a_sl, b_sl, c_sl, m, k, n);
8069 }
8070 }
8071 }
8072
8073 Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
8074 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8075 unsafe {
8081 let a_src = sl_f64(*a, base, n_ * n_);
8082 let b_src = sl_f64(*b, base, n_ * nrhs_);
8083 let mut a_scratch: Vec<f64> = a_src.to_vec();
8084 let mut x_buf: Vec<f64> = b_src.to_vec();
8085 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8086 if info != 0 {
8087 panic!(
8088 "DenseSolveF64: dgesv reported singular matrix \
8089 (info={info}, n={n_}, nrhs={nrhs_})"
8090 );
8091 }
8092 let dst = sl_mut_f64(*x, base, n_ * nrhs_);
8093 dst.copy_from_slice(&x_buf);
8094 }
8095 }
8096
8097 Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
8098 let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8099 unsafe {
8100 let a_src = sl(*a, base, n_ * n_);
8101 let b_src = sl(*b, base, n_ * nrhs_);
8102 let mut a_scratch: Vec<f32> = a_src.to_vec();
8103 let mut x_buf: Vec<f32> = b_src.to_vec();
8104 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8105 if info != 0 {
8106 panic!(
8107 "DenseSolveF32: sgesv reported singular matrix \
8108 (info={info}, n={n_}, nrhs={nrhs_})"
8109 );
8110 }
8111 let dst = sl_mut(*x, base, n_ * nrhs_);
8112 dst.copy_from_slice(&x_buf);
8113 }
8114 }
8115
8116 Thunk::BatchedDenseSolveF64 {
8117 a,
8118 b,
8119 x,
8120 batch,
8121 n,
8122 nrhs,
8123 } => {
8124 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8131 let a_stride = n_ * n_;
8132 let b_stride = n_ * nrhs_;
8133 unsafe {
8134 let a_full = sl_f64(*a, base, b_ * a_stride);
8135 let b_full = sl_f64(*b, base, b_ * b_stride);
8136 let x_full = sl_mut_f64(*x, base, b_ * b_stride);
8137 for bi in 0..b_ {
8138 let mut a_scratch: Vec<f64> =
8139 a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8140 let mut x_buf: Vec<f64> =
8141 b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8142 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8143 if info != 0 {
8144 panic!(
8145 "BatchedDenseSolveF64: slice {bi} \
8146 singular (info={info}, n={n_}, nrhs={nrhs_})"
8147 );
8148 }
8149 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8150 }
8151 }
8152 }
8153
8154 Thunk::BatchedDenseSolveF32 {
8155 a,
8156 b,
8157 x,
8158 batch,
8159 n,
8160 nrhs,
8161 } => {
8162 let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8163 let a_stride = n_ * n_;
8164 let b_stride = n_ * nrhs_;
8165 unsafe {
8166 let a_full = sl(*a, base, b_ * a_stride);
8167 let b_full = sl(*b, base, b_ * b_stride);
8168 let x_full = sl_mut(*x, base, b_ * b_stride);
8169 for bi in 0..b_ {
8170 let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8171 let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8172 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8173 if info != 0 {
8174 panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
8175 }
8176 x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8177 }
8178 }
8179 }
8180
8181 Thunk::BatchedDgemmF64 {
8182 a,
8183 b,
8184 c,
8185 batch,
8186 m,
8187 k,
8188 n,
8189 } => {
8190 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8191 let a_stride = m_ * k_;
8192 let b_stride = k_ * n_;
8193 let c_stride = m_ * n_;
8194 unsafe {
8195 let a_full = sl_f64(*a, base, b_ * a_stride);
8196 let b_full = sl_f64(*b, base, b_ * b_stride);
8197 let c_full = sl_mut_f64(*c, base, b_ * c_stride);
8198 for bi in 0..b_ {
8199 let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
8200 let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
8201 let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
8202 crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
8203 }
8204 }
8205 }
8206
8207 Thunk::BatchedSgemm {
8208 a,
8209 b,
8210 c,
8211 batch,
8212 m,
8213 k,
8214 n,
8215 } => {
8216 let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8217 if trace_thunks {
8218 eprintln!(
8219 "[batched-sgemm] batch={b_} m={m_} k={k_} n={n_} a={} b={} c={}",
8220 *a, *b, *c
8221 );
8222 }
8223 let a_stride = m_.saturating_mul(k_);
8224 let b_stride = k_.saturating_mul(n_);
8225 let c_stride = m_.saturating_mul(n_);
8226 let arena_len = arena_buf.len();
8227 let a_cap = (arena_len.saturating_sub(*a)) / 4;
8228 let b_cap = (arena_len.saturating_sub(*b)) / 4;
8229 let c_cap = (arena_len.saturating_sub(*c)) / 4;
8230 let a_elems = (b_ * a_stride).min(a_cap);
8231 let b_elems = (b_ * b_stride).min(b_cap);
8232 let c_elems = (b_ * c_stride).min(c_cap);
8233 let b_eff = b_
8234 .min(a_elems.checked_div(a_stride).unwrap_or(0))
8235 .min(b_elems.checked_div(b_stride).unwrap_or(0))
8236 .min(c_elems.checked_div(c_stride).unwrap_or(0));
8237 unsafe {
8238 let a_full = sl(*a, base, a_elems);
8239 let b_full = sl(*b, base, b_elems);
8240 let c_full = sl_mut(*c, base, c_elems);
8241 for bi in 0..b_eff {
8242 let a0 = bi * a_stride;
8243 let b0 = bi * b_stride;
8244 let c0 = bi * c_stride;
8245 if a0 + a_stride > a_full.len()
8246 || b0 + b_stride > b_full.len()
8247 || c0 + c_stride > c_full.len()
8248 {
8249 break;
8250 }
8251 let a_slice = &a_full[a0..a0 + a_stride];
8252 let b_slice = &b_full[b0..b0 + b_stride];
8253 let c_slice = &mut c_full[c0..c0 + c_stride];
8254 if std::ptr::eq(a_slice.as_ptr(), c_slice.as_mut_ptr())
8255 || std::ptr::eq(b_slice.as_ptr(), c_slice.as_mut_ptr())
8256 {
8257 let mut tmp = vec![0.0f32; c_stride];
8258 crate::blas::sgemm_auto(a_slice, b_slice, &mut tmp, m_, k_, n_);
8259 c_slice.copy_from_slice(&tmp);
8260 } else {
8261 crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
8262 }
8263 }
8264 }
8265 }
8266
8267 Thunk::Dgemm { a, b, c, m, k, n } => {
8268 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8269 unsafe {
8270 crate::blas::dgemm(
8271 sl_f64(*a, base, m * k),
8272 sl_f64(*b, base, k * n),
8273 sl_mut_f64(*c, base, m * n),
8274 m,
8275 k,
8276 n,
8277 );
8278 }
8279 }
8280
8281 Thunk::TransposeF64 {
8282 src,
8283 dst,
8284 in_total,
8285 out_dims,
8286 in_strides,
8287 } => unsafe {
8288 let inp = sl_f64(*src, base, *in_total as usize);
8289 let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
8290 let out = sl_mut_f64(*dst, base, out_total);
8291 transpose_walk_f64(inp, out, out_dims, in_strides);
8292 },
8293
8294 Thunk::ActivationF64 {
8295 src,
8296 dst,
8297 len,
8298 kind,
8299 } => {
8300 let len = *len as usize;
8301 unsafe {
8302 let inp = sl_f64(*src, base, len);
8303 let out = sl_mut_f64(*dst, base, len);
8304 apply_activation_f64(inp, out, *kind);
8305 }
8306 }
8307
8308 Thunk::ReduceSumF64 {
8309 src,
8310 dst,
8311 outer,
8312 reduced,
8313 inner,
8314 } => {
8315 let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
8316 unsafe {
8317 let inp = sl_f64(*src, base, o * r * n);
8318 let out = sl_mut_f64(*dst, base, o * n);
8319 reduce_sum_f64(inp, out, o, r, n);
8320 }
8321 }
8322
8323 Thunk::CopyF64 { src, dst, len } => {
8324 let mut len = *len as usize;
8325 if *src == *dst || len == 0 {
8326 continue;
8327 }
8328 let arena_len = arena_buf.len();
8329 let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8330 let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8331 len = len.min(max_from_src).min(max_from_dst);
8332 if len == 0 {
8333 continue;
8334 }
8335 let byte_len = len.saturating_mul(8);
8336 unsafe {
8337 std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8338 }
8339 }
8340
8341 Thunk::CopyI64 { src, dst, len } => {
8342 let mut len = *len as usize;
8343 if *src == *dst || len == 0 {
8344 continue;
8345 }
8346 let arena_len = arena_buf.len();
8347 let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8348 let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8349 len = len.min(max_from_src).min(max_from_dst);
8350 if len == 0 {
8351 continue;
8352 }
8353 let byte_len = len.saturating_mul(8);
8354 unsafe {
8355 std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8356 }
8357 }
8358
8359 Thunk::CastF32ToI64 { src, dst, len } => {
8360 let len = *len as usize;
8361 if len == 0 {
8362 continue;
8363 }
8364 unsafe {
8365 let inp = sl(*src, base, len);
8366 let out = sl_mut_i64(*dst, base, len);
8367 for i in 0..len {
8368 out[i] = inp[i].round() as i64;
8369 }
8370 }
8371 }
8372
8373 Thunk::CastI64ToF32 { src, dst, len } => {
8374 let len = *len as usize;
8375 if len == 0 {
8376 continue;
8377 }
8378 unsafe {
8379 let inp = sl_i64(*src, base, len);
8380 let out = sl_mut(*dst, base, len);
8381 for i in 0..len {
8382 out[i] = inp[i] as f32;
8383 }
8384 }
8385 }
8386
8387 Thunk::CastBoolToI32 { src, dst, len } => {
8388 let len = *len as usize;
8389 if len == 0 {
8390 continue;
8391 }
8392 unsafe {
8393 let inp = &arena_buf[*src..*src + len];
8394 let out = sl_mut_i32(*dst, base, len);
8395 for i in 0..len {
8396 out[i] = i32::from(inp[i] != 0);
8397 }
8398 }
8399 }
8400
8401 Thunk::CastI32ToF32 { src, dst, len } => {
8402 let len = *len as usize;
8403 if len == 0 {
8404 continue;
8405 }
8406 unsafe {
8407 let inp = sl_i32(*src, base, len);
8408 let out = sl_mut(*dst, base, len);
8409 for i in 0..len {
8410 out[i] = inp[i] as f32;
8411 }
8412 }
8413 }
8414
8415 Thunk::BinaryFullF64 {
8416 lhs,
8417 rhs,
8418 dst,
8419 len,
8420 lhs_len,
8421 rhs_len,
8422 op,
8423 out_dims_bcast,
8424 bcast_lhs_strides,
8425 bcast_rhs_strides,
8426 } => {
8427 let len = *len as usize;
8428 let lhs_len = *lhs_len as usize;
8429 let rhs_len = *rhs_len as usize;
8430 unsafe {
8431 let l = sl_f64(*lhs, base, lhs_len);
8432 let r = sl_f64(*rhs, base, rhs_len);
8433 let d = sl_mut_f64(*dst, base, len);
8434 if lhs_len == len && rhs_len == len {
8435 for i in 0..len {
8436 d[i] = binary_op_f64(*op, l[i], r[i]);
8437 }
8438 } else if !out_dims_bcast.is_empty() {
8439 let rank = out_dims_bcast.len();
8443 let mut coords = vec![0u32; rank];
8444 for i in 0..len {
8445 let mut rem = i;
8446 for ax in (0..rank).rev() {
8447 let sz = out_dims_bcast[ax] as usize;
8448 coords[ax] = (rem % sz) as u32;
8449 rem /= sz;
8450 }
8451 let mut li: usize = 0;
8452 let mut ri: usize = 0;
8453 for ax in 0..rank {
8454 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8455 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8456 }
8457 d[i] = binary_op_f64(*op, l[li], r[ri]);
8458 }
8459 } else {
8460 for i in 0..len {
8465 d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
8466 }
8467 }
8468 }
8469 }
8470
8471 Thunk::BinaryFullC64 {
8472 lhs,
8473 rhs,
8474 dst,
8475 len,
8476 lhs_len,
8477 rhs_len,
8478 op,
8479 out_dims_bcast,
8480 bcast_lhs_strides,
8481 bcast_rhs_strides,
8482 } => {
8483 let n_out = *len as usize;
8489 let n_l = *lhs_len as usize;
8490 let n_r = *rhs_len as usize;
8491 unsafe {
8492 let l = sl(*lhs, base, 2 * n_l);
8493 let r = sl(*rhs, base, 2 * n_r);
8494 let d = sl_mut(*dst, base, 2 * n_out);
8495 let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
8496 match op {
8497 BinaryOp::Add => (a_re + b_re, a_im + b_im),
8498 BinaryOp::Sub => (a_re - b_re, a_im - b_im),
8499 BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
8500 BinaryOp::Div => {
8501 let denom = b_re * b_re + b_im * b_im;
8502 (
8503 (a_re * b_re + a_im * b_im) / denom,
8504 (a_im * b_re - a_re * b_im) / denom,
8505 )
8506 }
8507 BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
8508 unreachable!("C64 max/min/pow rejected at lowering")
8509 }
8510 }
8511 };
8512 if n_l == n_out && n_r == n_out {
8513 for i in 0..n_out {
8514 let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
8515 d[2 * i] = re;
8516 d[2 * i + 1] = im;
8517 }
8518 } else if !out_dims_bcast.is_empty() {
8519 let rank = out_dims_bcast.len();
8523 let mut coords = vec![0u32; rank];
8524 for i in 0..n_out {
8525 let mut rem = i;
8526 for ax in (0..rank).rev() {
8527 let sz = out_dims_bcast[ax] as usize;
8528 coords[ax] = (rem % sz) as u32;
8529 rem /= sz;
8530 }
8531 let mut li: usize = 0;
8532 let mut ri: usize = 0;
8533 for ax in 0..rank {
8534 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8535 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8536 }
8537 let (re, im) =
8538 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
8539 d[2 * i] = re;
8540 d[2 * i + 1] = im;
8541 }
8542 } else {
8543 for i in 0..n_out {
8545 let li = if n_l == 1 { 0 } else { i % n_l };
8546 let ri = if n_r == 1 { 0 } else { i % n_r };
8547 let (re, im) =
8548 do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
8549 d[2 * i] = re;
8550 d[2 * i + 1] = im;
8551 }
8552 }
8553 }
8554 }
8555
8556 Thunk::ComplexNormSqF32 { src, dst, len } => {
8557 let n = *len as usize;
8558 unsafe {
8559 let s = sl(*src, base, 2 * n);
8560 let d = sl_mut(*dst, base, n);
8561 for i in 0..n {
8562 let re = s[2 * i];
8563 let im = s[2 * i + 1];
8564 d[i] = re * re + im * im;
8565 }
8566 }
8567 }
8568
8569 Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
8570 let n = *len as usize;
8573 unsafe {
8574 let zb = sl(*z, base, 2 * n);
8575 let gb = sl(*g, base, n);
8576 let db = sl_mut(*dz, base, 2 * n);
8577 for i in 0..n {
8578 let re = zb[2 * i];
8579 let im = zb[2 * i + 1];
8580 let gv = gb[i];
8581 db[2 * i] = gv * re;
8582 db[2 * i + 1] = gv * im;
8583 }
8584 }
8585 }
8586
8587 Thunk::ConjugateC64 { src, dst, len } => {
8588 let n = *len as usize;
8589 unsafe {
8590 let s = sl(*src, base, 2 * n);
8591 let d = sl_mut(*dst, base, 2 * n);
8592 for i in 0..n {
8593 d[2 * i] = s[2 * i];
8594 d[2 * i + 1] = -s[2 * i + 1];
8595 }
8596 }
8597 }
8598
8599 Thunk::ActivationC64 {
8600 src,
8601 dst,
8602 len,
8603 kind,
8604 } => {
8605 let n = *len as usize;
8606 unsafe {
8607 let s = sl(*src, base, 2 * n);
8608 let d = sl_mut(*dst, base, 2 * n);
8609 for i in 0..n {
8610 let a = s[2 * i];
8611 let b = s[2 * i + 1];
8612 let (re, im) = match kind {
8613 Activation::Neg => (-a, -b),
8614 Activation::Exp => {
8615 let ea = a.exp();
8617 (ea * b.cos(), ea * b.sin())
8618 }
8619 Activation::Log => {
8620 let r = (a * a + b * b).sqrt();
8622 (r.ln(), b.atan2(a))
8623 }
8624 Activation::Sqrt => {
8625 let r = (a * a + b * b).sqrt();
8628 let re = ((r + a) * 0.5).max(0.0).sqrt();
8629 let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
8630 let im = if b >= 0.0 { im_mag } else { -im_mag };
8631 (re, im)
8632 }
8633 _ => unreachable!("non-C64 activation kind survived lowering"),
8634 };
8635 d[2 * i] = re;
8636 d[2 * i + 1] = im;
8637 }
8638 }
8639 }
8640
8641 Thunk::Scan {
8642 body,
8643 body_init,
8644 body_input_off,
8645 body_output_off,
8646 outer_init_off,
8647 outer_final_off,
8648 length,
8649 carry_bytes,
8650 save_trajectory,
8651 xs_inputs,
8652 bcast_inputs,
8653 num_checkpoints,
8654 } => {
8655 let cb = *carry_bytes as usize;
8656 let n_steps = *length as usize;
8657 let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
8661 n_steps } else {
8663 *num_checkpoints as usize
8664 };
8665 let checkpoint_t_for_k = |k: usize| -> usize {
8666 if k_total == n_steps {
8667 k
8668 } else {
8669 ((k + 1) * n_steps)
8670 .div_ceil(k_total)
8671 .saturating_sub(1)
8672 .min(n_steps - 1)
8673 }
8674 };
8675 let mut next_k = 0usize;
8676
8677 let mut body_buf: Vec<u8> = (**body_init).clone();
8678 unsafe {
8679 std::ptr::copy_nonoverlapping(
8680 base.add(*outer_init_off),
8681 body_buf.as_mut_ptr().add(*body_input_off),
8682 cb,
8683 );
8684 for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
8688 std::ptr::copy_nonoverlapping(
8689 base.add(*outer_b_off),
8690 body_buf.as_mut_ptr().add(*body_b_off),
8691 *total_bytes as usize,
8692 );
8693 }
8694 }
8695 for t in 0..n_steps {
8696 for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
8697 let psb = *per_step_bytes as usize;
8698 unsafe {
8699 std::ptr::copy_nonoverlapping(
8700 base.add(*outer_xs_off + t * psb),
8701 body_buf.as_mut_ptr().add(*body_x_off),
8702 psb,
8703 );
8704 }
8705 }
8706
8707 execute_thunks(body, &mut body_buf);
8708
8709 if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
8710 unsafe {
8711 std::ptr::copy_nonoverlapping(
8712 body_buf.as_ptr().add(*body_output_off),
8713 base.add(*outer_final_off + next_k * cb),
8714 cb,
8715 );
8716 }
8717 next_k += 1;
8718 }
8719
8720 if *body_output_off != *body_input_off {
8721 body_buf
8722 .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
8723 }
8724 }
8725
8726 if !*save_trajectory {
8727 unsafe {
8729 std::ptr::copy_nonoverlapping(
8730 body_buf.as_ptr().add(*body_output_off),
8731 base.add(*outer_final_off),
8732 cb,
8733 );
8734 }
8735 }
8736 }
8737
8738 Thunk::ScanBackward {
8739 body_vjp,
8740 body_init,
8741 body_carry_in_off,
8742 body_x_offs,
8743 body_d_output_off,
8744 body_dcarry_out_off,
8745 outer_init_off,
8746 outer_traj_off,
8747 outer_upstream_off,
8748 outer_xs_offs,
8749 outer_dinit_off,
8750 length,
8751 carry_bytes,
8752 save_trajectory,
8753 num_checkpoints,
8754 forward_body,
8755 forward_body_init,
8756 forward_body_carry_in_off,
8757 forward_body_output_off,
8758 forward_body_x_offs,
8759 carry_elem_size,
8760 } => {
8761 let cb = *carry_bytes as usize;
8774 let n_steps = *length as usize;
8775 let k_total = *num_checkpoints as usize;
8776 let is_recursive = k_total != 0 && k_total != n_steps;
8777 let checkpoint_t_for_k = |k: usize| -> usize {
8778 ((k + 1) * n_steps)
8779 .div_ceil(k_total)
8780 .saturating_sub(1)
8781 .min(n_steps - 1)
8782 };
8783
8784 let mut fwd_buf: Vec<u8> = if is_recursive {
8785 (**forward_body_init.as_ref().unwrap()).clone()
8786 } else {
8787 Vec::new()
8788 };
8789
8790 let mut dcarry: Vec<u8> = vec![0u8; cb];
8791 if !*save_trajectory {
8792 unsafe {
8793 std::ptr::copy_nonoverlapping(
8794 base.add(*outer_upstream_off),
8795 dcarry.as_mut_ptr(),
8796 cb,
8797 );
8798 }
8799 }
8800
8801 let mut body_buf: Vec<u8> = (**body_init).clone();
8802
8803 let process_iter =
8808 |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
8809 if *save_trajectory {
8810 unsafe {
8811 let up_off = *outer_upstream_off + t * cb;
8812 match *carry_elem_size {
8813 4 => {
8814 let up_ptr = base.add(up_off) as *const f32;
8815 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8816 let n_elems = cb / 4;
8817 for i in 0..n_elems {
8818 *dc_ptr.add(i) += *up_ptr.add(i);
8819 }
8820 }
8821 8 => {
8822 let up_ptr = base.add(up_off) as *const f64;
8823 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8824 let n_elems = cb / 8;
8825 for i in 0..n_elems {
8826 *dc_ptr.add(i) += *up_ptr.add(i);
8827 }
8828 }
8829 other => panic!(
8830 "ScanBackward: unsupported carry elem size {other} \
8831 (only f32/f64 carries are supported today)"
8832 ),
8833 }
8834 }
8835 }
8836 body_buf[*body_carry_in_off..*body_carry_in_off + cb]
8837 .copy_from_slice(carry_in);
8838 unsafe {
8839 for (i, body_x_off) in body_x_offs.iter().enumerate() {
8840 let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
8841 let psb = per_step_bytes as usize;
8842 std::ptr::copy_nonoverlapping(
8843 base.add(outer_xs_off + t * psb),
8844 body_buf.as_mut_ptr().add(*body_x_off),
8845 psb,
8846 );
8847 }
8848 std::ptr::copy_nonoverlapping(
8849 dcarry.as_ptr(),
8850 body_buf.as_mut_ptr().add(*body_d_output_off),
8851 cb,
8852 );
8853 }
8854 execute_thunks(body_vjp, body_buf);
8855 unsafe {
8856 std::ptr::copy_nonoverlapping(
8857 body_buf.as_ptr().add(*body_dcarry_out_off),
8858 dcarry.as_mut_ptr(),
8859 cb,
8860 );
8861 }
8862 };
8863
8864 if is_recursive {
8865 let leaf_threshold = 4usize;
8873 let fb_sched = forward_body.as_ref().unwrap();
8874 let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8875 let mut segment_end = n_steps - 1;
8876 for seg_k in (0..k_total).rev() {
8877 let segment_start = if seg_k == 0 {
8878 0
8879 } else {
8880 checkpoint_t_for_k(seg_k - 1) + 1
8881 };
8882 let mut anchor: Vec<u8> = vec![0u8; cb];
8883 unsafe {
8884 let src = if seg_k == 0 {
8885 base.add(*outer_init_off)
8886 } else {
8887 base.add(*outer_traj_off + (seg_k - 1) * cb)
8888 };
8889 std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8890 }
8891 let mut leaf_action = |t: usize, carry_in: &[u8]| {
8894 process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8895 };
8896 unsafe {
8897 griewank_process_segment(
8898 segment_start,
8899 segment_end,
8900 &anchor,
8901 cb,
8902 fb_sched,
8903 fb_init,
8904 *forward_body_carry_in_off,
8905 *forward_body_output_off,
8906 forward_body_x_offs,
8907 base,
8908 outer_xs_offs,
8909 &mut fwd_buf,
8910 leaf_threshold,
8911 &mut leaf_action,
8912 );
8913 }
8914 if seg_k == 0 {
8915 break;
8916 }
8917 segment_end = segment_start - 1;
8918 }
8919 } else {
8920 let mut carry_buf: Vec<u8> = vec![0u8; cb];
8923 for t in (0..n_steps).rev() {
8924 unsafe {
8925 let src = if t == 0 {
8926 base.add(*outer_init_off)
8927 } else {
8928 base.add(*outer_traj_off + (t - 1) * cb)
8929 };
8930 std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8931 }
8932 process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8933 }
8934 }
8935
8936 unsafe {
8937 std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8938 }
8939 }
8940
8941 Thunk::ScanBackwardXs {
8942 body_vjp,
8943 body_init,
8944 body_carry_in_off,
8945 body_x_offs,
8946 body_d_output_off,
8947 body_dcarry_out_off,
8948 body_dxs_out_off,
8949 outer_init_off,
8950 outer_traj_off,
8951 outer_upstream_off,
8952 outer_xs_offs,
8953 outer_dxs_off,
8954 length,
8955 carry_bytes,
8956 carry_elem_size,
8957 per_step_bytes,
8958 save_trajectory,
8959 num_checkpoints,
8960 forward_body,
8961 forward_body_init,
8962 forward_body_carry_in_off,
8963 forward_body_output_off,
8964 forward_body_x_offs,
8965 } => {
8966 let cb = *carry_bytes as usize;
8967 let psb = *per_step_bytes as usize;
8968 let n_steps = *length as usize;
8969 let k_total = *num_checkpoints as usize;
8970 let is_recursive = k_total != 0 && k_total != n_steps;
8971 let checkpoint_t_for_k = |k: usize| -> usize {
8972 ((k + 1) * n_steps)
8973 .div_ceil(k_total)
8974 .saturating_sub(1)
8975 .min(n_steps - 1)
8976 };
8977
8978 let mut fwd_buf: Vec<u8> = if is_recursive {
8982 (**forward_body_init.as_ref().unwrap()).clone()
8983 } else {
8984 Vec::new()
8985 };
8986 let mut seg_cache: Vec<u8> = Vec::new();
8987 let mut seg_start_t: usize = usize::MAX;
8988 let mut seg_count: usize = 0;
8989 let recompute_carry_t =
8990 |t: usize,
8991 dst: &mut [u8],
8992 fwd_buf: &mut Vec<u8>,
8993 seg_cache: &mut Vec<u8>,
8994 seg_start_t: &mut usize,
8995 seg_count: &mut usize| {
8996 if !is_recursive {
8997 unsafe {
8998 let src = if t == 0 {
8999 base.add(*outer_init_off)
9000 } else {
9001 base.add(*outer_traj_off + (t - 1) * cb)
9002 };
9003 std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
9004 }
9005 return;
9006 }
9007 if *seg_start_t != usize::MAX
9008 && t >= *seg_start_t
9009 && t < *seg_start_t + *seg_count
9010 {
9011 let off = (t - *seg_start_t) * cb;
9012 dst.copy_from_slice(&seg_cache[off..off + cb]);
9013 return;
9014 }
9015 let seg_k = (0..k_total)
9016 .find(|&k| t <= checkpoint_t_for_k(k))
9017 .unwrap_or(k_total - 1);
9018 let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
9019 (0, unsafe { base.add(*outer_init_off) as *const u8 })
9020 } else {
9021 let prev_ck = checkpoint_t_for_k(seg_k - 1);
9022 (prev_ck + 1, unsafe {
9023 base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
9024 })
9025 };
9026 let seg_end_t = checkpoint_t_for_k(seg_k);
9027 let seg_size = seg_end_t - anchor_t + 1;
9028
9029 fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
9030 unsafe {
9031 std::ptr::copy_nonoverlapping(
9032 anchor_ptr,
9033 fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
9034 cb,
9035 );
9036 }
9037 seg_cache.resize(seg_size * cb, 0u8);
9038 seg_cache[0..cb].copy_from_slice(
9039 &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9040 );
9041 let fb_sched = forward_body.as_ref().unwrap();
9042 for i in 1..seg_size {
9043 let cur_iter = anchor_t + i - 1;
9044 for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
9045 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
9046 let xb = x_psb as usize;
9047 unsafe {
9048 std::ptr::copy_nonoverlapping(
9049 base.add(outer_xs_off + cur_iter * xb),
9050 fwd_buf.as_mut_ptr().add(*fb_x_off),
9051 xb,
9052 );
9053 }
9054 }
9055 execute_thunks(fb_sched, fwd_buf);
9056 if *forward_body_output_off != *forward_body_carry_in_off {
9057 fwd_buf.copy_within(
9058 *forward_body_output_off..*forward_body_output_off + cb,
9059 *forward_body_carry_in_off,
9060 );
9061 }
9062 let cache_off = i * cb;
9063 seg_cache[cache_off..cache_off + cb].copy_from_slice(
9064 &fwd_buf
9065 [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9066 );
9067 }
9068 *seg_start_t = anchor_t;
9069 *seg_count = seg_size;
9070
9071 let off = (t - anchor_t) * cb;
9072 dst.copy_from_slice(&seg_cache[off..off + cb]);
9073 };
9074
9075 let mut dcarry: Vec<u8> = vec![0u8; cb];
9076 if !*save_trajectory {
9077 unsafe {
9078 std::ptr::copy_nonoverlapping(
9079 base.add(*outer_upstream_off),
9080 dcarry.as_mut_ptr(),
9081 cb,
9082 );
9083 }
9084 }
9085
9086 let mut body_buf: Vec<u8> = (**body_init).clone();
9087
9088 for t in (0..n_steps).rev() {
9089 if *save_trajectory {
9090 unsafe {
9091 let up_off = *outer_upstream_off + t * cb;
9092 match *carry_elem_size {
9093 4 => {
9094 let up_ptr = base.add(up_off) as *const f32;
9095 let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
9096 let n_elems = cb / 4;
9097 for i in 0..n_elems {
9098 *dc_ptr.add(i) += *up_ptr.add(i);
9099 }
9100 }
9101 8 => {
9102 let up_ptr = base.add(up_off) as *const f64;
9103 let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
9104 let n_elems = cb / 8;
9105 for i in 0..n_elems {
9106 *dc_ptr.add(i) += *up_ptr.add(i);
9107 }
9108 }
9109 other => panic!(
9110 "ScanBackwardXs: unsupported carry elem size {other} \
9111 (only f32/f64 carries are supported today)"
9112 ),
9113 }
9114 }
9115 }
9116
9117 let carry_dst_start = *body_carry_in_off;
9121 {
9122 let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
9123 recompute_carry_t(
9124 t,
9125 carry_slice,
9126 &mut fwd_buf,
9127 &mut seg_cache,
9128 &mut seg_start_t,
9129 &mut seg_count,
9130 );
9131 }
9132 unsafe {
9133 for (i, body_x_off) in body_x_offs.iter().enumerate() {
9134 let (outer_xs_off, x_psb) = outer_xs_offs[i];
9135 let xb = x_psb as usize;
9136 std::ptr::copy_nonoverlapping(
9137 base.add(outer_xs_off + t * xb),
9138 body_buf.as_mut_ptr().add(*body_x_off),
9139 xb,
9140 );
9141 }
9142 std::ptr::copy_nonoverlapping(
9143 dcarry.as_ptr(),
9144 body_buf.as_mut_ptr().add(*body_d_output_off),
9145 cb,
9146 );
9147 }
9148
9149 execute_thunks(body_vjp, &mut body_buf);
9150
9151 unsafe {
9154 std::ptr::copy_nonoverlapping(
9155 body_buf.as_ptr().add(*body_dxs_out_off),
9156 base.add(*outer_dxs_off + t * psb),
9157 psb,
9158 );
9159 }
9160
9161 unsafe {
9163 std::ptr::copy_nonoverlapping(
9164 body_buf.as_ptr().add(*body_dcarry_out_off),
9165 dcarry.as_mut_ptr(),
9166 cb,
9167 );
9168 }
9169 }
9170 }
9171
9172 Thunk::FusedMmBiasAct {
9173 a,
9174 w,
9175 bias,
9176 c,
9177 m,
9178 k,
9179 n,
9180 act,
9181 } => {
9182 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9183 unsafe {
9184 let out = sl_mut(*c, base, m * n);
9185 crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
9186 match act {
9187 Some(Activation::Gelu) => {
9188 crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
9189 }
9190 Some(other) => {
9191 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9192 apply_activation_inplace(out, *other);
9193 }
9194 None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
9195 }
9196 }
9197 }
9198
9199 Thunk::FusedResidualLN {
9200 x,
9201 res,
9202 bias,
9203 g,
9204 b,
9205 out,
9206 rows,
9207 h,
9208 eps,
9209 has_bias,
9210 } => {
9211 let (rows, h) = (*rows as usize, *h as usize);
9212 unsafe {
9213 let zero = &zero_bias[..h];
9214 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9215 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9216 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9217 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9218 let bi_ptr = bi.as_ptr() as usize;
9219 let g_ptr = sl(*g, base, h).as_ptr() as usize;
9220 let b_ptr = sl(*b, base, h).as_ptr() as usize;
9221 let e = *eps;
9222 crate::pool::par_for(rows, 4, &|off, cnt| {
9223 let xs =
9224 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9225 let rs =
9226 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9227 let os = std::slice::from_raw_parts_mut(
9228 (o_ptr as *mut f32).add(off * h),
9229 cnt * h,
9230 );
9231 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9232 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9233 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9234 crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
9235 });
9236 }
9237 }
9238
9239 Thunk::FusedResidualRmsNorm {
9240 x,
9241 res,
9242 bias,
9243 g,
9244 b,
9245 out,
9246 rows,
9247 h,
9248 eps,
9249 has_bias,
9250 } => {
9251 let (rows, h) = (*rows as usize, *h as usize);
9252 unsafe {
9253 let zero = &zero_bias[..h];
9254 let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9255 let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9256 let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9257 let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9258 let bi_ptr = bi.as_ptr() as usize;
9259 let g_ptr = sl(*g, base, h).as_ptr() as usize;
9260 let b_ptr = sl(*b, base, h).as_ptr() as usize;
9261 let e = *eps;
9262 crate::pool::par_for(rows, 4, &|off, cnt| {
9263 let xs =
9264 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9265 let rs =
9266 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9267 let os = std::slice::from_raw_parts_mut(
9268 (o_ptr as *mut f32).add(off * h),
9269 cnt * h,
9270 );
9271 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9272 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9273 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9274 crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
9275 });
9276 }
9277 }
9278
9279 Thunk::BiasAdd {
9280 src,
9281 bias,
9282 dst,
9283 m,
9284 n,
9285 } => {
9286 let (m, n) = (*m as usize, *n as usize);
9287 let len = m * n;
9288 unsafe {
9289 let out = sl_mut(*dst, base, len);
9290 if *src != *dst {
9291 let src_ptr = base.add(*src) as *const f32;
9292 let dst_ptr = base.add(*dst) as *mut f32;
9293 if src_ptr != dst_ptr {
9294 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
9295 }
9296 }
9297 crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9298 }
9299 }
9300
9301 Thunk::BinaryFull {
9302 lhs,
9303 rhs,
9304 dst,
9305 len,
9306 lhs_len,
9307 rhs_len,
9308 op,
9309 out_dims_bcast,
9310 bcast_lhs_strides,
9311 bcast_rhs_strides,
9312 elem_bytes,
9313 } => {
9314 let len = *len as usize;
9315 let ll = (*lhs_len as usize).max(1);
9316 let rl = (*rhs_len as usize).max(1);
9317 let eb = (*elem_bytes).max(1) as usize;
9318 let arena_len = arena_buf.len();
9319 let ll = ll.min((arena_len.saturating_sub(*lhs)) / eb);
9320 let rl = rl.min((arena_len.saturating_sub(*rhs)) / eb);
9321 let len = len.min((arena_len.saturating_sub(*dst)) / eb);
9322 unsafe {
9323 if eb == 8 {
9324 let l = sl_i64(*lhs, base, ll);
9325 let r = sl_i64(*rhs, base, rl);
9326 let o = sl_mut_i64(*dst, base, len);
9327 if !out_dims_bcast.is_empty() {
9328 let rank = out_dims_bcast.len();
9329 let mut coords = vec![0u32; rank];
9330 for i in 0..len {
9331 let mut rem = i;
9332 for ax in (0..rank).rev() {
9333 let sz = out_dims_bcast[ax] as usize;
9334 coords[ax] = (rem % sz) as u32;
9335 rem /= sz;
9336 }
9337 let mut li = 0usize;
9338 let mut ri = 0usize;
9339 for ax in 0..rank {
9340 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9341 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9342 }
9343 o[i] = match op {
9344 BinaryOp::Add => l[li].wrapping_add(r[ri]),
9345 BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9346 BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9347 BinaryOp::Div => {
9348 if r[ri] == 0 {
9349 0
9350 } else {
9351 l[li] / r[ri]
9352 }
9353 }
9354 BinaryOp::Max => l[li].max(r[ri]),
9355 BinaryOp::Min => l[li].min(r[ri]),
9356 BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9357 };
9358 }
9359 } else {
9360 for i in 0..len {
9361 let li = if ll == 1 { 0 } else { i % ll };
9362 let ri = if rl == 1 { 0 } else { i % rl };
9363 o[i] = match op {
9364 BinaryOp::Add => l[li].wrapping_add(r[ri]),
9365 BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9366 BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9367 BinaryOp::Div => {
9368 if r[ri] == 0 {
9369 0
9370 } else {
9371 l[li] / r[ri]
9372 }
9373 }
9374 BinaryOp::Max => l[li].max(r[ri]),
9375 BinaryOp::Min => l[li].min(r[ri]),
9376 BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9377 };
9378 }
9379 }
9380 } else {
9381 let l = sl(*lhs, base, ll);
9382 let r = sl(*rhs, base, rl);
9383 let o = sl_mut(*dst, base, len);
9384 if ll == len && rl == len {
9385 #[cfg(target_arch = "aarch64")]
9386 if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
9387 use std::arch::aarch64::*;
9388 let chunks = len / 4;
9389 for c in 0..chunks {
9390 let off = c * 4;
9391 let vl = vld1q_f32(l.as_ptr().add(off));
9392 let vr = vld1q_f32(r.as_ptr().add(off));
9393 let res = match op {
9394 BinaryOp::Add => vaddq_f32(vl, vr),
9395 BinaryOp::Mul => vmulq_f32(vl, vr),
9396 _ => unreachable!(),
9397 };
9398 vst1q_f32(o.as_mut_ptr().add(off), res);
9399 }
9400 for i in (chunks * 4)..len {
9401 o[i] = match op {
9402 BinaryOp::Add => l[i] + r[i],
9403 BinaryOp::Mul => l[i] * r[i],
9404 _ => unreachable!(),
9405 };
9406 }
9407 continue;
9408 }
9409 }
9410 if !out_dims_bcast.is_empty() {
9411 let rank = out_dims_bcast.len();
9412 let mut coords = vec![0u32; rank];
9413 for i in 0..len {
9414 let mut rem = i;
9415 for ax in (0..rank).rev() {
9416 let sz = out_dims_bcast[ax] as usize;
9417 coords[ax] = (rem % sz) as u32;
9418 rem /= sz;
9419 }
9420 let mut li = 0usize;
9421 let mut ri = 0usize;
9422 for ax in 0..rank {
9423 li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9424 ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9425 }
9426 o[i] = match op {
9427 BinaryOp::Add => l[li] + r[ri],
9428 BinaryOp::Sub => l[li] - r[ri],
9429 BinaryOp::Mul => l[li] * r[ri],
9430 BinaryOp::Div => l[li] / r[ri],
9431 BinaryOp::Max => l[li].max(r[ri]),
9432 BinaryOp::Min => l[li].min(r[ri]),
9433 BinaryOp::Pow => l[li].powf(r[ri]),
9434 };
9435 }
9436 } else {
9437 for i in 0..len {
9438 let li = if ll == 1 { 0 } else { i % ll };
9439 let ri = if rl == 1 { 0 } else { i % rl };
9440 o[i] = match op {
9441 BinaryOp::Add => l[li] + r[ri],
9442 BinaryOp::Sub => l[li] - r[ri],
9443 BinaryOp::Mul => l[li] * r[ri],
9444 BinaryOp::Div => l[li] / r[ri],
9445 BinaryOp::Max => l[li].max(r[ri]),
9446 BinaryOp::Min => l[li].min(r[ri]),
9447 BinaryOp::Pow => l[li].powf(r[ri]),
9448 };
9449 }
9450 }
9451 }
9452 }
9453 }
9454
9455 Thunk::Gather {
9456 table,
9457 table_len,
9458 idx,
9459 dst,
9460 num_idx,
9461 trailing,
9462 idx_i64,
9463 table_bytes,
9464 } => {
9465 let (ni, tr) = (*num_idx as usize, *trailing as usize);
9466 let rows = *table_len as usize / tr.max(1);
9467 unsafe {
9468 if *table_bytes == 8 {
9469 let tab = sl_i64(*table, base, *table_len as usize);
9470 let out = sl_mut_i64(*dst, base, ni * tr);
9471 if *idx_i64 != 0 {
9472 let ids = sl_i64(*idx, base, ni);
9473 for i in 0..ni {
9474 let row = ids[i].max(0) as usize;
9475 if row < rows {
9476 out[i * tr..(i + 1) * tr]
9477 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9478 }
9479 }
9480 } else {
9481 let ids = sl(*idx, base, ni);
9482 for i in 0..ni {
9483 let row = ids[i] as usize;
9484 if row < rows {
9485 out[i * tr..(i + 1) * tr]
9486 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9487 }
9488 }
9489 }
9490 } else {
9491 let tab = sl(*table, base, *table_len as usize);
9492 let out = sl_mut(*dst, base, ni * tr);
9493 if *idx_i64 != 0 {
9494 let ids = sl_i64(*idx, base, ni);
9495 for i in 0..ni {
9496 let row = ids[i].max(0) as usize;
9497 if row < rows {
9498 out[i * tr..(i + 1) * tr]
9499 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9500 }
9501 }
9502 } else {
9503 let ids = sl(*idx, base, ni);
9504 for i in 0..ni {
9505 let row = ids[i] as usize;
9506 if row < rows {
9507 out[i * tr..(i + 1) * tr]
9508 .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9509 }
9510 }
9511 }
9512 }
9513 }
9514 }
9515
9516 Thunk::Narrow {
9517 src,
9518 dst,
9519 outer,
9520 src_stride,
9521 dst_stride,
9522 inner,
9523 elem_bytes,
9524 } => {
9525 let (outer, ss, ds, inner, eb) = (
9526 *outer as usize,
9527 *src_stride as usize,
9528 *dst_stride as usize,
9529 *inner as usize,
9530 *elem_bytes as usize,
9531 );
9532 let row_bytes = inner.saturating_mul(eb);
9533 let src_row_stride = ss.saturating_mul(eb);
9534 let dst_row_stride = ds.saturating_mul(eb);
9535 if trace_thunks {
9536 eprintln!(
9537 "[narrow] src={} dst={} outer={outer} ss={ss} ds={ds} inner={inner} eb={eb} row={row_bytes} arena={}",
9538 *src,
9539 *dst,
9540 arena_buf.len()
9541 );
9542 }
9543 if row_bytes > 0 && *src != *dst {
9544 let arena_len = arena_buf.len();
9545 for o in 0..outer {
9546 let s_off = *src + o * src_row_stride;
9547 let d_off = *dst + o * dst_row_stride;
9548 if s_off == d_off {
9549 continue;
9550 }
9551 if s_off.saturating_add(row_bytes) > arena_len
9552 || d_off.saturating_add(row_bytes) > arena_len
9553 {
9554 break;
9555 }
9556 unsafe {
9557 std::ptr::copy_nonoverlapping(
9558 base.add(s_off),
9559 base.add(d_off),
9560 row_bytes,
9561 );
9562 }
9563 }
9564 }
9565 }
9566
9567 Thunk::Copy { src, dst, len } => {
9568 let mut len = *len as usize;
9569 if *src == *dst || len == 0 {
9570 continue;
9571 }
9572 let arena_len = arena_buf.len();
9573 let max_from_src = (arena_len.saturating_sub(*src)) / 4;
9574 let max_from_dst = (arena_len.saturating_sub(*dst)) / 4;
9575 len = len.min(max_from_src).min(max_from_dst);
9576 if len == 0 {
9577 continue;
9578 }
9579 let byte_len = len.saturating_mul(4);
9580 unsafe {
9581 std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
9582 }
9583 }
9584
9585 Thunk::LayerNorm {
9586 src,
9587 g,
9588 b,
9589 dst,
9590 rows,
9591 h,
9592 eps,
9593 } => {
9594 let (rows, h) = (*rows as usize, *h as usize);
9595 unsafe {
9596 let input = sl(*src, base, rows * h);
9597 let gamma = sl(*g, base, h);
9598 let beta = sl(*b, base, h);
9599 let output = sl_mut(*dst, base, rows * h);
9600 if rows >= 4 && rows * h >= 30_000 {
9602 let i_ptr = input.as_ptr() as usize;
9603 let o_ptr = output.as_mut_ptr() as usize;
9604 let g_ptr = gamma.as_ptr() as usize;
9605 let b_ptr = beta.as_ptr() as usize;
9606 let e = *eps;
9607 crate::pool::par_for(rows, 4, &|off, cnt| {
9608 let inp = std::slice::from_raw_parts(
9609 (i_ptr as *const f32).add(off * h),
9610 cnt * h,
9611 );
9612 let out = std::slice::from_raw_parts_mut(
9613 (o_ptr as *mut f32).add(off * h),
9614 cnt * h,
9615 );
9616 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9617 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9618 for row in 0..cnt {
9619 crate::kernels::layer_norm_row(
9620 &inp[row * h..(row + 1) * h],
9621 g,
9622 b,
9623 &mut out[row * h..(row + 1) * h],
9624 h,
9625 e,
9626 );
9627 }
9628 });
9629 } else {
9630 for row in 0..rows {
9631 crate::kernels::layer_norm_row(
9632 &input[row * h..(row + 1) * h],
9633 gamma,
9634 beta,
9635 &mut output[row * h..(row + 1) * h],
9636 h,
9637 *eps,
9638 );
9639 }
9640 }
9641 }
9642 }
9643
9644 Thunk::GroupNorm {
9645 src,
9646 g,
9647 b,
9648 dst,
9649 n,
9650 c,
9651 h,
9652 w,
9653 num_groups,
9654 eps,
9655 } => {
9656 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9657 let plane = c * h * w;
9658 unsafe {
9659 for ni in 0..n {
9660 let input = sl(*src, base.add(ni * plane), plane);
9661 let gamma = sl(*g, base, c);
9662 let beta = sl(*b, base, c);
9663 let output = sl_mut(*dst, base.add(ni * plane), plane);
9664 crate::kernels::group_norm_nchw(
9665 input,
9666 gamma,
9667 beta,
9668 output,
9669 1,
9670 c,
9671 h,
9672 w,
9673 *num_groups as usize,
9674 *eps,
9675 );
9676 }
9677 }
9678 }
9679
9680 Thunk::BatchNormInference {
9681 src,
9682 g,
9683 b,
9684 mean,
9685 var,
9686 dst,
9687 count,
9688 channels,
9689 eps,
9690 } => {
9691 let count = *count as usize;
9692 let c = *channels as usize;
9693 let n = count * c;
9694 unsafe {
9695 crate::kernels::batch_norm_inference(
9696 sl(*src, base, n),
9697 sl(*g, base, c),
9698 sl(*b, base, c),
9699 sl(*mean, base, c),
9700 sl(*var, base, c),
9701 sl_mut(*dst, base, n),
9702 c,
9703 *eps,
9704 );
9705 }
9706 }
9707
9708 Thunk::LayerNorm2d {
9709 src,
9710 g,
9711 b,
9712 dst,
9713 n,
9714 c,
9715 h,
9716 w,
9717 eps,
9718 } => {
9719 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9720 let plane = c * h * w;
9721 unsafe {
9722 let input = sl(*src, base, n * plane);
9723 let gamma = sl(*g, base, c);
9724 let beta = sl(*b, base, c);
9725 let output = sl_mut(*dst, base, n * plane);
9726 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
9727 }
9728 }
9729
9730 Thunk::ConvTranspose2d {
9731 src,
9732 weight,
9733 dst,
9734 n,
9735 c_in,
9736 h,
9737 w_in,
9738 c_out,
9739 h_out,
9740 w_out,
9741 kh,
9742 kw,
9743 sh,
9744 sw,
9745 ph,
9746 pw,
9747 dh,
9748 dw,
9749 groups,
9750 } => {
9751 let n = *n as usize;
9752 let c_in = *c_in as usize;
9753 let h = *h as usize;
9754 let w_in = *w_in as usize;
9755 let c_out = *c_out as usize;
9756 let h_out = *h_out as usize;
9757 let w_out = *w_out as usize;
9758 unsafe {
9759 let inp = sl(*src, base, n * c_in * h * w_in);
9760 let wt = sl(
9761 *weight,
9762 base,
9763 c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
9764 );
9765 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
9766 crate::kernels::conv_transpose2d_nchw(
9767 inp,
9768 wt,
9769 out,
9770 n,
9771 c_in,
9772 h,
9773 w_in,
9774 c_out,
9775 h_out,
9776 w_out,
9777 *kh as usize,
9778 *kw as usize,
9779 *sh as usize,
9780 *sw as usize,
9781 *ph as usize,
9782 *pw as usize,
9783 *dh as usize,
9784 *dw as usize,
9785 *groups as usize,
9786 );
9787 }
9788 }
9789
9790 Thunk::ResizeNearest2x {
9791 src,
9792 dst,
9793 n,
9794 c,
9795 h,
9796 w,
9797 } => {
9798 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9799 let in_plane = c * h * w;
9800 let out_plane = c * h * 2 * w * 2;
9801 unsafe {
9802 for ni in 0..n {
9803 let input = sl(*src, base.add(ni * in_plane), in_plane);
9804 let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
9805 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
9806 }
9807 }
9808 }
9809
9810 Thunk::AxialRope2d {
9811 src,
9812 dst,
9813 batch,
9814 seq,
9815 hidden,
9816 end_x,
9817 end_y,
9818 head_dim,
9819 num_heads,
9820 theta,
9821 repeat_factor,
9822 } => {
9823 let b = *batch as usize;
9824 let s = *seq as usize;
9825 let hdim = *head_dim as usize;
9826 let nh = *num_heads as usize;
9827 let plane = s * (*hidden as usize);
9828 unsafe {
9829 for bi in 0..b {
9830 let input = sl(*src, base.add(bi * plane), plane);
9831 let output = sl_mut(*dst, base.add(bi * plane), plane);
9832 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
9833 input,
9834 nh,
9835 s,
9836 hdim,
9837 *end_x as usize,
9838 *end_y as usize,
9839 *theta,
9840 *repeat_factor as usize,
9841 );
9842 output.copy_from_slice(&rotated);
9843 }
9844 }
9845 }
9846
9847 Thunk::RmsNorm {
9848 src,
9849 g,
9850 b,
9851 dst,
9852 rows,
9853 h,
9854 eps,
9855 } => {
9856 let (rows, h) = (*rows as usize, *h as usize);
9857 unsafe {
9858 let input = sl(*src, base, rows * h);
9859 let gamma = sl(*g, base, h);
9860 let beta = sl(*b, base, h);
9861 let output = sl_mut(*dst, base, rows * h);
9862 let inv_h = 1.0 / h as f32;
9863 for row in 0..rows {
9864 let in_row = &input[row * h..(row + 1) * h];
9865 let out_row = &mut output[row * h..(row + 1) * h];
9866 let mut sumsq = 0f32;
9868 for &v in in_row {
9869 sumsq += v * v;
9870 }
9871 let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
9872 for i in 0..h {
9873 out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
9874 }
9875 }
9876 }
9877 }
9878
9879 Thunk::Softmax { data, rows, cols } => {
9880 let (rows, cols) = (*rows as usize, *cols as usize);
9881 unsafe {
9882 crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
9883 }
9884 }
9885
9886 Thunk::Cumsum {
9887 src,
9888 dst,
9889 rows,
9890 cols,
9891 exclusive,
9892 } => {
9893 let (rows, cols) = (*rows as usize, *cols as usize);
9894 unsafe {
9895 let s = sl(*src, base, rows * cols);
9896 let d = sl_mut(*dst, base, rows * cols);
9897 if *exclusive {
9898 for r in 0..rows {
9899 let mut acc = 0.0f32;
9900 for c in 0..cols {
9901 d[r * cols + c] = acc;
9902 acc += s[r * cols + c];
9903 }
9904 }
9905 } else {
9906 for r in 0..rows {
9907 let mut acc = 0.0f32;
9908 for c in 0..cols {
9909 acc += s[r * cols + c];
9910 d[r * cols + c] = acc;
9911 }
9912 }
9913 }
9914 }
9915 }
9916
9917 Thunk::Sample {
9918 logits,
9919 dst,
9920 batch,
9921 vocab,
9922 top_k,
9923 top_p,
9924 temperature,
9925 seed,
9926 } => {
9927 let (b, v) = (*batch as usize, *vocab as usize);
9928 let k = (*top_k as usize).min(v);
9929 unsafe {
9930 let lg = sl(*logits, base, b * v);
9931 let out = sl_mut(*dst, base, b);
9932 let mut rng =
9933 rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
9934 for bi in 0..b {
9935 let row = &lg[bi * v..(bi + 1) * v];
9936 out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
9937 }
9938 }
9939 }
9940
9941 Thunk::GatedDeltaNet {
9942 q,
9943 k,
9944 v,
9945 g,
9946 beta,
9947 state,
9948 dst,
9949 batch,
9950 seq,
9951 heads,
9952 state_size,
9953 } => unsafe {
9954 execute_gated_delta_net_f32(
9955 *q,
9956 *k,
9957 *v,
9958 *g,
9959 *beta,
9960 *state,
9961 *dst,
9962 *batch as usize,
9963 *seq as usize,
9964 *heads as usize,
9965 *state_size as usize,
9966 base,
9967 );
9968 },
9969
9970 Thunk::SelectiveScan {
9971 x,
9972 delta,
9973 a,
9974 b: bp,
9975 c: cp,
9976 dst,
9977 batch,
9978 seq,
9979 hidden,
9980 state_size,
9981 } => {
9982 let (b, s, h, n) = (
9983 *batch as usize,
9984 *seq as usize,
9985 *hidden as usize,
9986 *state_size as usize,
9987 );
9988 unsafe {
9989 let xs = sl(*x, base, b * s * h);
9990 let dt = sl(*delta, base, b * s * h);
9991 let am = sl(*a, base, h * n);
9992 let bm = sl(*bp, base, b * s * n);
9993 let cm = sl(*cp, base, b * s * n);
9994 let out = sl_mut(*dst, base, b * s * h);
9995
9996 let mut state = vec![0f32; h * n];
10000 for bi in 0..b {
10001 for v in state.iter_mut() {
10003 *v = 0.0;
10004 }
10005 for si in 0..s {
10006 let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10007 let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10008 let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10009 let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10010 let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10011
10012 for ci in 0..h {
10013 let d = dt_row[ci];
10014 let xv = x_row[ci];
10015 let mut acc = 0f32;
10016 for ni in 0..n {
10017 let da = (d * am[ci * n + ni]).exp();
10019 state[ci * n + ni] =
10020 da * state[ci * n + ni] + d * b_row[ni] * xv;
10021 acc += c_row[ni] * state[ci * n + ni];
10022 }
10023 out_row[ci] = acc;
10024 }
10025 }
10026 }
10027 }
10028 }
10029
10030 Thunk::DequantMatMul {
10031 x,
10032 w_q,
10033 scale,
10034 zp,
10035 dst,
10036 m,
10037 k,
10038 n,
10039 block_size,
10040 is_asymmetric,
10041 } => {
10042 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10043 let n_blocks = k.div_ceil(bs);
10044 unsafe {
10045 let xs = sl(*x, base, m * k);
10046 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
10047 let scales = sl(*scale, base, n_blocks * n);
10048 let zps = if *is_asymmetric {
10049 sl(*zp, base, n_blocks * n)
10050 } else {
10051 &[][..]
10052 };
10053 let out = sl_mut(*dst, base, m * n);
10054 dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10055 }
10056 }
10057
10058 Thunk::DequantMatMulGguf {
10059 x,
10060 w_q,
10061 dst,
10062 m,
10063 k,
10064 n,
10065 scheme,
10066 } => {
10067 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10068 let block_bytes = scheme.gguf_block_bytes() as usize;
10069 let block_elems = scheme.gguf_block_size() as usize;
10070 debug_assert!(
10071 block_bytes > 0 && block_elems > 0,
10072 "non-GGUF scheme in GGUF arm"
10073 );
10074 debug_assert!(
10075 (k * n).is_multiple_of(block_elems),
10076 "k*n={} not aligned to GGUF block size {}",
10077 k * n,
10078 block_elems
10079 );
10080 let total_bytes = (k * n) / block_elems * block_bytes;
10081 unsafe {
10082 let xs = sl(*x, base, m * k);
10083 let w_bytes_ptr = base.add(*w_q) as *const u8;
10084 let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
10085 let out = sl_mut(*dst, base, m * n);
10086 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
10087 }
10088 }
10089
10090 Thunk::DequantMatMulInt4 {
10091 x,
10092 w_q,
10093 scale,
10094 zp,
10095 dst,
10096 m,
10097 k,
10098 n,
10099 block_size,
10100 is_asymmetric,
10101 } => {
10102 let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10103 let n_blocks = k.div_ceil(bs);
10104 unsafe {
10105 let xs = sl(*x, base, m * k);
10106 let w_bytes = std::slice::from_raw_parts(
10107 base.add(*w_q) as *const u8,
10108 (k * n).div_ceil(2),
10109 );
10110 let scales = sl(*scale, base, n_blocks * n);
10111 let zps = if *is_asymmetric {
10112 sl(*zp, base, n_blocks * n)
10113 } else {
10114 &[][..]
10115 };
10116 let out = sl_mut(*dst, base, m * n);
10117 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10118 }
10119 }
10120
10121 Thunk::DequantMatMulFp8 {
10122 x,
10123 w_q,
10124 scale,
10125 dst,
10126 m,
10127 k,
10128 n,
10129 e5m2,
10130 } => {
10131 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10132 unsafe {
10133 let xs = sl(*x, base, m * k);
10134 let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
10135 let scales = sl(*scale, base, n);
10136 let out = sl_mut(*dst, base, m * n);
10137 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
10138 }
10139 }
10140
10141 Thunk::DequantMatMulNvfp4 {
10142 x,
10143 w_q,
10144 scale,
10145 global_scale,
10146 dst,
10147 m,
10148 k,
10149 n,
10150 } => {
10151 let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10152 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
10153 unsafe {
10154 let xs = sl(*x, base, m * k);
10155 let w_bytes = std::slice::from_raw_parts(
10156 base.add(*w_q) as *const u8,
10157 (k * n).div_ceil(2),
10158 );
10159 let scale_bytes =
10160 std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
10161 let gs = sl(*global_scale, base, 1)[0];
10162 let out = sl_mut(*dst, base, m * n);
10163 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
10164 }
10165 }
10166
10167 Thunk::LoraMatMul {
10168 x,
10169 w,
10170 a,
10171 b,
10172 dst,
10173 m,
10174 k,
10175 n,
10176 r,
10177 scale,
10178 } => {
10179 let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
10180 unsafe {
10181 let xs = sl(*x, base, m * k);
10182 let ws = sl(*w, base, k * n);
10183 let a_s = sl(*a, base, k * r);
10184 let bs = sl(*b, base, r * n);
10185 let out = sl_mut(*dst, base, m * n);
10186 crate::blas::sgemm(xs, ws, out, m, k, n);
10187 let mut tmp = vec![0f32; m * r];
10188 crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
10189 if *scale != 1.0 {
10190 for v in tmp.iter_mut() {
10191 *v *= *scale;
10192 }
10193 }
10194 crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
10195 }
10196 }
10197
10198 Thunk::Attention {
10199 q,
10200 k,
10201 v,
10202 mask,
10203 out,
10204 batch,
10205 seq,
10206 kv_seq,
10207 heads,
10208 head_dim,
10209 mask_kind,
10210 q_row_stride,
10211 k_row_stride,
10212 v_row_stride,
10213 bhsd,
10214 } => {
10215 let (b, q_s, k_s, nh, dh) = (
10216 *batch as usize,
10217 *seq as usize,
10218 *kv_seq as usize,
10219 *heads as usize,
10220 *head_dim as usize,
10221 );
10222 let hs = nh * dh;
10223 let (qrs, krs, vrs) = if *bhsd {
10226 (dh, dh, dh)
10227 } else {
10228 (
10229 *q_row_stride as usize,
10230 *k_row_stride as usize,
10231 *v_row_stride as usize,
10232 )
10233 };
10234 let bhsd = *bhsd;
10235 let _ = (q_row_stride, k_row_stride, v_row_stride);
10236 let scale = (dh as f32).powf(-0.5);
10237 let ss = q_s * k_s;
10238 let cfg = crate::config::RuntimeConfig::global();
10239 unsafe {
10240 let q_len = if bhsd {
10247 b * nh * q_s * dh
10248 } else {
10249 b * q_s * qrs
10250 };
10251 let k_len = if bhsd {
10252 b * nh * k_s * dh
10253 } else {
10254 b * k_s * krs
10255 };
10256 let v_len = if bhsd {
10257 b * nh * k_s * dh
10258 } else {
10259 b * k_s * vrs
10260 };
10261 let q_data = sl(*q, base, q_len);
10262 let k_data = sl(*k, base, k_len);
10263 let v_data = sl(*v, base, v_len);
10264 let mask_data: &[f32] = match mask_kind {
10265 rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
10266 rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
10267 _ => &[],
10268 };
10269 let out_len = if bhsd {
10270 b * nh * q_s * dh
10271 } else {
10272 b * q_s * hs
10273 };
10274 let out_data = sl_mut(*out, base, out_len);
10275
10276 if bhsd {
10287 let scores = &mut sdpa_scores[..ss];
10288 for bi in 0..b {
10289 for hi in 0..nh {
10290 let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
10291 let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
10292 for qi in 0..q_s {
10294 let q_base = q_head_base + qi * dh;
10295 for ki in 0..k_s {
10296 let k_base = k_head_base + ki * dh;
10297 let mut dot = 0f32;
10298 for d in 0..dh {
10299 dot += q_data[q_base + d] * k_data[k_base + d];
10300 }
10301 scores[qi * k_s + ki] = dot * scale;
10302 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10303 && !mask_data.is_empty()
10304 && mask_data[bi * k_s + ki] < mask_thr
10305 {
10306 scores[qi * k_s + ki] = mask_neg;
10307 }
10308 }
10309 }
10310 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10311 let off = (bi * nh + hi) * q_s * k_s;
10312 for i in 0..q_s * k_s {
10313 scores[i] += mask_data[off + i];
10314 }
10315 }
10316 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10317 crate::kernels::neon_softmax(scores, q_s, k_s);
10318 for qi in 0..q_s {
10320 let o_base = q_head_base + qi * dh;
10321 for d in 0..dh {
10322 out_data[o_base + d] = 0.0;
10323 }
10324 for ki in 0..k_s {
10325 let sc = scores[qi * k_s + ki];
10326 if sc > score_thr {
10327 let v_base = k_head_base + ki * dh;
10328 for d in 0..dh {
10329 out_data[o_base + d] += sc * v_data[v_base + d];
10330 }
10331 }
10332 }
10333 }
10334 }
10335 }
10336 continue;
10337 }
10338
10339 if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
10346 let scores = &mut sdpa_scores[..ss];
10348 #[cfg(target_arch = "aarch64")]
10349 let neon_chunks = dh / 4;
10350
10351 for bi in 0..b {
10352 for hi in 0..nh {
10353 for qi in 0..q_s {
10355 let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
10356 for ki in 0..k_s {
10357 let k_off = bi * k_s * krs + ki * krs + hi * dh;
10358 #[cfg(target_arch = "aarch64")]
10359 let mut dot;
10360 #[cfg(not(target_arch = "aarch64"))]
10361 let mut dot = 0f32;
10362 #[cfg(target_arch = "aarch64")]
10363 {
10364 use std::arch::aarch64::*;
10365 let mut acc = vdupq_n_f32(0.0);
10366 for c in 0..neon_chunks {
10367 let vq =
10368 vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
10369 let vk =
10370 vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
10371 acc = vfmaq_f32(acc, vq, vk);
10372 }
10373 dot = vaddvq_f32(acc);
10374 for d in (neon_chunks * 4)..dh {
10375 dot += q_data[q_off + d] * k_data[k_off + d];
10376 }
10377 }
10378 #[cfg(not(target_arch = "aarch64"))]
10379 for d in 0..dh {
10380 dot += q_data[q_off + d] * k_data[k_off + d];
10381 }
10382 scores[qi * k_s + ki] = dot * scale;
10383 if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10390 && !mask_data.is_empty()
10391 && mask_data[bi * k_s + ki] < mask_thr
10392 {
10393 scores[qi * k_s + ki] = mask_neg;
10394 }
10395 }
10396 }
10397
10398 if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10399 let off = (bi * nh + hi) * q_s * k_s;
10400 for i in 0..q_s * k_s {
10401 scores[i] += mask_data[off + i];
10402 }
10403 }
10404 apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10405 crate::kernels::neon_softmax(scores, q_s, k_s);
10406
10407 for qi in 0..q_s {
10409 let o_off = bi * q_s * hs + qi * hs + hi * dh;
10410 for d in 0..dh {
10412 out_data[o_off + d] = 0.0;
10413 }
10414 for ki in 0..k_s {
10415 let sc = scores[qi * k_s + ki];
10416 if sc > score_thr {
10417 let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
10418 #[cfg(target_arch = "aarch64")]
10419 {
10420 use std::arch::aarch64::*;
10421 let vsc = vdupq_n_f32(sc);
10422 for c in 0..neon_chunks {
10423 let off = c * 4;
10424 let vo = vld1q_f32(
10425 out_data.as_ptr().add(o_off + off),
10426 );
10427 let vv =
10428 vld1q_f32(v_data.as_ptr().add(v_off + off));
10429 vst1q_f32(
10430 out_data.as_mut_ptr().add(o_off + off),
10431 vfmaq_f32(vo, vsc, vv),
10432 );
10433 }
10434 }
10435 #[cfg(not(target_arch = "aarch64"))]
10436 for d in 0..dh {
10437 out_data[o_off + d] += sc * v_data[v_off + d];
10438 }
10439 }
10440 }
10441 }
10442 }
10443 }
10444 } else {
10445 let total_work = b * nh;
10447 let q_addr = q_data.as_ptr() as usize;
10448 let k_addr = k_data.as_ptr() as usize;
10449 let v_addr = v_data.as_ptr() as usize;
10450 let m_addr = mask_data.as_ptr() as usize;
10451 let o_addr = out_data.as_mut_ptr() as usize;
10452 let sc_addr = sdpa_scores.as_mut_ptr() as usize;
10453
10454 crate::pool::par_for(total_work, 1, &|off, cnt| {
10455 for idx in off..off + cnt {
10456 let bi = idx / nh;
10457 let hi = idx % nh;
10458
10459 let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
10460 let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
10461 let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
10462 let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
10463 let sc = std::slice::from_raw_parts_mut(
10464 (sc_addr as *mut f32).add(idx * ss),
10465 ss,
10466 );
10467
10468 crate::blas::sgemm_general(
10471 q_start,
10472 k_start,
10473 sc.as_mut_ptr(),
10474 q_s,
10475 k_s,
10476 dh,
10477 scale,
10478 0.0,
10479 qrs,
10480 krs,
10481 k_s,
10482 false,
10483 true,
10484 );
10485
10486 match mask_kind {
10487 rlx_ir::op::MaskKind::Custom => {
10488 let mask_bi = std::slice::from_raw_parts(
10489 (m_addr as *const f32).add(bi * k_s),
10490 k_s,
10491 );
10492 for ki in 0..k_s {
10493 if mask_bi[ki] < mask_thr {
10494 for qi in 0..q_s {
10495 sc[qi * k_s + ki] = mask_neg;
10496 }
10497 }
10498 }
10499 }
10500 rlx_ir::op::MaskKind::Bias => {
10501 let bias = std::slice::from_raw_parts(
10503 (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
10504 q_s * k_s,
10505 );
10506 for i in 0..q_s * k_s {
10507 sc[i] += bias[i];
10508 }
10509 }
10510 _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
10511 }
10512
10513 crate::kernels::neon_softmax(sc, q_s, k_s);
10514
10515 crate::blas::sgemm_general(
10519 sc.as_ptr(),
10520 v_start,
10521 o_start,
10522 q_s,
10523 dh,
10524 k_s,
10525 1.0,
10526 0.0,
10527 k_s,
10528 vrs,
10529 hs,
10530 false,
10531 false,
10532 );
10533 }
10534 });
10535 }
10536 }
10537 }
10538
10539 Thunk::AttentionBackward {
10540 q,
10541 k,
10542 v,
10543 dy,
10544 mask,
10545 out,
10546 batch,
10547 seq,
10548 kv_seq,
10549 heads,
10550 head_dim,
10551 mask_kind,
10552 wrt,
10553 bhsd,
10554 } => {
10555 let (b, q_s, k_s, nh, dh) = (
10556 *batch as usize,
10557 *seq as usize,
10558 *kv_seq as usize,
10559 *heads as usize,
10560 *head_dim as usize,
10561 );
10562 unsafe {
10563 let q_len = if *bhsd {
10564 b * nh * q_s * dh
10565 } else {
10566 b * q_s * nh * dh
10567 };
10568 let k_len = if *bhsd {
10569 b * nh * k_s * dh
10570 } else {
10571 b * k_s * nh * dh
10572 };
10573 let out_len = match wrt {
10574 rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
10575 k_len
10576 }
10577 rlx_ir::op::AttentionBwdWrt::Query => q_len,
10578 };
10579 let q_data = sl(*q, base, q_len);
10580 let k_data = sl(*k, base, k_len);
10581 let v_data = sl(*v, base, k_len);
10582 let dy_data = sl(*dy, base, q_len);
10583 let out_data = sl_mut(*out, base, out_len);
10584 let mask_data: &[f32] = if *mask != 0 {
10585 let ml = match mask_kind {
10586 rlx_ir::op::MaskKind::Custom => b * k_s,
10587 rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
10588 _ => 0,
10589 };
10590 sl(*mask, base, ml)
10591 } else {
10592 &[]
10593 };
10594 crate::attention_bwd::attention_backward(
10595 *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
10596 *mask_kind, mask_data, *bhsd,
10597 );
10598 }
10599 }
10600
10601 Thunk::ActivationInPlace { data, len, act } => {
10602 let len = *len as usize;
10603 unsafe {
10604 let d = sl_mut(*data, base, len);
10605 match act {
10606 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
10607 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
10608 Activation::Silu => crate::kernels::par_silu_inplace(d),
10609 Activation::Relu => {
10610 for v in d.iter_mut() {
10611 *v = v.max(0.0);
10612 }
10613 }
10614 Activation::Sigmoid => {
10615 for v in d.iter_mut() {
10616 *v = 1.0 / (1.0 + (-*v).exp());
10617 }
10618 }
10619 Activation::Tanh => {
10620 for v in d.iter_mut() {
10621 *v = v.tanh();
10622 }
10623 }
10624 Activation::Exp => {
10625 for v in d.iter_mut() {
10626 *v = v.exp();
10627 }
10628 }
10629 Activation::Log => {
10630 for v in d.iter_mut() {
10631 *v = v.ln();
10632 }
10633 }
10634 Activation::Sqrt => {
10635 for v in d.iter_mut() {
10636 *v = v.sqrt();
10637 }
10638 }
10639 Activation::Rsqrt => {
10640 for v in d.iter_mut() {
10641 *v = 1.0 / v.sqrt();
10642 }
10643 }
10644 Activation::Neg => {
10645 for v in d.iter_mut() {
10646 *v = -*v;
10647 }
10648 }
10649 Activation::Abs => {
10650 for v in d.iter_mut() {
10651 *v = v.abs();
10652 }
10653 }
10654 Activation::Round => {
10655 for v in d.iter_mut() {
10656 *v = v.round();
10657 }
10658 }
10659 Activation::Sin => {
10660 for v in d.iter_mut() {
10661 *v = v.sin();
10662 }
10663 }
10664 Activation::Cos => {
10665 for v in d.iter_mut() {
10666 *v = v.cos();
10667 }
10668 }
10669 Activation::Tan => {
10670 for v in d.iter_mut() {
10671 *v = v.tan();
10672 }
10673 }
10674 Activation::Atan => {
10675 for v in d.iter_mut() {
10676 *v = v.atan();
10677 }
10678 }
10679 }
10680 }
10681 }
10682
10683 Thunk::FusedAttnBlock {
10684 hidden,
10685 qkv_w,
10686 out_w,
10687 mask,
10688 out,
10689 qkv_b,
10690 out_b,
10691 cos,
10692 sin,
10693 cos_len,
10694 batch,
10695 seq,
10696 hs,
10697 nh,
10698 dh,
10699 has_bias,
10700 has_rope,
10701 } => {
10702 let (b, s) = (*batch as usize, *seq as usize);
10703 let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
10704 let m = b * s;
10705 let scale = (d_h as f32).powf(-0.5);
10706 let half = d_h / 2;
10707 unsafe {
10708 let inp = sl(*hidden, base, m * h);
10709 let wq = sl(*qkv_w, base, h * 3 * h);
10710 let wo = sl(*out_w, base, h * h);
10711 let mk = sl(*mask, base, b * s);
10712 let dst = sl_mut(*out, base, m * h);
10713
10714 let mut qkv = vec![0f32; m * 3 * h];
10716 let mut attn_out = vec![0f32; m * h];
10717 let mut scores_buf = vec![0f32; s * s]; crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
10721 if *has_bias {
10722 let bias = sl(*qkv_b, base, 3 * h);
10723 crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
10724 }
10725
10726 #[cfg(target_arch = "aarch64")]
10729 let neon_chunks = d_h / 4;
10730 #[cfg(target_arch = "aarch64")]
10731 let _rope_chunks = half / 4;
10732
10733 for bi in 0..b {
10734 for hi in 0..n_h {
10735 for qi in 0..s {
10737 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10738 for ki in 0..s {
10739 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10740 let mut dot = 0f32;
10741
10742 if *has_rope {
10743 let q_cos = qi * half;
10745 let k_cos = ki * half;
10746 let cos_tab = sl(*cos, base, *cos_len as usize);
10747 let sin_tab = sl(*sin, base, *cos_len as usize);
10748 for i in 0..half {
10751 let q1 = qkv[q_base + i];
10752 let q2 = qkv[q_base + half + i];
10753 let k1 = qkv[k_base + i];
10754 let k2 = qkv[k_base + half + i];
10755 let c_q = cos_tab[q_cos + i];
10756 let s_q = sin_tab[q_cos + i];
10757 let c_k = cos_tab[k_cos + i];
10758 let s_k = sin_tab[k_cos + i];
10759 let qr1 = q1 * c_q - q2 * s_q;
10760 let kr1 = k1 * c_k - k2 * s_k;
10761 let qr2 = q2 * c_q + q1 * s_q;
10762 let kr2 = k2 * c_k + k1 * s_k;
10763 dot += qr1 * kr1 + qr2 * kr2;
10764 }
10765 } else {
10766 #[cfg(target_arch = "aarch64")]
10768 {
10769 use std::arch::aarch64::*;
10770 let mut acc = vdupq_n_f32(0.0);
10771 for c in 0..neon_chunks {
10772 let vq =
10773 vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
10774 let vk =
10775 vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
10776 acc = vfmaq_f32(acc, vq, vk);
10777 }
10778 dot = vaddvq_f32(acc);
10779 for d in (neon_chunks * 4)..d_h {
10780 dot += qkv[q_base + d] * qkv[k_base + d];
10781 }
10782 }
10783 #[cfg(not(target_arch = "aarch64"))]
10784 for d in 0..d_h {
10785 dot += qkv[q_base + d] * qkv[k_base + d];
10786 }
10787 }
10788
10789 scores_buf[qi * s + ki] = dot * scale;
10790 if mk[bi * s + ki] < mask_thr {
10791 scores_buf[qi * s + ki] = mask_neg;
10792 }
10793 }
10794 }
10795
10796 crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
10798
10799 for qi in 0..s {
10801 let o_base = bi * s * h + qi * h + hi * d_h;
10802 for d in 0..d_h {
10803 attn_out[o_base + d] = 0.0;
10804 }
10805 for ki in 0..s {
10806 let sc = scores_buf[qi * s + ki];
10807 if sc > score_thr {
10808 let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10809 #[cfg(target_arch = "aarch64")]
10810 {
10811 use std::arch::aarch64::*;
10812 let vsc = vdupq_n_f32(sc);
10813 for c in 0..neon_chunks {
10814 let off = c * 4;
10815 let vo =
10816 vld1q_f32(attn_out.as_ptr().add(o_base + off));
10817 let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
10818 vst1q_f32(
10819 attn_out.as_mut_ptr().add(o_base + off),
10820 vfmaq_f32(vo, vsc, vv),
10821 );
10822 }
10823 }
10824 #[cfg(not(target_arch = "aarch64"))]
10825 for d in 0..d_h {
10826 attn_out[o_base + d] += sc * qkv[v_base + d];
10827 }
10828 }
10829 }
10830 }
10831 }
10832 }
10833
10834 crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
10836 if *has_bias {
10837 let bias = sl(*out_b, base, h);
10838 crate::blas::bias_add(dst, bias, m, h);
10839 }
10840 }
10841 }
10842
10843 Thunk::Rope {
10844 src,
10845 cos,
10846 sin,
10847 dst,
10848 batch,
10849 seq,
10850 hidden,
10851 head_dim,
10852 n_rot,
10853 cos_len,
10854 src_row_stride,
10855 } => {
10856 let (b, s, hs, dh, nr) = (
10857 *batch as usize,
10858 *seq as usize,
10859 *hidden as usize,
10860 *head_dim as usize,
10861 *n_rot as usize,
10862 );
10863 let tab_half = dh / 2;
10864 let rot_half = nr / 2;
10865 let nh = hs / dh;
10866 let cl = *cos_len as usize;
10867 let src_rs = *src_row_stride as usize;
10868 unsafe {
10869 let x = sl(*src, base, b * s * src_rs);
10870 let cos_tab = sl(*cos, base, cl);
10871 let sin_tab = sl(*sin, base, cl);
10872 let out = sl_mut(*dst, base, b * s * hs);
10873
10874 let total = b * s;
10875 let x_ptr = x.as_ptr() as usize;
10876 let o_ptr = out.as_mut_ptr() as usize;
10877 let c_ptr = cos_tab.as_ptr() as usize;
10878 let s_ptr = sin_tab.as_ptr() as usize;
10879
10880 crate::pool::par_for(total, 4, &|off, cnt| {
10881 for idx in off..off + cnt {
10882 let bi = idx / s;
10883 let si = idx % s;
10884 let tab_off = si * tab_half;
10885
10886 for hi in 0..nh {
10887 let src_base = bi * s * src_rs + si * src_rs + hi * dh;
10888 let dst_base = bi * s * hs + si * hs + hi * dh;
10889 let xp = (x_ptr as *const f32).add(src_base);
10890 let op = (o_ptr as *mut f32).add(dst_base);
10891 let cp = (c_ptr as *const f32).add(tab_off);
10892 let sp = (s_ptr as *const f32).add(tab_off);
10893
10894 for i in 0..rot_half {
10895 let x1 = *xp.add(i);
10896 let x2 = *xp.add(rot_half + i);
10897 let cv = *cp.add(i);
10898 let sv = *sp.add(i);
10899 *op.add(i) = x1 * cv - x2 * sv;
10900 *op.add(rot_half + i) = x2 * cv + x1 * sv;
10901 }
10902 for j in nr..dh {
10903 *op.add(j) = *xp.add(j);
10904 }
10905 }
10906 }
10907 });
10908 }
10909 }
10910 Thunk::FusedBertLayer {
10911 hidden,
10912 qkv_w,
10913 qkv_b,
10914 out_w,
10915 out_b,
10916 mask,
10917 ln1_g,
10918 ln1_b,
10919 eps1,
10920 fc1_w,
10921 fc1_b,
10922 fc2_w,
10923 fc2_b,
10924 ln2_g,
10925 ln2_b,
10926 eps2,
10927 out,
10928 batch,
10929 seq,
10930 hs,
10931 nh,
10932 dh,
10933 int_dim,
10934 } => {
10935 let (b, s, h, n_h, d_h) = (
10936 *batch as usize,
10937 *seq as usize,
10938 *hs as usize,
10939 *nh as usize,
10940 *dh as usize,
10941 );
10942 let m = b * s;
10943 let id = *int_dim as usize;
10944 let scale = (d_h as f32).powf(-0.5);
10945 let _half = d_h / 2;
10946 #[cfg(target_arch = "aarch64")]
10947 let neon_chunks = d_h / 4;
10948 unsafe {
10949 let inp = sl(*hidden, base, m * h);
10950 let dst = sl_mut(*out, base, m * h);
10951 let mk = sl(*mask, base, b * s);
10952
10953 let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
10955 let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
10956 let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
10957 let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
10958 let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
10959 let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
10960
10961 crate::blas::par_sgemm_bias(
10963 inp,
10964 sl(*qkv_w, base, h * 3 * h),
10965 sl(*qkv_b, base, 3 * h),
10966 qkv,
10967 m,
10968 h,
10969 3 * h,
10970 );
10971
10972 for bi in 0..b {
10974 for hi in 0..n_h {
10975 for qi in 0..s {
10976 for ki in 0..s {
10977 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10978 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10979 #[cfg(target_arch = "aarch64")]
10980 let dot;
10981 #[cfg(not(target_arch = "aarch64"))]
10982 let mut dot = 0f32;
10983 #[cfg(target_arch = "aarch64")]
10984 {
10985 use std::arch::aarch64::*;
10986 let mut acc = vdupq_n_f32(0.0);
10987 for c in 0..neon_chunks {
10988 acc = vfmaq_f32(
10989 acc,
10990 vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
10991 vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
10992 );
10993 }
10994 dot = vaddvq_f32(acc);
10995 }
10996 #[cfg(not(target_arch = "aarch64"))]
10997 for d in 0..d_h {
10998 dot += qkv[q_base + d] * qkv[k_base + d];
10999 }
11000 sc[qi * s + ki] = dot * scale;
11001 if mk[bi * s + ki] < mask_thr {
11002 sc[qi * s + ki] = mask_neg;
11003 }
11004 }
11005 }
11006 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11007 for qi in 0..s {
11008 let o = bi * s * h + qi * h + hi * d_h;
11009 for d in 0..d_h {
11010 attn[o + d] = 0.0;
11011 }
11012 for ki in 0..s {
11013 let w = sc[qi * s + ki];
11014 if w > score_thr {
11015 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11016 #[cfg(target_arch = "aarch64")]
11017 {
11018 use std::arch::aarch64::*;
11019 let vw = vdupq_n_f32(w);
11020 for c in 0..neon_chunks {
11021 let off = c * 4;
11022 vst1q_f32(
11023 attn.as_mut_ptr().add(o + off),
11024 vfmaq_f32(
11025 vld1q_f32(attn.as_ptr().add(o + off)),
11026 vw,
11027 vld1q_f32(qkv.as_ptr().add(v + off)),
11028 ),
11029 );
11030 }
11031 }
11032 #[cfg(not(target_arch = "aarch64"))]
11033 for d in 0..d_h {
11034 attn[o + d] += w * qkv[v + d];
11035 }
11036 }
11037 }
11038 }
11039 }
11040 }
11041
11042 crate::blas::sgemm_bias(
11044 attn,
11045 sl(*out_w, base, h * h),
11046 sl(*out_b, base, h),
11047 res,
11048 m,
11049 h,
11050 h,
11051 );
11052 #[cfg(target_arch = "aarch64")]
11053 {
11054 use std::arch::aarch64::*;
11055 let chunks_h = (m * h) / 4;
11056 for c in 0..chunks_h {
11057 let off = c * 4;
11058 vst1q_f32(
11059 res.as_mut_ptr().add(off),
11060 vaddq_f32(
11061 vld1q_f32(res.as_ptr().add(off)),
11062 vld1q_f32(inp.as_ptr().add(off)),
11063 ),
11064 );
11065 }
11066 for i in (chunks_h * 4)..(m * h) {
11067 res[i] += inp[i];
11068 }
11069 }
11070 #[cfg(not(target_arch = "aarch64"))]
11071 for i in 0..m * h {
11072 res[i] += inp[i];
11073 }
11074
11075 let g1 = sl(*ln1_g, base, h);
11077 let b1 = sl(*ln1_b, base, h);
11078 for r in 0..m {
11079 crate::kernels::layer_norm_row(
11080 &res[r * h..(r + 1) * h],
11081 g1,
11082 b1,
11083 &mut normed[r * h..(r + 1) * h],
11084 h,
11085 *eps1,
11086 );
11087 }
11088
11089 crate::blas::par_sgemm_bias(
11091 normed,
11092 sl(*fc1_w, base, h * id),
11093 sl(*fc1_b, base, id),
11094 ffn,
11095 m,
11096 h,
11097 id,
11098 );
11099 crate::kernels::par_gelu_inplace(ffn);
11100
11101 crate::blas::par_sgemm_bias(
11103 ffn,
11104 sl(*fc2_w, base, id * h),
11105 sl(*fc2_b, base, h),
11106 res,
11107 m,
11108 id,
11109 h,
11110 );
11111 #[cfg(target_arch = "aarch64")]
11112 {
11113 use std::arch::aarch64::*;
11114 let chunks_h = (m * h) / 4;
11115 for c in 0..chunks_h {
11116 let off = c * 4;
11117 vst1q_f32(
11118 res.as_mut_ptr().add(off),
11119 vaddq_f32(
11120 vld1q_f32(res.as_ptr().add(off)),
11121 vld1q_f32(normed.as_ptr().add(off)),
11122 ),
11123 );
11124 }
11125 for i in (chunks_h * 4)..(m * h) {
11126 res[i] += normed[i];
11127 }
11128 }
11129 #[cfg(not(target_arch = "aarch64"))]
11130 for i in 0..m * h {
11131 res[i] += normed[i];
11132 }
11133
11134 let g2 = sl(*ln2_g, base, h);
11136 let b2 = sl(*ln2_b, base, h);
11137 for r in 0..m {
11138 crate::kernels::layer_norm_row(
11139 &res[r * h..(r + 1) * h],
11140 g2,
11141 b2,
11142 &mut dst[r * h..(r + 1) * h],
11143 h,
11144 *eps2,
11145 );
11146 }
11147 }
11148 }
11149
11150 Thunk::FusedNomicLayer {
11151 hidden,
11152 qkv_w,
11153 out_w,
11154 mask,
11155 cos,
11156 sin,
11157 cos_len,
11158 ln1_g,
11159 ln1_b,
11160 eps1,
11161 fc11_w,
11162 fc12_w: _,
11163 fc2_w,
11164 ln2_g,
11165 ln2_b,
11166 eps2,
11167 out,
11168 batch,
11169 seq,
11170 hs,
11171 nh,
11172 dh,
11173 int_dim,
11174 } => {
11175 let (b, s, h, n_h, d_h) = (
11176 *batch as usize,
11177 *seq as usize,
11178 *hs as usize,
11179 *nh as usize,
11180 *dh as usize,
11181 );
11182 let m = b * s;
11183 let id = *int_dim as usize;
11184 let scale = (d_h as f32).powf(-0.5);
11185 let half_dh = d_h / 2;
11186 #[cfg(target_arch = "aarch64")]
11187 let neon_chunks = d_h / 4;
11188 unsafe {
11189 let inp = sl(*hidden, base, m * h);
11190 let dst = sl_mut(*out, base, m * h);
11191 let mk = sl(*mask, base, b * s);
11192 let cos_tab = sl(*cos, base, *cos_len as usize);
11193 let sin_tab = sl(*sin, base, *cos_len as usize);
11194 let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
11196
11197 let mut qkv = vec![0f32; m * 3 * h];
11198 let mut attn = vec![0f32; m * h];
11199 let mut res = vec![0f32; m * h];
11200 let mut normed = vec![0f32; m * h];
11201 let mut ffn_concat = vec![0f32; m * 2 * id]; let mut sc = vec![0f32; s * s];
11203
11204 crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
11206
11207 for bi in 0..b {
11209 for hi in 0..n_h {
11210 for qi in 0..s {
11211 for ki in 0..s {
11212 let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11213 let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11214 let mut dot = 0f32;
11215 for i in 0..half_dh {
11216 let q1 = qkv[q_base + i];
11217 let q2 = qkv[q_base + half_dh + i];
11218 let k1 = qkv[k_base + i];
11219 let k2 = qkv[k_base + half_dh + i];
11220 let cq = cos_tab[qi * half_dh + i];
11221 let sq = sin_tab[qi * half_dh + i];
11222 let ck = cos_tab[ki * half_dh + i];
11223 let sk = sin_tab[ki * half_dh + i];
11224 dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
11225 + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
11226 }
11227 sc[qi * s + ki] = dot * scale;
11228 if mk[bi * s + ki] < mask_thr {
11229 sc[qi * s + ki] = mask_neg;
11230 }
11231 }
11232 }
11233 crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11234 for qi in 0..s {
11235 let o = bi * s * h + qi * h + hi * d_h;
11236 for d in 0..d_h {
11237 attn[o + d] = 0.0;
11238 }
11239 for ki in 0..s {
11240 let w = sc[qi * s + ki];
11241 if w > score_thr {
11242 let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11243 #[cfg(target_arch = "aarch64")]
11244 {
11245 use std::arch::aarch64::*;
11246 let vw = vdupq_n_f32(w);
11247 for c in 0..neon_chunks {
11248 let off = c * 4;
11249 vst1q_f32(
11250 attn.as_mut_ptr().add(o + off),
11251 vfmaq_f32(
11252 vld1q_f32(attn.as_ptr().add(o + off)),
11253 vw,
11254 vld1q_f32(qkv.as_ptr().add(v + off)),
11255 ),
11256 );
11257 }
11258 }
11259 #[cfg(not(target_arch = "aarch64"))]
11260 for d in 0..d_h {
11261 attn[o + d] += w * qkv[v + d];
11262 }
11263 }
11264 }
11265 }
11266 }
11267 }
11268
11269 crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
11271 for i in 0..m * h {
11272 res[i] += inp[i];
11273 }
11274
11275 let g1 = sl(*ln1_g, base, h);
11277 let b1 = sl(*ln1_b, base, h);
11278 for r in 0..m {
11279 crate::kernels::layer_norm_row(
11280 &res[r * h..(r + 1) * h],
11281 g1,
11282 b1,
11283 &mut normed[r * h..(r + 1) * h],
11284 h,
11285 *eps1,
11286 );
11287 }
11288
11289 crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
11291 for row in 0..m {
11294 let bo = row * 2 * id;
11295 for j in 0..id {
11297 let x = ffn_concat[bo + id + j];
11298 ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
11299 }
11300 for j in 0..id {
11302 ffn_concat[bo + j] *= ffn_concat[bo + id + j];
11303 }
11304 }
11305
11306 crate::blas::sgemm_general(
11311 ffn_concat.as_ptr(),
11312 sl(*fc2_w, base, id * h).as_ptr(),
11313 res.as_mut_ptr(),
11314 m,
11315 h,
11316 id,
11317 1.0,
11318 0.0,
11319 2 * id,
11320 h,
11321 h,
11322 false,
11323 false,
11324 );
11325 for i in 0..m * h {
11326 res[i] += normed[i];
11327 }
11328
11329 let g2 = sl(*ln2_g, base, h);
11331 let b2 = sl(*ln2_b, base, h);
11332 for r in 0..m {
11333 crate::kernels::layer_norm_row(
11334 &res[r * h..(r + 1) * h],
11335 g2,
11336 b2,
11337 &mut dst[r * h..(r + 1) * h],
11338 h,
11339 *eps2,
11340 );
11341 }
11342 }
11343 }
11344
11345 Thunk::FusedSwiGLU {
11346 src,
11347 dst,
11348 n_half,
11349 total,
11350 gate_first,
11351 } => {
11352 let n = *n_half as usize;
11353 let t = *total as usize;
11354 let outer = t / n;
11355 let in_total = outer * 2 * n;
11356 let gate_first = *gate_first;
11357 unsafe {
11358 let inp = sl(*src, base, in_total);
11359 let out = sl_mut(*dst, base, t);
11360 for o in 0..outer {
11361 let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
11362 let out_row = &mut out[o * n..(o + 1) * n];
11363 for i in 0..n {
11364 let (up, gate) = if gate_first {
11365 (in_row[n + i], in_row[i])
11366 } else {
11367 (in_row[i], in_row[n + i])
11368 };
11369 out_row[i] = up * (gate / (1.0 + (-gate).exp()));
11370 }
11371 }
11372 }
11373 }
11374
11375 Thunk::Concat {
11376 dst,
11377 outer,
11378 inner,
11379 total_axis,
11380 inputs,
11381 } => {
11382 let outer = *outer as usize;
11383 let inner = *inner as usize;
11384 let total_axis = *total_axis as usize;
11385 let row_stride = total_axis * inner;
11386 let out_total = outer * row_stride;
11387 unsafe {
11388 let out = sl_mut(*dst, base, out_total);
11389 let mut cum: usize = 0;
11390 for (src_off, in_axis) in inputs {
11391 let in_axis = *in_axis as usize;
11392 let copy_per_row = in_axis * inner;
11393 let dst_col_off = cum * inner;
11394 let in_total = outer * copy_per_row;
11395 let inp = sl(*src_off, base, in_total);
11396 for o in 0..outer {
11397 let dst_row_start = o * row_stride + dst_col_off;
11398 let src_row_start = o * copy_per_row;
11399 out[dst_row_start..dst_row_start + copy_per_row]
11400 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
11401 }
11402 cum += in_axis;
11403 }
11404 }
11405 }
11406
11407 Thunk::ConcatF64 {
11408 dst,
11409 outer,
11410 inner,
11411 total_axis,
11412 inputs,
11413 } => {
11414 let outer = *outer as usize;
11415 let inner = *inner as usize;
11416 let total_axis = *total_axis as usize;
11417 let row_stride = total_axis * inner;
11418 let out_total = outer * row_stride;
11419 unsafe {
11420 let out = sl_mut_f64(*dst, base, out_total);
11421 let mut cum: usize = 0;
11422 for (src_off, in_axis) in inputs {
11423 let in_axis = *in_axis as usize;
11424 let copy_per_row = in_axis * inner;
11425 let dst_col_off = cum * inner;
11426 let in_total = outer * copy_per_row;
11427 let inp = sl_f64(*src_off, base, in_total);
11428 for o in 0..outer {
11429 let dst_row_start = o * row_stride + dst_col_off;
11430 let src_row_start = o * copy_per_row;
11431 out[dst_row_start..dst_row_start + copy_per_row]
11432 .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
11433 }
11434 cum += in_axis;
11435 }
11436 }
11437 }
11438
11439 Thunk::Compare {
11440 lhs,
11441 rhs,
11442 dst,
11443 len,
11444 op,
11445 inputs_i64,
11446 inputs_elem_bytes,
11447 dst_elem_bytes,
11448 } => {
11449 let len = *len as usize;
11450 let arena_len = arena_buf.len();
11451 let elem = (*inputs_elem_bytes).max(1) as usize;
11452 let dst_eb = (*dst_elem_bytes).max(1) as usize;
11453 let max_l = (arena_len.saturating_sub(*lhs)) / elem;
11454 let max_r = (arena_len.saturating_sub(*rhs)) / elem;
11455 let max_d = (arena_len.saturating_sub(*dst)) / dst_eb;
11456 let len = len.min(max_l).min(max_r).min(max_d);
11457 if trace_thunks && len > 0 {
11458 eprintln!("[compare] len={len} lhs={} rhs={} dst={}", *lhs, *rhs, *dst);
11459 }
11460 if elem == 1 {
11461 let l = arena_buf[*lhs..*lhs + len].to_vec();
11462 let r = arena_buf[*rhs..*rhs + len].to_vec();
11463 for i in 0..len {
11464 let v = match op {
11465 CmpOp::Eq => l[i] == r[i],
11466 CmpOp::Ne => l[i] != r[i],
11467 CmpOp::Lt => l[i] < r[i],
11468 CmpOp::Le => l[i] <= r[i],
11469 CmpOp::Gt => l[i] > r[i],
11470 CmpOp::Ge => l[i] >= r[i],
11471 };
11472 if *dst_elem_bytes == 1 {
11473 arena_buf[*dst + i] = u8::from(v);
11474 } else {
11475 unsafe {
11476 let o = sl_mut(*dst, base, len);
11477 o[i] = if v { 1.0 } else { 0.0 };
11478 }
11479 }
11480 }
11481 } else if *inputs_i64 != 0 {
11482 unsafe {
11483 let l = sl_i64(*lhs, base, len);
11484 let r = sl_i64(*rhs, base, len);
11485 for i in 0..len {
11486 let v = match op {
11487 CmpOp::Eq => l[i] == r[i],
11488 CmpOp::Ne => l[i] != r[i],
11489 CmpOp::Lt => l[i] < r[i],
11490 CmpOp::Le => l[i] <= r[i],
11491 CmpOp::Gt => l[i] > r[i],
11492 CmpOp::Ge => l[i] >= r[i],
11493 };
11494 if *dst_elem_bytes == 1 {
11495 arena_buf[*dst + i] = u8::from(v);
11496 } else {
11497 let o = sl_mut(*dst, base, len);
11498 o[i] = if v { 1.0 } else { 0.0 };
11499 }
11500 }
11501 }
11502 } else {
11503 unsafe {
11504 let l = sl(*lhs, base, len);
11505 let r = sl(*rhs, base, len);
11506 for i in 0..len {
11507 let v = match op {
11508 CmpOp::Eq => l[i] == r[i],
11509 CmpOp::Ne => l[i] != r[i],
11510 CmpOp::Lt => l[i] < r[i],
11511 CmpOp::Le => l[i] <= r[i],
11512 CmpOp::Gt => l[i] > r[i],
11513 CmpOp::Ge => l[i] >= r[i],
11514 };
11515 if *dst_elem_bytes == 1 {
11516 arena_buf[*dst + i] = u8::from(v);
11517 } else {
11518 let o = sl_mut(*dst, base, len);
11519 o[i] = if v { 1.0 } else { 0.0 };
11520 }
11521 }
11522 }
11523 }
11524 }
11525
11526 Thunk::Where {
11527 cond,
11528 on_true,
11529 on_false,
11530 dst,
11531 len,
11532 elem_bytes,
11533 cond_elem_bytes,
11534 } => {
11535 let len = *len as usize;
11536 let eb = *elem_bytes as usize;
11537 let cond_eb = (*cond_elem_bytes).max(1) as usize;
11538 let arena_len = arena_buf.len();
11539 let len = len
11540 .min((arena_len.saturating_sub(*cond)) / cond_eb)
11541 .min((arena_len.saturating_sub(*on_true)) / eb)
11542 .min((arena_len.saturating_sub(*on_false)) / eb)
11543 .min((arena_len.saturating_sub(*dst)) / eb);
11544 unsafe {
11545 if *elem_bytes == 8 {
11546 let t = sl_i64(*on_true, base, len);
11547 let e = sl_i64(*on_false, base, len);
11548 let o = sl_mut_i64(*dst, base, len);
11549 if *cond_elem_bytes == 1 {
11550 let c = &arena_buf[*cond..*cond + len];
11551 for i in 0..len {
11552 o[i] = if c[i] != 0 { t[i] } else { e[i] };
11553 }
11554 } else {
11555 let c = sl_i64(*cond, base, len);
11556 for i in 0..len {
11557 o[i] = if c[i] != 0 { t[i] } else { e[i] };
11558 }
11559 }
11560 } else if *cond_elem_bytes == 1 {
11561 let c = &arena_buf[*cond..*cond + len];
11562 let t = sl(*on_true, base, len);
11563 let e = sl(*on_false, base, len);
11564 let o = sl_mut(*dst, base, len);
11565 for i in 0..len {
11566 o[i] = if c[i] != 0 { t[i] } else { e[i] };
11567 }
11568 } else {
11569 let c = sl(*cond, base, len);
11570 let t = sl(*on_true, base, len);
11571 let e = sl(*on_false, base, len);
11572 let o = sl_mut(*dst, base, len);
11573 for i in 0..len {
11574 o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
11575 }
11576 }
11577 }
11578 }
11579
11580 Thunk::ScatterAdd {
11581 updates,
11582 indices,
11583 dst,
11584 num_updates,
11585 out_dim,
11586 trailing,
11587 } => {
11588 let num_updates = *num_updates as usize;
11589 let out_dim = *out_dim as usize;
11590 let trailing = *trailing as usize;
11591 unsafe {
11592 let upd = sl(*updates, base, num_updates * trailing);
11593 let ids = sl(*indices, base, num_updates);
11594 let out = sl_mut(*dst, base, out_dim * trailing);
11595 for v in out.iter_mut() {
11597 *v = 0.0;
11598 }
11599 for i in 0..num_updates {
11600 let row = ids[i] as usize;
11601 debug_assert!(row < out_dim, "ScatterAdd index out of range");
11602 let src_off = i * trailing;
11603 let dst_off = row * trailing;
11604 for j in 0..trailing {
11605 out[dst_off + j] += upd[src_off + j];
11606 }
11607 }
11608 }
11609 }
11610
11611 Thunk::GroupedMatMul {
11612 input,
11613 weight,
11614 expert_idx,
11615 dst,
11616 m,
11617 k_dim,
11618 n,
11619 num_experts,
11620 } => {
11621 let m = *m as usize;
11622 let k_dim = *k_dim as usize;
11623 let n = *n as usize;
11624 let num_experts = *num_experts as usize;
11625 unsafe {
11626 let inp = sl(*input, base, m * k_dim);
11627 let wt = sl(*weight, base, num_experts * k_dim * n);
11628 let ids = sl(*expert_idx, base, m);
11629 let out = sl_mut(*dst, base, m * n);
11630
11631 let mut counts = vec![0usize; num_experts];
11634 for i in 0..m {
11635 let e = ids[i] as usize;
11636 debug_assert!(
11637 e < num_experts,
11638 "expert_idx out of range: {e} >= {num_experts}"
11639 );
11640 counts[e] += 1;
11641 }
11642 let mut offsets = vec![0usize; num_experts + 1];
11644 for e in 0..num_experts {
11645 offsets[e + 1] = offsets[e] + counts[e];
11646 }
11647 let mut packed_in = vec![0f32; m * k_dim];
11651 let mut original_pos = vec![0usize; m];
11652 let mut write_idx = vec![0usize; num_experts];
11653 for i in 0..m {
11654 let e = ids[i] as usize;
11655 let dst_row = offsets[e] + write_idx[e];
11656 packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
11657 .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
11658 original_pos[dst_row] = i;
11659 write_idx[e] += 1;
11660 }
11661
11662 let mut packed_out = vec![0f32; m * n];
11666 let expert_stride = k_dim * n;
11667 let gmm_ord = crate::moe_residency::next_gmm_ord();
11668 let moe_layer = gmm_ord / 3;
11669 for e in 0..num_experts {
11670 let count = counts[e];
11671 if count == 0 {
11672 continue;
11673 }
11674 crate::moe_residency::record_expert_tokens(moe_layer, e, count);
11675 let in_start = offsets[e];
11676 let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
11677 let w_slab: &[f32] =
11678 if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
11679 if let Some(ptr) =
11680 crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
11681 {
11682 std::slice::from_raw_parts(ptr, expert_stride)
11683 } else {
11684 &wt[e * expert_stride..(e + 1) * expert_stride]
11685 }
11686 } else {
11687 &wt[e * expert_stride..(e + 1) * expert_stride]
11688 };
11689 let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
11690 crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
11691 }
11692
11693 for packed_idx in 0..m {
11695 let i = original_pos[packed_idx];
11696 out[i * n..(i + 1) * n]
11697 .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
11698 }
11699 }
11700 }
11701
11702 Thunk::DequantGroupedMatMulGguf {
11703 input,
11704 w_q,
11705 expert_idx,
11706 dst,
11707 m,
11708 k_dim,
11709 n,
11710 num_experts,
11711 scheme,
11712 } => {
11713 let m = *m as usize;
11714 let k_dim = *k_dim as usize;
11715 let n = *n as usize;
11716 let num_experts = *num_experts as usize;
11717 let block_elems = scheme.gguf_block_size() as usize;
11718 let block_bytes = scheme.gguf_block_bytes() as usize;
11719 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
11720 unsafe {
11721 let inp = sl(*input, base, m * k_dim);
11722 let wt = std::slice::from_raw_parts(
11723 base.add(*w_q) as *const u8,
11724 num_experts * slab_bytes,
11725 );
11726 let ids = sl(*expert_idx, base, m);
11727 let out = sl_mut(*dst, base, m * n);
11728 crate::gguf_matmul::gguf_grouped_matmul_bt(
11729 inp,
11730 wt,
11731 ids,
11732 out,
11733 m,
11734 k_dim,
11735 n,
11736 num_experts,
11737 *scheme,
11738 );
11739 }
11740 }
11741
11742 Thunk::DequantMoEWeightsGguf {
11743 w_q,
11744 dst,
11745 k_dim,
11746 n,
11747 num_experts,
11748 scheme,
11749 } => {
11750 let k_dim = *k_dim as usize;
11751 let n = *n as usize;
11752 let num_experts = *num_experts as usize;
11753 let block_elems = scheme.gguf_block_size() as usize;
11754 let block_bytes = scheme.gguf_block_bytes() as usize;
11755 let slab_bytes = (k_dim * n) / block_elems * block_bytes;
11756 unsafe {
11757 let wt = std::slice::from_raw_parts(
11758 base.add(*w_q) as *const u8,
11759 num_experts * slab_bytes,
11760 );
11761 let out = sl_mut(*dst, base, num_experts * k_dim * n);
11762 crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
11763 wt,
11764 out,
11765 num_experts,
11766 k_dim,
11767 n,
11768 *scheme,
11769 );
11770 }
11771 }
11772
11773 Thunk::TopK {
11774 src,
11775 dst,
11776 outer,
11777 axis_dim,
11778 k,
11779 indices_i64,
11780 } => {
11781 let outer = *outer as usize;
11782 let axis_dim = *axis_dim as usize;
11783 let k = *k as usize;
11784 unsafe {
11785 let inp = sl(*src, base, outer * axis_dim);
11786 let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
11790 if *indices_i64 != 0 {
11791 let out = sl_mut_i64(*dst, base, outer * k);
11792 for o in 0..outer {
11793 row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
11794 for ki in 0..k {
11795 let mut best_i = 0usize;
11796 let mut best_v = row_buf[0];
11797 for i in 1..axis_dim {
11798 let v = row_buf[i];
11799 if v > best_v {
11800 best_v = v;
11801 best_i = i;
11802 }
11803 }
11804 out[o * k + ki] = best_i as i64;
11805 row_buf[best_i] = f32::NEG_INFINITY;
11806 }
11807 }
11808 } else {
11809 let out = sl_mut(*dst, base, outer * k);
11810 for o in 0..outer {
11811 row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
11812 for ki in 0..k {
11813 let mut best_i = 0usize;
11814 let mut best_v = row_buf[0];
11815 for i in 1..axis_dim {
11816 let v = row_buf[i];
11817 if v > best_v {
11818 best_v = v;
11819 best_i = i;
11820 }
11821 }
11822 out[o * k + ki] = best_i as f32;
11823 row_buf[best_i] = f32::NEG_INFINITY;
11824 }
11825 }
11826 if let Some(cap) = schedule.moe_topk_capture.as_ref() {
11827 cap.push_topk_f32(&out[..outer * k], axis_dim);
11828 }
11829 }
11830 }
11831 }
11832
11833 Thunk::Reduce {
11834 src,
11835 dst,
11836 outer,
11837 reduced,
11838 inner,
11839 op,
11840 } => {
11841 let outer = *outer as usize;
11842 let reduced = *reduced as usize;
11843 let inner = *inner as usize;
11844 let in_total = outer * reduced * inner;
11845 let out_total = outer * inner;
11846 unsafe {
11847 let inp = sl(*src, base, in_total);
11848 let out = sl_mut(*dst, base, out_total);
11849 for o in 0..outer {
11850 for i in 0..inner {
11851 let mut acc = match op {
11852 ReduceOp::Max => f32::NEG_INFINITY,
11853 ReduceOp::Min => f32::INFINITY,
11854 ReduceOp::Prod => 1.0f32,
11855 _ => 0.0f32, };
11857 for r in 0..reduced {
11859 let v = inp[o * reduced * inner + r * inner + i];
11860 acc = match op {
11861 ReduceOp::Sum | ReduceOp::Mean => acc + v,
11862 ReduceOp::Max => acc.max(v),
11863 ReduceOp::Min => acc.min(v),
11864 ReduceOp::Prod => acc * v,
11865 };
11866 }
11867 if matches!(op, ReduceOp::Mean) {
11868 acc /= reduced as f32;
11869 }
11870 out[o * inner + i] = acc;
11871 }
11872 }
11873 }
11874 }
11875
11876 Thunk::Conv2D1x1 {
11877 src,
11878 weight,
11879 dst,
11880 n,
11881 c_in,
11882 c_out,
11883 hw,
11884 } => {
11885 let n = *n as usize;
11886 let c_in = *c_in as usize;
11887 let c_out = *c_out as usize;
11888 let hw = *hw as usize;
11889 unsafe {
11890 let inp = sl(*src, base, n * c_in * hw);
11891 let wt = sl(*weight, base, c_out * c_in);
11892 let out = sl_mut(*dst, base, n * c_out * hw);
11893 for ni in 0..n {
11898 let in_off = ni * c_in * hw;
11899 let out_off = ni * c_out * hw;
11900 crate::blas::sgemm(
11901 wt,
11902 &inp[in_off..in_off + c_in * hw],
11903 &mut out[out_off..out_off + c_out * hw],
11904 c_out,
11905 c_in,
11906 hw,
11907 );
11908 }
11909 }
11910 }
11911
11912 Thunk::Conv2D {
11913 src,
11914 weight,
11915 dst,
11916 n,
11917 c_in,
11918 h,
11919 w,
11920 c_out,
11921 h_out,
11922 w_out,
11923 kh,
11924 kw,
11925 sh,
11926 sw,
11927 ph,
11928 pw,
11929 dh,
11930 dw,
11931 groups,
11932 } => {
11933 let n = *n as usize;
11934 let c_in = *c_in as usize;
11935 let h = *h as usize;
11936 let w = *w as usize;
11937 let c_out = *c_out as usize;
11938 let h_out = *h_out as usize;
11939 let w_out = *w_out as usize;
11940 let kh = *kh as usize;
11941 let kw = *kw as usize;
11942 let sh = *sh as usize;
11943 let sw = *sw as usize;
11944 let ph = *ph as usize;
11945 let pw = *pw as usize;
11946 let dh = *dh as usize;
11947 let dw = *dw as usize;
11948 let groups = *groups as usize;
11949 let c_in_per_g = c_in / groups;
11950 let c_out_per_g = c_out / groups;
11951 unsafe {
11952 let inp = sl(*src, base, n * c_in * h * w);
11953 let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
11954 let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
11955 for ni in 0..n {
11956 for co in 0..c_out {
11957 let g = co / c_out_per_g;
11958 let ci_start = g * c_in_per_g;
11959 for ho in 0..h_out {
11960 for wo in 0..w_out {
11961 let mut acc = 0f32;
11962 for ci_off in 0..c_in_per_g {
11963 let ci = ci_start + ci_off;
11964 let in_chan = ((ni * c_in) + ci) * h * w;
11965 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
11966 for ki in 0..kh {
11967 for kj in 0..kw {
11968 let hi = ho * sh + ki * dh;
11969 let wi = wo * sw + kj * dw;
11970 if hi < ph || wi < pw {
11971 continue;
11972 }
11973 let hi = hi - ph;
11974 let wi = wi - pw;
11975 if hi >= h || wi >= w {
11976 continue;
11977 }
11978 acc += inp[in_chan + hi * w + wi]
11979 * wt[wt_chan + ki * kw + kj];
11980 }
11981 }
11982 }
11983 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
11984 acc;
11985 }
11986 }
11987 }
11988 }
11989 }
11990 }
11991
11992 Thunk::Pool2D {
11993 src,
11994 dst,
11995 n,
11996 c,
11997 h,
11998 w,
11999 h_out,
12000 w_out,
12001 kh,
12002 kw,
12003 sh,
12004 sw,
12005 ph,
12006 pw,
12007 kind,
12008 } => {
12009 let n = *n as usize;
12010 let c = *c as usize;
12011 let h = *h as usize;
12012 let w = *w as usize;
12013 let h_out = *h_out as usize;
12014 let w_out = *w_out as usize;
12015 let kh = *kh as usize;
12016 let kw = *kw as usize;
12017 let sh = *sh as usize;
12018 let sw = *sw as usize;
12019 let ph = *ph as usize;
12020 let pw = *pw as usize;
12021 let kernel_area = (kh * kw) as f32;
12022 unsafe {
12023 let inp = sl(*src, base, n * c * h * w);
12024 let out = sl_mut(*dst, base, n * c * h_out * w_out);
12025 for ni in 0..n {
12026 for ci in 0..c {
12027 let in_chan = ni * c * h * w + ci * h * w;
12028 let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
12029 for ho in 0..h_out {
12030 for wo in 0..w_out {
12031 let mut acc = match kind {
12032 ReduceOp::Max => f32::NEG_INFINITY,
12033 _ => 0f32, };
12035 for ki in 0..kh {
12036 for kj in 0..kw {
12037 let hi = ho * sh + ki;
12038 let wi = wo * sw + kj;
12039 if hi < ph || wi < pw {
12041 continue;
12042 }
12043 let hi = hi - ph;
12044 let wi = wi - pw;
12045 if hi >= h || wi >= w {
12046 continue;
12047 }
12048 let v = inp[in_chan + hi * w + wi];
12049 match kind {
12050 ReduceOp::Max => acc = acc.max(v),
12051 _ => acc += v,
12052 }
12053 }
12054 }
12055 if matches!(kind, ReduceOp::Mean) {
12056 acc /= kernel_area;
12057 }
12058 out[out_chan + ho * w_out + wo] = acc;
12059 }
12060 }
12061 }
12062 }
12063 }
12064 }
12065
12066 Thunk::ReluBackward { x, dy, dx, len } => {
12067 let len = *len as usize;
12068 unsafe {
12069 let xs = sl(*x, base, len);
12070 let dys = sl(*dy, base, len);
12071 let out = sl_mut(*dx, base, len);
12072 for i in 0..len {
12073 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12074 }
12075 }
12076 }
12077
12078 Thunk::ReluBackwardF64 { x, dy, dx, len } => {
12079 let len = *len as usize;
12080 unsafe {
12081 let xs = sl_f64(*x, base, len);
12082 let dys = sl_f64(*dy, base, len);
12083 let out = sl_mut_f64(*dx, base, len);
12084 for i in 0..len {
12085 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12086 }
12087 }
12088 }
12089
12090 Thunk::QMatMul {
12091 x,
12092 w,
12093 bias,
12094 out,
12095 m,
12096 k,
12097 n,
12098 x_zp,
12099 w_zp,
12100 out_zp,
12101 mult,
12102 } => {
12103 let m = *m as usize;
12104 let k = *k as usize;
12105 let n = *n as usize;
12106 unsafe {
12107 let x_ptr = base.add(*x) as *const i8;
12108 let w_ptr = base.add(*w) as *const i8;
12109 let bias_ptr = base.add(*bias) as *const i32;
12110 let out_ptr = base.add(*out) as *mut i8;
12111 for mi in 0..m {
12112 for ni in 0..n {
12113 let mut acc: i32 = *bias_ptr.add(ni);
12114 for ki in 0..k {
12115 let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
12116 let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
12117 acc += xv * wv;
12118 }
12119 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12122 let r = r.clamp(-128, 127) as i8;
12123 *out_ptr.add(mi * n + ni) = r;
12124 }
12125 }
12126 }
12127 }
12128
12129 Thunk::QConv2d {
12130 x,
12131 w,
12132 bias,
12133 out,
12134 n,
12135 c_in,
12136 h,
12137 w_in,
12138 c_out,
12139 h_out,
12140 w_out,
12141 kh,
12142 kw,
12143 sh,
12144 sw,
12145 ph,
12146 pw,
12147 dh,
12148 dw,
12149 groups,
12150 x_zp,
12151 w_zp,
12152 out_zp,
12153 mult,
12154 } => {
12155 let n = *n as usize;
12156 let c_in = *c_in as usize;
12157 let h = *h as usize;
12158 let w_in = *w_in as usize;
12159 let c_out = *c_out as usize;
12160 let h_out = *h_out as usize;
12161 let w_out = *w_out as usize;
12162 let kh = *kh as usize;
12163 let kw = *kw as usize;
12164 let sh = *sh as usize;
12165 let sw = *sw as usize;
12166 let ph = *ph as usize;
12167 let pw = *pw as usize;
12168 let dh = *dh as usize;
12169 let dw = *dw as usize;
12170 let groups = *groups as usize;
12171 let c_in_per_g = c_in / groups;
12172 let c_out_per_g = c_out / groups;
12173 unsafe {
12174 let x_ptr = base.add(*x) as *const i8;
12175 let w_ptr = base.add(*w) as *const i8;
12176 let bias_ptr = base.add(*bias) as *const i32;
12177 let out_ptr = base.add(*out) as *mut i8;
12178 for ni in 0..n {
12179 for co in 0..c_out {
12180 let g = co / c_out_per_g;
12181 let ci_start = g * c_in_per_g;
12182 for ho in 0..h_out {
12183 for wo in 0..w_out {
12184 let mut acc: i32 = *bias_ptr.add(co);
12185 for ci_off in 0..c_in_per_g {
12186 let ci = ci_start + ci_off;
12187 let in_chan = ((ni * c_in) + ci) * h * w_in;
12188 let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12189 for ki in 0..kh {
12190 for kj in 0..kw {
12191 let hi = ho * sh + ki * dh;
12192 let wi = wo * sw + kj * dw;
12193 if hi < ph || wi < pw {
12194 continue;
12195 }
12196 let hi = hi - ph;
12197 let wi = wi - pw;
12198 if hi >= h || wi >= w_in {
12199 continue;
12200 }
12201 let xv = *x_ptr.add(in_chan + hi * w_in + wi)
12202 as i32
12203 - *x_zp;
12204 let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
12205 - *w_zp;
12206 acc += xv * wv;
12207 }
12208 }
12209 }
12210 let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12211 let r = r.clamp(-128, 127) as i8;
12212 let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
12213 *out_ptr.add(dst) = r;
12214 }
12215 }
12216 }
12217 }
12218 }
12219 }
12220
12221 Thunk::Quantize {
12222 x,
12223 q,
12224 len,
12225 chan_axis: _,
12226 chan_dim,
12227 inner,
12228 scales,
12229 zero_points,
12230 } => {
12231 let len = *len as usize;
12232 let chan_dim = *chan_dim as usize;
12233 let inner = *inner as usize;
12234 unsafe {
12235 let xs = sl(*x, base, len);
12236 let q_ptr = base.add(*q) as *mut i8;
12237 for i in 0..len {
12238 let c = if chan_dim == 1 {
12239 0
12240 } else {
12241 (i / inner) % chan_dim
12242 };
12243 let inv_scale = 1.0 / scales[c];
12244 let zp = zero_points[c];
12245 let v = (xs[i] * inv_scale).round() as i32 + zp;
12246 *q_ptr.add(i) = v.clamp(-128, 127) as i8;
12247 }
12248 }
12249 }
12250
12251 Thunk::Dequantize {
12252 q,
12253 x,
12254 len,
12255 chan_axis: _,
12256 chan_dim,
12257 inner,
12258 scales,
12259 zero_points,
12260 } => {
12261 let len = *len as usize;
12262 let chan_dim = *chan_dim as usize;
12263 let inner = *inner as usize;
12264 unsafe {
12265 let q_ptr = base.add(*q) as *const i8;
12266 let out = sl_mut(*x, base, len);
12267 for i in 0..len {
12268 let c = if chan_dim == 1 {
12269 0
12270 } else {
12271 (i / inner) % chan_dim
12272 };
12273 let scale = scales[c];
12274 let zp = zero_points[c];
12275 let qv = *q_ptr.add(i) as i32;
12276 out[i] = (qv - zp) as f32 * scale;
12277 }
12278 }
12279 }
12280
12281 Thunk::FakeQuantize {
12282 x,
12283 out,
12284 len,
12285 chan_axis: _,
12286 chan_dim,
12287 inner,
12288 bits,
12289 ste: _,
12290 scale_mode,
12291 state_off,
12292 } => {
12293 use rlx_ir::op::ScaleMode;
12294 let len = *len as usize;
12295 let chan_dim = *chan_dim as usize;
12296 let inner = *inner as usize;
12297 let q_max: f32 = match *bits {
12298 8 => 127.0,
12299 4 => 7.0,
12300 2 => 1.0,
12301 n => panic!("FakeQuantize: unsupported bits {n}"),
12302 };
12303 unsafe {
12304 let xs = sl(*x, base, len);
12305 let outs = sl_mut(*out, base, len);
12306
12307 let mut scale = vec![0f32; chan_dim];
12308 match scale_mode {
12309 ScaleMode::PerBatch => {
12310 let mut max_abs = vec![0f32; chan_dim];
12311 for i in 0..len {
12312 let c = if chan_dim == 1 {
12313 0
12314 } else {
12315 (i / inner) % chan_dim
12316 };
12317 let a = xs[i].abs();
12318 if a > max_abs[c] {
12319 max_abs[c] = a;
12320 }
12321 }
12322 for c in 0..chan_dim {
12323 scale[c] = (max_abs[c] / q_max).max(1e-12);
12324 }
12325 }
12326 ScaleMode::EMA { decay } => {
12327 let mut max_abs = vec![0f32; chan_dim];
12330 for i in 0..len {
12331 let c = if chan_dim == 1 {
12332 0
12333 } else {
12334 (i / inner) % chan_dim
12335 };
12336 let a = xs[i].abs();
12337 if a > max_abs[c] {
12338 max_abs[c] = a;
12339 }
12340 }
12341 let state =
12342 sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
12343 for c in 0..chan_dim {
12344 let cur = (max_abs[c] / q_max).max(1e-12);
12345 let blended = if state[c] <= 0.0 {
12347 cur
12348 } else {
12349 *decay * state[c] + (1.0 - *decay) * cur
12350 };
12351 state[c] = blended;
12352 scale[c] = blended;
12353 }
12354 }
12355 ScaleMode::Fixed => {
12356 let state =
12357 sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
12358 for c in 0..chan_dim {
12359 scale[c] = state[c].max(1e-12);
12360 }
12361 }
12362 }
12363
12364 for i in 0..len {
12365 let c = if chan_dim == 1 {
12366 0
12367 } else {
12368 (i / inner) % chan_dim
12369 };
12370 let s = scale[c];
12371 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
12372 outs[i] = qv * s;
12373 }
12374 }
12375 }
12376
12377 Thunk::ActivationBackward {
12378 x,
12379 dy,
12380 dx,
12381 len,
12382 kind,
12383 } => {
12384 let len = *len as usize;
12385 unsafe {
12386 let xs = sl(*x, base, len);
12387 let dys = sl(*dy, base, len);
12388 let out = sl_mut(*dx, base, len);
12389 activation_backward_kernel(*kind, xs, dys, out);
12390 }
12391 }
12392
12393 Thunk::ActivationBackwardF64 {
12394 x,
12395 dy,
12396 dx,
12397 len,
12398 kind,
12399 } => {
12400 let len = *len as usize;
12401 unsafe {
12402 let xs = sl_f64(*x, base, len);
12403 let dys = sl_f64(*dy, base, len);
12404 let out = sl_mut_f64(*dx, base, len);
12405 activation_backward_kernel_f64(*kind, xs, dys, out);
12406 }
12407 }
12408
12409 Thunk::FakeQuantizeLSQ {
12410 x,
12411 scale_off,
12412 out,
12413 len,
12414 chan_axis: _,
12415 chan_dim,
12416 inner,
12417 bits,
12418 } => {
12419 let len = *len as usize;
12420 let chan_dim = *chan_dim as usize;
12421 let inner = *inner as usize;
12422 let q_max: f32 = match *bits {
12423 8 => 127.0,
12424 4 => 7.0,
12425 2 => 1.0,
12426 n => panic!("FakeQuantizeLSQ: bad bits {n}"),
12427 };
12428 unsafe {
12429 let xs = sl(*x, base, len);
12430 let scale = sl(*scale_off, base, chan_dim);
12431 let outs = sl_mut(*out, base, len);
12432 for i in 0..len {
12433 let c = if chan_dim == 1 {
12434 0
12435 } else {
12436 (i / inner) % chan_dim
12437 };
12438 let s = scale[c].max(1e-12);
12439 let qv = (xs[i] / s).round().clamp(-q_max, q_max);
12440 outs[i] = qv * s;
12441 }
12442 }
12443 }
12444
12445 Thunk::FakeQuantizeLSQBackwardX {
12446 x,
12447 scale_off,
12448 dy,
12449 dx,
12450 len,
12451 chan_axis: _,
12452 chan_dim,
12453 inner,
12454 bits,
12455 } => {
12456 let len = *len as usize;
12457 let chan_dim = *chan_dim as usize;
12458 let inner = *inner as usize;
12459 let q_max: f32 = match *bits {
12460 8 => 127.0,
12461 4 => 7.0,
12462 2 => 1.0,
12463 n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
12464 };
12465 unsafe {
12466 let xs = sl(*x, base, len);
12467 let scale = sl(*scale_off, base, chan_dim);
12468 let dys = sl(*dy, base, len);
12469 let outs = sl_mut(*dx, base, len);
12470 for i in 0..len {
12472 let c = if chan_dim == 1 {
12473 0
12474 } else {
12475 (i / inner) % chan_dim
12476 };
12477 let z = xs[i] / scale[c].max(1e-12);
12478 outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
12479 }
12480 }
12481 }
12482
12483 Thunk::FakeQuantizeLSQBackwardScale {
12484 x,
12485 scale_off,
12486 dy,
12487 dscale,
12488 len,
12489 chan_axis: _,
12490 chan_dim,
12491 inner,
12492 bits,
12493 } => {
12494 let len = *len as usize;
12495 let chan_dim = *chan_dim as usize;
12496 let inner = *inner as usize;
12497 let q_max: f32 = match *bits {
12498 8 => 127.0,
12499 4 => 7.0,
12500 2 => 1.0,
12501 n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
12502 };
12503 unsafe {
12504 let xs = sl(*x, base, len);
12505 let scale = sl(*scale_off, base, chan_dim);
12506 let dys = sl(*dy, base, len);
12507 let outs = sl_mut(*dscale, base, chan_dim);
12508 for v in outs.iter_mut() {
12509 *v = 0.0;
12510 }
12511 for i in 0..len {
12514 let c = if chan_dim == 1 {
12515 0
12516 } else {
12517 (i / inner) % chan_dim
12518 };
12519 let s = scale[c].max(1e-12);
12520 let z = xs[i] / s;
12521 let psi = if z.abs() <= q_max {
12522 -z + z.round()
12523 } else if z > 0.0 {
12524 q_max
12525 } else {
12526 -q_max
12527 };
12528 outs[c] += psi * dys[i];
12529 }
12530 }
12531 }
12532
12533 Thunk::FakeQuantizeBackward {
12534 x,
12535 dy,
12536 dx,
12537 len,
12538 chan_axis: _,
12539 chan_dim,
12540 inner,
12541 bits,
12542 ste,
12543 } => {
12544 use rlx_ir::op::SteKind;
12545 let len = *len as usize;
12546 let chan_dim = *chan_dim as usize;
12547 let inner = *inner as usize;
12548 let q_max: f32 = match *bits {
12549 8 => 127.0,
12550 4 => 7.0,
12551 2 => 1.0,
12552 n => panic!("FakeQuantizeBackward: bad bits {n}"),
12553 };
12554 unsafe {
12555 let xs = sl(*x, base, len);
12556 let dys = sl(*dy, base, len);
12557 let outs = sl_mut(*dx, base, len);
12558
12559 let mut max_abs = vec![0f32; chan_dim];
12561 for i in 0..len {
12562 let c = if chan_dim == 1 {
12563 0
12564 } else {
12565 (i / inner) % chan_dim
12566 };
12567 let a = xs[i].abs();
12568 if a > max_abs[c] {
12569 max_abs[c] = a;
12570 }
12571 }
12572 let mut scale = vec![0f32; chan_dim];
12573 for c in 0..chan_dim {
12574 scale[c] = (max_abs[c] / q_max).max(1e-12);
12575 }
12576
12577 match *ste {
12578 SteKind::Identity => {
12579 outs.copy_from_slice(dys);
12581 }
12582 SteKind::ClippedIdentity => {
12583 for i in 0..len {
12586 let c = if chan_dim == 1 {
12587 0
12588 } else {
12589 (i / inner) % chan_dim
12590 };
12591 let bound = q_max * scale[c];
12592 outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
12593 }
12594 }
12595 SteKind::Tanh => {
12596 for i in 0..len {
12598 let c = if chan_dim == 1 {
12599 0
12600 } else {
12601 (i / inner) % chan_dim
12602 };
12603 let t = (xs[i] / scale[c]).tanh();
12604 outs[i] = dys[i] * (1.0 - t * t);
12605 }
12606 }
12607 SteKind::HardTanh => {
12608 for i in 0..len {
12610 let c = if chan_dim == 1 {
12611 0
12612 } else {
12613 (i / inner) % chan_dim
12614 };
12615 let bound = q_max * scale[c];
12616 let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
12617 outs[i] = dys[i] * attenuation;
12618 }
12619 }
12620 }
12621 }
12622 }
12623
12624 Thunk::LayerNormBackwardInput {
12625 x,
12626 gamma,
12627 dy,
12628 dx,
12629 rows,
12630 h,
12631 eps,
12632 } => {
12633 let rows = *rows as usize;
12634 let h = *h as usize;
12635 let eps = *eps;
12636 unsafe {
12637 let xs = sl(*x, base, rows * h);
12638 let g = sl(*gamma, base, h);
12639 let dys = sl(*dy, base, rows * h);
12640 let out = sl_mut(*dx, base, rows * h);
12641 let n_inv = 1.0 / h as f32;
12642 for r in 0..rows {
12643 let xr = &xs[r * h..(r + 1) * h];
12644 let dyr = &dys[r * h..(r + 1) * h];
12645 let mut sum = 0f32;
12648 for &v in xr {
12649 sum += v;
12650 }
12651 let mean = sum * n_inv;
12652 let mut var = 0f32;
12653 for &v in xr {
12654 let d = v - mean;
12655 var += d * d;
12656 }
12657 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
12658
12659 let mut s_sy = 0f32;
12662 let mut s_sxh = 0f32;
12663 for d in 0..h {
12664 let xh = (xr[d] - mean) * inv_std;
12665 let sy = dyr[d] * g[d];
12666 s_sy += sy;
12667 s_sxh += sy * xh;
12668 }
12669 let m_sy = s_sy * n_inv;
12670 let m_sxh = s_sxh * n_inv;
12671
12672 for d in 0..h {
12673 let xh = (xr[d] - mean) * inv_std;
12674 let sy = dyr[d] * g[d];
12675 out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
12676 }
12677 }
12678 }
12679 }
12680
12681 Thunk::BatchNormInferenceBackwardInput {
12682 x,
12683 gamma,
12684 mean,
12685 var,
12686 dy,
12687 dx,
12688 count,
12689 channels,
12690 eps,
12691 } => {
12692 let count = *count as usize;
12693 let c = *channels as usize;
12694 let n = count * c;
12695 let eps = *eps;
12696 unsafe {
12697 crate::kernels::batch_norm_inference_backward_input(
12698 sl(*x, base, n),
12699 sl(*gamma, base, c),
12700 sl(*mean, base, c),
12701 sl(*var, base, c),
12702 sl(*dy, base, n),
12703 sl_mut(*dx, base, n),
12704 c,
12705 eps,
12706 );
12707 }
12708 }
12709
12710 Thunk::BatchNormInferenceBackwardGamma {
12711 x,
12712 mean,
12713 var,
12714 dy,
12715 dgamma,
12716 count,
12717 channels,
12718 eps,
12719 } => {
12720 let count = *count as usize;
12721 let c = *channels as usize;
12722 let n = count * c;
12723 let eps = *eps;
12724 unsafe {
12725 crate::kernels::batch_norm_inference_backward_gamma(
12726 sl(*x, base, n),
12727 sl(*mean, base, c),
12728 sl(*var, base, c),
12729 sl(*dy, base, n),
12730 sl_mut(*dgamma, base, c),
12731 c,
12732 eps,
12733 );
12734 }
12735 }
12736
12737 Thunk::BatchNormInferenceBackwardBeta {
12738 dy,
12739 dbeta,
12740 count,
12741 channels,
12742 } => {
12743 let count = *count as usize;
12744 let c = *channels as usize;
12745 let n = count * c;
12746 unsafe {
12747 crate::kernels::batch_norm_inference_backward_beta(
12748 sl(*dy, base, n),
12749 sl_mut(*dbeta, base, c),
12750 c,
12751 );
12752 }
12753 }
12754
12755 Thunk::LayerNormBackwardGamma {
12756 x,
12757 dy,
12758 dgamma,
12759 rows,
12760 h,
12761 eps,
12762 } => {
12763 let rows = *rows as usize;
12764 let h = *h as usize;
12765 let eps = *eps;
12766 unsafe {
12767 let xs = sl(*x, base, rows * h);
12768 let dys = sl(*dy, base, rows * h);
12769 let out = sl_mut(*dgamma, base, h);
12770 for v in out.iter_mut() {
12771 *v = 0.0;
12772 }
12773 let n_inv = 1.0 / h as f32;
12774 for r in 0..rows {
12775 let xr = &xs[r * h..(r + 1) * h];
12776 let dyr = &dys[r * h..(r + 1) * h];
12777 let mut sum = 0f32;
12778 for &v in xr {
12779 sum += v;
12780 }
12781 let mean = sum * n_inv;
12782 let mut var = 0f32;
12783 for &v in xr {
12784 let d = v - mean;
12785 var += d * d;
12786 }
12787 let inv_std = 1.0 / (var * n_inv + eps).sqrt();
12788 for d in 0..h {
12789 let xh = (xr[d] - mean) * inv_std;
12790 out[d] += dyr[d] * xh;
12791 }
12792 }
12793 }
12794 }
12795
12796 Thunk::RmsNormBackwardInput {
12797 x,
12798 gamma,
12799 beta,
12800 dy,
12801 dx,
12802 rows,
12803 h,
12804 eps,
12805 } => {
12806 let (rows, h) = (*rows as usize, *h as usize);
12807 unsafe {
12808 let xs = sl(*x, base, rows * h);
12809 let g = sl(*gamma, base, h);
12810 let b = sl(*beta, base, h);
12811 let dys = sl(*dy, base, rows * h);
12812 let out = sl_mut(*dx, base, rows * h);
12813 let mut dg = vec![0f32; h];
12814 let mut db = vec![0f32; h];
12815 for r in 0..rows {
12816 crate::training_bwd::rms_norm_backward_row(
12817 &xs[r * h..(r + 1) * h],
12818 g,
12819 b,
12820 &dys[r * h..(r + 1) * h],
12821 &mut out[r * h..(r + 1) * h],
12822 &mut dg,
12823 &mut db,
12824 *eps,
12825 );
12826 }
12827 }
12828 }
12829
12830 Thunk::RmsNormBackwardGamma {
12831 x,
12832 gamma,
12833 beta,
12834 dy,
12835 dgamma,
12836 rows,
12837 h,
12838 eps,
12839 } => {
12840 let (rows, h) = (*rows as usize, *h as usize);
12841 unsafe {
12842 let xs = sl(*x, base, rows * h);
12843 let g = sl(*gamma, base, h);
12844 let b = sl(*beta, base, h);
12845 let dys = sl(*dy, base, rows * h);
12846 let out = sl_mut(*dgamma, base, h);
12847 for v in out.iter_mut() {
12848 *v = 0.0;
12849 }
12850 let mut dx = vec![0f32; h];
12851 let mut db = vec![0f32; h];
12852 for r in 0..rows {
12853 crate::training_bwd::rms_norm_backward_row(
12854 &xs[r * h..(r + 1) * h],
12855 g,
12856 b,
12857 &dys[r * h..(r + 1) * h],
12858 &mut dx,
12859 &mut *out,
12860 &mut db,
12861 *eps,
12862 );
12863 }
12864 }
12865 }
12866
12867 Thunk::RmsNormBackwardBeta {
12868 x,
12869 gamma,
12870 beta,
12871 dy,
12872 dbeta,
12873 rows,
12874 h,
12875 eps,
12876 } => {
12877 let (rows, h) = (*rows as usize, *h as usize);
12878 unsafe {
12879 let xs = sl(*x, base, rows * h);
12880 let g = sl(*gamma, base, h);
12881 let b = sl(*beta, base, h);
12882 let dys = sl(*dy, base, rows * h);
12883 let out = sl_mut(*dbeta, base, h);
12884 for v in out.iter_mut() {
12885 *v = 0.0;
12886 }
12887 let mut dx = vec![0f32; h];
12888 let mut dg = vec![0f32; h];
12889 for r in 0..rows {
12890 crate::training_bwd::rms_norm_backward_row(
12891 &xs[r * h..(r + 1) * h],
12892 g,
12893 b,
12894 &dys[r * h..(r + 1) * h],
12895 &mut dx,
12896 &mut dg,
12897 &mut *out,
12898 *eps,
12899 );
12900 }
12901 }
12902 }
12903
12904 Thunk::RopeBackward {
12905 dy,
12906 cos,
12907 sin,
12908 dx,
12909 batch,
12910 seq,
12911 hidden,
12912 head_dim,
12913 n_rot,
12914 cos_len,
12915 } => {
12916 let (b, s, hs, dh, nr, cl) = (
12917 *batch as usize,
12918 *seq as usize,
12919 *hidden as usize,
12920 *head_dim as usize,
12921 *n_rot as usize,
12922 *cos_len as usize,
12923 );
12924 let nh = hs / dh;
12925 let tab_half = dh / 2;
12926 unsafe {
12927 let dys = sl(*dy, base, b * s * hs);
12928 let cos_tab = sl(*cos, base, cl);
12929 let sin_tab = sl(*sin, base, cl);
12930 let out = sl_mut(*dx, base, b * s * hs);
12931 for bi in 0..b {
12932 for si in 0..s {
12933 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12934 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12935 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12936 for hi in 0..nh {
12937 let base_idx = bi * s * hs + si * hs + hi * dh;
12938 crate::training_bwd::rope_backward_row(
12939 &dys[base_idx..base_idx + dh],
12940 cp,
12941 sp,
12942 &mut out[base_idx..base_idx + dh],
12943 dh,
12944 nr,
12945 );
12946 }
12947 }
12948 }
12949 }
12950 }
12951
12952 Thunk::CumsumBackward {
12953 dy,
12954 dx,
12955 rows,
12956 cols,
12957 exclusive,
12958 } => {
12959 let (rows, cols) = (*rows as usize, *cols as usize);
12960 unsafe {
12961 let dys = sl(*dy, base, rows * cols);
12962 let out = sl_mut(*dx, base, rows * cols);
12963 for r in 0..rows {
12964 crate::training_bwd::cumsum_backward_row(
12965 &dys[r * cols..(r + 1) * cols],
12966 &mut out[r * cols..(r + 1) * cols],
12967 *exclusive,
12968 );
12969 }
12970 }
12971 }
12972
12973 Thunk::GroupNormBackwardInput {
12974 x,
12975 gamma,
12976 beta: _beta,
12977 dy,
12978 dx,
12979 n,
12980 c,
12981 h,
12982 w,
12983 num_groups,
12984 eps,
12985 } => {
12986 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
12987 let plane = c * h * w;
12988 unsafe {
12989 let xs = sl(*x, base, n * plane);
12990 let g = sl(*gamma, base, c);
12991 let dys = sl(*dy, base, n * plane);
12992 let out = sl_mut(*dx, base, n * plane);
12993 crate::training_bwd::group_norm_backward_input_nchw(
12994 xs,
12995 g,
12996 dys,
12997 out,
12998 n,
12999 c,
13000 h,
13001 w,
13002 *num_groups as usize,
13003 *eps,
13004 );
13005 }
13006 }
13007
13008 Thunk::GroupNormBackwardGamma {
13009 x,
13010 dy,
13011 dgamma,
13012 n,
13013 c,
13014 h,
13015 w,
13016 num_groups,
13017 eps,
13018 } => {
13019 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13020 let plane = c * h * w;
13021 unsafe {
13022 let xs = sl(*x, base, n * plane);
13023 let dys = sl(*dy, base, n * plane);
13024 let out = sl_mut(*dgamma, base, c);
13025 crate::training_bwd::group_norm_backward_gamma_nchw(
13026 xs,
13027 dys,
13028 out,
13029 n,
13030 c,
13031 h,
13032 w,
13033 *num_groups as usize,
13034 *eps,
13035 );
13036 }
13037 }
13038
13039 Thunk::GroupNormBackwardBeta {
13040 dy,
13041 dbeta,
13042 n,
13043 c,
13044 h,
13045 w,
13046 } => {
13047 let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13048 let plane = c * h * w;
13049 unsafe {
13050 let dys = sl(*dy, base, n * plane);
13051 let out = sl_mut(*dbeta, base, c);
13052 crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
13053 }
13054 }
13055
13056 Thunk::GatherBackward {
13057 dy,
13058 indices,
13059 dst,
13060 outer,
13061 axis_dim,
13062 num_idx,
13063 trailing,
13064 } => {
13065 let (outer, axis_dim, num_idx, trailing) = (
13066 *outer as usize,
13067 *axis_dim as usize,
13068 *num_idx as usize,
13069 *trailing as usize,
13070 );
13071 unsafe {
13072 let dys = sl(*dy, base, outer * num_idx * trailing);
13073 let ids = sl(*indices, base, num_idx);
13074 let out = sl_mut(*dst, base, outer * axis_dim * trailing);
13075 for v in out.iter_mut() {
13076 *v = 0.0;
13077 }
13078 crate::training_bwd::gather_axis_backward(
13079 dys, ids, out, outer, axis_dim, num_idx, trailing,
13080 );
13081 }
13082 }
13083
13084 Thunk::MaxPool2dBackward {
13085 x,
13086 dy,
13087 dx,
13088 n,
13089 c,
13090 h,
13091 w,
13092 h_out,
13093 w_out,
13094 kh,
13095 kw,
13096 sh,
13097 sw,
13098 ph,
13099 pw,
13100 } => unsafe {
13101 execute_maxpool2d_backward_f32(
13102 *x, *dy, *dx, *n, *c, *h, *w, *h_out, *w_out, *kh, *kw, *sh, *sw, *ph, *pw,
13103 base,
13104 );
13105 },
13106
13107 Thunk::Conv2dBackwardInput {
13108 dy,
13109 w,
13110 dx,
13111 n,
13112 c_in,
13113 h,
13114 w_in,
13115 c_out,
13116 h_out,
13117 w_out,
13118 kh,
13119 kw,
13120 sh,
13121 sw,
13122 ph,
13123 pw,
13124 dh,
13125 dw,
13126 groups,
13127 } => {
13128 let n = *n as usize;
13140 let c_in = *c_in as usize;
13141 let h = *h as usize;
13142 let w_in = *w_in as usize;
13143 let c_out = *c_out as usize;
13144 let h_out = *h_out as usize;
13145 let w_out = *w_out as usize;
13146 let kh = *kh as usize;
13147 let kw = *kw as usize;
13148 let sh = *sh as usize;
13149 let sw = *sw as usize;
13150 let ph = *ph as usize;
13151 let pw = *pw as usize;
13152 let dh = *dh as usize;
13153 let dw = *dw as usize;
13154 let groups = *groups as usize;
13155 let c_in_per_g = c_in / groups;
13156 let c_out_per_g = c_out / groups;
13157
13158 let m_dim = c_in_per_g * kh * kw;
13159 let n_dim = h_out * w_out;
13160 let k_dim = c_out_per_g;
13161
13162 let dy_stride_n = c_out * h_out * w_out;
13163 let dy_stride_g = c_out_per_g * h_out * w_out;
13164 let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13165 let dx_stride_n = c_in * h * w_in;
13166 let dx_stride_g = c_in_per_g * h * w_in;
13167
13168 unsafe {
13169 let dys = sl(*dy, base, n * c_out * h_out * w_out);
13170 let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
13171 let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
13172 for v in dxs.iter_mut() {
13173 *v = 0.0;
13174 }
13175
13176 let mut dcol = vec![0f32; m_dim * n_dim];
13178
13179 for ni in 0..n {
13180 for g in 0..groups {
13181 let w_g_off = g * w_stride_g;
13182 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13183 let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
13184
13185 crate::blas::sgemm_general(
13190 ws.as_ptr().add(w_g_off),
13191 dys.as_ptr().add(dy_n_g_off),
13192 dcol.as_mut_ptr(),
13193 m_dim,
13194 n_dim,
13195 k_dim,
13196 1.0,
13197 0.0,
13198 m_dim,
13199 n_dim,
13200 n_dim,
13201 true,
13202 false,
13203 );
13204
13205 col2im(
13207 &dcol,
13208 &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
13209 c_in_per_g,
13210 h,
13211 w_in,
13212 h_out,
13213 w_out,
13214 kh,
13215 kw,
13216 sh,
13217 sw,
13218 ph,
13219 pw,
13220 dh,
13221 dw,
13222 );
13223 }
13224 }
13225 }
13226 }
13227
13228 Thunk::Conv2dBackwardWeight {
13229 x,
13230 dy,
13231 dw,
13232 n,
13233 c_in,
13234 h,
13235 w,
13236 c_out,
13237 h_out,
13238 w_out,
13239 kh,
13240 kw,
13241 sh,
13242 sw,
13243 ph,
13244 pw,
13245 dh,
13246 dw_dil,
13247 groups,
13248 } => {
13249 let n = *n as usize;
13250 let c_in = *c_in as usize;
13251 let h = *h as usize;
13252 let w = *w as usize;
13253 let c_out = *c_out as usize;
13264 let h_out = *h_out as usize;
13265 let w_out = *w_out as usize;
13266 let kh = *kh as usize;
13267 let kw = *kw as usize;
13268 let sh = *sh as usize;
13269 let sw = *sw as usize;
13270 let ph = *ph as usize;
13271 let pw = *pw as usize;
13272 let dh = *dh as usize;
13273 let dw_dil = *dw_dil as usize;
13274 let groups = *groups as usize;
13275 let c_in_per_g = c_in / groups;
13276 let c_out_per_g = c_out / groups;
13277
13278 let m_dim = c_out_per_g;
13279 let n_dim = c_in_per_g * kh * kw;
13280 let k_dim = h_out * w_out;
13281
13282 let x_stride_n = c_in * h * w;
13283 let x_stride_g = c_in_per_g * h * w;
13284 let dy_stride_n = c_out * h_out * w_out;
13285 let dy_stride_g = c_out_per_g * h_out * w_out;
13286 let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13287
13288 unsafe {
13289 let xs = sl(*x, base, n * c_in * h * w);
13290 let dys = sl(*dy, base, n * c_out * h_out * w_out);
13291 let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
13292 for v in dws.iter_mut() {
13293 *v = 0.0;
13294 }
13295
13296 let mut col = vec![0f32; n_dim * k_dim];
13297
13298 for ni in 0..n {
13299 for g in 0..groups {
13300 let x_n_g_off = ni * x_stride_n + g * x_stride_g;
13301 im2col(
13302 &xs[x_n_g_off..x_n_g_off + x_stride_g],
13303 &mut col,
13304 c_in_per_g,
13305 h,
13306 w,
13307 h_out,
13308 w_out,
13309 kh,
13310 kw,
13311 sh,
13312 sw,
13313 ph,
13314 pw,
13315 dh,
13316 dw_dil,
13317 );
13318
13319 let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13320 let dw_g_off = g * dw_stride_g;
13321
13322 crate::blas::sgemm_general(
13330 dys.as_ptr().add(dy_n_g_off),
13331 col.as_ptr(),
13332 dws.as_mut_ptr().add(dw_g_off),
13333 m_dim,
13334 n_dim,
13335 k_dim,
13336 1.0,
13337 1.0,
13338 k_dim,
13339 k_dim,
13340 n_dim,
13341 false,
13342 true,
13343 );
13344 }
13345 }
13346 }
13347 }
13348
13349 Thunk::Im2Col {
13350 x,
13351 col,
13352 n,
13353 c_in,
13354 h,
13355 w,
13356 h_out,
13357 w_out,
13358 kh,
13359 kw,
13360 sh,
13361 sw,
13362 ph,
13363 pw,
13364 dh,
13365 dw_dil,
13366 } => {
13367 let c_in = *c_in as usize;
13368 let h = *h as usize;
13369 let w = *w as usize;
13370 let h_out = *h_out as usize;
13371 let w_out = *w_out as usize;
13372 let kh = *kh as usize;
13373 let kw = *kw as usize;
13374 let sh = *sh as usize;
13375 let sw = *sw as usize;
13376 let ph = *ph as usize;
13377 let pw = *pw as usize;
13378 let dh = *dh as usize;
13379 let dw_dil = *dw_dil as usize;
13380 let per_batch = c_in * h * w;
13381 unsafe {
13382 let n_eff = if *n == 0 { 0usize } else { *n as usize };
13383 let x_floats = if n_eff == 0 {
13384 per_batch.max(1)
13385 } else {
13386 n_eff * per_batch
13387 };
13388 let xs = sl(*x, base, x_floats);
13389 let n = if *n == 0 {
13390 xs.len() / per_batch.max(1)
13391 } else {
13392 n_eff
13393 };
13394 let m = n * h_out * w_out;
13395 let k = c_in * kh * kw;
13396 let cols = sl_mut(*col, base, m * k);
13397 crate::im2col::im2col_rows_layout(
13398 xs, cols, n, c_in, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw_dil,
13399 );
13400 }
13401 }
13402
13403 Thunk::SoftmaxCrossEntropy {
13404 logits,
13405 labels,
13406 dst,
13407 n,
13408 c,
13409 } => {
13410 let n = *n as usize;
13411 let c = *c as usize;
13412 unsafe {
13413 let lg = sl(*logits, base, n * c);
13414 let lb = sl(*labels, base, n);
13415 let out = sl_mut(*dst, base, n);
13416 for ni in 0..n {
13417 let row = &lg[ni * c..(ni + 1) * c];
13418 let mut m = f32::NEG_INFINITY;
13420 for &v in row {
13421 if v > m {
13422 m = v;
13423 }
13424 }
13425 let mut sum = 0f32;
13426 for &v in row {
13427 sum += (v - m).exp();
13428 }
13429 let lse = m + sum.ln();
13430 let label_idx = lb[ni] as usize;
13431 out[ni] = lse - row[label_idx];
13433 }
13434 }
13435 }
13436
13437 Thunk::SoftmaxCrossEntropyBackward {
13438 logits,
13439 labels,
13440 d_loss,
13441 dlogits,
13442 n,
13443 c,
13444 } => {
13445 let n = *n as usize;
13446 let c = *c as usize;
13447 unsafe {
13448 let lg = sl(*logits, base, n * c);
13449 let lb = sl(*labels, base, n);
13450 let dl = sl(*d_loss, base, n);
13451 let out = sl_mut(*dlogits, base, n * c);
13452 for ni in 0..n {
13453 let row = &lg[ni * c..(ni + 1) * c];
13454 let label_idx = lb[ni] as usize;
13455 let scale = dl[ni];
13456 let mut m = f32::NEG_INFINITY;
13457 for &v in row {
13458 if v > m {
13459 m = v;
13460 }
13461 }
13462 let mut sum = 0f32;
13463 for &v in row {
13464 sum += (v - m).exp();
13465 }
13466 let inv_sum = 1.0 / sum;
13467 let dst_row = &mut out[ni * c..(ni + 1) * c];
13468 for k in 0..c {
13469 let p = (row[k] - m).exp() * inv_sum;
13470 let one_hot = if k == label_idx { 1.0 } else { 0.0 };
13471 dst_row[k] = (p - one_hot) * scale;
13472 }
13473 }
13474 }
13475 }
13476
13477 Thunk::GatherAxis {
13478 table,
13479 idx,
13480 dst,
13481 outer,
13482 axis_dim,
13483 num_idx,
13484 trailing,
13485 idx_i64,
13486 table_bytes,
13487 } => {
13488 let outer = *outer as usize;
13489 let axis_dim = *axis_dim as usize;
13490 let num_idx = *num_idx as usize;
13491 let trailing = *trailing as usize;
13492 unsafe {
13493 if *table_bytes == 8 {
13494 let tab = sl_i64(*table, base, outer * axis_dim * trailing);
13495 let out = sl_mut_i64(*dst, base, outer * num_idx * trailing);
13496 for o in 0..outer {
13497 let tab_outer = o * axis_dim * trailing;
13498 let out_outer = o * num_idx * trailing;
13499 if *idx_i64 != 0 {
13500 let ids = sl_i64(*idx, base, num_idx);
13501 for k in 0..num_idx {
13502 let row = ids[k].max(0) as usize;
13503 if row < axis_dim {
13504 let tab_row = tab_outer + row * trailing;
13505 let out_row = out_outer + k * trailing;
13506 out[out_row..out_row + trailing]
13507 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13508 }
13509 }
13510 } else {
13511 let ids = sl(*idx, base, num_idx);
13512 for k in 0..num_idx {
13513 let row = ids[k] as usize;
13514 if row < axis_dim {
13515 let tab_row = tab_outer + row * trailing;
13516 let out_row = out_outer + k * trailing;
13517 out[out_row..out_row + trailing]
13518 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13519 }
13520 }
13521 }
13522 }
13523 } else {
13524 let tab = sl(*table, base, outer * axis_dim * trailing);
13525 let out = sl_mut(*dst, base, outer * num_idx * trailing);
13526 for o in 0..outer {
13527 let tab_outer = o * axis_dim * trailing;
13528 let out_outer = o * num_idx * trailing;
13529 if *idx_i64 != 0 {
13530 let ids = sl_i64(*idx, base, num_idx);
13531 for k in 0..num_idx {
13532 let row = ids[k].max(0) as usize;
13533 if row < axis_dim {
13534 let tab_row = tab_outer + row * trailing;
13535 let out_row = out_outer + k * trailing;
13536 out[out_row..out_row + trailing]
13537 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13538 }
13539 }
13540 } else {
13541 let ids = sl(*idx, base, num_idx);
13542 for k in 0..num_idx {
13543 let row = ids[k] as usize;
13544 if row < axis_dim {
13545 let tab_row = tab_outer + row * trailing;
13546 let out_row = out_outer + k * trailing;
13547 out[out_row..out_row + trailing]
13548 .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13549 }
13550 }
13551 }
13552 }
13553 }
13554 }
13555 }
13556
13557 Thunk::Transpose {
13558 src,
13559 dst,
13560 in_total,
13561 out_dims,
13562 in_strides,
13563 elem_bytes,
13564 } => {
13565 let rank = out_dims.len();
13570 let total: usize = out_dims.iter().map(|&d| d as usize).product();
13571 let in_total = *in_total as usize;
13572 unsafe {
13573 if *elem_bytes == 8 {
13574 let inp = sl_i64(*src, base, in_total);
13575 let out = sl_mut_i64(*dst, base, total);
13576 let mut idx = vec![0usize; rank];
13577 for o in 0..total {
13578 let mut src_idx = 0usize;
13579 for d in 0..rank {
13580 src_idx += idx[d] * in_strides[d] as usize;
13581 }
13582 out[o] = inp[src_idx];
13583 for d in (0..rank).rev() {
13584 idx[d] += 1;
13585 if idx[d] < out_dims[d] as usize {
13586 break;
13587 }
13588 idx[d] = 0;
13589 }
13590 }
13591 } else {
13592 let inp = sl(*src, base, in_total);
13593 let out = sl_mut(*dst, base, total);
13594 let mut idx = vec![0usize; rank];
13595 for o in 0..total {
13596 let mut src_idx = 0usize;
13597 for d in 0..rank {
13598 src_idx += idx[d] * in_strides[d] as usize;
13599 }
13600 out[o] = inp[src_idx];
13601 for d in (0..rank).rev() {
13602 idx[d] += 1;
13603 if idx[d] < out_dims[d] as usize {
13604 break;
13605 }
13606 idx[d] = 0;
13607 }
13608 }
13609 }
13610 }
13611 }
13612
13613 Thunk::CustomOp {
13619 kernel,
13620 inputs,
13621 output,
13622 attrs,
13623 } => {
13624 let (out_off, out_len, out_shape) = output;
13625 unsafe {
13626 dispatch_custom_op(
13627 &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
13628 );
13629 }
13630 }
13631 }
13632 if trace_done {
13633 eprintln!("[thunk {i} done]");
13634 }
13635 }
13636}
13637
13638#[allow(clippy::too_many_arguments)]
13653unsafe fn griewank_process_segment(
13654 t_lo: usize,
13655 t_hi: usize,
13656 anchor_carry: &[u8],
13657 cb: usize,
13658 fwd_sched: &ThunkSchedule,
13659 fwd_init: &[u8],
13660 fwd_carry_in_off: usize,
13661 fwd_output_off: usize,
13662 fwd_x_offs: &[usize],
13663 base: *mut u8,
13664 outer_xs_offs: &[(usize, u32)],
13665 fwd_buf: &mut Vec<u8>,
13666 leaf_threshold: usize,
13667 process_iter: &mut dyn FnMut(usize, &[u8]),
13668) {
13669 unsafe {
13670 let size = t_hi - t_lo + 1;
13671 if size == 1 {
13672 process_iter(t_lo, anchor_carry);
13673 return;
13674 }
13675 if size <= leaf_threshold {
13676 let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
13678 cache.extend_from_slice(anchor_carry);
13679 fwd_buf.copy_from_slice(fwd_init);
13680 std::ptr::copy_nonoverlapping(
13681 anchor_carry.as_ptr(),
13682 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
13683 cb,
13684 );
13685 for i in 1..size {
13686 let cur_iter = t_lo + i - 1;
13687 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
13688 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
13689 let xb = x_psb as usize;
13690 std::ptr::copy_nonoverlapping(
13691 base.add(outer_xs_off + cur_iter * xb),
13692 fwd_buf.as_mut_ptr().add(*fb_x_off),
13693 xb,
13694 );
13695 }
13696 execute_thunks(fwd_sched, fwd_buf);
13697 if fwd_output_off != fwd_carry_in_off {
13698 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
13699 }
13700 cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
13701 }
13702 for t in (t_lo..=t_hi).rev() {
13704 let idx = t - t_lo;
13705 let carry = &cache[idx * cb..(idx + 1) * cb];
13706 process_iter(t, carry);
13707 }
13708 return;
13709 }
13710
13711 let mid = t_lo + size / 2;
13715 fwd_buf.copy_from_slice(fwd_init);
13716 std::ptr::copy_nonoverlapping(
13717 anchor_carry.as_ptr(),
13718 fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
13719 cb,
13720 );
13721 for cur_iter in t_lo..mid {
13722 for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
13723 let (outer_xs_off, x_psb) = outer_xs_offs[idx];
13724 let xb = x_psb as usize;
13725 std::ptr::copy_nonoverlapping(
13726 base.add(outer_xs_off + cur_iter * xb),
13727 fwd_buf.as_mut_ptr().add(*fb_x_off),
13728 xb,
13729 );
13730 }
13731 execute_thunks(fwd_sched, fwd_buf);
13732 if fwd_output_off != fwd_carry_in_off {
13733 fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
13734 }
13735 }
13736 let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
13737
13738 griewank_process_segment(
13742 mid,
13743 t_hi,
13744 &mid_carry,
13745 cb,
13746 fwd_sched,
13747 fwd_init,
13748 fwd_carry_in_off,
13749 fwd_output_off,
13750 fwd_x_offs,
13751 base,
13752 outer_xs_offs,
13753 fwd_buf,
13754 leaf_threshold,
13755 process_iter,
13756 );
13757 griewank_process_segment(
13759 t_lo,
13760 mid - 1,
13761 anchor_carry,
13762 cb,
13763 fwd_sched,
13764 fwd_init,
13765 fwd_carry_in_off,
13766 fwd_output_off,
13767 fwd_x_offs,
13768 base,
13769 outer_xs_offs,
13770 fwd_buf,
13771 leaf_threshold,
13772 process_iter,
13773 );
13774 }
13775}
13776
13777pub unsafe fn execute_fft1d_f64(
13794 src: usize,
13795 dst: usize,
13796 outer: usize,
13797 n_complex: usize,
13798 inverse: bool,
13799 norm_tag: u32,
13800 base: *mut u8,
13801) {
13802 let row_elems = 2 * n_complex;
13803 let mut re = vec![0f64; n_complex];
13804 let mut im = vec![0f64; n_complex];
13805 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
13806 let scale = norm.output_scale(n_complex, inverse);
13807 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
13810 BluesteinScratchF64::empty()
13811 } else {
13812 BluesteinScratchF64::build(n_complex, inverse)
13813 };
13814 for o in 0..outer {
13815 let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
13816 let s = unsafe { sl_f64(row_offset, base, row_elems) };
13817 re.copy_from_slice(&s[..n_complex]);
13818 im.copy_from_slice(&s[n_complex..]);
13819 if n_complex.is_power_of_two() {
13820 fft_radix2_inplace_f64(&mut re, &mut im, inverse);
13821 } else if n_complex <= 16 {
13822 fft_naive_inplace_f64(&mut re, &mut im, inverse);
13823 } else {
13824 fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
13825 }
13826 if scale != 1.0 {
13827 re.iter_mut().for_each(|v| *v *= scale);
13828 im.iter_mut().for_each(|v| *v *= scale);
13829 }
13830 let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
13831 let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
13832 d[..n_complex].copy_from_slice(&re);
13833 d[n_complex..].copy_from_slice(&im);
13834 }
13835}
13836
13837pub unsafe fn execute_gated_delta_net_f32(
13846 q: usize,
13847 k: usize,
13848 v: usize,
13849 g: usize,
13850 beta: usize,
13851 state: usize,
13852 dst: usize,
13853 batch: usize,
13854 seq: usize,
13855 heads: usize,
13856 state_size: usize,
13857 base: *mut u8,
13858) {
13859 use rayon::prelude::*;
13860
13861 #[derive(Copy, Clone)]
13862 struct ArenaPtr(usize);
13863 unsafe impl Send for ArenaPtr {}
13864 unsafe impl Sync for ArenaPtr {}
13865 impl ArenaPtr {
13866 #[inline]
13867 fn get(self) -> *mut u8 {
13868 self.0 as *mut u8
13869 }
13870 }
13871
13872 unsafe {
13873 let arena = ArenaPtr(base as usize);
13874 let (b, s, h, n) = (batch, seq, heads, state_size);
13875 let scale = 1.0f32 / (n as f32).sqrt();
13876 let use_external = state != 0;
13877 let mut owned_state = vec![0f32; h * n * n];
13878
13879 crate::pool::num_threads();
13880
13881 assert!(
13882 n <= crate::gdn::GDN_MAX_STATE,
13883 "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
13884 crate::gdn::GDN_MAX_STATE
13885 );
13886
13887 let qs = sl(q, arena.get(), b * s * h * n);
13888 let ks = sl(k, arena.get(), b * s * h * n);
13889 let vs = sl(v, arena.get(), b * s * h * n);
13890 let gs = sl(g, arena.get(), b * s * h);
13891 let betas = sl(beta, arena.get(), b * s * h);
13892 let _out = sl_mut(dst, arena.get(), b * s * h * n);
13893 let hs_n = h * n;
13894
13895 let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
13896 for ti in 0..s {
13897 let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
13898 let gb_step = bi * s * h + ti * h + hi;
13899 let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
13900 crate::gdn::gdn_step_blas(
13901 s_mat,
13902 &qs[qkv_step..qkv_step + n],
13903 &ks[qkv_step..qkv_step + n],
13904 &vs[qkv_step..qkv_step + n],
13905 gs[gb_step],
13906 betas[gb_step],
13907 out_row,
13908 sk,
13909 n,
13910 scale,
13911 );
13912 }
13913 };
13914
13915 if !use_external && s > 1 {
13918 for bi in 0..b {
13919 (0..h).into_par_iter().for_each(|hi| {
13920 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
13921 let sk = &mut sk_buf[..n];
13922 let mut local_state =
13923 [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
13924 let s_mat = &mut local_state[..n * n];
13925 s_mat.fill(0.0);
13926 run_head(bi, hi, s_mat, sk);
13927 });
13928 }
13929 return;
13930 }
13931
13932 if use_external {
13933 let state_bytes = state;
13934 (0..b * h).into_par_iter().for_each(|bhi| {
13935 let bi = bhi / h;
13936 let hi = bhi % h;
13937 let elem_off = bi * h * n * n + hi * n * n;
13938 let s_mat = sl_mut(
13939 state_bytes + elem_off * std::mem::size_of::<f32>(),
13940 arena.get(),
13941 n * n,
13942 );
13943 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
13944 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
13945 });
13946 } else {
13947 for bi in 0..b {
13948 owned_state.fill(0.0);
13949 owned_state
13950 .par_chunks_mut(n * n)
13951 .enumerate()
13952 .for_each(|(hi, s_mat)| {
13953 let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
13954 run_head(bi, hi, s_mat, &mut sk_buf[..n]);
13955 });
13956 }
13957 }
13958 }
13959}
13960
13961pub unsafe fn execute_rms_norm_backward_input_f32(
13963 x: usize,
13964 gamma: usize,
13965 beta: usize,
13966 dy: usize,
13967 dx: usize,
13968 rows: u32,
13969 h: u32,
13970 eps: f32,
13971 base: *mut u8,
13972) {
13973 let (rows, h) = (rows as usize, h as usize);
13974 let mut dg = vec![0f32; h];
13975 let mut db = vec![0f32; h];
13976 let xs = sl(x, base, rows * h);
13977 let dys = sl(dy, base, rows * h);
13978 let g = sl(gamma, base, h);
13979 let b = sl(beta, base, h);
13980 let out = sl_mut(dx, base, rows * h);
13981 for r in 0..rows {
13982 crate::training_bwd::rms_norm_backward_row(
13983 &xs[r * h..(r + 1) * h],
13984 g,
13985 b,
13986 &dys[r * h..(r + 1) * h],
13987 &mut out[r * h..(r + 1) * h],
13988 &mut dg,
13989 &mut db,
13990 eps,
13991 );
13992 }
13993}
13994
13995pub unsafe fn execute_rms_norm_backward_gamma_f32(
13996 x: usize,
13997 gamma: usize,
13998 beta: usize,
13999 dy: usize,
14000 dgamma: usize,
14001 rows: u32,
14002 h: u32,
14003 eps: f32,
14004 base: *mut u8,
14005) {
14006 let (rows, h) = (rows as usize, h as usize);
14007 let out = sl_mut(dgamma, base, h);
14008 out.fill(0.0);
14009 let mut dx = vec![0f32; h];
14010 let mut db = vec![0f32; h];
14011 let xs = sl(x, base, rows * h);
14012 let dys = sl(dy, base, rows * h);
14013 let g = sl(gamma, base, h);
14014 let b = sl(beta, base, h);
14015 for r in 0..rows {
14016 crate::training_bwd::rms_norm_backward_row(
14017 &xs[r * h..(r + 1) * h],
14018 g,
14019 b,
14020 &dys[r * h..(r + 1) * h],
14021 &mut dx,
14022 out,
14023 &mut db,
14024 eps,
14025 );
14026 }
14027}
14028
14029pub unsafe fn execute_rms_norm_backward_beta_f32(
14030 x: usize,
14031 gamma: usize,
14032 beta: usize,
14033 dy: usize,
14034 dbeta: usize,
14035 rows: u32,
14036 h: u32,
14037 eps: f32,
14038 base: *mut u8,
14039) {
14040 let (rows, h) = (rows as usize, h as usize);
14041 let out = sl_mut(dbeta, base, h);
14042 out.fill(0.0);
14043 let mut dx = vec![0f32; h];
14044 let mut dg = vec![0f32; h];
14045 let xs = sl(x, base, rows * h);
14046 let dys = sl(dy, base, rows * h);
14047 let g = sl(gamma, base, h);
14048 let b = sl(beta, base, h);
14049 for r in 0..rows {
14050 crate::training_bwd::rms_norm_backward_row(
14051 &xs[r * h..(r + 1) * h],
14052 g,
14053 b,
14054 &dys[r * h..(r + 1) * h],
14055 &mut dx,
14056 &mut dg,
14057 out,
14058 eps,
14059 );
14060 }
14061}
14062
14063#[allow(clippy::too_many_arguments)]
14064pub unsafe fn execute_conv2d_forward_f32(
14065 src: usize,
14066 weight: usize,
14067 dst: usize,
14068 n: u32,
14069 c_in: u32,
14070 h: u32,
14071 w: u32,
14072 c_out: u32,
14073 h_out: u32,
14074 w_out: u32,
14075 kh: u32,
14076 kw: u32,
14077 sh: u32,
14078 sw: u32,
14079 ph: u32,
14080 pw: u32,
14081 dh: u32,
14082 dw: u32,
14083 groups: u32,
14084 base: *mut u8,
14085) {
14086 let n = n as usize;
14087 let c_in = c_in as usize;
14088 let h = h as usize;
14089 let w = w as usize;
14090 let c_out = c_out as usize;
14091 let h_out = h_out as usize;
14092 let w_out = w_out as usize;
14093 let kh = kh as usize;
14094 let kw = kw as usize;
14095 let sh = sh as usize;
14096 let sw = sw as usize;
14097 let ph = ph as usize;
14098 let pw = pw as usize;
14099 let dh = dh as usize;
14100 let dw = dw as usize;
14101 let groups = groups as usize;
14102 let c_in_per_g = c_in / groups;
14103 let inp = sl(src, base, n * c_in * h * w);
14104 let wt = sl(weight, base, c_out * c_in_per_g * kh * kw);
14105 let out = sl_mut(dst, base, n * c_out * h_out * w_out);
14106 crate::conv_fwd::conv2d_forward_nchw_f32(
14107 inp, wt, out, n, c_in, h, w, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw, groups,
14108 );
14109}
14110
14111pub unsafe fn execute_maxpool2d_backward_f32(
14112 x: usize,
14113 dy: usize,
14114 dx: usize,
14115 n: u32,
14116 c: u32,
14117 h: u32,
14118 w: u32,
14119 h_out: u32,
14120 w_out: u32,
14121 kh: u32,
14122 kw: u32,
14123 sh: u32,
14124 sw: u32,
14125 ph: u32,
14126 pw: u32,
14127 base: *mut u8,
14128) {
14129 let (n, c, h, w) = (n as usize, c as usize, h as usize, w as usize);
14130 let (h_out, w_out) = (h_out as usize, w_out as usize);
14131 let (kh, kw) = (kh as usize, kw as usize);
14132 let (sh, sw) = (sh as usize, sw as usize);
14133 let (ph, pw) = (ph as usize, pw as usize);
14134 let xs = sl(x, base, n * c * h * w);
14135 let dys = sl(dy, base, n * c * h_out * w_out);
14136 let dxs = sl_mut(dx, base, n * c * h * w);
14137 crate::training_bwd::maxpool2d_backward_nchw(
14138 xs, dys, dxs, n, c, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw,
14139 );
14140}
14141
14142pub unsafe fn execute_rope_backward_f32(
14143 dy: usize,
14144 cos: usize,
14145 sin: usize,
14146 dx: usize,
14147 batch: u32,
14148 seq: u32,
14149 hidden: u32,
14150 head_dim: u32,
14151 n_rot: u32,
14152 cos_len: u32,
14153 base: *mut u8,
14154) {
14155 let (b, s, hs, dh, nr, cl) = (
14156 batch as usize,
14157 seq as usize,
14158 hidden as usize,
14159 head_dim as usize,
14160 n_rot as usize,
14161 cos_len as usize,
14162 );
14163 let nh = hs / dh;
14164 let tab_half = dh / 2;
14165 let dys = sl(dy, base, b * s * hs);
14166 let cos_tab = sl(cos, base, cl);
14167 let sin_tab = sl(sin, base, cl);
14168 let out = sl_mut(dx, base, b * s * hs);
14169 for bi in 0..b {
14170 for si in 0..s {
14171 let tab_off = si.saturating_mul(tab_half) % cl.max(1);
14172 let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
14173 let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
14174 for hi in 0..nh {
14175 let base_idx = bi * s * hs + si * hs + hi * dh;
14176 crate::training_bwd::rope_backward_row(
14177 &dys[base_idx..base_idx + dh],
14178 cp,
14179 sp,
14180 &mut out[base_idx..base_idx + dh],
14181 dh,
14182 nr,
14183 );
14184 }
14185 }
14186 }
14187}
14188
14189pub unsafe fn execute_cumsum_backward_f32(
14190 dy: usize,
14191 dx: usize,
14192 rows: u32,
14193 cols: u32,
14194 exclusive: bool,
14195 base: *mut u8,
14196) {
14197 let (rows, cols) = (rows as usize, cols as usize);
14198 let dys = sl(dy, base, rows * cols);
14199 let out = sl_mut(dx, base, rows * cols);
14200 for r in 0..rows {
14201 crate::training_bwd::cumsum_backward_row(
14202 &dys[r * cols..(r + 1) * cols],
14203 &mut out[r * cols..(r + 1) * cols],
14204 exclusive,
14205 );
14206 }
14207}
14208
14209pub unsafe fn execute_gather_backward_f32(
14210 dy: usize,
14211 indices: usize,
14212 dst: usize,
14213 outer: u32,
14214 axis_dim: u32,
14215 num_idx: u32,
14216 trailing: u32,
14217 base: *mut u8,
14218) {
14219 let (outer, axis_dim, num_idx, trailing) = (
14220 outer as usize,
14221 axis_dim as usize,
14222 num_idx as usize,
14223 trailing as usize,
14224 );
14225 let out = sl_mut(dst, base, outer * axis_dim * trailing);
14226 out.fill(0.0);
14227 crate::training_bwd::gather_axis_backward(
14228 sl(dy, base, outer * num_idx * trailing),
14229 sl(indices, base, num_idx),
14230 out,
14231 outer,
14232 axis_dim,
14233 num_idx,
14234 trailing,
14235 );
14236}
14237
14238pub unsafe fn execute_dequant_matmul_gguf_f32(
14240 x: usize,
14241 w_q: usize,
14242 dst: usize,
14243 m: usize,
14244 k: usize,
14245 n: usize,
14246 scheme: rlx_ir::quant::QuantScheme,
14247 base: *mut u8,
14248) {
14249 unsafe {
14250 let block_bytes = scheme.gguf_block_bytes() as usize;
14251 let block_elems = scheme.gguf_block_size() as usize;
14252 let total_bytes = (k * n) / block_elems * block_bytes;
14253 let xs = sl(x, base, m * k);
14254 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
14255 let out = sl_mut(dst, base, m * n);
14256 crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
14257 }
14258}
14259
14260pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
14262 input: usize,
14263 w_q: usize,
14264 expert_idx: usize,
14265 dst: usize,
14266 m: usize,
14267 k: usize,
14268 n: usize,
14269 num_experts: usize,
14270 scheme: rlx_ir::quant::QuantScheme,
14271 base: *mut u8,
14272) {
14273 unsafe {
14274 let block_bytes = scheme.gguf_block_bytes() as usize;
14275 let block_elems = scheme.gguf_block_size() as usize;
14276 let slab_bytes = (k * n) / block_elems * block_bytes;
14277 let xs = sl(input, base, m * k);
14278 let w_bytes =
14279 std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
14280 let ids = sl(expert_idx, base, m);
14281 let out = sl_mut(dst, base, m * n);
14282 crate::gguf_matmul::gguf_grouped_matmul_bt(
14283 xs,
14284 w_bytes,
14285 ids,
14286 out,
14287 m,
14288 k,
14289 n,
14290 num_experts,
14291 scheme,
14292 );
14293 }
14294}
14295
14296pub unsafe fn execute_dequant_matmul_int4_f32(
14298 x: usize,
14299 w_q: usize,
14300 scale: usize,
14301 zp: usize,
14302 dst: usize,
14303 m: usize,
14304 k: usize,
14305 n: usize,
14306 block_size: u32,
14307 is_asymmetric: bool,
14308 base: *mut u8,
14309) {
14310 let bs = block_size as usize;
14311 let n_blocks = k.div_ceil(bs);
14312 unsafe {
14313 let xs = sl(x, base, m * k);
14314 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
14315 let scales = sl(scale, base, n_blocks * n);
14316 let zps = if is_asymmetric {
14317 sl(zp, base, n_blocks * n)
14318 } else {
14319 &[][..]
14320 };
14321 let out = sl_mut(dst, base, m * n);
14322 dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
14323 }
14324}
14325
14326pub unsafe fn execute_dequant_matmul_fp8_f32(
14328 x: usize,
14329 w_q: usize,
14330 scale: usize,
14331 dst: usize,
14332 m: usize,
14333 k: usize,
14334 n: usize,
14335 e5m2: bool,
14336 base: *mut u8,
14337) {
14338 unsafe {
14339 let xs = sl(x, base, m * k);
14340 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
14341 let scales = sl(scale, base, n);
14342 let out = sl_mut(dst, base, m * n);
14343 dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
14344 }
14345}
14346
14347pub unsafe fn execute_dequant_matmul_nvfp4_f32(
14349 x: usize,
14350 w_q: usize,
14351 scale: usize,
14352 global_scale: usize,
14353 dst: usize,
14354 m: usize,
14355 k: usize,
14356 n: usize,
14357 base: *mut u8,
14358) {
14359 let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
14360 unsafe {
14361 let xs = sl(x, base, m * k);
14362 let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
14363 let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
14364 let gs = sl(global_scale, base, 1)[0];
14365 let out = sl_mut(dst, base, m * n);
14366 dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
14367 }
14368}
14369
14370pub unsafe fn execute_gated_delta_net_f16(
14372 q: usize,
14373 k: usize,
14374 v: usize,
14375 g: usize,
14376 beta: usize,
14377 state: usize,
14378 dst: usize,
14379 batch: usize,
14380 seq: usize,
14381 heads: usize,
14382 state_size: usize,
14383 base: *mut u8,
14384) {
14385 use half::f16;
14386 unsafe {
14387 let read_f16 = |off: usize, len: usize| -> Vec<f32> {
14388 let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
14389 raw.chunks_exact(2)
14390 .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
14391 .collect()
14392 };
14393 let write_f16 = |off: usize, data: &[f32]| {
14394 let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
14395 for (i, &v) in data.iter().enumerate() {
14396 let le = f16::from_f32(v).to_le_bytes();
14397 out[i * 2] = le[0];
14398 out[i * 2 + 1] = le[1];
14399 }
14400 };
14401
14402 let (b, s, h, n) = (batch, seq, heads, state_size);
14403 let q_f = read_f16(q, b * s * h * n);
14404 let k_f = read_f16(k, b * s * h * n);
14405 let v_f = read_f16(v, b * s * h * n);
14406 let g_f = read_f16(g, b * s * h);
14407 let b_f = read_f16(beta, b * s * h);
14408 let mut state_f = if state != 0 {
14409 read_f16(state, b * h * n * n)
14410 } else {
14411 vec![0f32; b * h * n * n]
14412 };
14413 let mut out_f = vec![0f32; b * s * h * n];
14414 let scale = 1.0f32 / (n as f32).sqrt();
14415 let mut sk_buf = vec![0f32; n];
14416 let mut owned_state = vec![0f32; h * n * n];
14417
14418 for bi in 0..b {
14419 let state_slice: &mut [f32] = if state != 0 {
14420 let start = bi * h * n * n;
14421 &mut state_f[start..start + h * n * n]
14422 } else {
14423 owned_state.fill(0.0);
14424 &mut owned_state
14425 };
14426
14427 for ti in 0..s {
14428 let qkv_step_base = bi * s * h * n + ti * h * n;
14429 let gb_step_base = bi * s * h + ti * h;
14430
14431 for hi in 0..h {
14432 let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14433 let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14434 let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14435 let g_t = g_f[gb_step_base + hi];
14436 let beta_t = b_f[gb_step_base + hi];
14437
14438 let s_base = hi * n * n;
14439 let s_mat = &mut state_slice[s_base..s_base + n * n];
14440
14441 let g_exp = g_t.exp();
14442 for st in s_mat.iter_mut() {
14443 *st *= g_exp;
14444 }
14445
14446 for j in 0..n {
14447 let mut acc = 0f32;
14448 for i in 0..n {
14449 acc += s_mat[i * n + j] * k_row[i];
14450 }
14451 sk_buf[j] = acc;
14452 }
14453
14454 for j in 0..n {
14455 sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
14456 }
14457
14458 for i in 0..n {
14459 let ki = k_row[i];
14460 for j in 0..n {
14461 s_mat[i * n + j] += ki * sk_buf[j];
14462 }
14463 }
14464
14465 let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14466 for j in 0..n {
14467 let mut acc = 0f32;
14468 for i in 0..n {
14469 acc += s_mat[i * n + j] * q_row[i];
14470 }
14471 out_row[j] = acc * scale;
14472 }
14473 }
14474 }
14475 }
14476
14477 write_f16(dst, &out_f);
14478 if state != 0 {
14479 write_f16(state, &state_f);
14480 }
14481 }
14482}
14483
14484pub unsafe fn execute_group_norm_nchw_f32(
14486 src: usize,
14487 g: usize,
14488 b: usize,
14489 dst: usize,
14490 n: usize,
14491 c: usize,
14492 h: usize,
14493 w: usize,
14494 num_groups: usize,
14495 eps: f32,
14496 base: *mut u8,
14497) {
14498 let plane = c * h * w;
14499 for ni in 0..n {
14500 let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
14501 let gamma = unsafe { sl(g, base, c) };
14502 let beta = unsafe { sl(b, base, c) };
14503 let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
14504 crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
14505 }
14506}
14507
14508pub unsafe fn execute_layer_norm2d_nchw_f32(
14510 src: usize,
14511 g: usize,
14512 b: usize,
14513 dst: usize,
14514 n: usize,
14515 c: usize,
14516 h: usize,
14517 w: usize,
14518 eps: f32,
14519 base: *mut u8,
14520) {
14521 let plane = c * h * w;
14522 unsafe {
14523 let input = sl(src, base, n * plane);
14524 let gamma = sl(g, base, c);
14525 let beta = sl(b, base, c);
14526 let output = sl_mut(dst, base, n * plane);
14527 crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
14528 }
14529}
14530
14531pub unsafe fn execute_conv_transpose2d_nchw_f32(
14533 src: usize,
14534 weight: usize,
14535 dst: usize,
14536 n: usize,
14537 c_in: usize,
14538 h: usize,
14539 w_in: usize,
14540 c_out: usize,
14541 h_out: usize,
14542 w_out: usize,
14543 kh: usize,
14544 kw: usize,
14545 sh: usize,
14546 sw: usize,
14547 ph: usize,
14548 pw: usize,
14549 dh: usize,
14550 dw: usize,
14551 groups: usize,
14552 base: *mut u8,
14553) {
14554 let in_elems = n * c_in * h * w_in;
14555 let w_elems = c_in * (c_out / groups) * kh * kw;
14556 let out_elems = n * c_out * h_out * w_out;
14557 unsafe {
14558 let input = sl(src, base, in_elems);
14559 let wt = sl(weight, base, w_elems);
14560 let output = sl_mut(dst, base, out_elems);
14561 crate::kernels::conv_transpose2d_nchw(
14562 input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
14563 dw, groups,
14564 );
14565 }
14566}
14567
14568pub unsafe fn execute_resize_nearest_2x_f32(
14570 src: usize,
14571 dst: usize,
14572 n: usize,
14573 c: usize,
14574 h: usize,
14575 w: usize,
14576 base: *mut u8,
14577) {
14578 let in_plane = c * h * w;
14579 let out_plane = c * h * 2 * w * 2;
14580 for ni in 0..n {
14581 let input = unsafe {
14582 sl(
14583 src + ni * in_plane * std::mem::size_of::<f32>(),
14584 base,
14585 in_plane,
14586 )
14587 };
14588 let output = unsafe {
14589 sl_mut(
14590 dst + ni * out_plane * std::mem::size_of::<f32>(),
14591 base,
14592 out_plane,
14593 )
14594 };
14595 crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
14596 }
14597}
14598
14599pub unsafe fn execute_axial_rope2d_f32(
14601 src: usize,
14602 dst: usize,
14603 batch: usize,
14604 seq: usize,
14605 hidden: usize,
14606 end_x: usize,
14607 end_y: usize,
14608 head_dim: usize,
14609 num_heads: usize,
14610 theta: f32,
14611 repeat_factor: usize,
14612 base: *mut u8,
14613) {
14614 let plane = seq * hidden;
14615 let plane_bytes = plane * std::mem::size_of::<f32>();
14616 for bi in 0..batch {
14617 let in_off = src + bi * plane_bytes;
14618 let input = unsafe { sl(in_off, base, plane) };
14619 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
14620 input,
14621 num_heads,
14622 seq,
14623 head_dim,
14624 end_x,
14625 end_y,
14626 theta,
14627 repeat_factor,
14628 );
14629 let out_off = dst + bi * plane_bytes;
14630 let output = unsafe { sl_mut(out_off, base, plane) };
14631 output.copy_from_slice(&rotated);
14632 }
14633}
14634
14635pub unsafe fn execute_fft_butterfly_stage_f32(
14637 state_src: usize,
14638 state_dst: usize,
14639 gate_src: usize,
14640 rev_src: usize,
14641 tw_re_src: usize,
14642 tw_im_src: usize,
14643 batch: usize,
14644 n_fft: usize,
14645 stage: usize,
14646 base: *mut u8,
14647) {
14648 let half = n_fft / 2;
14649 let stride = 1usize << stage;
14650 let gate = unsafe { sl(gate_src, base, half) };
14651 let rev = unsafe { sl(rev_src, base, half) };
14652 let tw_re = unsafe { sl(tw_re_src, base, half) };
14653 let tw_im = unsafe { sl(tw_im_src, base, half) };
14654 let row_elems = n_fft * 2;
14655 for b in 0..batch {
14656 let in_off = state_src + b * row_elems * std::mem::size_of::<f32>();
14657 let out_off = state_dst + b * row_elems * std::mem::size_of::<f32>();
14658 let inp = unsafe { sl(in_off, base, row_elems) };
14659 let out = unsafe { sl_mut(out_off, base, row_elems) };
14660 out.copy_from_slice(inp);
14661 for bf in 0..half {
14662 if gate[bf] == 0.0 {
14663 continue;
14664 }
14665 let group = bf / stride;
14666 let k = bf % stride;
14667 let i0 = group * 2 * stride + k;
14668 let i1 = i0 + stride;
14669 let w_re = tw_re[bf];
14670 let w_im = tw_im[bf];
14671 let in_a_re = inp[i0 * 2];
14672 let in_a_im = inp[i0 * 2 + 1];
14673 let in_b_re = inp[i1 * 2];
14674 let in_b_im = inp[i1 * 2 + 1];
14675 let (b_re, b_im) = (
14676 in_b_re * w_re - in_b_im * w_im,
14677 in_b_re * w_im + in_b_im * w_re,
14678 );
14679 let (top_re, top_im) = (in_a_re + b_re, in_a_im + b_im);
14680 let (bot_re, bot_im) = (in_a_re - b_re, in_a_im - b_im);
14681 let (oa_re, oa_im, ob_re, ob_im) = if rev[bf] >= 0.5 {
14682 (bot_re, bot_im, top_re, top_im)
14683 } else {
14684 (top_re, top_im, bot_re, bot_im)
14685 };
14686 out[i0 * 2] = oa_re;
14687 out[i0 * 2 + 1] = oa_im;
14688 out[i1 * 2] = ob_re;
14689 out[i1 * 2 + 1] = ob_im;
14690 }
14691 }
14692}
14693
14694pub unsafe fn execute_fft1d_f32(
14696 src: usize,
14697 dst: usize,
14698 outer: usize,
14699 n_complex: usize,
14700 inverse: bool,
14701 norm_tag: u32,
14702 base: *mut u8,
14703) {
14704 let row_elems = 2 * n_complex;
14705 let mut re = vec![0f32; n_complex];
14706 let mut im = vec![0f32; n_complex];
14707 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14708 let scale = norm.output_scale(n_complex, inverse) as f32;
14709 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14710 BluesteinScratchF32::empty()
14711 } else {
14712 BluesteinScratchF32::build(n_complex, inverse)
14713 };
14714 for o in 0..outer {
14715 let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
14716 let s = unsafe { sl(row_offset, base, row_elems) };
14717 re.copy_from_slice(&s[..n_complex]);
14718 im.copy_from_slice(&s[n_complex..]);
14719 if n_complex.is_power_of_two() {
14720 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
14721 } else if n_complex <= 16 {
14722 fft_naive_inplace_f32(&mut re, &mut im, inverse);
14723 } else {
14724 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
14725 }
14726 if scale != 1.0 {
14727 re.iter_mut().for_each(|v| *v *= scale);
14728 im.iter_mut().for_each(|v| *v *= scale);
14729 }
14730 let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
14731 let d = unsafe { sl_mut(dst_offset, base, row_elems) };
14732 d[..n_complex].copy_from_slice(&re);
14733 d[n_complex..].copy_from_slice(&im);
14734 }
14735}
14736
14737pub unsafe fn execute_fft1d_c64(
14739 src: usize,
14740 dst: usize,
14741 outer: usize,
14742 n_complex: usize,
14743 inverse: bool,
14744 norm_tag: u32,
14745 base: *mut u8,
14746) {
14747 let row_bytes = n_complex * 8;
14748 let mut re = vec![0f32; n_complex];
14749 let mut im = vec![0f32; n_complex];
14750 let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14751 let scale = norm.output_scale(n_complex, inverse) as f32;
14752 let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14753 BluesteinScratchF32::empty()
14754 } else {
14755 BluesteinScratchF32::build(n_complex, inverse)
14756 };
14757 for o in 0..outer {
14758 let row_offset = src + o * row_bytes;
14759 for i in 0..n_complex {
14760 let elem_off = row_offset + i * 8;
14761 re[i] = f32::from_le_bytes([
14762 *base.add(elem_off),
14763 *base.add(elem_off + 1),
14764 *base.add(elem_off + 2),
14765 *base.add(elem_off + 3),
14766 ]);
14767 im[i] = f32::from_le_bytes([
14768 *base.add(elem_off + 4),
14769 *base.add(elem_off + 5),
14770 *base.add(elem_off + 6),
14771 *base.add(elem_off + 7),
14772 ]);
14773 }
14774 if n_complex.is_power_of_two() {
14775 fft_radix2_inplace_f32(&mut re, &mut im, inverse);
14776 } else if n_complex <= 16 {
14777 fft_naive_inplace_f32(&mut re, &mut im, inverse);
14778 } else {
14779 fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
14780 }
14781 if scale != 1.0 {
14782 re.iter_mut().for_each(|v| *v *= scale);
14783 im.iter_mut().for_each(|v| *v *= scale);
14784 }
14785 let dst_row = dst + o * row_bytes;
14786 for i in 0..n_complex {
14787 let elem_off = dst_row + i * 8;
14788 let re_b = re[i].to_le_bytes();
14789 let im_b = im[i].to_le_bytes();
14790 for j in 0..4 {
14791 *base.add(elem_off + j) = re_b[j];
14792 *base.add(elem_off + 4 + j) = im_b[j];
14793 }
14794 }
14795 }
14796}
14797
14798pub unsafe fn execute_log_mel(
14800 spec: usize,
14801 filters: usize,
14802 dst: usize,
14803 outer: usize,
14804 n_fft: usize,
14805 n_bins: usize,
14806 n_mels: usize,
14807 base: *mut u8,
14808) {
14809 execute_log_mel_f32(spec, filters, dst, outer, n_fft, n_bins, n_mels, base);
14810}
14811
14812pub unsafe fn execute_log_mel_f32(
14813 spec: usize,
14814 filters: usize,
14815 dst: usize,
14816 outer: usize,
14817 n_fft: usize,
14818 n_bins: usize,
14819 n_mels: usize,
14820 base: *mut u8,
14821) {
14822 let spec_ptr = base.add(spec) as *const f32;
14823 let filt_ptr = base.add(filters) as *const f32;
14824 let dst_ptr = base.add(dst) as *mut f32;
14825 let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14826 let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
14827 let out = std::slice::from_raw_parts_mut(dst_ptr, outer * n_mels);
14828 rlx_ir::audio::log_mel_block_f32(spec, filters, outer, n_fft, n_bins, n_mels, out);
14829}
14830
14831pub unsafe fn execute_log_mel_backward_f32(
14832 spec: usize,
14833 filters: usize,
14834 dy: usize,
14835 dst: usize,
14836 outer: usize,
14837 n_fft: usize,
14838 n_bins: usize,
14839 n_mels: usize,
14840 base: *mut u8,
14841) {
14842 let spec_ptr = base.add(spec) as *const f32;
14843 let filt_ptr = base.add(filters) as *const f32;
14844 let dy_ptr = base.add(dy) as *const f32;
14845 let dst_ptr = base.add(dst) as *mut f32;
14846 let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14847 let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
14848 let dy = std::slice::from_raw_parts(dy_ptr, outer * n_mels);
14849 let d_spec = std::slice::from_raw_parts_mut(dst_ptr, outer * n_fft * 2);
14850 d_spec.fill(0.0);
14851 rlx_ir::audio::log_mel_block_vjp(spec, filters, dy, outer, n_fft, n_bins, n_mels, d_spec);
14852}
14853
14854pub unsafe fn execute_fft1d(
14856 src: usize,
14857 dst: usize,
14858 outer: usize,
14859 n_complex: usize,
14860 inverse: bool,
14861 norm_tag: u32,
14862 dtype: rlx_ir::DType,
14863 base: *mut u8,
14864) {
14865 match dtype {
14866 rlx_ir::DType::F32 => {
14867 execute_fft1d_f32(src, dst, outer, n_complex, inverse, norm_tag, base)
14868 }
14869 rlx_ir::DType::F64 => {
14870 execute_fft1d_f64(src, dst, outer, n_complex, inverse, norm_tag, base)
14871 }
14872 rlx_ir::DType::C64 => {
14873 execute_fft1d_c64(src, dst, outer, n_complex, inverse, norm_tag, base)
14874 }
14875 other => panic!("execute_fft1d: unsupported dtype {other:?}"),
14876 }
14877}
14878
14879fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
14884 let n = re.len();
14885 debug_assert_eq!(im.len(), n);
14886 debug_assert!(
14887 n.is_power_of_two(),
14888 "fft_radix2_f32: n={n} must be a power of two"
14889 );
14890 if n <= 1 {
14891 return;
14892 }
14893
14894 let mut j = 0usize;
14895 for i in 1..n {
14896 let mut bit = n >> 1;
14897 while j & bit != 0 {
14898 j ^= bit;
14899 bit >>= 1;
14900 }
14901 j ^= bit;
14902 if i < j {
14903 re.swap(i, j);
14904 im.swap(i, j);
14905 }
14906 }
14907
14908 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
14909 let mut len = 2usize;
14910 while len <= n {
14911 let half = len / 2;
14912 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
14913 let w_re_step = theta.cos();
14914 let w_im_step = theta.sin();
14915 let mut i = 0usize;
14916 while i < n {
14917 let mut wre = 1.0_f64;
14918 let mut wim = 0.0_f64;
14919 for k in 0..half {
14920 let wre_f = wre as f32;
14921 let wim_f = wim as f32;
14922 let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
14923 let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
14924 let u_re = re[i + k];
14925 let u_im = im[i + k];
14926 re[i + k] = u_re + t_re;
14927 im[i + k] = u_im + t_im;
14928 re[i + k + half] = u_re - t_re;
14929 im[i + k + half] = u_im - t_im;
14930 let new_wre = wre * w_re_step - wim * w_im_step;
14931 let new_wim = wre * w_im_step + wim * w_re_step;
14932 wre = new_wre;
14933 wim = new_wim;
14934 }
14935 i += len;
14936 }
14937 len <<= 1;
14938 }
14939}
14940
14941fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
14945 let n = re.len();
14946 debug_assert_eq!(im.len(), n);
14947 debug_assert!(
14948 n.is_power_of_two(),
14949 "fft_radix2: n={n} must be a power of two"
14950 );
14951 if n <= 1 {
14952 return;
14953 }
14954
14955 let mut j = 0usize;
14957 for i in 1..n {
14958 let mut bit = n >> 1;
14959 while j & bit != 0 {
14960 j ^= bit;
14961 bit >>= 1;
14962 }
14963 j ^= bit;
14964 if i < j {
14965 re.swap(i, j);
14966 im.swap(i, j);
14967 }
14968 }
14969
14970 let sign = if inverse { 1.0 } else { -1.0 };
14972 let mut len = 2usize;
14973 while len <= n {
14974 let half = len / 2;
14975 let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
14976 let w_re_step = theta.cos();
14977 let w_im_step = theta.sin();
14978 let mut i = 0usize;
14979 while i < n {
14980 let mut wre = 1.0_f64;
14982 let mut wim = 0.0_f64;
14983 for k in 0..half {
14984 let t_re = wre * re[i + k + half] - wim * im[i + k + half];
14985 let t_im = wre * im[i + k + half] + wim * re[i + k + half];
14986 let u_re = re[i + k];
14987 let u_im = im[i + k];
14988 re[i + k] = u_re + t_re;
14989 im[i + k] = u_im + t_im;
14990 re[i + k + half] = u_re - t_re;
14991 im[i + k + half] = u_im - t_im;
14992 let new_wre = wre * w_re_step - wim * w_im_step;
14993 let new_wim = wre * w_im_step + wim * w_re_step;
14994 wre = new_wre;
14995 wim = new_wim;
14996 }
14997 i += len;
14998 }
14999 len <<= 1;
15000 }
15001}
15002
15003struct BluesteinScratchF64 {
15007 m: usize,
15009 w_re: Vec<f64>,
15013 w_im: Vec<f64>,
15014 bf_re: Vec<f64>,
15017 bf_im: Vec<f64>,
15018 ar: Vec<f64>,
15020 ai: Vec<f64>,
15021}
15022
15023impl BluesteinScratchF64 {
15024 fn empty() -> Self {
15025 Self {
15026 m: 0,
15027 w_re: Vec::new(),
15028 w_im: Vec::new(),
15029 bf_re: Vec::new(),
15030 bf_im: Vec::new(),
15031 ar: Vec::new(),
15032 ai: Vec::new(),
15033 }
15034 }
15035
15036 fn build(n: usize, inverse: bool) -> Self {
15037 let m = if n <= 1 {
15040 1
15041 } else {
15042 (2 * n - 1).next_power_of_two()
15043 };
15044
15045 let mod_2n = (2 * n) as u64;
15048 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15049 let mut w_re = vec![0.0_f64; n];
15050 let mut w_im = vec![0.0_f64; n];
15051 for k in 0..n {
15052 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15053 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15054 w_re[k] = theta.cos();
15055 w_im[k] = theta.sin();
15056 }
15057
15058 let mut bf_re = vec![0.0_f64; m];
15061 let mut bf_im = vec![0.0_f64; m];
15062 if n > 0 {
15063 bf_re[0] = w_re[0];
15064 bf_im[0] = -w_im[0];
15065 for k in 1..n {
15066 bf_re[k] = w_re[k];
15067 bf_im[k] = -w_im[k];
15068 bf_re[m - k] = w_re[k];
15069 bf_im[m - k] = -w_im[k];
15070 }
15071 }
15072 if m > 1 {
15073 fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
15074 }
15075
15076 Self {
15077 m,
15078 w_re,
15079 w_im,
15080 bf_re,
15081 bf_im,
15082 ar: vec![0.0_f64; m],
15083 ai: vec![0.0_f64; m],
15084 }
15085 }
15086}
15087
15088fn fft_naive_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15090 let n = re.len();
15091 if n <= 1 {
15092 return;
15093 }
15094 let sign = if inverse { 1.0 } else { -1.0 };
15095 let mut out_re = vec![0.0_f64; n];
15096 let mut out_im = vec![0.0_f64; n];
15097 for k in 0..n {
15098 for nn in 0..n {
15099 let theta = sign * 2.0 * std::f64::consts::PI * (nn as f64) * (k as f64) / (n as f64);
15100 let c = theta.cos();
15101 let s = theta.sin();
15102 out_re[k] += re[nn] * c - im[nn] * s;
15103 out_im[k] += re[nn] * s + im[nn] * c;
15104 }
15105 }
15106 re.copy_from_slice(&out_re);
15107 im.copy_from_slice(&out_im);
15108}
15109
15110fn fft_naive_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
15111 let n = re.len();
15112 if n <= 1 {
15113 return;
15114 }
15115 let sign = if inverse { 1.0f32 } else { -1.0f32 };
15116 let mut out_re = vec![0.0_f32; n];
15117 let mut out_im = vec![0.0_f32; n];
15118 for k in 0..n {
15119 for nn in 0..n {
15120 let theta = sign * 2.0 * std::f32::consts::PI * (nn as f32) * (k as f32) / (n as f32);
15121 let c = theta.cos();
15122 let s = theta.sin();
15123 out_re[k] += re[nn] * c - im[nn] * s;
15124 out_im[k] += re[nn] * s + im[nn] * c;
15125 }
15126 }
15127 re.copy_from_slice(&out_re);
15128 im.copy_from_slice(&out_im);
15129}
15130
15131fn fft_bluestein_inplace_f64(
15140 re: &mut [f64],
15141 im: &mut [f64],
15142 _inverse: bool,
15143 s: &mut BluesteinScratchF64,
15144) {
15145 let n = re.len();
15146 debug_assert_eq!(im.len(), n);
15147 debug_assert_eq!(s.w_re.len(), n);
15148 if n <= 1 {
15149 return;
15150 }
15151 let m = s.m;
15152
15153 for k in 0..m {
15155 s.ar[k] = 0.0;
15156 s.ai[k] = 0.0;
15157 }
15158 for k in 0..n {
15159 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
15160 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
15161 }
15162
15163 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
15165
15166 for k in 0..m {
15168 let ar = s.ar[k];
15169 let ai = s.ai[k];
15170 let br = s.bf_re[k];
15171 let bi = s.bf_im[k];
15172 s.ar[k] = ar * br - ai * bi;
15173 s.ai[k] = ar * bi + ai * br;
15174 }
15175
15176 fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
15179 let inv_m = 1.0 / (m as f64);
15180
15181 for k in 0..n {
15183 let yr = s.ar[k] * inv_m;
15184 let yi = s.ai[k] * inv_m;
15185 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
15186 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
15187 }
15188}
15189
15190struct BluesteinScratchF32 {
15194 m: usize,
15195 w_re: Vec<f32>,
15196 w_im: Vec<f32>,
15197 bf_re: Vec<f32>,
15198 bf_im: Vec<f32>,
15199 ar: Vec<f32>,
15200 ai: Vec<f32>,
15201}
15202
15203impl BluesteinScratchF32 {
15204 fn empty() -> Self {
15205 Self {
15206 m: 0,
15207 w_re: Vec::new(),
15208 w_im: Vec::new(),
15209 bf_re: Vec::new(),
15210 bf_im: Vec::new(),
15211 ar: Vec::new(),
15212 ai: Vec::new(),
15213 }
15214 }
15215
15216 fn build(n: usize, inverse: bool) -> Self {
15217 let m = if n <= 1 {
15218 1
15219 } else {
15220 (2 * n - 1).next_power_of_two()
15221 };
15222
15223 let mod_2n = (2 * n) as u64;
15224 let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15225 let mut w_re = vec![0.0_f32; n];
15226 let mut w_im = vec![0.0_f32; n];
15227 for k in 0..n {
15228 let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15229 let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15230 w_re[k] = theta.cos() as f32;
15231 w_im[k] = theta.sin() as f32;
15232 }
15233
15234 let mut bf_re = vec![0.0_f32; m];
15235 let mut bf_im = vec![0.0_f32; m];
15236 if n > 0 {
15237 bf_re[0] = w_re[0];
15238 bf_im[0] = -w_im[0];
15239 for k in 1..n {
15240 bf_re[k] = w_re[k];
15241 bf_im[k] = -w_im[k];
15242 bf_re[m - k] = w_re[k];
15243 bf_im[m - k] = -w_im[k];
15244 }
15245 }
15246 if m > 1 {
15247 fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
15248 }
15249
15250 Self {
15251 m,
15252 w_re,
15253 w_im,
15254 bf_re,
15255 bf_im,
15256 ar: vec![0.0_f32; m],
15257 ai: vec![0.0_f32; m],
15258 }
15259 }
15260}
15261
15262fn fft_bluestein_inplace_f32(
15263 re: &mut [f32],
15264 im: &mut [f32],
15265 _inverse: bool,
15266 s: &mut BluesteinScratchF32,
15267) {
15268 let n = re.len();
15269 debug_assert_eq!(im.len(), n);
15270 debug_assert_eq!(s.w_re.len(), n);
15271 if n <= 1 {
15272 return;
15273 }
15274 let m = s.m;
15275
15276 for k in 0..m {
15277 s.ar[k] = 0.0;
15278 s.ai[k] = 0.0;
15279 }
15280 for k in 0..n {
15281 s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
15282 s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
15283 }
15284
15285 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
15286
15287 for k in 0..m {
15288 let ar = s.ar[k];
15289 let ai = s.ai[k];
15290 let br = s.bf_re[k];
15291 let bi = s.bf_im[k];
15292 s.ar[k] = ar * br - ai * bi;
15293 s.ai[k] = ar * bi + ai * br;
15294 }
15295
15296 fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
15297 let inv_m = 1.0_f32 / (m as f32);
15298
15299 for k in 0..n {
15300 let yr = s.ar[k] * inv_m;
15301 let yi = s.ai[k] * inv_m;
15302 re[k] = yr * s.w_re[k] - yi * s.w_im[k];
15303 im[k] = yr * s.w_im[k] + yi * s.w_re[k];
15304 }
15305}
15306
15307unsafe fn dispatch_custom_op(
15313 kernel: &dyn crate::op_registry::CpuKernel,
15314 inputs: &[(usize, u32, Shape)],
15315 out_off: usize,
15316 out_len: u32,
15317 out_shape: &Shape,
15318 attrs: &[u8],
15319 base: *mut u8,
15320) {
15321 use crate::op_registry::{CpuTensorMut, CpuTensorRef};
15322 use rlx_ir::DType;
15323
15324 macro_rules! build_in_view {
15329 ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
15330 CpuTensorRef::$variant {
15331 data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
15332 shape: $shape,
15333 }
15334 };
15335 }
15336 macro_rules! build_out_view {
15337 ($variant:ident, $rust_ty:ty) => {
15338 CpuTensorMut::$variant {
15339 data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
15340 shape: out_shape,
15341 }
15342 };
15343 }
15344
15345 let in_views: Vec<CpuTensorRef<'_>> = inputs
15346 .iter()
15347 .map(|(off, len, shape)| {
15348 let n = *len as usize;
15349 let off = *off;
15350 match shape.dtype() {
15351 DType::F32 => build_in_view!(shape, off, n, F32, f32),
15352 DType::F64 => build_in_view!(shape, off, n, F64, f64),
15353 DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
15354 DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
15355 DType::I8 => build_in_view!(shape, off, n, I8, i8),
15356 DType::I16 => build_in_view!(shape, off, n, I16, i16),
15357 DType::I32 => build_in_view!(shape, off, n, I32, i32),
15358 DType::I64 => build_in_view!(shape, off, n, I64, i64),
15359 DType::U8 => build_in_view!(shape, off, n, U8, u8),
15360 DType::U32 => build_in_view!(shape, off, n, U32, u32),
15361 DType::Bool => build_in_view!(shape, off, n, Bool, u8),
15362 DType::C64 => panic!(
15366 "Op::Custom kernel input has DType::C64 — built-in \
15367 complex ops handle their own kernels; user-registered \
15368 ops don't yet see complex tensors"
15369 ),
15370 }
15371 })
15372 .collect();
15373
15374 let result = match out_shape.dtype() {
15375 DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
15376 DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
15377 DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
15378 DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
15379 DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
15380 DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
15381 DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
15382 DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
15383 DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
15384 DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
15385 DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
15386 DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
15387 };
15388 if let Err(e) = result {
15389 panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
15390 }
15391}
15392
15393#[inline(always)]
15399unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
15400 if offset == usize::MAX {
15401 return &[];
15402 }
15403 unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
15404}
15405
15406#[inline(always)]
15407unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
15408 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
15409}
15410
15411#[inline(always)]
15413fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
15417 use rlx_ir::op::Activation;
15418 match act {
15419 Activation::Gelu => crate::kernels::par_gelu_inplace(d),
15420 Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
15421 Activation::Silu => crate::kernels::par_silu_inplace(d),
15422 Activation::Relu => {
15423 for v in d.iter_mut() {
15424 *v = v.max(0.0);
15425 }
15426 }
15427 Activation::Sigmoid => {
15428 for v in d.iter_mut() {
15429 *v = 1.0 / (1.0 + (-*v).exp());
15430 }
15431 }
15432 Activation::Tanh => {
15433 for v in d.iter_mut() {
15434 *v = v.tanh();
15435 }
15436 }
15437 Activation::Exp => {
15438 for v in d.iter_mut() {
15439 *v = v.exp();
15440 }
15441 }
15442 Activation::Log => {
15443 for v in d.iter_mut() {
15444 *v = v.ln();
15445 }
15446 }
15447 Activation::Sqrt => {
15448 for v in d.iter_mut() {
15449 *v = v.sqrt();
15450 }
15451 }
15452 Activation::Rsqrt => {
15453 for v in d.iter_mut() {
15454 *v = 1.0 / v.sqrt();
15455 }
15456 }
15457 Activation::Neg => {
15458 for v in d.iter_mut() {
15459 *v = -*v;
15460 }
15461 }
15462 Activation::Abs => {
15463 for v in d.iter_mut() {
15464 *v = v.abs();
15465 }
15466 }
15467 Activation::Round => {
15468 for v in d.iter_mut() {
15469 *v = v.round();
15470 }
15471 }
15472 Activation::Sin => {
15473 for v in d.iter_mut() {
15474 *v = v.sin();
15475 }
15476 }
15477 Activation::Cos => {
15478 for v in d.iter_mut() {
15479 *v = v.cos();
15480 }
15481 }
15482 Activation::Tan => {
15483 for v in d.iter_mut() {
15484 *v = v.tan();
15485 }
15486 }
15487 Activation::Atan => {
15488 for v in d.iter_mut() {
15489 *v = v.atan();
15490 }
15491 }
15492 }
15493}
15494
15495#[allow(clippy::too_many_arguments)]
15504fn im2col(
15505 x: &[f32],
15506 col: &mut [f32],
15507 c_in: usize,
15508 h: usize,
15509 w: usize,
15510 h_out: usize,
15511 w_out: usize,
15512 kh: usize,
15513 kw: usize,
15514 sh: usize,
15515 sw: usize,
15516 ph: usize,
15517 pw: usize,
15518 dh: usize,
15519 dw_dil: usize,
15520) {
15521 let n_dim = h_out * w_out;
15522 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
15523 debug_assert_eq!(x.len(), c_in * h * w);
15524 let h_isz = h as isize;
15525 let w_isz = w as isize;
15526 let ph_isz = ph as isize;
15527 let pw_isz = pw as isize;
15528 for ci in 0..c_in {
15529 for ki in 0..kh {
15530 for kj in 0..kw {
15531 let row = ((ci * kh) + ki) * kw + kj;
15532 let row_off = row * n_dim;
15533 for ho in 0..h_out {
15534 let hi = (ho * sh + ki * dh) as isize - ph_isz;
15535 if hi < 0 || hi >= h_isz {
15536 for wo in 0..w_out {
15537 col[row_off + ho * w_out + wo] = 0.0;
15538 }
15539 continue;
15540 }
15541 let hi = hi as usize;
15542 let in_row_off = (ci * h + hi) * w;
15543 for wo in 0..w_out {
15544 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
15545 col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
15546 0.0
15547 } else {
15548 x[in_row_off + wi as usize]
15549 };
15550 }
15551 }
15552 }
15553 }
15554 }
15555}
15556
15557#[allow(clippy::too_many_arguments)]
15564fn col2im(
15565 col: &[f32],
15566 x: &mut [f32],
15567 c_in: usize,
15568 h: usize,
15569 w: usize,
15570 h_out: usize,
15571 w_out: usize,
15572 kh: usize,
15573 kw: usize,
15574 sh: usize,
15575 sw: usize,
15576 ph: usize,
15577 pw: usize,
15578 dh: usize,
15579 dw_dil: usize,
15580) {
15581 let n_dim = h_out * w_out;
15582 debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
15583 debug_assert_eq!(x.len(), c_in * h * w);
15584 let h_isz = h as isize;
15585 let w_isz = w as isize;
15586 let ph_isz = ph as isize;
15587 let pw_isz = pw as isize;
15588 for ci in 0..c_in {
15589 for ki in 0..kh {
15590 for kj in 0..kw {
15591 let row = ((ci * kh) + ki) * kw + kj;
15592 let row_off = row * n_dim;
15593 for ho in 0..h_out {
15594 let hi = (ho * sh + ki * dh) as isize - ph_isz;
15595 if hi < 0 || hi >= h_isz {
15596 continue;
15597 }
15598 let hi = hi as usize;
15599 let in_row_off = (ci * h + hi) * w;
15600 for wo in 0..w_out {
15601 let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
15602 if wi < 0 || wi >= w_isz {
15603 continue;
15604 }
15605 x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
15606 }
15607 }
15608 }
15609 }
15610 }
15611}
15612
15613fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
15623 match axis {
15624 None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
15625 Some(d) => {
15626 let chan_dim = shape.dim(d).unwrap_static();
15627 let inner: usize = (d + 1..shape.rank())
15628 .map(|i| shape.dim(i).unwrap_static())
15629 .product::<usize>()
15630 .max(1);
15631 (d, chan_dim, inner)
15632 }
15633 }
15634}
15635
15636fn activation_backward_kernel(
15637 act: rlx_ir::op::Activation,
15638 xs: &[f32],
15639 dys: &[f32],
15640 out: &mut [f32],
15641) {
15642 use rlx_ir::op::Activation;
15643 let n = xs.len();
15644 debug_assert_eq!(dys.len(), n);
15645 debug_assert_eq!(out.len(), n);
15646 match act {
15647 Activation::Relu => {
15648 for i in 0..n {
15649 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
15650 }
15651 }
15652 Activation::Sigmoid => {
15653 for i in 0..n {
15654 let s = 1.0 / (1.0 + (-xs[i]).exp());
15655 out[i] = s * (1.0 - s) * dys[i];
15656 }
15657 }
15658 Activation::Tanh => {
15659 for i in 0..n {
15660 let t = xs[i].tanh();
15661 out[i] = (1.0 - t * t) * dys[i];
15662 }
15663 }
15664 Activation::Silu => {
15665 for i in 0..n {
15667 let s = 1.0 / (1.0 + (-xs[i]).exp());
15668 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
15669 }
15670 }
15671 Activation::Gelu => {
15672 const INV_SQRT2: f32 = 0.707_106_77;
15675 const INV_SQRT_2PI: f32 = 0.398_942_3;
15676 for i in 0..n {
15677 let x = xs[i];
15678 let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
15679 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
15680 out[i] = (phi + x * pdf) * dys[i];
15681 }
15682 }
15683 Activation::GeluApprox => {
15684 const C: f32 = 0.797_884_6; const A: f32 = 0.044_715;
15688 for i in 0..n {
15689 let x = xs[i];
15690 let inner = C * (x + A * x * x * x);
15691 let t = inner.tanh();
15692 let dinner = C * (1.0 + 3.0 * A * x * x);
15693 let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
15694 out[i] = d * dys[i];
15695 }
15696 }
15697 Activation::Exp => {
15698 for i in 0..n {
15699 out[i] = xs[i].exp() * dys[i];
15700 }
15701 }
15702 Activation::Log => {
15703 for i in 0..n {
15704 out[i] = dys[i] / xs[i];
15705 }
15706 }
15707 Activation::Sqrt => {
15708 for i in 0..n {
15710 let s = xs[i].sqrt();
15711 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
15712 }
15713 }
15714 Activation::Rsqrt => {
15715 for i in 0..n {
15717 let s = xs[i].sqrt();
15718 out[i] = if s > 0.0 {
15719 -0.5 * dys[i] / (xs[i] * s)
15720 } else {
15721 0.0
15722 };
15723 }
15724 }
15725 Activation::Neg => {
15726 for i in 0..n {
15727 out[i] = -dys[i];
15728 }
15729 }
15730 Activation::Abs => {
15731 for i in 0..n {
15733 let x = xs[i];
15734 let s = if x > 0.0 {
15735 1.0
15736 } else if x < 0.0 {
15737 -1.0
15738 } else {
15739 0.0
15740 };
15741 out[i] = s * dys[i];
15742 }
15743 }
15744 Activation::Round => {
15745 out.copy_from_slice(dys);
15750 }
15751 Activation::Sin => {
15752 for i in 0..n {
15754 out[i] = xs[i].cos() * dys[i];
15755 }
15756 }
15757 Activation::Cos => {
15758 for i in 0..n {
15759 out[i] = -xs[i].sin() * dys[i];
15760 }
15761 }
15762 Activation::Tan => {
15763 for i in 0..n {
15765 let t = xs[i].tan();
15766 out[i] = (1.0 + t * t) * dys[i];
15767 }
15768 }
15769 Activation::Atan => {
15770 for i in 0..n {
15772 let x = xs[i];
15773 out[i] = dys[i] / (1.0 + x * x);
15774 }
15775 }
15776 }
15777}
15778
15779fn activation_backward_kernel_f64(
15783 act: rlx_ir::op::Activation,
15784 xs: &[f64],
15785 dys: &[f64],
15786 out: &mut [f64],
15787) {
15788 use rlx_ir::op::Activation;
15789 let n = xs.len();
15790 debug_assert_eq!(dys.len(), n);
15791 debug_assert_eq!(out.len(), n);
15792 match act {
15793 Activation::Relu => {
15794 for i in 0..n {
15795 out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
15796 }
15797 }
15798 Activation::Sigmoid => {
15799 for i in 0..n {
15800 let s = 1.0 / (1.0 + (-xs[i]).exp());
15801 out[i] = s * (1.0 - s) * dys[i];
15802 }
15803 }
15804 Activation::Tanh => {
15805 for i in 0..n {
15806 let t = xs[i].tanh();
15807 out[i] = (1.0 - t * t) * dys[i];
15808 }
15809 }
15810 Activation::Silu => {
15811 for i in 0..n {
15812 let s = 1.0 / (1.0 + (-xs[i]).exp());
15813 out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
15814 }
15815 }
15816 Activation::Gelu | Activation::GeluApprox => {
15817 const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
15819 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
15820 for i in 0..n {
15821 let x = xs[i];
15822 let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
15823 let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
15824 out[i] = (phi + x * pdf) * dys[i];
15825 }
15826 }
15827 Activation::Exp => {
15828 for i in 0..n {
15829 out[i] = xs[i].exp() * dys[i];
15830 }
15831 }
15832 Activation::Log => {
15833 for i in 0..n {
15834 out[i] = dys[i] / xs[i];
15835 }
15836 }
15837 Activation::Sqrt => {
15838 for i in 0..n {
15839 let s = xs[i].sqrt();
15840 out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
15841 }
15842 }
15843 Activation::Rsqrt => {
15844 for i in 0..n {
15845 let s = xs[i].sqrt();
15846 out[i] = if s > 0.0 {
15847 -0.5 * dys[i] / (xs[i] * s)
15848 } else {
15849 0.0
15850 };
15851 }
15852 }
15853 Activation::Neg => {
15854 for i in 0..n {
15855 out[i] = -dys[i];
15856 }
15857 }
15858 Activation::Abs => {
15859 for i in 0..n {
15860 let x = xs[i];
15861 let s = if x > 0.0 {
15862 1.0
15863 } else if x < 0.0 {
15864 -1.0
15865 } else {
15866 0.0
15867 };
15868 out[i] = s * dys[i];
15869 }
15870 }
15871 Activation::Round => {
15872 out.copy_from_slice(dys);
15873 }
15874 Activation::Sin => {
15875 for i in 0..n {
15876 out[i] = xs[i].cos() * dys[i];
15877 }
15878 }
15879 Activation::Cos => {
15880 for i in 0..n {
15881 out[i] = -xs[i].sin() * dys[i];
15882 }
15883 }
15884 Activation::Tan => {
15885 for i in 0..n {
15886 let t = xs[i].tan();
15887 out[i] = (1.0 + t * t) * dys[i];
15888 }
15889 }
15890 Activation::Atan => {
15891 for i in 0..n {
15892 let x = xs[i];
15893 out[i] = dys[i] / (1.0 + x * x);
15894 }
15895 }
15896 }
15897}
15898
15899#[inline(always)]
15904fn erf_f64(x: f64) -> f64 {
15905 let s = x.signum();
15906 let x = x.abs();
15907 let t = 1.0 / (1.0 + 0.327_591_1 * x);
15908 let y = 1.0
15909 - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
15910 + 0.254_829_59)
15911 * t
15912 * (-x * x).exp();
15913 s * y
15914}
15915
15916#[inline(always)]
15919fn erf_f32(x: f32) -> f32 {
15920 let s = x.signum();
15921 let x = x.abs();
15922 let t = 1.0 / (1.0 + 0.327_591_1 * x);
15923 let y = 1.0
15924 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
15925 + 0.254_829_6)
15926 * t
15927 * (-x * x).exp();
15928 s * y
15929}
15930
15931fn narrow_thunk_closure(
15932 src: usize,
15933 dst: usize,
15934 outer: u32,
15935 src_stride: u32,
15936 dst_stride: u32,
15937 inner: u32,
15938 elem_bytes: u8,
15939) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
15940 let (outer, ss, ds, inner, eb) = (
15941 outer as usize,
15942 src_stride as usize,
15943 dst_stride as usize,
15944 inner as usize,
15945 elem_bytes as usize,
15946 );
15947 let row_bytes = inner.saturating_mul(eb);
15948 let src_row_stride = ss.saturating_mul(eb);
15949 let dst_row_stride = ds.saturating_mul(eb);
15950 Arc::new(move |base: *mut u8| unsafe {
15951 if row_bytes == 0 || src == dst {
15952 return;
15953 }
15954 let arena_len = usize::MAX;
15956 for o in 0..outer {
15957 let s_off = src + o * src_row_stride;
15958 let d_off = dst + o * dst_row_stride;
15959 if s_off == d_off {
15960 continue;
15961 }
15962 if s_off.saturating_add(row_bytes) > arena_len
15963 || d_off.saturating_add(row_bytes) > arena_len
15964 {
15965 break;
15966 }
15967 std::ptr::copy_nonoverlapping(base.add(s_off), base.add(d_off), row_bytes);
15968 }
15969 })
15970}
15971
15972unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
15973 if offset == usize::MAX {
15974 return &[];
15975 }
15976 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
15977}
15978
15979#[inline(always)]
15980unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
15981 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
15982}
15983
15984#[inline(always)]
15985unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
15986 if offset == usize::MAX {
15987 return &[];
15988 }
15989 unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
15990}
15991
15992#[inline(always)]
15993unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
15994 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
15995}
15996
15997#[inline(always)]
16002#[allow(dead_code)]
16003unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
16004 if offset == usize::MAX {
16005 return &[];
16006 }
16007 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
16008}
16009
16010#[inline(always)]
16011#[allow(dead_code)]
16012unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
16013 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
16014}
16015
16016#[inline(always)]
16017unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
16018 if offset == usize::MAX {
16019 return &[];
16020 }
16021 unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
16022}
16023
16024#[inline(always)]
16025unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
16026 unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
16027}
16028
16029fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
16033 let rank = out_dims.len();
16034 let mut idx = vec![0u32; rank];
16035 for o in 0..out.len() {
16036 let mut src_off = 0usize;
16037 for d in 0..rank {
16038 src_off += idx[d] as usize * in_strides[d] as usize;
16039 }
16040 out[o] = inp[src_off];
16041 for d in (0..rank).rev() {
16043 idx[d] += 1;
16044 if idx[d] < out_dims[d] {
16045 break;
16046 }
16047 idx[d] = 0;
16048 }
16049 }
16050}
16051
16052fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
16058 match kind {
16059 Activation::Neg => {
16060 for (o, &v) in out.iter_mut().zip(inp) {
16061 *o = -v;
16062 }
16063 }
16064 Activation::Exp => {
16065 for (o, &v) in out.iter_mut().zip(inp) {
16066 *o = v.exp();
16067 }
16068 }
16069 Activation::Log => {
16070 for (o, &v) in out.iter_mut().zip(inp) {
16071 *o = v.ln();
16072 }
16073 }
16074 Activation::Sqrt => {
16075 for (o, &v) in out.iter_mut().zip(inp) {
16076 *o = v.sqrt();
16077 }
16078 }
16079 Activation::Rsqrt => {
16080 for (o, &v) in out.iter_mut().zip(inp) {
16081 *o = 1.0 / v.sqrt();
16082 }
16083 }
16084 Activation::Abs => {
16085 for (o, &v) in out.iter_mut().zip(inp) {
16086 *o = v.abs();
16087 }
16088 }
16089 Activation::Tanh => {
16090 for (o, &v) in out.iter_mut().zip(inp) {
16091 *o = v.tanh();
16092 }
16093 }
16094 Activation::Sigmoid => {
16095 for (o, &v) in out.iter_mut().zip(inp) {
16096 *o = 1.0 / (1.0 + (-v).exp());
16097 }
16098 }
16099 Activation::Relu => {
16100 for (o, &v) in out.iter_mut().zip(inp) {
16101 *o = v.max(0.0);
16102 }
16103 }
16104 Activation::Round => {
16105 for (o, &v) in out.iter_mut().zip(inp) {
16106 *o = v.round_ties_even();
16107 }
16108 }
16109 Activation::Sin => {
16110 for (o, &v) in out.iter_mut().zip(inp) {
16111 *o = v.sin();
16112 }
16113 }
16114 Activation::Cos => {
16115 for (o, &v) in out.iter_mut().zip(inp) {
16116 *o = v.cos();
16117 }
16118 }
16119 Activation::Tan => {
16120 for (o, &v) in out.iter_mut().zip(inp) {
16121 *o = v.tan();
16122 }
16123 }
16124 Activation::Atan => {
16125 for (o, &v) in out.iter_mut().zip(inp) {
16126 *o = v.atan();
16127 }
16128 }
16129 Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
16130 panic!(
16131 "apply_activation_f64: {kind:?} not yet implemented at f64. \
16132 Add when a workload needs it."
16133 );
16134 }
16135 }
16136}
16137
16138#[inline]
16139fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
16140 match op {
16141 BinaryOp::Add => a + b,
16142 BinaryOp::Sub => a - b,
16143 BinaryOp::Mul => a * b,
16144 BinaryOp::Div => a / b,
16145 BinaryOp::Max => a.max(b),
16146 BinaryOp::Min => a.min(b),
16147 BinaryOp::Pow => a.powf(b),
16148 }
16149}
16150
16151fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
16154 for o in 0..outer {
16155 for n in 0..inner {
16156 let mut acc = 0.0_f64;
16157 for r in 0..reduced {
16158 acc += inp[o * reduced * inner + r * inner + n];
16159 }
16160 out[o * inner + n] = acc;
16161 }
16162 }
16163}
16164
16165#[cfg(test)]
16166mod tests {
16167 use super::*;
16168 use rlx_ir::*;
16169
16170 #[test]
16176 fn narrow_rope_fuses_in_unfused_path() {
16177 let f = DType::F32;
16178 let mut g = Graph::new("nr_fuse");
16179 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); let cos = g.input("cos", Shape::new(&[16], f));
16182 let sin = g.input("sin", Shape::new(&[16], f));
16183 let q = g.narrow_(qkv, 2, 0, 64);
16185 let q_rope = g.rope(q, cos, sin, 16);
16186 g.set_outputs(vec![q_rope]);
16187
16188 let plan = rlx_opt::memory::plan_memory(&g);
16189 let arena = crate::arena::Arena::from_plan(plan);
16190 let sched = compile_thunks(&g, &arena);
16191
16192 let mut narrow_count = 0;
16193 let mut rope_with_stride: Option<u32> = None;
16194 for t in &sched.thunks {
16195 match t {
16196 Thunk::Narrow { .. } => narrow_count += 1,
16197 Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
16198 _ => {}
16199 }
16200 }
16201 assert_eq!(
16204 narrow_count, 0,
16205 "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
16206 );
16207 assert_eq!(
16208 rope_with_stride,
16209 Some(192),
16210 "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
16211 );
16212 }
16213
16214 #[test]
16217 fn ssm_selective_scan_matches_reference() {
16218 use rlx_ir::Philox4x32;
16219 let bch = 1usize;
16220 let s = 4usize;
16221 let h = 3usize;
16222 let n = 2usize;
16223
16224 let mut rng = Philox4x32::new(13);
16225 let mut x = vec![0f32; bch * s * h];
16226 rng.fill_normal(&mut x);
16227 let mut delta = vec![0f32; bch * s * h];
16228 for v in delta.iter_mut() {
16230 *v = (rng.next_f32() - 0.5) * 0.1;
16231 }
16232 let mut a = vec![0f32; h * n];
16233 for v in a.iter_mut() {
16234 *v = -(rng.next_f32() * 0.5 + 0.1);
16235 } let mut b = vec![0f32; bch * s * n];
16237 rng.fill_normal(&mut b);
16238 let mut c = vec![0f32; bch * s * n];
16239 rng.fill_normal(&mut c);
16240
16241 let mut expected = vec![0f32; bch * s * h];
16243 for bi in 0..bch {
16244 let mut state = vec![0f32; h * n];
16245 for si in 0..s {
16246 for ci in 0..h {
16247 let d = delta[bi * s * h + si * h + ci];
16248 let xv = x[bi * s * h + si * h + ci];
16249 let mut acc = 0f32;
16250 for ni in 0..n {
16251 let da = (d * a[ci * n + ni]).exp();
16252 state[ci * n + ni] =
16253 da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
16254 acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
16255 }
16256 expected[bi * s * h + si * h + ci] = acc;
16257 }
16258 }
16259 }
16260
16261 let f = DType::F32;
16263 let mut g = Graph::new("ssm");
16264 let xn = g.input("x", Shape::new(&[bch, s, h], f));
16265 let dn = g.input("delta", Shape::new(&[bch, s, h], f));
16266 let an = g.param("a", Shape::new(&[h, n], f));
16267 let bn = g.param("b", Shape::new(&[bch, s, n], f));
16268 let cn = g.param("c", Shape::new(&[bch, s, n], f));
16269 let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
16270 g.set_outputs(vec![yn]);
16271
16272 let plan = rlx_opt::memory::plan_memory(&g);
16273 let mut arena = crate::arena::Arena::from_plan(plan);
16274 let sched = compile_thunks(&g, &arena);
16275
16276 let xn_off = arena.byte_offset(xn);
16277 let dn_off = arena.byte_offset(dn);
16278 let an_off = arena.byte_offset(an);
16279 let bn_off = arena.byte_offset(bn);
16280 let cn_off = arena.byte_offset(cn);
16281 let yn_off = arena.byte_offset(yn);
16282 let buf = arena.raw_buf_mut();
16283 unsafe {
16284 let copy = |dst: *mut f32, data: &[f32]| {
16285 for (i, &v) in data.iter().enumerate() {
16286 *dst.add(i) = v;
16287 }
16288 };
16289 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
16290 copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
16291 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
16292 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
16293 copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
16294 }
16295 execute_thunks(&sched, arena.raw_buf_mut());
16296
16297 let actual: Vec<f32> = unsafe {
16298 let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
16299 (0..bch * s * h).map(|i| *p.add(i)).collect()
16300 };
16301
16302 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16303 assert!(
16304 (e - a).abs() < 1e-3,
16305 "mismatch at {i}: expected {e}, got {a}"
16306 );
16307 }
16308 }
16309
16310 #[test]
16313 fn conv_1x1_fast_path_matches_scalar() {
16314 use rlx_ir::Philox4x32;
16315 let n = 2usize;
16317 let c_in = 4usize;
16318 let h = 3usize;
16319 let w = 3usize;
16320 let c_out = 5usize;
16321 let mut rng = Philox4x32::new(31);
16322 let mut x = vec![0f32; n * c_in * h * w];
16323 rng.fill_normal(&mut x);
16324 let mut weight = vec![0f32; c_out * c_in];
16325 rng.fill_normal(&mut weight);
16326
16327 let mut expected = vec![0f32; n * c_out * h * w];
16330 for ni in 0..n {
16331 for co in 0..c_out {
16332 for hi in 0..h {
16333 for wi in 0..w {
16334 let mut acc = 0f32;
16335 for ci in 0..c_in {
16336 acc += weight[co * c_in + ci]
16337 * x[((ni * c_in) + ci) * h * w + hi * w + wi];
16338 }
16339 expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
16340 }
16341 }
16342 }
16343 }
16344
16345 let f = DType::F32;
16347 let mut g = Graph::new("conv1x1");
16348 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
16349 let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
16350 let cn = g.add_node(
16352 rlx_ir::Op::Conv {
16353 kernel_size: vec![1, 1],
16354 stride: vec![1, 1],
16355 padding: vec![0, 0],
16356 dilation: vec![1, 1],
16357 groups: 1,
16358 },
16359 vec![xn, wn],
16360 Shape::new(&[n, c_out, h, w], f),
16361 );
16362 g.set_outputs(vec![cn]);
16363
16364 let plan = rlx_opt::memory::plan_memory(&g);
16365 let mut arena = crate::arena::Arena::from_plan(plan);
16366 let sched = compile_thunks(&g, &arena);
16367
16368 let saw_fast = sched
16370 .thunks
16371 .iter()
16372 .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
16373 let saw_slow = sched
16374 .thunks
16375 .iter()
16376 .any(|t| matches!(t, Thunk::Conv2D { .. }));
16377 assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
16378 assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
16379
16380 let xn_off = arena.byte_offset(xn);
16381 let wn_off = arena.byte_offset(wn);
16382 let cn_off = arena.byte_offset(cn);
16383 let buf = arena.raw_buf_mut();
16384 unsafe {
16385 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
16386 for (i, &v) in x.iter().enumerate() {
16387 *xp.add(i) = v;
16388 }
16389 let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
16390 for (i, &v) in weight.iter().enumerate() {
16391 *wp.add(i) = v;
16392 }
16393 }
16394 execute_thunks(&sched, arena.raw_buf_mut());
16395
16396 let actual: Vec<f32> = unsafe {
16397 let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
16398 (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
16399 };
16400
16401 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16402 assert!(
16403 (e - a).abs() < 1e-3,
16404 "mismatch at {i}: expected {e}, got {a}"
16405 );
16406 }
16407 }
16408
16409 #[test]
16412 fn dequant_matmul_int8_sym_matches_reference() {
16413 use rlx_ir::Philox4x32;
16414 use rlx_ir::quant::QuantScheme;
16415
16416 let m = 3usize;
16417 let k = 8usize;
16418 let n = 4usize;
16419 let block_size = 4usize; let blocks_per_col = k / block_size;
16421
16422 let mut rng = Philox4x32::new(99);
16424 let mut x = vec![0f32; m * k];
16425 rng.fill_normal(&mut x);
16426 let w_q: Vec<i8> = (0..(k * n))
16427 .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
16428 .collect();
16429 let scales: Vec<f32> = (0..(blocks_per_col * n))
16430 .map(|i| 0.01 + 0.001 * i as f32)
16431 .collect();
16432
16433 let mut w_f32 = vec![0f32; k * n];
16435 for p in 0..k {
16436 let block = p / block_size;
16437 for j in 0..n {
16438 let s = scales[block * n + j];
16439 w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
16440 }
16441 }
16442 let mut expected = vec![0f32; m * n];
16443 for i in 0..m {
16444 for j in 0..n {
16445 let mut acc = 0f32;
16446 for p in 0..k {
16447 acc += x[i * k + p] * w_f32[p * n + j];
16448 }
16449 expected[i * n + j] = acc;
16450 }
16451 }
16452
16453 let f = DType::F32;
16455 let mut g = Graph::new("dq");
16456 let xn = g.input("x", Shape::new(&[m, k], f));
16457 let wn = g.param("w", Shape::new(&[k, n], DType::I8));
16458 let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
16459 let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); let dq = g.dequant_matmul(
16461 xn,
16462 wn,
16463 sn,
16464 zn,
16465 QuantScheme::Int8Block {
16466 block_size: block_size as u32,
16467 },
16468 Shape::new(&[m, n], f),
16469 );
16470 g.set_outputs(vec![dq]);
16471
16472 let plan = rlx_opt::memory::plan_memory(&g);
16473 let mut arena = crate::arena::Arena::from_plan(plan);
16474 let sched = compile_thunks(&g, &arena);
16475
16476 let xn_off = arena.byte_offset(xn);
16477 let wn_off = arena.byte_offset(wn);
16478 let sn_off = arena.byte_offset(sn);
16479 let zn_off = arena.byte_offset(zn);
16480 let dq_off = arena.byte_offset(dq);
16481 let buf = arena.raw_buf_mut();
16482 unsafe {
16483 let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
16485 for (i, &v) in x.iter().enumerate() {
16486 *xp.add(i) = v;
16487 }
16488 let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
16489 for (i, &v) in scales.iter().enumerate() {
16490 *sp.add(i) = v;
16491 }
16492 let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
16493 for i in 0..(blocks_per_col * n) {
16494 *zp.add(i) = 0.0;
16495 }
16496 let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
16498 for (i, &v) in w_q.iter().enumerate() {
16499 *wp.add(i) = v;
16500 }
16501 }
16502 execute_thunks(&sched, arena.raw_buf_mut());
16503
16504 let actual: Vec<f32> = unsafe {
16505 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
16506 (0..m * n).map(|i| *p.add(i)).collect()
16507 };
16508
16509 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16510 assert!(
16511 (e - a).abs() < 1e-3,
16512 "mismatch at {i}: expected {e}, got {a}"
16513 );
16514 }
16515 }
16516
16517 #[test]
16519 fn lora_matmul_matches_unfused_reference() {
16520 use rlx_ir::Philox4x32;
16521
16522 let m = 4usize;
16523 let k = 8usize;
16524 let n = 6usize;
16525 let r = 2usize;
16526 let scale = 0.5f32;
16527
16528 let mut rng = Philox4x32::new(42);
16530 let mut x = vec![0f32; m * k];
16531 rng.fill_normal(&mut x);
16532 let mut w = vec![0f32; k * n];
16533 rng.fill_normal(&mut w);
16534 let mut a = vec![0f32; k * r];
16535 rng.fill_normal(&mut a);
16536 let mut b = vec![0f32; r * n];
16537 rng.fill_normal(&mut b);
16538
16539 let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
16541 let mut o = vec![0f32; rows * cols];
16542 for i in 0..rows {
16543 for j in 0..cols {
16544 let mut acc = 0f32;
16545 for p in 0..inner {
16546 acc += a_buf[i * inner + p] * b_buf[p * cols + j];
16547 }
16548 o[i * cols + j] = acc;
16549 }
16550 }
16551 o
16552 };
16553 let xw = naive(&x, &w, m, k, n);
16554 let xa = naive(&x, &a, m, k, r);
16555 let xab = naive(&xa, &b, m, r, n);
16556 let mut expected = xw;
16557 for i in 0..(m * n) {
16558 expected[i] += scale * xab[i];
16559 }
16560
16561 let f = DType::F32;
16563 let mut g = Graph::new("lora");
16564 let xn = g.input("x", Shape::new(&[m, k], f));
16565 let wn = g.param("w", Shape::new(&[k, n], f));
16566 let an = g.param("a", Shape::new(&[k, r], f));
16567 let bn = g.param("b", Shape::new(&[r, n], f));
16568 let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
16569 g.set_outputs(vec![lm]);
16570
16571 let plan = rlx_opt::memory::plan_memory(&g);
16572 let mut arena = crate::arena::Arena::from_plan(plan);
16573 let sched = compile_thunks(&g, &arena);
16574
16575 let xn_off = arena.byte_offset(xn);
16576 let wn_off = arena.byte_offset(wn);
16577 let an_off = arena.byte_offset(an);
16578 let bn_off = arena.byte_offset(bn);
16579 let lm_off = arena.byte_offset(lm);
16580 let buf = arena.raw_buf_mut();
16581 unsafe {
16582 let copy = |dst: *mut f32, data: &[f32]| {
16583 for (i, &v) in data.iter().enumerate() {
16584 *dst.add(i) = v;
16585 }
16586 };
16587 copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
16588 copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
16589 copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
16590 copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
16591 }
16592 execute_thunks(&sched, arena.raw_buf_mut());
16593
16594 let actual: Vec<f32> = unsafe {
16595 let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
16596 (0..m * n).map(|i| *p.add(i)).collect()
16597 };
16598
16599 for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16600 assert!(
16601 (e - a).abs() < 1e-3,
16602 "mismatch at {i}: expected {e}, got {a}"
16603 );
16604 }
16605 }
16606
16607 #[test]
16609 fn sample_temperature_zero_is_argmax() {
16610 let f = DType::F32;
16613 let mut g = Graph::new("samp");
16614 let logits = g.input("logits", Shape::new(&[1, 8], f));
16615 let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
16616 g.set_outputs(vec![s]);
16617 let plan = rlx_opt::memory::plan_memory(&g);
16618 let mut arena = crate::arena::Arena::from_plan(plan);
16619 let sched = compile_thunks(&g, &arena);
16620
16621 let logits_off = arena.byte_offset(logits);
16622 let s_off = arena.byte_offset(s);
16623 let buf = arena.raw_buf_mut();
16624 unsafe {
16625 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
16626 let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
16628 for (i, &v) in inputs.iter().enumerate() {
16629 *p.add(i) = v;
16630 }
16631 }
16632 execute_thunks(&sched, arena.raw_buf_mut());
16633
16634 let token = unsafe {
16635 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
16636 *p as usize
16637 };
16638 assert_eq!(token, 5, "low-temp sampling should pick the argmax");
16639 }
16640
16641 #[test]
16642 fn sample_top_k_one_is_deterministic() {
16643 let f = DType::F32;
16645 let mut g = Graph::new("samp_k1");
16646 let logits = g.input("logits", Shape::new(&[1, 4], f));
16647 let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
16648 g.set_outputs(vec![s]);
16649 let plan = rlx_opt::memory::plan_memory(&g);
16650 let mut arena = crate::arena::Arena::from_plan(plan);
16651 let sched = compile_thunks(&g, &arena);
16652
16653 let logits_off = arena.byte_offset(logits);
16654 let s_off = arena.byte_offset(s);
16655 let buf = arena.raw_buf_mut();
16656 unsafe {
16657 let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
16658 let inputs = [0.1f32, 5.0, 0.3, 0.4]; for (i, &v) in inputs.iter().enumerate() {
16660 *p.add(i) = v;
16661 }
16662 }
16663 execute_thunks(&sched, arena.raw_buf_mut());
16664 let token = unsafe {
16665 let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
16666 *p as usize
16667 };
16668 assert_eq!(token, 1);
16669 }
16670
16671 #[test]
16673 fn cumsum_inclusive_matches_naive() {
16674 let f = DType::F32;
16675 let mut g = Graph::new("cumsum");
16676 let x = g.input("x", Shape::new(&[2, 4], f));
16677 let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
16678 g.set_outputs(vec![cs]);
16679 let plan = rlx_opt::memory::plan_memory(&g);
16680 let mut arena = crate::arena::Arena::from_plan(plan);
16681 let sched = compile_thunks(&g, &arena);
16682
16683 let x_off = arena.byte_offset(x);
16685 let out_off = arena.byte_offset(cs);
16686 let buf = arena.raw_buf_mut();
16687 unsafe {
16688 let p = buf.as_mut_ptr().add(x_off) as *mut f32;
16689 let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
16690 for (i, &v) in inputs.iter().enumerate() {
16691 *p.add(i) = v;
16692 }
16693 }
16694 execute_thunks(&sched, arena.raw_buf_mut());
16695
16696 let out: Vec<f32> = unsafe {
16697 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
16698 (0..8).map(|i| *p.add(i)).collect()
16699 };
16700 assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
16701 }
16702
16703 #[test]
16707 fn narrow_attention_fuses_in_unfused_path() {
16708 let f = DType::F32;
16709 let mut g = Graph::new("nattn_fuse");
16710 let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); let mask = g.input("mask", Shape::new(&[8, 16], f));
16713 let q = g.narrow_(qkv, 2, 0, 64);
16714 let k = g.narrow_(qkv, 2, 64, 64);
16715 let v = g.narrow_(qkv, 2, 128, 64);
16716 let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
16717 g.set_outputs(vec![attn]);
16718
16719 let plan = rlx_opt::memory::plan_memory(&g);
16720 let arena = crate::arena::Arena::from_plan(plan);
16721 let sched = compile_thunks(&g, &arena);
16722
16723 let mut narrow_count = 0;
16724 let mut attn_strides: Option<(u32, u32, u32)> = None;
16725 for t in &sched.thunks {
16726 match t {
16727 Thunk::Narrow { .. } => narrow_count += 1,
16728 Thunk::Attention {
16729 q_row_stride,
16730 k_row_stride,
16731 v_row_stride,
16732 ..
16733 } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
16734 _ => {}
16735 }
16736 }
16737 assert_eq!(
16740 narrow_count, 0,
16741 "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
16742 );
16743 assert_eq!(
16744 attn_strides,
16745 Some((192, 192, 192)),
16746 "Attention should walk Q/K/V with parent row stride 192"
16747 );
16748 }
16749
16750 fn run_graph(
16761 g: &Graph,
16762 inputs: &[(NodeId, &[f32])],
16763 out_id: NodeId,
16764 out_len: usize,
16765 ) -> Vec<f32> {
16766 let plan = rlx_opt::memory::plan_memory(g);
16767 let mut arena = crate::arena::Arena::from_plan(plan);
16768 let sched = compile_thunks(g, &arena);
16769 for &(id, data) in inputs {
16770 let off = arena.byte_offset(id);
16771 let buf = arena.raw_buf_mut();
16772 unsafe {
16773 let p = buf.as_mut_ptr().add(off) as *mut f32;
16774 for (i, &v) in data.iter().enumerate() {
16775 *p.add(i) = v;
16776 }
16777 }
16778 }
16779 execute_thunks(&sched, arena.raw_buf_mut());
16780 let off = arena.byte_offset(out_id);
16781 unsafe {
16782 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
16783 (0..out_len).map(|i| *p.add(i)).collect()
16784 }
16785 }
16786
16787 #[test]
16788 fn relu_backward_matches_mask() {
16789 let f = DType::F32;
16790 let len = 7usize;
16791 let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
16792 let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
16793
16794 let mut g = Graph::new("relu_bw");
16795 let xn = g.input("x", Shape::new(&[len], f));
16796 let dyn_ = g.input("dy", Shape::new(&[len], f));
16797 let dx = g.relu_backward(xn, dyn_);
16798 g.set_outputs(vec![dx]);
16799
16800 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
16801 let expected: Vec<f32> = x
16805 .iter()
16806 .zip(&dy)
16807 .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
16808 .collect();
16809 for (a, e) in actual.iter().zip(&expected) {
16810 assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
16811 }
16812 }
16813
16814 #[test]
16815 fn maxpool2d_backward_routes_to_argmax() {
16816 let f = DType::F32;
16817 let x: Vec<f32> = vec![
16819 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
16820 ];
16821 let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
16825
16826 let mut g = Graph::new("maxpool_bw");
16827 let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
16828 let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
16829 let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
16830 g.set_outputs(vec![dx]);
16831
16832 let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
16833 let mut expected = vec![0f32; 16];
16834 expected[5] = 0.5;
16835 expected[7] = 1.0;
16836 expected[13] = 2.0;
16837 expected[15] = 4.0;
16838 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
16839 assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
16840 }
16841 }
16842
16843 #[test]
16844 fn conv2d_backward_input_matches_numerical_gradient() {
16845 use rlx_ir::Philox4x32;
16846 let n = 1usize;
16849 let c_in = 2usize;
16850 let h = 4usize;
16851 let w = 4usize;
16852 let c_out = 3usize;
16853 let kh = 3usize;
16854 let kw = 3usize;
16855 let ph = 1usize;
16856 let pw = 1usize;
16857 let sh = 1usize;
16858 let sw = 1usize;
16859 let h_out = (h + 2 * ph - kh) / sh + 1;
16861 let w_out = (w + 2 * pw - kw) / sw + 1;
16862 assert_eq!(h_out, 4);
16863 assert_eq!(w_out, 4);
16864
16865 let mut rng = Philox4x32::new(7);
16866 let mut x = vec![0f32; n * c_in * h * w];
16867 rng.fill_normal(&mut x);
16868 let mut wt = vec![0f32; c_out * c_in * kh * kw];
16869 rng.fill_normal(&mut wt);
16870 let mut dy = vec![0f32; n * c_out * h_out * w_out];
16871 rng.fill_normal(&mut dy);
16872
16873 let f = DType::F32;
16875 let mut g = Graph::new("conv_bwi");
16876 let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
16877 let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
16878 let dx = g.conv2d_backward_input(
16879 dy_in,
16880 w_in,
16881 Shape::new(&[n, c_in, h, w], f),
16882 vec![kh, kw],
16883 vec![sh, sw],
16884 vec![ph, pw],
16885 vec![1, 1],
16886 1,
16887 );
16888 g.set_outputs(vec![dx]);
16889 let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
16890
16891 let forward = |x: &[f32]| -> Vec<f32> {
16895 let mut out = vec![0f32; n * c_out * h_out * w_out];
16896 for ni in 0..n {
16897 for co in 0..c_out {
16898 for ho in 0..h_out {
16899 for wo in 0..w_out {
16900 let mut acc = 0f32;
16901 for ci in 0..c_in {
16902 for ki in 0..kh {
16903 for kj in 0..kw {
16904 let hi = ho * sh + ki;
16905 let wi = wo * sw + kj;
16906 if hi < ph || wi < pw {
16907 continue;
16908 }
16909 let hi = hi - ph;
16910 let wi = wi - pw;
16911 if hi >= h || wi >= w {
16912 continue;
16913 }
16914 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
16915 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
16916 acc += xv * wv;
16917 }
16918 }
16919 }
16920 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
16921 }
16922 }
16923 }
16924 }
16925 out
16926 };
16927 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
16928 let eps = 1e-3f32;
16929 let mut numerical = vec![0f32; x.len()];
16930 for i in 0..x.len() {
16931 let saved = x[i];
16932 x[i] = saved + eps;
16933 let plus = dot(&forward(&x), &dy);
16934 x[i] = saved - eps;
16935 let minus = dot(&forward(&x), &dy);
16936 x[i] = saved;
16937 numerical[i] = (plus - minus) / (2.0 * eps);
16938 }
16939 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
16940 assert!(
16942 (a - n).abs() < 5e-3,
16943 "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
16944 );
16945 }
16946 }
16947
16948 #[test]
16949 fn conv2d_backward_weight_matches_numerical_gradient() {
16950 use rlx_ir::Philox4x32;
16951 let n = 2usize;
16952 let c_in = 2usize;
16953 let h = 4usize;
16954 let w = 4usize;
16955 let c_out = 2usize;
16956 let kh = 3usize;
16957 let kw = 3usize;
16958 let ph = 0usize;
16959 let pw = 0usize;
16960 let sh = 1usize;
16961 let sw = 1usize;
16962 let h_out = (h + 2 * ph - kh) / sh + 1;
16963 let w_out = (w + 2 * pw - kw) / sw + 1;
16964
16965 let mut rng = Philox4x32::new(11);
16966 let mut x = vec![0f32; n * c_in * h * w];
16967 rng.fill_normal(&mut x);
16968 let mut wt = vec![0f32; c_out * c_in * kh * kw];
16969 rng.fill_normal(&mut wt);
16970 let mut dy = vec![0f32; n * c_out * h_out * w_out];
16971 rng.fill_normal(&mut dy);
16972
16973 let f = DType::F32;
16974 let mut g = Graph::new("conv_bww");
16975 let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
16976 let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
16977 let dwn = g.conv2d_backward_weight(
16978 xn,
16979 dyn_,
16980 Shape::new(&[c_out, c_in, kh, kw], f),
16981 vec![kh, kw],
16982 vec![sh, sw],
16983 vec![ph, pw],
16984 vec![1, 1],
16985 1,
16986 );
16987 g.set_outputs(vec![dwn]);
16988 let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
16989
16990 let forward = |wt: &[f32]| -> Vec<f32> {
16991 let mut out = vec![0f32; n * c_out * h_out * w_out];
16992 for ni in 0..n {
16993 for co in 0..c_out {
16994 for ho in 0..h_out {
16995 for wo in 0..w_out {
16996 let mut acc = 0f32;
16997 for ci in 0..c_in {
16998 for ki in 0..kh {
16999 for kj in 0..kw {
17000 let hi = ho + ki;
17001 let wi = wo + kj;
17002 let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
17003 let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
17004 acc += xv * wv;
17005 }
17006 }
17007 }
17008 out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
17009 }
17010 }
17011 }
17012 }
17013 out
17014 };
17015 let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17016 let eps = 1e-3f32;
17017 let mut numerical = vec![0f32; wt.len()];
17018 for i in 0..wt.len() {
17019 let saved = wt[i];
17020 wt[i] = saved + eps;
17021 let plus = dot(&forward(&wt), &dy);
17022 wt[i] = saved - eps;
17023 let minus = dot(&forward(&wt), &dy);
17024 wt[i] = saved;
17025 numerical[i] = (plus - minus) / (2.0 * eps);
17026 }
17027 for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17028 assert!(
17029 (a - n).abs() < 5e-3,
17030 "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
17031 );
17032 }
17033 }
17034
17035 #[test]
17036 fn softmax_cross_entropy_matches_reference() {
17037 let f = DType::F32;
17038 let logits: Vec<f32> = vec![
17039 1.0, 2.0, 3.0, -1.0, 0.0, 4.0, 5.0, 5.0, 5.0, ];
17043 let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
17044
17045 let mut g = Graph::new("sce");
17046 let lg = g.input("logits", Shape::new(&[3, 3], f));
17047 let lb = g.input("labels", Shape::new(&[3], f));
17048 let loss = g.softmax_cross_entropy_with_logits(lg, lb);
17049 g.set_outputs(vec![loss]);
17050 let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
17051
17052 let mut expected = vec![0f32; 3];
17054 for ni in 0..3 {
17055 let row = &logits[ni * 3..(ni + 1) * 3];
17056 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17057 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17058 let lse = m + sum.ln();
17059 let label_idx = labels[ni] as usize;
17060 expected[ni] = lse - row[label_idx];
17061 }
17062 for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
17063 assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
17064 }
17065 }
17066
17067 #[test]
17068 fn softmax_cross_entropy_backward_matches_numerical_gradient() {
17069 use rlx_ir::Philox4x32;
17070 let n = 4usize;
17071 let c = 5usize;
17072 let mut rng = Philox4x32::new(23);
17073 let mut logits = vec![0f32; n * c];
17074 rng.fill_normal(&mut logits);
17075 let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
17076 let mut d_loss = vec![0f32; n];
17077 rng.fill_normal(&mut d_loss);
17078
17079 let f = DType::F32;
17080 let mut g = Graph::new("sce_bw");
17081 let lg = g.input("logits", Shape::new(&[n, c], f));
17082 let lb = g.input("labels", Shape::new(&[n], f));
17083 let dl = g.input("d_loss", Shape::new(&[n], f));
17084 let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
17085 g.set_outputs(vec![dlogits]);
17086 let analytical = run_graph(
17087 &g,
17088 &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
17089 dlogits,
17090 n * c,
17091 );
17092
17093 let sce_loss = |logits: &[f32]| -> Vec<f32> {
17095 let mut out = vec![0f32; n];
17096 for ni in 0..n {
17097 let row = &logits[ni * c..(ni + 1) * c];
17098 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17099 let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17100 out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
17101 }
17102 out
17103 };
17104 let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
17105 let eps = 1e-3f32;
17106 let mut numerical = vec![0f32; logits.len()];
17107 for i in 0..logits.len() {
17108 let saved = logits[i];
17109 logits[i] = saved + eps;
17110 let plus = dot(&sce_loss(&logits), &d_loss);
17111 logits[i] = saved - eps;
17112 let minus = dot(&sce_loss(&logits), &d_loss);
17113 logits[i] = saved;
17114 numerical[i] = (plus - minus) / (2.0 * eps);
17115 }
17116 for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
17117 assert!(
17118 (a - num).abs() < 5e-3,
17119 "sce_bw[{i}]: analytical {a} vs numerical {num}"
17120 );
17121 }
17122 }
17123
17124 fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
17137 for node in graph.nodes() {
17138 if let Op::Constant { data } = &node.op
17139 && arena.has_buffer(node.id)
17140 && !data.is_empty()
17141 {
17142 let buf = arena.slice_mut(node.id);
17143 let n_floats = data.len() / 4;
17144 let n = buf.len().min(n_floats);
17145 for i in 0..n {
17146 let bytes = [
17147 data[i * 4],
17148 data[i * 4 + 1],
17149 data[i * 4 + 2],
17150 data[i * 4 + 3],
17151 ];
17152 buf[i] = f32::from_le_bytes(bytes);
17153 }
17154 }
17155 }
17156 }
17157
17158 fn prepare(
17162 graph: &Graph,
17163 seed_inputs: &[(NodeId, &[f32])],
17164 ) -> (ThunkSchedule, crate::arena::Arena) {
17165 let plan = rlx_opt::memory::plan_memory(graph);
17166 let mut arena = crate::arena::Arena::from_plan(plan);
17167 let sched = compile_thunks(graph, &arena);
17168 fill_constants_into_arena(graph, &mut arena);
17169 for &(id, data) in seed_inputs {
17170 let off = arena.byte_offset(id);
17171 let buf = arena.raw_buf_mut();
17172 unsafe {
17173 let p = buf.as_mut_ptr().add(off) as *mut f32;
17174 for (i, &v) in data.iter().enumerate() {
17175 *p.add(i) = v;
17176 }
17177 }
17178 }
17179 (sched, arena)
17180 }
17181
17182 fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
17183 let off = arena.byte_offset(id);
17184 unsafe {
17185 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
17186 (0..len).map(|i| *p.add(i)).collect()
17187 }
17188 }
17189
17190 fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
17191 let off = arena.byte_offset(id);
17192 let buf = arena.raw_buf_mut();
17193 unsafe {
17194 let p = buf.as_mut_ptr().add(off) as *mut f32;
17195 for (i, &v) in data.iter().enumerate() {
17196 *p.add(i) = v;
17197 }
17198 }
17199 }
17200
17201 fn prepare_f64(
17203 graph: &Graph,
17204 seed_inputs: &[(NodeId, &[f64])],
17205 ) -> (ThunkSchedule, crate::arena::Arena) {
17206 let plan = rlx_opt::memory::plan_memory(graph);
17207 let mut arena = crate::arena::Arena::from_plan(plan);
17208 let sched = compile_thunks(graph, &arena);
17209 fill_constants_into_arena(graph, &mut arena);
17210 for &(id, data) in seed_inputs {
17211 let off = arena.byte_offset(id);
17212 let buf = arena.raw_buf_mut();
17213 unsafe {
17214 let p = buf.as_mut_ptr().add(off) as *mut f64;
17215 for (i, &v) in data.iter().enumerate() {
17216 *p.add(i) = v;
17217 }
17218 }
17219 }
17220 (sched, arena)
17221 }
17222
17223 fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
17224 let off = arena.byte_offset(id);
17225 unsafe {
17226 let p = arena.raw_buf().as_ptr().add(off) as *const f64;
17227 (0..len).map(|i| *p.add(i)).collect()
17228 }
17229 }
17230
17231 #[test]
17241 fn dense_solve_f64_end_to_end() {
17242 let mut g = Graph::new("solve_e2e");
17243 let a = g.input("A", Shape::new(&[2, 2], DType::F64));
17244 let b = g.input("b", Shape::new(&[2], DType::F64));
17245 let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
17246 g.set_outputs(vec![x]);
17247
17248 let a_data = [2.0, 1.0, 1.0, 3.0_f64];
17249 let b_data = [5.0, 10.0_f64];
17250 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17251 execute_thunks(&sched, arena.raw_buf_mut());
17252
17253 let got = read_arena_f64(&arena, x, 2);
17254 let want = [1.0, 3.0_f64];
17255 for i in 0..2 {
17256 assert!(
17257 (got[i] - want[i]).abs() < 1e-12,
17258 "x[{i}] = {} (expected {})",
17259 got[i],
17260 want[i]
17261 );
17262 }
17263 }
17264
17265 #[test]
17271 fn dense_solve_f64_5x5_laplacian() {
17272 let n = 5usize;
17273 let mut g = Graph::new("solve_5x5");
17274 let a = g.input("A", Shape::new(&[n, n], DType::F64));
17275 let b = g.input("b", Shape::new(&[n], DType::F64));
17276 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17277 g.set_outputs(vec![x]);
17278
17279 let mut a_data = vec![0.0_f64; n * n];
17281 for i in 0..n {
17282 a_data[i * n + i] = 2.0;
17283 if i > 0 {
17284 a_data[i * n + (i - 1)] = -1.0;
17285 }
17286 if i + 1 < n {
17287 a_data[i * n + (i + 1)] = -1.0;
17288 }
17289 }
17290 let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
17291 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17292 execute_thunks(&sched, arena.raw_buf_mut());
17293
17294 let got = read_arena_f64(&arena, x, n);
17295 let mut residual = vec![0.0_f64; n];
17297 for i in 0..n {
17298 for j in 0..n {
17299 residual[i] += a_data[i * n + j] * got[j];
17300 }
17301 }
17302 for i in 0..n {
17303 assert!(
17304 (residual[i] - b_data[i]).abs() < 1e-10,
17305 "row {i}: residual {} vs b {}",
17306 residual[i],
17307 b_data[i]
17308 );
17309 }
17310 }
17311
17312 #[test]
17331 fn hello_resistor_gradient_end_to_end() {
17332 use rlx_opt::autodiff::grad_with_loss;
17333 let n = 3usize;
17334
17335 let mut g = Graph::new("hello_resistor");
17337 let a = g.param("A", Shape::new(&[n, n], DType::F64));
17338 let b = g.input("b", Shape::new(&[n], DType::F64));
17339 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17340 let loss = g.reduce(
17341 x,
17342 ReduceOp::Sum,
17343 vec![0],
17344 false,
17345 Shape::new(&[1], DType::F64),
17346 );
17347 g.set_outputs(vec![loss]);
17348
17349 let bwd = grad_with_loss(&g, &[a, b]);
17351 assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
17352
17353 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17357 for node in graph.nodes() {
17358 let name = match &node.op {
17359 rlx_ir::Op::Input { name } => Some(name.as_str()),
17360 rlx_ir::Op::Param { name } => Some(name.as_str()),
17361 _ => None,
17362 };
17363 if name == Some(want) {
17364 return node.id;
17365 }
17366 }
17367 panic!("no node named {want:?} in bwd graph");
17368 };
17369 let a_bwd = find_by_name(&bwd, "A");
17370 let b_bwd = find_by_name(&bwd, "b");
17371 let d_out_bwd = find_by_name(&bwd, "d_output");
17372
17373 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17377 let b_data = [1.0, 2.0, 3.0_f64];
17378 let d_output = [1.0_f64]; let (sched, mut arena) = prepare_f64(
17382 &bwd,
17383 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
17384 );
17385 execute_thunks(&sched, arena.raw_buf_mut());
17386
17387 let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
17388 let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
17389 let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
17390
17391 let x_ref = {
17394 let mut a = a_data;
17395 let mut b = b_data;
17396 let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
17397 assert_eq!(info, 0);
17398 b
17399 };
17400 let loss_ref: f64 = x_ref.iter().sum();
17401 let db_ref = {
17403 let mut at = [0.0_f64; 9];
17404 for i in 0..n {
17405 for j in 0..n {
17406 at[i * n + j] = a_data[j * n + i];
17407 }
17408 }
17409 let mut ones = [1.0_f64; 3];
17410 let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
17411 assert_eq!(info, 0);
17412 ones
17413 };
17414 let mut da_ref = [0.0_f64; 9];
17416 for i in 0..n {
17417 for j in 0..n {
17418 da_ref[i * n + j] = -db_ref[i] * x_ref[j];
17419 }
17420 }
17421
17422 assert!(
17424 (loss_out[0] - loss_ref).abs() < 1e-10,
17425 "loss: got {}, want {}",
17426 loss_out[0],
17427 loss_ref
17428 );
17429 for i in 0..n {
17430 assert!(
17431 (db_out[i] - db_ref[i]).abs() < 1e-10,
17432 "db[{i}]: got {}, want {}",
17433 db_out[i],
17434 db_ref[i]
17435 );
17436 }
17437 for i in 0..n * n {
17438 assert!(
17439 (da_out[i] - da_ref[i]).abs() < 1e-10,
17440 "dA[{i}]: got {}, want {}",
17441 da_out[i],
17442 da_ref[i]
17443 );
17444 }
17445
17446 let h = 1e-6_f64;
17449 for k in 0..n {
17450 let mut bp = b_data;
17451 bp[k] += h;
17452 let mut bm = b_data;
17453 bm[k] -= h;
17454 let lp = {
17455 let mut ac = a_data;
17456 let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
17457 assert_eq!(info, 0);
17458 bp.iter().sum::<f64>()
17459 };
17460 let lm = {
17461 let mut ac = a_data;
17462 let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
17463 assert_eq!(info, 0);
17464 bm.iter().sum::<f64>()
17465 };
17466 let fd = (lp - lm) / (2.0 * h);
17467 assert!(
17468 (db_out[k] - fd).abs() < 1e-7,
17469 "FD mismatch on db[{k}]: AD={} FD={}",
17470 db_out[k],
17471 fd
17472 );
17473 }
17474 }
17475
17476 #[test]
17481 fn scan_geometric_growth_f64() {
17482 let n = 3usize;
17483 let length = 10u32;
17484
17485 let mut body = Graph::new("scan_body");
17487 let x = body.input("carry", Shape::new(&[n], DType::F64));
17488 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
17489 let scale = body.add_node(
17490 Op::Constant { data: scale_bytes },
17491 vec![],
17492 Shape::new(&[n], DType::F64),
17493 );
17494 let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
17495 let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
17496 body.set_outputs(vec![next]);
17497
17498 let mut g = Graph::new("scan_outer");
17500 let init = g.input("init", Shape::new(&[n], DType::F64));
17501 let final_carry = g.scan(init, body, length);
17502 g.set_outputs(vec![final_carry]);
17503
17504 let init_data = vec![1.0_f64; n];
17505 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17506 execute_thunks(&sched, arena.raw_buf_mut());
17507 let got = read_arena_f64(&arena, final_carry, n);
17508 let want: f64 = 1.1_f64.powi(length as i32);
17509 for i in 0..n {
17510 assert!(
17511 (got[i] - want).abs() < 1e-12,
17512 "got[{i}] = {} want {}",
17513 got[i],
17514 want
17515 );
17516 }
17517 }
17518
17519 #[test]
17526 fn scan_with_xs_cumulative_sum() {
17527 let n = 3usize;
17528 let length = 4u32;
17529
17530 let mut body = Graph::new("cumsum_body");
17531 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17533 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
17534 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
17535 body.set_outputs(vec![next]);
17536
17537 let mut g = Graph::new("cumsum_outer");
17538 let init = g.input("init", Shape::new(&[n], DType::F64));
17539 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17540 let final_carry = g.scan_with_xs(init, &[xs], body, length);
17541 g.set_outputs(vec![final_carry]);
17542
17543 let init_data = vec![0.0_f64; n];
17544 let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
17546 execute_thunks(&sched, arena.raw_buf_mut());
17547 let got = read_arena_f64(&arena, final_carry, n);
17548
17549 let mut want = init_data.clone();
17553 for t in 0..length as usize {
17554 for j in 0..n {
17555 want[j] += xs_data[t * n + j];
17556 }
17557 }
17558 for i in 0..n {
17559 assert!(
17560 (got[i] - want[i]).abs() < 1e-12,
17561 "got[{i}] = {} want {}",
17562 got[i],
17563 want[i]
17564 );
17565 }
17566 }
17567
17568 #[test]
17572 fn scan_with_xs_be_with_drive() {
17573 let n = 3usize;
17574 let length = 4u32;
17575 let dt = 0.1_f64;
17576
17577 let mut m_data = vec![0.0_f64; n * n];
17578 for i in 0..n {
17579 m_data[i * n + i] = 1.0 + dt * 2.0;
17580 if i > 0 {
17581 m_data[i * n + (i - 1)] = -dt;
17582 }
17583 if i + 1 < n {
17584 m_data[i * n + (i + 1)] = -dt;
17585 }
17586 }
17587 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17588
17589 let mut body = Graph::new("be_drive_body");
17590 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17591 let drive = body.input("drive", Shape::new(&[n], DType::F64));
17592 let m = body.add_node(
17593 Op::Constant { data: m_bytes },
17594 vec![],
17595 Shape::new(&[n, n], DType::F64),
17596 );
17597 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
17598 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
17599 body.set_outputs(vec![next]);
17600
17601 let mut g = Graph::new("be_drive_outer");
17602 let init = g.input("init", Shape::new(&[n], DType::F64));
17603 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17604 let final_carry = g.scan_with_xs(init, &[xs], body, length);
17605 g.set_outputs(vec![final_carry]);
17606
17607 let init_data = vec![0.0_f64; n];
17608 let mut xs_data = vec![0.0_f64; length as usize * n];
17611 xs_data[0] = 1.0;
17612
17613 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
17614 execute_thunks(&sched, arena.raw_buf_mut());
17615 let got = read_arena_f64(&arena, final_carry, n);
17616
17617 let mut x = init_data.clone();
17619 for t in 0..length as usize {
17620 for j in 0..n {
17621 x[j] += xs_data[t * n + j];
17622 }
17623 let mut a_copy = m_data.clone();
17624 crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
17625 }
17626 for i in 0..n {
17627 assert!(
17628 (got[i] - x[i]).abs() < 1e-12,
17629 "got[{i}] = {} ref {}",
17630 got[i],
17631 x[i]
17632 );
17633 }
17634 }
17635
17636 #[test]
17642 fn batched_dense_solve_gradient_matches_per_batch_analytic() {
17643 use rlx_opt::autodiff::grad_with_loss;
17644 let n = 3usize;
17645 let batch = 4usize;
17646
17647 let mut g = Graph::new("bds_grad");
17648 let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
17649 let b = g.input("b", Shape::new(&[batch, n], DType::F64));
17650 let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
17651 let loss = g.reduce(
17652 x,
17653 ReduceOp::Sum,
17654 vec![0, 1],
17655 false,
17656 Shape::new(&[1], DType::F64),
17657 );
17658 g.set_outputs(vec![loss]);
17659
17660 let bwd = grad_with_loss(&g, &[a, b]);
17661
17662 let find = |graph: &Graph, want: &str| -> NodeId {
17663 for node in graph.nodes() {
17664 let name = match &node.op {
17665 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17666 _ => None,
17667 };
17668 if name == Some(want) {
17669 return node.id;
17670 }
17671 }
17672 panic!("no node named {want}");
17673 };
17674 let a_id = find(&bwd, "A");
17675 let b_id = find(&bwd, "b");
17676 let d_out_id = find(&bwd, "d_output");
17677
17678 let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
17679 let mut a_data = vec![0.0_f64; batch * n * n];
17680 let mut b_data = vec![0.0_f64; batch * n];
17681 for bi in 0..batch {
17682 for i in 0..n {
17683 for j in 0..n {
17684 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
17685 }
17686 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
17687 }
17688 for i in 0..n {
17689 b_data[bi * n + i] = rng.next_f32() as f64;
17690 }
17691 }
17692 let d_seed = [1.0_f64];
17693
17694 let (sched, mut arena) = prepare_f64(
17695 &bwd,
17696 &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
17697 );
17698 execute_thunks(&sched, arena.raw_buf_mut());
17699 let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
17700 let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
17701
17702 for bi in 0..batch {
17705 let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
17706 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
17707 let mut a_copy = a_slice.clone();
17708 crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
17709 let x_ref = b_slice.clone();
17710 let mut at = vec![0.0_f64; n * n];
17712 for i in 0..n {
17713 for j in 0..n {
17714 at[i * n + j] = a_slice[j * n + i];
17715 }
17716 }
17717 let mut ones = vec![1.0_f64; n];
17718 crate::blas::dgesv(&mut at, &mut ones, n, 1);
17719 let db_ref = ones;
17720 for i in 0..n {
17721 let got = db_out[bi * n + i];
17722 assert!(
17723 (got - db_ref[i]).abs() < 1e-10,
17724 "batch {bi}, db[{i}]: got {got} ref {}",
17725 db_ref[i]
17726 );
17727 }
17728 for i in 0..n {
17730 for j in 0..n {
17731 let got = da_out[bi * n * n + i * n + j];
17732 let want = -db_ref[i] * x_ref[j];
17733 assert!(
17734 (got - want).abs() < 1e-10,
17735 "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
17736 );
17737 }
17738 }
17739 }
17740 }
17741
17742 #[test]
17747 fn scan_checkpointed_grad_matches_plain_scan_grad() {
17748 use rlx_opt::autodiff::grad_with_loss;
17749 let n = 2usize;
17750 let length = 6u32;
17751
17752 let make_body = || {
17753 let mut body = Graph::new("ck_body");
17754 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17755 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
17756 let scale = body.add_node(
17757 Op::Constant { data: scale_bytes },
17758 vec![],
17759 Shape::new(&[n], DType::F64),
17760 );
17761 let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
17762 body.set_outputs(vec![next]);
17763 body
17764 };
17765
17766 let mut g_plain = Graph::new("ck_plain");
17768 let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
17769 let final_p = g_plain.scan(init_p, make_body(), length);
17770 let loss_p = g_plain.reduce(
17771 final_p,
17772 ReduceOp::Sum,
17773 vec![0],
17774 false,
17775 Shape::new(&[1], DType::F64),
17776 );
17777 g_plain.set_outputs(vec![loss_p]);
17778 let bwd_p = grad_with_loss(&g_plain, &[init_p]);
17779
17780 let mut g_ck = Graph::new("ck_ckpt");
17782 let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
17783 let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
17784 let loss_c = g_ck.reduce(
17785 final_c,
17786 ReduceOp::Sum,
17787 vec![0],
17788 false,
17789 Shape::new(&[1], DType::F64),
17790 );
17791 g_ck.set_outputs(vec![loss_c]);
17792 let bwd_c = grad_with_loss(&g_ck, &[init_c]);
17793
17794 let find = |graph: &Graph, want: &str| -> NodeId {
17795 for node in graph.nodes() {
17796 let name = match &node.op {
17797 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17798 _ => None,
17799 };
17800 if name == Some(want) {
17801 return node.id;
17802 }
17803 }
17804 panic!("no {want}");
17805 };
17806
17807 let init_data = vec![0.5_f64, -0.5];
17808 let d_seed = [1.0_f64];
17809
17810 let (s_p, mut a_p) = prepare_f64(
17811 &bwd_p,
17812 &[
17813 (find(&bwd_p, "init"), &init_data),
17814 (find(&bwd_p, "d_output"), &d_seed),
17815 ],
17816 );
17817 execute_thunks(&s_p, a_p.raw_buf_mut());
17818 let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
17819
17820 let (s_c, mut a_c) = prepare_f64(
17821 &bwd_c,
17822 &[
17823 (find(&bwd_c, "init"), &init_data),
17824 (find(&bwd_c, "d_output"), &d_seed),
17825 ],
17826 );
17827 execute_thunks(&s_c, a_c.raw_buf_mut());
17828 let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
17829
17830 for i in 0..n {
17831 assert!(
17832 (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
17833 "dinit[{i}]: plain={} checkpointed={}",
17834 dinit_p[i],
17835 dinit_c[i]
17836 );
17837 }
17838 }
17839
17840 #[test]
17846 fn recursive_checkpointing_matches_full_trajectory() {
17847 let n = 2usize;
17848 let length = 4u32;
17849
17850 let build_body = || -> Graph {
17852 let mut body = Graph::new("rc_body");
17853 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17854 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
17855 let ones = body.add_node(
17856 Op::Constant { data: ones_bytes },
17857 vec![],
17858 Shape::new(&[n], DType::F64),
17859 );
17860 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
17861 body.set_outputs(vec![next]);
17862 body
17863 };
17864
17865 let body_vjp_for = || -> Graph {
17868 use rlx_opt::autodiff::grad;
17869 let body = build_body();
17870 let carry_id = body
17872 .nodes()
17873 .iter()
17874 .find(|n| matches!(n.op, Op::Input { .. }))
17875 .map(|n| n.id)
17876 .unwrap();
17877 grad(&body, &[carry_id])
17878 };
17879
17880 let mut g_full = Graph::new("rc_outer_full");
17882 let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
17883 let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
17884 let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
17886 let dinit_full_id = g_full.scan_backward(
17887 init_full,
17888 traj_full_id,
17889 upstream_full,
17890 &[],
17891 body_vjp_for(),
17892 length,
17893 true,
17894 Shape::new(&[n], DType::F64),
17895 );
17896 g_full.set_outputs(vec![dinit_full_id]);
17897
17898 let k = 2u32;
17901 let mut g_rec = Graph::new("rc_outer_rec");
17902 let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
17903 let traj_rec_id = g_rec.add_node(
17904 Op::Scan {
17905 body: Box::new(build_body()),
17906 length,
17907 save_trajectory: true,
17908 num_bcast: 0,
17909 num_xs: 0,
17910 num_checkpoints: k,
17911 },
17912 vec![init_rec],
17913 Shape::new(&[k as usize, n], DType::F64),
17914 );
17915 let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
17918 let dinit_rec_id = g_rec.add_node(
17919 Op::ScanBackward {
17920 body_vjp: Box::new(body_vjp_for()),
17921 length,
17922 save_trajectory: true,
17923 num_xs: 0,
17924 num_checkpoints: k,
17925 forward_body: Some(Box::new(build_body())),
17926 },
17927 vec![init_rec, traj_rec_id, upstream_rec],
17928 Shape::new(&[n], DType::F64),
17929 );
17930 g_rec.set_outputs(vec![dinit_rec_id]);
17931
17932 let init_data = vec![0.5_f64, -0.5];
17934 let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
17935
17936 let find = |graph: &Graph, want: &str| -> NodeId {
17937 for node in graph.nodes() {
17938 if let Op::Input { name } = &node.op
17939 && name == want
17940 {
17941 return node.id;
17942 }
17943 }
17944 panic!("no input {want}");
17945 };
17946
17947 let (s_full, mut a_full) = prepare_f64(
17948 &g_full,
17949 &[
17950 (find(&g_full, "init"), &init_data),
17951 (find(&g_full, "upstream"), &upstream_data),
17952 ],
17953 );
17954 execute_thunks(&s_full, a_full.raw_buf_mut());
17955 let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
17956
17957 let (s_rec, mut a_rec) = prepare_f64(
17958 &g_rec,
17959 &[
17960 (find(&g_rec, "init"), &init_data),
17961 (find(&g_rec, "upstream"), &upstream_data),
17962 ],
17963 );
17964 execute_thunks(&s_rec, a_rec.raw_buf_mut());
17965 let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
17966
17967 for i in 0..n {
17968 assert!(
17969 (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
17970 "i={i}: full={} rec={}",
17971 dinit_full[i],
17972 dinit_rec[i]
17973 );
17974 }
17975 }
17976
17977 #[test]
17986 fn vmap_of_grad_scan_matches_per_row_runs() {
17987 use rlx_opt::autodiff::grad_with_loss;
17988 use rlx_opt::vmap::vmap;
17989 let n = 2usize;
17990 let length = 3u32;
17991 let batch = 3usize;
17992
17993 let mut body = Graph::new("scan_grad_body");
17994 let carry = body.input("carry", Shape::new(&[n], DType::F64));
17995 let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
17996 let ones = body.add_node(
17997 Op::Constant { data: ones_bytes },
17998 vec![],
17999 Shape::new(&[n], DType::F64),
18000 );
18001 let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
18002 body.set_outputs(vec![next]);
18003
18004 let mut g = Graph::new("scan_grad_outer");
18005 let init = g.input("init", Shape::new(&[n], DType::F64));
18006 let final_x = g.scan(init, body, length);
18007 let loss = g.reduce(
18008 final_x,
18009 ReduceOp::Sum,
18010 vec![0],
18011 false,
18012 Shape::new(&[1], DType::F64),
18013 );
18014 g.set_outputs(vec![loss]);
18015
18016 let bwd = grad_with_loss(&g, &[init]);
18017 let bg = vmap(&bwd, &["init"], batch);
18018
18019 let find = |graph: &Graph, want: &str| -> NodeId {
18020 for node in graph.nodes() {
18021 let name = match &node.op {
18022 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18023 _ => None,
18024 };
18025 if name == Some(want) {
18026 return node.id;
18027 }
18028 }
18029 panic!("no node named {want}");
18030 };
18031 let init_b = find(&bg, "init");
18032 let d_out_b = find(&bg, "d_output");
18033
18034 let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
18035 let d_seed = [1.0_f64];
18036
18037 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
18038 execute_thunks(&sched, arena.raw_buf_mut());
18039 let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
18040
18041 for i in 0..batch * n {
18042 assert!(
18043 (dinit_b[i] - 1.0).abs() < 1e-12,
18044 "dinit[{i}] = {} (expected 1.0)",
18045 dinit_b[i]
18046 );
18047 }
18048
18049 for bi in 0..batch {
18051 let row = &init_data[bi * n..(bi + 1) * n];
18052 let mut g2 = Graph::new("per_row_grad");
18053 let init2 = g2.input("init", Shape::new(&[n], DType::F64));
18054 let mut body2 = Graph::new("per_row_body");
18055 let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
18056 let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18057 let ones2 = body2.add_node(
18058 Op::Constant { data: ones2_bytes },
18059 vec![],
18060 Shape::new(&[n], DType::F64),
18061 );
18062 let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
18063 body2.set_outputs(vec![next2]);
18064 let final2 = g2.scan(init2, body2, length);
18065 let loss2 = g2.reduce(
18066 final2,
18067 ReduceOp::Sum,
18068 vec![0],
18069 false,
18070 Shape::new(&[1], DType::F64),
18071 );
18072 g2.set_outputs(vec![loss2]);
18073 let bwd2 = grad_with_loss(&g2, &[init2]);
18074 let init2_id = find(&bwd2, "init");
18075 let d_out2_id = find(&bwd2, "d_output");
18076 let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
18077 execute_thunks(&s2, a2.raw_buf_mut());
18078 let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
18079 for j in 0..n {
18080 let got = dinit_b[bi * n + j];
18081 let want = row_dinit[j];
18082 assert!(
18083 (got - want).abs() < 1e-12,
18084 "row {bi}, j {j}: vmap'd={got} per-row={want}"
18085 );
18086 }
18087 }
18088 }
18089
18090 #[test]
18096 fn vmap_scan_cumulative_sum_matches_scalar_runs() {
18097 use rlx_opt::vmap::vmap;
18098 let n = 2usize;
18099 let length = 4u32;
18100 let batch = 3usize;
18101
18102 let mut body = Graph::new("scan_body_cumsum");
18104 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18105 let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
18106 let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
18107 body.set_outputs(vec![next]);
18108
18109 let mut g = Graph::new("scan_outer_cumsum");
18110 let init = g.input("init", Shape::new(&[n], DType::F64));
18111 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18112 let final_carry = g.scan_with_xs(init, &[xs], body, length);
18113 g.set_outputs(vec![final_carry]);
18114
18115 let bg = vmap(&g, &["init", "xs"], batch);
18117
18118 let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
18120 let xs_data: Vec<f64> = (0..batch * length as usize * n)
18123 .map(|i| 0.1 * (i as f64))
18124 .collect();
18125
18126 let find = |graph: &Graph, want: &str| -> NodeId {
18127 for node in graph.nodes() {
18128 if let Op::Input { name } = &node.op
18129 && name == want
18130 {
18131 return node.id;
18132 }
18133 }
18134 panic!("no input {want}");
18135 };
18136 let init_b = find(&bg, "init");
18137 let xs_b = find(&bg, "xs");
18138 let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
18139 execute_thunks(&sched, arena.raw_buf_mut());
18140 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
18141
18142 for bi in 0..batch {
18144 let init_slice = &init_data[bi * n..(bi + 1) * n];
18145 let mut x = init_slice.to_vec();
18146 for t in 0..length as usize {
18147 for j in 0..n {
18148 x[j] += xs_data[bi * length as usize * n + t * n + j];
18149 }
18150 }
18151
18152 for i in 0..n {
18153 let got = batched_out[bi * n + i];
18154 assert!(
18155 (got - x[i]).abs() < 1e-12,
18156 "row {bi}, i {i}: got {got} ref {}",
18157 x[i]
18158 );
18159 }
18160 }
18161 }
18162
18163 #[test]
18168 fn vmap_dense_solve_matches_scalar_runs() {
18169 use rlx_opt::vmap::vmap;
18170 let n = 3usize;
18171 let batch = 4usize;
18172
18173 let mut g = Graph::new("solve_forward");
18174 let a = g.input("A", Shape::new(&[n, n], DType::F64));
18175 let b = g.input("b", Shape::new(&[n], DType::F64));
18176 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
18177 g.set_outputs(vec![x]);
18178
18179 let bg = vmap(&g, &["A", "b"], batch);
18181
18182 let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
18184 let mut a_data = vec![0.0_f64; batch * n * n];
18185 let mut b_data = vec![0.0_f64; batch * n];
18186 for bi in 0..batch {
18187 for i in 0..n {
18189 for j in 0..n {
18190 a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
18191 }
18192 a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
18193 }
18194 for i in 0..n {
18195 b_data[bi * n + i] = rng.next_f32() as f64;
18196 }
18197 }
18198
18199 let find = |graph: &Graph, want: &str| -> NodeId {
18200 for node in graph.nodes() {
18201 if let Op::Input { name } = &node.op
18202 && name == want
18203 {
18204 return node.id;
18205 }
18206 }
18207 panic!("no input named {want}");
18208 };
18209 let ba = find(&bg, "A");
18210 let bb = find(&bg, "b");
18211 let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
18212 execute_thunks(&sched, arena.raw_buf_mut());
18213 let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
18214
18215 for bi in 0..batch {
18217 let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
18218 let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
18219 crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
18220 for i in 0..n {
18221 let got = batched_x[bi * n + i];
18222 let want = b_slice[i];
18223 assert!(
18224 (got - want).abs() < 1e-12,
18225 "row {bi}, i {i}: got {got} want {want}"
18226 );
18227 }
18228 }
18229 }
18230
18231 #[test]
18238 fn vmap_matmul_add_reduce_matches_scalar_runs() {
18239 use rlx_opt::vmap::vmap;
18240 let n = 3usize;
18241 let batch = 4usize;
18242
18243 let mut g = Graph::new("vmap_e2e_forward");
18245 let x = g.input("x", Shape::new(&[n], DType::F64));
18246 let w = g.input("w", Shape::new(&[n, n], DType::F64));
18247 let b = g.input("b", Shape::new(&[n], DType::F64));
18248 let x_row = g.add_node(
18249 Op::Reshape {
18250 new_shape: vec![1, n as i64],
18251 },
18252 vec![x],
18253 Shape::new(&[1, n], DType::F64),
18254 );
18255 let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
18256 let mm_flat = g.add_node(
18257 Op::Reshape {
18258 new_shape: vec![n as i64],
18259 },
18260 vec![mm],
18261 Shape::new(&[n], DType::F64),
18262 );
18263 let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
18264 let loss = g.reduce(
18265 yv,
18266 ReduceOp::Sum,
18267 vec![0],
18268 false,
18269 Shape::new(&[1], DType::F64),
18270 );
18271 g.set_outputs(vec![loss]);
18272
18273 let bg = vmap(&g, &["x"], batch);
18275
18276 let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
18278 let n_w = n * n;
18279 let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
18280 let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
18281 let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
18282 for _ in 0..batch * n {
18283 x_data_batched.push(rng.next_f32() as f64);
18284 }
18285
18286 let find = |graph: &Graph, want: &str| -> NodeId {
18288 for node in graph.nodes() {
18289 if let Op::Input { name } = &node.op
18290 && name == want
18291 {
18292 return node.id;
18293 }
18294 }
18295 panic!("no input named {want}");
18296 };
18297 let bx = find(&bg, "x");
18298 let bw = find(&bg, "w");
18299 let bb = find(&bg, "b");
18300 let (sched, mut arena) =
18301 prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
18302 execute_thunks(&sched, arena.raw_buf_mut());
18303 let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
18309
18310 for bi in 0..batch {
18312 let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
18313 let mut g2 = Graph::new("scalar_run");
18314 let x2 = g2.input("x", Shape::new(&[n], DType::F64));
18315 let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
18316 let b2 = g2.input("b", Shape::new(&[n], DType::F64));
18317 let xr = g2.add_node(
18318 Op::Reshape {
18319 new_shape: vec![1, n as i64],
18320 },
18321 vec![x2],
18322 Shape::new(&[1, n], DType::F64),
18323 );
18324 let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
18325 let mf = g2.add_node(
18326 Op::Reshape {
18327 new_shape: vec![n as i64],
18328 },
18329 vec![m],
18330 Shape::new(&[n], DType::F64),
18331 );
18332 let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
18333 let l2 = g2.reduce(
18334 yv2,
18335 ReduceOp::Sum,
18336 vec![0],
18337 false,
18338 Shape::new(&[1], DType::F64),
18339 );
18340 g2.set_outputs(vec![l2]);
18341 let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
18342 execute_thunks(&s2, a2.raw_buf_mut());
18343 let scalar_out = read_arena_f64(&a2, l2, 1);
18344 assert!(
18345 (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
18346 "row {bi}: batched={} scalar={}",
18347 batched_out[bi],
18348 scalar_out[0]
18349 );
18350 }
18351 }
18352
18353 #[test]
18360 fn scan_with_xs_dxs_matches_fd() {
18361 use rlx_opt::autodiff::grad_with_loss;
18362 let n = 3usize;
18363 let length = 3u32;
18364 let dt = 0.1_f64;
18365
18366 let mut m_data = vec![0.0_f64; n * n];
18367 for i in 0..n {
18368 m_data[i * n + i] = 1.0 + dt * 2.0;
18369 if i > 0 {
18370 m_data[i * n + (i - 1)] = -dt;
18371 }
18372 if i + 1 < n {
18373 m_data[i * n + (i + 1)] = -dt;
18374 }
18375 }
18376 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18377
18378 let mut body = Graph::new("be_dxs_body");
18379 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18380 let drive = body.input("drive", Shape::new(&[n], DType::F64));
18381 let m = body.add_node(
18382 Op::Constant { data: m_bytes },
18383 vec![],
18384 Shape::new(&[n, n], DType::F64),
18385 );
18386 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18387 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18388 body.set_outputs(vec![next]);
18389
18390 let mut g = Graph::new("be_dxs_outer");
18391 let init = g.input("init", Shape::new(&[n], DType::F64));
18392 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18393 let final_carry = g.scan_with_xs(init, &[xs], body, length);
18394 let loss = g.reduce(
18395 final_carry,
18396 ReduceOp::Sum,
18397 vec![0],
18398 false,
18399 Shape::new(&[1], DType::F64),
18400 );
18401 g.set_outputs(vec![loss]);
18402
18403 let bwd = grad_with_loss(&g, &[init, xs]);
18405 assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
18406
18407 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18408 for node in graph.nodes() {
18409 let name = match &node.op {
18410 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18411 _ => None,
18412 };
18413 if name == Some(want) {
18414 return node.id;
18415 }
18416 }
18417 panic!("no node named {want:?}");
18418 };
18419 let init_bwd = find_by_name(&bwd, "init");
18420 let xs_bwd = find_by_name(&bwd, "xs");
18421 let d_out_bwd = find_by_name(&bwd, "d_output");
18422
18423 let init_data = vec![0.5_f64, 0.0, -0.5];
18424 let xs_data: Vec<f64> = (0..length as usize * n)
18425 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
18426 .collect();
18427 let d_seed = [1.0_f64];
18428
18429 let (sched, mut arena) = prepare_f64(
18430 &bwd,
18431 &[
18432 (init_bwd, &init_data),
18433 (xs_bwd, &xs_data),
18434 (d_out_bwd, &d_seed),
18435 ],
18436 );
18437 execute_thunks(&sched, arena.raw_buf_mut());
18438 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18439 let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
18440
18441 let h = 1e-6;
18442 let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
18443 let mut acc = x0.to_vec();
18444 for t in 0..length as usize {
18445 for j in 0..n {
18446 acc[j] += xs_in[t * n + j];
18447 }
18448 let mut a_copy = m_data.clone();
18449 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18450 }
18451 acc.iter().sum()
18452 };
18453
18454 for i in 0..n {
18456 let mut ip = init_data.to_vec();
18457 ip[i] += h;
18458 let mut im = init_data.to_vec();
18459 im[i] -= h;
18460 let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
18461 assert!(
18462 (dinit[i] - fd).abs() < 1e-7,
18463 "FD dinit[{i}]: AD={} FD={}",
18464 dinit[i],
18465 fd
18466 );
18467 }
18468
18469 for t in 0..length as usize {
18471 for j in 0..n {
18472 let idx = t * n + j;
18473 let mut xp = xs_data.clone();
18474 xp[idx] += h;
18475 let mut xm = xs_data.clone();
18476 xm[idx] -= h;
18477 let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
18478 assert!(
18479 (dxs[idx] - fd).abs() < 1e-7,
18480 "FD dxs[t={t},j={j}]: AD={} FD={}",
18481 dxs[idx],
18482 fd
18483 );
18484 }
18485 }
18486 }
18487
18488 #[test]
18496 fn scan_with_xs_gradient_dinit_matches_fd() {
18497 use rlx_opt::autodiff::grad_with_loss;
18498 let n = 3usize;
18499 let length = 3u32;
18500 let dt = 0.1_f64;
18501
18502 let mut m_data = vec![0.0_f64; n * n];
18503 for i in 0..n {
18504 m_data[i * n + i] = 1.0 + dt * 2.0;
18505 if i > 0 {
18506 m_data[i * n + (i - 1)] = -dt;
18507 }
18508 if i + 1 < n {
18509 m_data[i * n + (i + 1)] = -dt;
18510 }
18511 }
18512 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18513
18514 let mut body = Graph::new("be_xs_grad_body");
18515 let carry = body.input("carry", Shape::new(&[n], DType::F64));
18516 let drive = body.input("drive", Shape::new(&[n], DType::F64));
18517 let m = body.add_node(
18518 Op::Constant { data: m_bytes },
18519 vec![],
18520 Shape::new(&[n, n], DType::F64),
18521 );
18522 let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18523 let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18524 body.set_outputs(vec![next]);
18525
18526 let mut g = Graph::new("be_xs_grad_outer");
18527 let init = g.input("init", Shape::new(&[n], DType::F64));
18528 let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18529 let final_carry = g.scan_with_xs(init, &[xs], body, length);
18530 let loss = g.reduce(
18531 final_carry,
18532 ReduceOp::Sum,
18533 vec![0],
18534 false,
18535 Shape::new(&[1], DType::F64),
18536 );
18537 g.set_outputs(vec![loss]);
18538
18539 let bwd = grad_with_loss(&g, &[init]);
18540
18541 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18542 for node in graph.nodes() {
18543 let name = match &node.op {
18544 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18545 _ => None,
18546 };
18547 if name == Some(want) {
18548 return node.id;
18549 }
18550 }
18551 panic!("no node named {want:?}");
18552 };
18553 let init_bwd = find_by_name(&bwd, "init");
18554 let xs_bwd = find_by_name(&bwd, "xs");
18555 let d_out_bwd = find_by_name(&bwd, "d_output");
18556
18557 let init_data = vec![0.5_f64, 0.0, -0.5];
18558 let xs_data: Vec<f64> = (0..length as usize * n)
18560 .map(|i| 0.1_f64 * ((i as f64) - 4.0))
18561 .collect();
18562 let d_seed = [1.0_f64];
18563
18564 let (sched, mut arena) = prepare_f64(
18565 &bwd,
18566 &[
18567 (init_bwd, &init_data),
18568 (xs_bwd, &xs_data),
18569 (d_out_bwd, &d_seed),
18570 ],
18571 );
18572 execute_thunks(&sched, arena.raw_buf_mut());
18573 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18574
18575 let h = 1e-6;
18576 let loss_at = |x0: &[f64]| -> f64 {
18577 let mut acc = x0.to_vec();
18578 for t in 0..length as usize {
18579 for j in 0..n {
18580 acc[j] += xs_data[t * n + j];
18581 }
18582 let mut a_copy = m_data.clone();
18583 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18584 }
18585 acc.iter().sum()
18586 };
18587 for i in 0..n {
18588 let mut ip = init_data.to_vec();
18589 ip[i] += h;
18590 let mut im = init_data.to_vec();
18591 im[i] -= h;
18592 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18593 assert!(
18594 (dinit[i] - fd).abs() < 1e-7,
18595 "FD dinit[{i}]: AD={} FD={}",
18596 dinit[i],
18597 fd
18598 );
18599 }
18600 }
18601
18602 #[test]
18610 fn scan_gradient_geometric_matches_closed_form() {
18611 use rlx_opt::autodiff::grad_with_loss;
18612 let n = 3usize;
18613 let length = 5u32;
18614
18615 let mut body = Graph::new("scan_grad_body");
18616 let x = body.input("carry", Shape::new(&[n], DType::F64));
18617 let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
18618 let scale = body.add_node(
18619 Op::Constant { data: scale_bytes },
18620 vec![],
18621 Shape::new(&[n], DType::F64),
18622 );
18623 let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
18624 body.set_outputs(vec![next]);
18625
18626 let mut g = Graph::new("scan_grad_outer");
18627 let init = g.input("init", Shape::new(&[n], DType::F64));
18628 let final_x = g.scan(init, body, length);
18629 let loss = g.reduce(
18630 final_x,
18631 ReduceOp::Sum,
18632 vec![0],
18633 false,
18634 Shape::new(&[1], DType::F64),
18635 );
18636 g.set_outputs(vec![loss]);
18637
18638 let bwd = grad_with_loss(&g, &[init]);
18639 assert_eq!(bwd.outputs.len(), 2);
18640
18641 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18642 for node in graph.nodes() {
18643 let name = match &node.op {
18644 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18645 _ => None,
18646 };
18647 if name == Some(want) {
18648 return node.id;
18649 }
18650 }
18651 panic!("no node named {want:?}");
18652 };
18653 let init_bwd = find_by_name(&bwd, "init");
18654 let d_out_bwd = find_by_name(&bwd, "d_output");
18655
18656 let init_data = vec![1.0_f64; n];
18657 let d_seed = [1.0_f64];
18658 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
18659 execute_thunks(&sched, arena.raw_buf_mut());
18660 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18661
18662 let want = 1.1_f64.powi(length as i32);
18663 for i in 0..n {
18664 assert!(
18665 (dinit[i] - want).abs() < 1e-12,
18666 "dinit[{i}] = {} want {}",
18667 dinit[i],
18668 want
18669 );
18670 }
18671
18672 let h = 1e-6;
18674 let loss_at = |x: &[f64]| -> f64 {
18675 let mut acc = x.to_vec();
18676 for _ in 0..length {
18677 for v in acc.iter_mut() {
18678 *v *= 1.1;
18679 }
18680 }
18681 acc.iter().sum()
18682 };
18683 let mut ip = init_data.clone();
18684 ip[0] += h;
18685 let mut im = init_data.clone();
18686 im[0] -= h;
18687 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18688 assert!(
18689 (dinit[0] - fd).abs() < 1e-7,
18690 "FD dinit[0]: AD={} FD={}",
18691 dinit[0],
18692 fd
18693 );
18694 }
18695
18696 #[test]
18699 fn scan_gradient_backward_euler_matches_fd() {
18700 use rlx_opt::autodiff::grad_with_loss;
18701 let n = 4usize;
18702 let length = 3u32;
18703 let dt = 0.05_f64;
18704
18705 let mut m_data = vec![0.0_f64; n * n];
18706 for i in 0..n {
18707 m_data[i * n + i] = 1.0 + dt * 2.0;
18708 if i > 0 {
18709 m_data[i * n + (i - 1)] = -dt;
18710 }
18711 if i + 1 < n {
18712 m_data[i * n + (i + 1)] = -dt;
18713 }
18714 }
18715 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18716
18717 let mut body = Graph::new("be_grad_body");
18718 let x = body.input("x", Shape::new(&[n], DType::F64));
18719 let m = body.add_node(
18720 Op::Constant { data: m_bytes },
18721 vec![],
18722 Shape::new(&[n, n], DType::F64),
18723 );
18724 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18725 body.set_outputs(vec![next]);
18726
18727 let mut g = Graph::new("be_grad_outer");
18728 let init = g.input("x0", Shape::new(&[n], DType::F64));
18729 let final_x = g.scan(init, body, length);
18730 let loss = g.reduce(
18731 final_x,
18732 ReduceOp::Sum,
18733 vec![0],
18734 false,
18735 Shape::new(&[1], DType::F64),
18736 );
18737 g.set_outputs(vec![loss]);
18738
18739 let bwd = grad_with_loss(&g, &[init]);
18740
18741 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18742 for node in graph.nodes() {
18743 let name = match &node.op {
18744 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18745 _ => None,
18746 };
18747 if name == Some(want) {
18748 return node.id;
18749 }
18750 }
18751 panic!("no node named {want:?}");
18752 };
18753 let init_bwd = find_by_name(&bwd, "x0");
18754 let d_out_bwd = find_by_name(&bwd, "d_output");
18755
18756 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18757 let d_seed = [1.0_f64];
18758 let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
18759 execute_thunks(&sched, arena.raw_buf_mut());
18760 let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18761
18762 let h = 1e-6;
18763 let loss_at = |x0: &[f64]| -> f64 {
18764 let mut acc = x0.to_vec();
18765 for _ in 0..length {
18766 let mut a_copy = m_data.clone();
18767 crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18768 }
18769 acc.iter().sum()
18770 };
18771 for i in 0..n {
18772 let mut ip = init_data.to_vec();
18773 ip[i] += h;
18774 let mut im = init_data.to_vec();
18775 im[i] -= h;
18776 let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18777 assert!(
18778 (dinit[i] - fd).abs() < 1e-7,
18779 "FD dinit[{i}]: AD={} FD={}",
18780 dinit[i],
18781 fd
18782 );
18783 }
18784 }
18785
18786 #[test]
18792 fn scan_trajectory_backward_euler_records_waveform() {
18793 let n = 4usize;
18794 let length = 5u32;
18795 let dt = 0.05_f64;
18796
18797 let mut m_data = vec![0.0_f64; n * n];
18798 for i in 0..n {
18799 m_data[i * n + i] = 1.0 + dt * 2.0;
18800 if i > 0 {
18801 m_data[i * n + (i - 1)] = -dt;
18802 }
18803 if i + 1 < n {
18804 m_data[i * n + (i + 1)] = -dt;
18805 }
18806 }
18807 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18808
18809 let mut body = Graph::new("be_traj_body");
18810 let x = body.input("x", Shape::new(&[n], DType::F64));
18811 let m = body.add_node(
18812 Op::Constant { data: m_bytes },
18813 vec![],
18814 Shape::new(&[n, n], DType::F64),
18815 );
18816 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18817 body.set_outputs(vec![next]);
18818
18819 let mut g = Graph::new("be_traj_outer");
18820 let init = g.input("x0", Shape::new(&[n], DType::F64));
18821 let traj = g.scan_trajectory(init, body, length);
18822 g.set_outputs(vec![traj]);
18823
18824 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18825 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
18826 execute_thunks(&sched, arena.raw_buf_mut());
18827 let got = read_arena_f64(&arena, traj, length as usize * n);
18828
18829 let mut want = Vec::<f64>::with_capacity(length as usize * n);
18831 let mut x_ref = init_data.to_vec();
18832 for _ in 0..length {
18833 let mut a_copy = m_data.clone();
18834 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
18835 want.extend_from_slice(&x_ref);
18836 }
18837 for i in 0..length as usize * n {
18838 assert!(
18839 (got[i] - want[i]).abs() < 1e-12,
18840 "got[{i}] = {} ref {}",
18841 got[i],
18842 want[i]
18843 );
18844 }
18845
18846 for t in 1..length as usize {
18849 let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
18850 let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
18851 assert!(
18852 curr <= prev + 1e-15,
18853 "mass should decay: row {} sum {prev}, row {t} sum {curr}",
18854 t - 1
18855 );
18856 }
18857
18858 let mut body2 = Graph::new("be_final_body");
18862 let x2 = body2.input("x", Shape::new(&[n], DType::F64));
18863 let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18864 let m2 = body2.add_node(
18865 Op::Constant { data: m_bytes2 },
18866 vec![],
18867 Shape::new(&[n, n], DType::F64),
18868 );
18869 let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
18870 body2.set_outputs(vec![next2]);
18871
18872 let mut g2 = Graph::new("be_final_outer");
18873 let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
18874 let final_x = g2.scan(init2, body2, length);
18875 g2.set_outputs(vec![final_x]);
18876 let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
18877 execute_thunks(&sched2, arena2.raw_buf_mut());
18878 let final_got = read_arena_f64(&arena2, final_x, n);
18879
18880 let last_row = &got[(length as usize - 1) * n..length as usize * n];
18881 for i in 0..n {
18882 assert!(
18883 (last_row[i] - final_got[i]).abs() < 1e-15,
18884 "last trajectory row[{i}] = {} vs final-scan = {}",
18885 last_row[i],
18886 final_got[i]
18887 );
18888 }
18889 }
18890
18891 #[test]
18897 fn scan_backward_euler_heat_f64() {
18898 let n = 4usize;
18899 let length = 5u32;
18900 let dt = 0.05_f64;
18901
18902 let mut m_data = vec![0.0_f64; n * n];
18905 for i in 0..n {
18906 m_data[i * n + i] = 1.0 + dt * 2.0;
18907 if i > 0 {
18908 m_data[i * n + (i - 1)] = -dt;
18909 }
18910 if i + 1 < n {
18911 m_data[i * n + (i + 1)] = -dt;
18912 }
18913 }
18914 let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18915
18916 let mut body = Graph::new("be_body");
18917 let x = body.input("x", Shape::new(&[n], DType::F64));
18918 let m = body.add_node(
18919 Op::Constant { data: m_bytes },
18920 vec![],
18921 Shape::new(&[n, n], DType::F64),
18922 );
18923 let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18924 body.set_outputs(vec![next]);
18925
18926 let mut g = Graph::new("be_outer");
18927 let init = g.input("x0", Shape::new(&[n], DType::F64));
18928 let final_x = g.scan(init, body, length);
18929 g.set_outputs(vec![final_x]);
18930
18931 let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18933 let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
18934 execute_thunks(&sched, arena.raw_buf_mut());
18935 let got = read_arena_f64(&arena, final_x, n);
18936
18937 let mut ref_x = init_data.to_vec();
18939 for _ in 0..length {
18940 let mut a_copy = m_data.clone();
18941 crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
18942 }
18943 for i in 0..n {
18944 assert!(
18945 (got[i] - ref_x[i]).abs() < 1e-12,
18946 "got[{i}] = {} ref {}",
18947 got[i],
18948 ref_x[i]
18949 );
18950 }
18951 let mass: f64 = got.iter().sum();
18956 assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
18957 }
18958
18959 #[test]
18963 fn dense_solve_f64_multi_rhs_forward() {
18964 let n = 3usize;
18965 let k = 2usize;
18966 let mut g = Graph::new("solve_multi_rhs");
18967 let a = g.input("A", Shape::new(&[n, n], DType::F64));
18968 let b = g.input("B", Shape::new(&[n, k], DType::F64));
18969 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
18970 g.set_outputs(vec![x]);
18971
18972 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
18973 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
18974 let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
18975 execute_thunks(&sched, arena.raw_buf_mut());
18976 let x_got = read_arena_f64(&arena, x, n * k);
18977 for c in 0..k {
18978 for i in 0..n {
18979 let mut acc = 0.0_f64;
18980 for j in 0..n {
18981 acc += a_data[i * n + j] * x_got[j * k + c];
18982 }
18983 let want = b_data[i * k + c];
18984 assert!(
18985 (acc - want).abs() < 1e-10,
18986 "col {c} row {i}: got {acc} want {want}"
18987 );
18988 }
18989 }
18990 }
18991
18992 #[test]
18995 fn dense_solve_f64_multi_rhs_gradient() {
18996 use rlx_opt::autodiff::grad_with_loss;
18997 let n = 3usize;
18998 let k = 2usize;
18999 let mut g = Graph::new("solve_mrhs_grad");
19000 let a = g.param("A", Shape::new(&[n, n], DType::F64));
19001 let b = g.input("B", Shape::new(&[n, k], DType::F64));
19002 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19003 let loss = g.reduce(
19004 x,
19005 ReduceOp::Sum,
19006 vec![0, 1],
19007 false,
19008 Shape::new(&[1], DType::F64),
19009 );
19010 g.set_outputs(vec![loss]);
19011
19012 let bwd = grad_with_loss(&g, &[a, b]);
19013 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19014 for node in graph.nodes() {
19015 let name = match &node.op {
19016 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19017 _ => None,
19018 };
19019 if name == Some(want) {
19020 return node.id;
19021 }
19022 }
19023 panic!("no node named {want:?}");
19024 };
19025 let a_bwd = find_by_name(&bwd, "A");
19026 let b_bwd = find_by_name(&bwd, "B");
19027 let d_out = find_by_name(&bwd, "d_output");
19028
19029 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19030 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19031 let d_seed = [1.0_f64];
19032
19033 let (sched, mut arena) = prepare_f64(
19034 &bwd,
19035 &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
19036 );
19037 execute_thunks(&sched, arena.raw_buf_mut());
19038 let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
19039 let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
19040
19041 let mut x_ref = b_data;
19043 {
19044 let mut a_copy = a_data;
19045 crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
19046 }
19047 let mut at = [0.0_f64; 9];
19048 for i in 0..n {
19049 for j in 0..n {
19050 at[i * n + j] = a_data[j * n + i];
19051 }
19052 }
19053 let mut ones_nk = vec![1.0_f64; n * k];
19054 crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
19055 let db_ref = ones_nk;
19056 let mut da_ref = [0.0_f64; 9];
19057 for i in 0..n {
19058 for j in 0..n {
19059 let mut acc = 0.0_f64;
19060 for c in 0..k {
19061 acc += db_ref[i * k + c] * x_ref[j * k + c];
19062 }
19063 da_ref[i * n + j] = -acc;
19064 }
19065 }
19066 for i in 0..n * k {
19067 assert!(
19068 (db_got[i] - db_ref[i]).abs() < 1e-10,
19069 "dB[{i}]: got {} want {}",
19070 db_got[i],
19071 db_ref[i]
19072 );
19073 }
19074 for i in 0..n * n {
19075 assert!(
19076 (da_got[i] - da_ref[i]).abs() < 1e-10,
19077 "dA[{i}]: got {} want {}",
19078 da_got[i],
19079 da_ref[i]
19080 );
19081 }
19082
19083 let h = 1e-6;
19085 let mut bp = b_data;
19086 bp[0] += h;
19087 let mut bm = b_data;
19088 bm[0] -= h;
19089 let xp = {
19090 let mut a_copy = a_data;
19091 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
19092 bp
19093 };
19094 let xm = {
19095 let mut a_copy = a_data;
19096 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
19097 bm
19098 };
19099 let lp: f64 = xp.iter().sum();
19100 let lm: f64 = xm.iter().sum();
19101 let fd = (lp - lm) / (2.0 * h);
19102 assert!(
19103 (db_got[0] - fd).abs() < 1e-7,
19104 "FD dB[0,0]: AD={} FD={}",
19105 db_got[0],
19106 fd
19107 );
19108 }
19109
19110 #[test]
19112 fn dense_solve_f64_multi_rhs_jvp() {
19113 use rlx_opt::autodiff_fwd::jvp;
19114 let n = 3usize;
19115 let k = 2usize;
19116 let mut g = Graph::new("solve_mrhs_jvp");
19117 let a = g.input("A", Shape::new(&[n, n], DType::F64));
19118 let b = g.input("B", Shape::new(&[n, k], DType::F64));
19119 let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19120 g.set_outputs(vec![x]);
19121
19122 let jg = jvp(&g, &[b]);
19123 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19124 for node in graph.nodes() {
19125 let name = match &node.op {
19126 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19127 _ => None,
19128 };
19129 if name == Some(want) {
19130 return node.id;
19131 }
19132 }
19133 panic!("no node named {want:?}");
19134 };
19135 let a_id = find_by_name(&jg, "A");
19136 let b_id = find_by_name(&jg, "B");
19137 let tb_id = find_by_name(&jg, "tangent_B");
19138
19139 let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19140 let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19141 let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
19142
19143 let (sched, mut arena) =
19144 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
19145 execute_thunks(&sched, arena.raw_buf_mut());
19146 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
19147
19148 let mut a_copy = a_data;
19149 let mut tb_copy = tb_data;
19150 crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
19151 for i in 0..n * k {
19152 assert!(
19153 (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
19154 "t_X[{i}]: AD={} ref={}",
19155 tangent_x[i],
19156 tb_copy[i]
19157 );
19158 }
19159
19160 let h = 1e-6;
19161 let mut bp = b_data;
19162 let mut bm = b_data;
19163 for i in 0..n * k {
19164 bp[i] += h * tb_data[i];
19165 bm[i] -= h * tb_data[i];
19166 }
19167 let xp = {
19168 let mut a_copy = a_data;
19169 crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
19170 bp
19171 };
19172 let xm = {
19173 let mut a_copy = a_data;
19174 crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
19175 bm
19176 };
19177 for i in 0..n * k {
19178 let fd = (xp[i] - xm[i]) / (2.0 * h);
19179 assert!(
19180 (tangent_x[i] - fd).abs() < 1e-7,
19181 "FD t_X[{i}]: AD={} FD={}",
19182 tangent_x[i],
19183 fd
19184 );
19185 }
19186 }
19187
19188 #[test]
19195 fn jvp_dense_solve_b_runs_and_matches_fd() {
19196 use rlx_opt::autodiff_fwd::jvp;
19197 let n = 3usize;
19198
19199 let mut g = Graph::new("jvp_b_e2e");
19201 let a = g.input("A", Shape::new(&[n, n], DType::F64));
19202 let b = g.input("b", Shape::new(&[n], DType::F64));
19203 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19204 g.set_outputs(vec![x]);
19205
19206 let jg = jvp(&g, &[b]);
19208 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19210 for node in graph.nodes() {
19211 let name = match &node.op {
19212 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19213 _ => None,
19214 };
19215 if name == Some(want) {
19216 return node.id;
19217 }
19218 }
19219 panic!("no node named {want:?}");
19220 };
19221 let a_id = find_by_name(&jg, "A");
19222 let b_id = find_by_name(&jg, "b");
19223 let tb_id = find_by_name(&jg, "tangent_b");
19224
19225 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
19226 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
19227 let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
19229
19230 let (sched, mut arena) =
19231 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
19232 execute_thunks(&sched, arena.raw_buf_mut());
19233
19234 let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
19236 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
19237
19238 let t_x_ref = {
19240 let mut a = a_data;
19241 let mut tb = tb_data;
19242 let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
19243 assert_eq!(info, 0);
19244 tb
19245 };
19246 for i in 0..n {
19247 assert!(
19248 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
19249 "t_x[{i}]: got {} want {}",
19250 tangent_x[i],
19251 t_x_ref[i]
19252 );
19253 }
19254
19255 let h = 1e-6;
19257 let mut bp = b_data;
19258 let mut bm = b_data;
19259 for i in 0..n {
19260 bp[i] += h * tb_data[i];
19261 bm[i] -= h * tb_data[i];
19262 }
19263 let xp = {
19264 let mut a = a_data;
19265 let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
19266 assert_eq!(info, 0);
19267 bp
19268 };
19269 let xm = {
19270 let mut a = a_data;
19271 let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
19272 assert_eq!(info, 0);
19273 bm
19274 };
19275 let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
19276 for i in 0..n {
19277 assert!(
19278 (tangent_x[i] - fd[i]).abs() < 1e-7,
19279 "FD mismatch t_x[{i}]: AD={} FD={}",
19280 tangent_x[i],
19281 fd[i]
19282 );
19283 }
19284 let primal_ref = {
19286 let mut a = a_data;
19287 let mut b = b_data;
19288 crate::blas::dgesv(&mut a, &mut b, n, 1);
19289 b
19290 };
19291 for i in 0..n {
19292 assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
19293 }
19294 }
19295
19296 #[test]
19302 fn jvp_dense_solve_a_runs_and_matches_fd() {
19303 use rlx_opt::autodiff_fwd::jvp;
19304 let n = 3usize;
19305
19306 let mut g = Graph::new("jvp_a_e2e");
19307 let a = g.input("A", Shape::new(&[n, n], DType::F64));
19308 let b = g.input("b", Shape::new(&[n], DType::F64));
19309 let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19310 g.set_outputs(vec![x]);
19311
19312 let jg = jvp(&g, &[a]);
19313 let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19314 for node in graph.nodes() {
19315 let name = match &node.op {
19316 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19317 _ => None,
19318 };
19319 if name == Some(want) {
19320 return node.id;
19321 }
19322 }
19323 panic!("no node named {want:?}");
19324 };
19325 let a_id = find_by_name(&jg, "A");
19326 let b_id = find_by_name(&jg, "b");
19327 let ta_id = find_by_name(&jg, "tangent_A");
19328
19329 let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
19330 let b_data: [f64; 3] = [1.0, 2.0, 3.0];
19331 let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
19333
19334 let (sched, mut arena) =
19335 prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
19336 execute_thunks(&sched, arena.raw_buf_mut());
19337
19338 let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
19339
19340 let x_ref = {
19342 let mut a = a_data;
19343 let mut b = b_data;
19344 crate::blas::dgesv(&mut a, &mut b, n, 1);
19345 b
19346 };
19347 let mut prod = [0.0_f64; 3];
19348 for i in 0..n {
19349 for j in 0..n {
19350 prod[i] += ta_data[i * n + j] * x_ref[j];
19351 }
19352 }
19353 let t_x_ref = {
19354 let mut a = a_data;
19355 let mut p = prod;
19356 crate::blas::dgesv(&mut a, &mut p, n, 1);
19357 [-p[0], -p[1], -p[2]]
19358 };
19359 for i in 0..n {
19360 assert!(
19361 (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
19362 "closed-form t_x[{i}]: AD={} ref={}",
19363 tangent_x[i],
19364 t_x_ref[i]
19365 );
19366 }
19367
19368 let h = 1e-6;
19370 let mut ap = a_data;
19371 let mut am = a_data;
19372 for i in 0..n * n {
19373 ap[i] += h * ta_data[i];
19374 am[i] -= h * ta_data[i];
19375 }
19376 let xp = {
19377 let mut a = ap;
19378 let mut b = b_data;
19379 crate::blas::dgesv(&mut a, &mut b, n, 1);
19380 b
19381 };
19382 let xm = {
19383 let mut a = am;
19384 let mut b = b_data;
19385 crate::blas::dgesv(&mut a, &mut b, n, 1);
19386 b
19387 };
19388 for i in 0..n {
19389 let fd = (xp[i] - xm[i]) / (2.0 * h);
19390 assert!(
19391 (tangent_x[i] - fd).abs() < 1e-7,
19392 "FD t_x[{i}]: AD={} FD={}",
19393 tangent_x[i],
19394 fd
19395 );
19396 }
19397 }
19398
19399 #[test]
19405 fn q_conv2d_matches_reference() {
19406 use rlx_ir::Philox4x32;
19407 let n = 1usize;
19409 let c_in = 2usize;
19410 let h = 5usize;
19411 let w_in = 5usize;
19412 let c_out = 3usize;
19413 let kh = 3usize;
19414 let kw = 3usize;
19415 let ph = 1usize;
19416 let pw = 1usize;
19417 let sh = 1usize;
19418 let sw = 1usize;
19419 let h_out = (h + 2 * ph - kh) / sh + 1;
19420 let w_out = (w_in + 2 * pw - kw) / sw + 1;
19421
19422 let x_scale = 0.04f32;
19423 let w_scale = 0.02f32;
19424 let out_scale = 0.5f32;
19425 let mult = x_scale * w_scale / out_scale;
19426
19427 let mut rng = Philox4x32::new(2099);
19428 let mut xf = vec![0f32; n * c_in * h * w_in];
19429 rng.fill_normal(&mut xf);
19430 let mut wf = vec![0f32; c_out * c_in * kh * kw];
19431 rng.fill_normal(&mut wf);
19432 let xq: Vec<i8> = xf
19433 .iter()
19434 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
19435 .collect();
19436 let wq: Vec<i8> = wf
19437 .iter()
19438 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
19439 .collect();
19440 let bias: Vec<i32> = vec![0i32; c_out];
19441
19442 let mut g = Graph::new("qconv");
19443 let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
19444 let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
19445 let bn = g.input("b", Shape::new(&[c_out], DType::I32));
19446 let out = g.q_conv2d(
19447 xn,
19448 wn,
19449 bn,
19450 vec![kh, kw],
19451 vec![sh, sw],
19452 vec![ph, pw],
19453 vec![1, 1],
19454 1,
19455 0,
19456 0,
19457 0,
19458 mult,
19459 Shape::new(&[n, c_out, h_out, w_out], DType::I8),
19460 );
19461 g.set_outputs(vec![out]);
19462
19463 let plan = rlx_opt::memory::plan_memory(&g);
19464 let mut arena = crate::arena::Arena::from_plan(plan);
19465 let sched = compile_thunks(&g, &arena);
19466 let xn_off = arena.byte_offset(xn);
19469 let wn_off = arena.byte_offset(wn);
19470 let bn_off = arena.byte_offset(bn);
19471 let out_off = arena.byte_offset(out);
19472 let buf = arena.raw_buf_mut();
19473 unsafe {
19474 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
19475 for (i, &v) in xq.iter().enumerate() {
19476 *p.add(i) = v;
19477 }
19478 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
19479 for (i, &v) in wq.iter().enumerate() {
19480 *p.add(i) = v;
19481 }
19482 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
19483 for (i, &v) in bias.iter().enumerate() {
19484 *p.add(i) = v;
19485 }
19486 }
19487 execute_thunks(&sched, arena.raw_buf_mut());
19488 let out_q: Vec<i8> = unsafe {
19489 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
19490 (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
19491 };
19492
19493 let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
19495 for ni in 0..n {
19496 for co in 0..c_out {
19497 for ho in 0..h_out {
19498 for wo in 0..w_out {
19499 let mut acc: i32 = 0;
19500 for ci in 0..c_in {
19501 for ki in 0..kh {
19502 for kj in 0..kw {
19503 let hi = ho * sh + ki;
19504 let wi = wo * sw + kj;
19505 if hi < ph || wi < pw {
19506 continue;
19507 }
19508 let hi = hi - ph;
19509 let wi = wi - pw;
19510 if hi >= h || wi >= w_in {
19511 continue;
19512 }
19513 let xv =
19514 xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
19515 let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
19516 acc += xv * wv;
19517 }
19518 }
19519 }
19520 let r = (acc as f32 * mult).round() as i32;
19521 let r = r.clamp(-128, 127) as i8;
19522 out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
19523 }
19524 }
19525 }
19526 }
19527
19528 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
19529 assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
19530 }
19531 }
19532
19533 #[test]
19541 fn q_matmul_matches_fake_quant_reference() {
19542 use rlx_ir::Philox4x32;
19543 let m = 3usize;
19544 let k = 8usize;
19545 let n = 5usize;
19546 let mut rng = Philox4x32::new(2031);
19547
19548 let x_scale = 0.05f32;
19550 let w_scale = 0.03f32;
19551 let out_scale = 0.4f32;
19552 let mult = x_scale * w_scale / out_scale;
19553 let mut xf = vec![0f32; m * k];
19554 rng.fill_normal(&mut xf);
19555 let mut wf = vec![0f32; k * n];
19556 rng.fill_normal(&mut wf);
19557 let xq: Vec<i8> = xf
19558 .iter()
19559 .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
19560 .collect();
19561 let wq: Vec<i8> = wf
19562 .iter()
19563 .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
19564 .collect();
19565 let bias: Vec<i32> = vec![0i32; n];
19566
19567 let _f = DType::F32;
19569 let mut g_q = Graph::new("qmm_direct");
19570 let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
19571 let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
19572 let bn = g_q.input("b", Shape::new(&[n], DType::I32));
19573 let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
19574 g_q.set_outputs(vec![out]);
19575 let plan = rlx_opt::memory::plan_memory(&g_q);
19576 let mut arena = crate::arena::Arena::from_plan(plan);
19577 let sched = compile_thunks(&g_q, &arena);
19578
19579 let xn_off = arena.byte_offset(xn);
19581 let wn_off = arena.byte_offset(wn);
19582 let bn_off = arena.byte_offset(bn);
19583 let out_off = arena.byte_offset(out);
19584 let buf = arena.raw_buf_mut();
19585 unsafe {
19586 let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
19587 for (i, &v) in xq.iter().enumerate() {
19588 *p.add(i) = v;
19589 }
19590 let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
19591 for (i, &v) in wq.iter().enumerate() {
19592 *p.add(i) = v;
19593 }
19594 let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
19595 for (i, &v) in bias.iter().enumerate() {
19596 *p.add(i) = v;
19597 }
19598 }
19599 execute_thunks(&sched, arena.raw_buf_mut());
19600 let out_q: Vec<i8> = unsafe {
19601 let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
19602 (0..m * n).map(|i| *p.add(i)).collect()
19603 };
19604
19605 let mut out_ref = vec![0i8; m * n];
19610 for mi in 0..m {
19611 for ni in 0..n {
19612 let mut acc: i32 = 0;
19613 for ki in 0..k {
19614 acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
19615 }
19616 let r = (acc as f32 * mult).round() as i32;
19617 out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
19618 }
19619 }
19620
19621 for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
19622 assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
19623 }
19624 }
19625
19626 #[test]
19631 fn quantize_dequantize_round_trip() {
19632 use rlx_ir::Philox4x32;
19633 let len = 64;
19634 let mut rng = Philox4x32::new(2027);
19635 let mut x = vec![0f32; len];
19636 rng.fill_normal(&mut x);
19637 x[0] = 999.0;
19640 x[1] = -999.0;
19641
19642 let scale = 0.05f32;
19643 let zp = 3i32;
19644
19645 let f = DType::F32;
19646 let mut g = Graph::new("qdq");
19647 let xn = g.input("x", Shape::new(&[len], f));
19648 let q = g.quantize(xn, scale, zp);
19649 let dq = g.dequantize(q, scale, zp);
19650 g.set_outputs(vec![dq]);
19651
19652 let plan = rlx_opt::memory::plan_memory(&g);
19653 let mut arena = crate::arena::Arena::from_plan(plan);
19654 let sched = compile_thunks(&g, &arena);
19655 let xn_off = arena.byte_offset(xn);
19656 let dq_off = arena.byte_offset(dq);
19657 let buf = arena.raw_buf_mut();
19658 unsafe {
19659 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19660 for (i, &v) in x.iter().enumerate() {
19661 *p.add(i) = v;
19662 }
19663 }
19664 execute_thunks(&sched, arena.raw_buf_mut());
19665 let out: Vec<f32> = unsafe {
19666 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
19667 (0..len).map(|i| *p.add(i)).collect()
19668 };
19669
19670 let sat_pos = (127 - zp) as f32 * scale;
19673 let sat_neg = (-128 - zp) as f32 * scale;
19674 assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
19675 assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
19676
19677 for i in 2..len {
19680 assert!(
19681 (out[i] - x[i]).abs() <= scale + 1e-5,
19682 "qdq[{i}]: {} → {}, scale={scale}",
19683 x[i],
19684 out[i]
19685 );
19686 }
19687 }
19688
19689 #[test]
19695 fn quantize_per_channel_round_trip() {
19696 let c = 4usize;
19697 let inner = 5usize;
19698 let mags = [0.01f32, 0.5, 5.0, 50.0];
19701 let mut x = vec![0f32; c * inner];
19702 for ci in 0..c {
19703 for ii in 0..inner {
19704 x[ci * inner + ii] = match ii {
19708 0 => -mags[ci],
19709 1 => 0.0,
19710 2 => mags[ci],
19711 3 => mags[ci] * 1000.0, _ => -mags[ci] * 1000.0, };
19714 }
19715 }
19716 let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
19717 let zps: Vec<i32> = vec![0, 0, 0, 0];
19718
19719 let f = DType::F32;
19720 let mut g = Graph::new("qdq_pc");
19721 let xn = g.input("x", Shape::new(&[c, inner], f));
19722 let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
19723 let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
19724 g.set_outputs(vec![dq]);
19725
19726 let plan = rlx_opt::memory::plan_memory(&g);
19727 let mut arena = crate::arena::Arena::from_plan(plan);
19728 let sched = compile_thunks(&g, &arena);
19729 let xn_off = arena.byte_offset(xn);
19730 let dq_off = arena.byte_offset(dq);
19731 let buf = arena.raw_buf_mut();
19732 unsafe {
19733 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19734 for (i, &v) in x.iter().enumerate() {
19735 *p.add(i) = v;
19736 }
19737 }
19738 execute_thunks(&sched, arena.raw_buf_mut());
19739 let out: Vec<f32> = unsafe {
19740 let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
19741 (0..c * inner).map(|i| *p.add(i)).collect()
19742 };
19743
19744 for ci in 0..c {
19745 for ii in 0..3 {
19748 let idx = ci * inner + ii;
19749 assert!(
19750 (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
19751 "ch {ci} idx {ii}: {} vs {}",
19752 x[idx],
19753 out[idx]
19754 );
19755 }
19756 let sat_pos = 127.0 * scales[ci];
19758 let sat_neg = -128.0 * scales[ci];
19759 assert!(
19760 (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
19761 "ch {ci} +sat: {}",
19762 out[ci * inner + 3]
19763 );
19764 assert!(
19765 (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
19766 "ch {ci} -sat: {}",
19767 out[ci * inner + 4]
19768 );
19769 }
19770 }
19771
19772 #[test]
19778 fn activation_backward_matches_numerical_per_kind() {
19779 use rlx_ir::Philox4x32;
19780 use rlx_ir::op::Activation;
19781 let mut rng = Philox4x32::new(91);
19782 let len = 32;
19783 let mut x_pos = vec![0f32; len];
19788 rng.fill_normal(&mut x_pos);
19789 for v in x_pos.iter_mut() {
19790 *v = v.abs() + 0.5;
19791 }
19792 let mut x_any = vec![0f32; len];
19793 rng.fill_normal(&mut x_any);
19794 let mut dy = vec![0f32; len];
19795 rng.fill_normal(&mut dy);
19796
19797 for &(kind, x_data, eps, tol) in &[
19798 (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
19799 (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
19800 (Activation::Silu, &x_any[..], 1e-3, 5e-3),
19801 (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
19802 (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
19803 (Activation::Exp, &x_any[..], 1e-4, 5e-3),
19804 (Activation::Log, &x_pos[..], 1e-4, 5e-3),
19805 (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
19806 (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
19807 (Activation::Neg, &x_any[..], 1e-3, 5e-4),
19808 ] {
19809 let f = DType::F32;
19810 let mut g = Graph::new("act_bw");
19811 let xn = g.input("x", Shape::new(&[len], f));
19812 let dyn_ = g.input("dy", Shape::new(&[len], f));
19813 let dx = g.activation_backward(kind, xn, dyn_);
19814 g.set_outputs(vec![dx]);
19815
19816 let plan = rlx_opt::memory::plan_memory(&g);
19817 let mut arena = crate::arena::Arena::from_plan(plan);
19818 let sched = compile_thunks(&g, &arena);
19819
19820 let xn_off = arena.byte_offset(xn);
19821 let dyn_off = arena.byte_offset(dyn_);
19822 let dx_off = arena.byte_offset(dx);
19823 let buf = arena.raw_buf_mut();
19824 unsafe {
19825 let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19826 for (i, &v) in x_data.iter().enumerate() {
19827 *p.add(i) = v;
19828 }
19829 let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
19830 for (i, &v) in dy.iter().enumerate() {
19831 *p.add(i) = v;
19832 }
19833 }
19834 execute_thunks(&sched, arena.raw_buf_mut());
19835 let analytical: Vec<f32> = unsafe {
19836 let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
19837 (0..len).map(|i| *p.add(i)).collect()
19838 };
19839
19840 let act_apply = |kind: Activation, x: f32| -> f32 {
19843 match kind {
19844 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
19845 Activation::Tanh => x.tanh(),
19846 Activation::Silu => x / (1.0 + (-x).exp()),
19847 Activation::Gelu => {
19848 const INV_SQRT2: f32 = 0.707_106_77;
19850 0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
19851 }
19852 Activation::GeluApprox => {
19853 const C: f32 = 0.797_884_6;
19854 const A: f32 = 0.044_715;
19855 let inner = C * (x + A * x * x * x);
19856 0.5 * x * (1.0 + inner.tanh())
19857 }
19858 Activation::Exp => x.exp(),
19859 Activation::Log => x.ln(),
19860 Activation::Sqrt => x.sqrt(),
19861 Activation::Rsqrt => 1.0 / x.sqrt(),
19862 Activation::Neg => -x,
19863 Activation::Relu => x.max(0.0),
19864 Activation::Abs => x.abs(),
19865 Activation::Round => x.round(),
19866 Activation::Sin => x.sin(),
19867 Activation::Cos => x.cos(),
19868 Activation::Tan => x.tan(),
19869 Activation::Atan => x.atan(),
19870 }
19871 };
19872 for i in 0..len {
19873 let xv = x_data[i];
19874 let plus = act_apply(kind, xv + eps);
19875 let minus = act_apply(kind, xv - eps);
19876 let num = (plus - minus) / (2.0 * eps) * dy[i];
19877 assert!(
19878 (analytical[i] - num).abs() < tol,
19879 "{kind:?}[{i}]: analytical {} vs numerical {num}",
19880 analytical[i]
19881 );
19882 }
19883 }
19884 }
19885
19886 #[test]
19890 fn matmul_3d_gradient_matches_numerical() {
19891 use rlx_ir::Philox4x32;
19892 let batch = 2usize;
19893 let m = 3usize;
19894 let k = 4usize;
19895 let n = 5usize;
19896 let mut rng = Philox4x32::new(101);
19897 let mut a_data = vec![0f32; batch * m * k];
19898 rng.fill_normal(&mut a_data);
19899 let mut b_data = vec![0f32; batch * k * n];
19900 rng.fill_normal(&mut b_data);
19901
19902 let f = DType::F32;
19903 let mut fwd = Graph::new("matmul_3d");
19904 let an = fwd.input("a", Shape::new(&[batch, m, k], f));
19905 let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
19906 let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
19907 let loss = fwd.add_node(
19908 Op::Reduce {
19909 op: ReduceOp::Sum,
19910 axes: vec![0, 1, 2],
19911 keep_dim: false,
19912 },
19913 vec![mm],
19914 Shape::from_dims(&[], f),
19915 );
19916 fwd.set_outputs(vec![loss]);
19917
19918 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
19919 let d_out = bwd_graph
19920 .nodes()
19921 .iter()
19922 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
19923 .map(|n| n.id)
19924 .unwrap();
19925
19926 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
19927 let mut arena = crate::arena::Arena::from_plan(plan);
19928 let sched = compile_thunks(&bwd_graph, &arena);
19929 for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
19930 let off = arena.byte_offset(id);
19931 let buf = arena.raw_buf_mut();
19932 unsafe {
19933 let p = buf.as_mut_ptr().add(off) as *mut f32;
19934 for (i, &v) in data.iter().enumerate() {
19935 *p.add(i) = v;
19936 }
19937 }
19938 }
19939 execute_thunks(&sched, arena.raw_buf_mut());
19940 let gb_id = bwd_graph.outputs[1];
19941 let g_b: Vec<f32> = unsafe {
19942 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
19943 (0..batch * k * n).map(|i| *p.add(i)).collect()
19944 };
19945
19946 let forward_loss = |b_vals: &[f32]| -> f32 {
19948 let mut out = vec![0f32; batch * m * n];
19949 for bi in 0..batch {
19950 for mi in 0..m {
19951 for ni in 0..n {
19952 let mut acc = 0f32;
19953 for ki in 0..k {
19954 acc +=
19955 a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
19956 }
19957 out[bi * m * n + mi * n + ni] = acc;
19958 }
19959 }
19960 }
19961 out.iter().sum()
19962 };
19963 let eps = 1e-3f32;
19964 let mut bp_p = b_data.clone();
19965 let mut g_b_num = vec![0f32; b_data.len()];
19966 for i in 0..b_data.len() {
19967 let s = bp_p[i];
19968 bp_p[i] = s + eps;
19969 let lp = forward_loss(&bp_p);
19970 bp_p[i] = s - eps;
19971 let lm = forward_loss(&bp_p);
19972 bp_p[i] = s;
19973 g_b_num[i] = (lp - lm) / (2.0 * eps);
19974 }
19975 for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
19976 assert!(
19977 (a - n).abs() < 5e-3,
19978 "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
19979 );
19980 }
19981 }
19982
19983 #[test]
19989 fn softmax_gradient_matches_numerical() {
19990 use rlx_ir::Philox4x32;
19991 let n = 3usize;
19992 let c = 5usize;
19993 let mut rng = Philox4x32::new(57);
19994 let mut x_data = vec![0f32; n * c];
19995 rng.fill_normal(&mut x_data);
19996
19997 let f = DType::F32;
19998 let mut fwd = Graph::new("softmax_only");
19999 let xn = fwd.input("x", Shape::new(&[n, c], f));
20000 let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
20001 let loss = fwd.add_node(
20005 Op::Reduce {
20006 op: ReduceOp::Sum,
20007 axes: vec![0, 1],
20008 keep_dim: false,
20009 },
20010 vec![sm],
20011 Shape::from_dims(&[], f),
20012 );
20013 fwd.set_outputs(vec![loss]);
20014
20015 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
20019 let d_out = bwd_graph
20020 .nodes()
20021 .iter()
20022 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20023 .map(|n| n.id)
20024 .unwrap();
20025
20026 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20027 let mut arena = crate::arena::Arena::from_plan(plan);
20028 let sched = compile_thunks(&bwd_graph, &arena);
20029 for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
20030 let off = arena.byte_offset(id);
20031 let buf = arena.raw_buf_mut();
20032 unsafe {
20033 let p = buf.as_mut_ptr().add(off) as *mut f32;
20034 for (i, &v) in data.iter().enumerate() {
20035 *p.add(i) = v;
20036 }
20037 }
20038 }
20039 execute_thunks(&sched, arena.raw_buf_mut());
20040 let g_x_id = bwd_graph.outputs[1];
20041 let g_x: Vec<f32> = unsafe {
20042 let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
20043 (0..n * c).map(|i| *p.add(i)).collect()
20044 };
20045
20046 let forward_loss = |x: &[f32]| -> f32 {
20050 let mut total = 0f32;
20051 for ni in 0..n {
20052 let row = &x[ni * c..(ni + 1) * c];
20053 let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
20054 let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
20055 for &v in row {
20056 total += (v - m).exp() / denom;
20057 }
20058 }
20059 total
20060 };
20061 let eps = 1e-3f32;
20062 let mut p = x_data.clone();
20063 for i in 0..x_data.len() {
20064 let s = p[i];
20065 p[i] = s + eps;
20066 let lp = forward_loss(&p);
20067 p[i] = s - eps;
20068 let lm = forward_loss(&p);
20069 p[i] = s;
20070 let num = (lp - lm) / (2.0 * eps);
20071 assert!(
20072 (g_x[i] - num).abs() < 5e-3,
20073 "softmax g_x[{i}]: analytical {} vs numerical {num}",
20074 g_x[i]
20075 );
20076 }
20077 }
20078
20079 #[test]
20084 fn layer_norm_gradient_matches_numerical() {
20085 use rlx_ir::Philox4x32;
20086 let rows = 3usize;
20087 let h = 6usize;
20088 let mut rng = Philox4x32::new(1009);
20089 let mut x_data = vec![0f32; rows * h];
20090 rng.fill_normal(&mut x_data);
20091 let mut g_data = vec![0f32; h];
20092 rng.fill_normal(&mut g_data);
20093 for v in g_data.iter_mut() {
20094 *v = v.abs() + 0.5;
20095 }
20096 let mut b_data = vec![0f32; h];
20097 rng.fill_normal(&mut b_data);
20098 let eps = 1e-5f32;
20099
20100 let f = DType::F32;
20101 let mut fwd = Graph::new("ln_only");
20102 let xn = fwd.input("x", Shape::new(&[rows, h], f));
20103 let gp = fwd.param("gamma", Shape::new(&[h], f));
20104 let bp = fwd.param("beta", Shape::new(&[h], f));
20105 let ln = fwd.add_node(
20106 Op::LayerNorm { axis: -1, eps },
20107 vec![xn, gp, bp],
20108 Shape::new(&[rows, h], f),
20109 );
20110 let loss = fwd.add_node(
20111 Op::Reduce {
20112 op: ReduceOp::Sum,
20113 axes: vec![0, 1],
20114 keep_dim: false,
20115 },
20116 vec![ln],
20117 Shape::from_dims(&[], f),
20118 );
20119 fwd.set_outputs(vec![loss]);
20120
20121 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
20122 let d_out = bwd_graph
20123 .nodes()
20124 .iter()
20125 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20126 .map(|n| n.id)
20127 .unwrap();
20128
20129 let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20130 let mut arena = crate::arena::Arena::from_plan(plan);
20131 let sched = compile_thunks(&bwd_graph, &arena);
20132 for &(id, data) in &[
20133 (xn, &x_data),
20134 (gp, &g_data),
20135 (bp, &b_data),
20136 (d_out, &vec![1.0f32]),
20137 ] {
20138 let off = arena.byte_offset(id);
20139 let buf = arena.raw_buf_mut();
20140 unsafe {
20141 let p = buf.as_mut_ptr().add(off) as *mut f32;
20142 for (i, &v) in data.iter().enumerate() {
20143 *p.add(i) = v;
20144 }
20145 }
20146 }
20147 execute_thunks(&sched, arena.raw_buf_mut());
20148 let read = |id: NodeId, n: usize| -> Vec<f32> {
20149 let off = arena.byte_offset(id);
20150 unsafe {
20151 let p = arena.raw_buf().as_ptr().add(off) as *const f32;
20152 (0..n).map(|i| *p.add(i)).collect()
20153 }
20154 };
20155 let dx_a = read(bwd_graph.outputs[1], rows * h);
20156 let dg_a = read(bwd_graph.outputs[2], h);
20157 let db_a = read(bwd_graph.outputs[3], h);
20158
20159 let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
20160 let mut total = 0f32;
20161 for r in 0..rows {
20162 let row = &x[r * h..(r + 1) * h];
20163 let mean = row.iter().sum::<f32>() / h as f32;
20164 let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
20165 let inv_std = 1.0 / (var + eps).sqrt();
20166 for d in 0..h {
20167 total += ((row[d] - mean) * inv_std) * g[d] + b[d];
20168 }
20169 }
20170 total
20171 };
20172 let h_eps = 1e-3f32;
20173
20174 let mut x_p = x_data.clone();
20175 for i in 0..x_p.len() {
20176 let s = x_p[i];
20177 x_p[i] = s + h_eps;
20178 let lp = forward_loss(&x_p, &g_data, &b_data);
20179 x_p[i] = s - h_eps;
20180 let lm = forward_loss(&x_p, &g_data, &b_data);
20181 x_p[i] = s;
20182 let num = (lp - lm) / (2.0 * h_eps);
20183 assert!(
20184 (dx_a[i] - num).abs() < 5e-3,
20185 "ln dx[{i}]: analytical {} vs numerical {num}",
20186 dx_a[i]
20187 );
20188 }
20189 let mut g_p = g_data.clone();
20190 for i in 0..g_p.len() {
20191 let s = g_p[i];
20192 g_p[i] = s + h_eps;
20193 let lp = forward_loss(&x_data, &g_p, &b_data);
20194 g_p[i] = s - h_eps;
20195 let lm = forward_loss(&x_data, &g_p, &b_data);
20196 g_p[i] = s;
20197 let num = (lp - lm) / (2.0 * h_eps);
20198 assert!(
20199 (dg_a[i] - num).abs() < 5e-3,
20200 "ln dg[{i}]: analytical {} vs numerical {num}",
20201 dg_a[i]
20202 );
20203 }
20204 let mut b_p = b_data.clone();
20205 for i in 0..b_p.len() {
20206 let s = b_p[i];
20207 b_p[i] = s + h_eps;
20208 let lp = forward_loss(&x_data, &g_data, &b_p);
20209 b_p[i] = s - h_eps;
20210 let lm = forward_loss(&x_data, &g_data, &b_p);
20211 b_p[i] = s;
20212 let num = (lp - lm) / (2.0 * h_eps);
20213 assert!(
20214 (db_a[i] - num).abs() < 5e-3,
20215 "ln db[{i}]: analytical {} vs numerical {num}",
20216 db_a[i]
20217 );
20218 }
20219 }
20220
20221 #[test]
20226 fn dense_sce_mean_gradient_matches_numerical() {
20227 use rlx_ir::Philox4x32;
20228 let bs = 4usize;
20229 let k_in = 3usize;
20230 let c = 5usize;
20231 let mut rng = Philox4x32::new(7);
20232 let mut x = vec![0f32; bs * k_in];
20233 rng.fill_normal(&mut x);
20234 let mut w_init = vec![0f32; k_in * c];
20235 rng.fill_normal(&mut w_init);
20236 let mut b_init = vec![0f32; c];
20237 rng.fill_normal(&mut b_init);
20238 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
20239
20240 let f = DType::F32;
20242 let mut fwd = Graph::new("dense_sce");
20243 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
20244 let lb = fwd.input("labels", Shape::new(&[bs], f));
20245 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
20246 let bp = fwd.param("b", Shape::new(&[c], f));
20247 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
20248 let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
20249 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
20250 let loss = fwd.add_node(
20251 Op::Reduce {
20252 op: ReduceOp::Sum,
20253 axes: vec![0],
20254 keep_dim: false,
20255 },
20256 vec![loss_per],
20257 Shape::from_dims(&[], f),
20259 );
20260 fwd.set_outputs(vec![loss]);
20268
20269 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
20271 let d_out = bwd_graph
20274 .nodes()
20275 .iter()
20276 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20277 .map(|n| n.id)
20278 .expect("d_output input");
20279
20280 let (sched, mut arena) = prepare(
20281 &bwd_graph,
20282 &[
20283 (xn, &x),
20284 (lb, &labels),
20285 (wp, &w_init),
20286 (bp, &b_init),
20287 (d_out, &[1.0]),
20288 ],
20289 );
20290 execute_thunks(&sched, arena.raw_buf_mut());
20291
20292 let outs = &bwd_graph.outputs;
20293 let loss_id = outs[0];
20294 let gw_id = outs[1];
20295 let gb_id = outs[2];
20296 let loss_actual = read_arena(&arena, loss_id, 1)[0];
20297 let gw_actual = read_arena(&arena, gw_id, k_in * c);
20298 let gb_actual = read_arena(&arena, gb_id, c);
20299
20300 let plan = rlx_opt::memory::plan_memory(&fwd);
20304 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20305 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20306 write_arena(&mut fwd_arena, xn, &x);
20307 write_arena(&mut fwd_arena, lb, &labels);
20308
20309 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
20310 write_arena(arena, wp, w);
20311 write_arena(arena, bp, b);
20312 execute_thunks(&fwd_sched, arena.raw_buf_mut());
20313 read_arena(arena, loss, 1)[0]
20314 };
20315
20316 let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
20319 assert!(
20320 (loss_actual - loss_check).abs() < 1e-4,
20321 "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
20322 );
20323
20324 let eps = 1e-3f32;
20325 let mut w_perturbed = w_init.clone();
20326 let mut gw_numerical = vec![0f32; w_init.len()];
20327 for i in 0..w_init.len() {
20328 let saved = w_perturbed[i];
20329 w_perturbed[i] = saved + eps;
20330 let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
20331 w_perturbed[i] = saved - eps;
20332 let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
20333 w_perturbed[i] = saved;
20334 gw_numerical[i] = (lp - lm) / (2.0 * eps);
20335 }
20336 for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
20337 assert!(
20338 (a - n).abs() < 5e-3,
20339 "grad_w[{i}]: analytical {a} vs numerical {n}"
20340 );
20341 }
20342
20343 let mut b_perturbed = b_init.clone();
20344 let mut gb_numerical = vec![0f32; b_init.len()];
20345 for i in 0..b_init.len() {
20346 let saved = b_perturbed[i];
20347 b_perturbed[i] = saved + eps;
20348 let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
20349 b_perturbed[i] = saved - eps;
20350 let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
20351 b_perturbed[i] = saved;
20352 gb_numerical[i] = (lp - lm) / (2.0 * eps);
20353 }
20354 for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
20355 assert!(
20356 (a - n).abs() < 5e-3,
20357 "grad_b[{i}]: analytical {a} vs numerical {n}"
20358 );
20359 }
20360 }
20361
20362 #[test]
20365 fn dense_sce_mean_reduce_gradient_matches_numerical() {
20366 use rlx_ir::Philox4x32;
20367 let bs = 3usize;
20368 let k_in = 2usize;
20369 let c = 4usize;
20370 let mut rng = Philox4x32::new(13);
20371 let mut x = vec![0f32; bs * k_in];
20372 rng.fill_normal(&mut x);
20373 let mut w_init = vec![0f32; k_in * c];
20374 rng.fill_normal(&mut w_init);
20375 let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
20376
20377 let f = DType::F32;
20378 let mut fwd = Graph::new("dense_sce_mean");
20379 let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
20380 let lb = fwd.input("labels", Shape::new(&[bs], f));
20381 let wp = fwd.param("w", Shape::new(&[k_in, c], f));
20382 let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
20383 let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
20384 let loss = fwd.add_node(
20385 Op::Reduce {
20386 op: ReduceOp::Mean,
20387 axes: vec![0],
20388 keep_dim: false,
20389 },
20390 vec![loss_per],
20391 Shape::from_dims(&[], f),
20392 );
20393 fwd.set_outputs(vec![loss]);
20394
20395 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
20396 let d_out = bwd_graph
20397 .nodes()
20398 .iter()
20399 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20400 .map(|n| n.id)
20401 .unwrap();
20402
20403 let (sched, mut arena) = prepare(
20404 &bwd_graph,
20405 &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
20406 );
20407 execute_thunks(&sched, arena.raw_buf_mut());
20408
20409 let outs = &bwd_graph.outputs;
20410 let loss_id = outs[0];
20411 let gw_id = outs[1];
20412 let _ = read_arena(&arena, loss_id, 1)[0];
20413 let gw_actual = read_arena(&arena, gw_id, k_in * c);
20414
20415 let plan = rlx_opt::memory::plan_memory(&fwd);
20416 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20417 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20418 write_arena(&mut fwd_arena, xn, &x);
20419 write_arena(&mut fwd_arena, lb, &labels);
20420
20421 let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
20422 write_arena(arena, wp, w);
20423 execute_thunks(&fwd_sched, arena.raw_buf_mut());
20424 read_arena(arena, loss, 1)[0]
20425 };
20426
20427 let eps = 1e-3f32;
20428 let mut wp_p = w_init.clone();
20429 let mut gw_num = vec![0f32; w_init.len()];
20430 for i in 0..w_init.len() {
20431 let s = wp_p[i];
20432 wp_p[i] = s + eps;
20433 let lp = run_loss(&mut fwd_arena, &wp_p);
20434 wp_p[i] = s - eps;
20435 let lm = run_loss(&mut fwd_arena, &wp_p);
20436 wp_p[i] = s;
20437 gw_num[i] = (lp - lm) / (2.0 * eps);
20438 }
20439 for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
20440 assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
20441 }
20442 }
20443 #[test]
20448 fn tinyconv_full_gradient_matches_numerical() {
20449 use rlx_ir::Philox4x32;
20450 let n = 1usize;
20452 let c_in = 1usize;
20453 let h = 6usize;
20454 let w_in = 6usize;
20455 let c_mid = 2usize; let kh = 3;
20457 let kw = 3;
20458 let h1 = h - kh + 1; let w1 = w_in - kw + 1; let h2 = h1 / 2;
20461 let w2 = w1 / 2; let flat = c_mid * h2 * w2; let num_classes = 3usize;
20464
20465 let mut rng = Philox4x32::new(31);
20466 let mut x = vec![0f32; n * c_in * h * w_in];
20467 rng.fill_normal(&mut x);
20468 let mut wc = vec![0f32; c_mid * c_in * kh * kw];
20469 rng.fill_normal(&mut wc);
20470 for v in wc.iter_mut() {
20471 *v *= 0.2;
20472 }
20473 let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
20482 let mut wfc = vec![0f32; flat * num_classes];
20483 rng.fill_normal(&mut wfc);
20484 for v in wfc.iter_mut() {
20485 *v *= 0.5;
20486 }
20487 let mut bfc = vec![0f32; num_classes];
20488 rng.fill_normal(&mut bfc);
20489 let labels: Vec<f32> = vec![1.0]; let f = DType::F32;
20492 let mut fwd = Graph::new("tinyconv");
20493 let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
20494 let lb = fwd.input("labels", Shape::new(&[n], f));
20495 let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
20496 let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
20497 let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
20498 let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
20499
20500 let conv = fwd.add_node(
20502 Op::Conv {
20503 kernel_size: vec![kh, kw],
20504 stride: vec![1, 1],
20505 padding: vec![0, 0],
20506 dilation: vec![1, 1],
20507 groups: 1,
20508 },
20509 vec![xn, wcp],
20510 Shape::new(&[n, c_mid, h1, w1], f),
20511 );
20512 let bc_4d = fwd.add_node(
20524 Op::Reshape {
20525 new_shape: vec![1, c_mid as i64, 1, 1],
20526 },
20527 vec![bcp],
20528 Shape::new(&[1, c_mid, 1, 1], f),
20529 );
20530 let bc_expanded = fwd.add_node(
20531 Op::Expand {
20532 target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
20533 },
20534 vec![bc_4d],
20535 Shape::new(&[n, c_mid, h1, w1], f),
20536 );
20537 let conv_b = fwd.binary(
20538 BinaryOp::Add,
20539 conv,
20540 bc_expanded,
20541 Shape::new(&[n, c_mid, h1, w1], f),
20542 );
20543 let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
20544 let pool = fwd.add_node(
20545 Op::Pool {
20546 kind: ReduceOp::Max,
20547 kernel_size: vec![2, 2],
20548 stride: vec![2, 2],
20549 padding: vec![0, 0],
20550 },
20551 vec![relu],
20552 Shape::new(&[n, c_mid, h2, w2], f),
20553 );
20554 let flatn = fwd.add_node(
20555 Op::Reshape {
20556 new_shape: vec![n as i64, flat as i64],
20557 },
20558 vec![pool],
20559 Shape::new(&[n, flat], f),
20560 );
20561 let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
20562 let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
20563 let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
20564 let loss = fwd.add_node(
20565 Op::Reduce {
20566 op: ReduceOp::Mean,
20567 axes: vec![0],
20568 keep_dim: false,
20569 },
20570 vec![loss_per],
20571 Shape::from_dims(&[], f),
20572 );
20573 fwd.set_outputs(vec![loss]);
20574
20575 let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
20576 let d_out = bwd_graph
20577 .nodes()
20578 .iter()
20579 .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20580 .map(|n| n.id)
20581 .unwrap();
20582
20583 let (sched, mut arena) = prepare(
20584 &bwd_graph,
20585 &[
20586 (xn, &x),
20587 (lb, &labels),
20588 (wcp, &wc),
20589 (bcp, &bc),
20590 (wfp, &wfc),
20591 (bfp, &bfc),
20592 (d_out, &[1.0]),
20593 ],
20594 );
20595 execute_thunks(&sched, arena.raw_buf_mut());
20596
20597 let outs = bwd_graph.outputs.clone();
20598 let loss_id = outs[0];
20599 let g_wc_id = outs[1];
20600 let g_bc_id = outs[2];
20601 let g_wfc_id = outs[3];
20602 let g_bfc_id = outs[4];
20603 let loss_actual = read_arena(&arena, loss_id, 1)[0];
20604 let g_wc = read_arena(&arena, g_wc_id, wc.len());
20605 let g_bc = read_arena(&arena, g_bc_id, bc.len());
20606 let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
20607 let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
20608
20609 let plan = rlx_opt::memory::plan_memory(&fwd);
20611 let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20612 let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20613 write_arena(&mut fwd_arena, xn, &x);
20614 write_arena(&mut fwd_arena, lb, &labels);
20615
20616 let run_loss = |arena: &mut crate::arena::Arena,
20619 wc: &[f32],
20620 bc: &[f32],
20621 wfc: &[f32],
20622 bfc: &[f32]|
20623 -> f32 {
20624 write_arena(arena, wcp, wc);
20625 write_arena(arena, bcp, bc);
20626 write_arena(arena, wfp, wfc);
20627 write_arena(arena, bfp, bfc);
20628 execute_thunks(&fwd_sched, arena.raw_buf_mut());
20629 read_arena(arena, loss, 1)[0]
20630 };
20631
20632 let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
20633 assert!(
20634 (loss_actual - loss_check).abs() < 1e-4,
20635 "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
20636 );
20637
20638 let eps = 1e-3f32;
20639 let check_grad = |arena: &mut crate::arena::Arena,
20640 name: &str,
20641 analytical: &[f32],
20642 mut perturb: Box<
20643 dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
20644 >,
20645 n: usize| {
20646 for i in 0..n {
20647 let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
20648 let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
20649 let num = (lp - lm) / (2.0 * eps);
20650 assert!(
20651 (analytical[i] - num).abs() < 5e-3,
20652 "{name}[{i}]: analytical {} vs numerical {num}",
20653 analytical[i]
20654 );
20655 }
20656 };
20657
20658 #[allow(unused_macros)]
20661 macro_rules! sweep {
20662 ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
20663 let n = $base.len();
20664 for i in 0..n {
20665 let mut p = $base.clone();
20666 let s = p[i];
20667 p[i] = s + eps;
20668 let lp = {
20669 let $set_param = &p;
20670 run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
20671 let _ = $set_param;
20674 0.0_f32
20676 };
20677 let _ = lp;
20678 }
20679 }};
20680 }
20681 let _ = check_grad; for i in 0..wc.len() {
20685 let mut p = wc.clone();
20686 let s = p[i];
20687 p[i] = s + eps;
20688 let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
20689 p[i] = s - eps;
20690 let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
20691 let num = (lp - lm) / (2.0 * eps);
20692 assert!(
20693 (g_wc[i] - num).abs() < 5e-3,
20694 "g_wc[{i}]: {} vs {num}",
20695 g_wc[i]
20696 );
20697 }
20698 for i in 0..bc.len() {
20699 let mut p = bc.clone();
20700 let s = p[i];
20701 p[i] = s + eps;
20702 let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
20703 p[i] = s - eps;
20704 let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
20705 let num = (lp - lm) / (2.0 * eps);
20706 assert!(
20707 (g_bc[i] - num).abs() < 5e-3,
20708 "g_bc[{i}]: {} vs {num}",
20709 g_bc[i]
20710 );
20711 }
20712 for i in 0..wfc.len() {
20713 let mut p = wfc.clone();
20714 let s = p[i];
20715 p[i] = s + eps;
20716 let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
20717 p[i] = s - eps;
20718 let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
20719 let num = (lp - lm) / (2.0 * eps);
20720 assert!(
20721 (g_wfc[i] - num).abs() < 5e-3,
20722 "g_wfc[{i}]: {} vs {num}",
20723 g_wfc[i]
20724 );
20725 }
20726 for i in 0..bfc.len() {
20727 let mut p = bfc.clone();
20728 let s = p[i];
20729 p[i] = s + eps;
20730 let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
20731 p[i] = s - eps;
20732 let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
20733 let num = (lp - lm) / (2.0 * eps);
20734 assert!(
20735 (g_bfc[i] - num).abs() < 5e-3,
20736 "g_bfc[{i}]: {} vs {num}",
20737 g_bfc[i]
20738 );
20739 }
20740 }
20741
20742 #[test]
20746 fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
20747 let f = DType::F32;
20748 let mut g = Graph::new("nr_skip");
20749 let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
20750 let cos = g.input("cos", Shape::new(&[16], f));
20751 let sin = g.input("sin", Shape::new(&[16], f));
20752 let q = g.narrow_(qkv, 2, 0, 64);
20753 let q_rope = g.rope(q, cos, sin, 16);
20754 let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
20756 g.set_outputs(vec![q_rope, q_dup]);
20757
20758 let plan = rlx_opt::memory::plan_memory(&g);
20759 let arena = crate::arena::Arena::from_plan(plan);
20760 let sched = compile_thunks(&g, &arena);
20761
20762 let narrow_count = sched
20763 .thunks
20764 .iter()
20765 .filter(|t| matches!(t, Thunk::Narrow { .. }))
20766 .count();
20767 assert!(
20768 narrow_count >= 1,
20769 "Narrow with multiple consumers must NOT be fused away"
20770 );
20771 }
20772
20773 #[test]
20786 fn custom_fn_forward_inlines_body() {
20787 let s = Shape::new(&[3], DType::F32);
20788
20789 let mut body = Graph::new("addone_body");
20791 let x = body.input("x", s.clone());
20792 let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
20793 let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
20794 let y = body.binary(BinaryOp::Add, x, one, s.clone());
20795 body.set_outputs(vec![y]);
20796
20797 let mut g = Graph::new("custom_fn_outer");
20798 let xin = g.input("x_in", s.clone());
20799 let cf = g.custom_fn(vec![xin], body, None, None);
20800 g.set_outputs(vec![cf]);
20801
20802 let xs = vec![10.0_f32, 20.0, 30.0];
20803 let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
20804 execute_thunks(&sched, arena.raw_buf_mut());
20805 let got = read_arena(&arena, cf, 3);
20806 assert_eq!(got, vec![11.0, 21.0, 31.0]);
20807 }
20808
20809 fn find_named(graph: &Graph, want: &str) -> NodeId {
20811 for n in graph.nodes() {
20812 let name = match &n.op {
20813 Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20814 _ => None,
20815 };
20816 if name == Some(want) {
20817 return n.id;
20818 }
20819 }
20820 panic!("no node named {want:?} in graph");
20821 }
20822
20823 #[test]
20827 fn custom_fn_vjp_overrides_natural_gradient() {
20828 use rlx_opt::autodiff::grad_with_loss;
20829 let s = Shape::new(&[1], DType::F32);
20830
20831 let mut fwd = Graph::new("id_fwd");
20832 let x = fwd.input("x", s.clone());
20833 fwd.set_outputs(vec![x]);
20834
20835 let mut vjp_g = Graph::new("id_vjp");
20836 let _x_p = vjp_g.input("x", s.clone());
20837 let _y_p = vjp_g.input("primal_output", s.clone());
20838 let dy = vjp_g.input("d_output", s.clone());
20839 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
20840 let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
20841 let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
20842 vjp_g.set_outputs(vec![dx]);
20843
20844 let mut g = Graph::new("outer");
20845 let xp = g.param("x", s.clone());
20846 let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
20847 g.set_outputs(vec![cf]);
20848
20849 let bwd = grad_with_loss(&g, &[xp]);
20850 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
20851
20852 let xb = find_named(&bwd, "x");
20853 let dout = find_named(&bwd, "d_output");
20854 let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
20855 execute_thunks(&sched, arena.raw_buf_mut());
20856 let loss = read_arena(&arena, bwd.outputs[0], 1);
20857 let dx_v = read_arena(&arena, bwd.outputs[1], 1);
20858 assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
20859 assert!(
20860 (dx_v[0] - 2.0).abs() < 1e-6,
20861 "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
20862 dx_v[0]
20863 );
20864 }
20865
20866 #[test]
20871 fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
20872 use rlx_opt::autodiff::grad_with_loss;
20873 let s = Shape::new(&[1], DType::F32);
20874
20875 let mut fwd = Graph::new("mul_fwd");
20876 let a_f = fwd.input("a", s.clone());
20877 let b_f = fwd.input("b", s.clone());
20878 let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
20879 fwd.set_outputs(vec![y_f]);
20880
20881 let mut vjp_g = Graph::new("mul_vjp");
20882 let a_v = vjp_g.input("a", s.clone());
20883 let b_v = vjp_g.input("b", s.clone());
20884 let _y_v = vjp_g.input("primal_output", s.clone());
20885 let dy_v = vjp_g.input("d_output", s.clone());
20886 let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
20887 let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
20888 vjp_g.set_outputs(vec![da, db]);
20889
20890 let mut g = Graph::new("outer");
20891 let ap = g.param("a", s.clone());
20892 let bp = g.param("b", s.clone());
20893 let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
20894 g.set_outputs(vec![cf]);
20895
20896 let bwd = grad_with_loss(&g, &[ap, bp]);
20897 assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
20898
20899 let ab = find_named(&bwd, "a");
20900 let bb = find_named(&bwd, "b");
20901 let dout = find_named(&bwd, "d_output");
20902 let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
20903 execute_thunks(&sched, arena.raw_buf_mut());
20904 let loss = read_arena(&arena, bwd.outputs[0], 1);
20905 let da_v = read_arena(&arena, bwd.outputs[1], 1);
20906 let db_v = read_arena(&arena, bwd.outputs[2], 1);
20907 assert!((loss[0] - 15.0).abs() < 1e-5);
20908 assert!(
20909 (da_v[0] - 5.0).abs() < 1e-5,
20910 "da should be b=5.0, got {}",
20911 da_v[0]
20912 );
20913 assert!(
20914 (db_v[0] - 3.0).abs() < 1e-5,
20915 "db should be a=3.0, got {}",
20916 db_v[0]
20917 );
20918 }
20919
20920 #[test]
20923 fn custom_fn_jvp_overrides_natural_tangent() {
20924 use rlx_opt::autodiff_fwd::jvp;
20925 let s = Shape::new(&[1], DType::F32);
20926
20927 let mut fwd = Graph::new("id_fwd");
20928 let x = fwd.input("x", s.clone());
20929 fwd.set_outputs(vec![x]);
20930
20931 let mut jvp_g = Graph::new("id_jvp");
20932 let _x_p = jvp_g.input("x", s.clone());
20933 let tx = jvp_g.input("tangent_0", s.clone());
20934 let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
20935 let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
20936 let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
20937 jvp_g.set_outputs(vec![ty]);
20938
20939 let mut g = Graph::new("outer");
20940 let xin = g.input("x_in", s.clone());
20941 let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
20942 g.set_outputs(vec![cf]);
20943
20944 let fwd_g = jvp(&g, &[xin]);
20945 assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
20946
20947 let xb = find_named(&fwd_g, "x_in");
20948 let tan = find_named(&fwd_g, "tangent_x_in");
20949 let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
20950 execute_thunks(&sched, arena.raw_buf_mut());
20951 let y = read_arena(&arena, fwd_g.outputs[0], 1);
20952 let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
20953 assert!((y[0] - 7.0).abs() < 1e-6);
20954 assert!(
20955 (ty_v[0] - 2.0).abs() < 1e-6,
20956 "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
20957 ty_v[0]
20958 );
20959 }
20960
20961 #[test]
20966 fn c64_dtype_storage_layout() {
20967 assert_eq!(
20968 DType::C64.size_bytes(),
20969 8,
20970 "C64 should be 8 bytes (f32 real + f32 imag)"
20971 );
20972 assert!(DType::C64.is_complex());
20973 assert!(!DType::C64.is_float());
20974
20975 let s = Shape::new(&[2], DType::C64);
20977 assert_eq!(s.size_bytes().unwrap(), 16);
20978 }
20979
20980 fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
20987 let n = a.len();
20988 let s = Shape::new(&[n], DType::C64);
20989 let mut g = Graph::new("c64_bin");
20990 let in_a = g.input("a", s.clone());
20991 let in_b = g.input("b", s.clone());
20992 let out = g.binary(op, in_a, in_b, s.clone());
20993 g.set_outputs(vec![out]);
20994
20995 let plan = rlx_opt::memory::plan_memory(&g);
20996 let mut arena = crate::arena::Arena::from_plan(plan);
20997 let sched = compile_thunks(&g, &arena);
20998
20999 let a_off = arena.byte_offset(in_a);
21000 let b_off = arena.byte_offset(in_b);
21001 let out_off = arena.byte_offset(out);
21002 let buf = arena.raw_buf_mut();
21004 unsafe {
21005 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21006 let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
21007 for (i, &(re, im)) in a.iter().enumerate() {
21008 *pa.add(2 * i) = re;
21009 *pa.add(2 * i + 1) = im;
21010 }
21011 for (i, &(re, im)) in b.iter().enumerate() {
21012 *pb.add(2 * i) = re;
21013 *pb.add(2 * i + 1) = im;
21014 }
21015 }
21016 execute_thunks(&sched, arena.raw_buf_mut());
21017 let raw_out: Vec<f32> = unsafe {
21018 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21019 (0..(2 * n)).map(|i| *p.add(i)).collect()
21020 };
21021 (0..n)
21022 .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
21023 .collect()
21024 }
21025
21026 #[track_caller]
21027 fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
21028 let dr = (got.0 - expected.0).abs();
21029 let di = (got.1 - expected.1).abs();
21030 assert!(
21031 dr < tol && di < tol,
21032 "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
21033 got.0,
21034 got.1,
21035 expected.0,
21036 expected.1
21037 );
21038 }
21039
21040 #[test]
21041 fn c64_binary_add_matches_complex_arithmetic() {
21042 let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
21043 let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
21044 let out = run_c64_binary(BinaryOp::Add, &a, &b);
21045 assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
21046 assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
21047 }
21048
21049 #[test]
21050 fn c64_binary_sub_matches_complex_arithmetic() {
21051 let a = [(5.0_f32, 1.0_f32)];
21052 let b = [(2.0_f32, 3.0_f32)];
21053 let out = run_c64_binary(BinaryOp::Sub, &a, &b);
21054 assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
21055 }
21056
21057 #[test]
21058 fn c64_binary_mul_matches_complex_arithmetic() {
21059 let a = [(1.0_f32, 2.0_f32)];
21061 let b = [(3.0_f32, 4.0_f32)];
21062 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21063 assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
21064 }
21065
21066 #[test]
21067 fn c64_binary_div_matches_complex_arithmetic() {
21068 let a = [(1.0_f32, 2.0_f32)];
21072 let b = [(3.0_f32, 4.0_f32)];
21073 let out = run_c64_binary(BinaryOp::Div, &a, &b);
21074 assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
21075 }
21076
21077 #[test]
21078 fn c64_binary_mul_identity_one_is_no_op() {
21079 let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
21081 let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
21082 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21083 assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
21084 assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
21085 }
21086
21087 #[test]
21088 fn c64_binary_mul_by_i_rotates_90_degrees() {
21089 let a = [(1.0_f32, 0.0_f32)];
21091 let b = [(0.0_f32, 1.0_f32)];
21092 let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21093 assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
21094 }
21095
21096 #[test]
21097 fn c64_binary_div_by_self_gives_unity() {
21098 let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
21099 let out = run_c64_binary(BinaryOp::Div, &a, &a);
21100 assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
21101 assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
21102 }
21103
21104 #[test]
21105 #[should_panic(expected = "C64: complex max/min/pow")]
21106 fn c64_binary_max_is_rejected_at_lowering() {
21107 run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
21108 }
21109
21110 fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
21111 let n = a.len();
21112 let s = Shape::new(&[n], DType::C64);
21113 let mut g = Graph::new("c64_act");
21114 let in_a = g.input("a", s.clone());
21115 let out = g.activation(act, in_a, s.clone());
21116 g.set_outputs(vec![out]);
21117 let plan = rlx_opt::memory::plan_memory(&g);
21118 let mut arena = crate::arena::Arena::from_plan(plan);
21119 let sched = compile_thunks(&g, &arena);
21120 let a_off = arena.byte_offset(in_a);
21121 let out_off = arena.byte_offset(out);
21122 let buf = arena.raw_buf_mut();
21123 unsafe {
21124 let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21125 for (i, &(re, im)) in a.iter().enumerate() {
21126 *pa.add(2 * i) = re;
21127 *pa.add(2 * i + 1) = im;
21128 }
21129 }
21130 execute_thunks(&sched, arena.raw_buf_mut());
21131 let raw: Vec<f32> = unsafe {
21132 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21133 (0..(2 * n)).map(|i| *p.add(i)).collect()
21134 };
21135 (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
21136 }
21137
21138 #[test]
21139 fn c64_activation_neg_negates_both_components() {
21140 let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
21141 let out = run_c64_activation(Activation::Neg, &inp);
21142 assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
21143 assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
21144 }
21145
21146 #[test]
21147 fn c64_activation_exp_matches_euler() {
21148 let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
21151 let out = run_c64_activation(Activation::Exp, &inp);
21152 assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
21153 assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
21154 }
21155
21156 #[test]
21157 fn c64_activation_log_matches_principal_branch() {
21158 let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
21162 let out = run_c64_activation(Activation::Log, &inp);
21163 assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
21164 assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
21165 assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
21166 }
21167
21168 #[test]
21169 fn c64_activation_sqrt_squared_recovers_input() {
21170 let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
21173 let roots = run_c64_activation(Activation::Sqrt, &inp);
21174 assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
21176 assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
21177 }
21178
21179 #[test]
21180 #[should_panic(expected = "no natural complex extension")]
21181 fn c64_activation_relu_is_rejected_at_lowering() {
21182 run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
21183 }
21184
21185 fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
21189 let n = z.len();
21190 let mut g = Graph::new("cns_fwd");
21191 let in_z = g.input("z", Shape::new(&[n], DType::C64));
21192 let out = g.complex_norm_sq(in_z);
21193 g.set_outputs(vec![out]);
21194 let plan = rlx_opt::memory::plan_memory(&g);
21195 let mut arena = crate::arena::Arena::from_plan(plan);
21196 let sched = compile_thunks(&g, &arena);
21197 let z_off = arena.byte_offset(in_z);
21198 let out_off = arena.byte_offset(out);
21199 let buf = arena.raw_buf_mut();
21200 unsafe {
21201 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
21202 for (i, &(re, im)) in z.iter().enumerate() {
21203 *pz.add(2 * i) = re;
21204 *pz.add(2 * i + 1) = im;
21205 }
21206 }
21207 execute_thunks(&sched, arena.raw_buf_mut());
21208 unsafe {
21209 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21210 (0..n).map(|i| *p.add(i)).collect()
21211 }
21212 }
21213
21214 fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
21216 let n = z.len();
21217 let mut gr = Graph::new("cns_bwd");
21218 let in_z = gr.input("z", Shape::new(&[n], DType::C64));
21219 let in_g = gr.input("g", Shape::new(&[n], DType::F32));
21220 let out = gr.complex_norm_sq_backward(in_z, in_g);
21221 gr.set_outputs(vec![out]);
21222 let plan = rlx_opt::memory::plan_memory(&gr);
21223 let mut arena = crate::arena::Arena::from_plan(plan);
21224 let sched = compile_thunks(&gr, &arena);
21225 let z_off = arena.byte_offset(in_z);
21226 let g_off = arena.byte_offset(in_g);
21227 let out_off = arena.byte_offset(out);
21228 let buf = arena.raw_buf_mut();
21229 unsafe {
21230 let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
21231 let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
21232 for (i, &(re, im)) in z.iter().enumerate() {
21233 *pz.add(2 * i) = re;
21234 *pz.add(2 * i + 1) = im;
21235 }
21236 for (i, &v) in g.iter().enumerate() {
21237 *pg.add(i) = v;
21238 }
21239 }
21240 execute_thunks(&sched, arena.raw_buf_mut());
21241 unsafe {
21242 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21243 (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
21244 }
21245 }
21246
21247 #[test]
21248 fn complex_norm_sq_matches_textbook() {
21249 let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
21253 let out = run_complex_norm_sq(&z);
21254 assert!((out[0] - 25.0).abs() < 1e-5);
21255 assert!((out[1] - 1.0).abs() < 1e-6);
21256 assert!(out[2].abs() < 1e-6);
21257 }
21258
21259 #[test]
21260 fn complex_norm_sq_backward_matches_wirtinger_formula() {
21261 let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
21263 let g = [1.0_f32, 1.0_f32];
21264 let dz = run_complex_norm_sq_bwd(&z, &g);
21265 assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
21266 assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
21267 }
21268
21269 #[test]
21270 fn complex_norm_sq_backward_scales_with_upstream() {
21271 let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
21273 let g = [0.5_f32, -2.0_f32];
21274 let dz = run_complex_norm_sq_bwd(&z, &g);
21275 assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
21276 assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
21277 }
21278
21279 #[test]
21284 fn custom_fn_multi_extracts_each_subgraph_output() {
21285 use rlx_ir::ops::special::MultiOutputHandle;
21286
21287 let _ = MultiOutputHandle {
21288 source: NodeId(0),
21289 sub_shapes: vec![],
21290 offsets: vec![],
21291 }; let mut body = Graph::new("multi_body");
21295 let s3 = Shape::new(&[3], DType::F32);
21296 let x = body.input("x", s3.clone());
21297 let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
21298 let two = body.add_node(
21299 Op::Constant {
21300 data: vec![
21301 2.0_f32.to_le_bytes(),
21302 2.0_f32.to_le_bytes(),
21303 2.0_f32.to_le_bytes(),
21304 ]
21305 .into_iter()
21306 .flatten()
21307 .collect(),
21308 },
21309 vec![],
21310 s3.clone(),
21311 );
21312 let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
21313 body.set_outputs(vec![x_sq, two_x]);
21314
21315 let mut outer = Graph::new("multi_outer");
21317 let in_x = outer.input("xin", s3.clone());
21318 let handle = outer.custom_fn_multi(vec![in_x], body);
21319 assert_eq!(handle.n_outputs(), 2);
21320 let out0 = handle.output(&mut outer, 0); let out1 = handle.output(&mut outer, 1); outer.set_outputs(vec![out0, out1]);
21323
21324 let plan = rlx_opt::memory::plan_memory(&outer);
21325 let mut arena = crate::arena::Arena::from_plan(plan);
21326 let sched = compile_thunks(&outer, &arena);
21327 let xin_off = arena.byte_offset(in_x);
21328 let out0_off = arena.byte_offset(out0);
21329 let out1_off = arena.byte_offset(out1);
21330 let xs = [1.0_f32, 2.0, 3.0];
21331 unsafe {
21332 let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
21333 for (i, &v) in xs.iter().enumerate() {
21334 *p.add(i) = v;
21335 }
21336 }
21337 execute_thunks(&sched, arena.raw_buf_mut());
21338 let out0_v: Vec<f32> = unsafe {
21339 let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
21340 (0..3).map(|i| *p.add(i)).collect()
21341 };
21342 let out1_v: Vec<f32> = unsafe {
21343 let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
21344 (0..3).map(|i| *p.add(i)).collect()
21345 };
21346 for i in 0..3 {
21348 assert!(
21349 (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
21350 "out0[{i}] = {} != x² = {}",
21351 out0_v[i],
21352 xs[i] * xs[i]
21353 );
21354 assert!(
21355 (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
21356 "out1[{i}] = {} != 2x = {}",
21357 out1_v[i],
21358 2.0 * xs[i]
21359 );
21360 }
21361 }
21362
21363 #[test]
21364 fn complex_norm_sq_gradient_matches_finite_difference() {
21365 let z = [(3.0_f32, 4.0_f32)];
21367 let eps = 1e-3_f32;
21368 let v0 = run_complex_norm_sq(&z)[0];
21369 let z_pert = [(3.0_f32 + eps, 4.0_f32)];
21370 let v1 = run_complex_norm_sq(&z_pert)[0];
21371 let fd_re = (v1 - v0) / eps;
21372 let analytic_re = 2.0 * z[0].0;
21373 assert!((fd_re - analytic_re).abs() < 1e-2);
21374
21375 let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
21377 let v2 = run_complex_norm_sq(&z_pert_im)[0];
21378 let fd_im = (v2 - v0) / eps;
21379 let analytic_im = 2.0 * z[0].1;
21380 assert!((fd_im - analytic_im).abs() < 1e-2);
21381
21382 let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
21388 assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
21389 assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
21390 }
21391
21392 #[test]
21397 fn binary_full_5d_mid_singleton_broadcast() {
21398 let bh = 2usize;
21399 let h = 3;
21400 let w = 4;
21401 let f = DType::F32;
21402
21403 let mut g = Graph::new("bcast_5d");
21404 let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
21405 let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
21407 let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
21408 g.set_outputs(vec![out]);
21409
21410 let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
21412 let rhs_data: Vec<f32> = (0..bh * h * w * w)
21413 .map(|i| (i as f32 + 100.0) * 0.01)
21414 .collect();
21415
21416 let mut expected = vec![0f32; bh * h * w * h * w];
21418 for b_ in 0..bh {
21419 for hq in 0..h {
21420 for wq in 0..w {
21421 for hk in 0..h {
21422 for wk in 0..w {
21423 let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
21424 let ri = ((b_ * h + hq) * w + wq) * w + wk;
21426 expected[li] = lhs_data[li] + rhs_data[ri];
21427 }
21428 }
21429 }
21430 }
21431 }
21432
21433 let plan = rlx_opt::memory::plan_memory(&g);
21434 let mut arena = crate::arena::Arena::from_plan(plan);
21435 let sched = compile_thunks(&g, &arena);
21436 let lhs_off = arena.byte_offset(lhs);
21437 let rhs_off = arena.byte_offset(rhs);
21438 let out_off = arena.byte_offset(out);
21439 let buf = arena.raw_buf_mut();
21440 unsafe {
21441 let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
21442 for (i, &v) in lhs_data.iter().enumerate() {
21443 *p.add(i) = v;
21444 }
21445 let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
21446 for (i, &v) in rhs_data.iter().enumerate() {
21447 *p.add(i) = v;
21448 }
21449 }
21450 execute_thunks(&sched, arena.raw_buf_mut());
21451 let actual: Vec<f32> = unsafe {
21452 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21453 (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
21454 };
21455
21456 let mut max_diff = 0f32;
21458 let mut max_idx = 0;
21459 for i in 0..actual.len() {
21460 let d = (actual[i] - expected[i]).abs();
21461 if d > max_diff {
21462 max_diff = d;
21463 max_idx = i;
21464 }
21465 }
21466 assert!(
21467 max_diff < 1e-6,
21468 "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
21469 (actual={}, expected={})",
21470 actual[max_idx],
21471 expected[max_idx]
21472 );
21473 }
21474
21475 #[test]
21476 fn layer_norm2d_and_conv_transpose2d_kernels() {
21477 let mut out = vec![0f32; 8];
21478 crate::kernels::layer_norm2d_nchw(
21479 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
21480 &[1.0, 1.0],
21481 &[0.0, 0.0],
21482 &mut out,
21483 1,
21484 2,
21485 2,
21486 2,
21487 1e-5,
21488 );
21489 let mean0: f32 = (1.0 + 3.0) / 2.0;
21490 assert!((out[0] - mean0).abs() > 0.1);
21491
21492 let mut up = vec![0f32; 4];
21493 crate::kernels::conv_transpose2d_nchw(
21494 &[2.0],
21495 &[1.0, 0.0, 0.0, 1.0],
21496 &mut up,
21497 1,
21498 1,
21499 1,
21500 1,
21501 1,
21502 2,
21503 2,
21504 2,
21505 2,
21506 2,
21507 2,
21508 0,
21509 0,
21510 1,
21511 1,
21512 1,
21513 );
21514 assert!((up[0] - 2.0).abs() < 1e-5);
21515 assert!((up[3] - 2.0).abs() < 1e-5);
21516 }
21517}