embree4_rs/geometry/
user.rs

1use std::{marker::PhantomData, ptr};
2
3use crate::Device;
4
5use anyhow::Result;
6use embree4_sys::{RTCRayHit, RTC_INVALID_GEOMETRY_ID};
7
8use super::Geometry;
9
10/// The user geometry implementation.
11/// If you want to use custom geometry, you need to implement this trait.
12/// See the [examples/](https://github.com/psytrx/embree4-rs/tree/main/examples) for an example of
13/// how to implement one.
14pub trait UserGeometryImpl {
15    /// Returns the bounds of the geometry
16    fn bounds(&self) -> embree4_sys::RTCBounds;
17
18    /// Computes an intersection between the given ray and the geometry.
19    /// If an intersection is found,
20    ///
21    /// * the ray's `tfar` field
22    /// * the hit's normals (`Ng_x`, `Ng_y`, `Ng_z`)
23    /// * the hit's `u` and `v` coordinates
24    /// * the hit's `primID`, `geomID` and `instID`
25    ///
26    /// must all be updated.
27    ///
28    /// Setting `ray_hit.hit.geomID` to the supplied `geom_id` signals an intersection.
29    fn intersect(
30        &self,
31        geom_id: u32,
32        prim_id: u32,
33        ctx: &embree4_sys::RTCRayQueryContext,
34        ray_hit: &mut embree4_sys::RTCRayHit,
35    );
36}
37
38pub struct UserGeometry<T: UserGeometryImpl> {
39    pub handle: embree4_sys::RTCGeometry,
40    data: PhantomData<T>,
41}
42
43#[allow(clippy::missing_safety_doc)]
44impl<T: UserGeometryImpl> UserGeometry<T> {
45    /// Creates a new `UserGeometry` object.
46    ///
47    /// # Arguments
48    ///
49    /// * `device` - The Embree device.
50    /// * `data` - The user-defined data associated with the geometry.
51    /// * `bounds_fn` - The function pointer to the bounds function.
52    /// * `intersect_fn` - The function pointer to the intersect function.
53    /// * `occluded_fn` - The function pointer to the occluded function.
54    ///
55    /// # Returns
56    ///
57    /// A `Result` containing the `UserGeometry` object if successful, or an `anyhow::Error` if an error occurred.
58    pub fn try_new(device: &Device, data: &T) -> Result<Self> {
59        let handle = unsafe {
60            embree4_sys::rtcNewGeometry(device.handle, embree4_sys::RTCGeometryType::USER)
61        };
62        device.error_or((), "Could not create user geometry")?;
63
64        unsafe {
65            embree4_sys::rtcSetGeometryUserPrimitiveCount(handle, 1);
66        }
67        device.error_or((), "Could not set user geometry primitive count")?;
68
69        let data_ptr = data as *const _ as _;
70
71        unsafe {
72            embree4_sys::rtcSetGeometryUserData(handle, data_ptr);
73        }
74        device.error_or((), "Could not set user geometry data")?;
75
76        unsafe {
77            embree4_sys::rtcSetGeometryBoundsFunction(
78                handle,
79                Some(internal_bounds_fn::<T>),
80                data_ptr,
81            );
82        };
83        device.error_or((), "Could not set user geometry bounds function")?;
84
85        unsafe {
86            embree4_sys::rtcSetGeometryIntersectFunction(handle, Some(internal_intersect_fn::<T>));
87        }
88        device.error_or((), "Could not set user geometry intersect function")?;
89
90        // unsafe {
91        //     embree4_sys::rtcSetGeometryOccludedFunction(handle, Some(occluded_fn));
92        // }
93        // device_error_or(device, (), "Could not set user geometry occluded function")?;
94
95        // unsafe {
96        //     embree4_sys::rtcSetGeometryPointQueryFunction(
97        //         handle,
98        //         Some(internal_point_query_fn::<T>),
99        //     )
100        // }
101
102        unsafe {
103            embree4_sys::rtcCommitGeometry(handle);
104        }
105        device.error_or((), "Could not commit user geometry")?;
106
107        Ok(Self {
108            handle,
109            data: PhantomData,
110        })
111    }
112}
113
114impl<T: UserGeometryImpl> Geometry for UserGeometry<T> {
115    fn handle(&self) -> &embree4_sys::RTCGeometry {
116        &self.handle
117    }
118}
119
120impl<T: UserGeometryImpl> Drop for UserGeometry<T> {
121    fn drop(&mut self) {
122        unsafe {
123            embree4_sys::rtcReleaseGeometry(self.handle);
124        }
125    }
126}
127
128unsafe extern "C" fn internal_bounds_fn<T: UserGeometryImpl>(
129    args: *const embree4_sys::RTCBoundsFunctionArguments,
130) {
131    let args = *args;
132    let geom_ptr = args.geometryUserPtr as *const T;
133    let geom = ptr::read(geom_ptr);
134
135    *args.bounds_o = geom.bounds();
136}
137
138unsafe extern "C" fn internal_intersect_fn<T: UserGeometryImpl>(
139    args: *const embree4_sys::RTCIntersectFunctionNArguments,
140) {
141    let args = &*args;
142    let geom_ptr = args.geometryUserPtr as *const T;
143    let geom = ptr::read(geom_ptr);
144
145    let rayhit_n = args.rayhit as *mut f32;
146
147    let ray_n = rayhit_n;
148    let hit_n = rayhit_n.add(12 * args.N as usize);
149
150    let valid_ptr = args.valid as *const u32;
151    let valid = std::slice::from_raw_parts(valid_ptr, args.N as usize);
152
153    let context = &*(args.context as *const embree4_sys::RTCRayQueryContext);
154
155    let n = args.N as usize;
156    for (i, valid) in valid.iter().enumerate() {
157        if *valid == 0 {
158            continue;
159        }
160
161        let org_x = ray_n.add(offset(0, n, i));
162        let org_y = ray_n.add(offset(1, n, i));
163        let org_z = ray_n.add(offset(2, n, i));
164        let tnear = ray_n.add(offset(3, n, i));
165
166        let dir_x = ray_n.add(offset(4, n, i));
167        let dir_y = ray_n.add(offset(5, n, i));
168        let dir_z = ray_n.add(offset(6, n, i));
169        let time = ray_n.add(offset(7, n, i));
170
171        let tfar = ray_n.add(offset(8, n, i));
172        let mask = ray_n.add(offset(9, n, i)) as *mut u32;
173        let id = ray_n.add(offset(10, n, i)) as *mut u32;
174        let flags = ray_n.add(offset(11, n, i)) as *mut u32;
175
176        let ng_x = hit_n.add(offset(0, n, i));
177        let ng_y = hit_n.add(offset(1, n, i));
178        let ng_z = hit_n.add(offset(2, n, i));
179
180        let u = hit_n.add(offset(3, n, i));
181        let v = hit_n.add(offset(4, n, i));
182
183        let prim_id = hit_n.add(offset(5, n, i)) as *mut u32;
184        let geom_id = hit_n.add(offset(6, n, i)) as *mut u32;
185        let inst_id = hit_n.add(offset(7, n, i)) as *mut u32;
186
187        let mut ray_hit = RTCRayHit {
188            ray: embree4_sys::RTCRay {
189                org_x: *org_x,
190                org_y: *org_y,
191                org_z: *org_z,
192                tnear: *tnear,
193                dir_x: *dir_x,
194                dir_y: *dir_y,
195                dir_z: *dir_z,
196                time: *time,
197                tfar: *tfar,
198                mask: *mask,
199                id: *id,
200                flags: *flags,
201            },
202            hit: embree4_sys::RTCHit {
203                Ng_x: *ng_x,
204                Ng_y: *ng_y,
205                Ng_z: *ng_z,
206                u: *u,
207                v: *v,
208                primID: *prim_id,
209                geomID: *geom_id,
210                instID: [*inst_id],
211            },
212        };
213
214        geom.intersect(args.geomID, args.primID, context, &mut ray_hit);
215
216        if ray_hit.hit.geomID != RTC_INVALID_GEOMETRY_ID {
217            *tfar = ray_hit.ray.tfar;
218
219            *ng_x = ray_hit.hit.Ng_x;
220            *ng_y = ray_hit.hit.Ng_y;
221            *ng_z = ray_hit.hit.Ng_z;
222
223            *u = ray_hit.hit.u;
224            *v = ray_hit.hit.v;
225
226            *prim_id = ray_hit.hit.primID;
227            *geom_id = ray_hit.hit.geomID;
228            *inst_id = ray_hit.hit.instID[0];
229        }
230    }
231}
232
233#[inline(always)]
234fn offset(offset: usize, n: usize, i: usize) -> usize {
235    offset * n + i
236}