04_ray_intersection/
04_ray_intersection.rs1use apple_metal::{resource_options, MetalBuffer, MetalDevice};
2use apple_mps::{
3 acceleration_structure_status, cull_mode, data_type, intersection_data_type, intersection_type,
4 polygon_type, ray_data_type, PolygonAccelerationStructure, RayIntersector, SVGF,
5};
6
7#[repr(C)]
8#[derive(Clone, Copy, Debug)]
9struct PackedFloat3 {
10 x: f32,
11 y: f32,
12 z: f32,
13}
14
15#[repr(C)]
16#[derive(Clone, Copy, Debug)]
17struct PackedRayOriginDirection {
18 origin: PackedFloat3,
19 direction: PackedFloat3,
20}
21
22#[repr(C)]
23#[derive(Clone, Copy, Debug)]
24struct IntersectionDistancePrimitiveIndex {
25 distance: f32,
26 primitive_index: u32,
27}
28
29fn as_bytes<T>(values: &[T]) -> &[u8] {
30 unsafe {
31 core::slice::from_raw_parts(values.as_ptr().cast::<u8>(), core::mem::size_of_val(values))
32 }
33}
34
35fn read_struct<T: Copy>(buffer: &MetalBuffer) -> T {
36 let ptr = buffer.contents().expect("buffer contents").cast::<T>();
37 unsafe { *ptr }
38}
39
40#[allow(clippy::too_many_lines)]
41fn main() {
42 let device = MetalDevice::system_default().expect("no Metal device available");
43 let queue = device.new_command_queue().expect("command queue");
44
45 let vertices = [
46 [-1.0_f32, -1.0, 0.0, 0.0],
47 [1.0_f32, -1.0, 0.0, 0.0],
48 [0.0_f32, 1.0, 0.0, 0.0],
49 ];
50 let vertex_buffer = device
51 .new_buffer(
52 core::mem::size_of_val(&vertices),
53 resource_options::STORAGE_MODE_SHARED,
54 )
55 .expect("vertex buffer");
56 let _ = vertex_buffer.write_bytes(as_bytes(&vertices));
57
58 let acceleration_structure =
59 PolygonAccelerationStructure::new(&device).expect("polygon acceleration structure");
60 acceleration_structure.set_polygon_type(polygon_type::TRIANGLE);
61 acceleration_structure.set_vertex_stride(core::mem::size_of::<[f32; 4]>());
62 acceleration_structure.set_index_type(data_type::UINT32);
63 acceleration_structure.set_vertex_buffer(Some(&vertex_buffer));
64 acceleration_structure.set_index_buffer(None);
65 acceleration_structure.set_polygon_count(1);
66 acceleration_structure.rebuild();
67 assert_eq!(
68 acceleration_structure.status(),
69 acceleration_structure_status::BUILT,
70 "acceleration structure should build successfully"
71 );
72
73 let ray = PackedRayOriginDirection {
74 origin: PackedFloat3 {
75 x: 0.0,
76 y: 0.0,
77 z: 1.0,
78 },
79 direction: PackedFloat3 {
80 x: 0.0,
81 y: 0.0,
82 z: -1.0,
83 },
84 };
85 let miss = IntersectionDistancePrimitiveIndex {
86 distance: -1.0,
87 primitive_index: u32::MAX,
88 };
89
90 let ray_buffer = device
91 .new_buffer(
92 core::mem::size_of::<PackedRayOriginDirection>(),
93 resource_options::STORAGE_MODE_SHARED,
94 )
95 .expect("ray buffer");
96 let intersection_buffer = device
97 .new_buffer(
98 core::mem::size_of::<IntersectionDistancePrimitiveIndex>(),
99 resource_options::STORAGE_MODE_SHARED,
100 )
101 .expect("intersection buffer");
102 let _ = ray_buffer.write_bytes(as_bytes(&[ray]));
103 let _ = intersection_buffer.write_bytes(as_bytes(&[miss]));
104
105 let intersector = RayIntersector::new(&device).expect("ray intersector");
106 intersector.set_cull_mode(cull_mode::NONE);
107 intersector.set_ray_data_type(ray_data_type::PACKED_ORIGIN_DIRECTION);
108 intersector.set_intersection_data_type(intersection_data_type::DISTANCE_PRIMITIVE_INDEX);
109 assert!(
110 intersector.recommended_minimum_ray_batch_size(1) >= 1,
111 "recommended ray batch size should be positive"
112 );
113
114 let command_buffer = queue.new_command_buffer().expect("command buffer");
115 intersector.encode_intersection(
116 &command_buffer,
117 intersection_type::NEAREST,
118 &ray_buffer,
119 0,
120 &intersection_buffer,
121 0,
122 1,
123 &acceleration_structure,
124 );
125 command_buffer.commit();
126 command_buffer.wait_until_completed();
127
128 let intersection = read_struct::<IntersectionDistancePrimitiveIndex>(&intersection_buffer);
129 assert!(
130 (intersection.distance - 1.0).abs() < 1.0e-4,
131 "unexpected hit distance: {intersection:?}"
132 );
133 assert_eq!(
134 intersection.primitive_index, 0,
135 "expected to hit triangle 0"
136 );
137
138 let svgf = SVGF::new(&device).expect("svgf");
139 svgf.set_depth_weight(0.75);
140 svgf.set_normal_weight(64.0);
141 svgf.set_luminance_weight(2.5);
142 svgf.set_channel_count(3);
143 svgf.set_channel_count2(1);
144 assert!((svgf.depth_weight() - 0.75).abs() < f32::EPSILON);
145 assert!((svgf.normal_weight() - 64.0).abs() < f32::EPSILON);
146 assert!((svgf.luminance_weight() - 2.5).abs() < f32::EPSILON);
147 assert_eq!(svgf.channel_count(), 3);
148 assert_eq!(svgf.channel_count2(), 1);
149
150 println!(
151 "ray smoke passed: distance={:.3} primitive_index={} batch_size={}",
152 intersection.distance,
153 intersection.primitive_index,
154 intersector.recommended_minimum_ray_batch_size(1)
155 );
156}