use {
super::{
DriverError,
device::Device,
shader::{DescriptorBindingMap, PipelineDescriptorInfo, Shader, align_spriv},
},
ash::vk,
derive_builder::{Builder, UninitializedFieldError},
log::{trace, warn},
std::{ffi::CString, ops::Deref, sync::Arc, thread::panicking},
};
#[derive(Debug)]
pub struct ComputePipeline {
pub(crate) descriptor_bindings: DescriptorBindingMap,
pub(crate) descriptor_info: PipelineDescriptorInfo,
device: Arc<Device>,
pub(crate) layout: vk::PipelineLayout,
pub info: ComputePipelineInfo,
pub name: Option<String>,
pipeline: vk::Pipeline,
pub(crate) push_constants: Option<vk::PushConstantRange>,
}
impl ComputePipeline {
#[profiling::function]
pub fn create(
device: &Arc<Device>,
info: impl Into<ComputePipelineInfo>,
shader: impl Into<Shader>,
) -> Result<Self, DriverError> {
use std::slice::from_ref;
trace!("create");
let device = Arc::clone(device);
let info: ComputePipelineInfo = info.into();
let shader = shader.into();
let mut descriptor_bindings = shader.descriptor_bindings();
for (descriptor_info, _) in descriptor_bindings.values_mut() {
if descriptor_info.binding_count() == 0 {
descriptor_info.set_binding_count(info.bindless_descriptor_count);
}
}
let descriptor_info = PipelineDescriptorInfo::create(&device, &descriptor_bindings)?;
let descriptor_set_layouts = descriptor_info
.layouts
.values()
.map(|descriptor_set_layout| **descriptor_set_layout)
.collect::<Box<[_]>>();
unsafe {
let shader_module = device
.create_shader_module(
&vk::ShaderModuleCreateInfo::default().code(align_spriv(&shader.spirv)?),
None,
)
.map_err(|err| {
warn!("{err}");
DriverError::Unsupported
})?;
let entry_name = CString::new(shader.entry_name.as_bytes()).unwrap();
let mut stage_create_info = vk::PipelineShaderStageCreateInfo::default()
.module(shader_module)
.stage(shader.stage)
.name(&entry_name);
let specialization_info = shader.specialization_info.as_ref().map(|info| {
vk::SpecializationInfo::default()
.map_entries(&info.map_entries)
.data(&info.data)
});
if let Some(specialization_info) = &specialization_info {
stage_create_info = stage_create_info.specialization_info(specialization_info);
}
let mut layout_info =
vk::PipelineLayoutCreateInfo::default().set_layouts(&descriptor_set_layouts);
let push_constants = shader.push_constant_range();
if let Some(push_constants) = &push_constants {
layout_info = layout_info.push_constant_ranges(from_ref(push_constants));
}
let layout = device
.create_pipeline_layout(&layout_info, None)
.map_err(|err| {
warn!("{err}");
DriverError::Unsupported
})?;
let pipeline_info = vk::ComputePipelineCreateInfo::default()
.stage(stage_create_info)
.layout(layout);
let pipeline = device
.create_compute_pipelines(
Device::pipeline_cache(&device),
from_ref(&pipeline_info),
None,
)
.map_err(|(_, err)| {
warn!("{err}");
DriverError::Unsupported
})?[0];
device.destroy_shader_module(shader_module, None);
Ok(ComputePipeline {
descriptor_bindings,
descriptor_info,
device,
info,
layout,
name: None,
pipeline,
push_constants,
})
}
}
pub fn with_name(mut this: Self, name: impl Into<String>) -> Self {
this.name = Some(name.into());
this
}
}
impl Deref for ComputePipeline {
type Target = vk::Pipeline;
fn deref(&self) -> &Self::Target {
&self.pipeline
}
}
impl Drop for ComputePipeline {
#[profiling::function]
fn drop(&mut self) {
if panicking() {
return;
}
unsafe {
self.device.destroy_pipeline(self.pipeline, None);
self.device.destroy_pipeline_layout(self.layout, None);
}
}
}
#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[builder(
build_fn(
private,
name = "fallible_build",
error = "ComputePipelineInfoBuilderError"
),
derive(Clone, Copy, Debug),
pattern = "owned"
)]
#[non_exhaustive]
pub struct ComputePipelineInfo {
#[builder(default = "8192")]
pub bindless_descriptor_count: u32,
}
impl ComputePipelineInfo {
#[inline(always)]
pub fn to_builder(self) -> ComputePipelineInfoBuilder {
ComputePipelineInfoBuilder {
bindless_descriptor_count: Some(self.bindless_descriptor_count),
}
}
}
impl Default for ComputePipelineInfo {
fn default() -> Self {
Self {
bindless_descriptor_count: 8192,
}
}
}
impl From<ComputePipelineInfoBuilder> for ComputePipelineInfo {
fn from(info: ComputePipelineInfoBuilder) -> Self {
info.build()
}
}
impl ComputePipelineInfoBuilder {
#[inline(always)]
pub fn build(self) -> ComputePipelineInfo {
let res = self.fallible_build();
#[cfg(test)]
let res = res.unwrap();
#[cfg(not(test))]
let res = unsafe { res.unwrap_unchecked() };
res
}
}
#[derive(Debug)]
struct ComputePipelineInfoBuilderError;
impl From<UninitializedFieldError> for ComputePipelineInfoBuilderError {
fn from(_: UninitializedFieldError) -> Self {
Self
}
}
#[cfg(test)]
mod tests {
use super::*;
type Info = ComputePipelineInfo;
type Builder = ComputePipelineInfoBuilder;
#[test]
pub fn compute_pipeline_info() {
let info = Info::default();
let builder = info.to_builder().build();
assert_eq!(info, builder);
}
#[test]
pub fn compute_pipeline_info_builder() {
let info = Info::default();
let builder = Builder::default().build();
assert_eq!(info, builder);
}
}