use {
super::{
shader::ShaderCode, DescriptorBindingMap, Device, DriverError, PipelineDescriptorInfo,
Shader, SpecializationInfo,
},
ash::vk,
derive_builder::Builder,
log::{trace, warn},
std::{ffi::CString, ops::Deref, sync::Arc, thread::panicking},
};
#[derive(Debug)]
pub struct ComputePipeline {
pub descriptor_bindings: DescriptorBindingMap,
pub descriptor_info: PipelineDescriptorInfo,
pub device: Arc<Device>,
pub layout: vk::PipelineLayout,
pub info: ComputePipelineInfo,
pipeline: vk::Pipeline,
pub push_constants: Option<vk::PushConstantRange>,
}
impl ComputePipeline {
pub fn create(
device: &Arc<Device>,
info: impl Into<ComputePipelineInfo>,
) -> Result<Self, DriverError> {
use std::slice::from_ref;
trace!("create");
let device = Arc::clone(device);
let info: ComputePipelineInfo = info.into();
let shader = info.clone().into_shader();
let mut descriptor_bindings = shader.descriptor_bindings(&device);
for (descriptor_info, _) in descriptor_bindings.values_mut() {
if descriptor_info.binding_count() == 0 {
descriptor_info.set_binding_count(info.bindless_descriptor_count);
}
}
let descriptor_info = PipelineDescriptorInfo::create(&device, &descriptor_bindings)?;
let descriptor_set_layouts = descriptor_info
.layouts
.iter()
.map(|(_, descriptor_set_layout)| **descriptor_set_layout)
.collect::<Box<[_]>>();
unsafe {
let shader_module_create_info = vk::ShaderModuleCreateInfo {
code_size: info.spirv.len(),
p_code: info.spirv.as_ptr() as *const u32,
..Default::default()
};
let shader_module = device
.create_shader_module(&shader_module_create_info, None)
.map_err(|err| {
warn!("{err}");
DriverError::Unsupported
})?;
let entry_name = CString::new(info.entry_name.as_bytes()).unwrap();
let mut stage_create_info = vk::PipelineShaderStageCreateInfo::builder()
.module(shader_module)
.stage(shader.stage)
.name(&entry_name);
let specialization_info = info.specialization_info.as_ref().map(|info| {
vk::SpecializationInfo::builder()
.map_entries(&info.map_entries)
.data(&info.data)
.build()
});
if let Some(specialization_info) = &specialization_info {
stage_create_info = stage_create_info.specialization_info(specialization_info);
}
let mut layout_info =
vk::PipelineLayoutCreateInfo::builder().set_layouts(&descriptor_set_layouts);
let push_constants = shader.push_constant_range();
if let Some(push_constants) = &push_constants {
layout_info = layout_info.push_constant_ranges(from_ref(push_constants));
}
let layout = device
.create_pipeline_layout(&layout_info, None)
.map_err(|err| {
warn!("{err}");
DriverError::Unsupported
})?;
let pipeline_info = vk::ComputePipelineCreateInfo::builder()
.stage(stage_create_info.build())
.layout(layout);
let pipeline = device
.create_compute_pipelines(
vk::PipelineCache::null(),
from_ref(&pipeline_info.build()),
None,
)
.map_err(|(_, err)| {
warn!("{err}");
DriverError::Unsupported
})?[0];
device.destroy_shader_module(shader_module, None);
Ok(ComputePipeline {
descriptor_bindings,
descriptor_info,
device,
info,
layout,
pipeline,
push_constants,
})
}
}
}
impl Deref for ComputePipeline {
type Target = vk::Pipeline;
fn deref(&self) -> &Self::Target {
&self.pipeline
}
}
impl Drop for ComputePipeline {
fn drop(&mut self) {
if panicking() {
return;
}
unsafe {
self.device.destroy_pipeline(self.pipeline, None);
self.device.destroy_pipeline_layout(self.layout, None);
}
}
}
#[derive(Builder, Clone, Debug)]
#[builder(pattern = "owned")]
pub struct ComputePipelineInfo {
#[builder(default = "8192")]
pub bindless_descriptor_count: u32,
#[builder(setter(strip_option), default = "String::from(\"main\")")]
pub entry_name: String,
#[builder(default, setter(strip_option))]
pub name: Option<String>,
#[builder(default, setter(strip_option))]
pub specialization_info: Option<SpecializationInfo>,
pub spirv: Vec<u8>,
}
impl ComputePipelineInfo {
#[allow(clippy::new_ret_no_self)]
pub fn new(spirv: impl ShaderCode) -> ComputePipelineInfoBuilder {
ComputePipelineInfoBuilder::default().spirv(spirv.into_vec())
}
pub fn into_shader(self) -> Shader {
let mut shader =
Shader::new(vk::ShaderStageFlags::COMPUTE, self.spirv).entry_name(self.entry_name);
if let Some(specialization_info) = self.specialization_info {
shader = shader.specialization_info(specialization_info);
}
shader.build()
}
}
impl<S> From<S> for ComputePipelineInfo
where
S: ShaderCode,
{
fn from(spirv: S) -> Self {
Self::new(spirv).build().unwrap()
}
}
impl From<ComputePipelineInfoBuilder> for ComputePipelineInfo {
fn from(info: ComputePipelineInfoBuilder) -> Self {
info.build().unwrap()
}
}