use super::device::DeviceInner;
use super::pipeline::PipelineLayout;
use super::shader::ShaderModule;
use super::{Device, Error, Result, check};
use crate::raw::bindings::*;
use std::ffi::CString;
use std::sync::Arc;
#[derive(Debug, Clone, Copy)]
pub enum ShaderGroup {
General {
shader: u32,
},
TrianglesHit {
closest_hit: Option<u32>,
any_hit: Option<u32>,
},
ProceduralHit {
closest_hit: Option<u32>,
any_hit: Option<u32>,
intersection: u32,
},
}
impl ShaderGroup {
fn to_raw(self) -> VkRayTracingShaderGroupCreateInfoKHR {
const UNUSED: u32 = !0u32; let (r#type, general, closest, any, intersection) = match self {
Self::General { shader } => (
VkRayTracingShaderGroupTypeKHR::RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
shader,
UNUSED,
UNUSED,
UNUSED,
),
Self::TrianglesHit {
closest_hit,
any_hit,
} => (
VkRayTracingShaderGroupTypeKHR::RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
UNUSED,
closest_hit.unwrap_or(UNUSED),
any_hit.unwrap_or(UNUSED),
UNUSED,
),
Self::ProceduralHit {
closest_hit,
any_hit,
intersection,
} => (
VkRayTracingShaderGroupTypeKHR::RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR,
UNUSED,
closest_hit.unwrap_or(UNUSED),
any_hit.unwrap_or(UNUSED),
intersection,
),
};
VkRayTracingShaderGroupCreateInfoKHR {
sType: VkStructureType::STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
pNext: std::ptr::null(),
r#type,
generalShader: general,
closestHitShader: closest,
anyHitShader: any,
intersectionShader: intersection,
pShaderGroupCaptureReplayHandle: std::ptr::null(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RayTracingShaderStage {
Raygen,
Miss,
ClosestHit,
AnyHit,
Intersection,
Callable,
}
impl RayTracingShaderStage {
#[inline]
fn to_bit(self) -> u32 {
match self {
Self::Raygen => 0x0000_0100,
Self::AnyHit => 0x0000_0200,
Self::ClosestHit => 0x0000_0400,
Self::Miss => 0x0000_0800,
Self::Intersection => 0x0000_1000,
Self::Callable => 0x0000_2000,
}
}
}
pub struct RayTracingStage<'a> {
pub stage: RayTracingShaderStage,
pub module: &'a ShaderModule,
pub entry_point: &'a str,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ShaderBindingRegion {
pub address: u64,
pub stride: u64,
pub size: u64,
}
impl ShaderBindingRegion {
pub(crate) fn to_raw(self) -> VkStridedDeviceAddressRegionKHR {
VkStridedDeviceAddressRegionKHR {
deviceAddress: self.address,
stride: self.stride,
size: self.size,
}
}
}
pub struct RayTracingPipeline {
pub(crate) handle: VkPipeline,
pub(crate) device: Arc<DeviceInner>,
pub(crate) group_count: u32,
}
impl RayTracingPipeline {
pub fn new(
device: &Device,
layout: &PipelineLayout,
stages: &[RayTracingStage<'_>],
groups: &[ShaderGroup],
max_recursion_depth: u32,
) -> Result<Self> {
let create = device
.inner
.dispatch
.vkCreateRayTracingPipelinesKHR
.ok_or(Error::MissingFunction("vkCreateRayTracingPipelinesKHR"))?;
let entry_cstrs: Vec<CString> = stages
.iter()
.map(|s| CString::new(s.entry_point))
.collect::<std::result::Result<_, _>>()?;
let raw_stages: Vec<VkPipelineShaderStageCreateInfo> = stages
.iter()
.zip(entry_cstrs.iter())
.map(|(s, name)| VkPipelineShaderStageCreateInfo {
sType: VkStructureType::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
pNext: std::ptr::null(),
flags: 0,
stage: s.stage.to_bit(),
module: s.module.raw(),
pName: name.as_ptr(),
pSpecializationInfo: std::ptr::null(),
})
.collect();
let raw_groups: Vec<VkRayTracingShaderGroupCreateInfoKHR> =
groups.iter().map(|g| g.to_raw()).collect();
let info = VkRayTracingPipelineCreateInfoKHR {
sType: VkStructureType::STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR,
pNext: std::ptr::null(),
flags: 0,
stageCount: raw_stages.len() as u32,
pStages: if raw_stages.is_empty() {
std::ptr::null()
} else {
raw_stages.as_ptr()
},
groupCount: raw_groups.len() as u32,
pGroups: if raw_groups.is_empty() {
std::ptr::null()
} else {
raw_groups.as_ptr()
},
maxPipelineRayRecursionDepth: max_recursion_depth,
pLibraryInfo: std::ptr::null(),
pLibraryInterface: std::ptr::null(),
pDynamicState: std::ptr::null(),
layout: layout.raw(),
basePipelineHandle: 0,
basePipelineIndex: -1,
};
let mut handle: VkPipeline = 0;
check(unsafe {
create(
device.inner.handle,
0, 0, 1,
&info,
std::ptr::null(),
&mut handle,
)
})?;
Ok(Self {
handle,
device: Arc::clone(&device.inner),
group_count: raw_groups.len() as u32,
})
}
pub fn raw(&self) -> VkPipeline {
self.handle
}
pub fn group_count(&self) -> u32 {
self.group_count
}
pub fn get_shader_group_handles(
&self,
first_group: u32,
group_count: u32,
dst: &mut [u8],
) -> Result<()> {
let f = self
.device
.dispatch
.vkGetRayTracingShaderGroupHandlesKHR
.ok_or(Error::MissingFunction(
"vkGetRayTracingShaderGroupHandlesKHR",
))?;
check(unsafe {
f(
self.device.handle,
self.handle,
first_group,
group_count,
dst.len(),
dst.as_mut_ptr() as *mut std::ffi::c_void,
)
})
}
}
impl Drop for RayTracingPipeline {
fn drop(&mut self) {
if let Some(destroy) = self.device.dispatch.vkDestroyPipeline {
unsafe { destroy(self.device.handle, self.handle, std::ptr::null()) };
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct RayTracingPipelineProperties {
pub shader_group_handle_size: u32,
pub max_ray_recursion_depth: u32,
pub max_shader_group_stride: u32,
pub shader_group_base_alignment: u32,
pub shader_group_handle_alignment: u32,
pub max_ray_dispatch_invocation_count: u32,
pub max_ray_hit_attribute_size: u32,
}