const EPS: f32 = 1e-5;
#[derive(Debug, Clone)]
pub struct DepthComputer {
curvature: f32,
level_thresholds: [f32; 4],
}
impl DepthComputer {
pub fn new(curvature: f32) -> Self {
Self {
curvature,
level_thresholds: [0.5, 1.0, 2.0, 3.0],
}
}
pub fn with_thresholds(curvature: f32, thresholds: [f32; 4]) -> Self {
Self {
curvature,
level_thresholds: thresholds,
}
}
pub fn compute_depth(&self, point: &[f32]) -> f32 {
let norm_sq: f32 = point.iter().map(|x| x * x).sum();
let norm = norm_sq.sqrt();
if norm < EPS {
return 0.0;
}
let c = -self.curvature;
let clamped_norm = norm.min(1.0 - EPS);
let arctanh = 0.5 * ((1.0 + clamped_norm) / (1.0 - clamped_norm)).ln();
2.0 * arctanh / c.sqrt()
}
pub fn normalized_depth(&self, point: &[f32]) -> f32 {
let depth = self.compute_depth(point);
(depth / 5.0).min(1.0)
}
pub fn classify_level(&self, depth: f32) -> HierarchyLevel {
if depth < self.level_thresholds[0] {
HierarchyLevel::Root
} else if depth < self.level_thresholds[1] {
HierarchyLevel::High
} else if depth < self.level_thresholds[2] {
HierarchyLevel::Mid
} else if depth < self.level_thresholds[3] {
HierarchyLevel::Deep
} else {
HierarchyLevel::VeryDeep
}
}
pub fn radius_for_depth(&self, target_depth: f32) -> f32 {
let c = -self.curvature;
(target_depth * c.sqrt() / 2.0).tanh()
}
pub fn curvature(&self) -> f32 {
self.curvature
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum HierarchyLevel {
Root,
High,
Mid,
Deep,
VeryDeep,
}
impl HierarchyLevel {
pub fn as_level(&self) -> usize {
match self {
Self::Root => 0,
Self::High => 1,
Self::Mid => 2,
Self::Deep => 3,
Self::VeryDeep => 4,
}
}
pub fn weight_multiplier(&self) -> f32 {
match self {
Self::Root => 1.0,
Self::High => 1.2,
Self::Mid => 1.5,
Self::Deep => 2.0,
Self::VeryDeep => 3.0,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Root => "root",
Self::High => "high",
Self::Mid => "mid",
Self::Deep => "deep",
Self::VeryDeep => "very_deep",
}
}
}
impl std::fmt::Display for HierarchyLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_depth_at_origin() {
let computer = DepthComputer::new(-1.0);
let origin = vec![0.0, 0.0, 0.0, 0.0];
let depth = computer.compute_depth(&origin);
assert!(depth < 0.01);
}
#[test]
fn test_depth_increases_with_radius() {
let computer = DepthComputer::new(-1.0);
let point1 = vec![0.1, 0.0, 0.0, 0.0];
let point2 = vec![0.5, 0.0, 0.0, 0.0];
let point3 = vec![0.9, 0.0, 0.0, 0.0];
let d1 = computer.compute_depth(&point1);
let d2 = computer.compute_depth(&point2);
let d3 = computer.compute_depth(&point3);
assert!(d1 < d2);
assert!(d2 < d3);
}
#[test]
fn test_hierarchy_levels() {
let computer = DepthComputer::new(-1.0);
assert_eq!(
computer.classify_level(0.3),
HierarchyLevel::Root
);
assert_eq!(
computer.classify_level(0.7),
HierarchyLevel::High
);
assert_eq!(
computer.classify_level(1.5),
HierarchyLevel::Mid
);
assert_eq!(
computer.classify_level(2.5),
HierarchyLevel::Deep
);
assert_eq!(
computer.classify_level(4.0),
HierarchyLevel::VeryDeep
);
}
#[test]
fn test_radius_for_depth() {
let computer = DepthComputer::new(-1.0);
let radius = computer.radius_for_depth(1.0);
let point = vec![radius, 0.0, 0.0, 0.0];
let computed_depth = computer.compute_depth(&point);
assert!((computed_depth - 1.0).abs() < 0.01);
}
#[test]
fn test_normalized_depth() {
let computer = DepthComputer::new(-1.0);
let shallow = vec![0.1, 0.0, 0.0, 0.0];
let deep = vec![0.95, 0.0, 0.0, 0.0];
let norm_shallow = computer.normalized_depth(&shallow);
let norm_deep = computer.normalized_depth(&deep);
assert!(norm_shallow < 0.2);
assert!(norm_deep > 0.5);
assert!(norm_shallow <= 1.0);
assert!(norm_deep <= 1.0);
}
}