Skip to main content

rotex_vulkan/backend/vulkan/
device.rs

1use std::collections::BTreeMap;
2use std::ffi::CStr;
3
4use ash::vk;
5
6use crate::core::Instance;
7use crate::error::{Error, ErrorKind, Severity, vk_error};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum QueueCategory {
11    Graphics,
12    Compute,
13    Transfer,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub struct QueueRequest {
18    pub category: QueueCategory,
19    pub count: u32,
20}
21
22#[derive(Debug, Clone)]
23pub struct DeviceDescriptor {
24    pub required_features: vk::PhysicalDeviceFeatures,
25    pub enable_swapchain: bool,
26    pub queues: Vec<QueueRequest>,
27}
28
29pub struct Adapter {
30    pub(crate) handle: vk::PhysicalDevice,
31    name: String,
32    device_type: vk::PhysicalDeviceType,
33    limits: vk::PhysicalDeviceLimits,
34}
35
36impl Adapter {
37    pub(crate) fn new(
38        handle: vk::PhysicalDevice,
39        name: String,
40        device_type: vk::PhysicalDeviceType,
41        limits: vk::PhysicalDeviceLimits,
42    ) -> Self {
43        Self {
44            handle,
45            name,
46            device_type,
47            limits,
48        }
49    }
50
51    pub fn name(&self) -> &str {
52        &self.name
53    }
54
55    pub fn device_type(&self) -> vk::PhysicalDeviceType {
56        self.device_type
57    }
58
59    pub fn limits(&self) -> &vk::PhysicalDeviceLimits {
60        &self.limits
61    }
62
63    pub fn physical_device(&self) -> vk::PhysicalDevice {
64        self.handle
65    }
66
67    pub fn selection_score(&self) -> u32 {
68        match self.device_type {
69            vk::PhysicalDeviceType::DISCRETE_GPU => 400,
70            vk::PhysicalDeviceType::INTEGRATED_GPU => 300,
71            vk::PhysicalDeviceType::VIRTUAL_GPU => 200,
72            vk::PhysicalDeviceType::CPU => 100,
73            _ => 0,
74        }
75    }
76
77    pub fn has_swapchain_extension(&self, instance: &Instance) -> Result<bool, Error> {
78        let extensions = unsafe {
79            instance
80                .instance()
81                .enumerate_device_extension_properties(self.handle)
82        }
83        .map_err(vk_error)?;
84        Ok(extensions.iter().any(|ext| unsafe {
85            CStr::from_ptr(ext.extension_name.as_ptr()) == vk::KHR_SWAPCHAIN_NAME
86        }))
87    }
88
89    pub fn supports_queue_requests(&self, instance: &Instance, queues: &[QueueRequest]) -> bool {
90        let queue_families = unsafe {
91            instance
92                .instance()
93                .get_physical_device_queue_family_properties(self.handle)
94        };
95        let graphics_index = queue_families
96            .iter()
97            .enumerate()
98            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::GRAPHICS))
99            .map(|(index, _)| index as u32);
100        let compute_index = queue_families
101            .iter()
102            .enumerate()
103            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::COMPUTE))
104            .map(|(index, _)| index as u32);
105        let transfer_any_index = queue_families
106            .iter()
107            .enumerate()
108            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::TRANSFER))
109            .map(|(index, _)| index as u32);
110        let transfer_dedicated_index = queue_families
111            .iter()
112            .enumerate()
113            .find(|(_, family)| {
114                family.queue_flags.contains(vk::QueueFlags::TRANSFER)
115                    && !family.queue_flags.contains(vk::QueueFlags::GRAPHICS)
116                    && !family.queue_flags.contains(vk::QueueFlags::COMPUTE)
117            })
118            .map(|(index, _)| index as u32);
119
120        let mut has_request = false;
121        for request in queues {
122            if request.count == 0 {
123                continue;
124            }
125            has_request = true;
126            let family_index = match request.category {
127                QueueCategory::Graphics => graphics_index,
128                QueueCategory::Compute => compute_index,
129                QueueCategory::Transfer => transfer_dedicated_index
130                    .or(graphics_index)
131                    .or(transfer_any_index)
132                    .or(compute_index),
133            };
134            if family_index.is_none() {
135                return false;
136            }
137        }
138        has_request
139    }
140
141    pub fn request_device(
142        &self,
143        instance: &Instance,
144        desc: DeviceDescriptor,
145    ) -> Result<Device, Error> {
146        let queue_families = unsafe {
147            instance
148                .instance()
149                .get_physical_device_queue_family_properties(self.handle)
150        };
151        let graphics_index = queue_families
152            .iter()
153            .enumerate()
154            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::GRAPHICS))
155            .map(|(index, _)| index as u32);
156        let compute_index = queue_families
157            .iter()
158            .enumerate()
159            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::COMPUTE))
160            .map(|(index, _)| index as u32);
161        let transfer_any_index = queue_families
162            .iter()
163            .enumerate()
164            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::TRANSFER))
165            .map(|(index, _)| index as u32);
166        let transfer_dedicated_index = queue_families
167            .iter()
168            .enumerate()
169            .find(|(_, family)| {
170                family.queue_flags.contains(vk::QueueFlags::TRANSFER)
171                    && !family.queue_flags.contains(vk::QueueFlags::GRAPHICS)
172                    && !family.queue_flags.contains(vk::QueueFlags::COMPUTE)
173            })
174            .map(|(index, _)| index as u32);
175
176        if desc.enable_swapchain {
177            let extensions = unsafe {
178                instance
179                    .instance()
180                    .enumerate_device_extension_properties(self.handle)
181            }
182            .map_err(vk_error)?;
183            let has_swapchain = extensions.iter().any(|ext| unsafe {
184                CStr::from_ptr(ext.extension_name.as_ptr()) == vk::KHR_SWAPCHAIN_NAME
185            });
186            if !has_swapchain {
187                return Err(Error {
188                    kind: ErrorKind::NoCompatibleDevice,
189                    severity: Severity::Fatal,
190                });
191            }
192        }
193
194        let mut allocations = Vec::new();
195        for request in desc.queues {
196            if request.count == 0 {
197                continue;
198            }
199            let family_index = match request.category {
200                QueueCategory::Graphics => graphics_index,
201                QueueCategory::Compute => compute_index,
202                QueueCategory::Transfer => transfer_dedicated_index
203                    .or(graphics_index)
204                    .or(transfer_any_index)
205                    .or(compute_index),
206            };
207            let family_index = match family_index {
208                Some(index) => index,
209                None => {
210                    return Err(Error {
211                        kind: ErrorKind::NoCompatibleDevice,
212                        severity: Severity::Fatal,
213                    });
214                }
215            };
216            allocations.push(QueueAllocation {
217                category: request.category,
218                family_index,
219                count: request.count,
220            });
221        }
222
223        if allocations.is_empty() {
224            return Err(Error {
225                kind: ErrorKind::NoCompatibleDevice,
226                severity: Severity::Fatal,
227            });
228        }
229
230        let mut queue_priorities: BTreeMap<u32, Vec<f32>> = BTreeMap::new();
231        for allocation in &allocations {
232            let entry = queue_priorities
233                .entry(allocation.family_index)
234                .or_insert_with(Vec::new);
235            entry.extend(std::iter::repeat(1.0).take(allocation.count as usize));
236        }
237
238        for (family_index, priorities) in queue_priorities.iter_mut() {
239            let max_supported = queue_families[*family_index as usize].queue_count as usize;
240            if priorities.len() > max_supported {
241                priorities.truncate(max_supported);
242            }
243        }
244
245        let mut priorities_store = Vec::new();
246        let mut queue_layouts = Vec::new();
247        for (family_index, priorities) in queue_priorities {
248            priorities_store.push(priorities);
249            let idx = priorities_store.len() - 1;
250            queue_layouts.push((family_index, idx));
251        }
252
253        let queue_create_infos: Vec<vk::DeviceQueueCreateInfo> = queue_layouts
254            .into_iter()
255            .map(|(family_index, idx)| {
256                vk::DeviceQueueCreateInfo::default()
257                    .queue_family_index(family_index)
258                    .queue_priorities(&priorities_store[idx])
259            })
260            .collect();
261
262        let device_extensions: Vec<*const i8> = if desc.enable_swapchain {
263            vec![vk::KHR_SWAPCHAIN_NAME.as_ptr()]
264        } else {
265            Vec::new()
266        };
267        let device_create_info = vk::DeviceCreateInfo::default()
268            .queue_create_infos(&queue_create_infos)
269            .enabled_extension_names(&device_extensions)
270            .enabled_features(&desc.required_features);
271
272        let device = unsafe {
273            instance
274                .instance()
275                .create_device(self.handle, &device_create_info, None)
276        }
277        .map_err(vk_error)?;
278
279        let properties = unsafe {
280            instance
281                .instance()
282                .get_physical_device_properties(self.handle)
283        };
284
285        Ok(Device {
286            handle: self.handle,
287            device,
288            properties,
289            queues: allocations,
290        })
291    }
292}
293
294#[derive(Debug, Clone)]
295pub struct QueueAllocation {
296    pub category: QueueCategory,
297    pub family_index: u32,
298    pub count: u32,
299}
300
301pub struct Device {
302    pub(crate) handle: vk::PhysicalDevice,
303    pub(crate) device: ash::Device,
304    properties: vk::PhysicalDeviceProperties,
305    queues: Vec<QueueAllocation>,
306}
307
308impl Device {
309    pub fn logical_device(&self) -> &ash::Device {
310        &self.device
311    }
312
313    pub fn physical_device(&self) -> vk::PhysicalDevice {
314        self.handle
315    }
316
317    pub fn properties(&self) -> &vk::PhysicalDeviceProperties {
318        &self.properties
319    }
320
321    pub fn queues(&self) -> &[QueueAllocation] {
322        &self.queues
323    }
324
325    pub fn get_queue(&self, family_index: u32, queue_index: u32) -> vk::Queue {
326        unsafe { self.device.get_device_queue(family_index, queue_index) }
327    }
328
329    pub fn find_memory_type(
330        &self,
331        instance: &Instance,
332        type_filter: u32,
333        properties: vk::MemoryPropertyFlags,
334    ) -> Result<u32, Error> {
335        let memory_properties = unsafe {
336            instance
337                .instance()
338                .get_physical_device_memory_properties(self.physical_device())
339        };
340
341        for (index, memory_type) in memory_properties.memory_types.iter().enumerate() {
342            let is_allowed_by_hardware = (type_filter & (1 << index)) != 0;
343            let has_required_properties = memory_type.property_flags.contains(properties);
344
345            if is_allowed_by_hardware && has_required_properties {
346                return Ok(index as u32);
347            }
348        }
349
350        Err(Error::fatal(ErrorKind::NoCompatibleDevice))
351    }
352
353    pub fn pad_uniform_buffer_size(&self, original_size: usize) -> usize {
354        let min_alignment = self.properties.limits.min_uniform_buffer_offset_alignment as usize;
355        let mut aligned_size = original_size;
356
357        if min_alignment > 0 {
358            aligned_size = (aligned_size + min_alignment - 1) & !(min_alignment - 1);
359        }
360
361        aligned_size
362    }
363
364    pub fn destroy(&mut self) {
365        unsafe {
366            self.device.destroy_device(None);
367        }
368    }
369}