Skip to main content

rlx_cpu/
executor.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Graph executor — runs a fused IR graph on CPU using the arena + kernels.
17//!
18//! The executor is the runtime hot path. For a 6-layer BERT, it makes
19//! ~24 kernel calls total (one per fused node). Everything else is
20//! inside the kernels — SIMD, BLAS, pre-allocated arena buffers.
21
22use crate::arena::Arena;
23use crate::kernels;
24use rlx_ir::op::{Activation, BinaryOp, ReduceOp};
25use rlx_ir::{Graph, NodeId, Op};
26use std::collections::HashMap;
27
28/// External data provided at runtime (model weights + inputs).
29pub struct ExternalBuffers<'a> {
30    /// Map from node ID (Input/Param nodes) to external f32 data.
31    pub buffers: HashMap<NodeId, &'a [f32]>,
32}
33
34/// Execute a compiled graph on CPU.
35///
36/// The graph should already be fused and memory-planned.
37/// `arena` holds all intermediate buffers.
38/// `external` provides input data and model weights.
39///
40/// Returns the output node IDs (data is in the arena).
41pub fn execute(graph: &Graph, arena: &mut Arena, external: &ExternalBuffers) {
42    let schedule: Vec<NodeId> = arena.schedule().to_vec();
43    for &node_id in &schedule {
44        let node = graph.node(node_id);
45
46        match &node.op {
47            // External data — skip (data provided via get_data which reads
48            // from external buffers directly without copying to arena)
49            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => {}
50
51            // ── Fused matmul + bias + optional activation ───────────
52            Op::FusedMatMulBiasAct { activation } => {
53                let input_id = node.inputs[0];
54                let weight_id = node.inputs[1];
55                let bias_id = node.inputs[2];
56
57                let input = get_data(arena, external, input_id);
58                let weight = get_data(arena, external, weight_id);
59                let bias = get_data(arena, external, bias_id);
60                let output = get_output(arena, node_id);
61
62                // Compute output shape for sgemm
63                let shape = &node.shape;
64                let n = shape.dim(shape.rank() - 1).unwrap_static();
65                let m = shape.num_elements().unwrap() / n;
66                let k = input.len() / m;
67
68                // sgemm: output = input @ weight
69                // TODO: call cblas_sgemm via FFI
70                // For now, naive matmul as placeholder
71                matmul(input, weight, output, m, k, n);
72
73                // Fused bias + activation (parallel NEON kernels)
74                match activation {
75                    Some(Activation::Gelu) => kernels::par_bias_gelu(output, bias, m, n),
76                    Some(Activation::Silu) => {
77                        crate::blas::bias_add(output, bias, m, n);
78                        kernels::silu_inplace(output);
79                    }
80                    _ => crate::blas::bias_add(output, bias, m, n),
81                }
82            }
83
84            // ── Fused residual + LayerNorm (parallel NEON) ──────────
85            Op::FusedResidualLN { has_bias, eps } => {
86                let x_id = node.inputs[0];
87                let residual_id = node.inputs[1];
88                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
89                let zero_bias = vec![0f32; h];
90                let (gamma_id, beta_id, bias_slice) = if *has_bias {
91                    let b = get_data(arena, external, node.inputs[2]);
92                    (node.inputs[3], node.inputs[4], b)
93                } else {
94                    (node.inputs[2], node.inputs[3], zero_bias.as_slice())
95                };
96
97                let x = get_data(arena, external, x_id);
98                let residual = get_data(arena, external, residual_id);
99                let gamma = get_data(arena, external, gamma_id);
100                let beta = get_data(arena, external, beta_id);
101                let output = get_output(arena, node_id);
102
103                let n = x.len() / h;
104
105                // Parallel: each thread processes a chunk of rows
106                let x_ptr = x.as_ptr() as usize;
107                let r_ptr = residual.as_ptr() as usize;
108                let o_ptr = output.as_mut_ptr() as usize;
109                let bi_ptr = bias_slice.as_ptr() as usize;
110                let g_ptr = gamma.as_ptr() as usize;
111                let b_ptr = beta.as_ptr() as usize;
112                let e = *eps;
113                crate::pool::par_for(n, 4, &|off, cnt| unsafe {
114                    let x_s =
115                        std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
116                    let r_s =
117                        std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
118                    let o_s =
119                        std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
120                    let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
121                    let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
122                    let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
123                    kernels::residual_bias_layer_norm(x_s, r_s, bi, g, b, o_s, cnt, h, e);
124                });
125            }
126
127            // ── Fused residual + RMSNorm (parallel) ─────────────────
128            Op::FusedResidualRmsNorm { has_bias, eps } => {
129                let x_id = node.inputs[0];
130                let residual_id = node.inputs[1];
131                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
132                let zero_bias = vec![0f32; h];
133                let (gamma_id, beta_id, bias_slice) = if *has_bias {
134                    let b = get_data(arena, external, node.inputs[2]);
135                    (node.inputs[3], node.inputs[4], b)
136                } else {
137                    (node.inputs[2], node.inputs[3], zero_bias.as_slice())
138                };
139
140                let x = get_data(arena, external, x_id);
141                let residual = get_data(arena, external, residual_id);
142                let gamma = get_data(arena, external, gamma_id);
143                let beta = get_data(arena, external, beta_id);
144                let output = get_output(arena, node_id);
145
146                let n = x.len() / h;
147
148                let x_ptr = x.as_ptr() as usize;
149                let r_ptr = residual.as_ptr() as usize;
150                let o_ptr = output.as_mut_ptr() as usize;
151                let bi_ptr = bias_slice.as_ptr() as usize;
152                let g_ptr = gamma.as_ptr() as usize;
153                let b_ptr = beta.as_ptr() as usize;
154                let e = *eps;
155                crate::pool::par_for(n, 4, &|off, cnt| unsafe {
156                    let x_s =
157                        std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
158                    let r_s =
159                        std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
160                    let o_s =
161                        std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
162                    let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
163                    let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
164                    let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
165                    kernels::residual_bias_rms_norm(x_s, r_s, bi, g, b, o_s, cnt, h, e);
166                });
167            }
168
169            // ── Plain matmul ────────────────────────────────────────
170            Op::MatMul => {
171                let lhs = get_data(arena, external, node.inputs[0]);
172                let rhs = get_data(arena, external, node.inputs[1]);
173                let output = get_output(arena, node_id);
174
175                let shape = &node.shape;
176                let lhs_shape = &graph.node(node.inputs[0]).shape;
177                let rhs_shape = &graph.node(node.inputs[1]).shape;
178                let n = shape.dim(shape.rank() - 1).unwrap_static();
179                let out_m_inner = shape.dim(shape.rank() - 2).unwrap_static();
180                let k = lhs_shape.dim(lhs_shape.rank() - 1).unwrap_static();
181
182                // Outer batch dims — present when either input has rank > 2.
183                // Compute total batch as output.num_elements / (M * N).
184                let total = shape.num_elements().unwrap();
185                let per_batch_out = out_m_inner * n;
186                let batches = total / per_batch_out;
187
188                if batches == 1 {
189                    matmul(lhs, rhs, output, out_m_inner, k, n);
190                } else {
191                    let lhs_batched =
192                        lhs_shape.num_elements().unwrap_or(0) == batches * out_m_inner * k;
193                    let rhs_batched = rhs_shape.num_elements().unwrap_or(0) == batches * k * n;
194                    for b in 0..batches {
195                        let l_off = if lhs_batched { b * out_m_inner * k } else { 0 };
196                        let r_off = if rhs_batched { b * k * n } else { 0 };
197                        let o_off = b * out_m_inner * n;
198                        let l_slice = &lhs[l_off..l_off + out_m_inner * k];
199                        let r_slice = &rhs[r_off..r_off + k * n];
200                        let o_slice = &mut output[o_off..o_off + out_m_inner * n];
201                        matmul(l_slice, r_slice, o_slice, out_m_inner, k, n);
202                    }
203                }
204            }
205
206            // ── Element-wise binary ─────────────────────────────────
207            Op::Binary(op) => {
208                let lhs = get_data(arena, external, node.inputs[0]);
209                let rhs = get_data(arena, external, node.inputs[1]);
210                let output = get_output(arena, node_id);
211                let len = output.len();
212                let rhs_len = rhs.len();
213
214                // Fast path: Add with broadcast bias → NEON bias_add
215                if matches!(op, BinaryOp::Add) && rhs_len < len && len.is_multiple_of(rhs_len) {
216                    output.copy_from_slice(lhs);
217                    crate::blas::bias_add(output, rhs, len / rhs_len, rhs_len);
218                } else if rhs_len == len {
219                    for i in 0..len {
220                        output[i] = binary_op(*op, lhs[i], rhs[i]);
221                    }
222                } else {
223                    for i in 0..len {
224                        output[i] = binary_op(*op, lhs[i], rhs[i % rhs_len]);
225                    }
226                }
227            }
228
229            // ── Unary activation ────────────────────────────────────
230            Op::Activation(act) => {
231                let input = get_data(arena, external, node.inputs[0]);
232                let output = get_output(arena, node_id);
233                output.copy_from_slice(input);
234                let zeros = vec![0f32; node.shape.dim(node.shape.rank() - 1).unwrap_static()];
235                let m = output.len() / zeros.len();
236                let n = zeros.len();
237                match act {
238                    Activation::Gelu => kernels::par_bias_gelu(output, &zeros, m, n),
239                    Activation::Silu => kernels::silu_inplace(output),
240                    Activation::Relu => {
241                        for v in output.iter_mut() {
242                            *v = v.max(0.0);
243                        }
244                    }
245                    Activation::Exp => {
246                        for v in output.iter_mut() {
247                            *v = v.exp();
248                        }
249                    }
250                    Activation::Sqrt => {
251                        for v in output.iter_mut() {
252                            *v = v.sqrt();
253                        }
254                    }
255                    Activation::Neg => {
256                        for v in output.iter_mut() {
257                            *v = -*v;
258                        }
259                    }
260                    Activation::Tanh => {
261                        for v in output.iter_mut() {
262                            *v = v.tanh();
263                        }
264                    }
265                    Activation::Sigmoid => {
266                        for v in output.iter_mut() {
267                            *v = 1.0 / (1.0 + (-*v).exp());
268                        }
269                    }
270                    _ => {}
271                }
272            }
273
274            // ── Gather (embedding lookup) ───────────────────────────
275            Op::Gather { axis } => {
276                let table = get_data(arena, external, node.inputs[0]);
277                let indices = get_data(arena, external, node.inputs[1]);
278                let output = get_output(arena, node_id);
279
280                let table_shape = &graph.node(node.inputs[0]).shape;
281                let _out_shape = &node.shape;
282
283                // For axis=0 (embedding): table[V, D...], indices[B, S] → [B, S, D...]
284                if *axis == 0 {
285                    let trailing: usize = (1..table_shape.rank())
286                        .map(|i| table_shape.dim(i).unwrap_static())
287                        .product();
288                    for (i, &idx_f32) in indices.iter().enumerate() {
289                        let idx = idx_f32 as usize;
290                        let src = idx * trailing;
291                        let dst = i * trailing;
292                        output[dst..dst + trailing].copy_from_slice(&table[src..src + trailing]);
293                    }
294                } else {
295                    // General gather — fallback
296                    output.fill(0.0);
297                }
298            }
299
300            // ── Narrow (slice along axis) ───────────────────────────
301            Op::Narrow { axis, start, len } => {
302                let input = get_data(arena, external, node.inputs[0]);
303                let output = get_output(arena, node_id);
304                let in_shape = &graph.node(node.inputs[0]).shape;
305
306                let rank = in_shape.rank();
307                let outer: usize = (0..*axis)
308                    .map(|i| in_shape.dim(i).unwrap_static())
309                    .product::<usize>()
310                    .max(1);
311                let inner: usize = (*axis + 1..rank)
312                    .map(|i| in_shape.dim(i).unwrap_static())
313                    .product::<usize>()
314                    .max(1);
315                let in_axis_size = in_shape.dim(*axis).unwrap_static();
316
317                for o in 0..outer {
318                    for s in 0..*len {
319                        let src_off = o * in_axis_size * inner + (*start + s) * inner;
320                        let dst_off = o * len * inner + s * inner;
321                        output[dst_off..dst_off + inner]
322                            .copy_from_slice(&input[src_off..src_off + inner]);
323                    }
324                }
325            }
326
327            // ── Transpose ───────────────────────────────────────────
328            Op::Transpose { perm } => {
329                let input = get_data(arena, external, node.inputs[0]);
330                let output = get_output(arena, node_id);
331                let in_shape = &graph.node(node.inputs[0]).shape;
332                let rank = in_shape.rank();
333
334                let in_dims: Vec<usize> =
335                    (0..rank).map(|i| in_shape.dim(i).unwrap_static()).collect();
336                let out_dims: Vec<usize> = perm.iter().map(|&i| in_dims[i]).collect();
337
338                // Row-major strides for input and output spaces.
339                // For a shape [d0, d1, ..., d_{r-1}], stride[i] = product(d_{i+1..r}).
340                let mut in_strides = vec![1usize; rank];
341                for i in (0..rank - 1).rev() {
342                    in_strides[i] = in_strides[i + 1] * in_dims[i + 1];
343                }
344                let mut out_strides = vec![1usize; rank];
345                for i in (0..rank - 1).rev() {
346                    out_strides[i] = out_strides[i + 1] * out_dims[i + 1];
347                }
348
349                let total = output.len();
350                for flat_out in 0..total {
351                    let mut in_flat = 0;
352                    for d in 0..rank {
353                        // out_coord[d] decoded from flat_out via output strides.
354                        let coord = (flat_out / out_strides[d]) % out_dims[d];
355                        // Output dim d came from input dim perm[d].
356                        in_flat += coord * in_strides[perm[d]];
357                    }
358                    output[flat_out] = input[in_flat];
359                }
360            }
361
362            // ── Concat ──────────────────────────────────────────────
363            Op::Concat { axis } => {
364                let output = get_output(arena, node_id);
365                let out_shape = &node.shape;
366                let rank = out_shape.rank();
367
368                let outer: usize = (0..*axis)
369                    .map(|i| out_shape.dim(i).unwrap_static())
370                    .product::<usize>()
371                    .max(1);
372                let inner: usize = (*axis + 1..rank)
373                    .map(|i| out_shape.dim(i).unwrap_static())
374                    .product::<usize>()
375                    .max(1);
376
377                let mut dst_off = 0;
378                for o in 0..outer {
379                    for &inp_id in &node.inputs {
380                        let inp = get_data(arena, external, inp_id);
381                        let inp_shape = &graph.node(inp_id).shape;
382                        let inp_axis = inp_shape.dim(*axis).unwrap_static();
383                        let chunk = inp_axis * inner;
384                        let src_off = o * chunk;
385                        output[dst_off..dst_off + chunk]
386                            .copy_from_slice(&inp[src_off..src_off + chunk]);
387                        dst_off += chunk;
388                    }
389                }
390            }
391
392            // ── Reshape (zero-copy: same data, different shape) ─────
393            Op::Reshape { .. } | Op::Expand { .. } => {
394                let input = get_data(arena, external, node.inputs[0]);
395                let output = get_output(arena, node_id);
396                output[..input.len()].copy_from_slice(input);
397            }
398
399            // ── LayerNorm (parallel NEON) ────────────────────────────
400            Op::LayerNorm { eps, .. } => {
401                let input = get_data(arena, external, node.inputs[0]);
402                let gamma = get_data(arena, external, node.inputs[1]);
403                let beta = get_data(arena, external, node.inputs[2]);
404                let output = get_output(arena, node_id);
405                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
406                let n = input.len() / h;
407                for row in 0..n {
408                    let base = row * h;
409                    kernels::layer_norm_row(
410                        &input[base..base + h],
411                        gamma,
412                        beta,
413                        &mut output[base..base + h],
414                        h,
415                        *eps,
416                    );
417                }
418            }
419
420            Op::GroupNorm { num_groups, eps } => {
421                let input = get_data(arena, external, node.inputs[0]);
422                let gamma = get_data(arena, external, node.inputs[1]);
423                let beta = get_data(arena, external, node.inputs[2]);
424                let output = get_output(arena, node_id);
425                let n = node.shape.dim(0).unwrap_static();
426                let c = node.shape.dim(1).unwrap_static();
427                let h = node.shape.dim(2).unwrap_static();
428                let w = node.shape.dim(3).unwrap_static();
429                kernels::group_norm_nchw(input, gamma, beta, output, n, c, h, w, *num_groups, *eps);
430            }
431
432            Op::ResizeNearest2x => {
433                let input = get_data(arena, external, node.inputs[0]);
434                let output = get_output(arena, node_id);
435                let n = node.shape.dim(0).unwrap_static();
436                let c = node.shape.dim(1).unwrap_static();
437                let h = node.shape.dim(2).unwrap_static() / 2;
438                let w = node.shape.dim(3).unwrap_static() / 2;
439                let in_plane = c * h * w;
440                let out_plane = c * h * 2 * w * 2;
441                for ni in 0..n {
442                    kernels::resize_nearest_2x_nchw(
443                        &input[ni * in_plane..(ni + 1) * in_plane],
444                        &mut output[ni * out_plane..(ni + 1) * out_plane],
445                        c,
446                        h,
447                        w,
448                    );
449                }
450            }
451
452            Op::AxialRope2d {
453                end_x,
454                end_y,
455                head_dim,
456                num_heads,
457                theta,
458                repeat_factor,
459            } => {
460                let input = get_data(arena, external, node.inputs[0]);
461                let output = get_output(arena, node_id);
462                let batch = node.shape.dim(0).unwrap_static();
463                let seq = node.shape.dim(1).unwrap_static();
464                let plane = seq * node.shape.dim(2).unwrap_static();
465                for bi in 0..batch {
466                    let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
467                        &input[bi * plane..(bi + 1) * plane],
468                        *num_heads,
469                        seq,
470                        *head_dim,
471                        *end_x,
472                        *end_y,
473                        *theta,
474                        *repeat_factor,
475                    );
476                    output[bi * plane..(bi + 1) * plane].copy_from_slice(&rotated);
477                }
478            }
479
480            // ── Softmax ─────────────────────────────────────────────
481            Op::Softmax { axis } => {
482                let input = get_data(arena, external, node.inputs[0]);
483                let output = get_output(arena, node_id);
484                output.copy_from_slice(input);
485                let rank = node.shape.rank();
486                let ax = if *axis < 0 {
487                    (rank as i32 + axis) as usize
488                } else {
489                    *axis as usize
490                };
491                let cols = node.shape.dim(ax).unwrap_static();
492                let rows = output.len() / cols;
493                crate::naive::softmax(output, rows, cols);
494            }
495
496            // ── Attention (SDPA) — BLAS-accelerated ─────────────────
497            Op::Attention {
498                num_heads,
499                head_dim,
500                mask_kind,
501            } => {
502                let q = get_data(arena, external, node.inputs[0]);
503                let k = get_data(arena, external, node.inputs[1]);
504                let v = get_data(arena, external, node.inputs[2]);
505                // For non-Custom mask kinds the IR emits no mask input —
506                // synthesize an empty slice so the masking branch below
507                // sees `mask.len() < ...` and skips.
508                let mask: &[f32] = if matches!(
509                    mask_kind,
510                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
511                ) {
512                    get_data(arena, external, node.inputs[3])
513                } else {
514                    &[]
515                };
516                let output = get_output(arena, node_id);
517
518                let q_shape = &graph.node(node.inputs[0]).shape;
519                let k_shape = &graph.node(node.inputs[1]).shape;
520                let hs = num_heads * head_dim;
521                let scale = (*head_dim as f32).powf(-0.5);
522                let (batch_size, s_q) = if q_shape.rank() >= 3 {
523                    (
524                        q_shape.dim(0).unwrap_static(),
525                        q_shape.dim(1).unwrap_static(),
526                    )
527                } else {
528                    (1, q_shape.dim(0).unwrap_static())
529                };
530                // K and V share Lk. In decode mode Lk = past+1 and Lq = 1;
531                // in prefill Lq = Lk. Causal/SlidingWindow masking is
532                // expressed in absolute positions: Q-row qi is at absolute
533                // position (Lk - Lq) + qi, so masking shifts accordingly.
534                let s_k = if k_shape.rank() >= 3 {
535                    k_shape.dim(1).unwrap_static()
536                } else {
537                    k_shape.dim(0).unwrap_static()
538                };
539                let q_offset = s_k.saturating_sub(s_q);
540
541                // Pre-allocate buffers ONCE (reused across heads)
542                let q_buf_len = s_q * head_dim;
543                let k_buf_len = s_k * head_dim;
544                let mut q_head = vec![0f32; q_buf_len];
545                let mut k_head = vec![0f32; k_buf_len];
546                let mut v_head = vec![0f32; k_buf_len];
547                let mut scores = vec![0f32; s_q * s_k];
548                let mut out_head = vec![0f32; q_buf_len];
549
550                for bi in 0..batch_size {
551                    for hi in 0..*num_heads {
552                        // Gather per-head Q (Lq rows).
553                        for si in 0..s_q {
554                            let off = bi * s_q * hs + si * hs + hi * head_dim;
555                            q_head[si * head_dim..(si + 1) * head_dim]
556                                .copy_from_slice(&q[off..off + head_dim]);
557                        }
558                        // Gather per-head K, V (Lk rows).
559                        for si in 0..s_k {
560                            let off = bi * s_k * hs + si * hs + hi * head_dim;
561                            k_head[si * head_dim..(si + 1) * head_dim]
562                                .copy_from_slice(&k[off..off + head_dim]);
563                            v_head[si * head_dim..(si + 1) * head_dim]
564                                .copy_from_slice(&v[off..off + head_dim]);
565                        }
566                        // Q@K^T: scores[Lq, Lk]. Use NEON dots when the
567                        // larger of Lq/Lk is small; BLAS otherwise.
568                        if s_q.max(s_k) <= 32 {
569                            for qi in 0..s_q {
570                                for ki in 0..s_k {
571                                    let q_off = qi * head_dim;
572                                    let k_off = ki * head_dim;
573                                    #[cfg(target_arch = "aarch64")]
574                                    let mut dot;
575                                    #[cfg(not(target_arch = "aarch64"))]
576                                    let mut dot = 0f32;
577                                    #[cfg(target_arch = "aarch64")]
578                                    unsafe {
579                                        use std::arch::aarch64::*;
580                                        let chunks = head_dim / 4;
581                                        let mut acc = vdupq_n_f32(0.0);
582                                        for c in 0..chunks {
583                                            let vq = vld1q_f32(q_head.as_ptr().add(q_off + c * 4));
584                                            let vk = vld1q_f32(k_head.as_ptr().add(k_off + c * 4));
585                                            acc = vfmaq_f32(acc, vq, vk);
586                                        }
587                                        dot = vaddvq_f32(acc);
588                                        for d in (chunks * 4)..*head_dim {
589                                            dot += q_head[q_off + d] * k_head[k_off + d];
590                                        }
591                                    }
592                                    #[cfg(not(target_arch = "aarch64"))]
593                                    {
594                                        for d in 0..*head_dim {
595                                            dot += q_head[q_off + d] * k_head[k_off + d];
596                                        }
597                                    }
598                                    scores[qi * s_k + ki] = dot * scale;
599                                }
600                            }
601                        } else {
602                            crate::blas::sgemm_bt(
603                                &q_head,
604                                &k_head,
605                                &mut scores,
606                                s_q,
607                                *head_dim,
608                                s_k,
609                                scale,
610                            );
611                        }
612                        // Mask: branch on kind so None / Causal skip the
613                        // mask load entirely. Causal/SlidingWindow use
614                        // absolute positions so they handle Lq != Lk
615                        // (decode-mode with cached K/V).
616                        match mask_kind {
617                            rlx_ir::op::MaskKind::None => {}
618                            rlx_ir::op::MaskKind::Causal => {
619                                for qi in 0..s_q {
620                                    let abs_q = q_offset + qi;
621                                    for ki in (abs_q + 1)..s_k {
622                                        scores[qi * s_k + ki] = -1e9;
623                                    }
624                                }
625                            }
626                            rlx_ir::op::MaskKind::SlidingWindow(w) => {
627                                for qi in 0..s_q {
628                                    let abs_q = q_offset + qi;
629                                    let lo = abs_q.saturating_sub(*w);
630                                    for ki in 0..s_k {
631                                        if ki < lo || ki > abs_q {
632                                            scores[qi * s_k + ki] = -1e9;
633                                        }
634                                    }
635                                }
636                            }
637                            rlx_ir::op::MaskKind::Custom => {
638                                if mask.len() >= (bi + 1) * s_k {
639                                    let m = &mask[bi * s_k..(bi + 1) * s_k];
640                                    for qi in 0..s_q {
641                                        for ki in 0..s_k {
642                                            if m[ki] < 0.5 {
643                                                scores[qi * s_k + ki] = -1e9;
644                                            }
645                                        }
646                                    }
647                                }
648                            }
649                            rlx_ir::op::MaskKind::Bias => {
650                                // Bias is [batch, num_heads, s_q, s_k]
651                                // (additive, pre-softmax). Skip if the
652                                // buffer wasn't supplied.
653                                let per_bh = s_q * s_k;
654                                let need = (bi * *num_heads + hi + 1) * per_bh;
655                                if mask.len() >= need {
656                                    let bias_off = (bi * *num_heads + hi) * per_bh;
657                                    let b = &mask[bias_off..bias_off + per_bh];
658                                    for i in 0..per_bh {
659                                        scores[i] += b[i];
660                                    }
661                                }
662                            }
663                        }
664                        crate::naive::softmax(&mut scores, s_q, s_k);
665                        // scores[Lq, Lk] @ V[Lk, head_dim] → out_head[Lq, head_dim]
666                        if s_q.max(s_k) <= 32 {
667                            out_head.fill(0.0);
668                            for qi in 0..s_q {
669                                for ki in 0..s_k {
670                                    let sc = scores[qi * s_k + ki];
671                                    if sc > 1e-8 {
672                                        let v_off = ki * head_dim;
673                                        let o_off = qi * head_dim;
674                                        #[cfg(target_arch = "aarch64")]
675                                        unsafe {
676                                            use std::arch::aarch64::*;
677                                            let vsc = vdupq_n_f32(sc);
678                                            let chunks = head_dim / 4;
679                                            for c in 0..chunks {
680                                                let off = c * 4;
681                                                let vo =
682                                                    vld1q_f32(out_head.as_ptr().add(o_off + off));
683                                                let vv =
684                                                    vld1q_f32(v_head.as_ptr().add(v_off + off));
685                                                vst1q_f32(
686                                                    out_head.as_mut_ptr().add(o_off + off),
687                                                    vfmaq_f32(vo, vsc, vv),
688                                                );
689                                            }
690                                        }
691                                        #[cfg(not(target_arch = "aarch64"))]
692                                        for d in 0..*head_dim {
693                                            out_head[o_off + d] += sc * v_head[v_off + d];
694                                        }
695                                    }
696                                }
697                            }
698                        } else {
699                            crate::blas::sgemm(
700                                &scores,
701                                &v_head,
702                                &mut out_head,
703                                s_q,
704                                s_k,
705                                *head_dim,
706                            );
707                        }
708                        // Scatter back into [B, Lq, hs].
709                        for si in 0..s_q {
710                            let off = bi * s_q * hs + si * hs + hi * head_dim;
711                            output[off..off + head_dim]
712                                .copy_from_slice(&out_head[si * head_dim..(si + 1) * head_dim]);
713                        }
714                    }
715                }
716            }
717
718            // ── Rotary position embedding ────────────────────────────
719            Op::Rope { head_dim, n_rot } => {
720                let head_dim = *head_dim;
721                let n_rot = *n_rot;
722                let x = get_data(arena, external, node.inputs[0]);
723                let cos_cache = get_data(arena, external, node.inputs[1]);
724                let sin_cache = get_data(arena, external, node.inputs[2]);
725                let output = get_output(arena, node_id);
726                output.copy_from_slice(x);
727
728                let rot_half = n_rot / 2;
729                let tab_half = head_dim / 2;
730                let total = output.len();
731                let num_chunks = total / head_dim;
732                for chunk in 0..num_chunks {
733                    let off = chunk * head_dim;
734                    let cos_len = cos_cache.len();
735                    let max_seq = cos_len / tab_half;
736                    let pos = chunk % max_seq;
737                    let cos_off = pos * tab_half;
738
739                    for i in 0..rot_half {
740                        let cos_v = cos_cache[cos_off + i];
741                        let sin_v = sin_cache[cos_off + i];
742                        let x1 = output[off + i];
743                        let x2 = output[off + rot_half + i];
744                        output[off + i] = x1 * cos_v - x2 * sin_v;
745                        output[off + rot_half + i] = x2 * cos_v + x1 * sin_v;
746                    }
747                    output[(n_rot + off)..(head_dim + off)]
748                        .copy_from_slice(&x[(n_rot + off)..(head_dim + off)]);
749                }
750            }
751
752            // ── Compare ─────────────────────────────────────────────
753            Op::Compare(cmp) => {
754                let lhs = get_data(arena, external, node.inputs[0]);
755                let rhs = get_data(arena, external, node.inputs[1]);
756                let output = get_output(arena, node_id);
757                let rhs_len = rhs.len();
758                for i in 0..output.len() {
759                    let a = lhs[i];
760                    let b = rhs[i % rhs_len];
761                    output[i] = if compare_op(*cmp, a, b) { 1.0 } else { 0.0 };
762                }
763            }
764
765            // ── Where (conditional select) ──────────────────────────
766            Op::Where => {
767                let cond = get_data(arena, external, node.inputs[0]);
768                let on_true = get_data(arena, external, node.inputs[1]);
769                let on_false = get_data(arena, external, node.inputs[2]);
770                let output = get_output(arena, node_id);
771                for i in 0..output.len() {
772                    output[i] = if cond[i] > 0.5 {
773                        on_true[i]
774                    } else {
775                        on_false[i]
776                    };
777                }
778            }
779
780            // ── Reduce ──────────────────────────────────────────────
781            Op::Reduce {
782                op: reduce_op,
783                axes,
784                keep_dim: _,
785            } => {
786                let input = get_data(arena, external, node.inputs[0]);
787                let output = get_output(arena, node_id);
788                output.fill(0.0);
789                // Simple: only handle single-axis reduction for now
790                if axes.len() == 1 {
791                    let in_shape = &graph.node(node.inputs[0]).shape;
792                    let axis = axes[0];
793                    let rank = in_shape.rank();
794                    let outer: usize = (0..axis)
795                        .map(|i| in_shape.dim(i).unwrap_static())
796                        .product::<usize>()
797                        .max(1);
798                    let axis_size = in_shape.dim(axis).unwrap_static();
799                    let inner: usize = (axis + 1..rank)
800                        .map(|i| in_shape.dim(i).unwrap_static())
801                        .product::<usize>()
802                        .max(1);
803
804                    match reduce_op {
805                        ReduceOp::Sum | ReduceOp::Mean => {
806                            for o in 0..outer {
807                                for i in 0..inner {
808                                    let mut acc = 0f32;
809                                    for a in 0..axis_size {
810                                        acc += input[o * axis_size * inner + a * inner + i];
811                                    }
812                                    if matches!(reduce_op, ReduceOp::Mean) {
813                                        acc /= axis_size as f32;
814                                    }
815                                    output[o * inner + i] = acc;
816                                }
817                            }
818                        }
819                        ReduceOp::Max => {
820                            output.fill(f32::NEG_INFINITY);
821                            for o in 0..outer {
822                                for i in 0..inner {
823                                    for a in 0..axis_size {
824                                        let v = input[o * axis_size * inner + a * inner + i];
825                                        let idx = o * inner + i;
826                                        if v > output[idx] {
827                                            output[idx] = v;
828                                        }
829                                    }
830                                }
831                            }
832                        }
833                        _ => {} // Min, Prod — TODO
834                    }
835                }
836            }
837
838            // ── Cast ────────────────────────────────────────────────
839            Op::Cast { .. } => {
840                let input = get_data(arena, external, node.inputs[0]);
841                let output = get_output(arena, node_id);
842                output[..input.len()].copy_from_slice(input);
843            }
844
845            // ── Fused SwiGLU ────────────────────────────────────────
846            // Input layout: concatenated [..., 2N] tensor where the first
847            // N elements per row are the "up" projection and the next N
848            // are the "gate" projection. Output: [..., N] = up * silu(gate).
849            // `cast_to` is currently advisory: rlx-cpu always operates in
850            // f32, so backends that distinguish dtypes apply the cast; the
851            // CPU executor stores the f32 result regardless.
852            Op::FusedSwiGLU { cast_to: _, .. } => {
853                let input = get_data(arena, external, node.inputs[0]);
854                let output = get_output(arena, node_id);
855                // n = last-dim half (read from the node's own shape, NOT
856                // derived from buffer lengths — those count total elements
857                // including all leading dims).
858                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
859                let outer = output.len() / n;
860                debug_assert_eq!(
861                    outer * 2 * n,
862                    input.len(),
863                    "FusedSwiGLU: input/output shape mismatch"
864                );
865                for o in 0..outer {
866                    let in_row = &input[o * 2 * n..(o + 1) * 2 * n];
867                    let out_row = &mut output[o * n..(o + 1) * n];
868                    for i in 0..n {
869                        let up = in_row[i];
870                        let gate = in_row[n + i];
871                        let silu_gate = gate / (1.0 + (-gate).exp());
872                        out_row[i] = up * silu_gate;
873                    }
874                }
875            }
876
877            // ── DenseSolve: x = A⁻¹ b (F32 / F64 via LAPACK) ────────
878            Op::DenseSolve => {
879                let a_shape = &graph.node(node.inputs[0]).shape;
880                let n = a_shape.dim(0).unwrap_static();
881                let b_elems = node.shape.num_elements().unwrap();
882                let nrhs = b_elems / n.max(1);
883                match node.shape.dtype() {
884                    rlx_ir::DType::F32 => {
885                        let a = get_data(arena, external, node.inputs[0]);
886                        let b = get_data(arena, external, node.inputs[1]);
887                        let x = get_output(arena, node_id);
888                        let mut a_scratch = a.to_vec();
889                        let mut x_buf = b.to_vec();
890                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n, nrhs);
891                        if info != 0 {
892                            panic!("DenseSolve: singular matrix (info={info})");
893                        }
894                        x[..x_buf.len()].copy_from_slice(&x_buf);
895                    }
896                    rlx_ir::DType::F64 => {
897                        let (a_ptr, a_len) = arena.raw_ptr(node.inputs[0]);
898                        let (b_ptr, b_len) = arena.raw_ptr(node.inputs[1]);
899                        let (x_ptr, x_len) = arena.raw_ptr(node_id);
900                        unsafe {
901                            let a_src = std::slice::from_raw_parts(a_ptr as *const f64, a_len / 8);
902                            let b_src = std::slice::from_raw_parts(b_ptr as *const f64, b_len / 8);
903                            let mut a_scratch = a_src.to_vec();
904                            let mut x_buf = b_src.to_vec();
905                            let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n, nrhs);
906                            if info != 0 {
907                                panic!("DenseSolve: singular matrix (info={info})");
908                            }
909                            std::slice::from_raw_parts_mut(x_ptr as *mut f64, x_len / 8)
910                                .copy_from_slice(&x_buf);
911                        }
912                    }
913                    other => panic!("DenseSolve executor: unsupported dtype {other:?}"),
914                }
915            }
916
917            // ── Passthrough for unimplemented ops ───────────────────
918            _ => {
919                if !node.inputs.is_empty() && arena.has_buffer(node_id) {
920                    let input = get_data(arena, external, node.inputs[0]);
921                    let output = get_output(arena, node_id);
922                    let len = output.len().min(input.len());
923                    output[..len].copy_from_slice(&input[..len]);
924                }
925            }
926        }
927    }
928}
929
930/// Get read-only data for a node — from external or arena via raw pointer.
931/// SAFETY: the memory planner guarantees that input buffers don't overlap with
932/// the output buffer being written, so concurrent read+write is safe.
933fn get_data<'a>(arena: &'a Arena, external: &'a ExternalBuffers, id: NodeId) -> &'a [f32] {
934    // Check external first (test mode, or runtime inputs copied via run())
935    // Then arena (params pre-stored by set_param, computed intermediates)
936    if let Some(&ext) = external.buffers.get(&id) {
937        ext
938    } else if arena.has_buffer(id) {
939        let (ptr, len) = arena.raw_ptr(id);
940        unsafe { std::slice::from_raw_parts(ptr, len) }
941    } else {
942        panic!("no data for node {id}")
943    }
944}
945
946/// Get mutable output buffer via raw pointer (doesn't borrow arena).
947/// Takes `&Arena` (not `&mut Arena`) on purpose — the executor walks
948/// the schedule with multiple node-buffer references live at once;
949/// the arena allocator already partitioned them into non-overlapping
950/// regions at compile time.
951#[allow(clippy::mut_from_ref)]
952fn get_output(arena: &Arena, id: NodeId) -> &mut [f32] {
953    let (ptr, len) = arena.raw_ptr(id);
954    unsafe { std::slice::from_raw_parts_mut(ptr, len) }
955}
956
957/// Matrix multiply — uses BLAS when linked, naive fallback otherwise.
958#[inline]
959fn matmul(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
960    // Use BLAS sgemm (Accelerate/MKL/OpenBLAS) — linked via build.rs
961    crate::blas::sgemm(a, b, c, m, k, n);
962}
963
964fn binary_op(op: rlx_ir::op::BinaryOp, a: f32, b: f32) -> f32 {
965    use rlx_ir::op::BinaryOp::*;
966    match op {
967        Add => a + b,
968        Sub => a - b,
969        Mul => a * b,
970        Div => a / b,
971        Max => a.max(b),
972        Min => a.min(b),
973        Pow => a.powf(b),
974    }
975}
976
977fn compare_op(op: rlx_ir::op::CmpOp, a: f32, b: f32) -> bool {
978    use rlx_ir::op::CmpOp::*;
979    match op {
980        Eq => a == b,
981        Ne => a != b,
982        Lt => a < b,
983        Le => a <= b,
984        Gt => a > b,
985        Ge => a >= b,
986    }
987}
988
989// Reference scalar GELU — kept as a parity oracle for SIMD paths.
990#[allow(dead_code)]
991fn scalar_gelu(x: f32) -> f32 {
992    let sign = if x >= 0.0 { 1.0f32 } else { -1.0 };
993    let xa = x.abs();
994    let t = 1.0 / (1.0 + 0.3275911 * xa);
995    let y = t
996        * (0.254_829_6
997            + t * (-0.284_496_72 + t * (1.421_413_8 + t * (-1.453_152_1 + t * 1.061_405_4))));
998    let erf = sign * (1.0 - y * (-xa * xa).exp());
999    x * 0.5 * (1.0 + erf)
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004    use super::*;
1005    use rlx_ir::*;
1006
1007    use rlx_opt::fusion::FuseMatMulBiasAct;
1008    use rlx_opt::memory;
1009    use rlx_opt::pass::Pass;
1010
1011    /// End-to-end test: build graph → fuse → plan memory → execute.
1012    #[test]
1013    fn execute_fused_matmul_bias_gelu() {
1014        // Build graph: x @ w + b → gelu
1015        let mut g = Graph::new("test");
1016        let x_id = g.input("x", Shape::new(&[2, 4], DType::F32));
1017        let w_id = g.param("w", Shape::new(&[4, 3], DType::F32));
1018        let b_id = g.param("b", Shape::new(&[3], DType::F32));
1019        let mm = g.matmul(x_id, w_id, Shape::new(&[2, 3], DType::F32));
1020        let add = g.binary(BinaryOp::Add, mm, b_id, Shape::new(&[2, 3], DType::F32));
1021        let out = g.activation(Activation::Gelu, add, Shape::new(&[2, 3], DType::F32));
1022        g.set_outputs(vec![out]);
1023
1024        // Fuse
1025        let fused = FuseMatMulBiasAct.run(g);
1026        println!("{fused}");
1027
1028        // Plan memory
1029        let plan = memory::plan_memory(&fused);
1030        println!("Arena: {} bytes", plan.arena_size);
1031
1032        // Prepare data
1033        let x_data = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // [2, 4] identity-ish
1034        let w_data = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]; // [4, 3]
1035        let b_data = vec![0.5, -0.5, 0.0]; // [3]
1036
1037        let mut ext = ExternalBuffers {
1038            buffers: HashMap::new(),
1039        };
1040        ext.buffers.insert(fused.outputs[0], &[]); // placeholder
1041        // Find input/param node IDs in fused graph
1042        for node in fused.nodes() {
1043            match &node.op {
1044                Op::Input { name } if name == "x" => {
1045                    ext.buffers.insert(node.id, &x_data);
1046                }
1047                Op::Param { name } if name == "w" => {
1048                    ext.buffers.insert(node.id, &w_data);
1049                }
1050                Op::Param { name } if name == "b" => {
1051                    ext.buffers.insert(node.id, &b_data);
1052                }
1053                _ => {}
1054            }
1055        }
1056
1057        // Execute
1058        let mut arena = Arena::from_plan(plan);
1059        execute(&fused, &mut arena, &ext);
1060
1061        // Check output
1062        let output_id = fused.outputs[0];
1063        let result = arena.slice(output_id);
1064        println!("Result: {result:?}");
1065
1066        // x @ w = [[1,0,0], [0,1,0]]; + bias = [[1.5,-0.5,0], [0.5,0.5,0]]
1067        // gelu(1.5) ≈ 1.399, gelu(-0.5) ≈ -0.154, gelu(0) = 0
1068        // gelu(0.5) ≈ 0.346
1069        assert!((result[0] - 1.399).abs() < 0.01, "got {}", result[0]);
1070        assert!((result[1] - -0.154).abs() < 0.01, "got {}", result[1]);
1071        assert!((result[2] - 0.0).abs() < 0.01, "got {}", result[2]);
1072        assert!((result[3] - 0.346).abs() < 0.01, "got {}", result[3]);
1073    }
1074
1075    /// Test Gather (embedding lookup).
1076    #[test]
1077    fn execute_gather() {
1078        use rlx_ir::infer::GraphExt;
1079        let mut g = Graph::new("gather_test");
1080        // Embedding table [4, 3] and indices [2] → output [2, 3]
1081        let table = g.param("table", Shape::new(&[4, 3], DType::F32));
1082        let indices = g.input("ids", Shape::new(&[2], DType::F32)); // f32 indices
1083        let out = g.gather_(table, indices, 0);
1084        g.set_outputs(vec![out]);
1085
1086        let plan = memory::plan_memory(&g);
1087        let mut arena = Arena::from_plan(plan);
1088
1089        let table_data = vec![
1090            10.0, 11.0, 12.0, // row 0
1091            20.0, 21.0, 22.0, // row 1
1092            30.0, 31.0, 32.0, // row 2
1093            40.0, 41.0, 42.0, // row 3
1094        ];
1095        let ids_data = vec![2.0, 0.0]; // gather rows 2 and 0
1096
1097        let mut ext = ExternalBuffers {
1098            buffers: HashMap::new(),
1099        };
1100        for node in g.nodes() {
1101            match &node.op {
1102                Op::Param { name } if name == "table" => {
1103                    ext.buffers.insert(node.id, &table_data);
1104                }
1105                Op::Input { name } if name == "ids" => {
1106                    ext.buffers.insert(node.id, &ids_data);
1107                }
1108                _ => {}
1109            }
1110        }
1111
1112        execute(&g, &mut arena, &ext);
1113        let result = arena.slice(g.outputs[0]);
1114        assert_eq!(&result[..3], &[30.0, 31.0, 32.0]); // row 2
1115        assert_eq!(&result[3..6], &[10.0, 11.0, 12.0]); // row 0
1116    }
1117
1118    /// Test Narrow (slice).
1119    #[test]
1120    fn execute_narrow() {
1121        use rlx_ir::infer::GraphExt;
1122        let mut g = Graph::new("narrow_test");
1123        let x = g.input("x", Shape::new(&[2, 6], DType::F32));
1124        let sliced = g.narrow_(x, 1, 2, 3); // take cols 2..5
1125        g.set_outputs(vec![sliced]);
1126
1127        let plan = memory::plan_memory(&g);
1128        let mut arena = Arena::from_plan(plan);
1129
1130        let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0];
1131        let mut ext = ExternalBuffers {
1132            buffers: HashMap::new(),
1133        };
1134        for node in g.nodes() {
1135            if let Op::Input { .. } = &node.op {
1136                ext.buffers.insert(node.id, &data);
1137            }
1138        }
1139
1140        execute(&g, &mut arena, &ext);
1141        let result = arena.slice(g.outputs[0]);
1142        assert_eq!(result, &[2.0, 3.0, 4.0, 8.0, 9.0, 10.0]);
1143    }
1144
1145    /// Test Softmax.
1146    #[test]
1147    fn execute_softmax() {
1148        use rlx_ir::infer::GraphExt;
1149        let mut g = Graph::new("softmax_test");
1150        let x = g.input("x", Shape::new(&[1, 4], DType::F32));
1151        let sm = g.sm(x, -1);
1152        g.set_outputs(vec![sm]);
1153
1154        let plan = memory::plan_memory(&g);
1155        let mut arena = Arena::from_plan(plan);
1156
1157        let data = vec![1.0, 2.0, 3.0, 4.0];
1158        let mut ext = ExternalBuffers {
1159            buffers: HashMap::new(),
1160        };
1161        for node in g.nodes() {
1162            if let Op::Input { .. } = &node.op {
1163                ext.buffers.insert(node.id, &data);
1164            }
1165        }
1166
1167        execute(&g, &mut arena, &ext);
1168        let result = arena.slice(g.outputs[0]);
1169        let sum: f32 = result.iter().sum();
1170        assert!(
1171            (sum - 1.0).abs() < 1e-5,
1172            "softmax should sum to 1, got {sum}"
1173        );
1174        // Values should be monotonically increasing
1175        assert!(result[0] < result[1]);
1176        assert!(result[1] < result[2]);
1177        assert!(result[2] < result[3]);
1178    }
1179
1180    /// Test RoPE (rotary position embedding).
1181    #[test]
1182    fn execute_rope() {
1183        use rlx_ir::infer::GraphExt;
1184        let head_dim = 4;
1185        let half = head_dim / 2;
1186        let seq = 2;
1187
1188        let mut g = Graph::new("rope_test");
1189        // x: [seq, head_dim], cos: [seq, half], sin: [seq, half]
1190        let x = g.input("x", Shape::new(&[seq, head_dim], DType::F32));
1191        let cos = g.param("cos", Shape::new(&[seq, half], DType::F32));
1192        let sin = g.param("sin", Shape::new(&[seq, half], DType::F32));
1193        let rotated = g.rope(x, cos, sin, head_dim);
1194        g.set_outputs(vec![rotated]);
1195
1196        let plan = memory::plan_memory(&g);
1197        let mut arena = Arena::from_plan(plan);
1198
1199        // x = [[1, 0, 0, 1], [1, 1, 0, 0]] (2 positions, head_dim=4)
1200        let x_data = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0f32];
1201        // cos = [[1, 0], [0, 1]], sin = [[0, 1], [1, 0]] (identity-ish rotation)
1202        let cos_data = vec![1.0, 0.0, 0.0, 1.0f32];
1203        let sin_data = vec![0.0, 1.0, 1.0, 0.0f32];
1204
1205        let mut ext = ExternalBuffers {
1206            buffers: HashMap::new(),
1207        };
1208        for node in g.nodes() {
1209            match &node.op {
1210                Op::Input { name } if name == "x" => {
1211                    ext.buffers.insert(node.id, &x_data);
1212                }
1213                Op::Param { name } if name == "cos" => {
1214                    ext.buffers.insert(node.id, &cos_data);
1215                }
1216                Op::Param { name } if name == "sin" => {
1217                    ext.buffers.insert(node.id, &sin_data);
1218                }
1219                _ => {}
1220            }
1221        }
1222
1223        execute(&g, &mut arena, &ext);
1224        let result = arena.slice(g.outputs[0]);
1225
1226        // Position 0: cos=[1,0], sin=[0,1]
1227        //   x1=1, x2=0 → x1*cos[0]-x2*sin[0] = 1*1-0*0 = 1
1228        //   x1=0, x2=1 → same half: x2*cos[0]+x1*sin[0] = 0*1+1*0 → wait
1229        // Actually: for i=0: x[0]=1, x[half+0]=0 → out[0]=1*1-0*0=1, out[2]=0*1+1*0=0
1230        //           for i=1: x[1]=0, x[half+1]=1 → out[1]=0*0-1*1=-1, out[3]=1*0+0*1=0
1231        assert!((result[0] - 1.0).abs() < 1e-5, "pos0[0]={}", result[0]);
1232        assert!((result[1] - -1.0).abs() < 1e-5, "pos0[1]={}", result[1]);
1233        assert!((result[2] - 0.0).abs() < 1e-5, "pos0[2]={}", result[2]);
1234        assert!((result[3] - 0.0).abs() < 1e-5, "pos0[3]={}", result[3]);
1235
1236        // Position 1: cos=[0,1], sin=[1,0]
1237        //   x=[1,1,0,0]: for i=0: 1*0-0*1=-0=0, out[half+0]=0*0+1*1=1
1238        //                 for i=1: 1*1-0*0=1, out[half+1]=0*1+1*0=0
1239        assert!((result[4] - 0.0).abs() < 1e-5, "pos1[0]={}", result[4]);
1240        assert!((result[5] - 1.0).abs() < 1e-5, "pos1[1]={}", result[5]);
1241        assert!((result[6] - 1.0).abs() < 1e-5, "pos1[2]={}", result[6]);
1242        assert!((result[7] - 0.0).abs() < 1e-5, "pos1[3]={}", result[7]);
1243    }
1244
1245    /// Test LayerNorm standalone.
1246    #[test]
1247    fn execute_layer_norm() {
1248        use rlx_ir::infer::GraphExt;
1249        let mut g = Graph::new("ln_test");
1250        let x = g.input("x", Shape::new(&[1, 4], DType::F32));
1251        let gamma = g.param("g", Shape::new(&[4], DType::F32));
1252        let beta = g.param("b", Shape::new(&[4], DType::F32));
1253        let ln = g.ln(x, gamma, beta, 1e-5);
1254        g.set_outputs(vec![ln]);
1255
1256        let plan = memory::plan_memory(&g);
1257        let mut arena = Arena::from_plan(plan);
1258
1259        let x_data = vec![1.0, 2.0, 3.0, 4.0];
1260        let g_data = vec![1.0, 1.0, 1.0, 1.0];
1261        let b_data = vec![0.0, 0.0, 0.0, 0.0];
1262
1263        let mut ext = ExternalBuffers {
1264            buffers: HashMap::new(),
1265        };
1266        for node in g.nodes() {
1267            match &node.op {
1268                Op::Input { name } if name == "x" => {
1269                    ext.buffers.insert(node.id, &x_data);
1270                }
1271                Op::Param { name } if name == "g" => {
1272                    ext.buffers.insert(node.id, &g_data);
1273                }
1274                Op::Param { name } if name == "b" => {
1275                    ext.buffers.insert(node.id, &b_data);
1276                }
1277                _ => {}
1278            }
1279        }
1280
1281        execute(&g, &mut arena, &ext);
1282        let result = arena.slice(g.outputs[0]);
1283        let sum: f32 = result.iter().sum();
1284        assert!(
1285            sum.abs() < 1e-3,
1286            "LN output should be zero-centered, sum={sum}"
1287        );
1288    }
1289}