blade_graphics/vulkan/
mod.rs

1use ash::{
2    khr,
3    vk::{self},
4};
5use std::{mem, num::NonZeroU32, path::PathBuf, ptr, sync::Mutex};
6
7mod command;
8mod descriptor;
9mod init;
10mod pipeline;
11mod resource;
12mod surface;
13
14const QUERY_POOL_SIZE: usize = crate::limits::PASS_COUNT + 1;
15
16#[derive(Debug)]
17pub enum PlatformError {
18    Loading(ash::LoadingError),
19    Init(vk::Result),
20}
21
22struct Instance {
23    core: ash::Instance,
24    _debug_utils: ash::ext::debug_utils::Instance,
25    get_physical_device_properties2: khr::get_physical_device_properties2::Instance,
26    get_surface_capabilities2: Option<khr::get_surface_capabilities2::Instance>,
27    surface: Option<khr::surface::Instance>,
28}
29
30#[derive(Clone)]
31struct RayTracingDevice {
32    acceleration_structure: khr::acceleration_structure::Device,
33    scratch_buffer_alignment: u64,
34}
35
36#[derive(Clone, Default)]
37struct CommandScopeDevice {}
38#[derive(Clone, Default)]
39struct TimingDevice {
40    period: f32,
41}
42
43#[derive(Clone)]
44struct Workarounds {
45    extra_sync_src_access: vk::AccessFlags,
46    extra_sync_dst_access: vk::AccessFlags,
47    extra_descriptor_pool_create_flags: vk::DescriptorPoolCreateFlags,
48}
49
50#[derive(Clone)]
51struct Device {
52    core: ash::Device,
53    device_information: crate::DeviceInformation,
54    swapchain: Option<khr::swapchain::Device>,
55    debug_utils: ash::ext::debug_utils::Device,
56    timeline_semaphore: khr::timeline_semaphore::Device,
57    dynamic_rendering: khr::dynamic_rendering::Device,
58    ray_tracing: Option<RayTracingDevice>,
59    buffer_marker: Option<ash::amd::buffer_marker::Device>,
60    shader_info: Option<ash::amd::shader_info::Device>,
61    full_screen_exclusive: Option<ash::ext::full_screen_exclusive::Device>,
62    #[cfg(target_os = "windows")]
63    external_memory: Option<ash::khr::external_memory_win32::Device>,
64    #[cfg(not(target_os = "windows"))]
65    external_memory: Option<ash::khr::external_memory_fd::Device>,
66    command_scope: Option<CommandScopeDevice>,
67    timing: Option<TimingDevice>,
68    workarounds: Workarounds,
69}
70
71struct MemoryManager {
72    allocator: gpu_alloc::GpuAllocator<vk::DeviceMemory>,
73    slab: slab::Slab<gpu_alloc::MemoryBlock<vk::DeviceMemory>>,
74    valid_ash_memory_types: u32,
75}
76
77struct Queue {
78    raw: vk::Queue,
79    timeline_semaphore: vk::Semaphore,
80    last_progress: u64,
81}
82
83#[derive(Clone, Copy, Debug, Default, PartialEq)]
84struct InternalFrame {
85    acquire_semaphore: vk::Semaphore,
86    present_semaphore: vk::Semaphore,
87    image: vk::Image,
88    view: vk::ImageView,
89}
90
91#[derive(Clone, Copy, Debug, PartialEq)]
92struct Swapchain {
93    raw: vk::SwapchainKHR,
94    format: crate::TextureFormat,
95    alpha: crate::AlphaMode,
96    target_size: [u16; 2],
97}
98
99pub struct Surface {
100    device: khr::swapchain::Device,
101    raw: vk::SurfaceKHR,
102    frames: Vec<InternalFrame>,
103    next_semaphore: vk::Semaphore,
104    swapchain: Swapchain,
105    full_screen_exclusive: bool,
106}
107
108#[derive(Clone, Copy, Debug, PartialEq)]
109struct Presentation {
110    swapchain: vk::SwapchainKHR,
111    image_index: u32,
112    acquire_semaphore: vk::Semaphore,
113    present_semaphore: vk::Semaphore,
114}
115
116#[derive(Clone, Copy, Debug, PartialEq)]
117pub struct Frame {
118    swapchain: Swapchain,
119    image_index: Option<u32>,
120    internal: InternalFrame,
121}
122
123impl Frame {
124    pub fn texture(&self) -> Texture {
125        Texture {
126            raw: self.internal.image,
127            memory_handle: !0,
128            target_size: self.swapchain.target_size,
129            format: self.swapchain.format,
130            external: None,
131        }
132    }
133
134    pub fn texture_view(&self) -> TextureView {
135        TextureView {
136            raw: self.internal.view,
137            target_size: self.swapchain.target_size,
138            aspects: crate::TexelAspects::COLOR,
139        }
140    }
141}
142
143fn map_timeout(millis: u32) -> u64 {
144    if millis == !0 {
145        !0
146    } else {
147        millis as u64 * 1_000_000
148    }
149}
150
151pub struct Context {
152    memory: Mutex<MemoryManager>,
153    device: Device,
154    queue_family_index: u32,
155    queue: Mutex<Queue>,
156    physical_device: vk::PhysicalDevice,
157    naga_flags: naga::back::spv::WriterFlags,
158    shader_debug_path: Option<PathBuf>,
159    min_buffer_alignment: u64,
160    sample_count_flags: vk::SampleCountFlags,
161    instance: Instance,
162    entry: ash::Entry,
163}
164
165#[derive(Clone, Copy, Debug, Hash, PartialEq)]
166pub struct Buffer {
167    raw: vk::Buffer,
168    memory_handle: usize,
169    mapped_data: *mut u8,
170    external: Option<crate::ExternalMemorySource>,
171}
172
173impl Default for Buffer {
174    fn default() -> Self {
175        Self {
176            raw: vk::Buffer::null(),
177            memory_handle: !0,
178            mapped_data: ptr::null_mut(),
179            external: None,
180        }
181    }
182}
183
184impl Buffer {
185    pub fn data(&self) -> *mut u8 {
186        self.mapped_data
187    }
188}
189
190unsafe impl Send for Buffer {}
191unsafe impl Sync for Buffer {}
192
193#[derive(Clone, Copy, Debug, Hash, PartialEq)]
194pub struct Texture {
195    raw: vk::Image,
196    memory_handle: usize,
197    target_size: [u16; 2],
198    format: crate::TextureFormat,
199    external: Option<crate::ExternalMemorySource>,
200}
201
202impl Default for Texture {
203    fn default() -> Self {
204        Self {
205            raw: vk::Image::default(),
206            memory_handle: !0,
207            target_size: [0; 2],
208            format: crate::TextureFormat::Rgba8Unorm,
209            external: None,
210        }
211    }
212}
213
214#[derive(Clone, Copy, Debug, Default, Hash, PartialEq)]
215pub struct TextureView {
216    raw: vk::ImageView,
217    target_size: [u16; 2],
218    aspects: crate::TexelAspects,
219}
220
221#[derive(Clone, Copy, Debug, Hash, PartialEq)]
222pub struct Sampler {
223    raw: vk::Sampler,
224}
225
226#[derive(Clone, Copy, Debug, Default, Hash, PartialEq)]
227pub struct AccelerationStructure {
228    raw: vk::AccelerationStructureKHR,
229    buffer: vk::Buffer,
230    memory_handle: usize,
231}
232
233#[derive(Debug, Default)]
234struct DescriptorSetLayout {
235    raw: vk::DescriptorSetLayout,
236    update_template: vk::DescriptorUpdateTemplate,
237    template_size: u32,
238    template_offsets: Box<[u32]>,
239}
240
241impl DescriptorSetLayout {
242    fn is_empty(&self) -> bool {
243        self.template_size == 0
244    }
245}
246
247#[derive(Debug)]
248struct PipelineLayout {
249    raw: vk::PipelineLayout,
250    descriptor_set_layouts: Vec<DescriptorSetLayout>,
251}
252
253pub struct PipelineContext<'a> {
254    update_data: &'a mut [u8],
255    template_offsets: &'a [u32],
256}
257
258#[derive(Debug)]
259pub struct ComputePipeline {
260    raw: vk::Pipeline,
261    layout: PipelineLayout,
262    wg_size: [u32; 3],
263}
264
265impl ComputePipeline {
266    pub fn get_workgroup_size(&self) -> [u32; 3] {
267        self.wg_size
268    }
269}
270
271#[derive(Debug)]
272pub struct RenderPipeline {
273    raw: vk::Pipeline,
274    layout: PipelineLayout,
275}
276
277#[derive(Debug)]
278struct CommandBuffer {
279    raw: vk::CommandBuffer,
280    descriptor_pool: descriptor::DescriptorPool,
281    query_pool: vk::QueryPool,
282    timed_pass_names: Vec<String>,
283}
284
285struct CrashHandler {
286    name: String,
287    marker_buf: Buffer,
288    raw_string: Box<[u8]>,
289    next_offset: usize,
290}
291
292pub struct CommandEncoder {
293    pool: vk::CommandPool,
294    buffers: Box<[CommandBuffer]>,
295    device: Device,
296    update_data: Vec<u8>,
297    present: Option<Presentation>,
298    crash_handler: Option<CrashHandler>,
299    temp_label: Vec<u8>,
300    timings: crate::Timings,
301}
302pub struct TransferCommandEncoder<'a> {
303    raw: vk::CommandBuffer,
304    device: &'a Device,
305}
306pub struct AccelerationStructureCommandEncoder<'a> {
307    raw: vk::CommandBuffer,
308    device: &'a Device,
309}
310pub struct ComputeCommandEncoder<'a> {
311    cmd_buf: &'a mut CommandBuffer,
312    device: &'a Device,
313    update_data: &'a mut Vec<u8>,
314}
315//Note: we aren't merging this with `ComputeCommandEncoder`
316// because the destructors are different, and they can't be specialized
317// https://github.com/rust-lang/rust/issues/46893
318pub struct RenderCommandEncoder<'a> {
319    cmd_buf: &'a mut CommandBuffer,
320    device: &'a Device,
321    update_data: &'a mut Vec<u8>,
322}
323
324pub struct PipelineEncoder<'a, 'p> {
325    cmd_buf: &'a mut CommandBuffer,
326    layout: &'p PipelineLayout,
327    bind_point: vk::PipelineBindPoint,
328    device: &'a Device,
329    update_data: &'a mut Vec<u8>,
330}
331
332#[derive(Clone, Debug)]
333pub struct SyncPoint {
334    progress: u64,
335}
336
337#[hidden_trait::expose]
338impl crate::traits::CommandDevice for Context {
339    type CommandEncoder = CommandEncoder;
340    type SyncPoint = SyncPoint;
341
342    fn create_command_encoder(&self, desc: super::CommandEncoderDesc) -> CommandEncoder {
343        //TODO: these numbers are arbitrary, needs to be replaced by
344        // an abstraction from gpu-alloc, if possible.
345        const ROUGH_SET_COUNT: u32 = 60000;
346        let mut descriptor_sizes = vec![
347            vk::DescriptorPoolSize {
348                ty: vk::DescriptorType::INLINE_UNIFORM_BLOCK_EXT,
349                descriptor_count: ROUGH_SET_COUNT * crate::limits::PLAIN_DATA_SIZE,
350            },
351            vk::DescriptorPoolSize {
352                ty: vk::DescriptorType::STORAGE_BUFFER,
353                descriptor_count: ROUGH_SET_COUNT,
354            },
355            vk::DescriptorPoolSize {
356                ty: vk::DescriptorType::SAMPLED_IMAGE,
357                descriptor_count: 2 * ROUGH_SET_COUNT,
358            },
359            vk::DescriptorPoolSize {
360                ty: vk::DescriptorType::SAMPLER,
361                descriptor_count: ROUGH_SET_COUNT,
362            },
363            vk::DescriptorPoolSize {
364                ty: vk::DescriptorType::STORAGE_IMAGE,
365                descriptor_count: ROUGH_SET_COUNT,
366            },
367        ];
368        if self.device.ray_tracing.is_some() {
369            descriptor_sizes.push(vk::DescriptorPoolSize {
370                ty: vk::DescriptorType::ACCELERATION_STRUCTURE_KHR,
371                descriptor_count: ROUGH_SET_COUNT,
372            });
373        }
374
375        let pool_info = vk::CommandPoolCreateInfo {
376            flags: vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER,
377            ..Default::default()
378        };
379        let pool = unsafe {
380            self.device
381                .core
382                .create_command_pool(&pool_info, None)
383                .unwrap()
384        };
385        let cmd_buf_info = vk::CommandBufferAllocateInfo {
386            command_pool: pool,
387            command_buffer_count: desc.buffer_count,
388            ..Default::default()
389        };
390        let cmd_buffers = unsafe {
391            self.device
392                .core
393                .allocate_command_buffers(&cmd_buf_info)
394                .unwrap()
395        };
396
397        let buffers = cmd_buffers
398            .into_iter()
399            .map(|raw| {
400                if !desc.name.is_empty() {
401                    self.set_object_name(raw, desc.name);
402                };
403                let descriptor_pool = self.device.create_descriptor_pool();
404                let query_pool = if self.device.timing.is_some() {
405                    let query_pool_info = vk::QueryPoolCreateInfo::default()
406                        .query_type(vk::QueryType::TIMESTAMP)
407                        .query_count(QUERY_POOL_SIZE as u32);
408                    unsafe {
409                        self.device
410                            .core
411                            .create_query_pool(&query_pool_info, None)
412                            .unwrap()
413                    }
414                } else {
415                    vk::QueryPool::null()
416                };
417                CommandBuffer {
418                    raw,
419                    descriptor_pool,
420                    query_pool,
421                    timed_pass_names: Vec::new(),
422                }
423            })
424            .collect();
425
426        let crash_handler = if self.device.buffer_marker.is_some() {
427            Some(CrashHandler {
428                name: desc.name.to_string(),
429                marker_buf: self.create_buffer(crate::BufferDesc {
430                    name: "_marker",
431                    size: 4,
432                    memory: crate::Memory::Shared,
433                }),
434                raw_string: vec![0; 0x1000].into_boxed_slice(),
435                next_offset: 0,
436            })
437        } else {
438            None
439        };
440
441        CommandEncoder {
442            pool,
443            buffers,
444            device: self.device.clone(),
445            update_data: Vec::new(),
446            present: None,
447            crash_handler,
448            temp_label: Vec::new(),
449            timings: Default::default(),
450        }
451    }
452
453    fn destroy_command_encoder(&self, command_encoder: &mut CommandEncoder) {
454        for cmd_buf in command_encoder.buffers.iter_mut() {
455            let raw_cmd_buffers = [cmd_buf.raw];
456            unsafe {
457                self.device
458                    .core
459                    .free_command_buffers(command_encoder.pool, &raw_cmd_buffers);
460            }
461            self.device
462                .destroy_descriptor_pool(&mut cmd_buf.descriptor_pool);
463            if self.device.timing.is_some() {
464                unsafe {
465                    self.device
466                        .core
467                        .destroy_query_pool(cmd_buf.query_pool, None);
468                }
469            }
470        }
471        unsafe {
472            self.device
473                .core
474                .destroy_command_pool(mem::take(&mut command_encoder.pool), None)
475        };
476        if let Some(crash_handler) = command_encoder.crash_handler.take() {
477            self.destroy_buffer(crash_handler.marker_buf);
478        };
479    }
480
481    fn submit(&self, encoder: &mut CommandEncoder) -> SyncPoint {
482        let raw_cmd_buf = encoder.finish();
483        let mut queue = self.queue.lock().unwrap();
484        queue.last_progress += 1;
485        let progress = queue.last_progress;
486        let command_buffers = [raw_cmd_buf];
487        let wait_values_all = [0];
488        let mut wait_semaphores_all = [vk::Semaphore::null()];
489        let wait_stages = [vk::PipelineStageFlags::ALL_COMMANDS];
490        let mut signal_semaphores_all = [queue.timeline_semaphore, vk::Semaphore::null()];
491        let signal_values_all = [progress, 0];
492        let (num_wait_semaphores, num_signal_sepahores) = match encoder.present {
493            Some(ref presentation) => {
494                wait_semaphores_all[0] = presentation.acquire_semaphore;
495                signal_semaphores_all[1] = presentation.present_semaphore;
496                (1, 2)
497            }
498            None => (0, 1),
499        };
500        let mut timeline_info = vk::TimelineSemaphoreSubmitInfo::default()
501            .wait_semaphore_values(&wait_values_all[..num_wait_semaphores])
502            .signal_semaphore_values(&signal_values_all[..num_signal_sepahores]);
503        let vk_info = vk::SubmitInfo::default()
504            .command_buffers(&command_buffers)
505            .wait_semaphores(&wait_semaphores_all[..num_wait_semaphores])
506            .wait_dst_stage_mask(&wait_stages[..num_wait_semaphores])
507            .signal_semaphores(&signal_semaphores_all[..num_signal_sepahores])
508            .push_next(&mut timeline_info);
509        let ret = unsafe {
510            self.device
511                .core
512                .queue_submit(queue.raw, &[vk_info], vk::Fence::null())
513        };
514        encoder.check_gpu_crash(ret);
515
516        if let Some(presentation) = encoder.present.take() {
517            let khr_swapchain = self.device.swapchain.as_ref().unwrap();
518            let swapchains = [presentation.swapchain];
519            let image_indices = [presentation.image_index];
520            let wait_semaphores = [presentation.present_semaphore];
521            let present_info = vk::PresentInfoKHR::default()
522                .swapchains(&swapchains)
523                .image_indices(&image_indices)
524                .wait_semaphores(&wait_semaphores);
525            let ret = unsafe { khr_swapchain.queue_present(queue.raw, &present_info) };
526            let _ = encoder.check_gpu_crash(ret);
527        }
528
529        SyncPoint { progress }
530    }
531
532    fn wait_for(&self, sp: &SyncPoint, timeout_ms: u32) -> bool {
533        //Note: technically we could get away without locking the queue,
534        // but also this isn't time-sensitive, so it's fine.
535        let timeline_semaphore = self.queue.lock().unwrap().timeline_semaphore;
536        let semaphores = [timeline_semaphore];
537        let semaphore_values = [sp.progress];
538        let wait_info = vk::SemaphoreWaitInfoKHR::default()
539            .semaphores(&semaphores)
540            .values(&semaphore_values);
541        let timeout_ns = map_timeout(timeout_ms);
542        unsafe {
543            self.device
544                .timeline_semaphore
545                .wait_semaphores(&wait_info, timeout_ns)
546                .is_ok()
547        }
548    }
549}
550
551fn map_texture_format(format: crate::TextureFormat) -> vk::Format {
552    use crate::TextureFormat as Tf;
553    match format {
554        Tf::R8Unorm => vk::Format::R8_UNORM,
555        Tf::Rg8Unorm => vk::Format::R8G8_UNORM,
556        Tf::Rg8Snorm => vk::Format::R8G8_SNORM,
557        Tf::Rgba8Unorm => vk::Format::R8G8B8A8_UNORM,
558        Tf::Rgba8UnormSrgb => vk::Format::R8G8B8A8_SRGB,
559        Tf::Bgra8Unorm => vk::Format::B8G8R8A8_UNORM,
560        Tf::Bgra8UnormSrgb => vk::Format::B8G8R8A8_SRGB,
561        Tf::Rgba8Snorm => vk::Format::R8G8B8A8_SNORM,
562        Tf::R16Float => vk::Format::R16_SFLOAT,
563        Tf::Rg16Float => vk::Format::R16G16_SFLOAT,
564        Tf::Rgba16Float => vk::Format::R16G16B16A16_SFLOAT,
565        Tf::R32Float => vk::Format::R32_SFLOAT,
566        Tf::Rg32Float => vk::Format::R32G32_SFLOAT,
567        Tf::Rgba32Float => vk::Format::R32G32B32A32_SFLOAT,
568        Tf::R32Uint => vk::Format::R32_UINT,
569        Tf::Rg32Uint => vk::Format::R32G32_UINT,
570        Tf::Rgba32Uint => vk::Format::R32G32B32A32_UINT,
571        Tf::Depth32Float => vk::Format::D32_SFLOAT,
572        Tf::Depth32FloatStencil8Uint => vk::Format::D32_SFLOAT_S8_UINT,
573        Tf::Stencil8Uint => vk::Format::S8_UINT,
574        Tf::Bc1Unorm => vk::Format::BC1_RGBA_SRGB_BLOCK,
575        Tf::Bc1UnormSrgb => vk::Format::BC1_RGBA_UNORM_BLOCK,
576        Tf::Bc2Unorm => vk::Format::BC2_UNORM_BLOCK,
577        Tf::Bc2UnormSrgb => vk::Format::BC2_SRGB_BLOCK,
578        Tf::Bc3Unorm => vk::Format::BC3_UNORM_BLOCK,
579        Tf::Bc3UnormSrgb => vk::Format::BC3_SRGB_BLOCK,
580        Tf::Bc4Unorm => vk::Format::BC4_UNORM_BLOCK,
581        Tf::Bc4Snorm => vk::Format::BC4_SNORM_BLOCK,
582        Tf::Bc5Unorm => vk::Format::BC5_UNORM_BLOCK,
583        Tf::Bc5Snorm => vk::Format::BC5_SNORM_BLOCK,
584        Tf::Bc6hUfloat => vk::Format::BC6H_UFLOAT_BLOCK,
585        Tf::Bc6hFloat => vk::Format::BC6H_SFLOAT_BLOCK,
586        Tf::Bc7Unorm => vk::Format::BC7_UNORM_BLOCK,
587        Tf::Bc7UnormSrgb => vk::Format::BC7_SRGB_BLOCK,
588        Tf::Rgb10a2Unorm => vk::Format::A2B10G10R10_UNORM_PACK32,
589        Tf::Rg11b10Ufloat => vk::Format::B10G11R11_UFLOAT_PACK32,
590        Tf::Rgb9e5Ufloat => vk::Format::E5B9G9R9_UFLOAT_PACK32,
591    }
592}
593
594fn map_aspects(aspects: crate::TexelAspects) -> vk::ImageAspectFlags {
595    let mut flags = vk::ImageAspectFlags::empty();
596    if aspects.contains(crate::TexelAspects::COLOR) {
597        flags |= vk::ImageAspectFlags::COLOR;
598    }
599    if aspects.contains(crate::TexelAspects::DEPTH) {
600        flags |= vk::ImageAspectFlags::DEPTH;
601    }
602    if aspects.contains(crate::TexelAspects::STENCIL) {
603        flags |= vk::ImageAspectFlags::STENCIL;
604    }
605    flags
606}
607
608fn map_extent_3d(extent: &crate::Extent) -> vk::Extent3D {
609    vk::Extent3D {
610        width: extent.width,
611        height: extent.height,
612        depth: extent.depth,
613    }
614}
615
616fn map_subresource_range(
617    subresources: &crate::TextureSubresources,
618    aspects: crate::TexelAspects,
619) -> vk::ImageSubresourceRange {
620    vk::ImageSubresourceRange {
621        aspect_mask: map_aspects(aspects),
622        base_mip_level: subresources.base_mip_level,
623        level_count: subresources
624            .mip_level_count
625            .map_or(vk::REMAINING_MIP_LEVELS, NonZeroU32::get),
626        base_array_layer: subresources.base_array_layer,
627        layer_count: subresources
628            .array_layer_count
629            .map_or(vk::REMAINING_ARRAY_LAYERS, NonZeroU32::get),
630    }
631}
632
633fn map_comparison(fun: crate::CompareFunction) -> vk::CompareOp {
634    use crate::CompareFunction as Cf;
635    match fun {
636        Cf::Never => vk::CompareOp::NEVER,
637        Cf::Less => vk::CompareOp::LESS,
638        Cf::LessEqual => vk::CompareOp::LESS_OR_EQUAL,
639        Cf::Equal => vk::CompareOp::EQUAL,
640        Cf::GreaterEqual => vk::CompareOp::GREATER_OR_EQUAL,
641        Cf::Greater => vk::CompareOp::GREATER,
642        Cf::NotEqual => vk::CompareOp::NOT_EQUAL,
643        Cf::Always => vk::CompareOp::ALWAYS,
644    }
645}
646
647fn map_index_type(index_type: crate::IndexType) -> vk::IndexType {
648    match index_type {
649        crate::IndexType::U16 => vk::IndexType::UINT16,
650        crate::IndexType::U32 => vk::IndexType::UINT32,
651    }
652}
653
654fn map_vertex_format(vertex_format: crate::VertexFormat) -> vk::Format {
655    use crate::VertexFormat as Vf;
656    match vertex_format {
657        Vf::F32 => vk::Format::R32_SFLOAT,
658        Vf::F32Vec2 => vk::Format::R32G32_SFLOAT,
659        Vf::F32Vec3 => vk::Format::R32G32B32_SFLOAT,
660        Vf::F32Vec4 => vk::Format::R32G32B32A32_SFLOAT,
661        Vf::U32 => vk::Format::R32_UINT,
662        Vf::U32Vec2 => vk::Format::R32G32_UINT,
663        Vf::U32Vec3 => vk::Format::R32G32B32_UINT,
664        Vf::U32Vec4 => vk::Format::R32G32B32A32_UINT,
665        Vf::I32 => vk::Format::R32_SINT,
666        Vf::I32Vec2 => vk::Format::R32G32_SINT,
667        Vf::I32Vec3 => vk::Format::R32G32B32_SINT,
668        Vf::I32Vec4 => vk::Format::R32G32B32A32_SINT,
669    }
670}
671
672struct BottomLevelAccelerationStructureInput<'a> {
673    max_primitive_counts: Box<[u32]>,
674    build_range_infos: Box<[vk::AccelerationStructureBuildRangeInfoKHR]>,
675    _geometries: Box<[vk::AccelerationStructureGeometryKHR<'a>]>,
676    build_info: vk::AccelerationStructureBuildGeometryInfoKHR<'a>,
677}
678
679impl Device {
680    fn get_device_address(&self, piece: &crate::BufferPiece) -> u64 {
681        let vk_info = vk::BufferDeviceAddressInfo {
682            buffer: piece.buffer.raw,
683            ..Default::default()
684        };
685        let base = unsafe { self.core.get_buffer_device_address(&vk_info) };
686        base + piece.offset
687    }
688
689    fn map_acceleration_structure_meshes(
690        &self,
691        meshes: &[crate::AccelerationStructureMesh],
692    ) -> BottomLevelAccelerationStructureInput {
693        let mut total_primitive_count = 0;
694        let mut max_primitive_counts = Vec::with_capacity(meshes.len());
695        let mut build_range_infos = Vec::with_capacity(meshes.len());
696        let mut geometries = Vec::with_capacity(meshes.len());
697        for mesh in meshes {
698            total_primitive_count += mesh.triangle_count;
699            max_primitive_counts.push(mesh.triangle_count);
700            build_range_infos.push(vk::AccelerationStructureBuildRangeInfoKHR {
701                primitive_count: mesh.triangle_count,
702                primitive_offset: 0,
703                first_vertex: 0,
704                transform_offset: 0,
705            });
706
707            let mut triangles = vk::AccelerationStructureGeometryTrianglesDataKHR {
708                vertex_format: map_vertex_format(mesh.vertex_format),
709                vertex_data: {
710                    let device_address = self.get_device_address(&mesh.vertex_data);
711                    assert!(
712                        device_address & 0x3 == 0,
713                        "Vertex data address {device_address} is not aligned"
714                    );
715                    vk::DeviceOrHostAddressConstKHR { device_address }
716                },
717                vertex_stride: mesh.vertex_stride as u64,
718                max_vertex: mesh.vertex_count.saturating_sub(1),
719                ..Default::default()
720            };
721            if let Some(index_type) = mesh.index_type {
722                let device_address = self.get_device_address(&mesh.index_data);
723                assert!(
724                    device_address & 0x3 == 0,
725                    "Index data address {device_address} is not aligned"
726                );
727                triangles.index_type = map_index_type(index_type);
728                triangles.index_data = vk::DeviceOrHostAddressConstKHR { device_address };
729            }
730            if mesh.transform_data.buffer.raw != vk::Buffer::null() {
731                let device_address = self.get_device_address(&mesh.transform_data);
732                assert!(
733                    device_address & 0xF == 0,
734                    "Transform data address {device_address} is not aligned"
735                );
736                triangles.transform_data = vk::DeviceOrHostAddressConstKHR { device_address };
737            }
738
739            let geometry = vk::AccelerationStructureGeometryKHR {
740                geometry_type: vk::GeometryTypeKHR::TRIANGLES,
741                geometry: vk::AccelerationStructureGeometryDataKHR { triangles },
742                flags: if mesh.is_opaque {
743                    vk::GeometryFlagsKHR::OPAQUE
744                } else {
745                    vk::GeometryFlagsKHR::empty()
746                },
747                ..Default::default()
748            };
749            geometries.push(geometry);
750        }
751        let build_info = vk::AccelerationStructureBuildGeometryInfoKHR {
752            ty: vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL,
753            flags: vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE,
754            mode: vk::BuildAccelerationStructureModeKHR::BUILD,
755            geometry_count: geometries.len() as u32,
756            p_geometries: geometries.as_ptr(),
757            ..Default::default()
758        };
759
760        log::debug!(
761            "BLAS total {} primitives in {} geometries",
762            total_primitive_count,
763            geometries.len()
764        );
765        BottomLevelAccelerationStructureInput {
766            max_primitive_counts: max_primitive_counts.into_boxed_slice(),
767            build_range_infos: build_range_infos.into_boxed_slice(),
768            _geometries: geometries.into_boxed_slice(),
769            build_info,
770        }
771    }
772}