1use crate::device::{VulkanDevice, vulkan_device};
16use crate::shaders;
17use ash::vk;
18use std::collections::HashMap;
19use std::sync::{Mutex, OnceLock};
20
21pub const PUSH_CONSTANT_BYTES: u32 = 128;
23
24pub struct Kernels {
25 dev: &'static VulkanDevice,
26 pub dsl: vk::DescriptorSetLayout,
27 pub pipeline_layout: vk::PipelineLayout,
28 cache: Mutex<HashMap<&'static str, vk::Pipeline>>,
29 modules: Mutex<Vec<vk::ShaderModule>>,
30}
31
32unsafe impl Send for Kernels {}
33unsafe impl Sync for Kernels {}
34
35static KERNELS: OnceLock<Option<Kernels>> = OnceLock::new();
36
37pub fn kernels() -> Option<&'static Kernels> {
39 KERNELS
40 .get_or_init(|| vulkan_device().map(Kernels::new))
41 .as_ref()
42}
43
44impl Kernels {
45 fn new(dev: &'static VulkanDevice) -> Self {
46 let bindings = [vk::DescriptorSetLayoutBinding::default()
47 .binding(0)
48 .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
49 .descriptor_count(1)
50 .stage_flags(vk::ShaderStageFlags::COMPUTE)];
51 let dsl = unsafe {
52 dev.device.create_descriptor_set_layout(
53 &vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings),
54 None,
55 )
56 }
57 .expect("vk descriptor_set_layout");
58
59 let set_layouts = [dsl];
60 let pc_ranges = [vk::PushConstantRange::default()
61 .stage_flags(vk::ShaderStageFlags::COMPUTE)
62 .offset(0)
63 .size(PUSH_CONSTANT_BYTES)];
64 let pipeline_layout = unsafe {
65 dev.device.create_pipeline_layout(
66 &vk::PipelineLayoutCreateInfo::default()
67 .set_layouts(&set_layouts)
68 .push_constant_ranges(&pc_ranges),
69 None,
70 )
71 }
72 .expect("vk pipeline_layout");
73
74 Self {
75 dev,
76 dsl,
77 pipeline_layout,
78 cache: Mutex::new(HashMap::new()),
79 modules: Mutex::new(Vec::new()),
80 }
81 }
82
83 pub fn pipeline(&self, name: &'static str) -> vk::Pipeline {
85 if let Some(p) = self.cache.lock().unwrap().get(name) {
86 return *p;
87 }
88 let blob = shaders::blob(name)
89 .unwrap_or_else(|| panic!("rlx-vulkan: no embedded SPIR-V for kernel '{name}'"));
90 let words = shaders::words(blob);
91 let module = unsafe {
92 self.dev
93 .device
94 .create_shader_module(&vk::ShaderModuleCreateInfo::default().code(&words), None)
95 }
96 .unwrap_or_else(|e| panic!("vk shader_module '{name}': {e}"));
97
98 let stage = vk::PipelineShaderStageCreateInfo::default()
99 .stage(vk::ShaderStageFlags::COMPUTE)
100 .module(module)
101 .name(c"main");
102 let create = vk::ComputePipelineCreateInfo::default()
103 .stage(stage)
104 .layout(self.pipeline_layout);
105 let pipeline = unsafe {
106 self.dev
107 .device
108 .create_compute_pipelines(vk::PipelineCache::null(), &[create], None)
109 }
110 .map_err(|(_, e)| e)
111 .unwrap_or_else(|e| panic!("vk compute_pipeline '{name}': {e}"))[0];
112
113 self.modules.lock().unwrap().push(module);
114 self.cache.lock().unwrap().insert(name, pipeline);
115 pipeline
116 }
117}