use crate::error::{NerfError, NerfResult};
use crate::handle::LcgRng;
const PI2: u64 = 2_654_435_761;
const PI3: u64 = 805_459_861;
#[derive(Debug, Clone)]
pub struct HashGridConfig {
pub n_levels: usize,
pub n_features_per_level: usize,
pub log2_hashmap_size: usize,
pub base_resolution: usize,
pub max_resolution: usize,
}
#[derive(Debug, Clone)]
pub struct HashGrid {
pub config: HashGridConfig,
pub data: Vec<f32>,
level_resolutions: Vec<usize>,
}
impl HashGrid {
pub fn new(cfg: HashGridConfig, rng: &mut LcgRng) -> NerfResult<Self> {
if cfg.n_levels == 0 {
return Err(NerfError::InvalidHashConfig {
msg: "n_levels must be > 0".into(),
});
}
if cfg.n_features_per_level == 0 {
return Err(NerfError::InvalidHashConfig {
msg: "n_features_per_level must be > 0".into(),
});
}
if cfg.log2_hashmap_size == 0 || cfg.log2_hashmap_size > 32 {
return Err(NerfError::InvalidHashConfig {
msg: "log2_hashmap_size must be in 1..=32".into(),
});
}
if cfg.base_resolution == 0 {
return Err(NerfError::InvalidHashConfig {
msg: "base_resolution must be > 0".into(),
});
}
if cfg.max_resolution < cfg.base_resolution {
return Err(NerfError::InvalidHashConfig {
msg: "max_resolution must be >= base_resolution".into(),
});
}
let t = 1_usize << cfg.log2_hashmap_size;
let level_resolutions = if cfg.n_levels == 1 {
vec![cfg.base_resolution]
} else {
let b = ((cfg.max_resolution as f64) / (cfg.base_resolution as f64)).ln()
/ (cfg.n_levels - 1) as f64;
(0..cfg.n_levels)
.map(|l| {
let n_l = (cfg.base_resolution as f64 * (b * l as f64).exp()).floor() as usize;
n_l.max(1)
})
.collect()
};
let total = cfg.n_levels * t * cfg.n_features_per_level;
let mut data = vec![0.0_f32; total];
for v in data.iter_mut() {
*v = rng.next_f32_range(-0.0001, 0.0001);
}
Ok(Self {
config: cfg,
data,
level_resolutions,
})
}
#[must_use]
pub fn output_dim(&self) -> usize {
self.config.n_levels * self.config.n_features_per_level
}
pub fn query(&self, xyz: [f32; 3]) -> NerfResult<Vec<f32>> {
let t = 1_usize << self.config.log2_hashmap_size;
let f = self.config.n_features_per_level;
let mut out = vec![0.0_f32; self.output_dim()];
for (level, &n_l) in self.level_resolutions.iter().enumerate() {
let sx = xyz[0].clamp(0.0, 1.0) * (n_l as f32);
let sy = xyz[1].clamp(0.0, 1.0) * (n_l as f32);
let sz = xyz[2].clamp(0.0, 1.0) * (n_l as f32);
let ix = sx.floor() as i64;
let iy = sy.floor() as i64;
let iz = sz.floor() as i64;
let fx = sx - ix as f32;
let fy = sy - iy as f32;
let fz = sz - iz as f32;
let level_offset = level * t * f;
for cx in 0_u8..=1 {
for cy in 0_u8..=1 {
for cz in 0_u8..=1 {
let xi = ix + i64::from(cx);
let yi = iy + i64::from(cy);
let zi = iz + i64::from(cz);
let bucket = hash_coord(xi, yi, zi, t);
let w = trilinear_weight(fx, fy, fz, cx, cy, cz);
let base = level_offset + bucket * f;
let out_base = level * f;
for feat in 0..f {
out[out_base + feat] += w * self.data[base + feat];
}
}
}
}
}
Ok(out)
}
pub fn query_batch(&self, xyz_batch: &[f32], n: usize) -> NerfResult<Vec<f32>> {
if xyz_batch.len() != n * 3 {
return Err(NerfError::DimensionMismatch {
expected: n * 3,
got: xyz_batch.len(),
});
}
let out_dim = self.output_dim();
let mut out = vec![0.0_f32; n * out_dim];
for (i, out_chunk) in out.chunks_mut(out_dim).enumerate() {
let x = xyz_batch[i * 3];
let y = xyz_batch[i * 3 + 1];
let z = xyz_batch[i * 3 + 2];
let feat = self.query([x, y, z])?;
out_chunk.copy_from_slice(&feat);
}
Ok(out)
}
}
#[inline]
fn hash_coord(xi: i64, yi: i64, zi: i64, t: usize) -> usize {
let hx = xi as u64;
let hy = (yi as u64).wrapping_mul(PI2);
let hz = (zi as u64).wrapping_mul(PI3);
(hx ^ hy ^ hz) as usize % t
}
#[inline]
fn trilinear_weight(fx: f32, fy: f32, fz: f32, cx: u8, cy: u8, cz: u8) -> f32 {
let wx = if cx == 1 { fx } else { 1.0 - fx };
let wy = if cy == 1 { fy } else { 1.0 - fy };
let wz = if cz == 1 { fz } else { 1.0 - fz };
wx * wy * wz
}
#[cfg(test)]
mod tests {
use super::*;
fn make_grid(seed: u64) -> HashGrid {
let cfg = HashGridConfig {
n_levels: 4,
n_features_per_level: 2,
log2_hashmap_size: 8,
base_resolution: 4,
max_resolution: 32,
};
let mut rng = LcgRng::new(seed);
HashGrid::new(cfg, &mut rng).unwrap()
}
#[test]
fn query_output_shape() {
let grid = make_grid(1);
let feat = grid.query([0.5, 0.5, 0.5]).unwrap();
assert_eq!(feat.len(), grid.output_dim());
}
#[test]
fn batch_output_shape() {
let grid = make_grid(2);
let pts: Vec<f32> = (0..5).flat_map(|i| [i as f32 * 0.2; 3]).collect();
let out = grid.query_batch(&pts, 5).unwrap();
assert_eq!(out.len(), 5 * grid.output_dim());
}
#[test]
fn hash_coord_deterministic() {
assert_eq!(hash_coord(1, 2, 3, 256), hash_coord(1, 2, 3, 256));
}
#[test]
fn trilinear_weights_sum_to_one() {
let (fx, fy, fz) = (0.3, 0.7, 0.1);
let mut sum = 0.0_f32;
for cx in 0_u8..=1 {
for cy in 0_u8..=1 {
for cz in 0_u8..=1 {
sum += trilinear_weight(fx, fy, fz, cx, cy, cz);
}
}
}
assert!((sum - 1.0).abs() < 1e-6, "weights sum={sum}");
}
}