Skip to main content

oxicuda_nerf/
lib.rs

1//! `oxicuda-nerf` — Neural Radiance Fields and neural rendering primitives for OxiCUDA.
2//!
3//! Pure-Rust implementation of canonical NeRF algorithms with GPU PTX kernel generation.
4//!
5//! # Architecture
6//!
7//! ```text
8//! oxicuda-nerf
9//! ├── camera/       — Pinhole camera model and ray generation
10//! ├── encoding/     — Positional encoding, Instant-NGP hash grid, Mip-NeRF IPE
11//! ├── error         — NerfError / NerfResult
12//! ├── field/        — TensoRF, Instant-NGP hash field
13//! ├── handle        — NerfHandle (SmVersion + LcgRng)
14//! ├── metrics/      — PSNR, MSE image quality metrics
15//! ├── network/      — NeRF MLP, TinyNeRF
16//! ├── ptx_kernels   — GPU PTX kernel strings (7 kernels × 6 SM versions)
17//! └── rendering/    — Ray, sampling, volume rendering, occupancy grid
18//! ```
19
20// ─── Module declarations ─────────────────────────────────────────────────────
21
22pub mod camera;
23pub mod encoding;
24pub mod error;
25pub mod field;
26pub mod handle;
27pub mod metrics;
28pub mod network;
29pub mod ptx_kernels;
30pub mod rendering;
31
32// ─── Prelude ─────────────────────────────────────────────────────────────────
33
34/// Convenience re-exports for common NeRF types and functions.
35pub mod prelude {
36    pub use crate::camera::pinhole::PinholeCamera;
37    pub use crate::encoding::hash_grid::{HashGrid, HashGridConfig};
38    pub use crate::encoding::integrated_pe::{IpeConfig, integrated_pe};
39    pub use crate::encoding::positional::{PosEncConfig, positional_encode};
40    pub use crate::error::{NerfError, NerfResult};
41    pub use crate::field::hash_field::HashField;
42    pub use crate::field::tensorf::{TensorRf, TensorRfConfig};
43    pub use crate::handle::{LcgRng, NerfHandle, SmVersion};
44    pub use crate::metrics::image_quality::{ImageMetrics, compute_image_metrics, psnr};
45    pub use crate::network::nerf_mlp::{NerfMlp, NerfMlpConfig};
46    pub use crate::network::tiny_nerf::TinyNerf;
47    pub use crate::ptx_kernels::{
48        f32_hex, hash_grid_lookup_ptx, importance_resample_ptx, occupancy_update_ptx,
49        positional_encoding_ptx, ray_march_ptx, sh_to_rgb_ptx, volume_render_ptx,
50    };
51    pub use crate::rendering::occupancy::OccupancyGrid;
52    pub use crate::rendering::ray::{PinholeCamera as RayCamera, Ray};
53    pub use crate::rendering::sampling::{importance_sample, merge_samples, stratified_sample};
54    pub use crate::rendering::volume_render::{RenderResult, volume_render, volume_render_batch};
55}
56
57// ─── End-to-end integration tests ────────────────────────────────────────────
58
59#[cfg(test)]
60mod e2e_tests {
61    use crate::prelude::*;
62
63    // ── Test 1: Positional encoding shape ────────────────────────────────────
64
65    #[test]
66    fn e2e_positional_encoding_shape() {
67        let cfg = PosEncConfig {
68            n_freq: 10,
69            include_input: true,
70            input_dim: 3,
71        };
72        let n_pts = 16;
73        let input = vec![0.5_f32; n_pts * 3];
74        let out = positional_encode(&input, &cfg).unwrap();
75        assert_eq!(
76            out.len(),
77            n_pts * cfg.output_dim(),
78            "E2E: positional encoding output shape mismatch"
79        );
80    }
81
82    // ── Test 2: Positional encoding determinism ───────────────────────────────
83
84    #[test]
85    fn e2e_positional_encoding_deterministic() {
86        let cfg = PosEncConfig {
87            n_freq: 4,
88            include_input: false,
89            input_dim: 3,
90        };
91        let input = vec![0.1_f32, 0.5, -0.3, 0.0, 1.0, 0.7];
92        let out1 = positional_encode(&input, &cfg).unwrap();
93        let out2 = positional_encode(&input, &cfg).unwrap();
94        assert_eq!(out1, out2, "E2E: positional encoding must be deterministic");
95    }
96
97    // ── Test 3: Hash grid query shape ─────────────────────────────────────────
98
99    #[test]
100    fn e2e_hash_grid_query_shape() {
101        let cfg = HashGridConfig {
102            n_levels: 8,
103            n_features_per_level: 2,
104            log2_hashmap_size: 10,
105            base_resolution: 8,
106            max_resolution: 256,
107        };
108        let mut rng = LcgRng::new(42);
109        let grid = HashGrid::new(cfg, &mut rng).unwrap();
110        let feat = grid.query([0.3, 0.7, 0.5]).unwrap();
111        assert_eq!(
112            feat.len(),
113            grid.output_dim(),
114            "E2E: hash grid output dim should be n_levels * n_feat"
115        );
116        assert_eq!(grid.output_dim(), 16);
117    }
118
119    // ── Test 4: Hash grid corner values differ ────────────────────────────────
120
121    #[test]
122    fn e2e_hash_grid_trilinear_corner() {
123        let cfg = HashGridConfig {
124            n_levels: 4,
125            n_features_per_level: 2,
126            log2_hashmap_size: 8,
127            base_resolution: 4,
128            max_resolution: 32,
129        };
130        let mut rng = LcgRng::new(1234);
131        let grid = HashGrid::new(cfg, &mut rng).unwrap();
132        let feat_origin = grid.query([0.0, 0.0, 0.0]).unwrap();
133        let feat_far = grid.query([1.0, 1.0, 1.0]).unwrap();
134        // With random initialization, corners almost certainly differ
135        let are_different = feat_origin
136            .iter()
137            .zip(feat_far.iter())
138            .any(|(a, b)| (a - b).abs() > 1e-9);
139        assert!(
140            are_different,
141            "E2E: corner queries should return different values"
142        );
143    }
144
145    // ── Test 5: Volume render empty scene ─────────────────────────────────────
146
147    #[test]
148    fn e2e_volume_render_empty_scene() {
149        let n = 64;
150        let sigma = vec![0.0_f32; n];
151        let color = vec![0.5_f32; n * 3];
152        let t: Vec<f32> = (0..n).map(|i| 0.1 + i as f32 * 0.1).collect();
153        let res = volume_render(&sigma, &color, &t).unwrap();
154        assert!(
155            res.opacity < 1e-6,
156            "E2E: empty scene (zero density) should have near-zero opacity, got {}",
157            res.opacity
158        );
159    }
160
161    // ── Test 6: Volume render opaque first sample ─────────────────────────────
162
163    #[test]
164    fn e2e_volume_render_opaque_first_sample() {
165        let n = 16;
166        let mut sigma = vec![0.0_f32; n];
167        sigma[0] = 1e8_f32; // Extremely dense first sample
168        let mut color = vec![0.0_f32; n * 3];
169        color[0] = 1.0; // First sample: red
170        color[1] = 0.0;
171        color[2] = 0.0;
172        let t: Vec<f32> = (0..n).map(|i| 0.1 + i as f32 * 0.2).collect();
173        let res = volume_render(&sigma, &color, &t).unwrap();
174        assert!(
175            res.rgb[0] > 0.99,
176            "E2E: opaque red first sample, expected R≈1, got {}",
177            res.rgb[0]
178        );
179        assert!(
180            res.opacity > 0.99,
181            "E2E: opaque first sample should have opacity≈1, got {}",
182            res.opacity
183        );
184    }
185
186    // ── Test 7: Stratified sampling count ─────────────────────────────────────
187
188    #[test]
189    fn e2e_stratified_sampling_count() {
190        let mut rng = LcgRng::new(99);
191        let t_near = 0.1_f32;
192        let t_far = 5.0_f32;
193        let n = 128;
194        let samples = stratified_sample(t_near, t_far, n, &mut rng).unwrap();
195        assert_eq!(
196            samples.len(),
197            n,
198            "E2E: stratified_sample must return exactly n_samples"
199        );
200        for &t in &samples {
201            assert!(
202                t >= t_near && t <= t_far,
203                "E2E: sample {t} out of bounds [{t_near}, {t_far}]"
204            );
205        }
206    }
207
208    // ── Test 8: Importance sampling count ────────────────────────────────────
209
210    #[test]
211    fn e2e_importance_sampling_count() {
212        let mut rng = LcgRng::new(77);
213        let coarse_t = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
214        let weights = vec![0.01, 0.1, 0.5, 0.2, 0.1, 0.05, 0.02, 0.02];
215        let n_fine = 32;
216        let fine = importance_sample(&coarse_t, &weights, n_fine, &mut rng).unwrap();
217        assert_eq!(
218            fine.len(),
219            n_fine,
220            "E2E: importance_sample must return n_fine samples"
221        );
222    }
223
224    // ── Test 9: TensoRF density non-negative ──────────────────────────────────
225
226    #[test]
227    fn e2e_tensorf_density_nonneg() {
228        let cfg = TensorRfConfig {
229            rank: 8,
230            grid_dim: 16,
231            n_color_feat: 3,
232        };
233        let mut rng = LcgRng::new(2024);
234        let tf = TensorRf::new(cfg, &mut rng).unwrap();
235
236        let test_pts: &[[f32; 3]] = &[
237            [0.0, 0.0, 0.0],
238            [0.5, 0.5, 0.5],
239            [-1.0, -1.0, -1.0],
240            [1.0, 1.0, 1.0],
241            [0.3, -0.7, 0.9],
242        ];
243        for &xyz in test_pts {
244            let d = tf.query_density(xyz).unwrap();
245            assert!(
246                d >= 0.0,
247                "E2E: TensoRF density should be >= 0 (got {d}) at {:?}",
248                xyz
249            );
250        }
251    }
252
253    // ── Test 10: TinyNerf forward finite ─────────────────────────────────────
254
255    #[test]
256    fn e2e_tiny_nerf_forward_finite() {
257        let mut rng = LcgRng::new(314);
258        let net = TinyNerf::new(24, 64, &mut rng);
259        let x = vec![0.1_f32; 24];
260        let (sigma, rgb) = net.forward(&x).unwrap();
261        assert!(
262            sigma.is_finite(),
263            "E2E: TinyNerf sigma must be finite, got {sigma}"
264        );
265        assert!(sigma >= 0.0, "E2E: TinyNerf sigma must be >= 0");
266        for (i, &c) in rgb.iter().enumerate() {
267            assert!(
268                c.is_finite(),
269                "E2E: TinyNerf RGB[{i}] must be finite, got {c}"
270            );
271            assert!(
272                (0.0..=1.0).contains(&c),
273                "E2E: TinyNerf RGB[{i}]={c} must be in [0, 1]"
274            );
275        }
276    }
277
278    // ── Test 11: PSNR on identical images ────────────────────────────────────
279
280    #[test]
281    fn e2e_psnr_identity() {
282        let img = vec![0.5_f32; 256 * 256 * 3];
283        let p = psnr(&img, &img).unwrap();
284        assert!(
285            p.is_infinite() || p > 100.0,
286            "E2E: psnr(x, x) should be Inf or very large, got {p}"
287        );
288    }
289
290    // ── Test 12: All 7 PTX kernels × 6 SM versions ───────────────────────────
291
292    #[test]
293    #[allow(clippy::type_complexity)]
294    fn e2e_ptx_kernels_all_sm_versions() {
295        let sm_versions = [75_u32, 80, 86, 90, 100, 120];
296        let kernel_fns: &[(&str, fn(u32) -> String)] = &[
297            ("pe_kernel", positional_encoding_ptx),
298            ("volume_render_kernel", volume_render_ptx),
299            ("hash_grid_kernel", hash_grid_lookup_ptx),
300            ("ray_march_kernel", ray_march_ptx),
301            ("sh_eval_nerf_kernel", sh_to_rgb_ptx),
302            ("occupancy_update_kernel", occupancy_update_ptx),
303            ("importance_resample_kernel", importance_resample_ptx),
304        ];
305        for sm in sm_versions {
306            for (kernel_name, gen_fn) in kernel_fns {
307                let ptx = gen_fn(sm);
308                assert!(
309                    ptx.contains(&format!("sm_{sm}")),
310                    "PTX for {kernel_name} sm={sm} missing sm target"
311                );
312                assert!(
313                    ptx.contains(".version"),
314                    "PTX for {kernel_name} sm={sm} missing .version"
315                );
316                assert!(
317                    ptx.contains(".visible .entry"),
318                    "PTX for {kernel_name} sm={sm} missing .visible .entry"
319                );
320                assert!(
321                    ptx.contains(kernel_name),
322                    "PTX for {kernel_name} sm={sm} missing kernel name"
323                );
324            }
325        }
326        // Smoke-test f32_hex
327        assert_eq!(f32_hex(1.0_f32), "0F3F800000");
328    }
329}