Skip to main content

rlx_vulkan/
kernels.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! Per-kernel compute-pipeline cache.
7//!
8//! Every kernel shares one descriptor-set layout (a single storage buffer =
9//! the arena, at binding 0) and one pipeline layout (that DSL + a 128-byte
10//! push-constant range, the Vulkan-guaranteed minimum). Per-op parameters
11//! (offsets, dims, selectors) travel entirely through push constants, so a
12//! single descriptor set bound to the arena serves every dispatch. Compute
13//! pipelines are compiled lazily from the embedded SPIR-V and cached.
14
15use crate::device::{VulkanDevice, vulkan_device};
16use crate::shaders;
17use ash::vk;
18use std::collections::HashMap;
19use std::sync::{Mutex, OnceLock};
20
21/// Push-constant block size in bytes (Vulkan guarantees ≥ 128).
22pub const PUSH_CONSTANT_BYTES: u32 = 128;
23
24pub struct Kernels {
25    dev: &'static VulkanDevice,
26    pub dsl: vk::DescriptorSetLayout,
27    pub pipeline_layout: vk::PipelineLayout,
28    cache: Mutex<HashMap<&'static str, vk::Pipeline>>,
29    modules: Mutex<Vec<vk::ShaderModule>>,
30}
31
32unsafe impl Send for Kernels {}
33unsafe impl Sync for Kernels {}
34
35static KERNELS: OnceLock<Option<Kernels>> = OnceLock::new();
36
37/// The process-wide kernel cache, or `None` if no device.
38pub fn kernels() -> Option<&'static Kernels> {
39    KERNELS
40        .get_or_init(|| vulkan_device().map(Kernels::new))
41        .as_ref()
42}
43
44impl Kernels {
45    fn new(dev: &'static VulkanDevice) -> Self {
46        let bindings = [vk::DescriptorSetLayoutBinding::default()
47            .binding(0)
48            .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
49            .descriptor_count(1)
50            .stage_flags(vk::ShaderStageFlags::COMPUTE)];
51        let dsl = unsafe {
52            dev.device.create_descriptor_set_layout(
53                &vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings),
54                None,
55            )
56        }
57        .expect("vk descriptor_set_layout");
58
59        let set_layouts = [dsl];
60        let pc_ranges = [vk::PushConstantRange::default()
61            .stage_flags(vk::ShaderStageFlags::COMPUTE)
62            .offset(0)
63            .size(PUSH_CONSTANT_BYTES)];
64        let pipeline_layout = unsafe {
65            dev.device.create_pipeline_layout(
66                &vk::PipelineLayoutCreateInfo::default()
67                    .set_layouts(&set_layouts)
68                    .push_constant_ranges(&pc_ranges),
69                None,
70            )
71        }
72        .expect("vk pipeline_layout");
73
74        Self {
75            dev,
76            dsl,
77            pipeline_layout,
78            cache: Mutex::new(HashMap::new()),
79            modules: Mutex::new(Vec::new()),
80        }
81    }
82
83    /// Get (compiling on first use) the compute pipeline for kernel `name`.
84    pub fn pipeline(&self, name: &'static str) -> vk::Pipeline {
85        if let Some(p) = self.cache.lock().unwrap().get(name) {
86            return *p;
87        }
88        let blob = shaders::blob(name)
89            .unwrap_or_else(|| panic!("rlx-vulkan: no embedded SPIR-V for kernel '{name}'"));
90        let words = shaders::words(blob);
91        let module = unsafe {
92            self.dev
93                .device
94                .create_shader_module(&vk::ShaderModuleCreateInfo::default().code(&words), None)
95        }
96        .unwrap_or_else(|e| panic!("vk shader_module '{name}': {e}"));
97
98        let stage = vk::PipelineShaderStageCreateInfo::default()
99            .stage(vk::ShaderStageFlags::COMPUTE)
100            .module(module)
101            .name(c"main");
102        let create = vk::ComputePipelineCreateInfo::default()
103            .stage(stage)
104            .layout(self.pipeline_layout);
105        let pipeline = unsafe {
106            self.dev
107                .device
108                .create_compute_pipelines(vk::PipelineCache::null(), &[create], None)
109        }
110        .map_err(|(_, e)| e)
111        .unwrap_or_else(|e| panic!("vk compute_pipeline '{name}': {e}"))[0];
112
113        self.modules.lock().unwrap().push(module);
114        self.cache.lock().unwrap().insert(name, pipeline);
115        pipeline
116    }
117}