Skip to main content

rlx_oneapi/
kernels.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! Embedded SPIR-V kernel blobs (compiled from `kernels/*.cl` by `build.rs`)
7//! and their Level Zero module/kernel cache. `SPIRV_BLOBS` is empty unless the
8//! crate was built on an Intel oneAPI host with `ocloc` + `RLX_ONEAPI_BUILD_KERNELS=1`,
9//! in which case [`kernels`] builds one `ze_module` + `ze_kernel` per blob (the
10//! kernel name equals the `.cl` function name, which equals the file stem).
11
12include!(concat!(env!("OUT_DIR"), "/kernels_generated.rs"));
13
14use crate::device::{OneApiDevice, oneapi_device};
15use crate::level_zero::*;
16use std::collections::HashMap;
17use std::ffi::CString;
18use std::sync::OnceLock;
19
20/// SPIR-V byte blob for kernel `name`, if it was embedded this build.
21pub fn blob(name: &str) -> Option<&'static [u8]> {
22    SPIRV_BLOBS
23        .iter()
24        .find(|(n, _)| *n == name)
25        .map(|(_, b)| *b)
26}
27
28/// All embedded kernel names (deterministic; build.rs sorts).
29pub fn names() -> impl Iterator<Item = &'static str> {
30    SPIRV_BLOBS.iter().map(|(n, _)| *n)
31}
32
33/// Whether any native SPIR-V kernel was embedded (false on non-Intel builds).
34pub fn kernels_built() -> bool {
35    KERNELS_BUILT
36}
37
38/// Built Level Zero modules + kernels for the embedded SPIR-V blobs.
39pub struct Kernels {
40    modules: Vec<ModuleHandle>,
41    kernels: HashMap<&'static str, KernelHandle>,
42}
43
44// Handles are process-global, built once, and dispatched behind the device's
45// submit lock.
46unsafe impl Send for Kernels {}
47unsafe impl Sync for Kernels {}
48
49impl Kernels {
50    /// Kernel handle for `name`, or `None` if that kernel wasn't embedded/built.
51    pub fn get(&self, name: &str) -> Option<KernelHandle> {
52        self.kernels.get(name).copied()
53    }
54
55    fn build(dev: &OneApiDevice) -> Result<Kernels, String> {
56        let mut modules = Vec::new();
57        let mut kernels = HashMap::new();
58        for (name, spirv) in SPIRV_BLOBS {
59            let desc = ModuleDesc {
60                stype: ZE_STRUCTURE_TYPE_MODULE_DESC,
61                pnext: std::ptr::null(),
62                format: ZE_MODULE_FORMAT_IL_SPIRV,
63                input_size: spirv.len(),
64                p_input_module: spirv.as_ptr(),
65                p_build_flags: std::ptr::null(),
66                p_constants: std::ptr::null(),
67            };
68            let mut module: ModuleHandle = std::ptr::null_mut();
69            let mut build_log: ModuleBuildLogHandle = std::ptr::null_mut();
70            unsafe {
71                check(
72                    (dev.lib.module_create)(
73                        dev.context,
74                        dev.device,
75                        &desc,
76                        &mut module,
77                        &mut build_log,
78                    ),
79                    &format!("zeModuleCreate({name})"),
80                )?;
81            }
82
83            let cname = CString::new(*name).map_err(|e| format!("kernel name {name}: {e}"))?;
84            let kdesc = KernelDesc {
85                stype: ZE_STRUCTURE_TYPE_KERNEL_DESC,
86                pnext: std::ptr::null(),
87                flags: 0,
88                p_kernel_name: cname.as_ptr(),
89            };
90            let mut kernel: KernelHandle = std::ptr::null_mut();
91            unsafe {
92                check(
93                    (dev.lib.kernel_create)(module, &kdesc, &mut kernel),
94                    &format!("zeKernelCreate({name})"),
95                )?;
96            }
97            modules.push(module);
98            kernels.insert(*name, kernel);
99        }
100        Ok(Kernels { modules, kernels })
101    }
102}
103
104impl Drop for Kernels {
105    fn drop(&mut self) {
106        if let Some(dev) = oneapi_device() {
107            unsafe {
108                for (_, &k) in self.kernels.iter() {
109                    let _ = (dev.lib.kernel_destroy)(k);
110                }
111                for &m in &self.modules {
112                    let _ = (dev.lib.module_destroy)(m);
113                }
114            }
115        }
116    }
117}
118
119/// Process-wide kernel cache, or `None` when there is no device or no embedded
120/// SPIR-V (the latter is the case on every non-Intel build host).
121pub fn kernels() -> Option<&'static Kernels> {
122    static CACHE: OnceLock<Option<Kernels>> = OnceLock::new();
123    CACHE
124        .get_or_init(|| {
125            let dev = oneapi_device()?;
126            if SPIRV_BLOBS.is_empty() {
127                return None;
128            }
129            match Kernels::build(dev) {
130                Ok(k) => Some(k),
131                Err(e) => {
132                    eprintln!("rlx-oneapi: kernel build failed ({e}); using CPU reference path");
133                    None
134                }
135            }
136        })
137        .as_ref()
138}