Skip to main content

oxicuda_nerf/rendering/
sampling.rs

1//! Ray sampling strategies: stratified sampling and importance resampling.
2
3use crate::error::{NerfError, NerfResult};
4use crate::handle::LcgRng;
5
6// ─── Stratified sampling ──────────────────────────────────────────────────────
7
8/// Sample `n_samples` positions along a ray between `t_near` and `t_far`
9/// using stratified (jittered) sampling.
10///
11/// `t_i = t_near + (i + U(0,1)) / n_samples * (t_far - t_near)` for i = 0..n_samples
12///
13/// # Errors
14///
15/// Returns `InvalidBounds` if `t_far <= t_near`,
16/// `InvalidSampleCount` if `n_samples == 0`.
17pub fn stratified_sample(
18    t_near: f32,
19    t_far: f32,
20    n_samples: usize,
21    rng: &mut LcgRng,
22) -> NerfResult<Vec<f32>> {
23    if t_far <= t_near {
24        return Err(NerfError::InvalidBounds {
25            near: t_near,
26            far: t_far,
27        });
28    }
29    if n_samples == 0 {
30        return Err(NerfError::InvalidSampleCount { n: 0 });
31    }
32
33    let span = t_far - t_near;
34    let inv_n = 1.0 / n_samples as f32;
35
36    let t_vals: Vec<f32> = (0..n_samples)
37        .map(|i| {
38            let jitter = rng.next_f32();
39            t_near + (i as f32 + jitter) * inv_n * span
40        })
41        .collect();
42
43    Ok(t_vals)
44}
45
46// ─── Importance resampling ────────────────────────────────────────────────────
47
48/// Draw `n_fine` sample positions using inverse-CDF sampling from coarse weights.
49///
50/// Implements hierarchical NeRF sampling:
51/// 1. Build CDF from weights (with small ε for numerical stability, renormalize).
52/// 2. Draw n_fine uniform samples u_j in [0, 1].
53/// 3. Binary search for u_j in CDF to get t_low, t_high.
54/// 4. Linearly interpolate for exact t position.
55///
56/// # Errors
57///
58/// Returns `DimensionMismatch` if `coarse_t.len() != weights.len()`,
59/// `InvalidSampleCount` if `n_fine == 0` or coarse arrays are empty.
60pub fn importance_sample(
61    coarse_t: &[f32],
62    weights: &[f32],
63    n_fine: usize,
64    rng: &mut LcgRng,
65) -> NerfResult<Vec<f32>> {
66    if coarse_t.len() != weights.len() {
67        return Err(NerfError::DimensionMismatch {
68            expected: coarse_t.len(),
69            got: weights.len(),
70        });
71    }
72    if coarse_t.is_empty() {
73        return Err(NerfError::EmptyInput);
74    }
75    if n_fine == 0 {
76        return Err(NerfError::InvalidSampleCount { n: 0 });
77    }
78
79    let n = weights.len();
80    let eps = 1e-5_f32;
81
82    // Build normalized CDF
83    let w_sum: f32 = weights.iter().map(|&w| w.max(0.0) + eps).sum();
84    let mut cdf = Vec::with_capacity(n + 1);
85    cdf.push(0.0_f32);
86    let mut running = 0.0_f32;
87    for &w in weights {
88        running += w.max(0.0) + eps;
89        cdf.push(running / w_sum);
90    }
91    // Clamp last entry to exactly 1.0
92    if let Some(last) = cdf.last_mut() {
93        *last = 1.0;
94    }
95
96    // Draw n_fine samples via inverse CDF
97    let t_vals: Vec<f32> = (0..n_fine)
98        .map(|_| {
99            let u = rng.next_f32();
100            // Binary search for u in cdf[1..n+1]
101            let idx = cdf
102                .partition_point(|&c| c <= u)
103                .saturating_sub(1)
104                .min(n - 1);
105            let c_lo = cdf[idx];
106            let c_hi = cdf[idx + 1];
107            let t_lo = coarse_t[idx];
108            let t_hi = if idx + 1 < n {
109                coarse_t[idx + 1]
110            } else {
111                coarse_t[idx]
112            };
113
114            // Linear interpolation
115            let denom = c_hi - c_lo;
116            if denom < 1e-10 {
117                t_lo
118            } else {
119                t_lo + (u - c_lo) / denom * (t_hi - t_lo)
120            }
121        })
122        .collect();
123
124    Ok(t_vals)
125}
126
127// ─── Merge samples ────────────────────────────────────────────────────────────
128
129/// Merge and sort coarse and fine sample positions, removing near-duplicates.
130///
131/// Returns a sorted, deduplicated `Vec<f32>`.
132#[must_use]
133pub fn merge_samples(coarse_t: &[f32], fine_t: &[f32]) -> Vec<f32> {
134    let mut merged: Vec<f32> = coarse_t.iter().chain(fine_t.iter()).copied().collect();
135    merged.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
136    // Deduplicate values within epsilon
137    let mut dedup = Vec::with_capacity(merged.len());
138    for t in merged {
139        if dedup
140            .last()
141            .is_none_or(|&prev: &f32| (t - prev).abs() > 1e-7)
142        {
143            dedup.push(t);
144        }
145    }
146    dedup
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn stratified_count() {
155        let mut rng = LcgRng::new(42);
156        let t = stratified_sample(0.0, 1.0, 64, &mut rng).unwrap();
157        assert_eq!(t.len(), 64);
158    }
159
160    #[test]
161    fn stratified_in_bounds() {
162        let mut rng = LcgRng::new(7);
163        let t = stratified_sample(0.1, 5.0, 32, &mut rng).unwrap();
164        for &v in &t {
165            assert!((0.1..=5.0).contains(&v), "t={v} out of bounds");
166        }
167    }
168
169    #[test]
170    fn importance_count() {
171        let mut rng = LcgRng::new(99);
172        let coarse_t = vec![0.1, 0.2, 0.3, 0.4];
173        let weights = vec![0.1, 0.5, 0.3, 0.1];
174        let fine = importance_sample(&coarse_t, &weights, 8, &mut rng).unwrap();
175        assert_eq!(fine.len(), 8);
176    }
177
178    #[test]
179    fn merge_sorted() {
180        let coarse = [0.1, 0.3, 0.5];
181        let fine = [0.2, 0.4];
182        let merged = merge_samples(&coarse, &fine);
183        assert!(merged.windows(2).all(|w| w[0] <= w[1]));
184    }
185}