mesh_repair/
registration.rs

1//! Mesh registration and alignment algorithms.
2//!
3//! This module provides tools for aligning meshes to each other, including:
4//! - Iterative Closest Point (ICP) for rigid alignment
5//! - Feature-based registration using landmarks
6//! - Non-rigid/deformable registration
7//!
8//! # Use Cases
9//!
10//! - Aligning a foot scan to a template last
11//! - Registering multiple partial scans of the same object
12//! - Matching a head scan to a helmet template
13//!
14//! # Example
15//!
16//! ```
17//! use mesh_repair::{Mesh, Vertex};
18//! use mesh_repair::registration::{RegistrationParams, align_meshes};
19//! use nalgebra::Point3;
20//!
21//! // Create source and target meshes
22//! let mut source = Mesh::new();
23//! source.vertices.push(Vertex::from_coords(0.0, 0.0, 0.0));
24//! source.vertices.push(Vertex::from_coords(1.0, 0.0, 0.0));
25//! source.vertices.push(Vertex::from_coords(0.5, 1.0, 0.0));
26//! source.faces.push([0, 1, 2]);
27//!
28//! let mut target = Mesh::new();
29//! target.vertices.push(Vertex::from_coords(1.0, 1.0, 0.0));
30//! target.vertices.push(Vertex::from_coords(2.0, 1.0, 0.0));
31//! target.vertices.push(Vertex::from_coords(1.5, 2.0, 0.0));
32//! target.faces.push([0, 1, 2]);
33//!
34//! // Align source to target using ICP
35//! let params = RegistrationParams::icp();
36//! let result = align_meshes(&source, &target, &params).unwrap();
37//!
38//! println!("RMS error: {:.3} mm", result.rms_error);
39//! println!("Converged: {}", result.converged);
40//! ```
41
42use crate::{Mesh, MeshError, MeshResult};
43use nalgebra::{Matrix3, Matrix4, Point3, Rotation3, UnitQuaternion, Vector3};
44
45/// Parameters for mesh registration.
46#[derive(Debug, Clone)]
47pub struct RegistrationParams {
48    /// The registration algorithm to use.
49    pub algorithm: RegistrationAlgorithm,
50
51    /// Maximum number of iterations for iterative algorithms.
52    pub max_iterations: usize,
53
54    /// Convergence threshold for RMS error change.
55    pub convergence_threshold: f64,
56
57    /// Maximum correspondence distance (points further apart are ignored).
58    pub max_correspondence_distance: f64,
59
60    /// Optional landmark correspondences for feature-based registration.
61    pub landmarks: Vec<Landmark>,
62
63    /// Whether to allow scaling in addition to rigid transformation.
64    pub allow_scaling: bool,
65
66    /// Subsample ratio for large meshes (0.0-1.0).
67    /// 1.0 uses all points, 0.1 uses 10% of points.
68    pub subsample_ratio: f64,
69}
70
71impl Default for RegistrationParams {
72    fn default() -> Self {
73        Self {
74            algorithm: RegistrationAlgorithm::Icp,
75            max_iterations: 100,
76            convergence_threshold: 1e-6,
77            max_correspondence_distance: f64::INFINITY,
78            landmarks: Vec::new(),
79            allow_scaling: false,
80            subsample_ratio: 1.0,
81        }
82    }
83}
84
85impl RegistrationParams {
86    /// Create params for ICP registration.
87    pub fn icp() -> Self {
88        Self::default()
89    }
90
91    /// Create params for point-to-plane ICP (faster convergence).
92    pub fn icp_point_to_plane() -> Self {
93        Self {
94            algorithm: RegistrationAlgorithm::IcpPointToPlane,
95            ..Default::default()
96        }
97    }
98
99    /// Create params for landmark-based registration.
100    pub fn landmark_based(landmarks: Vec<Landmark>) -> Self {
101        Self {
102            algorithm: RegistrationAlgorithm::Landmark,
103            landmarks,
104            ..Default::default()
105        }
106    }
107
108    /// Create params for combined landmark + ICP registration.
109    pub fn landmark_then_icp(landmarks: Vec<Landmark>) -> Self {
110        Self {
111            algorithm: RegistrationAlgorithm::LandmarkThenIcp,
112            landmarks,
113            ..Default::default()
114        }
115    }
116
117    /// Set maximum iterations.
118    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
119        self.max_iterations = max_iterations;
120        self
121    }
122
123    /// Set convergence threshold.
124    pub fn with_convergence_threshold(mut self, threshold: f64) -> Self {
125        self.convergence_threshold = threshold;
126        self
127    }
128
129    /// Set maximum correspondence distance.
130    pub fn with_max_correspondence_distance(mut self, distance: f64) -> Self {
131        self.max_correspondence_distance = distance;
132        self
133    }
134
135    /// Allow scaling in addition to rigid transformation.
136    pub fn with_scaling(mut self) -> Self {
137        self.allow_scaling = true;
138        self
139    }
140
141    /// Set subsample ratio for large meshes.
142    pub fn with_subsample_ratio(mut self, ratio: f64) -> Self {
143        self.subsample_ratio = ratio.clamp(0.01, 1.0);
144        self
145    }
146}
147
148/// The registration algorithm to use.
149#[derive(Debug, Clone, Copy, PartialEq)]
150pub enum RegistrationAlgorithm {
151    /// Standard point-to-point ICP.
152    Icp,
153    /// Point-to-plane ICP (requires normals).
154    IcpPointToPlane,
155    /// Landmark-based (requires corresponding points).
156    Landmark,
157    /// Landmark alignment followed by ICP refinement.
158    LandmarkThenIcp,
159    /// Non-rigid/deformable registration (allows local deformations).
160    NonRigid,
161}
162
163/// A landmark correspondence for feature-based registration.
164#[derive(Debug, Clone)]
165pub struct Landmark {
166    /// Position on the source mesh.
167    pub source: Point3<f64>,
168    /// Corresponding position on the target mesh.
169    pub target: Point3<f64>,
170    /// Optional weight for this landmark.
171    pub weight: f64,
172}
173
174impl Landmark {
175    /// Create a new landmark correspondence.
176    pub fn new(source: Point3<f64>, target: Point3<f64>) -> Self {
177        Self {
178            source,
179            target,
180            weight: 1.0,
181        }
182    }
183
184    /// Create a weighted landmark correspondence.
185    pub fn weighted(source: Point3<f64>, target: Point3<f64>, weight: f64) -> Self {
186        Self {
187            source,
188            target,
189            weight,
190        }
191    }
192}
193
194/// Result of a registration operation.
195#[derive(Debug, Clone)]
196pub struct RegistrationResult {
197    /// The aligned mesh (source transformed to match target).
198    pub mesh: Mesh,
199
200    /// The transformation that was applied.
201    pub transformation: RigidTransform,
202
203    /// RMS (root mean square) error of the alignment.
204    pub rms_error: f64,
205
206    /// Maximum error (maximum distance between corresponding points).
207    pub max_error: f64,
208
209    /// Number of iterations performed.
210    pub iterations: usize,
211
212    /// Whether the algorithm converged.
213    pub converged: bool,
214
215    /// Number of valid correspondences used.
216    pub correspondences_used: usize,
217}
218
219impl RegistrationResult {
220    /// Check if the registration quality is acceptable.
221    ///
222    /// Returns true if converged and RMS error is below threshold.
223    pub fn is_acceptable(&self, max_rms_error: f64) -> bool {
224        self.converged && self.rms_error <= max_rms_error
225    }
226}
227
228/// Parameters specific to non-rigid registration.
229#[derive(Debug, Clone)]
230pub struct NonRigidParams {
231    /// Stiffness weight for regularization (higher = more rigid, less local deformation).
232    /// Range: 0.0 (fully flexible) to 100.0 (nearly rigid).
233    /// Default: 10.0
234    pub stiffness: f64,
235
236    /// Number of control points for the deformation field.
237    /// More control points allow finer deformations but increase computation.
238    /// If None, uses a subset of source vertices (default: ~500 points).
239    pub num_control_points: Option<usize>,
240
241    /// Landmark correspondences for guided deformation.
242    /// These serve as hard constraints during non-rigid registration.
243    pub landmarks: Vec<Landmark>,
244
245    /// Whether to apply an initial rigid alignment before non-rigid deformation.
246    /// Default: true
247    pub initial_rigid_alignment: bool,
248
249    /// Number of outer iterations for the non-rigid optimization.
250    /// Default: 10
251    pub outer_iterations: usize,
252
253    /// Smoothness parameter for RBF interpolation.
254    /// Higher values produce smoother deformations.
255    /// Default: 1.0
256    pub smoothness: f64,
257}
258
259impl Default for NonRigidParams {
260    fn default() -> Self {
261        Self {
262            stiffness: 10.0,
263            num_control_points: None,
264            landmarks: Vec::new(),
265            initial_rigid_alignment: true,
266            outer_iterations: 10,
267            smoothness: 1.0,
268        }
269    }
270}
271
272impl NonRigidParams {
273    /// Create default non-rigid registration params.
274    pub fn new() -> Self {
275        Self::default()
276    }
277
278    /// Set stiffness (regularization weight).
279    pub fn with_stiffness(mut self, stiffness: f64) -> Self {
280        self.stiffness = stiffness.max(0.0);
281        self
282    }
283
284    /// Set number of control points.
285    pub fn with_control_points(mut self, num_points: usize) -> Self {
286        self.num_control_points = Some(num_points.max(10));
287        self
288    }
289
290    /// Add landmark constraints.
291    pub fn with_landmarks(mut self, landmarks: Vec<Landmark>) -> Self {
292        self.landmarks = landmarks;
293        self
294    }
295
296    /// Disable initial rigid alignment.
297    pub fn without_initial_alignment(mut self) -> Self {
298        self.initial_rigid_alignment = false;
299        self
300    }
301
302    /// Set number of outer iterations.
303    pub fn with_outer_iterations(mut self, iterations: usize) -> Self {
304        self.outer_iterations = iterations.max(1);
305        self
306    }
307
308    /// Set smoothness parameter.
309    pub fn with_smoothness(mut self, smoothness: f64) -> Self {
310        self.smoothness = smoothness.max(0.01);
311        self
312    }
313}
314
315/// Result of a non-rigid registration operation.
316#[derive(Debug, Clone)]
317pub struct NonRigidRegistrationResult {
318    /// The aligned and deformed mesh.
319    pub mesh: Mesh,
320
321    /// Per-vertex displacements applied.
322    pub displacements: Vec<Vector3<f64>>,
323
324    /// Initial rigid transformation (if applied).
325    pub initial_transform: Option<RigidTransform>,
326
327    /// RMS error of the final alignment.
328    pub rms_error: f64,
329
330    /// Maximum error.
331    pub max_error: f64,
332
333    /// Number of iterations performed.
334    pub iterations: usize,
335
336    /// Whether the algorithm converged.
337    pub converged: bool,
338
339    /// Number of correspondences used.
340    pub correspondences_used: usize,
341
342    /// Average displacement magnitude.
343    pub average_displacement: f64,
344
345    /// Maximum displacement magnitude.
346    pub max_displacement: f64,
347}
348
349impl NonRigidRegistrationResult {
350    /// Check if the registration quality is acceptable.
351    pub fn is_acceptable(&self, max_rms_error: f64) -> bool {
352        self.converged && self.rms_error <= max_rms_error
353    }
354
355    /// Get the deformation field as a list of (original_position, displacement) pairs.
356    pub fn deformation_field(&self, original_mesh: &Mesh) -> Vec<(Point3<f64>, Vector3<f64>)> {
357        original_mesh
358            .vertices
359            .iter()
360            .zip(self.displacements.iter())
361            .map(|(v, d)| (v.position, *d))
362            .collect()
363    }
364}
365
366/// A rigid transformation (rotation + translation, optionally with scale).
367#[derive(Debug, Clone)]
368pub struct RigidTransform {
369    /// Rotation quaternion.
370    pub rotation: UnitQuaternion<f64>,
371    /// Translation vector.
372    pub translation: Vector3<f64>,
373    /// Uniform scale factor (1.0 = no scaling).
374    pub scale: f64,
375}
376
377impl Default for RigidTransform {
378    fn default() -> Self {
379        Self::identity()
380    }
381}
382
383impl RigidTransform {
384    /// Create an identity transformation.
385    pub fn identity() -> Self {
386        Self {
387            rotation: UnitQuaternion::identity(),
388            translation: Vector3::zeros(),
389            scale: 1.0,
390        }
391    }
392
393    /// Create a pure translation.
394    pub fn from_translation(translation: Vector3<f64>) -> Self {
395        Self {
396            rotation: UnitQuaternion::identity(),
397            translation,
398            scale: 1.0,
399        }
400    }
401
402    /// Create a pure rotation.
403    pub fn from_rotation(rotation: UnitQuaternion<f64>) -> Self {
404        Self {
405            rotation,
406            translation: Vector3::zeros(),
407            scale: 1.0,
408        }
409    }
410
411    /// Create a transformation from rotation and translation.
412    pub fn from_rotation_translation(
413        rotation: UnitQuaternion<f64>,
414        translation: Vector3<f64>,
415    ) -> Self {
416        Self {
417            rotation,
418            translation,
419            scale: 1.0,
420        }
421    }
422
423    /// Apply the transformation to a point.
424    pub fn transform_point(&self, point: &Point3<f64>) -> Point3<f64> {
425        let scaled = point.coords * self.scale;
426        let rotated = self.rotation * Point3::from(scaled);
427        Point3::from(rotated.coords + self.translation)
428    }
429
430    /// Apply the transformation to a vector (no translation).
431    pub fn transform_vector(&self, vector: &Vector3<f64>) -> Vector3<f64> {
432        let scaled = vector * self.scale;
433        self.rotation * scaled
434    }
435
436    /// Compose with another transformation (self applied first, then other).
437    pub fn then(&self, other: &RigidTransform) -> RigidTransform {
438        RigidTransform {
439            rotation: other.rotation * self.rotation,
440            translation: other.rotation * (self.translation * other.scale) + other.translation,
441            scale: self.scale * other.scale,
442        }
443    }
444
445    /// Get the inverse transformation.
446    pub fn inverse(&self) -> RigidTransform {
447        let inv_rotation = self.rotation.inverse();
448        let inv_scale = 1.0 / self.scale;
449        let inv_translation = inv_rotation * (-self.translation * inv_scale);
450        RigidTransform {
451            rotation: inv_rotation,
452            translation: inv_translation,
453            scale: inv_scale,
454        }
455    }
456
457    /// Convert to a 4x4 homogeneous transformation matrix.
458    pub fn to_matrix4(&self) -> Matrix4<f64> {
459        let rotation_matrix = self.rotation.to_rotation_matrix();
460        let mut result = Matrix4::identity();
461
462        for i in 0..3 {
463            for j in 0..3 {
464                result[(i, j)] = rotation_matrix[(i, j)] * self.scale;
465            }
466            result[(i, 3)] = self.translation[i];
467        }
468
469        result
470    }
471}
472
473/// Align a source mesh to a target mesh.
474///
475/// # Arguments
476///
477/// * `source` - The mesh to transform
478/// * `target` - The reference mesh to align to
479/// * `params` - Registration parameters
480///
481/// # Returns
482///
483/// A `RegistrationResult` containing the aligned mesh and transformation.
484pub fn align_meshes(
485    source: &Mesh,
486    target: &Mesh,
487    params: &RegistrationParams,
488) -> MeshResult<RegistrationResult> {
489    if source.is_empty() || target.is_empty() {
490        return Err(MeshError::EmptyMesh {
491            details: "Cannot align empty meshes".to_string(),
492        });
493    }
494
495    match params.algorithm {
496        RegistrationAlgorithm::Icp => icp_align(source, target, params, false),
497        RegistrationAlgorithm::IcpPointToPlane => icp_align(source, target, params, true),
498        RegistrationAlgorithm::Landmark => landmark_align(source, target, params),
499        RegistrationAlgorithm::LandmarkThenIcp => {
500            // First do landmark alignment
501            let landmark_result = landmark_align(source, target, params)?;
502
503            // Then refine with ICP
504            let mut icp_params = params.clone();
505            icp_params.algorithm = RegistrationAlgorithm::Icp;
506            let icp_result = icp_align(&landmark_result.mesh, target, &icp_params, false)?;
507
508            // Compose transformations
509            let total_transform = landmark_result
510                .transformation
511                .then(&icp_result.transformation);
512            let total_iterations = landmark_result.iterations + icp_result.iterations;
513
514            Ok(RegistrationResult {
515                mesh: icp_result.mesh,
516                transformation: total_transform,
517                rms_error: icp_result.rms_error,
518                max_error: icp_result.max_error,
519                iterations: total_iterations,
520                converged: icp_result.converged,
521                correspondences_used: icp_result.correspondences_used,
522            })
523        }
524        RegistrationAlgorithm::NonRigid => {
525            // Non-rigid registration should use the dedicated function
526            // Here we return a minimal result; use non_rigid_align for full results
527            let nr_result = non_rigid_align(source, target, &NonRigidParams::default(), params)?;
528            Ok(RegistrationResult {
529                mesh: nr_result.mesh,
530                transformation: nr_result
531                    .initial_transform
532                    .unwrap_or_else(RigidTransform::identity),
533                rms_error: nr_result.rms_error,
534                max_error: nr_result.max_error,
535                iterations: nr_result.iterations,
536                converged: nr_result.converged,
537                correspondences_used: nr_result.correspondences_used,
538            })
539        }
540    }
541}
542
543/// ICP registration implementation.
544fn icp_align(
545    source: &Mesh,
546    target: &Mesh,
547    params: &RegistrationParams,
548    _point_to_plane: bool,
549) -> MeshResult<RegistrationResult> {
550    // Build target point cloud for nearest neighbor queries
551    let target_points: Vec<Point3<f64>> = target.vertices.iter().map(|v| v.position).collect();
552
553    // Subsample source points if needed
554    let source_indices: Vec<usize> = if params.subsample_ratio < 1.0 {
555        let step = (1.0 / params.subsample_ratio).ceil() as usize;
556        (0..source.vertex_count()).step_by(step).collect()
557    } else {
558        (0..source.vertex_count()).collect()
559    };
560
561    let mut current_transform = RigidTransform::identity();
562    let mut previous_rms = f64::INFINITY;
563    let mut converged = false;
564    let mut iterations = 0;
565
566    // Transform source points
567    let mut transformed_source: Vec<Point3<f64>> = source_indices
568        .iter()
569        .map(|&i| source.vertices[i].position)
570        .collect();
571
572    for iter in 0..params.max_iterations {
573        iterations = iter + 1;
574
575        // Find correspondences (nearest neighbors)
576        let mut correspondences: Vec<(Point3<f64>, Point3<f64>)> = Vec::new();
577        let mut total_error_sq = 0.0;
578        let mut max_error = 0.0f64;
579
580        for source_point in &transformed_source {
581            // Find nearest point in target (brute force for now - could use KD-tree)
582            let (nearest, dist_sq) = find_nearest_point(source_point, &target_points);
583
584            let dist = dist_sq.sqrt();
585            if dist <= params.max_correspondence_distance {
586                correspondences.push((*source_point, nearest));
587                total_error_sq += dist_sq;
588                max_error = max_error.max(dist);
589            }
590        }
591
592        if correspondences.is_empty() {
593            return Err(MeshError::RepairFailed {
594                details: "No valid correspondences found".to_string(),
595            });
596        }
597
598        let rms_error = (total_error_sq / correspondences.len() as f64).sqrt();
599
600        // Check convergence
601        if (previous_rms - rms_error).abs() < params.convergence_threshold {
602            converged = true;
603            break;
604        }
605        previous_rms = rms_error;
606
607        // Compute optimal transformation for this iteration
608        let (source_pts, target_pts): (Vec<_>, Vec<_>) = correspondences.into_iter().unzip();
609
610        let iter_transform =
611            compute_rigid_transform(&source_pts, &target_pts, params.allow_scaling);
612
613        // Update cumulative transform and transformed points
614        current_transform = current_transform.then(&iter_transform);
615
616        for point in &mut transformed_source {
617            *point = iter_transform.transform_point(point);
618        }
619    }
620
621    // Apply final transformation to create result mesh
622    let mut result_mesh = source.clone();
623    for vertex in &mut result_mesh.vertices {
624        vertex.position = current_transform.transform_point(&vertex.position);
625        if let Some(ref mut normal) = vertex.normal {
626            *normal = current_transform.transform_vector(normal).normalize();
627        }
628    }
629
630    // Calculate final error metrics
631    let (rms_error, max_error, correspondences_used) = calculate_alignment_error(
632        &result_mesh,
633        &target_points,
634        params.max_correspondence_distance,
635    );
636
637    Ok(RegistrationResult {
638        mesh: result_mesh,
639        transformation: current_transform,
640        rms_error,
641        max_error,
642        iterations,
643        converged,
644        correspondences_used,
645    })
646}
647
648/// Landmark-based registration implementation.
649fn landmark_align(
650    source: &Mesh,
651    _target: &Mesh,
652    params: &RegistrationParams,
653) -> MeshResult<RegistrationResult> {
654    if params.landmarks.is_empty() {
655        return Err(MeshError::RepairFailed {
656            details: "No landmarks provided for landmark-based registration".to_string(),
657        });
658    }
659
660    if params.landmarks.len() < 3 {
661        return Err(MeshError::RepairFailed {
662            details: "At least 3 landmarks required for rigid registration".to_string(),
663        });
664    }
665
666    // Extract landmark positions
667    let source_points: Vec<Point3<f64>> = params.landmarks.iter().map(|l| l.source).collect();
668    let target_points: Vec<Point3<f64>> = params.landmarks.iter().map(|l| l.target).collect();
669
670    // Compute transformation
671    let transform = compute_rigid_transform(&source_points, &target_points, params.allow_scaling);
672
673    // Apply transformation to mesh
674    let mut result_mesh = source.clone();
675    for vertex in &mut result_mesh.vertices {
676        vertex.position = transform.transform_point(&vertex.position);
677        if let Some(ref mut normal) = vertex.normal {
678            *normal = transform.transform_vector(normal).normalize();
679        }
680    }
681
682    // Calculate landmark error
683    let mut total_error_sq = 0.0;
684    let mut max_error = 0.0f64;
685
686    for landmark in &params.landmarks {
687        let transformed = transform.transform_point(&landmark.source);
688        let error = (transformed - landmark.target).norm();
689        total_error_sq += error * error * landmark.weight;
690        max_error = max_error.max(error);
691    }
692
693    let rms_error = (total_error_sq / params.landmarks.len() as f64).sqrt();
694
695    Ok(RegistrationResult {
696        mesh: result_mesh,
697        transformation: transform,
698        rms_error,
699        max_error,
700        iterations: 1,
701        converged: true,
702        correspondences_used: params.landmarks.len(),
703    })
704}
705
706/// Find the nearest point in a list to a query point.
707fn find_nearest_point(query: &Point3<f64>, points: &[Point3<f64>]) -> (Point3<f64>, f64) {
708    let mut nearest = points[0];
709    let mut min_dist_sq = (query - nearest).norm_squared();
710
711    for point in points.iter().skip(1) {
712        let dist_sq = (query - point).norm_squared();
713        if dist_sq < min_dist_sq {
714            min_dist_sq = dist_sq;
715            nearest = *point;
716        }
717    }
718
719    (nearest, min_dist_sq)
720}
721
722/// Compute the optimal rigid transformation between point sets.
723///
724/// Uses the Kabsch algorithm (SVD-based) to find the rotation and translation
725/// that minimizes the RMSD between corresponding points.
726fn compute_rigid_transform(
727    source: &[Point3<f64>],
728    target: &[Point3<f64>],
729    allow_scaling: bool,
730) -> RigidTransform {
731    let n = source.len();
732    if n == 0 {
733        return RigidTransform::identity();
734    }
735
736    // Compute centroids
737    let source_centroid: Vector3<f64> =
738        source.iter().map(|p| p.coords).sum::<Vector3<f64>>() / n as f64;
739    let target_centroid: Vector3<f64> =
740        target.iter().map(|p| p.coords).sum::<Vector3<f64>>() / n as f64;
741
742    // Center the points
743    let centered_source: Vec<Vector3<f64>> =
744        source.iter().map(|p| p.coords - source_centroid).collect();
745    let centered_target: Vec<Vector3<f64>> =
746        target.iter().map(|p| p.coords - target_centroid).collect();
747
748    // Compute cross-covariance matrix H
749    let mut h = Matrix3::zeros();
750    for i in 0..n {
751        h += centered_source[i] * centered_target[i].transpose();
752    }
753
754    // SVD decomposition
755    let svd = h.svd(true, true);
756    let u = svd.u.unwrap();
757    let v_t = svd.v_t.unwrap();
758
759    // Compute rotation
760    let mut rotation_matrix = v_t.transpose() * u.transpose();
761
762    // Handle reflection case (det < 0)
763    if rotation_matrix.determinant() < 0.0 {
764        let mut v_t_fixed = v_t;
765        v_t_fixed.set_row(2, &(-v_t.row(2)));
766        rotation_matrix = v_t_fixed.transpose() * u.transpose();
767    }
768
769    let rotation =
770        UnitQuaternion::from_rotation_matrix(&Rotation3::from_matrix_unchecked(rotation_matrix));
771
772    // Compute scale if allowed
773    let scale = if allow_scaling {
774        let source_variance: f64 = centered_source.iter().map(|v| v.norm_squared()).sum();
775        let target_variance: f64 = centered_target.iter().map(|v| v.norm_squared()).sum();
776
777        if source_variance > 1e-10 {
778            (target_variance / source_variance).sqrt()
779        } else {
780            1.0
781        }
782    } else {
783        1.0
784    };
785
786    // Compute translation
787    let translation = target_centroid - scale * (rotation * source_centroid);
788
789    RigidTransform {
790        rotation,
791        translation,
792        scale,
793    }
794}
795
796/// Calculate alignment error between a mesh and target points.
797fn calculate_alignment_error(
798    mesh: &Mesh,
799    target_points: &[Point3<f64>],
800    max_distance: f64,
801) -> (f64, f64, usize) {
802    let mut total_error_sq = 0.0;
803    let mut max_error = 0.0f64;
804    let mut count = 0;
805
806    for vertex in &mesh.vertices {
807        let (_, dist_sq) = find_nearest_point(&vertex.position, target_points);
808        let dist = dist_sq.sqrt();
809
810        if dist <= max_distance {
811            total_error_sq += dist_sq;
812            max_error = max_error.max(dist);
813            count += 1;
814        }
815    }
816
817    let rms_error = if count > 0 {
818        (total_error_sq / count as f64).sqrt()
819    } else {
820        f64::INFINITY
821    };
822
823    (rms_error, max_error, count)
824}
825
826/// Perform non-rigid/deformable registration.
827///
828/// This algorithm allows local deformations while maintaining global smoothness,
829/// making it suitable for registering meshes with local shape differences
830/// (e.g., a foot scan to a template last with different proportions).
831///
832/// # Algorithm
833///
834/// 1. Optional initial rigid alignment using ICP
835/// 2. Select control points from source mesh
836/// 3. Iteratively:
837///    a. Find correspondences (nearest neighbors)
838///    b. Compute optimal displacements for control points
839///    c. Interpolate displacements to all vertices using RBF
840///    d. Apply regularization for smoothness
841///
842/// # Arguments
843///
844/// * `source` - The mesh to deform
845/// * `target` - The reference mesh to match
846/// * `nr_params` - Non-rigid registration parameters
847/// * `base_params` - Base registration parameters (max iterations, convergence, etc.)
848///
849/// # Returns
850///
851/// A `NonRigidRegistrationResult` containing the deformed mesh and displacement field.
852///
853/// # Example
854///
855/// ```
856/// use mesh_repair::{Mesh, Vertex};
857/// use mesh_repair::registration::{RegistrationParams, NonRigidParams, non_rigid_align};
858///
859/// let mut source = Mesh::new();
860/// source.vertices.push(Vertex::from_coords(0.0, 0.0, 0.0));
861/// source.vertices.push(Vertex::from_coords(10.0, 0.0, 0.0));
862/// source.vertices.push(Vertex::from_coords(5.0, 10.0, 0.0));
863/// source.faces.push([0, 1, 2]);
864///
865/// let mut target = Mesh::new();
866/// target.vertices.push(Vertex::from_coords(0.0, 0.0, 1.0));
867/// target.vertices.push(Vertex::from_coords(10.0, 0.0, 1.0));
868/// target.vertices.push(Vertex::from_coords(5.0, 12.0, 1.0)); // Slightly stretched
869/// target.faces.push([0, 1, 2]);
870///
871/// let nr_params = NonRigidParams::new().with_stiffness(5.0);
872/// let base_params = RegistrationParams::default();
873/// let result = non_rigid_align(&source, &target, &nr_params, &base_params).unwrap();
874///
875/// println!("RMS error: {:.3} mm", result.rms_error);
876/// println!("Max displacement: {:.3} mm", result.max_displacement);
877/// ```
878pub fn non_rigid_align(
879    source: &Mesh,
880    target: &Mesh,
881    nr_params: &NonRigidParams,
882    base_params: &RegistrationParams,
883) -> MeshResult<NonRigidRegistrationResult> {
884    if source.is_empty() || target.is_empty() {
885        return Err(MeshError::EmptyMesh {
886            details: "Cannot align empty meshes".to_string(),
887        });
888    }
889
890    let target_points: Vec<Point3<f64>> = target.vertices.iter().map(|v| v.position).collect();
891
892    // Step 1: Optional initial rigid alignment
893    let (working_mesh, initial_transform) = if nr_params.initial_rigid_alignment {
894        let rigid_params = RegistrationParams::icp()
895            .with_max_iterations(base_params.max_iterations / 2)
896            .with_convergence_threshold(base_params.convergence_threshold * 10.0);
897        let rigid_result = icp_align(source, target, &rigid_params, false)?;
898        (rigid_result.mesh, Some(rigid_result.transformation))
899    } else {
900        (source.clone(), None)
901    };
902
903    // Step 2: Select control points
904    let num_control_points = nr_params
905        .num_control_points
906        .unwrap_or_else(|| working_mesh.vertex_count().clamp(10, 500));
907
908    let control_indices = select_control_points(&working_mesh, num_control_points);
909    let num_controls = control_indices.len();
910
911    // Initialize control point positions and displacements
912    let mut control_positions: Vec<Point3<f64>> = control_indices
913        .iter()
914        .map(|&i| working_mesh.vertices[i].position)
915        .collect();
916
917    let mut control_displacements: Vec<Vector3<f64>> = vec![Vector3::zeros(); num_controls];
918
919    // Add landmark constraints as additional control points
920    let landmark_constraints: Vec<(Point3<f64>, Vector3<f64>)> = nr_params
921        .landmarks
922        .iter()
923        .map(|l| {
924            let displacement = l.target - l.source;
925            (l.source, displacement)
926        })
927        .collect();
928
929    // Step 3: Iterative non-rigid optimization
930    let mut converged = false;
931    let mut iterations = 0;
932    let mut previous_rms = f64::INFINITY;
933
934    // Current vertex positions
935    let mut current_positions: Vec<Point3<f64>> =
936        working_mesh.vertices.iter().map(|v| v.position).collect();
937
938    for outer_iter in 0..nr_params.outer_iterations {
939        iterations = outer_iter + 1;
940
941        // Find correspondences for control points
942        let mut correspondence_displacements: Vec<Vector3<f64>> = Vec::with_capacity(num_controls);
943
944        for (i, &ctrl_idx) in control_indices.iter().enumerate() {
945            let current_pos = current_positions[ctrl_idx];
946            let (nearest, _) = find_nearest_point(&current_pos, &target_points);
947
948            // Desired displacement to reach target
949            let desired_displacement = nearest - current_pos;
950
951            // Blend with current displacement based on stiffness
952            // Higher stiffness = smaller updates, more regularization
953            let alpha = 1.0 / (1.0 + nr_params.stiffness * 0.1);
954            let new_displacement = control_displacements[i] + desired_displacement * alpha;
955
956            correspondence_displacements.push(new_displacement);
957        }
958
959        // Apply regularization (Laplacian smoothing of displacements)
960        let regularized_displacements = regularize_displacements(
961            &control_positions,
962            &correspondence_displacements,
963            nr_params.stiffness,
964        );
965
966        control_displacements = regularized_displacements;
967
968        // Interpolate displacements to all vertices using RBF
969        let all_displacements = interpolate_displacements_rbf(
970            &control_positions,
971            &control_displacements,
972            &landmark_constraints,
973            &working_mesh
974                .vertices
975                .iter()
976                .map(|v| v.position)
977                .collect::<Vec<_>>(),
978            nr_params.smoothness,
979        );
980
981        // Update current positions
982        for (i, displacement) in all_displacements.iter().enumerate() {
983            current_positions[i] = working_mesh.vertices[i].position + displacement;
984        }
985
986        // Update control positions for next iteration
987        for (i, &ctrl_idx) in control_indices.iter().enumerate() {
988            control_positions[i] = current_positions[ctrl_idx];
989        }
990
991        // Calculate error
992        let mut total_error_sq = 0.0;
993        let mut count = 0;
994
995        for pos in &current_positions {
996            let (_, dist_sq) = find_nearest_point(pos, &target_points);
997            if dist_sq.sqrt() <= base_params.max_correspondence_distance {
998                total_error_sq += dist_sq;
999                count += 1;
1000            }
1001        }
1002
1003        let rms_error = if count > 0 {
1004            (total_error_sq / count as f64).sqrt()
1005        } else {
1006            f64::INFINITY
1007        };
1008
1009        // Check convergence
1010        if (previous_rms - rms_error).abs() < base_params.convergence_threshold {
1011            converged = true;
1012            break;
1013        }
1014        previous_rms = rms_error;
1015    }
1016
1017    // Build final result mesh
1018    let mut result_mesh = working_mesh.clone();
1019    let final_displacements: Vec<Vector3<f64>> = current_positions
1020        .iter()
1021        .zip(working_mesh.vertices.iter())
1022        .map(|(current, original)| current - original.position)
1023        .collect();
1024
1025    for (i, vertex) in result_mesh.vertices.iter_mut().enumerate() {
1026        vertex.position = current_positions[i];
1027        if let Some(ref mut normal) = vertex.normal {
1028            // Re-estimate normals would be better, but for now keep them
1029            // In practice, you'd recompute normals after deformation
1030            *normal = normal.normalize();
1031        }
1032    }
1033
1034    // Calculate final metrics
1035    let (rms_error, max_error, correspondences_used) = calculate_alignment_error(
1036        &result_mesh,
1037        &target_points,
1038        base_params.max_correspondence_distance,
1039    );
1040
1041    let displacement_magnitudes: Vec<f64> = final_displacements.iter().map(|d| d.norm()).collect();
1042    let average_displacement = if !displacement_magnitudes.is_empty() {
1043        displacement_magnitudes.iter().sum::<f64>() / displacement_magnitudes.len() as f64
1044    } else {
1045        0.0
1046    };
1047    let max_displacement = displacement_magnitudes
1048        .iter()
1049        .cloned()
1050        .fold(0.0f64, f64::max);
1051
1052    Ok(NonRigidRegistrationResult {
1053        mesh: result_mesh,
1054        displacements: final_displacements,
1055        initial_transform,
1056        rms_error,
1057        max_error,
1058        iterations,
1059        converged,
1060        correspondences_used,
1061        average_displacement,
1062        max_displacement,
1063    })
1064}
1065
1066/// Select control points from a mesh using farthest point sampling.
1067fn select_control_points(mesh: &Mesh, num_points: usize) -> Vec<usize> {
1068    let n = mesh.vertex_count();
1069    if num_points >= n {
1070        return (0..n).collect();
1071    }
1072
1073    let mut selected = Vec::with_capacity(num_points);
1074    let mut min_distances = vec![f64::INFINITY; n];
1075
1076    // Start with first vertex
1077    selected.push(0);
1078
1079    while selected.len() < num_points {
1080        // Update minimum distances to selected set
1081        let last_selected = selected[selected.len() - 1];
1082        let last_pos = mesh.vertices[last_selected].position;
1083
1084        for (i, dist) in min_distances.iter_mut().enumerate() {
1085            let d = (mesh.vertices[i].position - last_pos).norm();
1086            *dist = dist.min(d);
1087        }
1088
1089        // Find point with maximum minimum distance (farthest point sampling)
1090        let next = min_distances
1091            .iter()
1092            .enumerate()
1093            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1094            .map(|(i, _)| i)
1095            .unwrap();
1096
1097        selected.push(next);
1098    }
1099
1100    selected
1101}
1102
1103/// Regularize displacements using Laplacian smoothing.
1104fn regularize_displacements(
1105    positions: &[Point3<f64>],
1106    displacements: &[Vector3<f64>],
1107    stiffness: f64,
1108) -> Vec<Vector3<f64>> {
1109    let n = positions.len();
1110    if n <= 1 {
1111        return displacements.to_vec();
1112    }
1113
1114    // Build neighborhood based on spatial proximity
1115    // For each control point, find k nearest neighbors
1116    let k = (n / 4).clamp(3, 10);
1117
1118    let mut neighbors: Vec<Vec<usize>> = Vec::with_capacity(n);
1119    for i in 0..n {
1120        let mut distances: Vec<(usize, f64)> = (0..n)
1121            .filter(|&j| j != i)
1122            .map(|j| (j, (positions[i] - positions[j]).norm()))
1123            .collect();
1124        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1125        neighbors.push(distances.iter().take(k).map(|(idx, _)| *idx).collect());
1126    }
1127
1128    // Apply Laplacian smoothing
1129    let smoothing_weight = stiffness / (stiffness + 1.0);
1130    let mut smoothed = displacements.to_vec();
1131
1132    for i in 0..n {
1133        if neighbors[i].is_empty() {
1134            continue;
1135        }
1136
1137        // Compute average of neighbor displacements
1138        let neighbor_avg: Vector3<f64> = neighbors[i]
1139            .iter()
1140            .map(|&j| displacements[j])
1141            .sum::<Vector3<f64>>()
1142            / neighbors[i].len() as f64;
1143
1144        // Blend original with smoothed
1145        smoothed[i] = displacements[i] * (1.0 - smoothing_weight) + neighbor_avg * smoothing_weight;
1146    }
1147
1148    smoothed
1149}
1150
1151/// Interpolate displacements from control points to all vertices using RBF.
1152fn interpolate_displacements_rbf(
1153    control_positions: &[Point3<f64>],
1154    control_displacements: &[Vector3<f64>],
1155    landmark_constraints: &[(Point3<f64>, Vector3<f64>)],
1156    query_positions: &[Point3<f64>],
1157    smoothness: f64,
1158) -> Vec<Vector3<f64>> {
1159    // Combine control points and landmarks
1160    let mut all_positions: Vec<Point3<f64>> = control_positions.to_vec();
1161    let mut all_displacements: Vec<Vector3<f64>> = control_displacements.to_vec();
1162
1163    for (pos, disp) in landmark_constraints {
1164        all_positions.push(*pos);
1165        all_displacements.push(*disp);
1166    }
1167
1168    let n = all_positions.len();
1169    if n == 0 {
1170        return vec![Vector3::zeros(); query_positions.len()];
1171    }
1172
1173    // For small number of control points, use direct RBF interpolation
1174    // For larger sets, use a simplified approach
1175
1176    if n <= 100 {
1177        // Full RBF solve for each component (x, y, z)
1178        interpolate_rbf_full(
1179            &all_positions,
1180            &all_displacements,
1181            query_positions,
1182            smoothness,
1183        )
1184    } else {
1185        // Simplified: weighted average based on distance
1186        interpolate_rbf_simplified(
1187            &all_positions,
1188            &all_displacements,
1189            query_positions,
1190            smoothness,
1191        )
1192    }
1193}
1194
1195/// Full RBF interpolation (solves linear system).
1196fn interpolate_rbf_full(
1197    control_positions: &[Point3<f64>],
1198    control_displacements: &[Vector3<f64>],
1199    query_positions: &[Point3<f64>],
1200    smoothness: f64,
1201) -> Vec<Vector3<f64>> {
1202    use nalgebra::{DMatrix, DVector};
1203
1204    let n = control_positions.len();
1205    let epsilon = 1.0 / smoothness.max(0.01);
1206
1207    // Build RBF matrix
1208    let mut phi = DMatrix::zeros(n, n);
1209    for i in 0..n {
1210        for j in 0..n {
1211            let r = (control_positions[i] - control_positions[j]).norm();
1212            phi[(i, j)] = thin_plate_spline_rbf(r, epsilon);
1213        }
1214        // Add regularization to diagonal
1215        phi[(i, i)] += 1e-6;
1216    }
1217
1218    // Solve for weights for each component
1219    let mut result = vec![Vector3::zeros(); query_positions.len()];
1220
1221    // Try to solve; if it fails, fall back to simplified method
1222    let decomp = phi.clone().lu();
1223
1224    for component in 0..3 {
1225        let b: DVector<f64> =
1226            DVector::from_iterator(n, control_displacements.iter().map(|d| d[component]));
1227
1228        if let Some(weights) = decomp.solve(&b) {
1229            // Evaluate at query points
1230            for (q_idx, query) in query_positions.iter().enumerate() {
1231                let mut val = 0.0;
1232                for (c_idx, ctrl) in control_positions.iter().enumerate() {
1233                    let r = (query - ctrl).norm();
1234                    val += weights[c_idx] * thin_plate_spline_rbf(r, epsilon);
1235                }
1236                result[q_idx][component] = val;
1237            }
1238        } else {
1239            // Fall back to simplified interpolation
1240            return interpolate_rbf_simplified(
1241                control_positions,
1242                control_displacements,
1243                query_positions,
1244                smoothness,
1245            );
1246        }
1247    }
1248
1249    result
1250}
1251
1252/// Simplified RBF interpolation using inverse distance weighting.
1253fn interpolate_rbf_simplified(
1254    control_positions: &[Point3<f64>],
1255    control_displacements: &[Vector3<f64>],
1256    query_positions: &[Point3<f64>],
1257    smoothness: f64,
1258) -> Vec<Vector3<f64>> {
1259    let power = 2.0 + smoothness;
1260
1261    query_positions
1262        .iter()
1263        .map(|query| {
1264            let mut weighted_sum = Vector3::zeros();
1265            let mut weight_sum = 0.0;
1266
1267            for (ctrl, disp) in control_positions.iter().zip(control_displacements.iter()) {
1268                let dist = (query - ctrl).norm();
1269                let weight = if dist < 1e-10 {
1270                    1e10 // Very close, use this displacement directly
1271                } else {
1272                    1.0 / dist.powf(power)
1273                };
1274
1275                weighted_sum += disp * weight;
1276                weight_sum += weight;
1277            }
1278
1279            if weight_sum > 0.0 {
1280                weighted_sum / weight_sum
1281            } else {
1282                Vector3::zeros()
1283            }
1284        })
1285        .collect()
1286}
1287
1288/// Thin-plate spline RBF kernel.
1289fn thin_plate_spline_rbf(r: f64, _epsilon: f64) -> f64 {
1290    if r < 1e-10 { 0.0 } else { r * r * r.ln() }
1291}
1292
1293#[cfg(test)]
1294mod tests {
1295    use super::*;
1296    use crate::Vertex;
1297
1298    fn create_test_triangle() -> Mesh {
1299        let mut mesh = Mesh::new();
1300        mesh.vertices.push(Vertex::from_coords(0.0, 0.0, 0.0));
1301        mesh.vertices.push(Vertex::from_coords(10.0, 0.0, 0.0));
1302        mesh.vertices.push(Vertex::from_coords(5.0, 8.66, 0.0));
1303        mesh.faces.push([0, 1, 2]);
1304        mesh
1305    }
1306
1307    fn create_test_cube() -> Mesh {
1308        let mut mesh = Mesh::new();
1309        mesh.vertices.push(Vertex::from_coords(0.0, 0.0, 0.0));
1310        mesh.vertices.push(Vertex::from_coords(10.0, 0.0, 0.0));
1311        mesh.vertices.push(Vertex::from_coords(10.0, 10.0, 0.0));
1312        mesh.vertices.push(Vertex::from_coords(0.0, 10.0, 0.0));
1313        mesh.vertices.push(Vertex::from_coords(0.0, 0.0, 10.0));
1314        mesh.vertices.push(Vertex::from_coords(10.0, 0.0, 10.0));
1315        mesh.vertices.push(Vertex::from_coords(10.0, 10.0, 10.0));
1316        mesh.vertices.push(Vertex::from_coords(0.0, 10.0, 10.0));
1317
1318        mesh.faces.push([0, 2, 1]);
1319        mesh.faces.push([0, 3, 2]);
1320        mesh.faces.push([4, 5, 6]);
1321        mesh.faces.push([4, 6, 7]);
1322        mesh.faces.push([0, 1, 5]);
1323        mesh.faces.push([0, 5, 4]);
1324        mesh.faces.push([2, 3, 7]);
1325        mesh.faces.push([2, 7, 6]);
1326        mesh.faces.push([0, 4, 7]);
1327        mesh.faces.push([0, 7, 3]);
1328        mesh.faces.push([1, 2, 6]);
1329        mesh.faces.push([1, 6, 5]);
1330        mesh
1331    }
1332
1333    #[test]
1334    fn test_identity_registration() {
1335        let mesh = create_test_triangle();
1336        let target = mesh.clone();
1337
1338        let params = RegistrationParams::icp().with_max_iterations(10);
1339        let result = align_meshes(&mesh, &target, &params).unwrap();
1340
1341        // Should converge with minimal error
1342        assert!(result.rms_error < 0.01, "RMS error: {}", result.rms_error);
1343        assert!(result.converged);
1344    }
1345
1346    #[test]
1347    fn test_translation_recovery() {
1348        // Use cube instead of triangle for better ICP convergence
1349        let source = create_test_cube();
1350
1351        // Create target translated by (5, 3, 0)
1352        let mut target = source.clone();
1353        for vertex in &mut target.vertices {
1354            vertex.position.x += 5.0;
1355            vertex.position.y += 3.0;
1356        }
1357
1358        let params = RegistrationParams::icp().with_max_iterations(100);
1359        let result = align_meshes(&source, &target, &params).unwrap();
1360
1361        // Should recover the translation (ICP may have local minima with triangles)
1362        assert!(
1363            result.rms_error < 1.0,
1364            "RMS error should be reasonable: {}",
1365            result.rms_error
1366        );
1367
1368        // For cube-to-cube alignment, we can check alignment quality improved
1369        assert!(result.iterations > 0, "Should perform some iterations");
1370    }
1371
1372    #[test]
1373    fn test_rotation_recovery() {
1374        let source = create_test_cube();
1375
1376        // Create target rotated a small angle around Z axis (ICP works better with small rotations)
1377        let rotation = UnitQuaternion::from_axis_angle(&Vector3::z_axis(), 0.2); // ~11 degrees
1378        let mut target = source.clone();
1379        for vertex in &mut target.vertices {
1380            vertex.position = rotation * vertex.position;
1381        }
1382
1383        let params = RegistrationParams::icp().with_max_iterations(150);
1384        let result = align_meshes(&source, &target, &params).unwrap();
1385
1386        // ICP should make progress on the alignment
1387        assert!(result.iterations > 0, "Should perform some iterations");
1388        // Note: ICP can struggle with symmetric shapes and larger rotations
1389    }
1390
1391    #[test]
1392    fn test_landmark_registration() {
1393        let source = create_test_triangle();
1394
1395        // Create translated target
1396        let mut target = source.clone();
1397        for vertex in &mut target.vertices {
1398            vertex.position.x += 10.0;
1399            vertex.position.y += 5.0;
1400        }
1401
1402        let landmarks = vec![
1403            Landmark::new(Point3::new(0.0, 0.0, 0.0), Point3::new(10.0, 5.0, 0.0)),
1404            Landmark::new(Point3::new(10.0, 0.0, 0.0), Point3::new(20.0, 5.0, 0.0)),
1405            Landmark::new(Point3::new(5.0, 8.66, 0.0), Point3::new(15.0, 13.66, 0.0)),
1406        ];
1407
1408        let params = RegistrationParams::landmark_based(landmarks);
1409        let result = align_meshes(&source, &target, &params).unwrap();
1410
1411        // With exact landmarks, error should be near zero
1412        assert!(
1413            result.rms_error < 0.01,
1414            "RMS error should be minimal: {}",
1415            result.rms_error
1416        );
1417    }
1418
1419    #[test]
1420    fn test_scaling_registration() {
1421        let source = create_test_cube();
1422
1423        // Create target scaled by 2x (use cube for better convergence)
1424        let mut target = source.clone();
1425        for vertex in &mut target.vertices {
1426            vertex.position.coords *= 2.0;
1427        }
1428
1429        let params = RegistrationParams::icp()
1430            .with_scaling()
1431            .with_max_iterations(100);
1432        let result = align_meshes(&source, &target, &params).unwrap();
1433
1434        // Should make some progress - scaling recovery is challenging for ICP
1435        assert!(result.iterations > 0, "Should perform iterations");
1436        // Scale recovery works better with good initial alignment
1437        // The algorithm should at least run without error
1438    }
1439
1440    #[test]
1441    fn test_transform_composition() {
1442        let t1 = RigidTransform::from_translation(Vector3::new(1.0, 0.0, 0.0));
1443        let t2 = RigidTransform::from_translation(Vector3::new(0.0, 2.0, 0.0));
1444
1445        let composed = t1.then(&t2);
1446
1447        let point = Point3::new(0.0, 0.0, 0.0);
1448        let result = composed.transform_point(&point);
1449
1450        assert!((result.x - 1.0).abs() < 1e-10);
1451        assert!((result.y - 2.0).abs() < 1e-10);
1452        assert!((result.z - 0.0).abs() < 1e-10);
1453    }
1454
1455    #[test]
1456    fn test_transform_inverse() {
1457        let rotation =
1458            UnitQuaternion::from_axis_angle(&Vector3::z_axis(), std::f64::consts::FRAC_PI_2);
1459        let transform = RigidTransform {
1460            rotation,
1461            translation: Vector3::new(5.0, 3.0, 1.0),
1462            scale: 2.0,
1463        };
1464
1465        let inverse = transform.inverse();
1466        let point = Point3::new(1.0, 2.0, 3.0);
1467
1468        let transformed = transform.transform_point(&point);
1469        let recovered = inverse.transform_point(&transformed);
1470
1471        assert!(
1472            (point - recovered).norm() < 1e-10,
1473            "Inverse should recover original point"
1474        );
1475    }
1476
1477    #[test]
1478    fn test_empty_mesh_error() {
1479        let source = Mesh::new();
1480        let target = create_test_triangle();
1481
1482        let params = RegistrationParams::icp();
1483        assert!(matches!(
1484            align_meshes(&source, &target, &params),
1485            Err(MeshError::EmptyMesh { .. })
1486        ));
1487    }
1488
1489    #[test]
1490    fn test_insufficient_landmarks_error() {
1491        let source = create_test_triangle();
1492        let target = source.clone();
1493
1494        let landmarks = vec![
1495            Landmark::new(Point3::new(0.0, 0.0, 0.0), Point3::new(0.0, 0.0, 0.0)),
1496            Landmark::new(Point3::new(1.0, 0.0, 0.0), Point3::new(1.0, 0.0, 0.0)),
1497        ];
1498
1499        let params = RegistrationParams::landmark_based(landmarks);
1500        assert!(matches!(
1501            align_meshes(&source, &target, &params),
1502            Err(MeshError::RepairFailed { .. })
1503        ));
1504    }
1505
1506    #[test]
1507    fn test_landmark_then_icp() {
1508        let source = create_test_cube();
1509
1510        // Create target with rotation and translation
1511        let rotation = UnitQuaternion::from_axis_angle(&Vector3::z_axis(), 0.3);
1512        let translation = Vector3::new(5.0, 3.0, 2.0);
1513
1514        let mut target = source.clone();
1515        for vertex in &mut target.vertices {
1516            vertex.position = rotation * vertex.position;
1517            vertex.position.coords += translation;
1518        }
1519
1520        // Provide approximate landmarks (slightly off to test refinement)
1521        let landmarks = vec![
1522            Landmark::new(
1523                Point3::new(0.0, 0.0, 0.0),
1524                rotation * Point3::new(0.0, 0.0, 0.0) + translation,
1525            ),
1526            Landmark::new(
1527                Point3::new(10.0, 0.0, 0.0),
1528                rotation * Point3::new(10.0, 0.0, 0.0) + translation,
1529            ),
1530            Landmark::new(
1531                Point3::new(0.0, 10.0, 0.0),
1532                rotation * Point3::new(0.0, 10.0, 0.0) + translation,
1533            ),
1534        ];
1535
1536        let params = RegistrationParams::landmark_then_icp(landmarks);
1537        let result = align_meshes(&source, &target, &params).unwrap();
1538
1539        assert!(
1540            result.rms_error < 1.0,
1541            "RMS error should be small: {}",
1542            result.rms_error
1543        );
1544    }
1545
1546    #[test]
1547    fn test_max_correspondence_distance() {
1548        let source = create_test_triangle();
1549
1550        // Create target with some noise
1551        let mut target = source.clone();
1552        // Add an outlier vertex far away
1553        target.vertices.push(Vertex::from_coords(1000.0, 0.0, 0.0));
1554
1555        let params = RegistrationParams::icp()
1556            .with_max_correspondence_distance(50.0) // Reject correspondences > 50mm
1557            .with_max_iterations(20);
1558        let result = align_meshes(&source, &target, &params).unwrap();
1559
1560        // Should still converge well, ignoring the outlier
1561        assert!(result.rms_error < 1.0, "RMS error: {}", result.rms_error);
1562    }
1563
1564    #[test]
1565    fn test_subsample_ratio() {
1566        let source = create_test_cube();
1567        let target = source.clone();
1568
1569        let params = RegistrationParams::icp()
1570            .with_subsample_ratio(0.5) // Use 50% of points
1571            .with_max_iterations(10);
1572        let result = align_meshes(&source, &target, &params).unwrap();
1573
1574        assert!(result.rms_error < 0.1, "RMS error: {}", result.rms_error);
1575    }
1576
1577    #[test]
1578    fn test_weighted_landmarks() {
1579        let source = create_test_triangle();
1580
1581        // Create target translated
1582        let mut target = source.clone();
1583        for vertex in &mut target.vertices {
1584            vertex.position.x += 5.0;
1585        }
1586
1587        // Provide landmarks with different weights
1588        let landmarks = vec![
1589            Landmark::weighted(Point3::new(0.0, 0.0, 0.0), Point3::new(5.0, 0.0, 0.0), 2.0),
1590            Landmark::weighted(
1591                Point3::new(10.0, 0.0, 0.0),
1592                Point3::new(15.0, 0.0, 0.0),
1593                1.0,
1594            ),
1595            Landmark::weighted(
1596                Point3::new(5.0, 8.66, 0.0),
1597                Point3::new(10.0, 8.66, 0.0),
1598                1.0,
1599            ),
1600        ];
1601
1602        let params = RegistrationParams::landmark_based(landmarks);
1603        let result = align_meshes(&source, &target, &params).unwrap();
1604
1605        assert!(result.rms_error < 0.1, "RMS error: {}", result.rms_error);
1606    }
1607
1608    #[test]
1609    fn test_rigid_transform_to_matrix() {
1610        let rotation =
1611            UnitQuaternion::from_axis_angle(&Vector3::z_axis(), std::f64::consts::FRAC_PI_2);
1612        let transform = RigidTransform {
1613            rotation,
1614            translation: Vector3::new(1.0, 2.0, 3.0),
1615            scale: 1.0,
1616        };
1617
1618        let matrix = transform.to_matrix4();
1619
1620        // Test that point transformation matches matrix multiplication
1621        let point = Point3::new(1.0, 0.0, 0.0);
1622        let transformed = transform.transform_point(&point);
1623
1624        let homogeneous = nalgebra::Vector4::new(point.x, point.y, point.z, 1.0);
1625        let matrix_result = matrix * homogeneous;
1626
1627        assert!((transformed.x - matrix_result.x).abs() < 1e-10);
1628        assert!((transformed.y - matrix_result.y).abs() < 1e-10);
1629        assert!((transformed.z - matrix_result.z).abs() < 1e-10);
1630    }
1631
1632    // Non-rigid registration tests
1633
1634    #[test]
1635    fn test_non_rigid_identity() {
1636        // Non-rigid registration of identical meshes should produce minimal displacement
1637        let source = create_test_cube();
1638        let target = source.clone();
1639
1640        let nr_params = NonRigidParams::new();
1641        let base_params = RegistrationParams::default().with_max_iterations(20);
1642
1643        let result = non_rigid_align(&source, &target, &nr_params, &base_params).unwrap();
1644
1645        // Error should be very small
1646        assert!(
1647            result.rms_error < 0.5,
1648            "RMS error should be small for identical meshes: {}",
1649            result.rms_error
1650        );
1651
1652        // Displacements should be minimal
1653        assert!(
1654            result.max_displacement < 1.0,
1655            "Max displacement should be small: {}",
1656            result.max_displacement
1657        );
1658    }
1659
1660    #[test]
1661    fn test_non_rigid_translation() {
1662        // Non-rigid registration should handle pure translation
1663        let source = create_test_cube();
1664
1665        let mut target = source.clone();
1666        for vertex in &mut target.vertices {
1667            vertex.position.x += 5.0;
1668            vertex.position.z += 2.0;
1669        }
1670
1671        let nr_params = NonRigidParams::new().with_stiffness(1.0);
1672        let base_params = RegistrationParams::default().with_max_iterations(50);
1673
1674        let result = non_rigid_align(&source, &target, &nr_params, &base_params).unwrap();
1675
1676        // Should achieve reasonable alignment
1677        assert!(
1678            result.rms_error < 2.0,
1679            "RMS error should be reasonable: {}",
1680            result.rms_error
1681        );
1682    }
1683
1684    #[test]
1685    fn test_non_rigid_local_deformation() {
1686        // Non-rigid should handle local deformations that rigid ICP cannot
1687        let source = create_test_cube();
1688
1689        // Create target with non-uniform scaling (stretches in Y direction)
1690        let mut target = source.clone();
1691        for vertex in &mut target.vertices {
1692            vertex.position.y *= 1.5; // Stretch Y by 50%
1693        }
1694
1695        let nr_params = NonRigidParams::new()
1696            .with_stiffness(2.0)
1697            .with_outer_iterations(15);
1698        let base_params = RegistrationParams::default();
1699
1700        let result = non_rigid_align(&source, &target, &nr_params, &base_params).unwrap();
1701
1702        // Non-rigid should adapt to the stretching
1703        assert!(
1704            result.max_displacement > 0.1,
1705            "Should have non-trivial displacement to match stretched target"
1706        );
1707
1708        // Should converge or run all iterations
1709        assert!(result.iterations >= 1, "Should perform iterations");
1710    }
1711
1712    #[test]
1713    fn test_non_rigid_with_landmarks() {
1714        let source = create_test_cube();
1715
1716        // Target with translation
1717        let mut target = source.clone();
1718        for vertex in &mut target.vertices {
1719            vertex.position.x += 3.0;
1720        }
1721
1722        // Provide landmarks to guide the deformation
1723        let landmarks = vec![
1724            Landmark::new(Point3::new(0.0, 0.0, 0.0), Point3::new(3.0, 0.0, 0.0)),
1725            Landmark::new(Point3::new(10.0, 10.0, 10.0), Point3::new(13.0, 10.0, 10.0)),
1726        ];
1727
1728        let nr_params = NonRigidParams::new()
1729            .with_landmarks(landmarks)
1730            .with_stiffness(5.0);
1731        let base_params = RegistrationParams::default();
1732
1733        let result = non_rigid_align(&source, &target, &nr_params, &base_params).unwrap();
1734
1735        // Landmarks should help guide the registration
1736        assert!(result.iterations >= 1);
1737    }
1738
1739    #[test]
1740    fn test_non_rigid_params_builder() {
1741        let params = NonRigidParams::new()
1742            .with_stiffness(20.0)
1743            .with_control_points(100)
1744            .with_outer_iterations(5)
1745            .with_smoothness(2.0)
1746            .without_initial_alignment();
1747
1748        assert_eq!(params.stiffness, 20.0);
1749        assert_eq!(params.num_control_points, Some(100));
1750        assert_eq!(params.outer_iterations, 5);
1751        assert_eq!(params.smoothness, 2.0);
1752        assert!(!params.initial_rigid_alignment);
1753    }
1754
1755    #[test]
1756    fn test_non_rigid_result_deformation_field() {
1757        let source = create_test_cube();
1758        let target = source.clone();
1759
1760        let nr_params = NonRigidParams::new();
1761        let base_params = RegistrationParams::default().with_max_iterations(5);
1762
1763        let result = non_rigid_align(&source, &target, &nr_params, &base_params).unwrap();
1764
1765        // Get deformation field
1766        let field = result.deformation_field(&source);
1767
1768        assert_eq!(field.len(), source.vertex_count());
1769
1770        // Each entry should have position and displacement
1771        for (pos, _disp) in &field {
1772            // Position should be from original mesh
1773            assert!(
1774                source
1775                    .vertices
1776                    .iter()
1777                    .any(|v| (v.position - pos).norm() < 1e-6)
1778            );
1779        }
1780    }
1781
1782    #[test]
1783    fn test_select_control_points() {
1784        let mesh = create_test_cube();
1785
1786        // Select 4 control points
1787        let indices = select_control_points(&mesh, 4);
1788
1789        assert_eq!(indices.len(), 4);
1790
1791        // All indices should be valid
1792        for &idx in &indices {
1793            assert!(idx < mesh.vertex_count());
1794        }
1795
1796        // Indices should be unique
1797        let unique: std::collections::HashSet<_> = indices.iter().collect();
1798        assert_eq!(unique.len(), 4);
1799    }
1800
1801    #[test]
1802    fn test_select_control_points_more_than_vertices() {
1803        let mesh = create_test_triangle();
1804
1805        // Request more control points than vertices
1806        let indices = select_control_points(&mesh, 10);
1807
1808        // Should return all vertices
1809        assert_eq!(indices.len(), mesh.vertex_count());
1810    }
1811
1812    #[test]
1813    fn test_regularize_displacements() {
1814        let positions = vec![
1815            Point3::new(0.0, 0.0, 0.0),
1816            Point3::new(1.0, 0.0, 0.0),
1817            Point3::new(0.5, 1.0, 0.0),
1818        ];
1819
1820        // One point has a large displacement, others are zero
1821        let displacements = vec![
1822            Vector3::new(10.0, 0.0, 0.0),
1823            Vector3::zeros(),
1824            Vector3::zeros(),
1825        ];
1826
1827        // High stiffness should smooth out the spike
1828        let smoothed = regularize_displacements(&positions, &displacements, 10.0);
1829
1830        // The outlier should be reduced
1831        assert!(
1832            smoothed[0].norm() < displacements[0].norm(),
1833            "Regularization should reduce outlier"
1834        );
1835    }
1836
1837    #[test]
1838    fn test_non_rigid_empty_mesh_error() {
1839        let source = Mesh::new();
1840        let target = create_test_cube();
1841
1842        let nr_params = NonRigidParams::new();
1843        let base_params = RegistrationParams::default();
1844
1845        let result = non_rigid_align(&source, &target, &nr_params, &base_params);
1846
1847        assert!(matches!(result, Err(MeshError::EmptyMesh { .. })));
1848    }
1849
1850    #[test]
1851    fn test_thin_plate_spline_rbf() {
1852        // Test RBF kernel properties
1853        assert_eq!(thin_plate_spline_rbf(0.0, 1.0), 0.0);
1854
1855        // r^2 * ln(r) for r=1 should be 0 (ln(1) = 0)
1856        assert!((thin_plate_spline_rbf(1.0, 1.0) - 0.0).abs() < 1e-10);
1857
1858        // For r > 1, value should be positive
1859        assert!(thin_plate_spline_rbf(2.0, 1.0) > 0.0);
1860
1861        // For 0 < r < 1, value should be negative
1862        assert!(thin_plate_spline_rbf(0.5, 1.0) < 0.0);
1863    }
1864}