use super::device::DeviceInner;
use super::image::{ImageLayout, ImageView, Sampler};
use super::{Buffer, Device, Error, Result, check};
use crate::raw::bindings::*;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DescriptorType(pub VkDescriptorType);
impl DescriptorType {
pub const STORAGE_BUFFER: Self = Self(VkDescriptorType::DESCRIPTOR_TYPE_STORAGE_BUFFER);
pub const UNIFORM_BUFFER: Self = Self(VkDescriptorType::DESCRIPTOR_TYPE_UNIFORM_BUFFER);
pub const STORAGE_IMAGE: Self = Self(VkDescriptorType::DESCRIPTOR_TYPE_STORAGE_IMAGE);
pub const COMBINED_IMAGE_SAMPLER: Self =
Self(VkDescriptorType::DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER);
pub const SAMPLED_IMAGE: Self = Self(VkDescriptorType::DESCRIPTOR_TYPE_SAMPLED_IMAGE);
pub const SAMPLER: Self = Self(VkDescriptorType::DESCRIPTOR_TYPE_SAMPLER);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ShaderStageFlags(pub u32);
impl ShaderStageFlags {
pub const VERTEX: Self = Self(0x1);
pub const FRAGMENT: Self = Self(0x10);
pub const COMPUTE: Self = Self(0x20);
pub const ALL_GRAPHICS: Self = Self(0x1F);
pub const ALL: Self = Self(0x7FFFFFFF);
pub const fn contains(self, other: Self) -> bool {
(self.0 & other.0) == other.0
}
}
impl std::ops::BitOr for ShaderStageFlags {
type Output = Self;
fn bitor(self, rhs: Self) -> Self {
Self(self.0 | rhs.0)
}
}
#[derive(Debug, Clone, Copy)]
pub struct DescriptorSetLayoutBinding {
pub binding: u32,
pub descriptor_type: DescriptorType,
pub descriptor_count: u32,
pub stage_flags: ShaderStageFlags,
}
pub struct DescriptorSetLayout {
pub(crate) handle: VkDescriptorSetLayout,
pub(crate) device: Arc<DeviceInner>,
}
impl DescriptorSetLayout {
pub fn new(device: &Device, bindings: &[DescriptorSetLayoutBinding]) -> Result<Self> {
let create = device
.inner
.dispatch
.vkCreateDescriptorSetLayout
.ok_or(Error::MissingFunction("vkCreateDescriptorSetLayout"))?;
let raw_bindings: Vec<VkDescriptorSetLayoutBinding> = bindings
.iter()
.map(|b| VkDescriptorSetLayoutBinding {
binding: b.binding,
descriptorType: b.descriptor_type.0,
descriptorCount: b.descriptor_count,
stageFlags: b.stage_flags.0,
pImmutableSamplers: std::ptr::null(),
})
.collect();
let info = VkDescriptorSetLayoutCreateInfo {
sType: VkStructureType::STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
bindingCount: raw_bindings.len() as u32,
pBindings: raw_bindings.as_ptr(),
..Default::default()
};
let mut handle: VkDescriptorSetLayout = 0;
check(unsafe { create(device.inner.handle, &info, std::ptr::null(), &mut handle) })?;
Ok(Self {
handle,
device: Arc::clone(&device.inner),
})
}
pub fn raw(&self) -> VkDescriptorSetLayout {
self.handle
}
}
impl Drop for DescriptorSetLayout {
fn drop(&mut self) {
if let Some(destroy) = self.device.dispatch.vkDestroyDescriptorSetLayout {
unsafe { destroy(self.device.handle, self.handle, std::ptr::null()) };
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DescriptorPoolSize {
pub descriptor_type: DescriptorType,
pub descriptor_count: u32,
}
pub struct DescriptorPool {
pub(crate) handle: VkDescriptorPool,
pub(crate) device: Arc<DeviceInner>,
}
impl DescriptorPool {
pub fn new(device: &Device, max_sets: u32, sizes: &[DescriptorPoolSize]) -> Result<Self> {
let create = device
.inner
.dispatch
.vkCreateDescriptorPool
.ok_or(Error::MissingFunction("vkCreateDescriptorPool"))?;
let raw_sizes: Vec<VkDescriptorPoolSize> = sizes
.iter()
.map(|s| VkDescriptorPoolSize {
r#type: s.descriptor_type.0,
descriptorCount: s.descriptor_count,
})
.collect();
let info = VkDescriptorPoolCreateInfo {
sType: VkStructureType::STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO,
maxSets: max_sets,
poolSizeCount: raw_sizes.len() as u32,
pPoolSizes: raw_sizes.as_ptr(),
..Default::default()
};
let mut handle: VkDescriptorPool = 0;
check(unsafe { create(device.inner.handle, &info, std::ptr::null(), &mut handle) })?;
Ok(Self {
handle,
device: Arc::clone(&device.inner),
})
}
pub fn raw(&self) -> VkDescriptorPool {
self.handle
}
pub fn allocate(&self, layout: &DescriptorSetLayout) -> Result<DescriptorSet> {
let allocate = self
.device
.dispatch
.vkAllocateDescriptorSets
.ok_or(Error::MissingFunction("vkAllocateDescriptorSets"))?;
let info = VkDescriptorSetAllocateInfo {
sType: VkStructureType::STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO,
descriptorPool: self.handle,
descriptorSetCount: 1,
pSetLayouts: &layout.handle,
..Default::default()
};
let mut handle: VkDescriptorSet = 0;
check(unsafe { allocate(self.device.handle, &info, &mut handle) })?;
Ok(DescriptorSet {
handle,
device: Arc::clone(&self.device),
})
}
}
impl Drop for DescriptorPool {
fn drop(&mut self) {
if let Some(destroy) = self.device.dispatch.vkDestroyDescriptorPool {
unsafe { destroy(self.device.handle, self.handle, std::ptr::null()) };
}
}
}
pub struct DescriptorSet {
pub(crate) handle: VkDescriptorSet,
pub(crate) device: Arc<DeviceInner>,
}
impl DescriptorSet {
pub fn raw(&self) -> VkDescriptorSet {
self.handle
}
pub fn write_buffer(
&self,
binding: u32,
descriptor_type: DescriptorType,
buffer: &Buffer,
offset: u64,
range: u64,
) {
let update = self
.device
.dispatch
.vkUpdateDescriptorSets
.expect("vkUpdateDescriptorSets is required by Vulkan 1.0");
let info = VkDescriptorBufferInfo {
buffer: buffer.handle,
offset,
range,
};
let write = VkWriteDescriptorSet {
sType: VkStructureType::STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
dstSet: self.handle,
dstBinding: binding,
descriptorCount: 1,
descriptorType: descriptor_type.0,
pBufferInfo: &info,
..Default::default()
};
unsafe { update(self.device.handle, 1, &write, 0, std::ptr::null()) };
}
pub fn write_combined_image_sampler(
&self,
binding: u32,
sampler: &Sampler,
view: &ImageView,
image_layout: ImageLayout,
) {
let update = self
.device
.dispatch
.vkUpdateDescriptorSets
.expect("vkUpdateDescriptorSets is required by Vulkan 1.0");
let info = VkDescriptorImageInfo {
sampler: sampler.handle,
imageView: view.handle,
imageLayout: image_layout.0,
};
let write = VkWriteDescriptorSet {
sType: VkStructureType::STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
dstSet: self.handle,
dstBinding: binding,
descriptorCount: 1,
descriptorType: VkDescriptorType::DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
pImageInfo: &info,
..Default::default()
};
unsafe { update(self.device.handle, 1, &write, 0, std::ptr::null()) };
}
pub fn write_sampled_image(&self, binding: u32, view: &ImageView, image_layout: ImageLayout) {
let update = self
.device
.dispatch
.vkUpdateDescriptorSets
.expect("vkUpdateDescriptorSets is required by Vulkan 1.0");
let info = VkDescriptorImageInfo {
sampler: 0,
imageView: view.handle,
imageLayout: image_layout.0,
};
let write = VkWriteDescriptorSet {
sType: VkStructureType::STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
dstSet: self.handle,
dstBinding: binding,
descriptorCount: 1,
descriptorType: VkDescriptorType::DESCRIPTOR_TYPE_SAMPLED_IMAGE,
pImageInfo: &info,
..Default::default()
};
unsafe { update(self.device.handle, 1, &write, 0, std::ptr::null()) };
}
pub fn write_sampler(&self, binding: u32, sampler: &Sampler) {
let update = self
.device
.dispatch
.vkUpdateDescriptorSets
.expect("vkUpdateDescriptorSets is required by Vulkan 1.0");
let info = VkDescriptorImageInfo {
sampler: sampler.handle,
imageView: 0,
imageLayout: VkImageLayout::IMAGE_LAYOUT_UNDEFINED,
};
let write = VkWriteDescriptorSet {
sType: VkStructureType::STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
dstSet: self.handle,
dstBinding: binding,
descriptorCount: 1,
descriptorType: VkDescriptorType::DESCRIPTOR_TYPE_SAMPLER,
pImageInfo: &info,
..Default::default()
};
unsafe { update(self.device.handle, 1, &write, 0, std::ptr::null()) };
}
pub fn write_storage_image(&self, binding: u32, view: &ImageView, image_layout: ImageLayout) {
let update = self
.device
.dispatch
.vkUpdateDescriptorSets
.expect("vkUpdateDescriptorSets is required by Vulkan 1.0");
let info = VkDescriptorImageInfo {
sampler: 0,
imageView: view.handle,
imageLayout: image_layout.0,
};
let write = VkWriteDescriptorSet {
sType: VkStructureType::STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
dstSet: self.handle,
dstBinding: binding,
descriptorCount: 1,
descriptorType: VkDescriptorType::DESCRIPTOR_TYPE_STORAGE_IMAGE,
pImageInfo: &info,
..Default::default()
};
unsafe { update(self.device.handle, 1, &write, 0, std::ptr::null()) };
}
}