use {
super::{
shader::{DescriptorBindingMap, PipelineDescriptorInfo, Shader},
Device, DriverError,
},
ash::vk,
derive_builder::{Builder, UninitializedFieldError},
log::{trace, warn},
std::{ffi::CString, ops::Deref, sync::Arc, thread::panicking},
};
#[derive(Debug)]
pub struct ComputePipeline {
pub(crate) descriptor_bindings: DescriptorBindingMap,
pub(crate) descriptor_info: PipelineDescriptorInfo,
device: Arc<Device>,
pub(crate) layout: vk::PipelineLayout,
pub info: ComputePipelineInfo,
pipeline: vk::Pipeline,
pub(crate) push_constants: Option<vk::PushConstantRange>,
}
impl ComputePipeline {
pub fn create(
device: &Arc<Device>,
info: impl Into<ComputePipelineInfo>,
shader: impl Into<Shader>,
) -> Result<Self, DriverError> {
use std::slice::from_ref;
trace!("create");
let device = Arc::clone(device);
let info: ComputePipelineInfo = info.into();
let shader = shader.into();
let mut descriptor_bindings = shader.descriptor_bindings(&device);
for (descriptor_info, _) in descriptor_bindings.values_mut() {
if descriptor_info.binding_count() == 0 {
descriptor_info.set_binding_count(info.bindless_descriptor_count);
}
}
let descriptor_info = PipelineDescriptorInfo::create(&device, &descriptor_bindings)?;
let descriptor_set_layouts = descriptor_info
.layouts
.values()
.map(|descriptor_set_layout| **descriptor_set_layout)
.collect::<Box<[_]>>();
unsafe {
let shader_module_create_info = vk::ShaderModuleCreateInfo {
code_size: shader.spirv.len(),
p_code: shader.spirv.as_ptr() as *const u32,
..Default::default()
};
let shader_module = device
.create_shader_module(&shader_module_create_info, None)
.map_err(|err| {
warn!("{err}");
DriverError::Unsupported
})?;
let entry_name = CString::new(shader.entry_name.as_bytes()).unwrap();
let mut stage_create_info = vk::PipelineShaderStageCreateInfo::builder()
.module(shader_module)
.stage(shader.stage)
.name(&entry_name);
let specialization_info = shader.specialization_info.as_ref().map(|info| {
vk::SpecializationInfo::builder()
.map_entries(&info.map_entries)
.data(&info.data)
.build()
});
if let Some(specialization_info) = &specialization_info {
stage_create_info = stage_create_info.specialization_info(specialization_info);
}
let mut layout_info =
vk::PipelineLayoutCreateInfo::builder().set_layouts(&descriptor_set_layouts);
let push_constants = shader.push_constant_range();
if let Some(push_constants) = &push_constants {
layout_info = layout_info.push_constant_ranges(from_ref(push_constants));
}
let layout = device
.create_pipeline_layout(&layout_info, None)
.map_err(|err| {
warn!("{err}");
DriverError::Unsupported
})?;
let pipeline_info = vk::ComputePipelineCreateInfo::builder()
.stage(stage_create_info.build())
.layout(layout);
let pipeline = device
.create_compute_pipelines(
vk::PipelineCache::null(),
from_ref(&pipeline_info.build()),
None,
)
.map_err(|(_, err)| {
warn!("{err}");
DriverError::Unsupported
})?[0];
device.destroy_shader_module(shader_module, None);
Ok(ComputePipeline {
descriptor_bindings,
descriptor_info,
device,
info,
layout,
pipeline,
push_constants,
})
}
}
}
impl Deref for ComputePipeline {
type Target = vk::Pipeline;
fn deref(&self) -> &Self::Target {
&self.pipeline
}
}
impl Drop for ComputePipeline {
fn drop(&mut self) {
if panicking() {
return;
}
unsafe {
self.device.destroy_pipeline(self.pipeline, None);
self.device.destroy_pipeline_layout(self.layout, None);
}
}
}
#[derive(Builder, Clone, Debug, Default)]
#[builder(
pattern = "owned",
build_fn(
private,
name = "fallible_build",
error = "ComputePipelineInfoBuilderError"
)
)]
pub struct ComputePipelineInfo {
#[builder(default = "8192")]
pub bindless_descriptor_count: u32,
#[builder(default, setter(strip_option))]
pub name: Option<String>,
}
impl From<ComputePipelineInfoBuilder> for ComputePipelineInfo {
fn from(info: ComputePipelineInfoBuilder) -> Self {
info.build()
}
}
impl ComputePipelineInfoBuilder {
pub fn build(self) -> ComputePipelineInfo {
self.fallible_build()
.expect("All required fields set at initialization")
}
}
#[derive(Debug)]
struct ComputePipelineInfoBuilderError;
impl From<UninitializedFieldError> for ComputePipelineInfoBuilderError {
fn from(_: UninitializedFieldError) -> Self {
Self
}
}