use crate::Result;
use std::{collections::HashMap, sync::Arc};
use ash::vk::DescriptorSetLayout;
use ash::vk::DescriptorSetLayoutCreateInfo;
use ash::Device;
use parking_lot::RwLock;
use smallvec::SmallVec;
use std::hash::{Hash, Hasher};
use super::DescriptorSetBinding;
use super::MAX_BINDINGS;
#[derive(Clone, Debug)]
pub struct DescriptorLayoutInfo {
bindings: SmallVec<[DescriptorSetBinding; MAX_BINDINGS]>,
}
unsafe impl Send for DescriptorLayoutInfo {}
unsafe impl Sync for DescriptorLayoutInfo {}
impl DescriptorLayoutInfo {
pub fn new(bindings: &[DescriptorSetBinding]) -> Self {
let mut layout = Self {
bindings: Default::default(),
};
for binding in bindings {
layout.insert(*binding);
}
layout
}
pub fn bindings(&self) -> &[DescriptorSetBinding] {
&self.bindings
}
pub fn insert(&mut self, binding: DescriptorSetBinding) {
let mut len = self.bindings.len();
let mut mid = (len / 2).min(1);
loop {
if len < 1 {
if mid == self.bindings.len() || self.bindings[mid].binding > binding.binding {
self.bindings.insert(mid, binding);
} else {
self.bindings.insert(mid + 1, binding);
}
break;
}
match self.bindings[mid].binding.cmp(&binding.binding) {
std::cmp::Ordering::Less => {
len /= 2;
mid += (len as f32 / 2.0).floor() as usize;
}
std::cmp::Ordering::Equal => {
self.bindings[mid] = binding;
break;
}
std::cmp::Ordering::Greater => {
len /= 2;
mid -= (len as f32 / 2.0).ceil() as usize;
}
}
}
}
pub fn extend<I: Iterator<Item = DescriptorSetBinding>>(&mut self, bindings: I) {
self.bindings.extend(bindings);
}
pub fn get(&mut self, binding: u32) -> Option<DescriptorSetBinding> {
let mut len = self.bindings.len();
let mut mid = (len / 2).max(1);
loop {
if mid >= self.bindings.len() || len == 0 {
return None;
}
match self.bindings[mid].binding.cmp(&binding) {
std::cmp::Ordering::Less => {
len /= 2;
mid += (len as f32 / 2.0).ceil() as usize;
}
std::cmp::Ordering::Equal => return Some(self.bindings[mid]),
std::cmp::Ordering::Greater => {
len /= 2;
mid -= (len as f32 / 2.0).ceil() as usize;
}
}
}
}
}
impl Default for DescriptorLayoutInfo {
fn default() -> Self {
Self {
bindings: Default::default(),
}
}
}
impl Hash for DescriptorLayoutInfo {
fn hash<H: Hasher>(&self, state: &mut H) {
for binding in &self.bindings {
binding.binding.hash(state);
}
}
}
impl PartialEq for DescriptorLayoutInfo {
fn eq(&self, other: &Self) -> bool {
for (a, b) in self.bindings.iter().zip(&other.bindings) {
if a.binding != b.binding
|| a.descriptor_type != b.descriptor_type
|| a.descriptor_count != b.descriptor_count
|| a.stage_flags != b.stage_flags
{
return false;
}
}
true
}
}
impl Eq for DescriptorLayoutInfo {}
pub struct DescriptorLayoutCache {
device: Arc<Device>,
layouts: RwLock<HashMap<DescriptorLayoutInfo, DescriptorSetLayout>>,
}
impl DescriptorLayoutCache {
pub fn new(device: Arc<Device>) -> Self {
Self {
device,
layouts: RwLock::new(HashMap::new()),
}
}
pub fn get(&self, info: &DescriptorLayoutInfo) -> Result<DescriptorSetLayout> {
let guard = self.layouts.read();
if let Some(layout) = guard.get(info) {
Ok(*layout)
} else {
drop(guard);
let info = info.clone();
let layout = create(&self.device, &info)?;
Ok(*self.layouts.write().entry(info).or_insert(layout))
}
}
pub fn clear(&mut self) {
for (_, layout) in self.layouts.write().drain() {
destroy(&self.device, layout);
}
}
}
impl Drop for DescriptorLayoutCache {
fn drop(&mut self) {
self.clear()
}
}
pub fn create(device: &Device, info: &DescriptorLayoutInfo) -> Result<DescriptorSetLayout> {
let create_info = DescriptorSetLayoutCreateInfo {
binding_count: info.bindings.len() as u32,
p_bindings: info.bindings.as_ptr(),
..Default::default()
};
let layout = unsafe { device.create_descriptor_set_layout(&create_info, None)? };
Ok(layout)
}
pub fn destroy(device: &Device, layout: DescriptorSetLayout) {
unsafe { device.destroy_descriptor_set_layout(layout, None) }
}
#[cfg(test)]
mod tests {
use ash::vk::DescriptorType;
use super::*;
#[test]
fn layout_info_add() {
let mut layout = DescriptorLayoutInfo::new(&[]);
let mut bindings = [DescriptorSetBinding::default(); 6];
bindings[0] = DescriptorSetBinding {
binding: 0,
descriptor_type: DescriptorType::UNIFORM_BUFFER,
descriptor_count: 1,
..Default::default()
};
layout.insert(bindings[0]);
eprintln!("Layout: {:?}", layout);
bindings[2] = DescriptorSetBinding {
binding: 2,
descriptor_type: DescriptorType::UNIFORM_BUFFER,
descriptor_count: 1,
..Default::default()
};
layout.insert(bindings[2]);
eprintln!("Layout: {:?}", layout);
bindings[1] = DescriptorSetBinding {
binding: 1,
descriptor_type: DescriptorType::UNIFORM_BUFFER,
descriptor_count: 1,
..Default::default()
};
layout.insert(bindings[1]);
eprintln!("Layout: {:?}", layout);
bindings[3] = DescriptorSetBinding {
binding: 3,
descriptor_type: DescriptorType::UNIFORM_BUFFER,
descriptor_count: 1,
..Default::default()
};
layout.insert(bindings[3]);
eprintln!("Layout: {:?}", layout);
bindings[5] = DescriptorSetBinding {
binding: 5,
descriptor_type: DescriptorType::UNIFORM_BUFFER,
descriptor_count: 1,
..Default::default()
};
layout.insert(bindings[5]);
eprintln!("Layout: {:?}", layout);
bindings[4] = DescriptorSetBinding {
binding: 4,
descriptor_type: DescriptorType::UNIFORM_BUFFER,
descriptor_count: 1,
..Default::default()
};
layout.insert(bindings[4]);
layout.insert(DescriptorSetBinding {
binding: 4,
descriptor_type: DescriptorType::COMBINED_IMAGE_SAMPLER,
descriptor_count: 1,
..Default::default()
});
eprintln!("Layout: {:?}", layout);
assert!(layout
.bindings()
.iter()
.map(|val| val.binding)
.eq([0, 1, 2, 3, 4, 5].iter().cloned()));
for binding in &bindings {
assert_eq!(
Some(binding.binding),
layout.get(binding.binding).map(|val| val.binding)
)
}
assert_eq!(None, layout.get(9).map(|_| ()));
}
}