use crate::core;
use crate::ex::RuntimeError;
use ash::vk;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DescriptorBinding {
UniformBuffer {
binding: u32,
stage: vk::ShaderStageFlags,
},
StorageBuffer {
binding: u32,
stage: vk::ShaderStageFlags,
},
CombinedImageSampler {
binding: u32,
stage: vk::ShaderStageFlags,
},
StorageImage {
binding: u32,
stage: vk::ShaderStageFlags,
},
}
impl DescriptorBinding {
#[inline]
pub fn binding(self) -> u32 {
match self {
DescriptorBinding::UniformBuffer { binding, .. }
| DescriptorBinding::StorageBuffer { binding, .. }
| DescriptorBinding::CombinedImageSampler { binding, .. }
| DescriptorBinding::StorageImage { binding, .. } => binding,
}
}
#[inline]
pub fn descriptor_type(self) -> vk::DescriptorType {
match self {
DescriptorBinding::UniformBuffer { .. } => vk::DescriptorType::UNIFORM_BUFFER,
DescriptorBinding::StorageBuffer { .. } => vk::DescriptorType::STORAGE_BUFFER,
DescriptorBinding::CombinedImageSampler { .. } => {
vk::DescriptorType::COMBINED_IMAGE_SAMPLER
}
DescriptorBinding::StorageImage { .. } => vk::DescriptorType::STORAGE_IMAGE,
}
}
#[inline]
pub fn stage_flags(self) -> vk::ShaderStageFlags {
match self {
DescriptorBinding::UniformBuffer { stage, .. }
| DescriptorBinding::StorageBuffer { stage, .. }
| DescriptorBinding::CombinedImageSampler { stage, .. }
| DescriptorBinding::StorageImage { stage, .. } => stage,
}
}
#[inline]
pub fn to_layout_binding(self) -> vk::DescriptorSetLayoutBinding<'static> {
vk::DescriptorSetLayoutBinding {
binding: self.binding(),
descriptor_type: self.descriptor_type(),
descriptor_count: 1,
stage_flags: self.stage_flags(),
p_immutable_samplers: std::ptr::null(),
..Default::default()
}
}
}
pub struct DescriptorLayoutBuilder {
bindings: Vec<DescriptorBinding>,
}
impl DescriptorLayoutBuilder {
pub fn new() -> Self {
Self {
bindings: Vec::new(),
}
}
pub fn uniform_buffer(mut self, binding: u32, stage: vk::ShaderStageFlags) -> Self {
self.bindings
.push(DescriptorBinding::UniformBuffer { binding, stage });
self
}
pub fn storage_buffer(mut self, binding: u32, stage: vk::ShaderStageFlags) -> Self {
self.bindings
.push(DescriptorBinding::StorageBuffer { binding, stage });
self
}
pub fn combined_image_sampler(mut self, binding: u32, stage: vk::ShaderStageFlags) -> Self {
self.bindings
.push(DescriptorBinding::CombinedImageSampler { binding, stage });
self
}
pub fn storage_image(mut self, binding: u32, stage: vk::ShaderStageFlags) -> Self {
self.bindings
.push(DescriptorBinding::StorageImage { binding, stage });
self
}
pub fn build(
self,
device: &Arc<core::Device>,
) -> Result<core::DescriptorSetLayout, RuntimeError> {
let bindings: Vec<_> = self
.bindings
.into_iter()
.map(|b| b.to_layout_binding())
.collect();
let layout_info = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings);
let layout = unsafe {
device
.handle()
.create_descriptor_set_layout(&layout_info, None)
}
.map_err(|e| RuntimeError::Other(format!("Descriptor layout creation failed: {:?}", e)))?;
Ok(core::DescriptorSetLayout::from_raw(layout))
}
}
impl Default for DescriptorLayoutBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct DescriptorPoolSizes {
pub uniform_buffers: u32,
pub storage_buffers: u32,
pub combined_image_samplers: u32,
pub storage_images: u32,
}
impl DescriptorPoolSizes {
pub fn simple_material(count: u32) -> Self {
Self {
uniform_buffers: count,
storage_buffers: 0,
combined_image_samplers: count,
storage_images: 0,
}
}
pub fn compute(count: u32) -> Self {
Self {
uniform_buffers: 0,
storage_buffers: count * 2, combined_image_samplers: 0,
storage_images: 0,
}
}
pub fn custom() -> Self {
Self {
uniform_buffers: 0,
storage_buffers: 0,
combined_image_samplers: 0,
storage_images: 0,
}
}
pub fn with_uniform_buffers(mut self, count: u32) -> Self {
self.uniform_buffers = count;
self
}
pub fn with_storage_buffers(mut self, count: u32) -> Self {
self.storage_buffers = count;
self
}
pub fn with_combined_image_samplers(mut self, count: u32) -> Self {
self.combined_image_samplers = count;
self
}
pub fn with_storage_images(mut self, count: u32) -> Self {
self.storage_images = count;
self
}
fn to_pool_sizes(&self) -> Vec<vk::DescriptorPoolSize> {
let mut sizes = Vec::new();
if self.uniform_buffers > 0 {
sizes.push(vk::DescriptorPoolSize {
ty: vk::DescriptorType::UNIFORM_BUFFER,
descriptor_count: self.uniform_buffers,
});
}
if self.storage_buffers > 0 {
sizes.push(vk::DescriptorPoolSize {
ty: vk::DescriptorType::STORAGE_BUFFER,
descriptor_count: self.storage_buffers,
});
}
if self.combined_image_samplers > 0 {
sizes.push(vk::DescriptorPoolSize {
ty: vk::DescriptorType::COMBINED_IMAGE_SAMPLER,
descriptor_count: self.combined_image_samplers,
});
}
if self.storage_images > 0 {
sizes.push(vk::DescriptorPoolSize {
ty: vk::DescriptorType::STORAGE_IMAGE,
descriptor_count: self.storage_images,
});
}
sizes
}
}
pub fn create_descriptor_pool(
device: &Arc<core::Device>,
sizes: DescriptorPoolSizes,
max_sets: u32,
) -> Result<core::DescriptorPool, RuntimeError> {
let pool_sizes = sizes.to_pool_sizes();
if pool_sizes.is_empty() {
return Err(RuntimeError::Other(
"Descriptor pool must have at least one pool size".to_string(),
));
}
let pool_info = vk::DescriptorPoolCreateInfo::default()
.pool_sizes(&pool_sizes)
.max_sets(max_sets);
let pool = unsafe { device.handle().create_descriptor_pool(&pool_info, None) }
.map_err(|e| RuntimeError::Other(format!("Descriptor pool creation failed: {:?}", e)))?;
Ok(core::DescriptorPool::from_raw(pool))
}
pub struct DescriptorWriter {
device: Arc<core::Device>,
writes: Vec<(vk::DescriptorSet, u32, DescriptorResource)>,
}
pub enum DescriptorResource {
Buffer {
buffer: vk::Buffer,
offset: vk::DeviceSize,
range: vk::DeviceSize,
descriptor_type: vk::DescriptorType,
},
Image {
image_view: vk::ImageView,
sampler: vk::Sampler,
layout: vk::ImageLayout,
},
}
impl DescriptorWriter {
pub fn new(device: &Arc<core::Device>) -> Self {
Self {
device: Arc::clone(device),
writes: Vec::new(),
}
}
pub fn write_buffer(
mut self,
set: vk::DescriptorSet,
binding: u32,
buffer: vk::Buffer,
offset: vk::DeviceSize,
range: vk::DeviceSize,
descriptor_type: vk::DescriptorType,
) -> Self {
self.writes.push((
set,
binding,
DescriptorResource::Buffer {
buffer,
offset,
range,
descriptor_type,
},
));
self
}
pub fn write_image(
mut self,
set: vk::DescriptorSet,
binding: u32,
image_view: vk::ImageView,
sampler: vk::Sampler,
layout: vk::ImageLayout,
) -> Self {
self.writes.push((
set,
binding,
DescriptorResource::Image {
image_view,
sampler,
layout,
},
));
self
}
pub fn update(self) {
let mut buffer_infos = Vec::new();
let mut image_infos = Vec::new();
for (_, _, resource) in &self.writes {
match resource {
DescriptorResource::Buffer {
buffer,
offset,
range,
..
} => {
buffer_infos.push(vk::DescriptorBufferInfo {
buffer: *buffer,
offset: *offset,
range: *range,
});
}
DescriptorResource::Image {
image_view,
sampler,
layout,
} => {
image_infos.push(vk::DescriptorImageInfo {
sampler: *sampler,
image_view: *image_view,
image_layout: *layout,
});
}
}
}
let mut buffer_idx = 0;
let mut image_idx = 0;
let mut write_descriptors = Vec::new();
for (set, binding, resource) in &self.writes {
let write = match resource {
DescriptorResource::Buffer {
descriptor_type, ..
} => {
let write = vk::WriteDescriptorSet::default()
.dst_set(*set)
.dst_binding(*binding)
.dst_array_element(0)
.descriptor_type(*descriptor_type)
.descriptor_count(1)
.buffer_info(std::slice::from_ref(&buffer_infos[buffer_idx]));
buffer_idx += 1;
write
}
DescriptorResource::Image { .. } => {
let write = vk::WriteDescriptorSet::default()
.dst_set(*set)
.dst_binding(*binding)
.dst_array_element(0)
.descriptor_type(vk::DescriptorType::COMBINED_IMAGE_SAMPLER)
.descriptor_count(1)
.image_info(std::slice::from_ref(&image_infos[image_idx]));
image_idx += 1;
write
}
};
write_descriptors.push(write);
}
unsafe {
self.device
.handle()
.update_descriptor_sets(&write_descriptors, &[]);
}
}
}
pub fn create_ubo_layout(
device: &Arc<core::Device>,
binding: u32,
stage: vk::ShaderStageFlags,
) -> Result<core::DescriptorSetLayout, RuntimeError> {
DescriptorLayoutBuilder::new()
.uniform_buffer(binding, stage)
.build(device)
}
pub fn create_texture_layout(
device: &Arc<core::Device>,
binding: u32,
stage: vk::ShaderStageFlags,
) -> Result<core::DescriptorSetLayout, RuntimeError> {
DescriptorLayoutBuilder::new()
.combined_image_sampler(binding, stage)
.build(device)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ex::{RuntimeConfig, RuntimeManager};
fn test_runtime() -> RuntimeManager {
RuntimeManager::new(RuntimeConfig {
enable_validation: false,
..Default::default()
})
.unwrap()
}
#[test]
fn test_descriptor_binding_conversion() {
let binding = DescriptorBinding::UniformBuffer {
binding: 0,
stage: vk::ShaderStageFlags::VERTEX,
};
assert_eq!(binding.binding(), 0);
assert_eq!(
binding.descriptor_type(),
vk::DescriptorType::UNIFORM_BUFFER
);
assert_eq!(binding.stage_flags(), vk::ShaderStageFlags::VERTEX);
}
#[test]
fn test_descriptor_layout_builder() {
let runtime = test_runtime();
let layout = DescriptorLayoutBuilder::new()
.uniform_buffer(0, vk::ShaderStageFlags::VERTEX)
.combined_image_sampler(1, vk::ShaderStageFlags::FRAGMENT)
.build(&runtime.device());
assert!(layout.is_ok());
}
#[test]
fn test_create_ubo_layout() {
let runtime = test_runtime();
let layout = create_ubo_layout(&runtime.device(), 0, vk::ShaderStageFlags::VERTEX);
assert!(layout.is_ok());
}
#[test]
fn test_descriptor_pool_creation() {
let runtime = test_runtime();
let pool = create_descriptor_pool(
&runtime.device(),
DescriptorPoolSizes::simple_material(10),
10,
);
assert!(pool.is_ok());
}
#[test]
fn test_descriptor_pool_sizes() {
let sizes = DescriptorPoolSizes::simple_material(5);
assert_eq!(sizes.uniform_buffers, 5);
assert_eq!(sizes.combined_image_samplers, 5);
let compute_sizes = DescriptorPoolSizes::compute(3);
assert_eq!(compute_sizes.storage_buffers, 6);
}
#[test]
fn test_custom_pool_sizes() {
let sizes = DescriptorPoolSizes::custom()
.with_uniform_buffers(10)
.with_storage_buffers(5);
assert_eq!(sizes.uniform_buffers, 10);
assert_eq!(sizes.storage_buffers, 5);
}
}