pub mod implicit;
pub mod mesh;
pub mod octree;
pub mod renderer;
pub use implicit::{FourierFeatures, GlobalSDF, LocalSDF};
pub use mesh::{MarchingCubes, Mesh, Triangle, Vertex};
pub use octree::{AABB, AdaptiveOctree, OctreeNode};
pub use renderer::{Camera, DifferentiableRenderer, RayHit, RenderOutput, SphereTracingConfig};
use axonml_nn::Parameter;
#[derive(Debug, Clone)]
pub struct Aegis3DConfig {
pub scene_bounds: AABB,
pub max_depth: usize,
pub sdf_hidden_dim: usize,
pub num_frequencies: usize,
pub init_depth: usize,
pub render_config: SphereTracingConfig,
pub mesh_resolution: usize,
pub lambda_eikonal: f32,
pub lambda_smooth: f32,
}
impl Default for Aegis3DConfig {
fn default() -> Self {
Self {
scene_bounds: AABB::new([-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]),
max_depth: 6,
sdf_hidden_dim: 64,
num_frequencies: 4,
init_depth: 3,
render_config: SphereTracingConfig::default(),
mesh_resolution: 64,
lambda_eikonal: 0.1,
lambda_smooth: 0.01,
}
}
}
struct StoredView {
depth_map: Vec<f32>,
camera: Camera,
_affected_nodes: Vec<usize>,
}
pub struct Aegis3D {
pub octree: AdaptiveOctree,
pub renderer: DifferentiableRenderer,
pub config: Aegis3DConfig,
views: Vec<StoredView>,
}
impl Default for Aegis3D {
fn default() -> Self {
Self::new()
}
}
impl Aegis3D {
pub fn new() -> Self {
Self::with_config(Aegis3DConfig::default())
}
pub fn with_config(config: Aegis3DConfig) -> Self {
let mut octree = AdaptiveOctree::new(config.scene_bounds, config.max_depth);
octree.sdf_hidden_dim = config.sdf_hidden_dim;
octree.sdf_num_freq = config.num_frequencies;
Self {
octree,
renderer: DifferentiableRenderer::with_config(config.render_config.clone()),
config,
views: Vec::new(),
}
}
pub fn reconstruct_from_depth(&mut self, depth_map: &[f32], camera: &Camera) -> Mesh {
let points = self.backproject_depth(depth_map, camera);
self.octree.init_from_depth(&points, self.config.init_depth);
self.octree.extract_mesh(self.config.mesh_resolution)
}
pub fn add_view(&mut self, depth_map: &[f32], camera: Camera) {
let points = self.backproject_depth(depth_map, &camera);
self.octree.init_from_depth(&points, self.config.init_depth);
self.views.push(StoredView {
depth_map: depth_map.to_vec(),
camera,
_affected_nodes: Vec::new(), });
}
pub fn optimize(&mut self, num_steps: usize, learning_rate: f32) {
for _step in 0..num_steps {
let mut total_loss = 0.0f32;
for view in &self.views {
let rendered = self.renderer.render(&self.octree, &view.camera);
let depth_loss =
compute_depth_loss(&rendered.depth_map, &view.depth_map, &rendered.hit_mask);
total_loss += depth_loss;
}
let eikonal_loss = self.compute_eikonal_loss(256);
total_loss += self.config.lambda_eikonal * eikonal_loss;
let params = self.octree.parameters();
for param in ¶ms {
let var = param.variable();
let data = var.data().to_vec();
let _perturbed: Vec<f32> = data
.iter()
.map(|&v| v - learning_rate * total_loss.signum() * 0.001)
.collect();
}
}
}
fn compute_eikonal_loss(&self, num_samples: usize) -> f32 {
let bounds = self.octree.root.bounds();
let mut loss = 0.0f32;
let eps = 0.001;
for i in 0..num_samples {
let t = i as f32 / num_samples as f32;
let x = bounds.min[0] + t * (bounds.max[0] - bounds.min[0]);
let y = bounds.min[1] + ((t * 7.0) % 1.0) * (bounds.max[1] - bounds.min[1]);
let z = bounds.min[2] + ((t * 13.0) % 1.0) * (bounds.max[2] - bounds.min[2]);
let dx =
self.octree.query_sdf([x + eps, y, z]) - self.octree.query_sdf([x - eps, y, z]);
let dy =
self.octree.query_sdf([x, y + eps, z]) - self.octree.query_sdf([x, y - eps, z]);
let dz =
self.octree.query_sdf([x, y, z + eps]) - self.octree.query_sdf([x, y, z - eps]);
let grad_mag = ((dx * dx + dy * dy + dz * dz) / (4.0 * eps * eps)).sqrt();
loss += (grad_mag - 1.0).powi(2);
}
loss / num_samples as f32
}
fn backproject_depth(&self, depth_map: &[f32], camera: &Camera) -> Vec<[f32; 3]> {
let w = camera.width;
let h = camera.height;
let mut points = Vec::new();
let step = ((w * h) as f32 / 1000.0).max(1.0) as usize;
for i in (0..w * h).step_by(step) {
let depth = depth_map[i];
if depth <= 0.0 || depth >= self.renderer.config.max_distance {
continue;
}
let x = (i % w) as f32;
let y = (i / w) as f32;
let dir = camera.ray_direction(x + 0.5, y + 0.5);
points.push([
camera.position[0] + dir[0] * depth,
camera.position[1] + dir[1] * depth,
camera.position[2] + dir[2] * depth,
]);
}
points
}
pub fn extract_mesh(&self, resolution: usize) -> Mesh {
self.octree.extract_mesh(resolution)
}
pub fn extract_mesh_lod(&self, resolution: usize, max_lod: usize) -> Mesh {
self.octree.extract_mesh_lod(resolution, max_lod)
}
pub fn refine(&mut self, error_threshold: f32) {
self.octree.refine(error_threshold);
}
pub fn parameters(&self) -> Vec<Parameter> {
self.octree.parameters()
}
pub fn num_sdf_networks(&self) -> usize {
self.octree.num_leaves()
}
pub fn num_nodes(&self) -> usize {
self.octree.num_nodes()
}
}
fn compute_depth_loss(rendered: &[f32], observed: &[f32], mask: &[f32]) -> f32 {
let mut loss = 0.0f32;
let mut count = 0.0f32;
for i in 0..rendered.len().min(observed.len()) {
if mask[i] > 0.5 {
let diff = rendered[i] - observed[i];
loss += diff * diff;
count += 1.0;
}
}
if count > 0.0 { loss / count } else { 0.0 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aegis3d_creation() {
let aegis = Aegis3D::new();
assert_eq!(aegis.num_sdf_networks(), 1); assert_eq!(aegis.num_nodes(), 1);
assert!(!aegis.parameters().is_empty());
}
#[test]
fn test_aegis3d_custom_config() {
let config = Aegis3DConfig {
scene_bounds: AABB::new([-2.0, -2.0, -2.0], [2.0, 2.0, 2.0]),
max_depth: 4,
mesh_resolution: 16,
..Default::default()
};
let aegis = Aegis3D::with_config(config);
assert_eq!(aegis.config.max_depth, 4);
}
#[test]
fn test_depth_backprojection() {
let aegis = Aegis3D::new();
let camera = Camera::look_at([0.0, 0.0, 3.0], [0.0, 0.0, 0.0], 16, 16, 60.0);
let depth_map = vec![1.5; 16 * 16];
let points = aegis.backproject_depth(&depth_map, &camera);
assert!(!points.is_empty());
for p in &points {
let dist = ((p[0] - 0.0).powi(2) + (p[1] - 0.0).powi(2) + (p[2] - 3.0).powi(2)).sqrt();
assert!(
(dist - 1.5).abs() < 0.5,
"Point distance {dist} should be ~1.5"
);
}
}
#[test]
fn test_add_view() {
let mut aegis = Aegis3D::with_config(Aegis3DConfig {
scene_bounds: AABB::new([-3.0, -3.0, -3.0], [3.0, 3.0, 3.0]),
init_depth: 2,
..Default::default()
});
let camera = Camera::look_at([0.0, 0.0, 3.0], [0.0, 0.0, 0.0], 16, 16, 60.0);
let depth_map = vec![1.5; 16 * 16];
aegis.add_view(&depth_map, camera);
assert_eq!(aegis.views.len(), 1);
assert!(
aegis.num_sdf_networks() > 1,
"Should have subdivided octree"
);
}
#[test]
fn test_reconstruct_from_depth() {
let mut aegis = Aegis3D::with_config(Aegis3DConfig {
mesh_resolution: 8,
init_depth: 2,
..Default::default()
});
let camera = Camera::look_at([0.0, 0.0, 3.0], [0.0, 0.0, 0.0], 16, 16, 60.0);
let depth_map = vec![1.5; 16 * 16];
let mesh = aegis.reconstruct_from_depth(&depth_map, &camera);
assert!(mesh.num_vertices() >= 0);
}
#[test]
fn test_mesh_export_pipeline() {
let mc = MarchingCubes::new(12);
let mesh = mc.extract(
|x, y, z| {
let cx = x - 0.5;
let cy = y - 0.5;
let cz = z - 0.5;
(cx * cx + cy * cy + cz * cz).sqrt() - 0.3
},
[0.0, 0.0, 0.0],
[1.0, 1.0, 1.0],
);
let obj = mesh.to_obj();
assert!(obj.contains("v "));
assert!(obj.contains("f "));
let stl = mesh.to_stl_binary();
assert!(stl.len() > 84); }
#[test]
fn test_eikonal_loss() {
let aegis = Aegis3D::new();
let loss = aegis.compute_eikonal_loss(64);
assert!(loss.is_finite());
assert!(loss >= 0.0);
}
#[test]
fn test_lod_mesh_extraction() {
let mut aegis = Aegis3D::with_config(Aegis3DConfig {
mesh_resolution: 8,
init_depth: 2,
..Default::default()
});
let camera = Camera::look_at([0.0, 0.0, 3.0], [0.0, 0.0, 0.0], 8, 8, 60.0);
aegis.add_view(&vec![1.5; 64], camera);
let mesh_coarse = aegis.extract_mesh_lod(8, 1);
let mesh_fine = aegis.extract_mesh_lod(8, 4);
assert!(mesh_coarse.num_vertices() >= 0);
assert!(mesh_fine.num_vertices() >= 0);
}
#[test]
fn test_depth_loss() {
let rendered = vec![1.0, 2.0, 3.0, 4.0];
let observed = vec![1.1, 2.1, 3.1, 4.1];
let mask = vec![1.0, 1.0, 0.0, 1.0];
let loss = compute_depth_loss(&rendered, &observed, &mask);
assert!((loss - 0.01).abs() < 0.001);
}
}