use super::descriptor::{DescriptorSetLayout, ShaderStageFlags};
use super::device::DeviceInner;
use super::shader::ShaderModule;
use super::{Device, Error, Result, check};
use crate::raw::bindings::*;
use std::ffi::CString;
use std::sync::Arc;
#[derive(Debug, Clone, Copy)]
pub struct PushConstantRange {
pub stage_flags: ShaderStageFlags,
pub offset: u32,
pub size: u32,
}
pub struct PipelineLayout {
pub(crate) handle: VkPipelineLayout,
pub(crate) device: Arc<DeviceInner>,
}
impl PipelineLayout {
pub fn new(device: &Device, set_layouts: &[&DescriptorSetLayout]) -> Result<Self> {
Self::with_push_constants(device, set_layouts, &[])
}
pub fn with_push_constants(
device: &Device,
set_layouts: &[&DescriptorSetLayout],
push_constant_ranges: &[PushConstantRange],
) -> Result<Self> {
let create = device
.inner
.dispatch
.vkCreatePipelineLayout
.ok_or(Error::MissingFunction("vkCreatePipelineLayout"))?;
let raw_layouts: Vec<VkDescriptorSetLayout> =
set_layouts.iter().map(|l| l.handle).collect();
let raw_pcrs: Vec<VkPushConstantRange> = push_constant_ranges
.iter()
.map(|r| VkPushConstantRange {
stageFlags: r.stage_flags.0,
offset: r.offset,
size: r.size,
})
.collect();
let info = VkPipelineLayoutCreateInfo {
sType: VkStructureType::STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
setLayoutCount: raw_layouts.len() as u32,
pSetLayouts: raw_layouts.as_ptr(),
pushConstantRangeCount: raw_pcrs.len() as u32,
pPushConstantRanges: raw_pcrs.as_ptr(),
..Default::default()
};
let mut handle: VkPipelineLayout = 0;
check(unsafe { create(device.inner.handle, &info, std::ptr::null(), &mut handle) })?;
Ok(Self {
handle,
device: Arc::clone(&device.inner),
})
}
pub fn raw(&self) -> VkPipelineLayout {
self.handle
}
}
impl Drop for PipelineLayout {
fn drop(&mut self) {
if let Some(destroy) = self.device.dispatch.vkDestroyPipelineLayout {
unsafe { destroy(self.device.handle, self.handle, std::ptr::null()) };
}
}
}
#[derive(Default, Clone)]
pub struct SpecializationConstants {
entries: Vec<VkSpecializationMapEntry>,
data: Vec<u8>,
}
impl SpecializationConstants {
pub fn new() -> Self {
Self::default()
}
pub fn add_u32(mut self, constant_id: u32, value: u32) -> Self {
self.push_bytes(constant_id, &value.to_ne_bytes());
self
}
pub fn add_i32(mut self, constant_id: u32, value: i32) -> Self {
self.push_bytes(constant_id, &value.to_ne_bytes());
self
}
pub fn add_f32(mut self, constant_id: u32, value: f32) -> Self {
self.push_bytes(constant_id, &value.to_ne_bytes());
self
}
pub fn add_bool(mut self, constant_id: u32, value: bool) -> Self {
let v: u32 = if value { 1 } else { 0 };
self.push_bytes(constant_id, &v.to_ne_bytes());
self
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn len(&self) -> usize {
self.entries.len()
}
fn push_bytes(&mut self, constant_id: u32, bytes: &[u8]) {
let offset = self.data.len() as u32;
self.entries.push(VkSpecializationMapEntry {
constantID: constant_id,
offset,
size: bytes.len(),
});
self.data.extend_from_slice(bytes);
}
pub(crate) fn as_raw(&self) -> VkSpecializationInfo {
VkSpecializationInfo {
mapEntryCount: self.entries.len() as u32,
pMapEntries: self.entries.as_ptr(),
dataSize: self.data.len(),
pData: self.data.as_ptr() as *const _,
}
}
}
pub struct PipelineCache {
pub(crate) handle: VkPipelineCache,
pub(crate) device: Arc<DeviceInner>,
}
impl PipelineCache {
pub fn new(device: &Device) -> Result<Self> {
Self::with_data(device, &[])
}
pub fn with_data(device: &Device, data: &[u8]) -> Result<Self> {
let create = device
.inner
.dispatch
.vkCreatePipelineCache
.ok_or(Error::MissingFunction("vkCreatePipelineCache"))?;
let info = VkPipelineCacheCreateInfo {
sType: VkStructureType::STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO,
initialDataSize: data.len(),
pInitialData: if data.is_empty() {
std::ptr::null()
} else {
data.as_ptr() as *const _
},
..Default::default()
};
let mut handle: VkPipelineCache = 0;
check(unsafe { create(device.inner.handle, &info, std::ptr::null(), &mut handle) })?;
Ok(Self {
handle,
device: Arc::clone(&device.inner),
})
}
pub fn raw(&self) -> VkPipelineCache {
self.handle
}
pub fn data(&self) -> Result<Vec<u8>> {
let get = self
.device
.dispatch
.vkGetPipelineCacheData
.ok_or(Error::MissingFunction("vkGetPipelineCacheData"))?;
let mut size: usize = 0;
check(unsafe {
get(
self.device.handle,
self.handle,
&mut size,
std::ptr::null_mut(),
)
})?;
let mut bytes = vec![0u8; size];
check(unsafe {
get(
self.device.handle,
self.handle,
&mut size,
bytes.as_mut_ptr() as *mut _,
)
})?;
bytes.truncate(size);
Ok(bytes)
}
}
impl Drop for PipelineCache {
fn drop(&mut self) {
if let Some(destroy) = self.device.dispatch.vkDestroyPipelineCache {
unsafe { destroy(self.device.handle, self.handle, std::ptr::null()) };
}
}
}
pub struct ComputePipeline {
pub(crate) handle: VkPipeline,
pub(crate) device: Arc<DeviceInner>,
}
impl ComputePipeline {
pub fn new(
device: &Device,
layout: &PipelineLayout,
shader: &ShaderModule,
entry_point: &str,
) -> Result<Self> {
Self::with_specialization(
device,
layout,
shader,
entry_point,
&SpecializationConstants::new(),
)
}
pub fn with_specialization(
device: &Device,
layout: &PipelineLayout,
shader: &ShaderModule,
entry_point: &str,
specialization: &SpecializationConstants,
) -> Result<Self> {
Self::with_specialization_and_cache(
device,
layout,
shader,
entry_point,
specialization,
None,
)
}
pub fn with_specialization_and_cache(
device: &Device,
layout: &PipelineLayout,
shader: &ShaderModule,
entry_point: &str,
specialization: &SpecializationConstants,
cache: Option<&PipelineCache>,
) -> Result<Self> {
let create = device
.inner
.dispatch
.vkCreateComputePipelines
.ok_or(Error::MissingFunction("vkCreateComputePipelines"))?;
let entry_c = CString::new(entry_point)?;
let spec_raw = specialization.as_raw();
let p_spec = if specialization.is_empty() {
std::ptr::null()
} else {
&spec_raw as *const _
};
let stage = VkPipelineShaderStageCreateInfo {
sType: VkStructureType::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
stage: 0x20, module: shader.handle,
pName: entry_c.as_ptr(),
pSpecializationInfo: p_spec,
..Default::default()
};
let info = VkComputePipelineCreateInfo {
sType: VkStructureType::STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
stage,
layout: layout.handle,
..Default::default()
};
let cache_handle = cache.map_or(0u64, |c| c.handle);
let mut handle: VkPipeline = 0;
check(unsafe {
create(
device.inner.handle,
cache_handle,
1,
&info,
std::ptr::null(),
&mut handle,
)
})?;
Ok(Self {
handle,
device: Arc::clone(&device.inner),
})
}
pub fn raw(&self) -> VkPipeline {
self.handle
}
}
impl Drop for ComputePipeline {
fn drop(&mut self) {
if let Some(destroy) = self.device.dispatch.vkDestroyPipeline {
unsafe { destroy(self.device.handle, self.handle, std::ptr::null()) };
}
}
}