rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// SPDX-License-Identifier: GPL-3.0-only

//! Vulkan instance / physical-device / logical-device / compute-queue
//! singleton, brought up through `ash` with the dynamically-loaded Vulkan
//! loader. If no loader / driver is present (`Entry::load()` fails) or no
//! device exposes a compute queue, [`vulkan_device`] returns `None` and the
//! whole backend reports itself unavailable — the crate still compiles and
//! links on hosts without Vulkan (macOS without MoltenVK, CI).

use ash::vk;
use std::ffi::{CStr, c_char};
use std::sync::{Mutex, OnceLock};

/// Owned Vulkan context. One per process.
pub struct VulkanDevice {
    pub entry: ash::Entry,
    pub instance: ash::Instance,
    pub physical: vk::PhysicalDevice,
    pub device: ash::Device,
    pub queue: vk::Queue,
    pub queue_family: u32,
    pub mem_props: vk::PhysicalDeviceMemoryProperties,
    pub limits: vk::PhysicalDeviceLimits,
    pub name: String,
    /// True when the selected device is a portability driver (MoltenVK on
    /// Apple): `VK_KHR_portability_subset` was required and enabled. The matmul
    /// scheduler falls back to the scalar kernel here, since shared-memory
    /// tiling regresses under Vulkan→Metal translation (Metal is Apple's
    /// native path anyway).
    pub portability: bool,
    /// True when the device exposes a usable 16×16×16 f16·f16→f32 subgroup
    /// cooperative-matrix (tensor-core) config AND the features its kernel
    /// needs were enabled. Gates the `matmul_coop` fast path. Always false on
    /// portability drivers (MoltenVK doesn't expose `VK_KHR_cooperative_matrix`).
    pub coop_matmul: bool,
    cmd_pool: vk::CommandPool,
    /// Vulkan queues require external synchronization; serialize submits.
    submit_lock: Mutex<()>,
}

// The Vulkan handles are process-global and only touched through the
// `submit_lock`-guarded paths (or are immutable after construction). This
// mirrors the singleton pattern used by rlx-cuda's `CudaContext`.
unsafe impl Send for VulkanDevice {}
unsafe impl Sync for VulkanDevice {}

static DEVICE: OnceLock<Option<VulkanDevice>> = OnceLock::new();

/// Point the Vulkan loader at the MoltenVK ICD when the user hasn't, by probing
/// the standard Homebrew / system install locations. The loader reads
/// `VK_ICD_FILENAMES` at `vkCreateInstance`, so setting it here (before the
/// instance is built) is enough for `DEVICE=vulkan` to work out of the box.
#[cfg(all(target_vendor = "apple", not(target_os = "watchos")))]
fn ensure_macos_loader() {
    if std::env::var_os("VK_ICD_FILENAMES").is_some()
        || std::env::var_os("VK_DRIVER_FILES").is_some()
    {
        return; // user already chose an ICD
    }
    // Fixed locations first, then the versioned Homebrew Cellar.
    for cand in [
        "/opt/homebrew/share/vulkan/icd.d/MoltenVK_icd.json",
        "/usr/local/share/vulkan/icd.d/MoltenVK_icd.json",
    ] {
        if std::path::Path::new(cand).exists() {
            // SAFETY: runs once during the `OnceLock` device init, before the
            // Vulkan instance (and any benchmarking threads) are created.
            unsafe { std::env::set_var("VK_ICD_FILENAMES", cand) };
            return;
        }
    }
    for cellar in [
        "/opt/homebrew/Cellar/molten-vk",
        "/usr/local/Cellar/molten-vk",
    ] {
        if let Ok(rd) = std::fs::read_dir(cellar) {
            for ent in rd.flatten() {
                let icd = ent.path().join("etc/vulkan/icd.d/MoltenVK_icd.json");
                if icd.exists() {
                    // SAFETY: see above.
                    unsafe { std::env::set_var("VK_ICD_FILENAMES", icd) };
                    return;
                }
            }
        }
    }
}

/// The process-wide Vulkan device, or `None` when unavailable.
pub fn vulkan_device() -> Option<&'static VulkanDevice> {
    DEVICE.get_or_init(|| VulkanDevice::new().ok()).as_ref()
}

impl VulkanDevice {
    fn new() -> Result<Self, String> {
        // On macOS the Vulkan loader + MoltenVK ICD ship via Homebrew but land
        // off the default dlopen / loader search paths, so a stock `DEVICE=vulkan`
        // reported "unavailable" until the user exported VK_ICD_FILENAMES and
        // DYLD_LIBRARY_PATH by hand. Wire both up automatically here.
        #[cfg(all(target_vendor = "apple", not(target_os = "watchos")))]
        ensure_macos_loader();

        // SAFETY: dynamic-load the system Vulkan loader. If the default search
        // misses it (Homebrew prefix), retry from known libvulkan locations.
        // Returns Err on hosts with no loader, which we map to "unavailable".
        let entry = unsafe { ash::Entry::load() }
            .or_else(|orig| {
                for lib in [
                    "/opt/homebrew/lib/libvulkan.dylib",
                    "/opt/homebrew/lib/libvulkan.1.dylib",
                    "/usr/local/lib/libvulkan.dylib",
                ] {
                    if std::path::Path::new(lib).exists() {
                        if let Ok(e) = unsafe { ash::Entry::load_from(lib) } {
                            return Ok(e);
                        }
                    }
                }
                Err(orig)
            })
            .map_err(|e| format!("vk load: {e}"))?;

        let app_name = c"rlx-vulkan";
        let app_info = vk::ApplicationInfo::default()
            .application_name(app_name)
            .engine_name(app_name)
            .api_version(vk::make_api_version(0, 1, 1, 0));

        // On Apple platforms the only ICD is MoltenVK, a portability driver:
        // it must be opted into via the portability-enumeration extension.
        let mut inst_ext: Vec<*const c_char> = Vec::new();
        let mut inst_flags = vk::InstanceCreateFlags::empty();
        #[cfg(all(target_vendor = "apple", not(target_os = "watchos")))]
        {
            inst_ext.push(ash::khr::portability_enumeration::NAME.as_ptr());
            inst_ext.push(ash::khr::get_physical_device_properties2::NAME.as_ptr());
            inst_flags |= vk::InstanceCreateFlags::ENUMERATE_PORTABILITY_KHR;
        }

        let create_info = vk::InstanceCreateInfo::default()
            .application_info(&app_info)
            .enabled_extension_names(&inst_ext)
            .flags(inst_flags);
        let instance = unsafe { entry.create_instance(&create_info, None) }
            .map_err(|e| format!("vk instance: {e}"))?;

        // Pick the best physical device exposing a compute queue.
        let physical_devices = unsafe { instance.enumerate_physical_devices() }
            .map_err(|e| format!("vk enumerate: {e}"))?;
        let mut best: Option<(vk::PhysicalDevice, u32, i32)> = None;
        for &pd in &physical_devices {
            let props = unsafe { instance.get_physical_device_properties(pd) };
            let qfams = unsafe { instance.get_physical_device_queue_family_properties(pd) };
            let Some(qf) = qfams
                .iter()
                .position(|q| q.queue_flags.contains(vk::QueueFlags::COMPUTE) && q.queue_count > 0)
            else {
                continue;
            };
            let score = match props.device_type {
                vk::PhysicalDeviceType::DISCRETE_GPU => 3,
                vk::PhysicalDeviceType::INTEGRATED_GPU => 2,
                vk::PhysicalDeviceType::VIRTUAL_GPU => 1,
                _ => 0,
            };
            if best.map(|(_, _, s)| score > s).unwrap_or(true) {
                best = Some((pd, qf as u32, score));
            }
        }
        let (physical, queue_family, _) = best.ok_or_else(|| {
            unsafe { instance.destroy_instance(None) };
            "no Vulkan device with a compute queue".to_string()
        })?;

        let props = unsafe { instance.get_physical_device_properties(physical) };
        let name = unsafe { CStr::from_ptr(props.device_name.as_ptr()) }
            .to_string_lossy()
            .into_owned();

        // Enable VK_KHR_portability_subset on the logical device when the
        // physical device requires it (MoltenVK), else creation fails.
        let dev_exts =
            unsafe { instance.enumerate_device_extension_properties(physical) }.unwrap_or_default();
        let mut dev_ext: Vec<*const c_char> = Vec::new();
        let portability_name = c"VK_KHR_portability_subset";
        let mut is_portability = false;
        for e in &dev_exts {
            let n = unsafe { CStr::from_ptr(e.extension_name.as_ptr()) };
            if n == portability_name {
                dev_ext.push(portability_name.as_ptr());
                is_portability = true;
            }
        }

        // Cooperative-matrix (tensor-core) matmul fast path. Native drivers
        // only (MoltenVK never exposes it). Needs the coop-matrix extension
        // plus the f16 / memory-model features the `matmul_coop` kernel uses,
        // and a supported 16×16×16 f16·f16→f32 subgroup config.
        let has_ext = |want: &CStr| {
            dev_exts
                .iter()
                .any(|e| unsafe { CStr::from_ptr(e.extension_name.as_ptr()) } == want)
        };
        let coop_ext = c"VK_KHR_cooperative_matrix";
        let memmodel_ext = c"VK_KHR_vulkan_memory_model";
        let f16_ext = c"VK_KHR_shader_float16_int8";
        let s16_ext = c"VK_KHR_16bit_storage";
        let mut coop_matmul = false;
        if !is_portability
            && has_ext(coop_ext)
            && has_ext(memmodel_ext)
            && has_ext(f16_ext)
            && has_ext(s16_ext)
        {
            let mut coop_feat = vk::PhysicalDeviceCooperativeMatrixFeaturesKHR::default();
            let mut probe = vk::PhysicalDeviceFeatures2::default().push_next(&mut coop_feat);
            unsafe { instance.get_physical_device_features2(physical, &mut probe) };
            if coop_feat.cooperative_matrix != 0 {
                let ci = ash::khr::cooperative_matrix::Instance::new(&entry, &instance);
                let configs =
                    unsafe { ci.get_physical_device_cooperative_matrix_properties(physical) }
                        .unwrap_or_default();
                coop_matmul = configs.iter().any(|c| {
                    c.m_size == 16
                        && c.n_size == 16
                        && c.k_size == 16
                        && c.a_type == vk::ComponentTypeKHR::FLOAT16
                        && c.b_type == vk::ComponentTypeKHR::FLOAT16
                        && c.result_type == vk::ComponentTypeKHR::FLOAT32
                        && c.scope == vk::ScopeKHR::SUBGROUP
                });
            }
        }
        if coop_matmul {
            dev_ext.push(coop_ext.as_ptr());
            dev_ext.push(memmodel_ext.as_ptr());
            dev_ext.push(f16_ext.as_ptr());
            dev_ext.push(s16_ext.as_ptr());
        }
        if std::env::var_os("RLX_VULKAN_DEBUG").is_some() {
            eprintln!(
                "[rlx-vulkan] device={name:?} portability={is_portability} coop_matmul={coop_matmul}"
            );
        }

        let priorities = [1.0f32];
        let queue_infos = [vk::DeviceQueueCreateInfo::default()
            .queue_family_index(queue_family)
            .queue_priorities(&priorities)];
        let base_features = vk::PhysicalDeviceFeatures::default();
        let device = if coop_matmul {
            // Chain the feature structs the coop kernel requires via pNext.
            let mut coop_f =
                vk::PhysicalDeviceCooperativeMatrixFeaturesKHR::default().cooperative_matrix(true);
            let mut mm_f =
                vk::PhysicalDeviceVulkanMemoryModelFeatures::default().vulkan_memory_model(true);
            let mut f16_f =
                vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true);
            let mut s16_f =
                vk::PhysicalDevice16BitStorageFeatures::default().storage_buffer16_bit_access(true);
            let mut feats2 = vk::PhysicalDeviceFeatures2::default()
                .features(base_features)
                .push_next(&mut coop_f)
                .push_next(&mut mm_f)
                .push_next(&mut f16_f)
                .push_next(&mut s16_f);
            let dci = vk::DeviceCreateInfo::default()
                .queue_create_infos(&queue_infos)
                .enabled_extension_names(&dev_ext)
                .push_next(&mut feats2);
            unsafe { instance.create_device(physical, &dci, None) }
        } else {
            let dci = vk::DeviceCreateInfo::default()
                .queue_create_infos(&queue_infos)
                .enabled_extension_names(&dev_ext)
                .enabled_features(&base_features);
            unsafe { instance.create_device(physical, &dci, None) }
        }
        .map_err(|e| format!("vk device: {e}"))?;
        let queue = unsafe { device.get_device_queue(queue_family, 0) };

        let mem_props = unsafe { instance.get_physical_device_memory_properties(physical) };

        let cmd_pool = unsafe {
            device.create_command_pool(
                &vk::CommandPoolCreateInfo::default()
                    .queue_family_index(queue_family)
                    .flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER),
                None,
            )
        }
        .map_err(|e| format!("vk cmd pool: {e}"))?;

        Ok(Self {
            entry,
            instance,
            physical,
            device,
            queue,
            queue_family,
            mem_props,
            limits: props.limits,
            name,
            portability: is_portability,
            coop_matmul,
            cmd_pool,
            submit_lock: Mutex::new(()),
        })
    }

    /// Find a memory type index satisfying `type_bits` and `flags`.
    pub fn find_memory_type(&self, type_bits: u32, flags: vk::MemoryPropertyFlags) -> Option<u32> {
        let mp = &self.mem_props;
        (0..mp.memory_type_count).find(|&i| {
            (type_bits & (1 << i)) != 0
                && mp.memory_types[i as usize].property_flags.contains(flags)
        })
    }

    /// Record `record` into a one-shot primary command buffer, submit it to
    /// the compute queue, and block until the GPU finishes. Serialized
    /// across threads via `submit_lock` (queue submission needs external
    /// synchronization).
    pub fn submit_and_wait<F: FnOnce(vk::CommandBuffer)>(&self, record: F) {
        let _guard = self.submit_lock.lock().unwrap();
        let dev = &self.device;
        unsafe {
            let cmd = dev
                .allocate_command_buffers(
                    &vk::CommandBufferAllocateInfo::default()
                        .command_pool(self.cmd_pool)
                        .level(vk::CommandBufferLevel::PRIMARY)
                        .command_buffer_count(1),
                )
                .expect("vk alloc cmd buffer")[0];

            dev.begin_command_buffer(
                cmd,
                &vk::CommandBufferBeginInfo::default()
                    .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
            )
            .expect("vk begin cmd");

            record(cmd);

            dev.end_command_buffer(cmd).expect("vk end cmd");

            let fence = dev
                .create_fence(&vk::FenceCreateInfo::default(), None)
                .expect("vk fence");
            let cmds = [cmd];
            let submit = vk::SubmitInfo::default().command_buffers(&cmds);
            dev.queue_submit(self.queue, &[submit], fence)
                .expect("vk submit");
            dev.wait_for_fences(&[fence], true, u64::MAX)
                .expect("vk wait");
            dev.destroy_fence(fence, None);
            dev.free_command_buffers(self.cmd_pool, &cmds);
        }
    }

    /// Allocate one reusable primary command buffer from the shared pool. The
    /// caller records it once and re-submits it many times via
    /// [`submit_recorded_wait`] (the schedule is static across runs — inputs
    /// flow through the host-visible arena, not the command stream — so a single
    /// recording is valid for every step). Free it with [`free_cmds`].
    pub fn alloc_primary_cmd(&self) -> vk::CommandBuffer {
        unsafe {
            self.device
                .allocate_command_buffers(
                    &vk::CommandBufferAllocateInfo::default()
                        .command_pool(self.cmd_pool)
                        .level(vk::CommandBufferLevel::PRIMARY)
                        .command_buffer_count(1),
                )
                .expect("vk alloc cmd buffer")[0]
        }
    }

    /// Free command buffers allocated from the shared pool.
    pub fn free_cmds(&self, cmds: &[vk::CommandBuffer]) {
        unsafe {
            self.device.free_command_buffers(self.cmd_pool, cmds);
        }
    }

    /// Create one unsignaled fence, reused across submits (reset after each
    /// wait). Avoids the per-step create/destroy of [`submit_and_wait`].
    pub fn create_reusable_fence(&self) -> vk::Fence {
        unsafe {
            self.device
                .create_fence(&vk::FenceCreateInfo::default(), None)
                .expect("vk fence")
        }
    }

    /// Destroy a fence created by [`create_reusable_fence`].
    pub fn destroy_fence(&self, fence: vk::Fence) {
        unsafe {
            self.device.destroy_fence(fence, None);
        }
    }

    /// Submit an already-recorded command buffer and block until the GPU
    /// finishes, then reset `fence` for reuse. The command buffer must have been
    /// recorded WITHOUT `ONE_TIME_SUBMIT` so it can be resubmitted; since we
    /// wait here it is never pending at the next submit. Serialized via
    /// `submit_lock` (queue submission needs external synchronization).
    pub fn submit_recorded_wait(&self, cmd: vk::CommandBuffer, fence: vk::Fence) {
        let _guard = self.submit_lock.lock().unwrap();
        let dev = &self.device;
        unsafe {
            let cmds = [cmd];
            let submit = vk::SubmitInfo::default().command_buffers(&cmds);
            dev.queue_submit(self.queue, &[submit], fence)
                .expect("vk submit");
            dev.wait_for_fences(&[fence], true, u64::MAX)
                .expect("vk wait");
            dev.reset_fences(&[fence]).expect("vk reset fence");
        }
    }
}