1use std::{marker::PhantomData, ptr};
2
3use crate::Device;
4
5use anyhow::Result;
6use embree4_sys::{RTCRayHit, RTC_INVALID_GEOMETRY_ID};
7
8use super::Geometry;
9
10pub trait UserGeometryImpl {
15 fn bounds(&self) -> embree4_sys::RTCBounds;
17
18 fn intersect(
30 &self,
31 geom_id: u32,
32 prim_id: u32,
33 ctx: &embree4_sys::RTCRayQueryContext,
34 ray_hit: &mut embree4_sys::RTCRayHit,
35 );
36}
37
38pub struct UserGeometry<T: UserGeometryImpl> {
39 pub handle: embree4_sys::RTCGeometry,
40 data: PhantomData<T>,
41}
42
43#[allow(clippy::missing_safety_doc)]
44impl<T: UserGeometryImpl> UserGeometry<T> {
45 pub fn try_new(device: &Device, data: &T) -> Result<Self> {
59 let handle = unsafe {
60 embree4_sys::rtcNewGeometry(device.handle, embree4_sys::RTCGeometryType::USER)
61 };
62 device.error_or((), "Could not create user geometry")?;
63
64 unsafe {
65 embree4_sys::rtcSetGeometryUserPrimitiveCount(handle, 1);
66 }
67 device.error_or((), "Could not set user geometry primitive count")?;
68
69 let data_ptr = data as *const _ as _;
70
71 unsafe {
72 embree4_sys::rtcSetGeometryUserData(handle, data_ptr);
73 }
74 device.error_or((), "Could not set user geometry data")?;
75
76 unsafe {
77 embree4_sys::rtcSetGeometryBoundsFunction(
78 handle,
79 Some(internal_bounds_fn::<T>),
80 data_ptr,
81 );
82 };
83 device.error_or((), "Could not set user geometry bounds function")?;
84
85 unsafe {
86 embree4_sys::rtcSetGeometryIntersectFunction(handle, Some(internal_intersect_fn::<T>));
87 }
88 device.error_or((), "Could not set user geometry intersect function")?;
89
90 unsafe {
103 embree4_sys::rtcCommitGeometry(handle);
104 }
105 device.error_or((), "Could not commit user geometry")?;
106
107 Ok(Self {
108 handle,
109 data: PhantomData,
110 })
111 }
112}
113
114impl<T: UserGeometryImpl> Geometry for UserGeometry<T> {
115 fn handle(&self) -> &embree4_sys::RTCGeometry {
116 &self.handle
117 }
118}
119
120impl<T: UserGeometryImpl> Drop for UserGeometry<T> {
121 fn drop(&mut self) {
122 unsafe {
123 embree4_sys::rtcReleaseGeometry(self.handle);
124 }
125 }
126}
127
128unsafe extern "C" fn internal_bounds_fn<T: UserGeometryImpl>(
129 args: *const embree4_sys::RTCBoundsFunctionArguments,
130) {
131 let args = *args;
132 let geom_ptr = args.geometryUserPtr as *const T;
133 let geom = ptr::read(geom_ptr);
134
135 *args.bounds_o = geom.bounds();
136}
137
138unsafe extern "C" fn internal_intersect_fn<T: UserGeometryImpl>(
139 args: *const embree4_sys::RTCIntersectFunctionNArguments,
140) {
141 let args = &*args;
142 let geom_ptr = args.geometryUserPtr as *const T;
143 let geom = ptr::read(geom_ptr);
144
145 let rayhit_n = args.rayhit as *mut f32;
146
147 let ray_n = rayhit_n;
148 let hit_n = rayhit_n.add(12 * args.N as usize);
149
150 let valid_ptr = args.valid as *const u32;
151 let valid = std::slice::from_raw_parts(valid_ptr, args.N as usize);
152
153 let context = &*(args.context as *const embree4_sys::RTCRayQueryContext);
154
155 let n = args.N as usize;
156 for (i, valid) in valid.iter().enumerate() {
157 if *valid == 0 {
158 continue;
159 }
160
161 let org_x = ray_n.add(offset(0, n, i));
162 let org_y = ray_n.add(offset(1, n, i));
163 let org_z = ray_n.add(offset(2, n, i));
164 let tnear = ray_n.add(offset(3, n, i));
165
166 let dir_x = ray_n.add(offset(4, n, i));
167 let dir_y = ray_n.add(offset(5, n, i));
168 let dir_z = ray_n.add(offset(6, n, i));
169 let time = ray_n.add(offset(7, n, i));
170
171 let tfar = ray_n.add(offset(8, n, i));
172 let mask = ray_n.add(offset(9, n, i)) as *mut u32;
173 let id = ray_n.add(offset(10, n, i)) as *mut u32;
174 let flags = ray_n.add(offset(11, n, i)) as *mut u32;
175
176 let ng_x = hit_n.add(offset(0, n, i));
177 let ng_y = hit_n.add(offset(1, n, i));
178 let ng_z = hit_n.add(offset(2, n, i));
179
180 let u = hit_n.add(offset(3, n, i));
181 let v = hit_n.add(offset(4, n, i));
182
183 let prim_id = hit_n.add(offset(5, n, i)) as *mut u32;
184 let geom_id = hit_n.add(offset(6, n, i)) as *mut u32;
185 let inst_id = hit_n.add(offset(7, n, i)) as *mut u32;
186
187 let mut ray_hit = RTCRayHit {
188 ray: embree4_sys::RTCRay {
189 org_x: *org_x,
190 org_y: *org_y,
191 org_z: *org_z,
192 tnear: *tnear,
193 dir_x: *dir_x,
194 dir_y: *dir_y,
195 dir_z: *dir_z,
196 time: *time,
197 tfar: *tfar,
198 mask: *mask,
199 id: *id,
200 flags: *flags,
201 },
202 hit: embree4_sys::RTCHit {
203 Ng_x: *ng_x,
204 Ng_y: *ng_y,
205 Ng_z: *ng_z,
206 u: *u,
207 v: *v,
208 primID: *prim_id,
209 geomID: *geom_id,
210 instID: [*inst_id],
211 },
212 };
213
214 geom.intersect(args.geomID, args.primID, context, &mut ray_hit);
215
216 if ray_hit.hit.geomID != RTC_INVALID_GEOMETRY_ID {
217 *tfar = ray_hit.ray.tfar;
218
219 *ng_x = ray_hit.hit.Ng_x;
220 *ng_y = ray_hit.hit.Ng_y;
221 *ng_z = ray_hit.hit.Ng_z;
222
223 *u = ray_hit.hit.u;
224 *v = ray_hit.hit.v;
225
226 *prim_id = ray_hit.hit.primID;
227 *geom_id = ray_hit.hit.geomID;
228 *inst_id = ray_hit.hit.instID[0];
229 }
230 }
231}
232
233#[inline(always)]
234fn offset(offset: usize, n: usize, i: usize) -> usize {
235 offset * n + i
236}