use threecrate_core::{PointCloud, Result, Point3f, Vector3f, NormalPoint3f, Error};
use nalgebra::Matrix3;
use rayon::prelude::*;
use std::collections::BinaryHeap;
use std::cmp::Ordering;
#[derive(Debug, Clone)]
pub struct NormalEstimationConfig {
pub k_neighbors: usize,
pub radius: Option<f32>,
pub consistent_orientation: bool,
pub viewpoint: Option<Point3f>,
}
impl Default for NormalEstimationConfig {
fn default() -> Self {
Self {
k_neighbors: 10,
radius: None,
consistent_orientation: true,
viewpoint: None,
}
}
}
#[derive(Debug, Clone)]
struct Neighbor {
index: usize,
distance: f32,
}
impl PartialEq for Neighbor {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for Neighbor {}
impl PartialOrd for Neighbor {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Neighbor {
fn cmp(&self, other: &Self) -> Ordering {
other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
}
}
fn find_k_nearest_neighbors(points: &[Point3f], query_idx: usize, k: usize) -> Vec<usize> {
let query = &points[query_idx];
let mut heap = BinaryHeap::with_capacity(k + 1);
for (i, point) in points.iter().enumerate() {
if i == query_idx {
continue; }
let distance = (point - query).magnitude_squared();
let neighbor = Neighbor { index: i, distance };
if heap.len() < k {
heap.push(neighbor);
} else if let Some(farthest) = heap.peek() {
if neighbor.distance < farthest.distance {
heap.pop();
heap.push(neighbor);
}
}
}
heap.into_iter().map(|n| n.index).collect()
}
fn find_radius_neighbors(points: &[Point3f], query_idx: usize, radius: f32) -> Vec<usize> {
let query = &points[query_idx];
let radius_squared = radius * radius;
points.iter()
.enumerate()
.filter(|(i, point)| {
*i != query_idx && (**point - query).magnitude_squared() <= radius_squared
})
.map(|(i, _)| i)
.collect()
}
fn find_neighbors(points: &[Point3f], query_idx: usize, config: &NormalEstimationConfig) -> Vec<usize> {
if let Some(radius) = config.radius {
find_radius_neighbors(points, query_idx, radius)
} else {
find_k_nearest_neighbors(points, query_idx, config.k_neighbors)
}
}
fn compute_normal_pca(points: &[Point3f], indices: &[usize]) -> Vector3f {
if indices.len() < 3 {
return Vector3f::new(0.0, 0.0, 1.0);
}
let mut centroid = Point3f::origin();
for &idx in indices {
centroid += points[idx].coords;
}
centroid /= indices.len() as f32;
let mut covariance = Matrix3::zeros();
for &idx in indices {
let diff = points[idx] - centroid;
covariance += diff * diff.transpose();
}
covariance /= indices.len() as f32;
let eigen = covariance.symmetric_eigen();
let eigenvalues = eigen.eigenvalues;
let eigenvectors = eigen.eigenvectors;
let mut min_idx = 0;
for i in 1..3 {
if eigenvalues[i] < eigenvalues[min_idx] {
min_idx = i;
}
}
let mut normal: Vector3f = eigenvectors.column(min_idx).into();
let magnitude = normal.magnitude();
if magnitude > 1e-6 {
normal /= magnitude;
} else {
normal = Vector3f::new(0.0, 0.0, 1.0);
}
normal
}
fn orient_normal_towards_viewpoint(normal: Vector3f, point: Point3f, viewpoint: Point3f) -> Vector3f {
let to_viewpoint = (viewpoint - point).normalize();
let dot_product = normal.dot(&to_viewpoint);
if dot_product < 0.0 {
-normal
} else {
normal
}
}
pub fn estimate_normals(cloud: &PointCloud<Point3f>, k: usize) -> Result<PointCloud<NormalPoint3f>> {
let config = NormalEstimationConfig {
k_neighbors: k,
..Default::default()
};
estimate_normals_with_config(cloud, &config)
}
pub fn estimate_normals_with_config(
cloud: &PointCloud<Point3f>,
config: &NormalEstimationConfig
) -> Result<PointCloud<NormalPoint3f>> {
if cloud.is_empty() {
return Ok(PointCloud::new());
}
if config.k_neighbors < 3 {
return Err(Error::InvalidData("k_neighbors must be at least 3".to_string()));
}
let points = &cloud.points;
let viewpoint = config.viewpoint.unwrap_or_else(|| {
let mut min_x = points[0].x;
let mut min_y = points[0].y;
let mut min_z = points[0].z;
let mut max_x = points[0].x;
let mut max_y = points[0].y;
let mut max_z = points[0].z;
for point in points {
min_x = min_x.min(point.x);
min_y = min_y.min(point.y);
min_z = min_z.min(point.z);
max_x = max_x.max(point.x);
max_y = max_y.max(point.y);
max_z = max_z.max(point.z);
}
let center = Point3f::new(
(min_x + max_x) / 2.0,
(min_y + max_y) / 2.0,
(min_z + max_z) / 2.0,
);
let extent = ((max_x - min_x).powi(2) + (max_y - min_y).powi(2) + (max_z - min_z).powi(2)).sqrt();
center + Vector3f::new(0.0, 0.0, extent)
});
let normals: Vec<NormalPoint3f> = (0..points.len())
.into_par_iter()
.map(|i| {
let neighbors = find_neighbors(points, i, config);
let mut neighborhood = neighbors;
if config.radius.is_some() && neighborhood.len() < config.k_neighbors {
neighborhood = find_k_nearest_neighbors(points, i, config.k_neighbors);
}
if neighborhood.len() < 3 {
neighborhood = find_k_nearest_neighbors(points, i, config.k_neighbors.max(5));
}
let mut normal = compute_normal_pca(points, &neighborhood);
if config.consistent_orientation {
normal = orient_normal_towards_viewpoint(normal, points[i], viewpoint);
}
NormalPoint3f {
position: points[i],
normal,
}
})
.collect();
Ok(PointCloud::from_points(normals))
}
pub fn estimate_normals_radius(
cloud: &PointCloud<Point3f>,
radius: f32,
consistent_orientation: bool
) -> Result<PointCloud<NormalPoint3f>> {
let config = NormalEstimationConfig {
k_neighbors: 10, radius: Some(radius),
consistent_orientation,
viewpoint: None,
};
estimate_normals_with_config(cloud, &config)
}
#[deprecated(note = "Use estimate_normals instead which returns a new point cloud")]
pub fn estimate_normals_inplace(_cloud: &mut PointCloud<Point3f>, k: usize) -> Result<()> {
let _ = k;
Err(Error::Unsupported("Use estimate_normals instead".to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_normals_simple() {
let mut cloud = PointCloud::new();
cloud.push(Point3f::new(0.0, 0.0, 0.0));
cloud.push(Point3f::new(1.0, 0.0, 0.0));
cloud.push(Point3f::new(0.0, 1.0, 0.0));
cloud.push(Point3f::new(1.0, 1.0, 0.0));
cloud.push(Point3f::new(0.5, 0.5, 0.0));
let result = estimate_normals(&cloud, 3).unwrap();
assert_eq!(result.len(), 5);
for point in result.iter() {
let normal = point.normal;
assert!(normal.z.abs() > 0.8, "Normal should be primarily in Z direction: {:?}", normal);
}
}
#[test]
fn test_estimate_normals_empty() {
let cloud = PointCloud::<Point3f>::new();
let result = estimate_normals(&cloud, 5).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_estimate_normals_insufficient_k() {
let mut cloud = PointCloud::new();
cloud.push(Point3f::new(0.0, 0.0, 0.0));
let result = estimate_normals(&cloud, 2);
assert!(result.is_err());
}
#[test]
fn test_estimate_normals_radius() {
let mut cloud = PointCloud::new();
for i in 0..20 {
for j in 0..20 {
let x = (i as f32) * 0.1;
let y = (j as f32) * 0.1;
let z = 0.0;
cloud.push(Point3f::new(x, y, z));
}
}
let result = estimate_normals_radius(&cloud, 0.2, true).unwrap();
assert_eq!(result.len(), 400);
let mut z_direction_count = 0;
for point in result.iter() {
let normal_magnitude = point.normal.magnitude();
assert!((normal_magnitude - 1.0).abs() < 0.1, "Normal should be unit vector: magnitude={}", normal_magnitude);
if point.normal.z.abs() > 0.8 {
z_direction_count += 1;
}
}
let percentage = (z_direction_count as f32 / result.len() as f32) * 100.0;
assert!(percentage > 80.0, "Only {:.1}% of normals are in Z direction", percentage);
}
#[test]
fn test_estimate_normals_cylinder() {
let mut cloud = PointCloud::new();
for i in 0..10 {
for j in 0..10 {
let angle = (i as f32) * 0.6;
let height = (j as f32) * 0.2 - 1.0;
let x = angle.cos();
let y = angle.sin();
let z = height;
cloud.push(Point3f::new(x, y, z));
}
}
let config = NormalEstimationConfig {
k_neighbors: 8, radius: None,
consistent_orientation: true,
viewpoint: Some(Point3f::new(0.0, 0.0, 2.0)), };
let result = estimate_normals_with_config(&cloud, &config).unwrap();
assert_eq!(result.len(), 100);
let mut perpendicular_count = 0;
let mut outward_count = 0;
for point in result.iter() {
let normal_magnitude = point.normal.magnitude();
assert!((normal_magnitude - 1.0).abs() < 0.1, "Normal should be unit vector: magnitude={}", normal_magnitude);
let dot_with_z = point.normal.z.abs();
if dot_with_z < 0.8 {
perpendicular_count += 1;
}
let to_center = Vector3f::new(-point.position.x, -point.position.y, 0.0).normalize();
let dot_outward = point.normal.dot(&to_center);
if dot_outward > 0.5 {
outward_count += 1;
}
}
let percentage_perpendicular = (perpendicular_count as f32 / result.len() as f32) * 100.0;
let percentage_outward = (outward_count as f32 / result.len() as f32) * 100.0;
println!("Cylinder test: {:.1}% perpendicular to Z, {:.1}% pointing outward",
percentage_perpendicular, percentage_outward);
assert!(percentage_perpendicular > 60.0, "Only {:.1}% of normals are perpendicular to Z-axis", percentage_perpendicular);
}
#[test]
fn test_estimate_normals_orientation_consistency() {
let mut cloud = PointCloud::new();
cloud.push(Point3f::new(0.0, 0.0, 0.0));
cloud.push(Point3f::new(1.0, 0.0, 0.0));
cloud.push(Point3f::new(0.0, 1.0, 0.0));
cloud.push(Point3f::new(1.0, 1.0, 0.0));
let config_consistent = NormalEstimationConfig {
k_neighbors: 3,
radius: None,
consistent_orientation: true,
viewpoint: Some(Point3f::new(0.0, 0.0, 1.0)), };
let result_consistent = estimate_normals_with_config(&cloud, &config_consistent).unwrap();
let config_inconsistent = NormalEstimationConfig {
k_neighbors: 3,
radius: None,
consistent_orientation: false,
viewpoint: None,
};
let _result_inconsistent = estimate_normals_with_config(&cloud, &config_inconsistent).unwrap();
let first_normal_consistent = result_consistent.points[0].normal.z;
for point in result_consistent.iter() {
assert!((point.normal.z * first_normal_consistent) > 0.0,
"Normals should have consistent orientation");
}
println!("Consistent orientation test completed");
}
#[test]
fn test_find_neighbors() {
let points = vec![
Point3f::new(0.0, 0.0, 0.0),
Point3f::new(1.0, 0.0, 0.0),
Point3f::new(0.0, 1.0, 0.0),
Point3f::new(2.0, 0.0, 0.0),
];
let config_knn = NormalEstimationConfig {
k_neighbors: 2,
radius: None,
consistent_orientation: false,
viewpoint: None,
};
let neighbors_knn = find_neighbors(&points, 0, &config_knn);
assert_eq!(neighbors_knn.len(), 2);
let config_radius = NormalEstimationConfig {
k_neighbors: 10,
radius: Some(1.5),
consistent_orientation: false,
viewpoint: None,
};
let neighbors_radius = find_neighbors(&points, 0, &config_radius);
assert_eq!(neighbors_radius.len(), 2); }
}