1use crate::arena::Arena;
23use crate::kernels;
24use rlx_ir::op::{Activation, BinaryOp, ReduceOp};
25use rlx_ir::{Graph, NodeId, Op};
26use std::collections::HashMap;
27
28pub struct ExternalBuffers<'a> {
30 pub buffers: HashMap<NodeId, &'a [f32]>,
32}
33
34pub fn execute(graph: &Graph, arena: &mut Arena, external: &ExternalBuffers) {
42 let schedule: Vec<NodeId> = arena.schedule().to_vec();
43 for &node_id in &schedule {
44 let node = graph.node(node_id);
45
46 match &node.op {
47 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => {}
50
51 Op::FusedMatMulBiasAct { activation } => {
53 let input_id = node.inputs[0];
54 let weight_id = node.inputs[1];
55 let bias_id = node.inputs[2];
56
57 let input = get_data(arena, external, input_id);
58 let weight = get_data(arena, external, weight_id);
59 let bias = get_data(arena, external, bias_id);
60 let output = get_output(arena, node_id);
61
62 let shape = &node.shape;
64 let n = shape.dim(shape.rank() - 1).unwrap_static();
65 let m = shape.num_elements().unwrap() / n;
66 let k = input.len() / m;
67
68 matmul(input, weight, output, m, k, n);
72
73 match activation {
75 Some(Activation::Gelu) => kernels::par_bias_gelu(output, bias, m, n),
76 Some(Activation::Silu) => {
77 crate::blas::bias_add(output, bias, m, n);
78 kernels::silu_inplace(output);
79 }
80 _ => crate::blas::bias_add(output, bias, m, n),
81 }
82 }
83
84 Op::FusedResidualLN { has_bias, eps } => {
86 let x_id = node.inputs[0];
87 let residual_id = node.inputs[1];
88 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
89 let zero_bias = vec![0f32; h];
90 let (gamma_id, beta_id, bias_slice) = if *has_bias {
91 let b = get_data(arena, external, node.inputs[2]);
92 (node.inputs[3], node.inputs[4], b)
93 } else {
94 (node.inputs[2], node.inputs[3], zero_bias.as_slice())
95 };
96
97 let x = get_data(arena, external, x_id);
98 let residual = get_data(arena, external, residual_id);
99 let gamma = get_data(arena, external, gamma_id);
100 let beta = get_data(arena, external, beta_id);
101 let output = get_output(arena, node_id);
102
103 let n = x.len() / h;
104
105 let x_ptr = x.as_ptr() as usize;
107 let r_ptr = residual.as_ptr() as usize;
108 let o_ptr = output.as_mut_ptr() as usize;
109 let bi_ptr = bias_slice.as_ptr() as usize;
110 let g_ptr = gamma.as_ptr() as usize;
111 let b_ptr = beta.as_ptr() as usize;
112 let e = *eps;
113 crate::pool::par_for(n, 4, &|off, cnt| unsafe {
114 let x_s =
115 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
116 let r_s =
117 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
118 let o_s =
119 std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
120 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
121 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
122 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
123 kernels::residual_bias_layer_norm(x_s, r_s, bi, g, b, o_s, cnt, h, e);
124 });
125 }
126
127 Op::FusedResidualRmsNorm { has_bias, eps } => {
129 let x_id = node.inputs[0];
130 let residual_id = node.inputs[1];
131 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
132 let zero_bias = vec![0f32; h];
133 let (gamma_id, beta_id, bias_slice) = if *has_bias {
134 let b = get_data(arena, external, node.inputs[2]);
135 (node.inputs[3], node.inputs[4], b)
136 } else {
137 (node.inputs[2], node.inputs[3], zero_bias.as_slice())
138 };
139
140 let x = get_data(arena, external, x_id);
141 let residual = get_data(arena, external, residual_id);
142 let gamma = get_data(arena, external, gamma_id);
143 let beta = get_data(arena, external, beta_id);
144 let output = get_output(arena, node_id);
145
146 let n = x.len() / h;
147
148 let x_ptr = x.as_ptr() as usize;
149 let r_ptr = residual.as_ptr() as usize;
150 let o_ptr = output.as_mut_ptr() as usize;
151 let bi_ptr = bias_slice.as_ptr() as usize;
152 let g_ptr = gamma.as_ptr() as usize;
153 let b_ptr = beta.as_ptr() as usize;
154 let e = *eps;
155 crate::pool::par_for(n, 4, &|off, cnt| unsafe {
156 let x_s =
157 std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
158 let r_s =
159 std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
160 let o_s =
161 std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
162 let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
163 let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
164 let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
165 kernels::residual_bias_rms_norm(x_s, r_s, bi, g, b, o_s, cnt, h, e);
166 });
167 }
168
169 Op::MatMul => {
171 let lhs = get_data(arena, external, node.inputs[0]);
172 let rhs = get_data(arena, external, node.inputs[1]);
173 let output = get_output(arena, node_id);
174
175 let shape = &node.shape;
176 let lhs_shape = &graph.node(node.inputs[0]).shape;
177 let rhs_shape = &graph.node(node.inputs[1]).shape;
178 let n = shape.dim(shape.rank() - 1).unwrap_static();
179 let out_m_inner = shape.dim(shape.rank() - 2).unwrap_static();
180 let k = lhs_shape.dim(lhs_shape.rank() - 1).unwrap_static();
181
182 let total = shape.num_elements().unwrap();
185 let per_batch_out = out_m_inner * n;
186 let batches = total / per_batch_out;
187
188 if batches == 1 {
189 matmul(lhs, rhs, output, out_m_inner, k, n);
190 } else {
191 let lhs_batched =
192 lhs_shape.num_elements().unwrap_or(0) == batches * out_m_inner * k;
193 let rhs_batched = rhs_shape.num_elements().unwrap_or(0) == batches * k * n;
194 for b in 0..batches {
195 let l_off = if lhs_batched { b * out_m_inner * k } else { 0 };
196 let r_off = if rhs_batched { b * k * n } else { 0 };
197 let o_off = b * out_m_inner * n;
198 let l_slice = &lhs[l_off..l_off + out_m_inner * k];
199 let r_slice = &rhs[r_off..r_off + k * n];
200 let o_slice = &mut output[o_off..o_off + out_m_inner * n];
201 matmul(l_slice, r_slice, o_slice, out_m_inner, k, n);
202 }
203 }
204 }
205
206 Op::Binary(op) => {
208 let lhs = get_data(arena, external, node.inputs[0]);
209 let rhs = get_data(arena, external, node.inputs[1]);
210 let output = get_output(arena, node_id);
211 let len = output.len();
212 let rhs_len = rhs.len();
213
214 if matches!(op, BinaryOp::Add) && rhs_len < len && len.is_multiple_of(rhs_len) {
216 output.copy_from_slice(lhs);
217 crate::blas::bias_add(output, rhs, len / rhs_len, rhs_len);
218 } else if rhs_len == len {
219 for i in 0..len {
220 output[i] = binary_op(*op, lhs[i], rhs[i]);
221 }
222 } else {
223 for i in 0..len {
224 output[i] = binary_op(*op, lhs[i], rhs[i % rhs_len]);
225 }
226 }
227 }
228
229 Op::Activation(act) => {
231 let input = get_data(arena, external, node.inputs[0]);
232 let output = get_output(arena, node_id);
233 output.copy_from_slice(input);
234 let zeros = vec![0f32; node.shape.dim(node.shape.rank() - 1).unwrap_static()];
235 let m = output.len() / zeros.len();
236 let n = zeros.len();
237 match act {
238 Activation::Gelu => kernels::par_bias_gelu(output, &zeros, m, n),
239 Activation::Silu => kernels::silu_inplace(output),
240 Activation::Relu => {
241 for v in output.iter_mut() {
242 *v = v.max(0.0);
243 }
244 }
245 Activation::Exp => {
246 for v in output.iter_mut() {
247 *v = v.exp();
248 }
249 }
250 Activation::Sqrt => {
251 for v in output.iter_mut() {
252 *v = v.sqrt();
253 }
254 }
255 Activation::Neg => {
256 for v in output.iter_mut() {
257 *v = -*v;
258 }
259 }
260 Activation::Tanh => {
261 for v in output.iter_mut() {
262 *v = v.tanh();
263 }
264 }
265 Activation::Sigmoid => {
266 for v in output.iter_mut() {
267 *v = 1.0 / (1.0 + (-*v).exp());
268 }
269 }
270 _ => {}
271 }
272 }
273
274 Op::Gather { axis } => {
276 let table = get_data(arena, external, node.inputs[0]);
277 let indices = get_data(arena, external, node.inputs[1]);
278 let output = get_output(arena, node_id);
279
280 let table_shape = &graph.node(node.inputs[0]).shape;
281 let _out_shape = &node.shape;
282
283 if *axis == 0 {
285 let trailing: usize = (1..table_shape.rank())
286 .map(|i| table_shape.dim(i).unwrap_static())
287 .product();
288 for (i, &idx_f32) in indices.iter().enumerate() {
289 let idx = idx_f32 as usize;
290 let src = idx * trailing;
291 let dst = i * trailing;
292 output[dst..dst + trailing].copy_from_slice(&table[src..src + trailing]);
293 }
294 } else {
295 output.fill(0.0);
297 }
298 }
299
300 Op::Narrow { axis, start, len } => {
302 let input = get_data(arena, external, node.inputs[0]);
303 let output = get_output(arena, node_id);
304 let in_shape = &graph.node(node.inputs[0]).shape;
305
306 let rank = in_shape.rank();
307 let outer: usize = (0..*axis)
308 .map(|i| in_shape.dim(i).unwrap_static())
309 .product::<usize>()
310 .max(1);
311 let inner: usize = (*axis + 1..rank)
312 .map(|i| in_shape.dim(i).unwrap_static())
313 .product::<usize>()
314 .max(1);
315 let in_axis_size = in_shape.dim(*axis).unwrap_static();
316
317 for o in 0..outer {
318 for s in 0..*len {
319 let src_off = o * in_axis_size * inner + (*start + s) * inner;
320 let dst_off = o * len * inner + s * inner;
321 output[dst_off..dst_off + inner]
322 .copy_from_slice(&input[src_off..src_off + inner]);
323 }
324 }
325 }
326
327 Op::Transpose { perm } => {
329 let input = get_data(arena, external, node.inputs[0]);
330 let output = get_output(arena, node_id);
331 let in_shape = &graph.node(node.inputs[0]).shape;
332 let rank = in_shape.rank();
333
334 let in_dims: Vec<usize> =
335 (0..rank).map(|i| in_shape.dim(i).unwrap_static()).collect();
336 let out_dims: Vec<usize> = perm.iter().map(|&i| in_dims[i]).collect();
337
338 let mut in_strides = vec![1usize; rank];
341 for i in (0..rank - 1).rev() {
342 in_strides[i] = in_strides[i + 1] * in_dims[i + 1];
343 }
344 let mut out_strides = vec![1usize; rank];
345 for i in (0..rank - 1).rev() {
346 out_strides[i] = out_strides[i + 1] * out_dims[i + 1];
347 }
348
349 let total = output.len();
350 for flat_out in 0..total {
351 let mut in_flat = 0;
352 for d in 0..rank {
353 let coord = (flat_out / out_strides[d]) % out_dims[d];
355 in_flat += coord * in_strides[perm[d]];
357 }
358 output[flat_out] = input[in_flat];
359 }
360 }
361
362 Op::Concat { axis } => {
364 let output = get_output(arena, node_id);
365 let out_shape = &node.shape;
366 let rank = out_shape.rank();
367
368 let outer: usize = (0..*axis)
369 .map(|i| out_shape.dim(i).unwrap_static())
370 .product::<usize>()
371 .max(1);
372 let inner: usize = (*axis + 1..rank)
373 .map(|i| out_shape.dim(i).unwrap_static())
374 .product::<usize>()
375 .max(1);
376
377 let mut dst_off = 0;
378 for o in 0..outer {
379 for &inp_id in &node.inputs {
380 let inp = get_data(arena, external, inp_id);
381 let inp_shape = &graph.node(inp_id).shape;
382 let inp_axis = inp_shape.dim(*axis).unwrap_static();
383 let chunk = inp_axis * inner;
384 let src_off = o * chunk;
385 output[dst_off..dst_off + chunk]
386 .copy_from_slice(&inp[src_off..src_off + chunk]);
387 dst_off += chunk;
388 }
389 }
390 }
391
392 Op::Reshape { .. } | Op::Expand { .. } => {
394 let input = get_data(arena, external, node.inputs[0]);
395 let output = get_output(arena, node_id);
396 output[..input.len()].copy_from_slice(input);
397 }
398
399 Op::LayerNorm { eps, .. } => {
401 let input = get_data(arena, external, node.inputs[0]);
402 let gamma = get_data(arena, external, node.inputs[1]);
403 let beta = get_data(arena, external, node.inputs[2]);
404 let output = get_output(arena, node_id);
405 let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
406 let n = input.len() / h;
407 for row in 0..n {
408 let base = row * h;
409 kernels::layer_norm_row(
410 &input[base..base + h],
411 gamma,
412 beta,
413 &mut output[base..base + h],
414 h,
415 *eps,
416 );
417 }
418 }
419
420 Op::GroupNorm { num_groups, eps } => {
421 let input = get_data(arena, external, node.inputs[0]);
422 let gamma = get_data(arena, external, node.inputs[1]);
423 let beta = get_data(arena, external, node.inputs[2]);
424 let output = get_output(arena, node_id);
425 let n = node.shape.dim(0).unwrap_static();
426 let c = node.shape.dim(1).unwrap_static();
427 let h = node.shape.dim(2).unwrap_static();
428 let w = node.shape.dim(3).unwrap_static();
429 kernels::group_norm_nchw(input, gamma, beta, output, n, c, h, w, *num_groups, *eps);
430 }
431
432 Op::ResizeNearest2x => {
433 let input = get_data(arena, external, node.inputs[0]);
434 let output = get_output(arena, node_id);
435 let n = node.shape.dim(0).unwrap_static();
436 let c = node.shape.dim(1).unwrap_static();
437 let h = node.shape.dim(2).unwrap_static() / 2;
438 let w = node.shape.dim(3).unwrap_static() / 2;
439 let in_plane = c * h * w;
440 let out_plane = c * h * 2 * w * 2;
441 for ni in 0..n {
442 kernels::resize_nearest_2x_nchw(
443 &input[ni * in_plane..(ni + 1) * in_plane],
444 &mut output[ni * out_plane..(ni + 1) * out_plane],
445 c,
446 h,
447 w,
448 );
449 }
450 }
451
452 Op::AxialRope2d {
453 end_x,
454 end_y,
455 head_dim,
456 num_heads,
457 theta,
458 repeat_factor,
459 } => {
460 let input = get_data(arena, external, node.inputs[0]);
461 let output = get_output(arena, node_id);
462 let batch = node.shape.dim(0).unwrap_static();
463 let seq = node.shape.dim(1).unwrap_static();
464 let plane = seq * node.shape.dim(2).unwrap_static();
465 for bi in 0..batch {
466 let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
467 &input[bi * plane..(bi + 1) * plane],
468 *num_heads,
469 seq,
470 *head_dim,
471 *end_x,
472 *end_y,
473 *theta,
474 *repeat_factor,
475 );
476 output[bi * plane..(bi + 1) * plane].copy_from_slice(&rotated);
477 }
478 }
479
480 Op::Softmax { axis } => {
482 let input = get_data(arena, external, node.inputs[0]);
483 let output = get_output(arena, node_id);
484 output.copy_from_slice(input);
485 let rank = node.shape.rank();
486 let ax = if *axis < 0 {
487 (rank as i32 + axis) as usize
488 } else {
489 *axis as usize
490 };
491 let cols = node.shape.dim(ax).unwrap_static();
492 let rows = output.len() / cols;
493 crate::naive::softmax(output, rows, cols);
494 }
495
496 Op::Attention {
498 num_heads,
499 head_dim,
500 mask_kind,
501 score_scale,
502 attn_logit_softcap,
503 } => {
504 let q = get_data(arena, external, node.inputs[0]);
505 let k = get_data(arena, external, node.inputs[1]);
506 let v = get_data(arena, external, node.inputs[2]);
507 let mask: &[f32] = if matches!(
511 mask_kind,
512 rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
513 ) {
514 get_data(arena, external, node.inputs[3])
515 } else {
516 &[]
517 };
518 let output = get_output(arena, node_id);
519
520 let q_shape = &graph.node(node.inputs[0]).shape;
521 let k_shape = &graph.node(node.inputs[1]).shape;
522 let hs = num_heads * head_dim;
523 let scale = score_scale.unwrap_or((*head_dim as f32).powf(-0.5));
524 let (batch_size, s_q) = if q_shape.rank() >= 3 {
525 (
526 q_shape.dim(0).unwrap_static(),
527 q_shape.dim(1).unwrap_static(),
528 )
529 } else {
530 (1, q_shape.dim(0).unwrap_static())
531 };
532 let s_k = if k_shape.rank() >= 3 {
537 k_shape.dim(1).unwrap_static()
538 } else {
539 k_shape.dim(0).unwrap_static()
540 };
541 let q_offset = s_k.saturating_sub(s_q);
542
543 let q_buf_len = s_q * head_dim;
545 let k_buf_len = s_k * head_dim;
546 let mut q_head = vec![0f32; q_buf_len];
547 let mut k_head = vec![0f32; k_buf_len];
548 let mut v_head = vec![0f32; k_buf_len];
549 let mut scores = vec![0f32; s_q * s_k];
550 let mut out_head = vec![0f32; q_buf_len];
551
552 for bi in 0..batch_size {
553 for hi in 0..*num_heads {
554 for si in 0..s_q {
556 let off = bi * s_q * hs + si * hs + hi * head_dim;
557 q_head[si * head_dim..(si + 1) * head_dim]
558 .copy_from_slice(&q[off..off + head_dim]);
559 }
560 for si in 0..s_k {
562 let off = bi * s_k * hs + si * hs + hi * head_dim;
563 k_head[si * head_dim..(si + 1) * head_dim]
564 .copy_from_slice(&k[off..off + head_dim]);
565 v_head[si * head_dim..(si + 1) * head_dim]
566 .copy_from_slice(&v[off..off + head_dim]);
567 }
568 if s_q.max(s_k) <= 32 {
571 for qi in 0..s_q {
572 for ki in 0..s_k {
573 let q_off = qi * head_dim;
574 let k_off = ki * head_dim;
575 #[cfg(target_arch = "aarch64")]
576 let mut dot;
577 #[cfg(not(target_arch = "aarch64"))]
578 let mut dot = 0f32;
579 #[cfg(target_arch = "aarch64")]
580 unsafe {
581 use std::arch::aarch64::*;
582 let chunks = head_dim / 4;
583 let mut acc = vdupq_n_f32(0.0);
584 for c in 0..chunks {
585 let vq = vld1q_f32(q_head.as_ptr().add(q_off + c * 4));
586 let vk = vld1q_f32(k_head.as_ptr().add(k_off + c * 4));
587 acc = vfmaq_f32(acc, vq, vk);
588 }
589 dot = vaddvq_f32(acc);
590 for d in (chunks * 4)..*head_dim {
591 dot += q_head[q_off + d] * k_head[k_off + d];
592 }
593 }
594 #[cfg(not(target_arch = "aarch64"))]
595 {
596 for d in 0..*head_dim {
597 dot += q_head[q_off + d] * k_head[k_off + d];
598 }
599 }
600 scores[qi * s_k + ki] = dot * scale;
601 }
602 }
603 } else {
604 crate::blas::sgemm_bt(
605 &q_head,
606 &k_head,
607 &mut scores,
608 s_q,
609 *head_dim,
610 s_k,
611 scale,
612 );
613 }
614 match mask_kind {
619 rlx_ir::op::MaskKind::None => {}
620 rlx_ir::op::MaskKind::Causal => {
621 for qi in 0..s_q {
622 let abs_q = q_offset + qi;
623 for ki in (abs_q + 1)..s_k {
624 scores[qi * s_k + ki] = -1e9;
625 }
626 }
627 }
628 rlx_ir::op::MaskKind::SlidingWindow(w) => {
629 for qi in 0..s_q {
630 let abs_q = q_offset + qi;
631 let lo = abs_q.saturating_sub(*w);
632 for ki in 0..s_k {
633 if ki < lo || ki > abs_q {
634 scores[qi * s_k + ki] = -1e9;
635 }
636 }
637 }
638 }
639 rlx_ir::op::MaskKind::Custom => {
640 if mask.len() >= (bi + 1) * s_k {
641 let m = &mask[bi * s_k..(bi + 1) * s_k];
642 for qi in 0..s_q {
643 for ki in 0..s_k {
644 if m[ki] < 0.5 {
645 scores[qi * s_k + ki] = -1e9;
646 }
647 }
648 }
649 }
650 }
651 rlx_ir::op::MaskKind::Bias => {
652 let per_bh = s_q * s_k;
656 let need = (bi * *num_heads + hi + 1) * per_bh;
657 if mask.len() >= need {
658 let bias_off = (bi * *num_heads + hi) * per_bh;
659 let b = &mask[bias_off..bias_off + per_bh];
660 for i in 0..per_bh {
661 scores[i] += b[i];
662 }
663 }
664 }
665 }
666 if let Some(cap) = attn_logit_softcap {
667 if *cap > 0.0 {
668 for s in scores.iter_mut() {
669 *s = cap * (*s / cap).tanh();
670 }
671 }
672 }
673 crate::naive::softmax(&mut scores, s_q, s_k);
674 if s_q.max(s_k) <= 32 {
676 out_head.fill(0.0);
677 for qi in 0..s_q {
678 for ki in 0..s_k {
679 let sc = scores[qi * s_k + ki];
680 if sc > 1e-8 {
681 let v_off = ki * head_dim;
682 let o_off = qi * head_dim;
683 #[cfg(target_arch = "aarch64")]
684 unsafe {
685 use std::arch::aarch64::*;
686 let vsc = vdupq_n_f32(sc);
687 let chunks = head_dim / 4;
688 for c in 0..chunks {
689 let off = c * 4;
690 let vo =
691 vld1q_f32(out_head.as_ptr().add(o_off + off));
692 let vv =
693 vld1q_f32(v_head.as_ptr().add(v_off + off));
694 vst1q_f32(
695 out_head.as_mut_ptr().add(o_off + off),
696 vfmaq_f32(vo, vsc, vv),
697 );
698 }
699 }
700 #[cfg(not(target_arch = "aarch64"))]
701 for d in 0..*head_dim {
702 out_head[o_off + d] += sc * v_head[v_off + d];
703 }
704 }
705 }
706 }
707 } else {
708 crate::blas::sgemm(
709 &scores,
710 &v_head,
711 &mut out_head,
712 s_q,
713 s_k,
714 *head_dim,
715 );
716 }
717 for si in 0..s_q {
719 let off = bi * s_q * hs + si * hs + hi * head_dim;
720 output[off..off + head_dim]
721 .copy_from_slice(&out_head[si * head_dim..(si + 1) * head_dim]);
722 }
723 }
724 }
725 }
726
727 Op::Rope { head_dim, n_rot } => {
744 let head_dim = *head_dim;
745 let n_rot = *n_rot;
746 let x = get_data(arena, external, node.inputs[0]);
747 let cos_cache = get_data(arena, external, node.inputs[1]);
748 let sin_cache = get_data(arena, external, node.inputs[2]);
749 let x_shape = &graph.node(node.inputs[0]).shape;
750 let output = get_output(arena, node_id);
751 output.copy_from_slice(x);
752
753 let rot_half = n_rot / 2;
754 let tab_half = head_dim / 2;
755 let total = output.len();
756 let num_chunks = total / head_dim;
757
758 let cos_rows = cos_cache.len() / tab_half.max(1);
761 let (s_dim, heads_per_seq): (usize, usize) = {
762 let rank = x_shape.rank();
763 if rank == 0 {
764 (1, 1)
765 } else {
766 let last = if x_shape.dim(rank - 1).is_static() {
767 x_shape.dim(rank - 1).unwrap_static()
768 } else {
769 head_dim
770 };
771 if rank >= 3 && last > head_dim && last.is_multiple_of(head_dim) {
772 let s = if x_shape.dim(rank - 2).is_static() {
774 x_shape.dim(rank - 2).unwrap_static()
775 } else {
776 1
777 };
778 (s, last / head_dim)
779 } else if rank >= 4 && last == head_dim {
780 let s = if x_shape.dim(rank - 2).is_static() {
782 x_shape.dim(rank - 2).unwrap_static()
783 } else {
784 1
785 };
786 (s, 1)
787 } else if rank >= 3 && last == head_dim {
788 let s = if x_shape.dim(rank - 2).is_static() {
790 x_shape.dim(rank - 2).unwrap_static()
791 } else {
792 1
793 };
794 (s, 1)
795 } else {
796 (cos_rows.max(1), 1)
798 }
799 }
800 };
801
802 if std::env::var("RLX_ROPE_DEBUG").is_ok() {
803 eprintln!(
804 "[rope] shape={:?} num_chunks={num_chunks} cos_rows={cos_rows} s_dim={s_dim} heads_per_seq={heads_per_seq}",
805 x_shape.dims()
806 );
807 }
808 for chunk in 0..num_chunks {
809 let off = chunk * head_dim;
810 let pos = if heads_per_seq > 1 {
816 (chunk / heads_per_seq) % s_dim
817 } else {
818 chunk % s_dim
819 };
820 let pos = if cos_rows == 1 {
823 0
824 } else {
825 pos.min(cos_rows.saturating_sub(1))
826 };
827 if std::env::var("RLX_ROPE_DEBUG").is_ok() && chunk < 4 {
828 eprintln!("[rope] chunk={chunk} pos={pos}");
829 }
830 let cos_off = pos * tab_half;
831
832 for i in 0..rot_half {
833 let cos_v = cos_cache[cos_off + i];
834 let sin_v = sin_cache[cos_off + i];
835 let x1 = output[off + i];
836 let x2 = output[off + rot_half + i];
837 output[off + i] = x1 * cos_v - x2 * sin_v;
838 output[off + rot_half + i] = x2 * cos_v + x1 * sin_v;
839 }
840 output[(n_rot + off)..(head_dim + off)]
841 .copy_from_slice(&x[(n_rot + off)..(head_dim + off)]);
842 }
843 }
844
845 Op::Compare(cmp) => {
847 let lhs = get_data(arena, external, node.inputs[0]);
848 let rhs = get_data(arena, external, node.inputs[1]);
849 let output = get_output(arena, node_id);
850 let rhs_len = rhs.len();
851 for i in 0..output.len() {
852 let a = lhs[i];
853 let b = rhs[i % rhs_len];
854 output[i] = if compare_op(*cmp, a, b) { 1.0 } else { 0.0 };
855 }
856 }
857
858 Op::Where => {
860 let cond = get_data(arena, external, node.inputs[0]);
861 let on_true = get_data(arena, external, node.inputs[1]);
862 let on_false = get_data(arena, external, node.inputs[2]);
863 let output = get_output(arena, node_id);
864 for i in 0..output.len() {
865 output[i] = if cond[i] > 0.5 {
866 on_true[i]
867 } else {
868 on_false[i]
869 };
870 }
871 }
872
873 Op::Reduce {
875 op: reduce_op,
876 axes,
877 keep_dim: _,
878 } => {
879 let input = get_data(arena, external, node.inputs[0]);
880 let output = get_output(arena, node_id);
881 output.fill(0.0);
882 if axes.len() == 1 {
884 let in_shape = &graph.node(node.inputs[0]).shape;
885 let axis = axes[0];
886 let rank = in_shape.rank();
887 let outer: usize = (0..axis)
888 .map(|i| in_shape.dim(i).unwrap_static())
889 .product::<usize>()
890 .max(1);
891 let axis_size = in_shape.dim(axis).unwrap_static();
892 let inner: usize = (axis + 1..rank)
893 .map(|i| in_shape.dim(i).unwrap_static())
894 .product::<usize>()
895 .max(1);
896
897 match reduce_op {
898 ReduceOp::Sum | ReduceOp::Mean => {
899 for o in 0..outer {
900 for i in 0..inner {
901 let mut acc = 0f32;
902 for a in 0..axis_size {
903 acc += input[o * axis_size * inner + a * inner + i];
904 }
905 if matches!(reduce_op, ReduceOp::Mean) {
906 acc /= axis_size as f32;
907 }
908 output[o * inner + i] = acc;
909 }
910 }
911 }
912 ReduceOp::Max => {
913 output.fill(f32::NEG_INFINITY);
914 for o in 0..outer {
915 for i in 0..inner {
916 for a in 0..axis_size {
917 let v = input[o * axis_size * inner + a * inner + i];
918 let idx = o * inner + i;
919 if v > output[idx] {
920 output[idx] = v;
921 }
922 }
923 }
924 }
925 }
926 _ => {} }
928 }
929 }
930
931 Op::Cast { .. } => {
933 let input = get_data(arena, external, node.inputs[0]);
934 let output = get_output(arena, node_id);
935 output[..input.len()].copy_from_slice(input);
936 }
937
938 Op::FusedSwiGLU { cast_to: _, .. } => {
946 let input = get_data(arena, external, node.inputs[0]);
947 let output = get_output(arena, node_id);
948 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
952 let outer = output.len() / n;
953 debug_assert_eq!(
954 outer * 2 * n,
955 input.len(),
956 "FusedSwiGLU: input/output shape mismatch"
957 );
958 for o in 0..outer {
959 let in_row = &input[o * 2 * n..(o + 1) * 2 * n];
960 let out_row = &mut output[o * n..(o + 1) * n];
961 for i in 0..n {
962 let up = in_row[i];
963 let gate = in_row[n + i];
964 let silu_gate = gate / (1.0 + (-gate).exp());
965 out_row[i] = up * silu_gate;
966 }
967 }
968 }
969
970 Op::DenseSolve => {
972 let a_shape = &graph.node(node.inputs[0]).shape;
973 let n = a_shape.dim(0).unwrap_static();
974 let b_elems = node.shape.num_elements().unwrap();
975 let nrhs = b_elems / n.max(1);
976 match node.shape.dtype() {
977 rlx_ir::DType::F32 => {
978 let a = get_data(arena, external, node.inputs[0]);
979 let b = get_data(arena, external, node.inputs[1]);
980 let x = get_output(arena, node_id);
981 let mut a_scratch = a.to_vec();
982 let mut x_buf = b.to_vec();
983 let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n, nrhs);
984 if info != 0 {
985 panic!("DenseSolve: singular matrix (info={info})");
986 }
987 x[..x_buf.len()].copy_from_slice(&x_buf);
988 }
989 rlx_ir::DType::F64 => {
990 let (a_ptr, a_len) = arena.raw_ptr(node.inputs[0]);
991 let (b_ptr, b_len) = arena.raw_ptr(node.inputs[1]);
992 let (x_ptr, x_len) = arena.raw_ptr(node_id);
993 unsafe {
994 let a_src = std::slice::from_raw_parts(a_ptr as *const f64, a_len / 8);
995 let b_src = std::slice::from_raw_parts(b_ptr as *const f64, b_len / 8);
996 let mut a_scratch = a_src.to_vec();
997 let mut x_buf = b_src.to_vec();
998 let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n, nrhs);
999 if info != 0 {
1000 panic!("DenseSolve: singular matrix (info={info})");
1001 }
1002 std::slice::from_raw_parts_mut(x_ptr as *mut f64, x_len / 8)
1003 .copy_from_slice(&x_buf);
1004 }
1005 }
1006 other => panic!("DenseSolve executor: unsupported dtype {other:?}"),
1007 }
1008 }
1009
1010 _ => {
1012 if !node.inputs.is_empty() && arena.has_buffer(node_id) {
1013 let input = get_data(arena, external, node.inputs[0]);
1014 let output = get_output(arena, node_id);
1015 let len = output.len().min(input.len());
1016 output[..len].copy_from_slice(&input[..len]);
1017 }
1018 }
1019 }
1020 }
1021}
1022
1023fn get_data<'a>(arena: &'a Arena, external: &'a ExternalBuffers, id: NodeId) -> &'a [f32] {
1027 if let Some(&ext) = external.buffers.get(&id) {
1030 ext
1031 } else if arena.has_buffer(id) {
1032 let (ptr, len) = arena.raw_ptr(id);
1033 unsafe { std::slice::from_raw_parts(ptr, len) }
1034 } else {
1035 panic!("no data for node {id}")
1036 }
1037}
1038
1039#[allow(clippy::mut_from_ref)]
1045fn get_output(arena: &Arena, id: NodeId) -> &mut [f32] {
1046 let (ptr, len) = arena.raw_ptr(id);
1047 unsafe { std::slice::from_raw_parts_mut(ptr, len) }
1048}
1049
1050#[inline]
1052fn matmul(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
1053 crate::blas::sgemm(a, b, c, m, k, n);
1055}
1056
1057fn binary_op(op: rlx_ir::op::BinaryOp, a: f32, b: f32) -> f32 {
1058 use rlx_ir::op::BinaryOp::*;
1059 match op {
1060 Add => a + b,
1061 Sub => a - b,
1062 Mul => a * b,
1063 Div => a / b,
1064 Max => a.max(b),
1065 Min => a.min(b),
1066 Pow => a.powf(b),
1067 }
1068}
1069
1070fn compare_op(op: rlx_ir::op::CmpOp, a: f32, b: f32) -> bool {
1071 use rlx_ir::op::CmpOp::*;
1072 match op {
1073 Eq => a == b,
1074 Ne => a != b,
1075 Lt => a < b,
1076 Le => a <= b,
1077 Gt => a > b,
1078 Ge => a >= b,
1079 }
1080}
1081
1082#[allow(dead_code)]
1084fn scalar_gelu(x: f32) -> f32 {
1085 let sign = if x >= 0.0 { 1.0f32 } else { -1.0 };
1086 let xa = x.abs();
1087 let t = 1.0 / (1.0 + 0.3275911 * xa);
1088 let y = t
1089 * (0.254_829_6
1090 + t * (-0.284_496_72 + t * (1.421_413_8 + t * (-1.453_152_1 + t * 1.061_405_4))));
1091 let erf = sign * (1.0 - y * (-xa * xa).exp());
1092 x * 0.5 * (1.0 + erf)
1093}
1094
1095#[cfg(test)]
1096mod tests {
1097 use super::*;
1098 use rlx_ir::*;
1099
1100 use rlx_opt::fusion::FuseMatMulBiasAct;
1101 use rlx_opt::memory;
1102 use rlx_opt::pass::Pass;
1103
1104 #[test]
1106 fn execute_fused_matmul_bias_gelu() {
1107 let mut g = Graph::new("test");
1109 let x_id = g.input("x", Shape::new(&[2, 4], DType::F32));
1110 let w_id = g.param("w", Shape::new(&[4, 3], DType::F32));
1111 let b_id = g.param("b", Shape::new(&[3], DType::F32));
1112 let mm = g.matmul(x_id, w_id, Shape::new(&[2, 3], DType::F32));
1113 let add = g.binary(BinaryOp::Add, mm, b_id, Shape::new(&[2, 3], DType::F32));
1114 let out = g.activation(Activation::Gelu, add, Shape::new(&[2, 3], DType::F32));
1115 g.set_outputs(vec![out]);
1116
1117 let fused = FuseMatMulBiasAct.run(g);
1119 println!("{fused}");
1120
1121 let plan = memory::plan_memory(&fused);
1123 println!("Arena: {} bytes", plan.arena_size);
1124
1125 let x_data = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let w_data = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]; let b_data = vec![0.5, -0.5, 0.0]; let mut ext = ExternalBuffers {
1131 buffers: HashMap::new(),
1132 };
1133 ext.buffers.insert(fused.outputs[0], &[]); for node in fused.nodes() {
1136 match &node.op {
1137 Op::Input { name } if name == "x" => {
1138 ext.buffers.insert(node.id, &x_data);
1139 }
1140 Op::Param { name } if name == "w" => {
1141 ext.buffers.insert(node.id, &w_data);
1142 }
1143 Op::Param { name } if name == "b" => {
1144 ext.buffers.insert(node.id, &b_data);
1145 }
1146 _ => {}
1147 }
1148 }
1149
1150 let mut arena = Arena::from_plan(plan);
1152 execute(&fused, &mut arena, &ext);
1153
1154 let output_id = fused.outputs[0];
1156 let result = arena.slice(output_id);
1157 println!("Result: {result:?}");
1158
1159 assert!((result[0] - 1.399).abs() < 0.01, "got {}", result[0]);
1163 assert!((result[1] - -0.154).abs() < 0.01, "got {}", result[1]);
1164 assert!((result[2] - 0.0).abs() < 0.01, "got {}", result[2]);
1165 assert!((result[3] - 0.346).abs() < 0.01, "got {}", result[3]);
1166 }
1167
1168 #[test]
1170 fn execute_gather() {
1171 use rlx_ir::infer::GraphExt;
1172 let mut g = Graph::new("gather_test");
1173 let table = g.param("table", Shape::new(&[4, 3], DType::F32));
1175 let indices = g.input("ids", Shape::new(&[2], DType::F32)); let out = g.gather_(table, indices, 0);
1177 g.set_outputs(vec![out]);
1178
1179 let plan = memory::plan_memory(&g);
1180 let mut arena = Arena::from_plan(plan);
1181
1182 let table_data = vec![
1183 10.0, 11.0, 12.0, 20.0, 21.0, 22.0, 30.0, 31.0, 32.0, 40.0, 41.0, 42.0, ];
1188 let ids_data = vec![2.0, 0.0]; let mut ext = ExternalBuffers {
1191 buffers: HashMap::new(),
1192 };
1193 for node in g.nodes() {
1194 match &node.op {
1195 Op::Param { name } if name == "table" => {
1196 ext.buffers.insert(node.id, &table_data);
1197 }
1198 Op::Input { name } if name == "ids" => {
1199 ext.buffers.insert(node.id, &ids_data);
1200 }
1201 _ => {}
1202 }
1203 }
1204
1205 execute(&g, &mut arena, &ext);
1206 let result = arena.slice(g.outputs[0]);
1207 assert_eq!(&result[..3], &[30.0, 31.0, 32.0]); assert_eq!(&result[3..6], &[10.0, 11.0, 12.0]); }
1210
1211 #[test]
1213 fn execute_narrow() {
1214 use rlx_ir::infer::GraphExt;
1215 let mut g = Graph::new("narrow_test");
1216 let x = g.input("x", Shape::new(&[2, 6], DType::F32));
1217 let sliced = g.narrow_(x, 1, 2, 3); g.set_outputs(vec![sliced]);
1219
1220 let plan = memory::plan_memory(&g);
1221 let mut arena = Arena::from_plan(plan);
1222
1223 let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0];
1224 let mut ext = ExternalBuffers {
1225 buffers: HashMap::new(),
1226 };
1227 for node in g.nodes() {
1228 if let Op::Input { .. } = &node.op {
1229 ext.buffers.insert(node.id, &data);
1230 }
1231 }
1232
1233 execute(&g, &mut arena, &ext);
1234 let result = arena.slice(g.outputs[0]);
1235 assert_eq!(result, &[2.0, 3.0, 4.0, 8.0, 9.0, 10.0]);
1236 }
1237
1238 #[test]
1240 fn execute_softmax() {
1241 use rlx_ir::infer::GraphExt;
1242 let mut g = Graph::new("softmax_test");
1243 let x = g.input("x", Shape::new(&[1, 4], DType::F32));
1244 let sm = g.sm(x, -1);
1245 g.set_outputs(vec![sm]);
1246
1247 let plan = memory::plan_memory(&g);
1248 let mut arena = Arena::from_plan(plan);
1249
1250 let data = vec![1.0, 2.0, 3.0, 4.0];
1251 let mut ext = ExternalBuffers {
1252 buffers: HashMap::new(),
1253 };
1254 for node in g.nodes() {
1255 if let Op::Input { .. } = &node.op {
1256 ext.buffers.insert(node.id, &data);
1257 }
1258 }
1259
1260 execute(&g, &mut arena, &ext);
1261 let result = arena.slice(g.outputs[0]);
1262 let sum: f32 = result.iter().sum();
1263 assert!(
1264 (sum - 1.0).abs() < 1e-5,
1265 "softmax should sum to 1, got {sum}"
1266 );
1267 assert!(result[0] < result[1]);
1269 assert!(result[1] < result[2]);
1270 assert!(result[2] < result[3]);
1271 }
1272
1273 #[test]
1275 fn execute_rope() {
1276 use rlx_ir::infer::GraphExt;
1277 let head_dim = 4;
1278 let half = head_dim / 2;
1279 let seq = 2;
1280
1281 let mut g = Graph::new("rope_test");
1282 let x = g.input("x", Shape::new(&[seq, head_dim], DType::F32));
1284 let cos = g.param("cos", Shape::new(&[seq, half], DType::F32));
1285 let sin = g.param("sin", Shape::new(&[seq, half], DType::F32));
1286 let rotated = g.rope(x, cos, sin, head_dim);
1287 g.set_outputs(vec![rotated]);
1288
1289 let plan = memory::plan_memory(&g);
1290 let mut arena = Arena::from_plan(plan);
1291
1292 let x_data = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0f32];
1294 let cos_data = vec![1.0, 0.0, 0.0, 1.0f32];
1296 let sin_data = vec![0.0, 1.0, 1.0, 0.0f32];
1297
1298 let mut ext = ExternalBuffers {
1299 buffers: HashMap::new(),
1300 };
1301 for node in g.nodes() {
1302 match &node.op {
1303 Op::Input { name } if name == "x" => {
1304 ext.buffers.insert(node.id, &x_data);
1305 }
1306 Op::Param { name } if name == "cos" => {
1307 ext.buffers.insert(node.id, &cos_data);
1308 }
1309 Op::Param { name } if name == "sin" => {
1310 ext.buffers.insert(node.id, &sin_data);
1311 }
1312 _ => {}
1313 }
1314 }
1315
1316 execute(&g, &mut arena, &ext);
1317 let result = arena.slice(g.outputs[0]);
1318
1319 assert!((result[0] - 1.0).abs() < 1e-5, "pos0[0]={}", result[0]);
1325 assert!((result[1] - -1.0).abs() < 1e-5, "pos0[1]={}", result[1]);
1326 assert!((result[2] - 0.0).abs() < 1e-5, "pos0[2]={}", result[2]);
1327 assert!((result[3] - 0.0).abs() < 1e-5, "pos0[3]={}", result[3]);
1328
1329 assert!((result[4] - 0.0).abs() < 1e-5, "pos1[0]={}", result[4]);
1333 assert!((result[5] - 1.0).abs() < 1e-5, "pos1[1]={}", result[5]);
1334 assert!((result[6] - 1.0).abs() < 1e-5, "pos1[2]={}", result[6]);
1335 assert!((result[7] - 0.0).abs() < 1e-5, "pos1[3]={}", result[7]);
1336 }
1337
1338 #[test]
1340 fn execute_layer_norm() {
1341 use rlx_ir::infer::GraphExt;
1342 let mut g = Graph::new("ln_test");
1343 let x = g.input("x", Shape::new(&[1, 4], DType::F32));
1344 let gamma = g.param("g", Shape::new(&[4], DType::F32));
1345 let beta = g.param("b", Shape::new(&[4], DType::F32));
1346 let ln = g.ln(x, gamma, beta, 1e-5);
1347 g.set_outputs(vec![ln]);
1348
1349 let plan = memory::plan_memory(&g);
1350 let mut arena = Arena::from_plan(plan);
1351
1352 let x_data = vec![1.0, 2.0, 3.0, 4.0];
1353 let g_data = vec![1.0, 1.0, 1.0, 1.0];
1354 let b_data = vec![0.0, 0.0, 0.0, 0.0];
1355
1356 let mut ext = ExternalBuffers {
1357 buffers: HashMap::new(),
1358 };
1359 for node in g.nodes() {
1360 match &node.op {
1361 Op::Input { name } if name == "x" => {
1362 ext.buffers.insert(node.id, &x_data);
1363 }
1364 Op::Param { name } if name == "g" => {
1365 ext.buffers.insert(node.id, &g_data);
1366 }
1367 Op::Param { name } if name == "b" => {
1368 ext.buffers.insert(node.id, &b_data);
1369 }
1370 _ => {}
1371 }
1372 }
1373
1374 execute(&g, &mut arena, &ext);
1375 let result = arena.slice(g.outputs[0]);
1376 let sum: f32 = result.iter().sum();
1377 assert!(
1378 sum.abs() < 1e-3,
1379 "LN output should be zero-centered, sum={sum}"
1380 );
1381 }
1382}