use threecrate_core::{PointCloud, Result, Point3f, Vector3f, NormalPoint3f, Error};
use nalgebra::Matrix3;
use rayon::prelude::*;
use std::collections::BinaryHeap;
use std::cmp::Ordering;
#[derive(Debug, Clone)]
struct Neighbor {
index: usize,
distance: f32,
}
impl PartialEq for Neighbor {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for Neighbor {}
impl PartialOrd for Neighbor {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.distance.partial_cmp(&self.distance)
}
}
impl Ord for Neighbor {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap_or(Ordering::Equal)
}
}
fn find_k_nearest_neighbors(points: &[Point3f], query_idx: usize, k: usize) -> Vec<usize> {
let query = &points[query_idx];
let mut heap = BinaryHeap::with_capacity(k + 1);
for (i, point) in points.iter().enumerate() {
if i == query_idx {
continue; }
let distance = (point - query).magnitude_squared();
let neighbor = Neighbor { index: i, distance };
if heap.len() < k {
heap.push(neighbor);
} else if let Some(farthest) = heap.peek() {
if neighbor.distance < farthest.distance {
heap.pop();
heap.push(neighbor);
}
}
}
heap.into_iter().map(|n| n.index).collect()
}
fn compute_normal_pca(points: &[Point3f], indices: &[usize]) -> Vector3f {
if indices.len() < 3 {
return Vector3f::new(0.0, 0.0, 1.0);
}
let mut centroid = Point3f::origin();
for &idx in indices {
centroid += points[idx].coords;
}
centroid /= indices.len() as f32;
let mut covariance = Matrix3::zeros();
for &idx in indices {
let diff = points[idx] - centroid;
covariance += diff * diff.transpose();
}
covariance /= indices.len() as f32;
let eigen = covariance.symmetric_eigen();
let eigenvalues = eigen.eigenvalues;
let eigenvectors = eigen.eigenvectors;
let mut min_idx = 0;
for i in 1..3 {
if eigenvalues[i] < eigenvalues[min_idx] {
min_idx = i;
}
}
let normal = eigenvectors.column(min_idx).into();
normal
}
pub fn estimate_normals(cloud: &PointCloud<Point3f>, k: usize) -> Result<PointCloud<NormalPoint3f>> {
if cloud.is_empty() {
return Ok(PointCloud::new());
}
if k < 3 {
return Err(Error::InvalidData("k must be at least 3".to_string()));
}
let points = &cloud.points;
let normals: Vec<NormalPoint3f> = (0..points.len())
.into_par_iter()
.map(|i| {
let neighbors = find_k_nearest_neighbors(points, i, k);
let mut neighborhood = vec![i]; neighborhood.extend(neighbors);
let normal = compute_normal_pca(points, &neighborhood);
NormalPoint3f {
position: points[i],
normal,
}
})
.collect();
Ok(PointCloud::from_points(normals))
}
#[deprecated(note = "Use estimate_normals instead which returns a new point cloud")]
pub fn estimate_normals_inplace(_cloud: &mut PointCloud<Point3f>, k: usize) -> Result<()> {
let _ = k;
Err(Error::Unsupported("Use estimate_normals instead".to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_normals_simple() {
let mut cloud = PointCloud::new();
cloud.push(Point3f::new(0.0, 0.0, 0.0));
cloud.push(Point3f::new(1.0, 0.0, 0.0));
cloud.push(Point3f::new(0.0, 1.0, 0.0));
cloud.push(Point3f::new(1.0, 1.0, 0.0));
cloud.push(Point3f::new(0.5, 0.5, 0.0));
let result = estimate_normals(&cloud, 3).unwrap();
assert_eq!(result.len(), 5);
for point in result.iter() {
let normal = point.normal;
assert!(normal.z.abs() > 0.8, "Normal should be primarily in Z direction: {:?}", normal);
}
}
#[test]
fn test_estimate_normals_empty() {
let cloud = PointCloud::<Point3f>::new();
let result = estimate_normals(&cloud, 5).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_estimate_normals_insufficient_k() {
let mut cloud = PointCloud::new();
cloud.push(Point3f::new(0.0, 0.0, 0.0));
let result = estimate_normals(&cloud, 2);
assert!(result.is_err());
}
}