Skip to main content

rlx_vulkan/
device.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! Vulkan instance / physical-device / logical-device / compute-queue
7//! singleton, brought up through `ash` with the dynamically-loaded Vulkan
8//! loader. If no loader / driver is present (`Entry::load()` fails) or no
9//! device exposes a compute queue, [`vulkan_device`] returns `None` and the
10//! whole backend reports itself unavailable — the crate still compiles and
11//! links on hosts without Vulkan (macOS without MoltenVK, CI).
12
13use ash::vk;
14use std::ffi::{CStr, c_char};
15use std::sync::{Mutex, OnceLock};
16
17/// Owned Vulkan context. One per process.
18pub struct VulkanDevice {
19    pub entry: ash::Entry,
20    pub instance: ash::Instance,
21    pub physical: vk::PhysicalDevice,
22    pub device: ash::Device,
23    pub queue: vk::Queue,
24    pub queue_family: u32,
25    pub mem_props: vk::PhysicalDeviceMemoryProperties,
26    pub limits: vk::PhysicalDeviceLimits,
27    pub name: String,
28    /// True when the selected device is a portability driver (MoltenVK on
29    /// Apple): `VK_KHR_portability_subset` was required and enabled. The matmul
30    /// scheduler falls back to the scalar kernel here, since shared-memory
31    /// tiling regresses under Vulkan→Metal translation (Metal is Apple's
32    /// native path anyway).
33    pub portability: bool,
34    /// True when the device exposes a usable 16×16×16 f16·f16→f32 subgroup
35    /// cooperative-matrix (tensor-core) config AND the features its kernel
36    /// needs were enabled. Gates the `matmul_coop` fast path. Always false on
37    /// portability drivers (MoltenVK doesn't expose `VK_KHR_cooperative_matrix`).
38    pub coop_matmul: bool,
39    cmd_pool: vk::CommandPool,
40    /// Vulkan queues require external synchronization; serialize submits.
41    submit_lock: Mutex<()>,
42}
43
44// The Vulkan handles are process-global and only touched through the
45// `submit_lock`-guarded paths (or are immutable after construction). This
46// mirrors the singleton pattern used by rlx-cuda's `CudaContext`.
47unsafe impl Send for VulkanDevice {}
48unsafe impl Sync for VulkanDevice {}
49
50static DEVICE: OnceLock<Option<VulkanDevice>> = OnceLock::new();
51
52/// Point the Vulkan loader at the MoltenVK ICD when the user hasn't, by probing
53/// the standard Homebrew / system install locations. The loader reads
54/// `VK_ICD_FILENAMES` at `vkCreateInstance`, so setting it here (before the
55/// instance is built) is enough for `DEVICE=vulkan` to work out of the box.
56#[cfg(all(target_vendor = "apple", not(target_os = "watchos")))]
57fn ensure_macos_loader() {
58    if std::env::var_os("VK_ICD_FILENAMES").is_some()
59        || std::env::var_os("VK_DRIVER_FILES").is_some()
60    {
61        return; // user already chose an ICD
62    }
63    // Fixed locations first, then the versioned Homebrew Cellar.
64    for cand in [
65        "/opt/homebrew/share/vulkan/icd.d/MoltenVK_icd.json",
66        "/usr/local/share/vulkan/icd.d/MoltenVK_icd.json",
67    ] {
68        if std::path::Path::new(cand).exists() {
69            // SAFETY: runs once during the `OnceLock` device init, before the
70            // Vulkan instance (and any benchmarking threads) are created.
71            unsafe { std::env::set_var("VK_ICD_FILENAMES", cand) };
72            return;
73        }
74    }
75    for cellar in [
76        "/opt/homebrew/Cellar/molten-vk",
77        "/usr/local/Cellar/molten-vk",
78    ] {
79        if let Ok(rd) = std::fs::read_dir(cellar) {
80            for ent in rd.flatten() {
81                let icd = ent.path().join("etc/vulkan/icd.d/MoltenVK_icd.json");
82                if icd.exists() {
83                    // SAFETY: see above.
84                    unsafe { std::env::set_var("VK_ICD_FILENAMES", icd) };
85                    return;
86                }
87            }
88        }
89    }
90}
91
92/// The process-wide Vulkan device, or `None` when unavailable.
93pub fn vulkan_device() -> Option<&'static VulkanDevice> {
94    DEVICE.get_or_init(|| VulkanDevice::new().ok()).as_ref()
95}
96
97impl VulkanDevice {
98    fn new() -> Result<Self, String> {
99        // On macOS the Vulkan loader + MoltenVK ICD ship via Homebrew but land
100        // off the default dlopen / loader search paths, so a stock `DEVICE=vulkan`
101        // reported "unavailable" until the user exported VK_ICD_FILENAMES and
102        // DYLD_LIBRARY_PATH by hand. Wire both up automatically here.
103        #[cfg(all(target_vendor = "apple", not(target_os = "watchos")))]
104        ensure_macos_loader();
105
106        // SAFETY: dynamic-load the system Vulkan loader. If the default search
107        // misses it (Homebrew prefix), retry from known libvulkan locations.
108        // Returns Err on hosts with no loader, which we map to "unavailable".
109        let entry = unsafe { ash::Entry::load() }
110            .or_else(|orig| {
111                for lib in [
112                    "/opt/homebrew/lib/libvulkan.dylib",
113                    "/opt/homebrew/lib/libvulkan.1.dylib",
114                    "/usr/local/lib/libvulkan.dylib",
115                ] {
116                    if std::path::Path::new(lib).exists() {
117                        if let Ok(e) = unsafe { ash::Entry::load_from(lib) } {
118                            return Ok(e);
119                        }
120                    }
121                }
122                Err(orig)
123            })
124            .map_err(|e| format!("vk load: {e}"))?;
125
126        let app_name = c"rlx-vulkan";
127        let app_info = vk::ApplicationInfo::default()
128            .application_name(app_name)
129            .engine_name(app_name)
130            .api_version(vk::make_api_version(0, 1, 1, 0));
131
132        // On Apple platforms the only ICD is MoltenVK, a portability driver:
133        // it must be opted into via the portability-enumeration extension.
134        let mut inst_ext: Vec<*const c_char> = Vec::new();
135        let mut inst_flags = vk::InstanceCreateFlags::empty();
136        #[cfg(all(target_vendor = "apple", not(target_os = "watchos")))]
137        {
138            inst_ext.push(ash::khr::portability_enumeration::NAME.as_ptr());
139            inst_ext.push(ash::khr::get_physical_device_properties2::NAME.as_ptr());
140            inst_flags |= vk::InstanceCreateFlags::ENUMERATE_PORTABILITY_KHR;
141        }
142
143        let create_info = vk::InstanceCreateInfo::default()
144            .application_info(&app_info)
145            .enabled_extension_names(&inst_ext)
146            .flags(inst_flags);
147        let instance = unsafe { entry.create_instance(&create_info, None) }
148            .map_err(|e| format!("vk instance: {e}"))?;
149
150        // Pick the best physical device exposing a compute queue.
151        let physical_devices = unsafe { instance.enumerate_physical_devices() }
152            .map_err(|e| format!("vk enumerate: {e}"))?;
153        let mut best: Option<(vk::PhysicalDevice, u32, i32)> = None;
154        for &pd in &physical_devices {
155            let props = unsafe { instance.get_physical_device_properties(pd) };
156            let qfams = unsafe { instance.get_physical_device_queue_family_properties(pd) };
157            let Some(qf) = qfams
158                .iter()
159                .position(|q| q.queue_flags.contains(vk::QueueFlags::COMPUTE) && q.queue_count > 0)
160            else {
161                continue;
162            };
163            let score = match props.device_type {
164                vk::PhysicalDeviceType::DISCRETE_GPU => 3,
165                vk::PhysicalDeviceType::INTEGRATED_GPU => 2,
166                vk::PhysicalDeviceType::VIRTUAL_GPU => 1,
167                _ => 0,
168            };
169            if best.map(|(_, _, s)| score > s).unwrap_or(true) {
170                best = Some((pd, qf as u32, score));
171            }
172        }
173        let (physical, queue_family, _) = best.ok_or_else(|| {
174            unsafe { instance.destroy_instance(None) };
175            "no Vulkan device with a compute queue".to_string()
176        })?;
177
178        let props = unsafe { instance.get_physical_device_properties(physical) };
179        let name = unsafe { CStr::from_ptr(props.device_name.as_ptr()) }
180            .to_string_lossy()
181            .into_owned();
182
183        // Enable VK_KHR_portability_subset on the logical device when the
184        // physical device requires it (MoltenVK), else creation fails.
185        let dev_exts =
186            unsafe { instance.enumerate_device_extension_properties(physical) }.unwrap_or_default();
187        let mut dev_ext: Vec<*const c_char> = Vec::new();
188        let portability_name = c"VK_KHR_portability_subset";
189        let mut is_portability = false;
190        for e in &dev_exts {
191            let n = unsafe { CStr::from_ptr(e.extension_name.as_ptr()) };
192            if n == portability_name {
193                dev_ext.push(portability_name.as_ptr());
194                is_portability = true;
195            }
196        }
197
198        // Cooperative-matrix (tensor-core) matmul fast path. Native drivers
199        // only (MoltenVK never exposes it). Needs the coop-matrix extension
200        // plus the f16 / memory-model features the `matmul_coop` kernel uses,
201        // and a supported 16×16×16 f16·f16→f32 subgroup config.
202        let has_ext = |want: &CStr| {
203            dev_exts
204                .iter()
205                .any(|e| unsafe { CStr::from_ptr(e.extension_name.as_ptr()) } == want)
206        };
207        let coop_ext = c"VK_KHR_cooperative_matrix";
208        let memmodel_ext = c"VK_KHR_vulkan_memory_model";
209        let f16_ext = c"VK_KHR_shader_float16_int8";
210        let s16_ext = c"VK_KHR_16bit_storage";
211        let mut coop_matmul = false;
212        if !is_portability
213            && has_ext(coop_ext)
214            && has_ext(memmodel_ext)
215            && has_ext(f16_ext)
216            && has_ext(s16_ext)
217        {
218            let mut coop_feat = vk::PhysicalDeviceCooperativeMatrixFeaturesKHR::default();
219            let mut probe = vk::PhysicalDeviceFeatures2::default().push_next(&mut coop_feat);
220            unsafe { instance.get_physical_device_features2(physical, &mut probe) };
221            if coop_feat.cooperative_matrix != 0 {
222                let ci = ash::khr::cooperative_matrix::Instance::new(&entry, &instance);
223                let configs =
224                    unsafe { ci.get_physical_device_cooperative_matrix_properties(physical) }
225                        .unwrap_or_default();
226                coop_matmul = configs.iter().any(|c| {
227                    c.m_size == 16
228                        && c.n_size == 16
229                        && c.k_size == 16
230                        && c.a_type == vk::ComponentTypeKHR::FLOAT16
231                        && c.b_type == vk::ComponentTypeKHR::FLOAT16
232                        && c.result_type == vk::ComponentTypeKHR::FLOAT32
233                        && c.scope == vk::ScopeKHR::SUBGROUP
234                });
235            }
236        }
237        if coop_matmul {
238            dev_ext.push(coop_ext.as_ptr());
239            dev_ext.push(memmodel_ext.as_ptr());
240            dev_ext.push(f16_ext.as_ptr());
241            dev_ext.push(s16_ext.as_ptr());
242        }
243        if std::env::var_os("RLX_VULKAN_DEBUG").is_some() {
244            eprintln!(
245                "[rlx-vulkan] device={name:?} portability={is_portability} coop_matmul={coop_matmul}"
246            );
247        }
248
249        let priorities = [1.0f32];
250        let queue_infos = [vk::DeviceQueueCreateInfo::default()
251            .queue_family_index(queue_family)
252            .queue_priorities(&priorities)];
253        let base_features = vk::PhysicalDeviceFeatures::default();
254        let device = if coop_matmul {
255            // Chain the feature structs the coop kernel requires via pNext.
256            let mut coop_f =
257                vk::PhysicalDeviceCooperativeMatrixFeaturesKHR::default().cooperative_matrix(true);
258            let mut mm_f =
259                vk::PhysicalDeviceVulkanMemoryModelFeatures::default().vulkan_memory_model(true);
260            let mut f16_f =
261                vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true);
262            let mut s16_f =
263                vk::PhysicalDevice16BitStorageFeatures::default().storage_buffer16_bit_access(true);
264            let mut feats2 = vk::PhysicalDeviceFeatures2::default()
265                .features(base_features)
266                .push_next(&mut coop_f)
267                .push_next(&mut mm_f)
268                .push_next(&mut f16_f)
269                .push_next(&mut s16_f);
270            let dci = vk::DeviceCreateInfo::default()
271                .queue_create_infos(&queue_infos)
272                .enabled_extension_names(&dev_ext)
273                .push_next(&mut feats2);
274            unsafe { instance.create_device(physical, &dci, None) }
275        } else {
276            let dci = vk::DeviceCreateInfo::default()
277                .queue_create_infos(&queue_infos)
278                .enabled_extension_names(&dev_ext)
279                .enabled_features(&base_features);
280            unsafe { instance.create_device(physical, &dci, None) }
281        }
282        .map_err(|e| format!("vk device: {e}"))?;
283        let queue = unsafe { device.get_device_queue(queue_family, 0) };
284
285        let mem_props = unsafe { instance.get_physical_device_memory_properties(physical) };
286
287        let cmd_pool = unsafe {
288            device.create_command_pool(
289                &vk::CommandPoolCreateInfo::default()
290                    .queue_family_index(queue_family)
291                    .flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER),
292                None,
293            )
294        }
295        .map_err(|e| format!("vk cmd pool: {e}"))?;
296
297        Ok(Self {
298            entry,
299            instance,
300            physical,
301            device,
302            queue,
303            queue_family,
304            mem_props,
305            limits: props.limits,
306            name,
307            portability: is_portability,
308            coop_matmul,
309            cmd_pool,
310            submit_lock: Mutex::new(()),
311        })
312    }
313
314    /// Find a memory type index satisfying `type_bits` and `flags`.
315    pub fn find_memory_type(&self, type_bits: u32, flags: vk::MemoryPropertyFlags) -> Option<u32> {
316        let mp = &self.mem_props;
317        (0..mp.memory_type_count).find(|&i| {
318            (type_bits & (1 << i)) != 0
319                && mp.memory_types[i as usize].property_flags.contains(flags)
320        })
321    }
322
323    /// Record `record` into a one-shot primary command buffer, submit it to
324    /// the compute queue, and block until the GPU finishes. Serialized
325    /// across threads via `submit_lock` (queue submission needs external
326    /// synchronization).
327    pub fn submit_and_wait<F: FnOnce(vk::CommandBuffer)>(&self, record: F) {
328        let _guard = self.submit_lock.lock().unwrap();
329        let dev = &self.device;
330        unsafe {
331            let cmd = dev
332                .allocate_command_buffers(
333                    &vk::CommandBufferAllocateInfo::default()
334                        .command_pool(self.cmd_pool)
335                        .level(vk::CommandBufferLevel::PRIMARY)
336                        .command_buffer_count(1),
337                )
338                .expect("vk alloc cmd buffer")[0];
339
340            dev.begin_command_buffer(
341                cmd,
342                &vk::CommandBufferBeginInfo::default()
343                    .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
344            )
345            .expect("vk begin cmd");
346
347            record(cmd);
348
349            dev.end_command_buffer(cmd).expect("vk end cmd");
350
351            let fence = dev
352                .create_fence(&vk::FenceCreateInfo::default(), None)
353                .expect("vk fence");
354            let cmds = [cmd];
355            let submit = vk::SubmitInfo::default().command_buffers(&cmds);
356            dev.queue_submit(self.queue, &[submit], fence)
357                .expect("vk submit");
358            dev.wait_for_fences(&[fence], true, u64::MAX)
359                .expect("vk wait");
360            dev.destroy_fence(fence, None);
361            dev.free_command_buffers(self.cmd_pool, &cmds);
362        }
363    }
364
365    /// Allocate one reusable primary command buffer from the shared pool. The
366    /// caller records it once and re-submits it many times via
367    /// [`submit_recorded_wait`] (the schedule is static across runs — inputs
368    /// flow through the host-visible arena, not the command stream — so a single
369    /// recording is valid for every step). Free it with [`free_cmds`].
370    pub fn alloc_primary_cmd(&self) -> vk::CommandBuffer {
371        unsafe {
372            self.device
373                .allocate_command_buffers(
374                    &vk::CommandBufferAllocateInfo::default()
375                        .command_pool(self.cmd_pool)
376                        .level(vk::CommandBufferLevel::PRIMARY)
377                        .command_buffer_count(1),
378                )
379                .expect("vk alloc cmd buffer")[0]
380        }
381    }
382
383    /// Free command buffers allocated from the shared pool.
384    pub fn free_cmds(&self, cmds: &[vk::CommandBuffer]) {
385        unsafe {
386            self.device.free_command_buffers(self.cmd_pool, cmds);
387        }
388    }
389
390    /// Create one unsignaled fence, reused across submits (reset after each
391    /// wait). Avoids the per-step create/destroy of [`submit_and_wait`].
392    pub fn create_reusable_fence(&self) -> vk::Fence {
393        unsafe {
394            self.device
395                .create_fence(&vk::FenceCreateInfo::default(), None)
396                .expect("vk fence")
397        }
398    }
399
400    /// Destroy a fence created by [`create_reusable_fence`].
401    pub fn destroy_fence(&self, fence: vk::Fence) {
402        unsafe {
403            self.device.destroy_fence(fence, None);
404        }
405    }
406
407    /// Submit an already-recorded command buffer and block until the GPU
408    /// finishes, then reset `fence` for reuse. The command buffer must have been
409    /// recorded WITHOUT `ONE_TIME_SUBMIT` so it can be resubmitted; since we
410    /// wait here it is never pending at the next submit. Serialized via
411    /// `submit_lock` (queue submission needs external synchronization).
412    pub fn submit_recorded_wait(&self, cmd: vk::CommandBuffer, fence: vk::Fence) {
413        let _guard = self.submit_lock.lock().unwrap();
414        let dev = &self.device;
415        unsafe {
416            let cmds = [cmd];
417            let submit = vk::SubmitInfo::default().command_buffers(&cmds);
418            dev.queue_submit(self.queue, &[submit], fence)
419                .expect("vk submit");
420            dev.wait_for_fences(&[fence], true, u64::MAX)
421                .expect("vk wait");
422            dev.reset_fences(&[fence]).expect("vk reset fence");
423        }
424    }
425}