rlx_oneapi/host.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! Single-op CPU reference evaluation via `rlx-cpu`'s thunk executor (the same
7//! kernels the CPU backend uses, so results are bit-for-bit the reference).
8//!
9//! Two callers:
10//! - On a host with no Level Zero device (the macOS dev box / CI), `backend.rs`
11//! walks the whole legalized graph through this evaluator — every op is served
12//! by the CPU reference, so the backend is fully correct without Intel HW.
13//! - On Intel hardware, ops with no native SPIR-V kernel yet route here (read
14//! from the USM arena, eval, write back) — exactly rlx-vulkan's host-fallback.
15
16use rlx_ir::{Graph, Op, Shape};
17
18/// One host-eval input: f32 activations, or raw bytes for a packed quant weight
19/// (U8/I8 operands such as the GGUF weight of `Op::DequantMatMul`).
20pub enum HostBuf {
21 F32(Vec<f32>),
22 Bytes(Vec<u8>),
23}
24
25/// Run a single op on the CPU reference and return its f32 output.
26/// `inputs[i]` is `(declared_shape, buffer)`.
27pub fn eval(op: &Op, out_shape: &Shape, inputs: &[(Shape, HostBuf)]) -> Vec<f32> {
28 let mut g = Graph::new("oneapi_host_eval");
29 let ids: Vec<rlx_ir::NodeId> = inputs
30 .iter()
31 .enumerate()
32 .map(|(i, (sh, _))| {
33 g.append_node(
34 Op::Input {
35 name: format!("in{i}"),
36 },
37 vec![],
38 sh.clone(),
39 None,
40 )
41 })
42 .collect();
43 let out = g.append_node(op.clone(), ids.clone(), out_shape.clone(), None);
44 g.set_outputs(vec![out]);
45
46 let plan = rlx_compile::memory::plan_memory_aligned(&g, 16);
47 let mut arena = rlx_cpu::arena::Arena::from_plan(plan);
48
49 for (i, (_, buf)) in inputs.iter().enumerate() {
50 match buf {
51 HostBuf::F32(vals) => {
52 let slot = arena.slice_mut(ids[i]);
53 let n = slot.len().min(vals.len());
54 slot[..n].copy_from_slice(&vals[..n]);
55 }
56 HostBuf::Bytes(bytes) => {
57 let off = arena.byte_offset(ids[i]);
58 let raw = arena.raw_buf_mut();
59 let n = bytes.len().min(raw.len().saturating_sub(off));
60 raw[off..off + n].copy_from_slice(&bytes[..n]);
61 }
62 }
63 }
64
65 let schedule = rlx_cpu::thunk::compile_thunks(&g, &arena);
66 rlx_cpu::thunk::execute_thunks(&schedule, arena.raw_buf_mut());
67
68 let n = out_shape.num_elements().unwrap_or(0);
69 arena.slice_mut(out)[..n].to_vec()
70}