use {
super::{
DriverError,
device::Device,
shader::{DescriptorBindingMap, PipelineDescriptorInfo, Shader},
},
ash::vk,
derive_builder::{Builder, UninitializedFieldError},
log::{trace, warn},
std::{
ffi::CString,
hash::{Hash, Hasher},
slice,
sync::{Arc, OnceLock},
thread::panicking,
},
};
#[derive(Clone, Debug)]
pub struct ComputePipeline {
pub(crate) inner: Arc<ComputePipelineInner>,
}
impl ComputePipeline {
#[profiling::function]
pub fn create(
device: &Device,
info: impl Into<ComputePipelineInfo>,
shader: impl Into<Shader>,
) -> Result<Self, DriverError> {
trace!("create");
let info = info.into();
let shader = shader.into();
let mut descriptor_bindings = shader.descriptor_bindings();
for (descriptor_info, _) in descriptor_bindings.values_mut() {
if descriptor_info.binding_count() == 0 {
descriptor_info.set_binding_count(info.bindless_descriptor_count);
}
}
let descriptor_info = PipelineDescriptorInfo::create(device, &descriptor_bindings)?;
let descriptor_set_layouts = descriptor_info
.layouts
.values()
.map(|descriptor_set_layout| descriptor_set_layout.handle)
.collect::<Box<_>>();
unsafe {
let shader_module = device
.create_shader_module(
&vk::ShaderModuleCreateInfo::default().code(shader.spirv.words()),
None,
)
.map_err(|err| {
warn!("unable to create compute shader module: {err}");
DriverError::Unsupported
})?;
let entry_name =
CString::new(shader.entry_name.as_bytes()).expect("invalid entry name");
let mut stage_create_info = vk::PipelineShaderStageCreateInfo::default()
.module(shader_module)
.stage(shader.stage)
.name(&entry_name);
let specialization_info = shader.specialization.as_ref().map(Into::into);
if let Some(specialization_info) = &specialization_info {
stage_create_info = stage_create_info.specialization_info(specialization_info);
}
let mut layout_info =
vk::PipelineLayoutCreateInfo::default().set_layouts(&descriptor_set_layouts);
let push_constants = shader.push_constant_range();
if let Some(push_constants) = &push_constants {
layout_info = layout_info.push_constant_ranges(slice::from_ref(push_constants));
}
let layout = device
.create_pipeline_layout(&layout_info, None)
.map_err(|err| {
warn!("unable to create compute pipeline layout: {err}");
device.destroy_shader_module(shader_module, None);
DriverError::Unsupported
})?;
let create_info = vk::ComputePipelineCreateInfo::default()
.stage(stage_create_info)
.layout(layout);
let handle = device
.create_compute_pipelines(
Device::pipeline_cache(device),
slice::from_ref(&create_info),
None,
)
.map_err(|(_, err)| {
warn!("unable to create compute pipeline: {err}");
device.destroy_shader_module(shader_module, None);
DriverError::Unsupported
})?[0];
device.destroy_shader_module(shader_module, None);
Ok(ComputePipeline {
inner: Arc::new(ComputePipelineInner {
descriptor_bindings,
descriptor_info,
device: device.clone(),
handle,
info,
layout,
name: Default::default(),
push_constants,
}),
})
}
}
pub fn debug_name(&self) -> Option<&str> {
self.inner.name.get().map(String::as_str)
}
pub fn device(&self) -> &Device {
&self.inner.device
}
pub fn handle(&self) -> vk::Pipeline {
self.inner.handle
}
pub fn info(&self) -> ComputePipelineInfo {
self.inner.info
}
pub fn set_debug_name(&mut self, name: impl Into<String>) {
if !self.inner.device.physical_device.instance.info.debug {
return;
}
let _ = self.inner.name.set(name.into());
}
pub fn with_debug_name(mut self, name: impl Into<String>) -> Self {
self.set_debug_name(name);
self
}
}
impl Eq for ComputePipeline {}
impl Hash for ComputePipeline {
fn hash<H: Hasher>(&self, state: &mut H) {
Arc::as_ptr(&self.inner).hash(state);
}
}
impl PartialEq for ComputePipeline {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[builder(
build_fn(private, name = "fallible_build", error = "UninitializedFieldError"),
derive(Clone, Copy, Debug),
pattern = "owned"
)]
pub struct ComputePipelineInfo {
#[builder(default = "8192")]
pub bindless_descriptor_count: u32,
}
impl ComputePipelineInfo {
pub fn builder() -> ComputePipelineInfoBuilder {
Default::default()
}
pub fn into_builder(self) -> ComputePipelineInfoBuilder {
ComputePipelineInfoBuilder {
bindless_descriptor_count: Some(self.bindless_descriptor_count),
}
}
#[deprecated = "use into_builder function"]
#[doc(hidden)]
pub fn to_builder(self) -> ComputePipelineInfoBuilder {
self.into_builder()
}
}
impl Default for ComputePipelineInfo {
fn default() -> Self {
Self {
bindless_descriptor_count: 8192,
}
}
}
impl From<ComputePipelineInfoBuilder> for ComputePipelineInfo {
fn from(info: ComputePipelineInfoBuilder) -> Self {
info.build()
}
}
impl ComputePipelineInfoBuilder {
#[inline(always)]
pub fn build(self) -> ComputePipelineInfo {
self.fallible_build()
.expect("invalid compute pipeline info")
}
}
#[derive(Debug)]
pub(crate) struct ComputePipelineInner {
pub descriptor_bindings: DescriptorBindingMap,
pub descriptor_info: PipelineDescriptorInfo,
pub device: Device,
pub handle: vk::Pipeline,
pub info: ComputePipelineInfo,
pub layout: vk::PipelineLayout,
pub name: OnceLock<String>,
pub push_constants: Option<vk::PushConstantRange>,
}
impl Drop for ComputePipelineInner {
#[profiling::function]
fn drop(&mut self) {
if panicking() {
return;
}
unsafe {
self.device.destroy_pipeline(self.handle, None);
self.device.destroy_pipeline_layout(self.layout, None);
}
}
}
mod deprecated {
use crate::driver::compute::ComputePipeline;
impl ComputePipeline {
#[deprecated = "use with_debug_name function"]
#[doc(hidden)]
pub fn with_name(this: Self, name: impl Into<String>) -> Self {
this.with_debug_name(name)
}
}
}
#[cfg(test)]
mod test {
use super::*;
type Info = ComputePipelineInfo;
type Builder = ComputePipelineInfoBuilder;
#[test]
pub fn compute_pipeline_info() {
let info = Info::default();
let builder = info.into_builder().build();
assert_eq!(info, builder);
}
#[test]
pub fn compute_pipeline_info_builder() {
let info = Info::default();
let builder = Builder::default().build();
assert_eq!(info, builder);
}
}