use {
super::{Buffer, BufferInfo, DriverError, device::Device},
ash::vk,
derive_builder::{Builder, UninitializedFieldError},
log::warn,
std::{
ffi::c_void,
mem::{replace, size_of_val},
ops::Deref,
sync::Arc,
thread::panicking,
},
vk_sync::AccessType,
};
#[cfg(feature = "parking_lot")]
use parking_lot::Mutex;
#[cfg(not(feature = "parking_lot"))]
use std::sync::Mutex;
#[derive(Debug)]
pub struct AccelerationStructure {
access: Mutex<AccessType>,
accel_struct: (vk::AccelerationStructureKHR, Buffer),
device: Arc<Device>,
pub info: AccelerationStructureInfo,
}
impl AccelerationStructure {
#[profiling::function]
pub fn create(
device: &Arc<Device>,
info: impl Into<AccelerationStructureInfo>,
) -> Result<Self, DriverError> {
debug_assert!(device.physical_device.accel_struct_properties.is_some());
let info = info.into();
let buffer = Buffer::create(
device,
BufferInfo::device_mem(
info.size,
vk::BufferUsageFlags::ACCELERATION_STRUCTURE_STORAGE_KHR
| vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS,
),
)?;
let accel_struct = {
let create_info = vk::AccelerationStructureCreateInfoKHR::default()
.ty(info.ty)
.buffer(*buffer)
.size(info.size);
let accel_struct_ext = Device::expect_accel_struct_ext(device);
unsafe { accel_struct_ext.create_acceleration_structure(&create_info, None) }.map_err(
|err| {
warn!("{err}");
match err {
vk::Result::ERROR_INVALID_OPAQUE_CAPTURE_ADDRESS => {
DriverError::InvalidData
}
vk::Result::ERROR_OUT_OF_HOST_MEMORY => DriverError::OutOfMemory,
_ => DriverError::Unsupported,
}
},
)?
};
let device = Arc::clone(device);
Ok(AccelerationStructure {
access: Mutex::new(AccessType::Nothing),
accel_struct: (accel_struct, buffer),
device,
info,
})
}
#[profiling::function]
pub fn access(this: &Self, access: AccessType) -> AccessType {
#[cfg_attr(not(feature = "parking_lot"), allow(unused_mut))]
let mut access_guard = this.access.lock();
#[cfg(not(feature = "parking_lot"))]
let mut access_guard = access_guard.unwrap();
replace(&mut access_guard, access)
}
#[profiling::function]
pub fn device_address(this: &Self) -> vk::DeviceAddress {
let accel_struct_ext = Device::expect_accel_struct_ext(&this.device);
unsafe {
accel_struct_ext.get_acceleration_structure_device_address(
&vk::AccelerationStructureDeviceAddressInfoKHR::default()
.acceleration_structure(this.accel_struct.0),
)
}
}
pub fn instance_slice(instances: &[vk::AccelerationStructureInstanceKHR]) -> &[u8] {
use std::slice::from_raw_parts;
unsafe { from_raw_parts(instances.as_ptr() as *const _, size_of_val(instances)) }
}
#[profiling::function]
pub fn size_of(
device: &Device,
info: &AccelerationStructureGeometryInfo<impl AsRef<AccelerationStructureGeometry>>,
) -> AccelerationStructureSize {
use std::cell::RefCell;
#[derive(Default)]
struct Tls {
geometries: Vec<vk::AccelerationStructureGeometryKHR<'static>>,
max_primitive_counts: Vec<u32>,
}
thread_local! {
static TLS: RefCell<Tls> = Default::default();
}
TLS.with_borrow_mut(|tls| {
tls.geometries.clear();
tls.max_primitive_counts.clear();
for info in info.geometries.iter().map(AsRef::as_ref) {
tls.geometries.push(info.into());
tls.max_primitive_counts.push(info.max_primitive_count);
}
let info = vk::AccelerationStructureBuildGeometryInfoKHR::default()
.ty(info.ty)
.flags(info.flags)
.geometries(&tls.geometries);
let mut sizes = vk::AccelerationStructureBuildSizesInfoKHR::default();
let accel_struct_ext = Device::expect_accel_struct_ext(device);
unsafe {
accel_struct_ext.get_acceleration_structure_build_sizes(
vk::AccelerationStructureBuildTypeKHR::HOST_OR_DEVICE,
&info,
&tls.max_primitive_counts,
&mut sizes,
);
}
AccelerationStructureSize {
create_size: sizes.acceleration_structure_size,
build_size: sizes.build_scratch_size,
update_size: sizes.update_scratch_size,
}
})
}
}
impl Deref for AccelerationStructure {
type Target = vk::AccelerationStructureKHR;
fn deref(&self) -> &Self::Target {
&self.accel_struct.0
}
}
impl Drop for AccelerationStructure {
#[profiling::function]
fn drop(&mut self) {
if panicking() {
return;
}
let accel_struct_ext = Device::expect_accel_struct_ext(&self.device);
unsafe {
accel_struct_ext.destroy_acceleration_structure(self.accel_struct.0, None);
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct AccelerationStructureGeometry {
pub max_primitive_count: u32,
pub flags: vk::GeometryFlagsKHR,
pub geometry: AccelerationStructureGeometryData,
}
impl AccelerationStructureGeometry {
pub fn new(max_primitive_count: u32, geometry: AccelerationStructureGeometryData) -> Self {
let flags = Default::default();
Self {
max_primitive_count,
flags,
geometry,
}
}
pub fn opaque(max_primitive_count: u32, geometry: AccelerationStructureGeometryData) -> Self {
Self::new(max_primitive_count, geometry).flags(vk::GeometryFlagsKHR::OPAQUE)
}
pub fn flags(mut self, flags: vk::GeometryFlagsKHR) -> Self {
self.flags = flags;
self
}
}
impl<T> AsRef<AccelerationStructureGeometry> for (AccelerationStructureGeometry, T) {
fn as_ref(&self) -> &AccelerationStructureGeometry {
&self.0
}
}
impl<'b> From<&'b AccelerationStructureGeometry> for vk::AccelerationStructureGeometryKHR<'_> {
fn from(&value: &'b AccelerationStructureGeometry) -> Self {
value.into()
}
}
impl From<AccelerationStructureGeometry> for vk::AccelerationStructureGeometryKHR<'_> {
fn from(value: AccelerationStructureGeometry) -> Self {
Self::default()
.flags(value.flags)
.geometry(value.geometry.into())
.geometry_type(value.geometry.into())
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum AccelerationStructureGeometryData {
AABBs {
addr: DeviceOrHostAddress,
stride: vk::DeviceSize,
},
Instances {
addr: DeviceOrHostAddress,
array_of_pointers: bool,
},
Triangles {
index_addr: DeviceOrHostAddress,
index_type: vk::IndexType,
max_vertex: u32,
transform_addr: Option<DeviceOrHostAddress>,
vertex_addr: DeviceOrHostAddress,
vertex_format: vk::Format,
vertex_stride: vk::DeviceSize,
},
}
impl AccelerationStructureGeometryData {
pub fn aabbs(addr: impl Into<DeviceOrHostAddress>, stride: vk::DeviceSize) -> Self {
let addr = addr.into();
Self::AABBs { addr, stride }
}
pub fn instances(addr: impl Into<DeviceOrHostAddress>) -> Self {
let addr = addr.into();
Self::Instances {
addr,
array_of_pointers: false,
}
}
pub fn instance_pointers(addr: impl Into<DeviceOrHostAddress>) -> Self {
let addr = addr.into();
Self::Instances {
addr,
array_of_pointers: true,
}
}
pub fn triangles(
index_addr: impl Into<DeviceOrHostAddress>,
index_type: vk::IndexType,
max_vertex: u32,
transform_addr: impl Into<Option<DeviceOrHostAddress>>,
vertex_addr: impl Into<DeviceOrHostAddress>,
vertex_format: vk::Format,
vertex_stride: vk::DeviceSize,
) -> Self {
let index_addr = index_addr.into();
let transform_addr = transform_addr.into();
let vertex_addr = vertex_addr.into();
Self::Triangles {
index_addr,
index_type,
max_vertex,
transform_addr,
vertex_addr,
vertex_format,
vertex_stride,
}
}
}
impl From<AccelerationStructureGeometryData> for vk::GeometryTypeKHR {
fn from(value: AccelerationStructureGeometryData) -> Self {
match value {
AccelerationStructureGeometryData::AABBs { .. } => Self::AABBS,
AccelerationStructureGeometryData::Instances { .. } => Self::INSTANCES,
AccelerationStructureGeometryData::Triangles { .. } => Self::TRIANGLES,
}
}
}
impl From<AccelerationStructureGeometryData> for vk::AccelerationStructureGeometryDataKHR<'_> {
fn from(value: AccelerationStructureGeometryData) -> Self {
match value {
AccelerationStructureGeometryData::AABBs { addr, stride } => Self {
aabbs: vk::AccelerationStructureGeometryAabbsDataKHR::default()
.data(addr.into())
.stride(stride),
},
AccelerationStructureGeometryData::Instances {
addr,
array_of_pointers,
} => Self {
instances: vk::AccelerationStructureGeometryInstancesDataKHR::default()
.array_of_pointers(array_of_pointers)
.data(addr.into()),
},
AccelerationStructureGeometryData::Triangles {
index_addr,
index_type,
max_vertex,
transform_addr,
vertex_addr,
vertex_format,
vertex_stride,
} => Self {
triangles: vk::AccelerationStructureGeometryTrianglesDataKHR::default()
.index_data(index_addr.into())
.index_type(index_type)
.max_vertex(max_vertex)
.transform_data(transform_addr.map(Into::into).unwrap_or_default())
.vertex_data(vertex_addr.into())
.vertex_format(vertex_format)
.vertex_stride(vertex_stride),
},
}
}
}
#[derive(Clone, Debug)]
pub struct AccelerationStructureGeometryInfo<G> {
pub ty: vk::AccelerationStructureTypeKHR,
pub flags: vk::BuildAccelerationStructureFlagsKHR,
pub geometries: Box<[G]>,
}
impl<G> AccelerationStructureGeometryInfo<G> {
pub fn blas(geometries: impl Into<Box<[G]>>) -> Self {
let geometries = geometries.into();
Self {
ty: vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL,
flags: Default::default(),
geometries,
}
}
pub fn tlas(geometries: impl Into<Box<[G]>>) -> Self {
let geometries = geometries.into();
Self {
ty: vk::AccelerationStructureTypeKHR::TOP_LEVEL,
flags: Default::default(),
geometries,
}
}
pub fn flags(mut self, flags: vk::BuildAccelerationStructureFlagsKHR) -> Self {
self.flags = flags;
self
}
}
#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[builder(
build_fn(
private,
name = "fallible_build",
error = "AccelerationStructureInfoBuilderError"
),
derive(Clone, Copy, Debug),
pattern = "owned"
)]
#[non_exhaustive]
pub struct AccelerationStructureInfo {
#[builder(default = "vk::AccelerationStructureTypeKHR::GENERIC")]
pub ty: vk::AccelerationStructureTypeKHR,
pub size: vk::DeviceSize,
}
impl AccelerationStructureInfo {
#[inline(always)]
pub const fn blas(size: vk::DeviceSize) -> Self {
Self {
ty: vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL,
size,
}
}
#[allow(clippy::new_ret_no_self)]
pub fn builder() -> AccelerationStructureInfoBuilder {
Default::default()
}
#[inline(always)]
pub const fn tlas(size: vk::DeviceSize) -> Self {
Self {
ty: vk::AccelerationStructureTypeKHR::TOP_LEVEL,
size,
}
}
#[inline(always)]
pub fn to_builder(self) -> AccelerationStructureInfoBuilder {
AccelerationStructureInfoBuilder {
ty: Some(self.ty),
size: Some(self.size),
}
}
}
impl From<AccelerationStructureInfo> for () {
fn from(_: AccelerationStructureInfo) -> Self {}
}
impl AccelerationStructureInfoBuilder {
#[inline(always)]
pub fn build(self) -> AccelerationStructureInfo {
match self.fallible_build() {
Err(AccelerationStructureInfoBuilderError(err)) => panic!("{err}"),
Ok(info) => info,
}
}
}
#[derive(Debug)]
struct AccelerationStructureInfoBuilderError(UninitializedFieldError);
impl From<UninitializedFieldError> for AccelerationStructureInfoBuilderError {
fn from(err: UninitializedFieldError) -> Self {
Self(err)
}
}
#[derive(Clone, Copy, Debug)]
pub struct AccelerationStructureSize {
pub build_size: vk::DeviceSize,
pub create_size: vk::DeviceSize,
pub update_size: vk::DeviceSize,
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum DeviceOrHostAddress {
DeviceAddress(vk::DeviceAddress),
HostAddress(*mut c_void),
}
impl From<vk::DeviceAddress> for DeviceOrHostAddress {
fn from(device_address: vk::DeviceAddress) -> Self {
Self::DeviceAddress(device_address)
}
}
impl From<*mut c_void> for DeviceOrHostAddress {
fn from(host_address: *mut c_void) -> Self {
Self::HostAddress(host_address)
}
}
unsafe impl Send for DeviceOrHostAddress {}
unsafe impl Sync for DeviceOrHostAddress {}
impl From<DeviceOrHostAddress> for vk::DeviceOrHostAddressConstKHR {
fn from(value: DeviceOrHostAddress) -> Self {
match value {
DeviceOrHostAddress::DeviceAddress(device_address) => Self { device_address },
DeviceOrHostAddress::HostAddress(host_address) => Self { host_address },
}
}
}
impl From<DeviceOrHostAddress> for vk::DeviceOrHostAddressKHR {
fn from(value: DeviceOrHostAddress) -> Self {
match value {
DeviceOrHostAddress::DeviceAddress(device_address) => Self { device_address },
DeviceOrHostAddress::HostAddress(host_address) => Self { host_address },
}
}
}
#[cfg(test)]
mod tests {
use super::*;
type Info = AccelerationStructureInfo;
type Builder = AccelerationStructureInfoBuilder;
#[test]
pub fn accel_struct_info() {
let info = Info::blas(32);
let builder = info.to_builder().build();
assert_eq!(info, builder);
}
#[test]
pub fn accel_struct_info_builder() {
let info = Info {
size: 32,
ty: vk::AccelerationStructureTypeKHR::GENERIC,
};
let builder = Builder::default().size(32).build();
assert_eq!(info, builder);
}
#[test]
#[should_panic(expected = "Field not initialized: size")]
pub fn accel_struct_info_builder_uninit_size() {
Builder::default().build();
}
}