Skip to main content

04_ray_intersection/
04_ray_intersection.rs

1use apple_metal::{resource_options, MetalBuffer, MetalDevice};
2use apple_mps::{
3    acceleration_structure_status, cull_mode, data_type, intersection_data_type, intersection_type,
4    polygon_type, ray_data_type, PolygonAccelerationStructure, RayIntersector, SVGF,
5};
6
7#[repr(C)]
8#[derive(Clone, Copy, Debug)]
9struct PackedFloat3 {
10    x: f32,
11    y: f32,
12    z: f32,
13}
14
15#[repr(C)]
16#[derive(Clone, Copy, Debug)]
17struct PackedRayOriginDirection {
18    origin: PackedFloat3,
19    direction: PackedFloat3,
20}
21
22#[repr(C)]
23#[derive(Clone, Copy, Debug)]
24struct IntersectionDistancePrimitiveIndex {
25    distance: f32,
26    primitive_index: u32,
27}
28
29fn as_bytes<T>(values: &[T]) -> &[u8] {
30    unsafe {
31        core::slice::from_raw_parts(values.as_ptr().cast::<u8>(), core::mem::size_of_val(values))
32    }
33}
34
35fn read_struct<T: Copy>(buffer: &MetalBuffer) -> T {
36    let ptr = buffer.contents().expect("buffer contents").cast::<T>();
37    unsafe { *ptr }
38}
39
40#[allow(clippy::too_many_lines)]
41fn main() {
42    let device = MetalDevice::system_default().expect("no Metal device available");
43    let queue = device.new_command_queue().expect("command queue");
44
45    let vertices = [
46        [-1.0_f32, -1.0, 0.0, 0.0],
47        [1.0_f32, -1.0, 0.0, 0.0],
48        [0.0_f32, 1.0, 0.0, 0.0],
49    ];
50    let vertex_buffer = device
51        .new_buffer(
52            core::mem::size_of_val(&vertices),
53            resource_options::STORAGE_MODE_SHARED,
54        )
55        .expect("vertex buffer");
56    let _ = vertex_buffer.write_bytes(as_bytes(&vertices));
57
58    let acceleration_structure =
59        PolygonAccelerationStructure::new(&device).expect("polygon acceleration structure");
60    acceleration_structure.set_polygon_type(polygon_type::TRIANGLE);
61    acceleration_structure.set_vertex_stride(core::mem::size_of::<[f32; 4]>());
62    acceleration_structure.set_index_type(data_type::UINT32);
63    acceleration_structure.set_vertex_buffer(Some(&vertex_buffer));
64    acceleration_structure.set_index_buffer(None);
65    acceleration_structure.set_polygon_count(1);
66    acceleration_structure.rebuild();
67    assert_eq!(
68        acceleration_structure.status(),
69        acceleration_structure_status::BUILT,
70        "acceleration structure should build successfully"
71    );
72
73    let ray = PackedRayOriginDirection {
74        origin: PackedFloat3 {
75            x: 0.0,
76            y: 0.0,
77            z: 1.0,
78        },
79        direction: PackedFloat3 {
80            x: 0.0,
81            y: 0.0,
82            z: -1.0,
83        },
84    };
85    let miss = IntersectionDistancePrimitiveIndex {
86        distance: -1.0,
87        primitive_index: u32::MAX,
88    };
89
90    let ray_buffer = device
91        .new_buffer(
92            core::mem::size_of::<PackedRayOriginDirection>(),
93            resource_options::STORAGE_MODE_SHARED,
94        )
95        .expect("ray buffer");
96    let intersection_buffer = device
97        .new_buffer(
98            core::mem::size_of::<IntersectionDistancePrimitiveIndex>(),
99            resource_options::STORAGE_MODE_SHARED,
100        )
101        .expect("intersection buffer");
102    let _ = ray_buffer.write_bytes(as_bytes(&[ray]));
103    let _ = intersection_buffer.write_bytes(as_bytes(&[miss]));
104
105    let intersector = RayIntersector::new(&device).expect("ray intersector");
106    intersector.set_cull_mode(cull_mode::NONE);
107    intersector.set_ray_data_type(ray_data_type::PACKED_ORIGIN_DIRECTION);
108    intersector.set_intersection_data_type(intersection_data_type::DISTANCE_PRIMITIVE_INDEX);
109    assert!(
110        intersector.recommended_minimum_ray_batch_size(1) >= 1,
111        "recommended ray batch size should be positive"
112    );
113
114    let command_buffer = queue.new_command_buffer().expect("command buffer");
115    intersector.encode_intersection(
116        &command_buffer,
117        intersection_type::NEAREST,
118        &ray_buffer,
119        0,
120        &intersection_buffer,
121        0,
122        1,
123        &acceleration_structure,
124    );
125    command_buffer.commit();
126    command_buffer.wait_until_completed();
127
128    let intersection = read_struct::<IntersectionDistancePrimitiveIndex>(&intersection_buffer);
129    assert!(
130        (intersection.distance - 1.0).abs() < 1.0e-4,
131        "unexpected hit distance: {intersection:?}"
132    );
133    assert_eq!(
134        intersection.primitive_index, 0,
135        "expected to hit triangle 0"
136    );
137
138    let svgf = SVGF::new(&device).expect("svgf");
139    svgf.set_depth_weight(0.75);
140    svgf.set_normal_weight(64.0);
141    svgf.set_luminance_weight(2.5);
142    svgf.set_channel_count(3);
143    svgf.set_channel_count2(1);
144    assert!((svgf.depth_weight() - 0.75).abs() < f32::EPSILON);
145    assert!((svgf.normal_weight() - 64.0).abs() < f32::EPSILON);
146    assert!((svgf.luminance_weight() - 2.5).abs() < f32::EPSILON);
147    assert_eq!(svgf.channel_count(), 3);
148    assert_eq!(svgf.channel_count2(), 1);
149
150    println!(
151        "ray smoke passed: distance={:.3} primitive_index={} batch_size={}",
152        intersection.distance,
153        intersection.primitive_index,
154        intersector.recommended_minimum_ray_batch_size(1)
155    );
156}