Skip to main content

rlx_oneapi/
host.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! Single-op CPU reference evaluation via `rlx-cpu`'s thunk executor (the same
7//! kernels the CPU backend uses, so results are bit-for-bit the reference).
8//!
9//! Two callers:
10//! - On a host with no Level Zero device (the macOS dev box / CI), `backend.rs`
11//!   walks the whole legalized graph through this evaluator — every op is served
12//!   by the CPU reference, so the backend is fully correct without Intel HW.
13//! - On Intel hardware, ops with no native SPIR-V kernel yet route here (read
14//!   from the USM arena, eval, write back) — exactly rlx-vulkan's host-fallback.
15
16use rlx_ir::{Graph, Op, Shape};
17
18/// One host-eval input: f32 activations, or raw bytes for a packed quant weight
19/// (U8/I8 operands such as the GGUF weight of `Op::DequantMatMul`).
20pub enum HostBuf {
21    F32(Vec<f32>),
22    Bytes(Vec<u8>),
23}
24
25/// Run a single op on the CPU reference and return its f32 output.
26/// `inputs[i]` is `(declared_shape, buffer)`.
27pub fn eval(op: &Op, out_shape: &Shape, inputs: &[(Shape, HostBuf)]) -> Vec<f32> {
28    let mut g = Graph::new("oneapi_host_eval");
29    let ids: Vec<rlx_ir::NodeId> = inputs
30        .iter()
31        .enumerate()
32        .map(|(i, (sh, _))| {
33            g.append_node(
34                Op::Input {
35                    name: format!("in{i}"),
36                },
37                vec![],
38                sh.clone(),
39                None,
40            )
41        })
42        .collect();
43    let out = g.append_node(op.clone(), ids.clone(), out_shape.clone(), None);
44    g.set_outputs(vec![out]);
45
46    let plan = rlx_compile::memory::plan_memory_aligned(&g, 16);
47    let mut arena = rlx_cpu::arena::Arena::from_plan(plan);
48
49    for (i, (_, buf)) in inputs.iter().enumerate() {
50        match buf {
51            HostBuf::F32(vals) => {
52                let slot = arena.slice_mut(ids[i]);
53                let n = slot.len().min(vals.len());
54                slot[..n].copy_from_slice(&vals[..n]);
55            }
56            HostBuf::Bytes(bytes) => {
57                let off = arena.byte_offset(ids[i]);
58                let raw = arena.raw_buf_mut();
59                let n = bytes.len().min(raw.len().saturating_sub(off));
60                raw[off..off + n].copy_from_slice(&bytes[..n]);
61            }
62        }
63    }
64
65    let schedule = rlx_cpu::thunk::compile_thunks(&g, &arena);
66    rlx_cpu::thunk::execute_thunks(&schedule, arena.raw_buf_mut());
67
68    let n = out_shape.num_elements().unwrap_or(0);
69    arena.slice_mut(out)[..n].to_vec()
70}