1use ash::Instance;
2use ash::vk;
3use ash::vk::PhysicalDeviceProperties2;
4use ash::vk::PhysicalDeviceShaderCoreProperties2AMD;
5use ash::vk::PhysicalDeviceShaderCorePropertiesAMD;
6use ash::vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV;
7use std::ffi::CStr;
8
9use crate::vendor::Vendor;
10
11#[derive(Debug)]
13pub struct Device {
14 pub vendor: Vendor,
15 pub device_name: String,
16 pub device_type: DeviceType,
17 pub device_id: u32,
18 pub vendor_id: u32,
19 pub driver_name: String,
20 pub driver_info: String,
21 pub api_version: String,
22 pub heapbudget: u64,
24 pub heapsize: u64,
25 pub characteristics: GPUCharacteristics,
26}
27
28#[derive(Debug)]
32pub struct GPUCharacteristics {
33 pub memory_pressure: f32,
35 pub compute_units: Option<u32>,
37 pub shader_engines: Option<u32>,
38 pub shader_arrays_per_engine_count: Option<u32>,
39 pub compute_units_per_shader_array: Option<u32>,
40 pub simd_per_compute_unit: Option<u32>,
41 pub wavefronts_per_simd: Option<u32>,
42 pub wavefront_size: Option<u32>,
43 pub streaming_multiprocessors: Option<u32>,
45 pub warps_per_sm: Option<u32>,
46 pub max_image_dimension_2d: u32,
48 pub max_compute_shared_memory_size: u32,
49 pub max_compute_work_group_invocations: u32,
50 pub dedicated_transfer_queue: bool,
52 pub dedicated_async_compute_queue: bool,
53 pub supports_ray_tracing: bool,
54}
55
56impl Device {
57 pub fn new(instance: &Instance, physical_device: vk::PhysicalDevice) -> Self {
59 let physical_device_properties: vk::PhysicalDeviceProperties =
61 unsafe { instance.get_physical_device_properties(physical_device) };
62 let limits = physical_device_properties.limits;
63
64 let mut driver_properties: vk::PhysicalDeviceDriverProperties =
66 vk::PhysicalDeviceDriverProperties::default();
67 let mut properties2: PhysicalDeviceProperties2 =
68 PhysicalDeviceProperties2::default().push_next(&mut driver_properties);
69 unsafe {
70 instance.get_physical_device_properties2(physical_device, &mut properties2);
71 }
72
73 let vendor_id = physical_device_properties.vendor_id;
74 let vendor = Vendor::from_vendor_id(vendor_id).unwrap_or_else(|| {
75 eprintln!("Unknown vendor: {}", vendor_id);
76 panic!();
77 });
78
79 let device_name = cstring_to_string(
80 physical_device_properties
81 .device_name_as_c_str()
82 .unwrap_or(c"Unknown"),
83 );
84 let device_type = DeviceType::from(physical_device_properties.device_type.as_raw());
85 let device_id = physical_device_properties.device_id;
86 let api_version = decode_version_number(physical_device_properties.api_version);
87 let driver_name = cstring_to_string(
88 driver_properties
89 .driver_name_as_c_str()
90 .unwrap_or(c"Unknown"),
91 );
92 let driver_info = cstring_to_string(
93 driver_properties
94 .driver_info_as_c_str()
95 .unwrap_or(c"Unknown"),
96 );
97
98 let mut memory_budget = vk::PhysicalDeviceMemoryBudgetPropertiesEXT::default();
100 let mut memory_properties2 =
101 vk::PhysicalDeviceMemoryProperties2::default().push_next(&mut memory_budget);
102 unsafe {
103 instance
104 .get_physical_device_memory_properties2(physical_device, &mut memory_properties2);
105 }
106 let memory_properties = memory_properties2.memory_properties;
107 let vram_heap_index = (0..memory_properties.memory_heap_count)
108 .find(|&i| {
109 memory_properties.memory_heaps[i as usize]
110 .flags
111 .contains(vk::MemoryHeapFlags::DEVICE_LOCAL)
112 })
113 .unwrap_or(0);
114 let heapsize = memory_properties.memory_heaps[vram_heap_index as usize].size;
115 let heapbudget = memory_budget.heap_budget[vram_heap_index as usize];
116 let memory_pressure = if heapbudget > 0 {
117 (heapsize - heapbudget) as f32 / heapsize as f32
118 } else {
119 f32::NAN
120 };
121
122 let queue_families =
124 unsafe { instance.get_physical_device_queue_family_properties(physical_device) };
125 let mut dedicated_transfer_queue = false;
126 let mut dedicated_async_compute_queue = false;
127 for qf in queue_families.iter() {
128 let flags = qf.queue_flags;
129 if flags.contains(vk::QueueFlags::TRANSFER)
130 && !(flags.contains(vk::QueueFlags::GRAPHICS)
131 || flags.contains(vk::QueueFlags::COMPUTE))
132 {
133 dedicated_transfer_queue = true;
134 }
135 if flags.contains(vk::QueueFlags::COMPUTE) && !flags.contains(vk::QueueFlags::GRAPHICS)
136 {
137 dedicated_async_compute_queue = true;
138 }
139 }
140
141 let extensions = unsafe {
143 instance
144 .enumerate_device_extension_properties(physical_device)
145 .unwrap_or_default()
146 };
147 let supports_ray_tracing = extensions.iter().any(|ext| {
148 let ext_name = unsafe { CStr::from_ptr(ext.extension_name.as_ptr()) };
149 ext_name.to_str().unwrap_or("") == "VK_KHR_ray_tracing_pipeline"
150 || ext_name.to_str().unwrap_or("") == "VK_NV_ray_tracing"
151 });
152
153 let mut characteristics = GPUCharacteristics {
154 memory_pressure,
155 compute_units: None,
157 shader_engines: None,
158 shader_arrays_per_engine_count: None,
159 compute_units_per_shader_array: None,
160 simd_per_compute_unit: None,
161 wavefronts_per_simd: None,
162 wavefront_size: None,
163 streaming_multiprocessors: None,
164 warps_per_sm: None,
165 max_image_dimension_2d: limits.max_image_dimension2_d,
167 max_compute_shared_memory_size: limits.max_compute_shared_memory_size,
168 max_compute_work_group_invocations: limits.max_compute_work_group_invocations,
169 dedicated_transfer_queue,
171 dedicated_async_compute_queue,
172 supports_ray_tracing,
173 };
174
175 match vendor {
177 Vendor::AMD => {
178 let mut shader_core_properties = PhysicalDeviceShaderCorePropertiesAMD::default();
179 let mut shader_core_properties2 = PhysicalDeviceShaderCoreProperties2AMD::default();
180 let mut amd_properties2 = PhysicalDeviceProperties2::default()
181 .push_next(&mut shader_core_properties)
182 .push_next(&mut shader_core_properties2);
183 unsafe {
184 instance.get_physical_device_properties2(physical_device, &mut amd_properties2);
185 }
186 characteristics.compute_units = Some(
187 shader_core_properties.shader_engine_count
188 * shader_core_properties.shader_arrays_per_engine_count
189 * shader_core_properties.compute_units_per_shader_array,
190 );
191 characteristics.shader_engines = Some(shader_core_properties.shader_engine_count);
192 characteristics.shader_arrays_per_engine_count =
193 Some(shader_core_properties.shader_arrays_per_engine_count);
194 characteristics.compute_units_per_shader_array =
195 Some(shader_core_properties.compute_units_per_shader_array);
196 characteristics.simd_per_compute_unit =
197 Some(shader_core_properties.simd_per_compute_unit);
198 characteristics.wavefronts_per_simd =
199 Some(shader_core_properties.wavefronts_per_simd);
200 characteristics.wavefront_size = Some(shader_core_properties.wavefront_size);
201 }
202 Vendor::Nvidia => {
203 let mut sm_builtins = PhysicalDeviceShaderSMBuiltinsPropertiesNV::default();
204 let mut nv_properties2 =
205 PhysicalDeviceProperties2::default().push_next(&mut sm_builtins);
206 unsafe {
207 instance.get_physical_device_properties2(physical_device, &mut nv_properties2);
208 }
209 characteristics.streaming_multiprocessors = Some(sm_builtins.shader_sm_count);
210 characteristics.warps_per_sm = Some(sm_builtins.shader_warps_per_sm);
211 }
212 _ => {
213 }
215 };
216
217 Device {
218 vendor,
219 device_name,
220 device_type,
221 device_id,
222 vendor_id,
223 driver_name,
224 driver_info,
225 api_version,
226 heapbudget,
227 heapsize,
228 characteristics,
229 }
230 }
231}
232
233#[derive(Debug)]
235pub enum DeviceType {
236 Other = 0,
237 IntegratedGPU = 1,
238 DiscreteGPU = 2,
239 VirtualGPU = 3,
240 CPU = 4,
241 Unknown = 5,
242}
243
244impl DeviceType {
245 pub fn from(id: i32) -> Self {
247 match id {
248 0 => DeviceType::Other,
249 1 => DeviceType::IntegratedGPU,
250 2 => DeviceType::DiscreteGPU,
251 3 => DeviceType::VirtualGPU,
252 4 => DeviceType::CPU,
253 _ => DeviceType::Unknown,
254 }
255 }
256
257 pub fn name(&self) -> &'static str {
259 match self {
260 DeviceType::Other => "Other",
261 DeviceType::IntegratedGPU => "Integrated GPU",
262 DeviceType::DiscreteGPU => "Discrete GPU",
263 DeviceType::VirtualGPU => "Virtual GPU",
264 DeviceType::CPU => "CPU",
265 DeviceType::Unknown => "Unknown",
266 }
267 }
268}
269
270pub fn decode_version_number(version: u32) -> String {
272 let variant = (version >> 29) & 0b111;
273 let major = (version >> 22) & 0b1111111;
274 let minor = (version >> 12) & 0b1111111111;
275 let patch = version & 0b111111111111;
276 format!("{}.{}.{}.{}", variant, major, minor, patch)
277}
278
279pub fn cstring_to_string(cstr: &CStr) -> String {
281 cstr.to_string_lossy().into_owned()
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use ash::vk;
288 use std::ffi::CString;
289
290 fn dummy_cstr(s: &str) -> CString {
292 CString::new(s).unwrap()
293 }
294
295 #[test]
296 fn test_decode_version_number() {
297 let version: u32 = (1 << 22) | (2 << 12) | 3;
299 let decoded = decode_version_number(version);
300 assert_eq!(decoded, "0.1.2.3");
301 }
302
303 #[test]
304 fn test_cstring_to_string() {
305 let original = "Hello, world!";
306 let cstr = dummy_cstr(original);
307 let s = cstring_to_string(cstr.as_c_str());
308 assert_eq!(s, original);
309 }
310
311 #[test]
312 fn test_device_type_from() {
313 assert_eq!(DeviceType::from(0).name(), "Other");
314 assert_eq!(DeviceType::from(1).name(), "Integrated GPU");
315 assert_eq!(DeviceType::from(2).name(), "Discrete GPU");
316 assert_eq!(DeviceType::from(3).name(), "Virtual GPU");
317 assert_eq!(DeviceType::from(4).name(), "CPU");
318 assert_eq!(DeviceType::from(99).name(), "Unknown");
319 }
320
321 #[test]
322 fn test_gpu_characteristics_defaults() {
323 let limits = vk::PhysicalDeviceLimits {
325 max_image_dimension2_d: 8192,
326 max_compute_shared_memory_size: 16384,
327 max_compute_work_group_invocations: 1024,
328 ..Default::default()
329 };
330
331 let characteristics = GPUCharacteristics {
333 memory_pressure: 0.5,
334 compute_units: None,
335 shader_engines: None,
336 shader_arrays_per_engine_count: None,
337 compute_units_per_shader_array: None,
338 simd_per_compute_unit: None,
339 wavefronts_per_simd: None,
340 wavefront_size: None,
341 streaming_multiprocessors: None,
342 warps_per_sm: None,
343 max_image_dimension_2d: limits.max_image_dimension2_d,
344 max_compute_shared_memory_size: limits.max_compute_shared_memory_size,
345 max_compute_work_group_invocations: limits.max_compute_work_group_invocations,
346 dedicated_transfer_queue: false,
347 dedicated_async_compute_queue: false,
348 supports_ray_tracing: false,
349 };
350
351 assert_eq!(characteristics.max_image_dimension_2d, 8192);
352 assert_eq!(characteristics.max_compute_shared_memory_size, 16384);
353 assert_eq!(characteristics.max_compute_work_group_invocations, 1024);
354 assert!(characteristics.compute_units.is_none());
355 }
356}