use std::sync::Arc;
use vulkane::safe::{
AccelerationStructure, AccelerationStructureBuildFlags, AccelerationStructureBuildMode,
AccelerationStructureBuildType, AccelerationStructureCreateInfo, AccelerationStructureGeometry,
AccelerationStructureType, ApiVersion, Buffer, BufferCreateInfo, BufferUsage, BuildRange,
DeviceCreateInfo, Instance, InstanceCreateInfo, Queue, QueueCreateInfo, QueueFlags,
ShaderBindingRegion,
};
fn bootstrap() -> Option<(vulkane::safe::Device, vulkane::safe::PhysicalDevice, u32)> {
let instance = Instance::new(InstanceCreateInfo {
application_name: Some("vulkane-raytracing-test"),
api_version: ApiVersion::V1_0,
..Default::default()
})
.ok()?;
let physical = instance
.enumerate_physical_devices()
.ok()?
.into_iter()
.find(|pd| {
pd.queue_family_properties()
.iter()
.any(|q| q.queue_flags().contains(QueueFlags::COMPUTE))
})?;
let qf = physical
.queue_family_properties()
.iter()
.position(|q| q.queue_flags().contains(QueueFlags::COMPUTE))? as u32;
let device = physical
.create_device(DeviceCreateInfo {
queue_create_infos: &[QueueCreateInfo::single(qf)],
..Default::default()
})
.ok()?;
Some((device, physical, qf))
}
#[test]
fn acceleration_structure_build_sizes_rejects_length_mismatch() {
let Some((device, _physical, _qf)) = bootstrap() else {
return;
};
let r = device.acceleration_structure_build_sizes(
AccelerationStructureBuildType::Device,
AccelerationStructureType::BottomLevel,
&[AccelerationStructureGeometry::Aabbs {
data_address: 0,
stride: 24,
}],
&[], AccelerationStructureBuildFlags::default(),
);
match r {
Err(vulkane::safe::Error::InvalidArgument(msg)) => {
assert!(msg.contains("length"));
}
Err(vulkane::safe::Error::MissingFunction(_)) => {
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
}
#[test]
fn acceleration_structure_build_sizes_graceful_missing_function() {
let Some((device, _physical, _qf)) = bootstrap() else {
return;
};
match device.acceleration_structure_build_sizes(
AccelerationStructureBuildType::Device,
AccelerationStructureType::BottomLevel,
&[AccelerationStructureGeometry::Aabbs {
data_address: 0,
stride: 24,
}],
&[16], AccelerationStructureBuildFlags::PREFER_FAST_TRACE,
) {
Ok(sizes) => {
let _ = sizes.acceleration_structure_size;
let _ = sizes.build_scratch_size;
let _ = sizes.update_scratch_size;
}
Err(vulkane::safe::Error::MissingFunction(name)) => {
assert_eq!(name, "vkGetAccelerationStructureBuildSizesKHR");
}
Err(other) => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn acceleration_structure_new_graceful_missing_function() {
let Some((device, physical, _qf)) = bootstrap() else {
return;
};
let Ok(buffer) = Buffer::new(
&device,
BufferCreateInfo {
size: 4096,
usage: BufferUsage::STORAGE_BUFFER,
},
) else {
return;
};
let req = buffer.memory_requirements();
let Some(mt) = physical.find_memory_type(
req.memory_type_bits,
vulkane::safe::MemoryPropertyFlags::DEVICE_LOCAL,
) else {
return;
};
let Ok(memory) = vulkane::safe::DeviceMemory::allocate(&device, req.size, mt) else {
return;
};
let _ = buffer.bind_memory(&memory, 0);
let buffer_arc = Arc::new(buffer);
match AccelerationStructure::new(
&device,
AccelerationStructureCreateInfo {
buffer: Arc::clone(&buffer_arc),
offset: 0,
size: 1024,
type_: AccelerationStructureType::BottomLevel,
_marker: std::marker::PhantomData,
},
) {
Ok(_) => {
}
Err(vulkane::safe::Error::MissingFunction(name)) => {
assert_eq!(name, "vkCreateAccelerationStructureKHR");
}
Err(vulkane::safe::Error::Vk(_)) => {
}
Err(other) => panic!("unexpected error shape: {other:?}"),
}
}
#[test]
fn build_acceleration_structure_rejects_length_mismatch() {
let Some((device, _physical, qf)) = bootstrap() else {
return;
};
let queue: Queue = device.get_queue(qf, 0);
let Some(buffer) = Buffer::new(
&device,
BufferCreateInfo {
size: 4096,
usage: BufferUsage::STORAGE_BUFFER,
},
)
.ok()
.map(Arc::new) else {
return;
};
let dst = match AccelerationStructure::new(
&device,
AccelerationStructureCreateInfo {
buffer: Arc::clone(&buffer),
offset: 0,
size: 1024,
type_: AccelerationStructureType::BottomLevel,
_marker: std::marker::PhantomData,
},
) {
Ok(a) => a,
Err(_) => {
eprintln!("SKIP: cannot create AS (extension not loaded)");
return;
}
};
let r = queue.one_shot(&device, qf, |rec| {
rec.build_acceleration_structure(
AccelerationStructureType::BottomLevel,
AccelerationStructureBuildMode::Build,
AccelerationStructureBuildFlags::default(),
&dst,
None,
&[AccelerationStructureGeometry::Aabbs {
data_address: 0,
stride: 24,
}],
&[], 0,
)
});
match r {
Err(vulkane::safe::Error::InvalidArgument(msg)) => {
assert!(msg.contains("length"));
}
Err(vulkane::safe::Error::MissingFunction(_)) | Err(vulkane::safe::Error::Vk(_)) => {
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
}
#[test]
fn trace_rays_graceful_missing_function() {
let Some((device, _physical, qf)) = bootstrap() else {
return;
};
let queue: Queue = device.get_queue(qf, 0);
let r = queue.one_shot(&device, qf, |rec| {
rec.trace_rays(
ShaderBindingRegion::default(),
ShaderBindingRegion::default(),
ShaderBindingRegion::default(),
ShaderBindingRegion::default(),
1,
1,
1,
)
});
match r {
Ok(()) => {}
Err(vulkane::safe::Error::MissingFunction(name)) => {
assert_eq!(name, "vkCmdTraceRaysKHR");
}
Err(vulkane::safe::Error::Vk(_)) => {
}
Err(other) => panic!("unexpected error shape: {other:?}"),
}
}
#[test]
fn ray_tracing_pipeline_properties_queryable() {
let instance = match Instance::new(InstanceCreateInfo::default()) {
Ok(i) => i,
Err(_) => return,
};
for pd in instance.enumerate_physical_devices().unwrap_or_default() {
let props = pd.ray_tracing_pipeline_properties();
if let Some(p) = props {
let _ = p.shader_group_handle_size;
}
}
}
#[test]
fn build_range_and_shader_binding_region_default() {
let r: BuildRange = BuildRange::default();
assert_eq!(r.primitive_count, 0);
assert_eq!(r.primitive_offset, 0);
let s: ShaderBindingRegion = ShaderBindingRegion::default();
assert_eq!(s.address, 0);
assert_eq!(s.stride, 0);
assert_eq!(s.size, 0);
}
#[test]
fn device_features_ray_query_and_rt_pipeline_toggles_exist() {
let _ = vulkane::safe::DeviceFeatures::new()
.with_ray_query()
.with_ray_tracing_pipeline()
.with_acceleration_structure();
}