ash_bootstrap/
instance.rs

1use crate::system_info::{DEBUG_UTILS_EXT_NAME, SystemInfo, VALIDATION_LAYER_NAME};
2use ash::ext::debug_utils;
3use ash::vk::{AllocationCallbacks, DebugUtilsMessengerEXT, api_version_minor};
4use ash::{khr, vk};
5use raw_window_handle::{DisplayHandle, RawDisplayHandle, RawWindowHandle, WindowHandle};
6use std::borrow::Cow;
7use std::ffi;
8use std::ffi::{CStr, CString, c_char, c_void};
9use std::ops::Not;
10use std::sync::Arc;
11
12unsafe extern "system" fn vulkan_debug_callback(
13    message_severity: vk::DebugUtilsMessageSeverityFlagsEXT,
14    message_type: vk::DebugUtilsMessageTypeFlagsEXT,
15    p_callback_data: *const vk::DebugUtilsMessengerCallbackDataEXT<'_>,
16    _user_data: *mut std::os::raw::c_void,
17) -> vk::Bool32 {
18    unsafe {
19        let callback_data = *p_callback_data;
20        let message_id_number = callback_data.message_id_number;
21
22        let message_id_name = if callback_data.p_message_id_name.is_null() {
23            Cow::from("")
24        } else {
25            ffi::CStr::from_ptr(callback_data.p_message_id_name).to_string_lossy()
26        };
27
28        let message = if callback_data.p_message.is_null() {
29            Cow::from("")
30        } else {
31            ffi::CStr::from_ptr(callback_data.p_message).to_string_lossy()
32        };
33
34        println!(
35            "{message_severity:?}:\n{message_type:?} [{message_id_name} ({message_id_number})] : {message}\n",
36        );
37
38        vk::FALSE
39    }
40}
41
42#[derive(Debug)]
43pub struct DebugUserData(*mut c_void);
44
45impl Default for DebugUserData {
46    fn default() -> Self {
47        Self(std::ptr::null_mut())
48    }
49}
50
51impl DebugUserData {
52    /// Caller must ensure that data pointer points to valid memory.
53    pub unsafe fn new(data: *mut c_void) -> Self {
54        Self(data)
55    }
56}
57
58impl DebugUserData {
59    pub fn into_inner(self) -> *mut c_void {
60        self.0
61    }
62}
63
64#[derive(Debug)]
65pub struct InstanceBuilder<'a> {
66    // VkApplicationInfo
67    app_name: String,
68    engine_name: String,
69    application_version: u32,
70    engine_version: u32,
71    minimum_instance_version: u32,
72    required_instance_version: u32,
73
74    // VkInstanceCreateInfo
75    layers: Vec<String>,
76    extensions: Vec<Cow<'a, str>>,
77    flags: vk::InstanceCreateFlags,
78
79    // debug callback
80    debug_callback: vk::PFN_vkDebugUtilsMessengerCallbackEXT,
81    debug_message_severity: vk::DebugUtilsMessageSeverityFlagsEXT,
82    debug_message_type: vk::DebugUtilsMessageTypeFlagsEXT,
83    debug_user_data: DebugUserData,
84
85    // validation checks
86    disabled_validation_checks: Vec<vk::ValidationCheckEXT>,
87    enabled_validation_features: Vec<vk::ValidationFeatureEnableEXT>,
88    disabled_validation_features: Vec<vk::ValidationFeatureDisableEXT>,
89
90    allocation_callbacks: Option<vk::AllocationCallbacks<'static>>,
91
92    request_validation_layers: bool,
93    enable_validation_layers: bool,
94    // TODO: make typesafe
95    use_debug_messenger: bool,
96    headless_context: bool,
97
98    window_handle: Option<RawWindowHandle>,
99    display_handle: Option<RawDisplayHandle>,
100}
101
102impl<'a> InstanceBuilder<'a> {
103    pub fn new(window_display_handle: Option<(WindowHandle, DisplayHandle)>) -> Self {
104        let (window_handle, display_handle) = window_display_handle.unzip();
105        Self {
106            app_name: "".to_string(),
107            engine_name: "".to_string(),
108            application_version: 0,
109            engine_version: 0,
110            minimum_instance_version: 0,
111            required_instance_version: vk::API_VERSION_1_0,
112            layers: vec![],
113            extensions: vec![],
114            flags: Default::default(),
115            debug_callback: None,
116            debug_message_severity: vk::DebugUtilsMessageSeverityFlagsEXT::WARNING
117                | vk::DebugUtilsMessageSeverityFlagsEXT::ERROR,
118            debug_message_type: vk::DebugUtilsMessageTypeFlagsEXT::GENERAL
119                | vk::DebugUtilsMessageTypeFlagsEXT::VALIDATION
120                | vk::DebugUtilsMessageTypeFlagsEXT::PERFORMANCE,
121            debug_user_data: Default::default(),
122            disabled_validation_checks: vec![],
123            enabled_validation_features: vec![],
124            disabled_validation_features: vec![],
125            allocation_callbacks: None,
126            request_validation_layers: false,
127            enable_validation_layers: false,
128            use_debug_messenger: false,
129            headless_context: false,
130            display_handle: display_handle.map(|h| h.as_raw()),
131            window_handle: window_handle.map(|h| h.as_raw()),
132        }
133    }
134
135    pub fn app_name(mut self, app_name: impl Into<String>) -> Self {
136        self.app_name = app_name.into();
137        self
138    }
139
140    pub fn engine_name(mut self, engine_name: impl Into<String>) -> Self {
141        self.engine_name = engine_name.into();
142        self
143    }
144
145    pub fn app_version(mut self, version: u32) -> Self {
146        self.application_version = version;
147        self
148    }
149
150    pub fn engine_version(mut self, version: u32) -> Self {
151        self.engine_version = version;
152        self
153    }
154
155    pub fn require_api_version(mut self, version: u32) -> Self {
156        self.required_instance_version = version;
157        self
158    }
159
160    pub fn minimum_instance_version(mut self, version: u32) -> Self {
161        self.minimum_instance_version = version;
162        self
163    }
164
165    pub fn enable_layer(mut self, layer: impl Into<String>) -> Self {
166        self.layers.push(layer.into());
167        self
168    }
169
170    pub fn enable_extension(mut self, extension: impl Into<Cow<'a, str>>) -> Self {
171        self.extensions.push(extension.into());
172        self
173    }
174
175    pub fn enable_validation_layers(mut self, enable: bool) -> Self {
176        self.enable_validation_layers = enable;
177        self
178    }
179
180    pub fn request_validation_layers(mut self, request: bool) -> Self {
181        self.request_validation_layers = request;
182        self
183    }
184
185    pub fn use_default_debug_messenger(mut self) -> Self {
186        self.use_debug_messenger = true;
187        self.debug_callback = Some(vulkan_debug_callback);
188        self
189    }
190
191    #[cfg(feature = "enable_tracing")]
192    pub fn use_default_tracing_messenger(mut self) -> Self {
193        self.use_debug_messenger = true;
194        self.debug_callback = Some(crate::tracing::vulkan_tracing_callback);
195        self
196    }
197
198    pub fn set_debug_messenger(
199        mut self,
200        callback: vk::PFN_vkDebugUtilsMessengerCallbackEXT,
201    ) -> Self {
202        self.use_debug_messenger = true;
203        self.debug_callback = callback;
204        self
205    }
206
207    pub fn debug_user_data(mut self, debug_user_data: DebugUserData) -> Self {
208        self.debug_user_data = debug_user_data;
209        self
210    }
211
212    pub fn headless(mut self, headless: bool) -> Self {
213        self.headless_context = headless;
214        self
215    }
216
217    pub fn debug_messenger_severity(
218        mut self,
219        severity: vk::DebugUtilsMessageSeverityFlagsEXT,
220    ) -> Self {
221        self.debug_message_severity = severity;
222        self
223    }
224
225    pub fn add_debug_messenger_severity(
226        mut self,
227        severity: vk::DebugUtilsMessageSeverityFlagsEXT,
228    ) -> Self {
229        self.debug_message_severity |= severity;
230        self
231    }
232
233    pub fn debug_messenger_type(mut self, message_type: vk::DebugUtilsMessageTypeFlagsEXT) -> Self {
234        self.debug_message_type = message_type;
235        self
236    }
237
238    pub fn add_debug_messenger_type(
239        mut self,
240        message_type: vk::DebugUtilsMessageTypeFlagsEXT,
241    ) -> Self {
242        self.debug_message_type |= message_type;
243        self
244    }
245
246    #[cfg_attr(feature = "enable_tracing", tracing::instrument(skip(self)))]
247    pub fn build(self) -> crate::Result<Arc<Instance>> {
248        let system_info = SystemInfo::get_system_info()?;
249
250        let instance_version = {
251            if self.minimum_instance_version > vk::API_VERSION_1_0
252                || self.required_instance_version > vk::API_VERSION_1_0
253            {
254                let version = unsafe { system_info.entry.try_enumerate_instance_version() }?;
255
256                let version = version.unwrap_or(vk::API_VERSION_1_0);
257
258                if version < self.minimum_instance_version
259                    || (self.minimum_instance_version == 0
260                        && version < self.required_instance_version)
261                {
262                    return match api_version_minor(
263                        self.required_instance_version
264                            .max(self.minimum_instance_version),
265                    ) {
266                        3 => Err(crate::InstanceError::VulkanVersion13Unavailable.into()),
267                        2 => Err(crate::InstanceError::VulkanVersion12Unavailable.into()),
268                        1 => Err(crate::InstanceError::VulkanVersion11Unavailable.into()),
269                        minor => Err(crate::InstanceError::VulkanVersionUnavailable(format!(
270                            "1.{minor}"
271                        ))
272                        .into()),
273                    };
274                } else {
275                    version
276                }
277            } else {
278                vk::API_VERSION_1_0
279            }
280        };
281
282        #[cfg(feature = "enable_tracing")]
283        {
284            tracing::info!(
285                "Instance version: {}.{}.{}",
286                vk::api_version_major(instance_version),
287                vk::api_version_minor(instance_version),
288                vk::api_version_patch(instance_version)
289            );
290        }
291
292        let api_version = if instance_version < vk::API_VERSION_1_1
293            || self.required_instance_version < self.minimum_instance_version
294        {
295            instance_version
296        } else {
297            self.required_instance_version
298                .max(self.minimum_instance_version)
299        };
300        #[cfg(feature = "enable_tracing")]
301        {
302            use crate::version::Version;
303            let version = Version::new(api_version);
304            tracing::info!("api_version: {}", version);
305        }
306
307        let app_name = CString::new(self.app_name).map_err(anyhow::Error::msg)?;
308        let engine_name = CString::new(self.engine_name).map_err(anyhow::Error::msg)?;
309
310        let app_info = vk::ApplicationInfo::default()
311            .application_name(&app_name)
312            .application_version(self.application_version)
313            .engine_name(&engine_name)
314            .engine_version(self.engine_version)
315            .api_version(api_version);
316
317        #[cfg(feature = "enable_tracing")]
318        {
319            tracing::info!("Creating vkInstance with application info...");
320            tracing::debug!(
321                r#"
322Application info: {{
323    name: {:?},
324    version: {}.{}.{},
325    engine_name: {:?},
326    engine_version: {}.{}.{},
327    api_version: {}.{}.{},
328}}
329            "#,
330                app_name,
331                vk::api_version_major(self.application_version),
332                vk::api_version_minor(self.application_version),
333                vk::api_version_patch(self.application_version),
334                engine_name,
335                vk::api_version_major(self.engine_version),
336                vk::api_version_minor(self.engine_version),
337                vk::api_version_patch(self.engine_version),
338                vk::api_version_major(api_version),
339                vk::api_version_minor(api_version),
340                vk::api_version_patch(api_version),
341            )
342        }
343
344        let mut enabled_extensions: Vec<*const c_char> = vec![];
345        let mut enabled_layers: Vec<*const c_char> = vec![];
346
347        let extensions = self
348            .extensions
349            .into_iter()
350            .map(|s| CString::new(s.to_string()).expect("Could not create CString"))
351            .collect::<Vec<_>>();
352
353        let extensions_ptrs = extensions.iter().map(|e| e.as_ptr()).collect::<Vec<_>>();
354
355        enabled_extensions.extend_from_slice(&extensions_ptrs);
356
357        if self.debug_callback.is_some()
358            && self.use_debug_messenger
359            && system_info.debug_utils_available
360        {
361            enabled_extensions.push(DEBUG_UTILS_EXT_NAME.as_ptr());
362        }
363
364        let properties2_ext_enabled = api_version < vk::API_VERSION_1_1
365            && system_info.is_extension_available(vk::KHR_GET_PHYSICAL_DEVICE_PROPERTIES2_NAME)?;
366
367        if properties2_ext_enabled {
368            enabled_extensions.push(vk::KHR_GET_PHYSICAL_DEVICE_PROPERTIES2_NAME.as_ptr());
369        }
370
371        #[cfg(feature = "portability")]
372        let portability_enumeration_support =
373            system_info.is_extension_available(vk::KHR_PORTABILITY_ENUMERATION_NAME)?;
374        #[cfg(feature = "portability")]
375        if portability_enumeration_support {
376            enabled_extensions.push(vk::KHR_PORTABILITY_ENUMERATION_NAME.as_ptr());
377        }
378
379        if !self.headless_context {
380            if let Some(display_handle) = self.display_handle {
381                let surface_extensions_raw =
382                    ash_window::enumerate_required_extensions(display_handle)?;
383                let surface_extensions = surface_extensions_raw
384                    .iter()
385                    .map(|p| unsafe { CStr::from_ptr(*p) })
386                    .collect::<Vec<_>>();
387                let windowing_extensions = surface_extensions
388                    .iter()
389                    .map(|s| s.to_str().unwrap().to_string())
390                    .collect::<Vec<_>>();
391                if !system_info.are_extensions_available(surface_extensions)? {
392                    return Err(crate::InstanceError::WindowingExtensionsNotPresent(
393                        windowing_extensions,
394                    )
395                    .into());
396                };
397
398                enabled_extensions.extend_from_slice(surface_extensions_raw);
399            }
400        }
401
402        let cstr_enabled_extensions = enabled_extensions
403            .iter()
404            .map(|p| unsafe { CStr::from_ptr(*p) })
405            .collect::<Vec<_>>();
406
407        #[cfg(feature = "enable_tracing")]
408        tracing::trace!(?cstr_enabled_extensions);
409
410        let all_extensions_supported =
411            system_info.are_extensions_available(cstr_enabled_extensions)?;
412        if !all_extensions_supported {
413            let string_enabled_extensions = enabled_extensions
414                .iter()
415                .map(|p| unsafe { CStr::from_ptr(*p) }.to_str().unwrap().to_string())
416                .collect::<Vec<_>>();
417
418            return Err(crate::InstanceError::RequestedExtensionsNotPresent(
419                string_enabled_extensions,
420            )
421            .into());
422        };
423
424        let layers = self
425            .layers
426            .into_iter()
427            .map(|s| CString::new(s).expect("Could not create CString"))
428            .collect::<Vec<_>>();
429
430        let layers_ptrs = layers.iter().map(|e| e.as_ptr()).collect::<Vec<_>>();
431
432        enabled_layers.extend_from_slice(&layers_ptrs);
433
434        if self.enable_validation_layers
435            || (self.request_validation_layers && system_info.validation_layers_available)
436        {
437            enabled_layers.push(VALIDATION_LAYER_NAME.as_ptr())
438        };
439
440        let all_layers_supported =
441            system_info.are_layers_available(layers.iter().map(|s| s.as_c_str()))?;
442
443        if !all_layers_supported {
444            let enabled_layers_str = enabled_layers
445                .iter()
446                .map(|p| unsafe { CStr::from_ptr(*p) }.to_str().unwrap().to_string())
447                .collect::<Vec<_>>();
448            return Err(crate::InstanceError::RequestedLayersNotPresent(enabled_layers_str).into());
449        };
450
451        let mut messenger_create_info = vk::DebugUtilsMessengerCreateInfoEXT::default();
452        if self.use_debug_messenger {
453            messenger_create_info = messenger_create_info
454                .message_severity(self.debug_message_severity)
455                .message_type(self.debug_message_type)
456                .pfn_user_callback(self.debug_callback)
457                .user_data(self.debug_user_data.into_inner());
458
459            #[cfg(feature = "enable_tracing")]
460            tracing::trace!(?self.debug_callback, "Using debug messenger");
461        };
462
463        let instance_create_flags = if cfg!(feature = "portability") {
464            self.flags | vk::InstanceCreateFlags::ENUMERATE_PORTABILITY_KHR
465        } else {
466            self.flags
467        };
468
469        let mut instance_create_info = vk::InstanceCreateInfo::default()
470            .flags(instance_create_flags)
471            .application_info(&app_info)
472            .enabled_extension_names(&enabled_extensions)
473            .enabled_layer_names(&enabled_layers);
474
475        let mut features = vk::ValidationFeaturesEXT::default();
476
477        if !self.enabled_validation_features.is_empty()
478            || !self.disabled_validation_features.is_empty()
479        {
480            features = features
481                .enabled_validation_features(&self.enabled_validation_features)
482                .disabled_validation_features(&self.disabled_validation_features);
483
484            instance_create_info = instance_create_info.push_next(&mut features);
485        };
486
487        let mut checks = vk::ValidationFlagsEXT::default();
488
489        if !self.disabled_validation_checks.is_empty() {
490            checks = checks.disabled_validation_checks(&self.disabled_validation_checks);
491
492            instance_create_info = instance_create_info.push_next(&mut checks);
493        };
494
495        let instance = unsafe {
496            system_info
497                .entry
498                .create_instance(&instance_create_info, self.allocation_callbacks.as_ref())
499        }
500        .map_err(|_| crate::InstanceError::FailedCreateInstance)?;
501
502        #[cfg(feature = "enable_tracing")]
503        tracing::info!("Created vkInstance");
504
505        let mut debug_loader = None;
506        let mut debug_messenger = None;
507
508        if self.use_debug_messenger {
509            let loader = debug_utils::Instance::new(&system_info.entry, &instance);
510            let messenger = unsafe {
511                loader.create_debug_utils_messenger(
512                    &messenger_create_info,
513                    self.allocation_callbacks.as_ref(),
514                )
515            }?;
516
517            debug_loader.replace(loader);
518            debug_messenger.replace(messenger);
519        };
520
521        let surface_instance = self
522            .headless_context
523            .not()
524            .then(|| khr::surface::Instance::new(&system_info.entry, &instance));
525        let mut surface = None;
526        if let Some((window_handle, display_handle)) = self.window_handle.zip(self.display_handle) {
527            if surface_instance.is_some() {
528                surface = Some(unsafe {
529                    ash_window::create_surface(
530                        &system_info.entry,
531                        &instance,
532                        display_handle,
533                        window_handle,
534                        None,
535                    )?
536                });
537                #[cfg(feature = "enable_tracing")]
538                tracing::info!("Created vkSurfaceKhr")
539            }
540        };
541
542        Ok(Arc::new(Instance {
543            instance,
544            surface_instance,
545            surface,
546            allocation_callbacks: self.allocation_callbacks,
547            instance_version,
548            api_version,
549            properties2_ext_enabled,
550            debug_loader,
551            debug_messenger,
552            _system_info: system_info,
553        }))
554    }
555}
556
557pub struct Instance {
558    pub(crate) instance: ash::Instance,
559    pub(crate) allocation_callbacks: Option<AllocationCallbacks<'static>>,
560    pub(crate) surface_instance: Option<khr::surface::Instance>,
561    pub(crate) surface: Option<vk::SurfaceKHR>,
562    pub(crate) instance_version: u32,
563    pub api_version: u32,
564    pub(crate) properties2_ext_enabled: bool,
565    pub(crate) debug_loader: Option<debug_utils::Instance>,
566    pub(crate) debug_messenger: Option<DebugUtilsMessengerEXT>,
567    _system_info: SystemInfo,
568}
569
570impl Instance {
571    pub fn destroy(&self) {
572        unsafe {
573            if let Some((debug_messenger, debug_loader)) = self
574                .debug_messenger
575                .as_ref()
576                .zip(self.debug_loader.as_ref())
577            {
578                debug_loader.destroy_debug_utils_messenger(
579                    *debug_messenger,
580                    self.allocation_callbacks.as_ref(),
581                );
582            }
583            if let Some((surface_instance, surface)) =
584                self.surface_instance.as_ref().zip(self.surface)
585            {
586                surface_instance.destroy_surface(surface, self.allocation_callbacks.as_ref());
587            }
588            self.instance
589                .destroy_instance(self.allocation_callbacks.as_ref());
590        }
591    }
592}
593
594impl AsRef<ash::Instance> for Instance {
595    fn as_ref(&self) -> &ash::Instance {
596        &self.instance
597    }
598}
599
600#[cfg(test)]
601mod tests {
602
603    #[test]
604    fn compiles() {}
605}