use std::collections::HashMap;
use std::sync::Mutex;
use crate::sys::*;
use crate::core::*;
use crate::ffi::*;
use super::error::IcdError;
pub const MAX_PUSH_CONSTANT_SIZE: u32 = 128;
pub const PERSISTENT_DESCRIPTOR_SET: u32 = 0;
struct PersistentDescriptor {
descriptor_set: VkDescriptorSet,
buffers: Vec<VkBuffer>,
generation: u64,
}
pub struct PersistentDescriptorManager {
pools: HashMap<u64, VkDescriptorPool>,
set0_layout: HashMap<u64, VkDescriptorSetLayout>,
descriptors: HashMap<u64, PersistentDescriptor>,
generation: u64,
}
lazy_static::lazy_static! {
static ref DESCRIPTOR_MANAGER: Mutex<PersistentDescriptorManager> = Mutex::new(PersistentDescriptorManager {
pools: HashMap::new(),
set0_layout: HashMap::new(),
descriptors: HashMap::new(),
generation: 0,
});
}
pub unsafe fn create_persistent_layout(
device: VkDevice,
max_bindings: u32,
) -> Result<VkDescriptorSetLayout, IcdError> {
let mut manager = DESCRIPTOR_MANAGER.lock()?;
let device_key = device.as_raw();
if let Some(&layout) = manager.set0_layout.get(&device_key) {
return Ok(layout);
}
let mut bindings = Vec::with_capacity(max_bindings as usize);
for i in 0..max_bindings {
bindings.push(VkDescriptorSetLayoutBinding {
binding: i,
descriptorType: VkDescriptorType::StorageBuffer,
descriptorCount: 1,
stageFlags: VkShaderStageFlags::COMPUTE,
pImmutableSamplers: std::ptr::null(),
});
}
let create_info = VkDescriptorSetLayoutCreateInfo {
sType: VkStructureType::DescriptorSetLayoutCreateInfo,
pNext: std::ptr::null(),
flags: 0,
bindingCount: bindings.len() as u32,
pBindings: bindings.as_ptr(),
};
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(create_fn) = icd.create_descriptor_set_layout {
let mut layout = VkDescriptorSetLayout::NULL;
let result = create_fn(device, &create_info, std::ptr::null(), &mut layout);
if result == VkResult::Success {
manager.set0_layout.insert(device_key, layout);
return Ok(layout);
}
return Err(IcdError::VulkanError(result));
}
}
Err(IcdError::MissingFunction("vkCreateDescriptorSetLayout"))
}
pub unsafe fn get_persistent_pool(
device: VkDevice,
max_sets: u32,
max_descriptors: u32,
) -> Result<VkDescriptorPool, IcdError> {
let mut manager = DESCRIPTOR_MANAGER.lock()?;
let device_key = device.as_raw();
if let Some(&pool) = manager.pools.get(&device_key) {
return Ok(pool);
}
let pool_size = VkDescriptorPoolSize {
type_: VkDescriptorType::StorageBuffer,
descriptorCount: max_descriptors,
};
let create_info = VkDescriptorPoolCreateInfo {
sType: VkStructureType::DescriptorPoolCreateInfo,
pNext: std::ptr::null(),
flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
maxSets: max_sets,
poolSizeCount: 1,
pPoolSizes: &pool_size,
};
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(create_fn) = icd.create_descriptor_pool {
let mut pool = VkDescriptorPool::NULL;
let result = create_fn(device, &create_info, std::ptr::null(), &mut pool);
if result == VkResult::Success {
manager.pools.insert(device_key, pool);
return Ok(pool);
}
return Err(IcdError::VulkanError(result));
}
}
Err(IcdError::MissingFunction("vkCreateDescriptorPool"))
}
pub unsafe fn get_persistent_descriptor_set(
device: VkDevice,
buffers: &[VkBuffer],
) -> Result<VkDescriptorSet, IcdError> {
let mut manager = DESCRIPTOR_MANAGER.lock()?;
let cache_key = buffers.iter()
.map(|b| b.as_raw())
.fold(0u64, |acc, h| acc.wrapping_add(h).rotate_left(7));
if let Some(descriptor) = manager.descriptors.get(&cache_key) {
if descriptor.buffers == buffers {
return Ok(descriptor.descriptor_set);
}
}
let layout = create_persistent_layout(device, buffers.len() as u32)?;
let pool = get_persistent_pool(device, 1000, 10000)?;
let alloc_info = VkDescriptorSetAllocateInfo {
sType: VkStructureType::DescriptorSetAllocateInfo,
pNext: std::ptr::null(),
descriptorPool: pool,
descriptorSetCount: 1,
pSetLayouts: &layout,
};
let mut descriptor_set = VkDescriptorSet::NULL;
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(alloc_fn) = icd.allocate_descriptor_sets {
let result = alloc_fn(device, &alloc_info, &mut descriptor_set);
if result != VkResult::Success {
return Err(IcdError::VulkanError(result));
}
} else {
return Err(IcdError::MissingFunction("vkAllocateDescriptorSets"));
}
} else {
return Err(IcdError::NoIcdLoaded);
}
let mut buffer_infos = Vec::with_capacity(buffers.len());
let mut writes = Vec::with_capacity(buffers.len());
for (_i, &buffer) in buffers.iter().enumerate() {
buffer_infos.push(VkDescriptorBufferInfo {
buffer,
offset: 0,
range: VK_WHOLE_SIZE,
});
}
for (i, buffer_info) in buffer_infos.iter().enumerate() {
writes.push(VkWriteDescriptorSet {
sType: VkStructureType::WriteDescriptorSet,
pNext: std::ptr::null(),
dstSet: descriptor_set,
dstBinding: i as u32,
dstArrayElement: 0,
descriptorCount: 1,
descriptorType: VkDescriptorType::StorageBuffer,
pImageInfo: std::ptr::null(),
pBufferInfo: buffer_info,
pTexelBufferView: std::ptr::null(),
});
}
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(update_fn) = icd.update_descriptor_sets {
update_fn(device, writes.len() as u32, writes.as_ptr(), 0, std::ptr::null());
}
}
manager.generation += 1;
let generation = manager.generation;
manager.descriptors.insert(cache_key, PersistentDescriptor {
descriptor_set,
buffers: buffers.to_vec(),
generation,
});
Ok(descriptor_set)
}
pub fn create_push_constant_range(size: u32) -> VkPushConstantRange {
assert!(size <= MAX_PUSH_CONSTANT_SIZE, "Push constant size {} exceeds limit {}", size, MAX_PUSH_CONSTANT_SIZE);
VkPushConstantRange {
stageFlags: VkShaderStageFlags::COMPUTE,
offset: 0,
size,
}
}
pub unsafe fn create_compute_pipeline_layout(
device: VkDevice,
set0_binding_count: u32,
push_constant_size: u32,
) -> Result<VkPipelineLayout, IcdError> {
let set0_layout = create_persistent_layout(device, set0_binding_count)?;
let mut create_info = VkPipelineLayoutCreateInfo {
sType: VkStructureType::PipelineLayoutCreateInfo,
pNext: std::ptr::null(),
flags: 0,
setLayoutCount: 1,
pSetLayouts: &set0_layout,
pushConstantRangeCount: 0,
pPushConstantRanges: std::ptr::null(),
};
let push_range = if push_constant_size > 0 {
Some(create_push_constant_range(push_constant_size))
} else {
None
};
if let Some(ref range) = push_range {
create_info.pushConstantRangeCount = 1;
create_info.pPushConstantRanges = range;
}
let mut layout = VkPipelineLayout::NULL;
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(create_fn) = icd.create_pipeline_layout {
let result = create_fn(device, &create_info, std::ptr::null(), &mut layout);
if result == VkResult::Success {
return Ok(layout);
}
return Err(IcdError::VulkanError(result));
}
}
Err(IcdError::MissingFunction("vkCreatePipelineLayout"))
}
pub unsafe fn cleanup_persistent_descriptors(device: VkDevice) -> Result<(), IcdError> {
let mut manager = DESCRIPTOR_MANAGER.lock()?;
let device_key = device.as_raw();
if let Some(pool) = manager.pools.remove(&device_key) {
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(destroy_fn) = icd.destroy_descriptor_pool {
destroy_fn(device, pool, std::ptr::null());
}
}
}
if let Some(layout) = manager.set0_layout.remove(&device_key) {
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(destroy_fn) = icd.destroy_descriptor_set_layout {
destroy_fn(device, layout, std::ptr::null());
}
}
}
manager.descriptors.retain(|_, desc| {
desc.generation > 0 });
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_push_constant_range() {
let range = create_push_constant_range(64);
assert_eq!(range.stageFlags, VkShaderStageFlags::COMPUTE);
assert_eq!(range.offset, 0);
assert_eq!(range.size, 64);
}
#[test]
#[should_panic]
fn test_push_constant_size_limit() {
create_push_constant_range(256); }
}