1use ash::vk;
14use std::ffi::{CStr, c_char};
15use std::sync::{Mutex, OnceLock};
16
17pub struct VulkanDevice {
19 pub entry: ash::Entry,
20 pub instance: ash::Instance,
21 pub physical: vk::PhysicalDevice,
22 pub device: ash::Device,
23 pub queue: vk::Queue,
24 pub queue_family: u32,
25 pub mem_props: vk::PhysicalDeviceMemoryProperties,
26 pub limits: vk::PhysicalDeviceLimits,
27 pub name: String,
28 pub portability: bool,
34 pub coop_matmul: bool,
39 cmd_pool: vk::CommandPool,
40 submit_lock: Mutex<()>,
42}
43
44unsafe impl Send for VulkanDevice {}
48unsafe impl Sync for VulkanDevice {}
49
50static DEVICE: OnceLock<Option<VulkanDevice>> = OnceLock::new();
51
52#[cfg(all(target_vendor = "apple", not(target_os = "watchos")))]
57fn ensure_macos_loader() {
58 if std::env::var_os("VK_ICD_FILENAMES").is_some()
59 || std::env::var_os("VK_DRIVER_FILES").is_some()
60 {
61 return; }
63 for cand in [
65 "/opt/homebrew/share/vulkan/icd.d/MoltenVK_icd.json",
66 "/usr/local/share/vulkan/icd.d/MoltenVK_icd.json",
67 ] {
68 if std::path::Path::new(cand).exists() {
69 unsafe { std::env::set_var("VK_ICD_FILENAMES", cand) };
72 return;
73 }
74 }
75 for cellar in [
76 "/opt/homebrew/Cellar/molten-vk",
77 "/usr/local/Cellar/molten-vk",
78 ] {
79 if let Ok(rd) = std::fs::read_dir(cellar) {
80 for ent in rd.flatten() {
81 let icd = ent.path().join("etc/vulkan/icd.d/MoltenVK_icd.json");
82 if icd.exists() {
83 unsafe { std::env::set_var("VK_ICD_FILENAMES", icd) };
85 return;
86 }
87 }
88 }
89 }
90}
91
92pub fn vulkan_device() -> Option<&'static VulkanDevice> {
94 DEVICE.get_or_init(|| VulkanDevice::new().ok()).as_ref()
95}
96
97impl VulkanDevice {
98 fn new() -> Result<Self, String> {
99 #[cfg(all(target_vendor = "apple", not(target_os = "watchos")))]
104 ensure_macos_loader();
105
106 let entry = unsafe { ash::Entry::load() }
110 .or_else(|orig| {
111 for lib in [
112 "/opt/homebrew/lib/libvulkan.dylib",
113 "/opt/homebrew/lib/libvulkan.1.dylib",
114 "/usr/local/lib/libvulkan.dylib",
115 ] {
116 if std::path::Path::new(lib).exists() {
117 if let Ok(e) = unsafe { ash::Entry::load_from(lib) } {
118 return Ok(e);
119 }
120 }
121 }
122 Err(orig)
123 })
124 .map_err(|e| format!("vk load: {e}"))?;
125
126 let app_name = c"rlx-vulkan";
127 let app_info = vk::ApplicationInfo::default()
128 .application_name(app_name)
129 .engine_name(app_name)
130 .api_version(vk::make_api_version(0, 1, 1, 0));
131
132 let mut inst_ext: Vec<*const c_char> = Vec::new();
135 let mut inst_flags = vk::InstanceCreateFlags::empty();
136 #[cfg(all(target_vendor = "apple", not(target_os = "watchos")))]
137 {
138 inst_ext.push(ash::khr::portability_enumeration::NAME.as_ptr());
139 inst_ext.push(ash::khr::get_physical_device_properties2::NAME.as_ptr());
140 inst_flags |= vk::InstanceCreateFlags::ENUMERATE_PORTABILITY_KHR;
141 }
142
143 let create_info = vk::InstanceCreateInfo::default()
144 .application_info(&app_info)
145 .enabled_extension_names(&inst_ext)
146 .flags(inst_flags);
147 let instance = unsafe { entry.create_instance(&create_info, None) }
148 .map_err(|e| format!("vk instance: {e}"))?;
149
150 let physical_devices = unsafe { instance.enumerate_physical_devices() }
152 .map_err(|e| format!("vk enumerate: {e}"))?;
153 let mut best: Option<(vk::PhysicalDevice, u32, i32)> = None;
154 for &pd in &physical_devices {
155 let props = unsafe { instance.get_physical_device_properties(pd) };
156 let qfams = unsafe { instance.get_physical_device_queue_family_properties(pd) };
157 let Some(qf) = qfams
158 .iter()
159 .position(|q| q.queue_flags.contains(vk::QueueFlags::COMPUTE) && q.queue_count > 0)
160 else {
161 continue;
162 };
163 let score = match props.device_type {
164 vk::PhysicalDeviceType::DISCRETE_GPU => 3,
165 vk::PhysicalDeviceType::INTEGRATED_GPU => 2,
166 vk::PhysicalDeviceType::VIRTUAL_GPU => 1,
167 _ => 0,
168 };
169 if best.map(|(_, _, s)| score > s).unwrap_or(true) {
170 best = Some((pd, qf as u32, score));
171 }
172 }
173 let (physical, queue_family, _) = best.ok_or_else(|| {
174 unsafe { instance.destroy_instance(None) };
175 "no Vulkan device with a compute queue".to_string()
176 })?;
177
178 let props = unsafe { instance.get_physical_device_properties(physical) };
179 let name = unsafe { CStr::from_ptr(props.device_name.as_ptr()) }
180 .to_string_lossy()
181 .into_owned();
182
183 let dev_exts =
186 unsafe { instance.enumerate_device_extension_properties(physical) }.unwrap_or_default();
187 let mut dev_ext: Vec<*const c_char> = Vec::new();
188 let portability_name = c"VK_KHR_portability_subset";
189 let mut is_portability = false;
190 for e in &dev_exts {
191 let n = unsafe { CStr::from_ptr(e.extension_name.as_ptr()) };
192 if n == portability_name {
193 dev_ext.push(portability_name.as_ptr());
194 is_portability = true;
195 }
196 }
197
198 let has_ext = |want: &CStr| {
203 dev_exts
204 .iter()
205 .any(|e| unsafe { CStr::from_ptr(e.extension_name.as_ptr()) } == want)
206 };
207 let coop_ext = c"VK_KHR_cooperative_matrix";
208 let memmodel_ext = c"VK_KHR_vulkan_memory_model";
209 let f16_ext = c"VK_KHR_shader_float16_int8";
210 let s16_ext = c"VK_KHR_16bit_storage";
211 let mut coop_matmul = false;
212 if !is_portability
213 && has_ext(coop_ext)
214 && has_ext(memmodel_ext)
215 && has_ext(f16_ext)
216 && has_ext(s16_ext)
217 {
218 let mut coop_feat = vk::PhysicalDeviceCooperativeMatrixFeaturesKHR::default();
219 let mut probe = vk::PhysicalDeviceFeatures2::default().push_next(&mut coop_feat);
220 unsafe { instance.get_physical_device_features2(physical, &mut probe) };
221 if coop_feat.cooperative_matrix != 0 {
222 let ci = ash::khr::cooperative_matrix::Instance::new(&entry, &instance);
223 let configs =
224 unsafe { ci.get_physical_device_cooperative_matrix_properties(physical) }
225 .unwrap_or_default();
226 coop_matmul = configs.iter().any(|c| {
227 c.m_size == 16
228 && c.n_size == 16
229 && c.k_size == 16
230 && c.a_type == vk::ComponentTypeKHR::FLOAT16
231 && c.b_type == vk::ComponentTypeKHR::FLOAT16
232 && c.result_type == vk::ComponentTypeKHR::FLOAT32
233 && c.scope == vk::ScopeKHR::SUBGROUP
234 });
235 }
236 }
237 if coop_matmul {
238 dev_ext.push(coop_ext.as_ptr());
239 dev_ext.push(memmodel_ext.as_ptr());
240 dev_ext.push(f16_ext.as_ptr());
241 dev_ext.push(s16_ext.as_ptr());
242 }
243 if std::env::var_os("RLX_VULKAN_DEBUG").is_some() {
244 eprintln!(
245 "[rlx-vulkan] device={name:?} portability={is_portability} coop_matmul={coop_matmul}"
246 );
247 }
248
249 let priorities = [1.0f32];
250 let queue_infos = [vk::DeviceQueueCreateInfo::default()
251 .queue_family_index(queue_family)
252 .queue_priorities(&priorities)];
253 let base_features = vk::PhysicalDeviceFeatures::default();
254 let device = if coop_matmul {
255 let mut coop_f =
257 vk::PhysicalDeviceCooperativeMatrixFeaturesKHR::default().cooperative_matrix(true);
258 let mut mm_f =
259 vk::PhysicalDeviceVulkanMemoryModelFeatures::default().vulkan_memory_model(true);
260 let mut f16_f =
261 vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true);
262 let mut s16_f =
263 vk::PhysicalDevice16BitStorageFeatures::default().storage_buffer16_bit_access(true);
264 let mut feats2 = vk::PhysicalDeviceFeatures2::default()
265 .features(base_features)
266 .push_next(&mut coop_f)
267 .push_next(&mut mm_f)
268 .push_next(&mut f16_f)
269 .push_next(&mut s16_f);
270 let dci = vk::DeviceCreateInfo::default()
271 .queue_create_infos(&queue_infos)
272 .enabled_extension_names(&dev_ext)
273 .push_next(&mut feats2);
274 unsafe { instance.create_device(physical, &dci, None) }
275 } else {
276 let dci = vk::DeviceCreateInfo::default()
277 .queue_create_infos(&queue_infos)
278 .enabled_extension_names(&dev_ext)
279 .enabled_features(&base_features);
280 unsafe { instance.create_device(physical, &dci, None) }
281 }
282 .map_err(|e| format!("vk device: {e}"))?;
283 let queue = unsafe { device.get_device_queue(queue_family, 0) };
284
285 let mem_props = unsafe { instance.get_physical_device_memory_properties(physical) };
286
287 let cmd_pool = unsafe {
288 device.create_command_pool(
289 &vk::CommandPoolCreateInfo::default()
290 .queue_family_index(queue_family)
291 .flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER),
292 None,
293 )
294 }
295 .map_err(|e| format!("vk cmd pool: {e}"))?;
296
297 Ok(Self {
298 entry,
299 instance,
300 physical,
301 device,
302 queue,
303 queue_family,
304 mem_props,
305 limits: props.limits,
306 name,
307 portability: is_portability,
308 coop_matmul,
309 cmd_pool,
310 submit_lock: Mutex::new(()),
311 })
312 }
313
314 pub fn find_memory_type(&self, type_bits: u32, flags: vk::MemoryPropertyFlags) -> Option<u32> {
316 let mp = &self.mem_props;
317 (0..mp.memory_type_count).find(|&i| {
318 (type_bits & (1 << i)) != 0
319 && mp.memory_types[i as usize].property_flags.contains(flags)
320 })
321 }
322
323 pub fn submit_and_wait<F: FnOnce(vk::CommandBuffer)>(&self, record: F) {
328 let _guard = self.submit_lock.lock().unwrap();
329 let dev = &self.device;
330 unsafe {
331 let cmd = dev
332 .allocate_command_buffers(
333 &vk::CommandBufferAllocateInfo::default()
334 .command_pool(self.cmd_pool)
335 .level(vk::CommandBufferLevel::PRIMARY)
336 .command_buffer_count(1),
337 )
338 .expect("vk alloc cmd buffer")[0];
339
340 dev.begin_command_buffer(
341 cmd,
342 &vk::CommandBufferBeginInfo::default()
343 .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
344 )
345 .expect("vk begin cmd");
346
347 record(cmd);
348
349 dev.end_command_buffer(cmd).expect("vk end cmd");
350
351 let fence = dev
352 .create_fence(&vk::FenceCreateInfo::default(), None)
353 .expect("vk fence");
354 let cmds = [cmd];
355 let submit = vk::SubmitInfo::default().command_buffers(&cmds);
356 dev.queue_submit(self.queue, &[submit], fence)
357 .expect("vk submit");
358 dev.wait_for_fences(&[fence], true, u64::MAX)
359 .expect("vk wait");
360 dev.destroy_fence(fence, None);
361 dev.free_command_buffers(self.cmd_pool, &cmds);
362 }
363 }
364
365 pub fn alloc_primary_cmd(&self) -> vk::CommandBuffer {
371 unsafe {
372 self.device
373 .allocate_command_buffers(
374 &vk::CommandBufferAllocateInfo::default()
375 .command_pool(self.cmd_pool)
376 .level(vk::CommandBufferLevel::PRIMARY)
377 .command_buffer_count(1),
378 )
379 .expect("vk alloc cmd buffer")[0]
380 }
381 }
382
383 pub fn free_cmds(&self, cmds: &[vk::CommandBuffer]) {
385 unsafe {
386 self.device.free_command_buffers(self.cmd_pool, cmds);
387 }
388 }
389
390 pub fn create_reusable_fence(&self) -> vk::Fence {
393 unsafe {
394 self.device
395 .create_fence(&vk::FenceCreateInfo::default(), None)
396 .expect("vk fence")
397 }
398 }
399
400 pub fn destroy_fence(&self, fence: vk::Fence) {
402 unsafe {
403 self.device.destroy_fence(fence, None);
404 }
405 }
406
407 pub fn submit_recorded_wait(&self, cmd: vk::CommandBuffer, fence: vk::Fence) {
413 let _guard = self.submit_lock.lock().unwrap();
414 let dev = &self.device;
415 unsafe {
416 let cmds = [cmd];
417 let submit = vk::SubmitInfo::default().command_buffers(&cmds);
418 dev.queue_submit(self.queue, &[submit], fence)
419 .expect("vk submit");
420 dev.wait_for_fences(&[fence], true, u64::MAX)
421 .expect("vk wait");
422 dev.reset_fences(&[fence]).expect("vk reset fence");
423 }
424 }
425}