Skip to main content

oxicuda_nerf/rendering/
occupancy.rs

1//! Occupancy grid for accelerating NeRF rendering by skipping empty space.
2
3use crate::error::{NerfError, NerfResult};
4use crate::rendering::ray::Ray;
5
6/// 3D occupancy grid for fast ray marching.
7///
8/// The scene is assumed to span `[-scene_bound, scene_bound]^3`.
9/// The grid has `resolution^3` voxels.
10#[derive(Debug, Clone)]
11pub struct OccupancyGrid {
12    /// Flat bool occupancy: `[resolution * resolution * resolution]`.
13    pub data: Vec<bool>,
14    /// Grid resolution per axis.
15    pub resolution: usize,
16    /// Scene half-extent: scene spans `[-bound, bound]^3`.
17    pub scene_bound: f32,
18}
19
20impl OccupancyGrid {
21    /// Create a new occupancy grid, all empty.
22    ///
23    /// # Errors
24    ///
25    /// Returns `InvalidGridResolution` if `resolution == 0`.
26    pub fn new(resolution: usize, scene_bound: f32) -> NerfResult<Self> {
27        if resolution == 0 {
28            return Err(NerfError::InvalidGridResolution { res: 0 });
29        }
30        let total = resolution * resolution * resolution;
31        Ok(Self {
32            data: vec![false; total],
33            resolution,
34            scene_bound,
35        })
36    }
37
38    #[inline]
39    fn voxel_index(&self, ix: usize, iy: usize, iz: usize) -> usize {
40        ix * self.resolution * self.resolution + iy * self.resolution + iz
41    }
42
43    /// Mark a voxel as occupied/empty.
44    ///
45    /// # Errors
46    ///
47    /// Returns `HashLevelOutOfRange` if any index exceeds the resolution.
48    pub fn set(&mut self, ix: usize, iy: usize, iz: usize, occupied: bool) -> NerfResult<()> {
49        if ix >= self.resolution || iy >= self.resolution || iz >= self.resolution {
50            return Err(NerfError::HashLevelOutOfRange {
51                level: ix.max(iy).max(iz),
52            });
53        }
54        let idx = self.voxel_index(ix, iy, iz);
55        self.data[idx] = occupied;
56        Ok(())
57    }
58
59    /// Query occupancy of a voxel.
60    ///
61    /// # Errors
62    ///
63    /// Returns `HashLevelOutOfRange` if any index exceeds the resolution.
64    pub fn get(&self, ix: usize, iy: usize, iz: usize) -> NerfResult<bool> {
65        if ix >= self.resolution || iy >= self.resolution || iz >= self.resolution {
66            return Err(NerfError::HashLevelOutOfRange {
67                level: ix.max(iy).max(iz),
68            });
69        }
70        Ok(self.data[self.voxel_index(ix, iy, iz)])
71    }
72
73    /// Query if a world-space point lies in an occupied voxel.
74    ///
75    /// Points outside `[-scene_bound, scene_bound]^3` are considered empty.
76    #[must_use]
77    pub fn is_occupied_world(&self, xyz: [f32; 3]) -> bool {
78        let bound = self.scene_bound;
79        // Check bounds
80        if xyz[0] < -bound
81            || xyz[0] > bound
82            || xyz[1] < -bound
83            || xyz[1] > bound
84            || xyz[2] < -bound
85            || xyz[2] > bound
86        {
87            return false;
88        }
89        let res = self.resolution as f32;
90        let to_idx = |v: f32| -> usize {
91            let norm = (v + bound) / (2.0 * bound);
92            (norm * res).floor().clamp(0.0, res - 1.0) as usize
93        };
94        let ix = to_idx(xyz[0]);
95        let iy = to_idx(xyz[1]);
96        let iz = to_idx(xyz[2]);
97        self.data[self.voxel_index(ix, iy, iz)]
98    }
99
100    /// Update occupancy from density values using a threshold.
101    ///
102    /// `density` must have exactly `resolution^3` elements.
103    ///
104    /// # Errors
105    ///
106    /// Returns `DimensionMismatch` if sizes don't match.
107    pub fn update_from_density(&mut self, density: &[f32], threshold: f32) -> NerfResult<()> {
108        let expected = self.resolution * self.resolution * self.resolution;
109        if density.len() != expected {
110            return Err(NerfError::DimensionMismatch {
111                expected,
112                got: density.len(),
113            });
114        }
115        for (occ, &den) in self.data.iter_mut().zip(density.iter()) {
116            *occ = den > threshold;
117        }
118        Ok(())
119    }
120
121    /// March a ray and return t values where the grid is occupied.
122    ///
123    /// Steps along the ray with `step_size` and records t values for occupied voxels.
124    #[must_use]
125    pub fn march_ray_occupied(
126        &self,
127        ray: &Ray,
128        t_near: f32,
129        t_far: f32,
130        step_size: f32,
131    ) -> Vec<f32> {
132        if step_size <= 0.0 || t_far <= t_near {
133            return Vec::new();
134        }
135        let mut t = t_near;
136        let mut occupied_t = Vec::new();
137        while t <= t_far {
138            let pt = ray.at(t);
139            if self.is_occupied_world(pt) {
140                occupied_t.push(t);
141            }
142            t += step_size;
143        }
144        occupied_t
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn occupancy_set_get() {
154        let mut grid = OccupancyGrid::new(8, 1.0).unwrap();
155        grid.set(2, 3, 4, true).unwrap();
156        assert!(grid.get(2, 3, 4).unwrap());
157        assert!(!grid.get(0, 0, 0).unwrap());
158    }
159
160    #[test]
161    fn world_query_inside_bound() {
162        let mut grid = OccupancyGrid::new(4, 1.0).unwrap();
163        // Mark center voxel
164        grid.set(2, 2, 2, true).unwrap();
165        // World point that maps to (2,2,2) in a 4-res grid spanning [-1,1]
166        // cell width = 2/4 = 0.5, center of voxel 2 = -1 + 2.5*0.5 = 0.25
167        assert!(grid.is_occupied_world([0.25, 0.25, 0.25]));
168    }
169
170    #[test]
171    fn update_from_density() {
172        let mut grid = OccupancyGrid::new(2, 1.0).unwrap();
173        let density = vec![0.1, 0.2, 0.5, 0.8, 0.0, 0.9, 0.3, 0.7];
174        grid.update_from_density(&density, 0.4).unwrap();
175        // voxels 0,1,2,4,6 below threshold → empty; 3,5,7 above → occupied
176        assert!(!grid.data[0]);
177        assert!(grid.data[3]);
178        assert!(grid.data[5]);
179        assert!(grid.data[7]);
180    }
181}