Skip to main content

rlx_cpu/
executor.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Graph executor — runs a fused IR graph on CPU using the arena + kernels.
17//!
18//! The executor is the runtime hot path. For a 6-layer BERT, it makes
19//! ~24 kernel calls total (one per fused node). Everything else is
20//! inside the kernels — SIMD, BLAS, pre-allocated arena buffers.
21
22use crate::arena::Arena;
23use crate::kernels;
24use rlx_ir::op::{Activation, BinaryOp, ReduceOp};
25use rlx_ir::{Graph, NodeId, Op};
26use std::collections::HashMap;
27
28/// External data provided at runtime (model weights + inputs).
29pub struct ExternalBuffers<'a> {
30    /// Map from node ID (Input/Param nodes) to external f32 data.
31    pub buffers: HashMap<NodeId, &'a [f32]>,
32}
33
34/// Execute a compiled graph on CPU.
35///
36/// The graph should already be fused and memory-planned.
37/// `arena` holds all intermediate buffers.
38/// `external` provides input data and model weights.
39///
40/// Returns the output node IDs (data is in the arena).
41pub fn execute(graph: &Graph, arena: &mut Arena, external: &ExternalBuffers) {
42    let schedule: Vec<NodeId> = arena.schedule().to_vec();
43    for &node_id in &schedule {
44        let node = graph.node(node_id);
45
46        match &node.op {
47            // External data — skip (data provided via get_data which reads
48            // from external buffers directly without copying to arena)
49            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => {}
50
51            // ── Fused matmul + bias + optional activation ───────────
52            Op::FusedMatMulBiasAct { activation } => {
53                let input_id = node.inputs[0];
54                let weight_id = node.inputs[1];
55                let bias_id = node.inputs[2];
56
57                let input = get_data(arena, external, input_id);
58                let weight = get_data(arena, external, weight_id);
59                let bias = get_data(arena, external, bias_id);
60                let output = get_output(arena, node_id);
61
62                // Compute output shape for sgemm
63                let shape = &node.shape;
64                let n = shape.dim(shape.rank() - 1).unwrap_static();
65                let m = shape.num_elements().unwrap() / n;
66                let k = input.len() / m;
67
68                // sgemm: output = input @ weight
69                // TODO: call cblas_sgemm via FFI
70                // For now, naive matmul as placeholder
71                matmul(input, weight, output, m, k, n);
72
73                // Fused bias + activation (parallel NEON kernels)
74                match activation {
75                    Some(Activation::Gelu) => kernels::par_bias_gelu(output, bias, m, n),
76                    Some(Activation::Silu) => {
77                        crate::blas::bias_add(output, bias, m, n);
78                        kernels::silu_inplace(output);
79                    }
80                    _ => crate::blas::bias_add(output, bias, m, n),
81                }
82            }
83
84            // ── Fused residual + LayerNorm (parallel NEON) ──────────
85            Op::FusedResidualLN { has_bias, eps } => {
86                let x_id = node.inputs[0];
87                let residual_id = node.inputs[1];
88                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
89                let zero_bias = vec![0f32; h];
90                let (gamma_id, beta_id, bias_slice) = if *has_bias {
91                    let b = get_data(arena, external, node.inputs[2]);
92                    (node.inputs[3], node.inputs[4], b)
93                } else {
94                    (node.inputs[2], node.inputs[3], zero_bias.as_slice())
95                };
96
97                let x = get_data(arena, external, x_id);
98                let residual = get_data(arena, external, residual_id);
99                let gamma = get_data(arena, external, gamma_id);
100                let beta = get_data(arena, external, beta_id);
101                let output = get_output(arena, node_id);
102
103                let n = x.len() / h;
104
105                // Parallel: each thread processes a chunk of rows
106                let x_ptr = x.as_ptr() as usize;
107                let r_ptr = residual.as_ptr() as usize;
108                let o_ptr = output.as_mut_ptr() as usize;
109                let bi_ptr = bias_slice.as_ptr() as usize;
110                let g_ptr = gamma.as_ptr() as usize;
111                let b_ptr = beta.as_ptr() as usize;
112                let e = *eps;
113                crate::pool::par_for(n, 4, &|off, cnt| unsafe {
114                    let x_s =
115                        std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
116                    let r_s =
117                        std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
118                    let o_s =
119                        std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
120                    let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
121                    let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
122                    let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
123                    kernels::residual_bias_layer_norm(x_s, r_s, bi, g, b, o_s, cnt, h, e);
124                });
125            }
126
127            // ── Fused residual + RMSNorm (parallel) ─────────────────
128            Op::FusedResidualRmsNorm { has_bias, eps } => {
129                let x_id = node.inputs[0];
130                let residual_id = node.inputs[1];
131                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
132                let zero_bias = vec![0f32; h];
133                let (gamma_id, beta_id, bias_slice) = if *has_bias {
134                    let b = get_data(arena, external, node.inputs[2]);
135                    (node.inputs[3], node.inputs[4], b)
136                } else {
137                    (node.inputs[2], node.inputs[3], zero_bias.as_slice())
138                };
139
140                let x = get_data(arena, external, x_id);
141                let residual = get_data(arena, external, residual_id);
142                let gamma = get_data(arena, external, gamma_id);
143                let beta = get_data(arena, external, beta_id);
144                let output = get_output(arena, node_id);
145
146                let n = x.len() / h;
147
148                let x_ptr = x.as_ptr() as usize;
149                let r_ptr = residual.as_ptr() as usize;
150                let o_ptr = output.as_mut_ptr() as usize;
151                let bi_ptr = bias_slice.as_ptr() as usize;
152                let g_ptr = gamma.as_ptr() as usize;
153                let b_ptr = beta.as_ptr() as usize;
154                let e = *eps;
155                crate::pool::par_for(n, 4, &|off, cnt| unsafe {
156                    let x_s =
157                        std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
158                    let r_s =
159                        std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
160                    let o_s =
161                        std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
162                    let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
163                    let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
164                    let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
165                    kernels::residual_bias_rms_norm(x_s, r_s, bi, g, b, o_s, cnt, h, e);
166                });
167            }
168
169            // ── Plain matmul ────────────────────────────────────────
170            Op::MatMul => {
171                let lhs = get_data(arena, external, node.inputs[0]);
172                let rhs = get_data(arena, external, node.inputs[1]);
173                let output = get_output(arena, node_id);
174
175                let shape = &node.shape;
176                let lhs_shape = &graph.node(node.inputs[0]).shape;
177                let rhs_shape = &graph.node(node.inputs[1]).shape;
178                let n = shape.dim(shape.rank() - 1).unwrap_static();
179                let out_m_inner = shape.dim(shape.rank() - 2).unwrap_static();
180                let k = lhs_shape.dim(lhs_shape.rank() - 1).unwrap_static();
181
182                // Outer batch dims — present when either input has rank > 2.
183                // Compute total batch as output.num_elements / (M * N).
184                let total = shape.num_elements().unwrap();
185                let per_batch_out = out_m_inner * n;
186                let batches = total / per_batch_out;
187
188                if batches == 1 {
189                    matmul(lhs, rhs, output, out_m_inner, k, n);
190                } else {
191                    let lhs_batched =
192                        lhs_shape.num_elements().unwrap_or(0) == batches * out_m_inner * k;
193                    let rhs_batched = rhs_shape.num_elements().unwrap_or(0) == batches * k * n;
194                    for b in 0..batches {
195                        let l_off = if lhs_batched { b * out_m_inner * k } else { 0 };
196                        let r_off = if rhs_batched { b * k * n } else { 0 };
197                        let o_off = b * out_m_inner * n;
198                        let l_slice = &lhs[l_off..l_off + out_m_inner * k];
199                        let r_slice = &rhs[r_off..r_off + k * n];
200                        let o_slice = &mut output[o_off..o_off + out_m_inner * n];
201                        matmul(l_slice, r_slice, o_slice, out_m_inner, k, n);
202                    }
203                }
204            }
205
206            // ── Element-wise binary ─────────────────────────────────
207            Op::Binary(op) => {
208                let lhs = get_data(arena, external, node.inputs[0]);
209                let rhs = get_data(arena, external, node.inputs[1]);
210                let output = get_output(arena, node_id);
211                let len = output.len();
212                let rhs_len = rhs.len();
213
214                // Fast path: Add with broadcast bias → NEON bias_add
215                if matches!(op, BinaryOp::Add) && rhs_len < len && len.is_multiple_of(rhs_len) {
216                    output.copy_from_slice(lhs);
217                    crate::blas::bias_add(output, rhs, len / rhs_len, rhs_len);
218                } else if rhs_len == len {
219                    for i in 0..len {
220                        output[i] = binary_op(*op, lhs[i], rhs[i]);
221                    }
222                } else {
223                    for i in 0..len {
224                        output[i] = binary_op(*op, lhs[i], rhs[i % rhs_len]);
225                    }
226                }
227            }
228
229            // ── Unary activation ────────────────────────────────────
230            Op::Activation(act) => {
231                let input = get_data(arena, external, node.inputs[0]);
232                let output = get_output(arena, node_id);
233                output.copy_from_slice(input);
234                let zeros = vec![0f32; node.shape.dim(node.shape.rank() - 1).unwrap_static()];
235                let m = output.len() / zeros.len();
236                let n = zeros.len();
237                match act {
238                    Activation::Gelu => kernels::par_bias_gelu(output, &zeros, m, n),
239                    Activation::Silu => kernels::silu_inplace(output),
240                    Activation::Relu => {
241                        for v in output.iter_mut() {
242                            *v = v.max(0.0);
243                        }
244                    }
245                    Activation::Exp => {
246                        for v in output.iter_mut() {
247                            *v = v.exp();
248                        }
249                    }
250                    Activation::Sqrt => {
251                        for v in output.iter_mut() {
252                            *v = v.sqrt();
253                        }
254                    }
255                    Activation::Neg => {
256                        for v in output.iter_mut() {
257                            *v = -*v;
258                        }
259                    }
260                    Activation::Tanh => {
261                        for v in output.iter_mut() {
262                            *v = v.tanh();
263                        }
264                    }
265                    Activation::Sigmoid => {
266                        for v in output.iter_mut() {
267                            *v = 1.0 / (1.0 + (-*v).exp());
268                        }
269                    }
270                    _ => {}
271                }
272            }
273
274            // ── Gather (embedding lookup) ───────────────────────────
275            Op::Gather { axis } => {
276                let table = get_data(arena, external, node.inputs[0]);
277                let indices = get_data(arena, external, node.inputs[1]);
278                let output = get_output(arena, node_id);
279
280                let table_shape = &graph.node(node.inputs[0]).shape;
281                let _out_shape = &node.shape;
282
283                // For axis=0 (embedding): table[V, D...], indices[B, S] → [B, S, D...]
284                if *axis == 0 {
285                    let trailing: usize = (1..table_shape.rank())
286                        .map(|i| table_shape.dim(i).unwrap_static())
287                        .product();
288                    for (i, &idx_f32) in indices.iter().enumerate() {
289                        let idx = idx_f32 as usize;
290                        let src = idx * trailing;
291                        let dst = i * trailing;
292                        output[dst..dst + trailing].copy_from_slice(&table[src..src + trailing]);
293                    }
294                } else {
295                    // General gather — fallback
296                    output.fill(0.0);
297                }
298            }
299
300            // ── Narrow (slice along axis) ───────────────────────────
301            Op::Narrow { axis, start, len } => {
302                let input = get_data(arena, external, node.inputs[0]);
303                let output = get_output(arena, node_id);
304                let in_shape = &graph.node(node.inputs[0]).shape;
305
306                let rank = in_shape.rank();
307                let outer: usize = (0..*axis)
308                    .map(|i| in_shape.dim(i).unwrap_static())
309                    .product::<usize>()
310                    .max(1);
311                let inner: usize = (*axis + 1..rank)
312                    .map(|i| in_shape.dim(i).unwrap_static())
313                    .product::<usize>()
314                    .max(1);
315                let in_axis_size = in_shape.dim(*axis).unwrap_static();
316
317                for o in 0..outer {
318                    for s in 0..*len {
319                        let src_off = o * in_axis_size * inner + (*start + s) * inner;
320                        let dst_off = o * len * inner + s * inner;
321                        output[dst_off..dst_off + inner]
322                            .copy_from_slice(&input[src_off..src_off + inner]);
323                    }
324                }
325            }
326
327            // ── Transpose ───────────────────────────────────────────
328            Op::Transpose { perm } => {
329                let input = get_data(arena, external, node.inputs[0]);
330                let output = get_output(arena, node_id);
331                let in_shape = &graph.node(node.inputs[0]).shape;
332                let rank = in_shape.rank();
333
334                let in_dims: Vec<usize> =
335                    (0..rank).map(|i| in_shape.dim(i).unwrap_static()).collect();
336                let out_dims: Vec<usize> = perm.iter().map(|&i| in_dims[i]).collect();
337
338                // Row-major strides for input and output spaces.
339                // For a shape [d0, d1, ..., d_{r-1}], stride[i] = product(d_{i+1..r}).
340                let mut in_strides = vec![1usize; rank];
341                for i in (0..rank - 1).rev() {
342                    in_strides[i] = in_strides[i + 1] * in_dims[i + 1];
343                }
344                let mut out_strides = vec![1usize; rank];
345                for i in (0..rank - 1).rev() {
346                    out_strides[i] = out_strides[i + 1] * out_dims[i + 1];
347                }
348
349                let total = output.len();
350                for flat_out in 0..total {
351                    let mut in_flat = 0;
352                    for d in 0..rank {
353                        // out_coord[d] decoded from flat_out via output strides.
354                        let coord = (flat_out / out_strides[d]) % out_dims[d];
355                        // Output dim d came from input dim perm[d].
356                        in_flat += coord * in_strides[perm[d]];
357                    }
358                    output[flat_out] = input[in_flat];
359                }
360            }
361
362            // ── Concat ──────────────────────────────────────────────
363            Op::Concat { axis } => {
364                let output = get_output(arena, node_id);
365                let out_shape = &node.shape;
366                let rank = out_shape.rank();
367
368                let outer: usize = (0..*axis)
369                    .map(|i| out_shape.dim(i).unwrap_static())
370                    .product::<usize>()
371                    .max(1);
372                let inner: usize = (*axis + 1..rank)
373                    .map(|i| out_shape.dim(i).unwrap_static())
374                    .product::<usize>()
375                    .max(1);
376
377                let mut dst_off = 0;
378                for o in 0..outer {
379                    for &inp_id in &node.inputs {
380                        let inp = get_data(arena, external, inp_id);
381                        let inp_shape = &graph.node(inp_id).shape;
382                        let inp_axis = inp_shape.dim(*axis).unwrap_static();
383                        let chunk = inp_axis * inner;
384                        let src_off = o * chunk;
385                        output[dst_off..dst_off + chunk]
386                            .copy_from_slice(&inp[src_off..src_off + chunk]);
387                        dst_off += chunk;
388                    }
389                }
390            }
391
392            // ── Reshape (zero-copy: same data, different shape) ─────
393            Op::Reshape { .. } | Op::Expand { .. } => {
394                let input = get_data(arena, external, node.inputs[0]);
395                let output = get_output(arena, node_id);
396                output[..input.len()].copy_from_slice(input);
397            }
398
399            // ── LayerNorm (parallel NEON) ────────────────────────────
400            Op::LayerNorm { eps, .. } => {
401                let input = get_data(arena, external, node.inputs[0]);
402                let gamma = get_data(arena, external, node.inputs[1]);
403                let beta = get_data(arena, external, node.inputs[2]);
404                let output = get_output(arena, node_id);
405                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
406                let n = input.len() / h;
407                for row in 0..n {
408                    let base = row * h;
409                    kernels::layer_norm_row(
410                        &input[base..base + h],
411                        gamma,
412                        beta,
413                        &mut output[base..base + h],
414                        h,
415                        *eps,
416                    );
417                }
418            }
419
420            Op::GroupNorm { num_groups, eps } => {
421                let input = get_data(arena, external, node.inputs[0]);
422                let gamma = get_data(arena, external, node.inputs[1]);
423                let beta = get_data(arena, external, node.inputs[2]);
424                let output = get_output(arena, node_id);
425                let n = node.shape.dim(0).unwrap_static();
426                let c = node.shape.dim(1).unwrap_static();
427                let h = node.shape.dim(2).unwrap_static();
428                let w = node.shape.dim(3).unwrap_static();
429                kernels::group_norm_nchw(input, gamma, beta, output, n, c, h, w, *num_groups, *eps);
430            }
431
432            Op::ResizeNearest2x => {
433                let input = get_data(arena, external, node.inputs[0]);
434                let output = get_output(arena, node_id);
435                let n = node.shape.dim(0).unwrap_static();
436                let c = node.shape.dim(1).unwrap_static();
437                let h = node.shape.dim(2).unwrap_static() / 2;
438                let w = node.shape.dim(3).unwrap_static() / 2;
439                let in_plane = c * h * w;
440                let out_plane = c * h * 2 * w * 2;
441                for ni in 0..n {
442                    kernels::resize_nearest_2x_nchw(
443                        &input[ni * in_plane..(ni + 1) * in_plane],
444                        &mut output[ni * out_plane..(ni + 1) * out_plane],
445                        c,
446                        h,
447                        w,
448                    );
449                }
450            }
451
452            Op::AxialRope2d {
453                end_x,
454                end_y,
455                head_dim,
456                num_heads,
457                theta,
458                repeat_factor,
459            } => {
460                let input = get_data(arena, external, node.inputs[0]);
461                let output = get_output(arena, node_id);
462                let batch = node.shape.dim(0).unwrap_static();
463                let seq = node.shape.dim(1).unwrap_static();
464                let plane = seq * node.shape.dim(2).unwrap_static();
465                for bi in 0..batch {
466                    let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
467                        &input[bi * plane..(bi + 1) * plane],
468                        *num_heads,
469                        seq,
470                        *head_dim,
471                        *end_x,
472                        *end_y,
473                        *theta,
474                        *repeat_factor,
475                    );
476                    output[bi * plane..(bi + 1) * plane].copy_from_slice(&rotated);
477                }
478            }
479
480            // ── Softmax ─────────────────────────────────────────────
481            Op::Softmax { axis } => {
482                let input = get_data(arena, external, node.inputs[0]);
483                let output = get_output(arena, node_id);
484                output.copy_from_slice(input);
485                let rank = node.shape.rank();
486                let ax = if *axis < 0 {
487                    (rank as i32 + axis) as usize
488                } else {
489                    *axis as usize
490                };
491                let cols = node.shape.dim(ax).unwrap_static();
492                let rows = output.len() / cols;
493                crate::naive::softmax(output, rows, cols);
494            }
495
496            // ── Attention (SDPA) — BLAS-accelerated ─────────────────
497            Op::Attention {
498                num_heads,
499                head_dim,
500                mask_kind,
501                score_scale,
502                attn_logit_softcap,
503            } => {
504                let q = get_data(arena, external, node.inputs[0]);
505                let k = get_data(arena, external, node.inputs[1]);
506                let v = get_data(arena, external, node.inputs[2]);
507                // For non-Custom mask kinds the IR emits no mask input —
508                // synthesize an empty slice so the masking branch below
509                // sees `mask.len() < ...` and skips.
510                let mask: &[f32] = if matches!(
511                    mask_kind,
512                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
513                ) {
514                    get_data(arena, external, node.inputs[3])
515                } else {
516                    &[]
517                };
518                let output = get_output(arena, node_id);
519
520                let q_shape = &graph.node(node.inputs[0]).shape;
521                let k_shape = &graph.node(node.inputs[1]).shape;
522                let hs = num_heads * head_dim;
523                let scale = score_scale.unwrap_or((*head_dim as f32).powf(-0.5));
524                let (batch_size, s_q) = if q_shape.rank() >= 3 {
525                    (
526                        q_shape.dim(0).unwrap_static(),
527                        q_shape.dim(1).unwrap_static(),
528                    )
529                } else {
530                    (1, q_shape.dim(0).unwrap_static())
531                };
532                // K and V share Lk. In decode mode Lk = past+1 and Lq = 1;
533                // in prefill Lq = Lk. Causal/SlidingWindow masking is
534                // expressed in absolute positions: Q-row qi is at absolute
535                // position (Lk - Lq) + qi, so masking shifts accordingly.
536                let s_k = if k_shape.rank() >= 3 {
537                    k_shape.dim(1).unwrap_static()
538                } else {
539                    k_shape.dim(0).unwrap_static()
540                };
541                let q_offset = s_k.saturating_sub(s_q);
542
543                // Pre-allocate buffers ONCE (reused across heads)
544                let q_buf_len = s_q * head_dim;
545                let k_buf_len = s_k * head_dim;
546                let mut q_head = vec![0f32; q_buf_len];
547                let mut k_head = vec![0f32; k_buf_len];
548                let mut v_head = vec![0f32; k_buf_len];
549                let mut scores = vec![0f32; s_q * s_k];
550                let mut out_head = vec![0f32; q_buf_len];
551
552                for bi in 0..batch_size {
553                    for hi in 0..*num_heads {
554                        // Gather per-head Q (Lq rows).
555                        for si in 0..s_q {
556                            let off = bi * s_q * hs + si * hs + hi * head_dim;
557                            q_head[si * head_dim..(si + 1) * head_dim]
558                                .copy_from_slice(&q[off..off + head_dim]);
559                        }
560                        // Gather per-head K, V (Lk rows).
561                        for si in 0..s_k {
562                            let off = bi * s_k * hs + si * hs + hi * head_dim;
563                            k_head[si * head_dim..(si + 1) * head_dim]
564                                .copy_from_slice(&k[off..off + head_dim]);
565                            v_head[si * head_dim..(si + 1) * head_dim]
566                                .copy_from_slice(&v[off..off + head_dim]);
567                        }
568                        // Q@K^T: scores[Lq, Lk]. Use NEON dots when the
569                        // larger of Lq/Lk is small; BLAS otherwise.
570                        if s_q.max(s_k) <= 32 {
571                            for qi in 0..s_q {
572                                for ki in 0..s_k {
573                                    let q_off = qi * head_dim;
574                                    let k_off = ki * head_dim;
575                                    #[cfg(target_arch = "aarch64")]
576                                    let mut dot;
577                                    #[cfg(not(target_arch = "aarch64"))]
578                                    let mut dot = 0f32;
579                                    #[cfg(target_arch = "aarch64")]
580                                    unsafe {
581                                        use std::arch::aarch64::*;
582                                        let chunks = head_dim / 4;
583                                        let mut acc = vdupq_n_f32(0.0);
584                                        for c in 0..chunks {
585                                            let vq = vld1q_f32(q_head.as_ptr().add(q_off + c * 4));
586                                            let vk = vld1q_f32(k_head.as_ptr().add(k_off + c * 4));
587                                            acc = vfmaq_f32(acc, vq, vk);
588                                        }
589                                        dot = vaddvq_f32(acc);
590                                        for d in (chunks * 4)..*head_dim {
591                                            dot += q_head[q_off + d] * k_head[k_off + d];
592                                        }
593                                    }
594                                    #[cfg(not(target_arch = "aarch64"))]
595                                    {
596                                        for d in 0..*head_dim {
597                                            dot += q_head[q_off + d] * k_head[k_off + d];
598                                        }
599                                    }
600                                    scores[qi * s_k + ki] = dot * scale;
601                                }
602                            }
603                        } else {
604                            crate::blas::sgemm_bt(
605                                &q_head,
606                                &k_head,
607                                &mut scores,
608                                s_q,
609                                *head_dim,
610                                s_k,
611                                scale,
612                            );
613                        }
614                        // Mask: branch on kind so None / Causal skip the
615                        // mask load entirely. Causal/SlidingWindow use
616                        // absolute positions so they handle Lq != Lk
617                        // (decode-mode with cached K/V).
618                        match mask_kind {
619                            rlx_ir::op::MaskKind::None => {}
620                            rlx_ir::op::MaskKind::Causal => {
621                                for qi in 0..s_q {
622                                    let abs_q = q_offset + qi;
623                                    for ki in (abs_q + 1)..s_k {
624                                        scores[qi * s_k + ki] = -1e9;
625                                    }
626                                }
627                            }
628                            rlx_ir::op::MaskKind::SlidingWindow(w) => {
629                                for qi in 0..s_q {
630                                    let abs_q = q_offset + qi;
631                                    let lo = abs_q.saturating_sub(*w);
632                                    for ki in 0..s_k {
633                                        if ki < lo || ki > abs_q {
634                                            scores[qi * s_k + ki] = -1e9;
635                                        }
636                                    }
637                                }
638                            }
639                            rlx_ir::op::MaskKind::Custom => {
640                                if mask.len() >= (bi + 1) * s_k {
641                                    let m = &mask[bi * s_k..(bi + 1) * s_k];
642                                    for qi in 0..s_q {
643                                        for ki in 0..s_k {
644                                            if m[ki] < 0.5 {
645                                                scores[qi * s_k + ki] = -1e9;
646                                            }
647                                        }
648                                    }
649                                }
650                            }
651                            rlx_ir::op::MaskKind::Bias => {
652                                // Bias is [batch, num_heads, s_q, s_k]
653                                // (additive, pre-softmax). Skip if the
654                                // buffer wasn't supplied.
655                                let per_bh = s_q * s_k;
656                                let need = (bi * *num_heads + hi + 1) * per_bh;
657                                if mask.len() >= need {
658                                    let bias_off = (bi * *num_heads + hi) * per_bh;
659                                    let b = &mask[bias_off..bias_off + per_bh];
660                                    for i in 0..per_bh {
661                                        scores[i] += b[i];
662                                    }
663                                }
664                            }
665                        }
666                        if let Some(cap) = attn_logit_softcap {
667                            if *cap > 0.0 {
668                                for s in scores.iter_mut() {
669                                    *s = cap * (*s / cap).tanh();
670                                }
671                            }
672                        }
673                        crate::naive::softmax(&mut scores, s_q, s_k);
674                        // scores[Lq, Lk] @ V[Lk, head_dim] → out_head[Lq, head_dim]
675                        if s_q.max(s_k) <= 32 {
676                            out_head.fill(0.0);
677                            for qi in 0..s_q {
678                                for ki in 0..s_k {
679                                    let sc = scores[qi * s_k + ki];
680                                    if sc > 1e-8 {
681                                        let v_off = ki * head_dim;
682                                        let o_off = qi * head_dim;
683                                        #[cfg(target_arch = "aarch64")]
684                                        unsafe {
685                                            use std::arch::aarch64::*;
686                                            let vsc = vdupq_n_f32(sc);
687                                            let chunks = head_dim / 4;
688                                            for c in 0..chunks {
689                                                let off = c * 4;
690                                                let vo =
691                                                    vld1q_f32(out_head.as_ptr().add(o_off + off));
692                                                let vv =
693                                                    vld1q_f32(v_head.as_ptr().add(v_off + off));
694                                                vst1q_f32(
695                                                    out_head.as_mut_ptr().add(o_off + off),
696                                                    vfmaq_f32(vo, vsc, vv),
697                                                );
698                                            }
699                                        }
700                                        #[cfg(not(target_arch = "aarch64"))]
701                                        for d in 0..*head_dim {
702                                            out_head[o_off + d] += sc * v_head[v_off + d];
703                                        }
704                                    }
705                                }
706                            }
707                        } else {
708                            crate::blas::sgemm(
709                                &scores,
710                                &v_head,
711                                &mut out_head,
712                                s_q,
713                                s_k,
714                                *head_dim,
715                            );
716                        }
717                        // Scatter back into [B, Lq, hs].
718                        for si in 0..s_q {
719                            let off = bi * s_q * hs + si * hs + hi * head_dim;
720                            output[off..off + head_dim]
721                                .copy_from_slice(&out_head[si * head_dim..(si + 1) * head_dim]);
722                        }
723                    }
724                }
725            }
726
727            // ── Rotary position embedding ────────────────────────────
728            //
729            // Layout-aware position derivation:
730            //   - Packed rank-3 input `[B, S, H*D]` (heads in the last dim):
731            //     a head_dim-sized chunk index `c` maps to `(b, s, h)` with
732            //     `s = (c / H) % S`. We compute `s` explicitly and ignore
733            //     `cos.len()` for position derivation. Mirrors the MLX
734            //     `multi_head_packed` path in `rlx-mlx/src/lower.rs` so all
735            //     backends agree on per-chunk position.
736            //   - Rank-4 input `[B, H, S, D]` (heads as their own axis):
737            //     chunks per head are contiguous in `S`, so `s = c % S`.
738            //   - Single-head input `[B, S, D]`: same as rank-4 with H=1.
739            //   - Decode slice `cos.len() == half` (single position table):
740            //     `cos_len / tab_half = 1`, so `s = c % 1 = 0` is the right
741            //     value and matches the runtime-computed slice for absolute
742            //     position `past_seq`.
743            Op::Rope { head_dim, n_rot } => {
744                let head_dim = *head_dim;
745                let n_rot = *n_rot;
746                let x = get_data(arena, external, node.inputs[0]);
747                let cos_cache = get_data(arena, external, node.inputs[1]);
748                let sin_cache = get_data(arena, external, node.inputs[2]);
749                let x_shape = &graph.node(node.inputs[0]).shape;
750                let output = get_output(arena, node_id);
751                output.copy_from_slice(x);
752
753                let rot_half = n_rot / 2;
754                let tab_half = head_dim / 2;
755                let total = output.len();
756                let num_chunks = total / head_dim;
757
758                // Derive (s_dim, heads_per_seq) from the input shape so we can
759                // map chunk index → seq position without assuming layout.
760                let cos_rows = cos_cache.len() / tab_half.max(1);
761                let (s_dim, heads_per_seq): (usize, usize) = {
762                    let rank = x_shape.rank();
763                    if rank == 0 {
764                        (1, 1)
765                    } else {
766                        let last = if x_shape.dim(rank - 1).is_static() {
767                            x_shape.dim(rank - 1).unwrap_static()
768                        } else {
769                            head_dim
770                        };
771                        if rank >= 3 && last > head_dim && last.is_multiple_of(head_dim) {
772                            // Packed multi-head [..., S, H*D].
773                            let s = if x_shape.dim(rank - 2).is_static() {
774                                x_shape.dim(rank - 2).unwrap_static()
775                            } else {
776                                1
777                            };
778                            (s, last / head_dim)
779                        } else if rank >= 4 && last == head_dim {
780                            // [B, H, S, D] — heads on outer axis.
781                            let s = if x_shape.dim(rank - 2).is_static() {
782                                x_shape.dim(rank - 2).unwrap_static()
783                            } else {
784                                1
785                            };
786                            (s, 1)
787                        } else if rank >= 3 && last == head_dim {
788                            // [..., S, D] single head.
789                            let s = if x_shape.dim(rank - 2).is_static() {
790                                x_shape.dim(rank - 2).unwrap_static()
791                            } else {
792                                1
793                            };
794                            (s, 1)
795                        } else {
796                            // Fallback: rely on the cos-table-length heuristic.
797                            (cos_rows.max(1), 1)
798                        }
799                    }
800                };
801
802                if std::env::var("RLX_ROPE_DEBUG").is_ok() {
803                    eprintln!(
804                        "[rope] shape={:?} num_chunks={num_chunks} cos_rows={cos_rows} s_dim={s_dim} heads_per_seq={heads_per_seq}",
805                        x_shape.dims()
806                    );
807                }
808                for chunk in 0..num_chunks {
809                    let off = chunk * head_dim;
810                    // Position derivation:
811                    //   - Packed [B, S, H*D]: chunk = ((b*S)+s)*H + h, so
812                    //     s = (chunk / heads_per_seq) % s_dim.
813                    //   - [B, H, S, D] / [B, S, D]: chunks per seq run contig,
814                    //     s = chunk % s_dim.
815                    let pos = if heads_per_seq > 1 {
816                        (chunk / heads_per_seq) % s_dim
817                    } else {
818                        chunk % s_dim
819                    };
820                    // For the decode-slice case (cos has a single row), force
821                    // pos = 0 so we always index the supplied past_seq slice.
822                    let pos = if cos_rows == 1 {
823                        0
824                    } else {
825                        pos.min(cos_rows.saturating_sub(1))
826                    };
827                    if std::env::var("RLX_ROPE_DEBUG").is_ok() && chunk < 4 {
828                        eprintln!("[rope]   chunk={chunk} pos={pos}");
829                    }
830                    let cos_off = pos * tab_half;
831
832                    for i in 0..rot_half {
833                        let cos_v = cos_cache[cos_off + i];
834                        let sin_v = sin_cache[cos_off + i];
835                        let x1 = output[off + i];
836                        let x2 = output[off + rot_half + i];
837                        output[off + i] = x1 * cos_v - x2 * sin_v;
838                        output[off + rot_half + i] = x2 * cos_v + x1 * sin_v;
839                    }
840                    output[(n_rot + off)..(head_dim + off)]
841                        .copy_from_slice(&x[(n_rot + off)..(head_dim + off)]);
842                }
843            }
844
845            // ── Compare ─────────────────────────────────────────────
846            Op::Compare(cmp) => {
847                let lhs = get_data(arena, external, node.inputs[0]);
848                let rhs = get_data(arena, external, node.inputs[1]);
849                let output = get_output(arena, node_id);
850                let rhs_len = rhs.len();
851                for i in 0..output.len() {
852                    let a = lhs[i];
853                    let b = rhs[i % rhs_len];
854                    output[i] = if compare_op(*cmp, a, b) { 1.0 } else { 0.0 };
855                }
856            }
857
858            // ── Where (conditional select) ──────────────────────────
859            Op::Where => {
860                let cond = get_data(arena, external, node.inputs[0]);
861                let on_true = get_data(arena, external, node.inputs[1]);
862                let on_false = get_data(arena, external, node.inputs[2]);
863                let output = get_output(arena, node_id);
864                for i in 0..output.len() {
865                    output[i] = if cond[i] > 0.5 {
866                        on_true[i]
867                    } else {
868                        on_false[i]
869                    };
870                }
871            }
872
873            // ── Reduce ──────────────────────────────────────────────
874            Op::Reduce {
875                op: reduce_op,
876                axes,
877                keep_dim: _,
878            } => {
879                let input = get_data(arena, external, node.inputs[0]);
880                let output = get_output(arena, node_id);
881                output.fill(0.0);
882                // Simple: only handle single-axis reduction for now
883                if axes.len() == 1 {
884                    let in_shape = &graph.node(node.inputs[0]).shape;
885                    let axis = axes[0];
886                    let rank = in_shape.rank();
887                    let outer: usize = (0..axis)
888                        .map(|i| in_shape.dim(i).unwrap_static())
889                        .product::<usize>()
890                        .max(1);
891                    let axis_size = in_shape.dim(axis).unwrap_static();
892                    let inner: usize = (axis + 1..rank)
893                        .map(|i| in_shape.dim(i).unwrap_static())
894                        .product::<usize>()
895                        .max(1);
896
897                    match reduce_op {
898                        ReduceOp::Sum | ReduceOp::Mean => {
899                            for o in 0..outer {
900                                for i in 0..inner {
901                                    let mut acc = 0f32;
902                                    for a in 0..axis_size {
903                                        acc += input[o * axis_size * inner + a * inner + i];
904                                    }
905                                    if matches!(reduce_op, ReduceOp::Mean) {
906                                        acc /= axis_size as f32;
907                                    }
908                                    output[o * inner + i] = acc;
909                                }
910                            }
911                        }
912                        ReduceOp::Max => {
913                            output.fill(f32::NEG_INFINITY);
914                            for o in 0..outer {
915                                for i in 0..inner {
916                                    for a in 0..axis_size {
917                                        let v = input[o * axis_size * inner + a * inner + i];
918                                        let idx = o * inner + i;
919                                        if v > output[idx] {
920                                            output[idx] = v;
921                                        }
922                                    }
923                                }
924                            }
925                        }
926                        _ => {} // Min, Prod — TODO
927                    }
928                }
929            }
930
931            // ── Cast ────────────────────────────────────────────────
932            Op::Cast { .. } => {
933                let input = get_data(arena, external, node.inputs[0]);
934                let output = get_output(arena, node_id);
935                output[..input.len()].copy_from_slice(input);
936            }
937
938            // ── Fused SwiGLU ────────────────────────────────────────
939            // Input layout: concatenated [..., 2N] tensor where the first
940            // N elements per row are the "up" projection and the next N
941            // are the "gate" projection. Output: [..., N] = up * silu(gate).
942            // `cast_to` is currently advisory: rlx-cpu always operates in
943            // f32, so backends that distinguish dtypes apply the cast; the
944            // CPU executor stores the f32 result regardless.
945            Op::FusedSwiGLU { cast_to: _, .. } => {
946                let input = get_data(arena, external, node.inputs[0]);
947                let output = get_output(arena, node_id);
948                // n = last-dim half (read from the node's own shape, NOT
949                // derived from buffer lengths — those count total elements
950                // including all leading dims).
951                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
952                let outer = output.len() / n;
953                debug_assert_eq!(
954                    outer * 2 * n,
955                    input.len(),
956                    "FusedSwiGLU: input/output shape mismatch"
957                );
958                for o in 0..outer {
959                    let in_row = &input[o * 2 * n..(o + 1) * 2 * n];
960                    let out_row = &mut output[o * n..(o + 1) * n];
961                    for i in 0..n {
962                        let up = in_row[i];
963                        let gate = in_row[n + i];
964                        let silu_gate = gate / (1.0 + (-gate).exp());
965                        out_row[i] = up * silu_gate;
966                    }
967                }
968            }
969
970            // ── DenseSolve: x = A⁻¹ b (F32 / F64 via LAPACK) ────────
971            Op::DenseSolve => {
972                let a_shape = &graph.node(node.inputs[0]).shape;
973                let n = a_shape.dim(0).unwrap_static();
974                let b_elems = node.shape.num_elements().unwrap();
975                let nrhs = b_elems / n.max(1);
976                match node.shape.dtype() {
977                    rlx_ir::DType::F32 => {
978                        let a = get_data(arena, external, node.inputs[0]);
979                        let b = get_data(arena, external, node.inputs[1]);
980                        let x = get_output(arena, node_id);
981                        let mut a_scratch = a.to_vec();
982                        let mut x_buf = b.to_vec();
983                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n, nrhs);
984                        if info != 0 {
985                            panic!("DenseSolve: singular matrix (info={info})");
986                        }
987                        x[..x_buf.len()].copy_from_slice(&x_buf);
988                    }
989                    rlx_ir::DType::F64 => {
990                        let (a_ptr, a_len) = arena.raw_ptr(node.inputs[0]);
991                        let (b_ptr, b_len) = arena.raw_ptr(node.inputs[1]);
992                        let (x_ptr, x_len) = arena.raw_ptr(node_id);
993                        unsafe {
994                            let a_src = std::slice::from_raw_parts(a_ptr as *const f64, a_len / 8);
995                            let b_src = std::slice::from_raw_parts(b_ptr as *const f64, b_len / 8);
996                            let mut a_scratch = a_src.to_vec();
997                            let mut x_buf = b_src.to_vec();
998                            let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n, nrhs);
999                            if info != 0 {
1000                                panic!("DenseSolve: singular matrix (info={info})");
1001                            }
1002                            std::slice::from_raw_parts_mut(x_ptr as *mut f64, x_len / 8)
1003                                .copy_from_slice(&x_buf);
1004                        }
1005                    }
1006                    other => panic!("DenseSolve executor: unsupported dtype {other:?}"),
1007                }
1008            }
1009
1010            // ── Passthrough for unimplemented ops ───────────────────
1011            _ => {
1012                if !node.inputs.is_empty() && arena.has_buffer(node_id) {
1013                    let input = get_data(arena, external, node.inputs[0]);
1014                    let output = get_output(arena, node_id);
1015                    let len = output.len().min(input.len());
1016                    output[..len].copy_from_slice(&input[..len]);
1017                }
1018            }
1019        }
1020    }
1021}
1022
1023/// Get read-only data for a node — from external or arena via raw pointer.
1024/// SAFETY: the memory planner guarantees that input buffers don't overlap with
1025/// the output buffer being written, so concurrent read+write is safe.
1026fn get_data<'a>(arena: &'a Arena, external: &'a ExternalBuffers, id: NodeId) -> &'a [f32] {
1027    // Check external first (test mode, or runtime inputs copied via run())
1028    // Then arena (params pre-stored by set_param, computed intermediates)
1029    if let Some(&ext) = external.buffers.get(&id) {
1030        ext
1031    } else if arena.has_buffer(id) {
1032        let (ptr, len) = arena.raw_ptr(id);
1033        unsafe { std::slice::from_raw_parts(ptr, len) }
1034    } else {
1035        panic!("no data for node {id}")
1036    }
1037}
1038
1039/// Get mutable output buffer via raw pointer (doesn't borrow arena).
1040/// Takes `&Arena` (not `&mut Arena`) on purpose — the executor walks
1041/// the schedule with multiple node-buffer references live at once;
1042/// the arena allocator already partitioned them into non-overlapping
1043/// regions at compile time.
1044#[allow(clippy::mut_from_ref)]
1045fn get_output(arena: &Arena, id: NodeId) -> &mut [f32] {
1046    let (ptr, len) = arena.raw_ptr(id);
1047    unsafe { std::slice::from_raw_parts_mut(ptr, len) }
1048}
1049
1050/// Matrix multiply — uses BLAS when linked, naive fallback otherwise.
1051#[inline]
1052fn matmul(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
1053    // Use BLAS sgemm (Accelerate/MKL/OpenBLAS) — linked via build.rs
1054    crate::blas::sgemm(a, b, c, m, k, n);
1055}
1056
1057fn binary_op(op: rlx_ir::op::BinaryOp, a: f32, b: f32) -> f32 {
1058    use rlx_ir::op::BinaryOp::*;
1059    match op {
1060        Add => a + b,
1061        Sub => a - b,
1062        Mul => a * b,
1063        Div => a / b,
1064        Max => a.max(b),
1065        Min => a.min(b),
1066        Pow => a.powf(b),
1067    }
1068}
1069
1070fn compare_op(op: rlx_ir::op::CmpOp, a: f32, b: f32) -> bool {
1071    use rlx_ir::op::CmpOp::*;
1072    match op {
1073        Eq => a == b,
1074        Ne => a != b,
1075        Lt => a < b,
1076        Le => a <= b,
1077        Gt => a > b,
1078        Ge => a >= b,
1079    }
1080}
1081
1082// Reference scalar GELU — kept as a parity oracle for SIMD paths.
1083#[allow(dead_code)]
1084fn scalar_gelu(x: f32) -> f32 {
1085    let sign = if x >= 0.0 { 1.0f32 } else { -1.0 };
1086    let xa = x.abs();
1087    let t = 1.0 / (1.0 + 0.3275911 * xa);
1088    let y = t
1089        * (0.254_829_6
1090            + t * (-0.284_496_72 + t * (1.421_413_8 + t * (-1.453_152_1 + t * 1.061_405_4))));
1091    let erf = sign * (1.0 - y * (-xa * xa).exp());
1092    x * 0.5 * (1.0 + erf)
1093}
1094
1095#[cfg(test)]
1096mod tests {
1097    use super::*;
1098    use rlx_ir::*;
1099
1100    use rlx_opt::fusion::FuseMatMulBiasAct;
1101    use rlx_opt::memory;
1102    use rlx_opt::pass::Pass;
1103
1104    /// End-to-end test: build graph → fuse → plan memory → execute.
1105    #[test]
1106    fn execute_fused_matmul_bias_gelu() {
1107        // Build graph: x @ w + b → gelu
1108        let mut g = Graph::new("test");
1109        let x_id = g.input("x", Shape::new(&[2, 4], DType::F32));
1110        let w_id = g.param("w", Shape::new(&[4, 3], DType::F32));
1111        let b_id = g.param("b", Shape::new(&[3], DType::F32));
1112        let mm = g.matmul(x_id, w_id, Shape::new(&[2, 3], DType::F32));
1113        let add = g.binary(BinaryOp::Add, mm, b_id, Shape::new(&[2, 3], DType::F32));
1114        let out = g.activation(Activation::Gelu, add, Shape::new(&[2, 3], DType::F32));
1115        g.set_outputs(vec![out]);
1116
1117        // Fuse
1118        let fused = FuseMatMulBiasAct.run(g);
1119        println!("{fused}");
1120
1121        // Plan memory
1122        let plan = memory::plan_memory(&fused);
1123        println!("Arena: {} bytes", plan.arena_size);
1124
1125        // Prepare data
1126        let x_data = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // [2, 4] identity-ish
1127        let w_data = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]; // [4, 3]
1128        let b_data = vec![0.5, -0.5, 0.0]; // [3]
1129
1130        let mut ext = ExternalBuffers {
1131            buffers: HashMap::new(),
1132        };
1133        ext.buffers.insert(fused.outputs[0], &[]); // placeholder
1134        // Find input/param node IDs in fused graph
1135        for node in fused.nodes() {
1136            match &node.op {
1137                Op::Input { name } if name == "x" => {
1138                    ext.buffers.insert(node.id, &x_data);
1139                }
1140                Op::Param { name } if name == "w" => {
1141                    ext.buffers.insert(node.id, &w_data);
1142                }
1143                Op::Param { name } if name == "b" => {
1144                    ext.buffers.insert(node.id, &b_data);
1145                }
1146                _ => {}
1147            }
1148        }
1149
1150        // Execute
1151        let mut arena = Arena::from_plan(plan);
1152        execute(&fused, &mut arena, &ext);
1153
1154        // Check output
1155        let output_id = fused.outputs[0];
1156        let result = arena.slice(output_id);
1157        println!("Result: {result:?}");
1158
1159        // x @ w = [[1,0,0], [0,1,0]]; + bias = [[1.5,-0.5,0], [0.5,0.5,0]]
1160        // gelu(1.5) ≈ 1.399, gelu(-0.5) ≈ -0.154, gelu(0) = 0
1161        // gelu(0.5) ≈ 0.346
1162        assert!((result[0] - 1.399).abs() < 0.01, "got {}", result[0]);
1163        assert!((result[1] - -0.154).abs() < 0.01, "got {}", result[1]);
1164        assert!((result[2] - 0.0).abs() < 0.01, "got {}", result[2]);
1165        assert!((result[3] - 0.346).abs() < 0.01, "got {}", result[3]);
1166    }
1167
1168    /// Test Gather (embedding lookup).
1169    #[test]
1170    fn execute_gather() {
1171        use rlx_ir::infer::GraphExt;
1172        let mut g = Graph::new("gather_test");
1173        // Embedding table [4, 3] and indices [2] → output [2, 3]
1174        let table = g.param("table", Shape::new(&[4, 3], DType::F32));
1175        let indices = g.input("ids", Shape::new(&[2], DType::F32)); // f32 indices
1176        let out = g.gather_(table, indices, 0);
1177        g.set_outputs(vec![out]);
1178
1179        let plan = memory::plan_memory(&g);
1180        let mut arena = Arena::from_plan(plan);
1181
1182        let table_data = vec![
1183            10.0, 11.0, 12.0, // row 0
1184            20.0, 21.0, 22.0, // row 1
1185            30.0, 31.0, 32.0, // row 2
1186            40.0, 41.0, 42.0, // row 3
1187        ];
1188        let ids_data = vec![2.0, 0.0]; // gather rows 2 and 0
1189
1190        let mut ext = ExternalBuffers {
1191            buffers: HashMap::new(),
1192        };
1193        for node in g.nodes() {
1194            match &node.op {
1195                Op::Param { name } if name == "table" => {
1196                    ext.buffers.insert(node.id, &table_data);
1197                }
1198                Op::Input { name } if name == "ids" => {
1199                    ext.buffers.insert(node.id, &ids_data);
1200                }
1201                _ => {}
1202            }
1203        }
1204
1205        execute(&g, &mut arena, &ext);
1206        let result = arena.slice(g.outputs[0]);
1207        assert_eq!(&result[..3], &[30.0, 31.0, 32.0]); // row 2
1208        assert_eq!(&result[3..6], &[10.0, 11.0, 12.0]); // row 0
1209    }
1210
1211    /// Test Narrow (slice).
1212    #[test]
1213    fn execute_narrow() {
1214        use rlx_ir::infer::GraphExt;
1215        let mut g = Graph::new("narrow_test");
1216        let x = g.input("x", Shape::new(&[2, 6], DType::F32));
1217        let sliced = g.narrow_(x, 1, 2, 3); // take cols 2..5
1218        g.set_outputs(vec![sliced]);
1219
1220        let plan = memory::plan_memory(&g);
1221        let mut arena = Arena::from_plan(plan);
1222
1223        let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0];
1224        let mut ext = ExternalBuffers {
1225            buffers: HashMap::new(),
1226        };
1227        for node in g.nodes() {
1228            if let Op::Input { .. } = &node.op {
1229                ext.buffers.insert(node.id, &data);
1230            }
1231        }
1232
1233        execute(&g, &mut arena, &ext);
1234        let result = arena.slice(g.outputs[0]);
1235        assert_eq!(result, &[2.0, 3.0, 4.0, 8.0, 9.0, 10.0]);
1236    }
1237
1238    /// Test Softmax.
1239    #[test]
1240    fn execute_softmax() {
1241        use rlx_ir::infer::GraphExt;
1242        let mut g = Graph::new("softmax_test");
1243        let x = g.input("x", Shape::new(&[1, 4], DType::F32));
1244        let sm = g.sm(x, -1);
1245        g.set_outputs(vec![sm]);
1246
1247        let plan = memory::plan_memory(&g);
1248        let mut arena = Arena::from_plan(plan);
1249
1250        let data = vec![1.0, 2.0, 3.0, 4.0];
1251        let mut ext = ExternalBuffers {
1252            buffers: HashMap::new(),
1253        };
1254        for node in g.nodes() {
1255            if let Op::Input { .. } = &node.op {
1256                ext.buffers.insert(node.id, &data);
1257            }
1258        }
1259
1260        execute(&g, &mut arena, &ext);
1261        let result = arena.slice(g.outputs[0]);
1262        let sum: f32 = result.iter().sum();
1263        assert!(
1264            (sum - 1.0).abs() < 1e-5,
1265            "softmax should sum to 1, got {sum}"
1266        );
1267        // Values should be monotonically increasing
1268        assert!(result[0] < result[1]);
1269        assert!(result[1] < result[2]);
1270        assert!(result[2] < result[3]);
1271    }
1272
1273    /// Test RoPE (rotary position embedding).
1274    #[test]
1275    fn execute_rope() {
1276        use rlx_ir::infer::GraphExt;
1277        let head_dim = 4;
1278        let half = head_dim / 2;
1279        let seq = 2;
1280
1281        let mut g = Graph::new("rope_test");
1282        // x: [seq, head_dim], cos: [seq, half], sin: [seq, half]
1283        let x = g.input("x", Shape::new(&[seq, head_dim], DType::F32));
1284        let cos = g.param("cos", Shape::new(&[seq, half], DType::F32));
1285        let sin = g.param("sin", Shape::new(&[seq, half], DType::F32));
1286        let rotated = g.rope(x, cos, sin, head_dim);
1287        g.set_outputs(vec![rotated]);
1288
1289        let plan = memory::plan_memory(&g);
1290        let mut arena = Arena::from_plan(plan);
1291
1292        // x = [[1, 0, 0, 1], [1, 1, 0, 0]] (2 positions, head_dim=4)
1293        let x_data = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0f32];
1294        // cos = [[1, 0], [0, 1]], sin = [[0, 1], [1, 0]] (identity-ish rotation)
1295        let cos_data = vec![1.0, 0.0, 0.0, 1.0f32];
1296        let sin_data = vec![0.0, 1.0, 1.0, 0.0f32];
1297
1298        let mut ext = ExternalBuffers {
1299            buffers: HashMap::new(),
1300        };
1301        for node in g.nodes() {
1302            match &node.op {
1303                Op::Input { name } if name == "x" => {
1304                    ext.buffers.insert(node.id, &x_data);
1305                }
1306                Op::Param { name } if name == "cos" => {
1307                    ext.buffers.insert(node.id, &cos_data);
1308                }
1309                Op::Param { name } if name == "sin" => {
1310                    ext.buffers.insert(node.id, &sin_data);
1311                }
1312                _ => {}
1313            }
1314        }
1315
1316        execute(&g, &mut arena, &ext);
1317        let result = arena.slice(g.outputs[0]);
1318
1319        // Position 0: cos=[1,0], sin=[0,1]
1320        //   x1=1, x2=0 → x1*cos[0]-x2*sin[0] = 1*1-0*0 = 1
1321        //   x1=0, x2=1 → same half: x2*cos[0]+x1*sin[0] = 0*1+1*0 → wait
1322        // Actually: for i=0: x[0]=1, x[half+0]=0 → out[0]=1*1-0*0=1, out[2]=0*1+1*0=0
1323        //           for i=1: x[1]=0, x[half+1]=1 → out[1]=0*0-1*1=-1, out[3]=1*0+0*1=0
1324        assert!((result[0] - 1.0).abs() < 1e-5, "pos0[0]={}", result[0]);
1325        assert!((result[1] - -1.0).abs() < 1e-5, "pos0[1]={}", result[1]);
1326        assert!((result[2] - 0.0).abs() < 1e-5, "pos0[2]={}", result[2]);
1327        assert!((result[3] - 0.0).abs() < 1e-5, "pos0[3]={}", result[3]);
1328
1329        // Position 1: cos=[0,1], sin=[1,0]
1330        //   x=[1,1,0,0]: for i=0: 1*0-0*1=-0=0, out[half+0]=0*0+1*1=1
1331        //                 for i=1: 1*1-0*0=1, out[half+1]=0*1+1*0=0
1332        assert!((result[4] - 0.0).abs() < 1e-5, "pos1[0]={}", result[4]);
1333        assert!((result[5] - 1.0).abs() < 1e-5, "pos1[1]={}", result[5]);
1334        assert!((result[6] - 1.0).abs() < 1e-5, "pos1[2]={}", result[6]);
1335        assert!((result[7] - 0.0).abs() < 1e-5, "pos1[3]={}", result[7]);
1336    }
1337
1338    /// Test LayerNorm standalone.
1339    #[test]
1340    fn execute_layer_norm() {
1341        use rlx_ir::infer::GraphExt;
1342        let mut g = Graph::new("ln_test");
1343        let x = g.input("x", Shape::new(&[1, 4], DType::F32));
1344        let gamma = g.param("g", Shape::new(&[4], DType::F32));
1345        let beta = g.param("b", Shape::new(&[4], DType::F32));
1346        let ln = g.ln(x, gamma, beta, 1e-5);
1347        g.set_outputs(vec![ln]);
1348
1349        let plan = memory::plan_memory(&g);
1350        let mut arena = Arena::from_plan(plan);
1351
1352        let x_data = vec![1.0, 2.0, 3.0, 4.0];
1353        let g_data = vec![1.0, 1.0, 1.0, 1.0];
1354        let b_data = vec![0.0, 0.0, 0.0, 0.0];
1355
1356        let mut ext = ExternalBuffers {
1357            buffers: HashMap::new(),
1358        };
1359        for node in g.nodes() {
1360            match &node.op {
1361                Op::Input { name } if name == "x" => {
1362                    ext.buffers.insert(node.id, &x_data);
1363                }
1364                Op::Param { name } if name == "g" => {
1365                    ext.buffers.insert(node.id, &g_data);
1366                }
1367                Op::Param { name } if name == "b" => {
1368                    ext.buffers.insert(node.id, &b_data);
1369                }
1370                _ => {}
1371            }
1372        }
1373
1374        execute(&g, &mut arena, &ext);
1375        let result = arena.slice(g.outputs[0]);
1376        let sum: f32 = result.iter().sum();
1377        assert!(
1378            sum.abs() < 1e-3,
1379            "LN output should be zero-centered, sum={sum}"
1380        );
1381    }
1382}