oxicuda_nerf/rendering/
sampling.rs1use crate::error::{NerfError, NerfResult};
4use crate::handle::LcgRng;
5
6pub fn stratified_sample(
18 t_near: f32,
19 t_far: f32,
20 n_samples: usize,
21 rng: &mut LcgRng,
22) -> NerfResult<Vec<f32>> {
23 if t_far <= t_near {
24 return Err(NerfError::InvalidBounds {
25 near: t_near,
26 far: t_far,
27 });
28 }
29 if n_samples == 0 {
30 return Err(NerfError::InvalidSampleCount { n: 0 });
31 }
32
33 let span = t_far - t_near;
34 let inv_n = 1.0 / n_samples as f32;
35
36 let t_vals: Vec<f32> = (0..n_samples)
37 .map(|i| {
38 let jitter = rng.next_f32();
39 t_near + (i as f32 + jitter) * inv_n * span
40 })
41 .collect();
42
43 Ok(t_vals)
44}
45
46pub fn importance_sample(
61 coarse_t: &[f32],
62 weights: &[f32],
63 n_fine: usize,
64 rng: &mut LcgRng,
65) -> NerfResult<Vec<f32>> {
66 if coarse_t.len() != weights.len() {
67 return Err(NerfError::DimensionMismatch {
68 expected: coarse_t.len(),
69 got: weights.len(),
70 });
71 }
72 if coarse_t.is_empty() {
73 return Err(NerfError::EmptyInput);
74 }
75 if n_fine == 0 {
76 return Err(NerfError::InvalidSampleCount { n: 0 });
77 }
78
79 let n = weights.len();
80 let eps = 1e-5_f32;
81
82 let w_sum: f32 = weights.iter().map(|&w| w.max(0.0) + eps).sum();
84 let mut cdf = Vec::with_capacity(n + 1);
85 cdf.push(0.0_f32);
86 let mut running = 0.0_f32;
87 for &w in weights {
88 running += w.max(0.0) + eps;
89 cdf.push(running / w_sum);
90 }
91 if let Some(last) = cdf.last_mut() {
93 *last = 1.0;
94 }
95
96 let t_vals: Vec<f32> = (0..n_fine)
98 .map(|_| {
99 let u = rng.next_f32();
100 let idx = cdf
102 .partition_point(|&c| c <= u)
103 .saturating_sub(1)
104 .min(n - 1);
105 let c_lo = cdf[idx];
106 let c_hi = cdf[idx + 1];
107 let t_lo = coarse_t[idx];
108 let t_hi = if idx + 1 < n {
109 coarse_t[idx + 1]
110 } else {
111 coarse_t[idx]
112 };
113
114 let denom = c_hi - c_lo;
116 if denom < 1e-10 {
117 t_lo
118 } else {
119 t_lo + (u - c_lo) / denom * (t_hi - t_lo)
120 }
121 })
122 .collect();
123
124 Ok(t_vals)
125}
126
127#[must_use]
133pub fn merge_samples(coarse_t: &[f32], fine_t: &[f32]) -> Vec<f32> {
134 let mut merged: Vec<f32> = coarse_t.iter().chain(fine_t.iter()).copied().collect();
135 merged.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
136 let mut dedup = Vec::with_capacity(merged.len());
138 for t in merged {
139 if dedup
140 .last()
141 .is_none_or(|&prev: &f32| (t - prev).abs() > 1e-7)
142 {
143 dedup.push(t);
144 }
145 }
146 dedup
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn stratified_count() {
155 let mut rng = LcgRng::new(42);
156 let t = stratified_sample(0.0, 1.0, 64, &mut rng).unwrap();
157 assert_eq!(t.len(), 64);
158 }
159
160 #[test]
161 fn stratified_in_bounds() {
162 let mut rng = LcgRng::new(7);
163 let t = stratified_sample(0.1, 5.0, 32, &mut rng).unwrap();
164 for &v in &t {
165 assert!((0.1..=5.0).contains(&v), "t={v} out of bounds");
166 }
167 }
168
169 #[test]
170 fn importance_count() {
171 let mut rng = LcgRng::new(99);
172 let coarse_t = vec![0.1, 0.2, 0.3, 0.4];
173 let weights = vec![0.1, 0.5, 0.3, 0.1];
174 let fine = importance_sample(&coarse_t, &weights, 8, &mut rng).unwrap();
175 assert_eq!(fine.len(), 8);
176 }
177
178 #[test]
179 fn merge_sorted() {
180 let coarse = [0.1, 0.3, 0.5];
181 let fine = [0.2, 0.4];
182 let merged = merge_samples(&coarse, &fine);
183 assert!(merged.windows(2).all(|w| w[0] <= w[1]));
184 }
185}