mesh_repair/
template.rs

1//! Template-based mesh fitting.
2//!
3//! This module provides tools for fitting template meshes to scans or measurements,
4//! enabling parametric customization of product designs.
5//!
6//! # Use Cases
7//!
8//! - Fitting a shoe last template to a foot scan
9//! - Adapting a helmet liner to head measurements
10//! - Creating size variations of a product template
11//!
12//! # Example
13//!
14//! ```
15//! use mesh_repair::{Mesh, Vertex};
16//! use mesh_repair::template::{FitTemplate, FitParams, ControlRegion};
17//! use nalgebra::Point3;
18//!
19//! // Create a template mesh
20//! let mut template_mesh = Mesh::new();
21//! template_mesh.vertices.push(Vertex::from_coords(0.0, 0.0, 0.0));
22//! template_mesh.vertices.push(Vertex::from_coords(10.0, 0.0, 0.0));
23//! template_mesh.vertices.push(Vertex::from_coords(5.0, 10.0, 0.0));
24//! template_mesh.faces.push([0, 1, 2]);
25//!
26//! // Create template with control regions
27//! let template = FitTemplate::new(template_mesh)
28//!     .with_control_region(ControlRegion::point("tip", Point3::new(5.0, 10.0, 0.0)));
29//!
30//! // Fit to target measurements
31//! let params = FitParams::default()
32//!     .with_landmark_target("tip", Point3::new(5.0, 12.0, 0.0));
33//!
34//! let result = template.fit(&params).unwrap();
35//! println!("Fit error: {:.3} mm", result.fit_error);
36//! ```
37
38use crate::morph::{self, Constraint, MorphParams};
39use crate::registration::{self, RegistrationParams, RigidTransform};
40use crate::{Mesh, MeshError, MeshResult};
41use nalgebra::{Point3, Vector3};
42use std::collections::{HashMap, HashSet};
43
44/// A template mesh with control regions for parametric fitting.
45#[derive(Debug, Clone)]
46pub struct FitTemplate {
47    /// The base template mesh.
48    pub mesh: Mesh,
49
50    /// Named control regions for manipulation.
51    pub control_regions: HashMap<String, ControlRegion>,
52
53    /// Default fitting parameters.
54    pub default_params: FitParams,
55}
56
57impl FitTemplate {
58    /// Create a new template from a mesh.
59    pub fn new(mesh: Mesh) -> Self {
60        Self {
61            mesh,
62            control_regions: HashMap::new(),
63            default_params: FitParams::default(),
64        }
65    }
66
67    /// Add a control region to the template.
68    pub fn with_control_region(mut self, region: ControlRegion) -> Self {
69        self.control_regions.insert(region.name.clone(), region);
70        self
71    }
72
73    /// Add multiple control regions.
74    pub fn with_control_regions(mut self, regions: Vec<ControlRegion>) -> Self {
75        for region in regions {
76            self.control_regions.insert(region.name.clone(), region);
77        }
78        self
79    }
80
81    /// Set default fitting parameters.
82    pub fn with_default_params(mut self, params: FitParams) -> Self {
83        self.default_params = params;
84        self
85    }
86
87    /// Get a control region by name.
88    pub fn get_region(&self, name: &str) -> Option<&ControlRegion> {
89        self.control_regions.get(name)
90    }
91
92    /// Get the position of a landmark control region.
93    pub fn get_landmark_position(&self, name: &str) -> Option<Point3<f64>> {
94        self.control_regions
95            .get(name)
96            .and_then(|r| match &r.definition {
97                RegionDefinition::Point(p) => Some(*p),
98                RegionDefinition::Vertices(indices) if indices.len() == 1 => self
99                    .mesh
100                    .vertices
101                    .get(indices[0] as usize)
102                    .map(|v| v.position),
103                _ => None,
104            })
105    }
106
107    /// List all control region names.
108    pub fn region_names(&self) -> Vec<&str> {
109        self.control_regions.keys().map(|s| s.as_str()).collect()
110    }
111
112    /// Fit the template using the given parameters.
113    ///
114    /// This performs a multi-stage fitting process:
115    /// 1. Rigid alignment (if a target scan is provided)
116    /// 2. Landmark-based deformation
117    /// 3. Measurement-based adjustment
118    pub fn fit(&self, params: &FitParams) -> MeshResult<FitResult> {
119        if self.mesh.is_empty() {
120            return Err(MeshError::EmptyMesh {
121                details: "Cannot fit an empty template mesh".to_string(),
122            });
123        }
124
125        let mut current_mesh = self.mesh.clone();
126        let mut total_transform = RigidTransform::identity();
127        let mut stages_completed = Vec::new();
128
129        // Stage 1: Rigid alignment to scan (if provided)
130        if let Some(ref scan) = params.target_scan {
131            let reg_params = RegistrationParams::icp()
132                .with_max_iterations(params.registration_iterations)
133                .with_convergence_threshold(params.convergence_threshold);
134
135            let reg_result = registration::align_meshes(&current_mesh, scan, &reg_params)?;
136            current_mesh = reg_result.mesh;
137            total_transform = reg_result.transformation;
138            stages_completed.push(FitStage::RigidAlignment {
139                rms_error: reg_result.rms_error,
140            });
141        }
142
143        // Stage 2: Landmark-based deformation
144        if !params.landmark_targets.is_empty() {
145            let mut constraints = Vec::new();
146
147            for (name, target) in &params.landmark_targets {
148                if let Some(region) = self.control_regions.get(name) {
149                    // Get current position (after rigid alignment)
150                    let source = match &region.definition {
151                        RegionDefinition::Point(p) => total_transform.transform_point(p),
152                        RegionDefinition::Vertices(indices) if !indices.is_empty() => {
153                            // Average position of vertices
154                            let sum: Vector3<f64> = indices
155                                .iter()
156                                .filter_map(|&i| current_mesh.vertices.get(i as usize))
157                                .map(|v| v.position.coords)
158                                .sum();
159                            Point3::from(sum / indices.len() as f64)
160                        }
161                        _ => continue,
162                    };
163
164                    constraints.push(Constraint::weighted(source, *target, region.weight));
165                }
166            }
167
168            if !constraints.is_empty() {
169                let morph_params = MorphParams::rbf()
170                    .with_constraints(constraints)
171                    .with_smoothness(params.smoothness);
172
173                let morph_result = morph::morph_mesh(&current_mesh, &morph_params)?;
174                current_mesh = morph_result.mesh;
175                stages_completed.push(FitStage::LandmarkDeformation {
176                    constraints_applied: params.landmark_targets.len(),
177                    max_displacement: morph_result.max_displacement,
178                });
179            }
180        }
181
182        // Stage 3: Measurement-based adjustment
183        if !params.measurement_targets.is_empty() {
184            for (name, measurement) in &params.measurement_targets {
185                if let Some(region) = self.control_regions.get(name) {
186                    current_mesh =
187                        apply_measurement_constraint(&current_mesh, region, measurement)?;
188                }
189            }
190            stages_completed.push(FitStage::MeasurementAdjustment {
191                measurements_applied: params.measurement_targets.len(),
192            });
193        }
194
195        // Calculate fit quality metrics
196        let fit_error = calculate_fit_error(&current_mesh, params, &self.control_regions);
197
198        Ok(FitResult {
199            mesh: current_mesh,
200            fit_error,
201            stages: stages_completed,
202            transform: total_transform,
203        })
204    }
205
206    /// Fit the template to a target scan.
207    ///
208    /// This is a convenience method that combines registration and morphing.
209    pub fn fit_to_scan(&self, scan: &Mesh) -> MeshResult<FitResult> {
210        let params = FitParams::default().with_target_scan(scan.clone());
211        self.fit(&params)
212    }
213
214    /// Fit the template to target measurements only.
215    pub fn fit_to_measurements(
216        &self,
217        measurements: HashMap<String, Measurement>,
218    ) -> MeshResult<FitResult> {
219        let params = FitParams::default().with_measurements(measurements);
220        self.fit(&params)
221    }
222}
223
224/// A control region on a template mesh.
225#[derive(Debug, Clone)]
226pub struct ControlRegion {
227    /// Unique name for this region (e.g., "heel", "toe_tip", "ankle").
228    pub name: String,
229
230    /// How this region is defined.
231    pub definition: RegionDefinition,
232
233    /// Weight for this region in fitting operations.
234    pub weight: f64,
235
236    /// Whether this region should be preserved (not deformed).
237    pub preserve: bool,
238}
239
240impl ControlRegion {
241    /// Create a point-based control region (single landmark).
242    pub fn point(name: impl Into<String>, position: Point3<f64>) -> Self {
243        Self {
244            name: name.into(),
245            definition: RegionDefinition::Point(position),
246            weight: 1.0,
247            preserve: false,
248        }
249    }
250
251    /// Create a vertex-based control region.
252    pub fn vertices(name: impl Into<String>, indices: Vec<u32>) -> Self {
253        Self {
254            name: name.into(),
255            definition: RegionDefinition::Vertices(indices),
256            weight: 1.0,
257            preserve: false,
258        }
259    }
260
261    /// Create a face-based control region.
262    pub fn faces(name: impl Into<String>, indices: Vec<u32>) -> Self {
263        Self {
264            name: name.into(),
265            definition: RegionDefinition::Faces(indices),
266            weight: 1.0,
267            preserve: false,
268        }
269    }
270
271    /// Create a spatial bounds region (box).
272    pub fn bounds(name: impl Into<String>, min: Point3<f64>, max: Point3<f64>) -> Self {
273        Self {
274            name: name.into(),
275            definition: RegionDefinition::Bounds { min, max },
276            weight: 1.0,
277            preserve: false,
278        }
279    }
280
281    /// Create a spherical region.
282    pub fn sphere(name: impl Into<String>, center: Point3<f64>, radius: f64) -> Self {
283        Self {
284            name: name.into(),
285            definition: RegionDefinition::Sphere { center, radius },
286            weight: 1.0,
287            preserve: false,
288        }
289    }
290
291    /// Create a cylindrical region.
292    pub fn cylinder(
293        name: impl Into<String>,
294        axis_start: Point3<f64>,
295        axis_end: Point3<f64>,
296        radius: f64,
297    ) -> Self {
298        Self {
299            name: name.into(),
300            definition: RegionDefinition::Cylinder {
301                axis_start,
302                axis_end,
303                radius,
304            },
305            weight: 1.0,
306            preserve: false,
307        }
308    }
309
310    /// Create a measurement region (for circumference, etc.).
311    pub fn measurement(
312        name: impl Into<String>,
313        measurement_type: MeasurementType,
314        plane_origin: Point3<f64>,
315        plane_normal: Vector3<f64>,
316    ) -> Self {
317        Self {
318            name: name.into(),
319            definition: RegionDefinition::MeasurementPlane {
320                measurement_type,
321                origin: plane_origin,
322                normal: plane_normal.normalize(),
323            },
324            weight: 1.0,
325            preserve: false,
326        }
327    }
328
329    /// Set the weight for this region.
330    pub fn with_weight(mut self, weight: f64) -> Self {
331        self.weight = weight;
332        self
333    }
334
335    /// Mark this region as preserved (not deformed).
336    pub fn preserved(mut self) -> Self {
337        self.preserve = true;
338        self
339    }
340
341    /// Get the vertex indices that belong to this region.
342    pub fn get_vertex_indices(&self, mesh: &Mesh) -> HashSet<u32> {
343        match &self.definition {
344            RegionDefinition::Point(_) => HashSet::new(),
345            RegionDefinition::Vertices(indices) => indices.iter().copied().collect(),
346            RegionDefinition::Faces(face_indices) => {
347                let mut vertices = HashSet::new();
348                for &fi in face_indices {
349                    if let Some(face) = mesh.faces.get(fi as usize) {
350                        vertices.insert(face[0]);
351                        vertices.insert(face[1]);
352                        vertices.insert(face[2]);
353                    }
354                }
355                vertices
356            }
357            RegionDefinition::Bounds { min, max } => mesh
358                .vertices
359                .iter()
360                .enumerate()
361                .filter(|(_, v)| {
362                    v.position.x >= min.x
363                        && v.position.x <= max.x
364                        && v.position.y >= min.y
365                        && v.position.y <= max.y
366                        && v.position.z >= min.z
367                        && v.position.z <= max.z
368                })
369                .map(|(i, _)| i as u32)
370                .collect(),
371            RegionDefinition::Sphere { center, radius } => mesh
372                .vertices
373                .iter()
374                .enumerate()
375                .filter(|(_, v)| (v.position - center).norm() <= *radius)
376                .map(|(i, _)| i as u32)
377                .collect(),
378            RegionDefinition::Cylinder {
379                axis_start,
380                axis_end,
381                radius,
382            } => {
383                let axis = axis_end - axis_start;
384                let axis_len_sq = axis.norm_squared();
385                if axis_len_sq < 1e-10 {
386                    return HashSet::new();
387                }
388
389                mesh.vertices
390                    .iter()
391                    .enumerate()
392                    .filter(|(_, v)| {
393                        let to_point = v.position - axis_start;
394                        let t = to_point.dot(&axis) / axis_len_sq;
395                        if !(0.0..=1.0).contains(&t) {
396                            return false;
397                        }
398                        let projection = axis_start + axis * t;
399                        (v.position - projection).norm() <= *radius
400                    })
401                    .map(|(i, _)| i as u32)
402                    .collect()
403            }
404            RegionDefinition::MeasurementPlane { origin, normal, .. } => {
405                // Get vertices near the measurement plane
406                let tolerance = 5.0; // mm
407                mesh.vertices
408                    .iter()
409                    .enumerate()
410                    .filter(|(_, v)| {
411                        let dist = (v.position - origin).dot(normal).abs();
412                        dist <= tolerance
413                    })
414                    .map(|(i, _)| i as u32)
415                    .collect()
416            }
417        }
418    }
419}
420
421/// How a control region is defined.
422#[derive(Debug, Clone)]
423pub enum RegionDefinition {
424    /// A single point (landmark).
425    Point(Point3<f64>),
426
427    /// A set of vertex indices.
428    Vertices(Vec<u32>),
429
430    /// A set of face indices.
431    Faces(Vec<u32>),
432
433    /// An axis-aligned bounding box.
434    Bounds { min: Point3<f64>, max: Point3<f64> },
435
436    /// A sphere.
437    Sphere { center: Point3<f64>, radius: f64 },
438
439    /// A cylinder (for limbs, handles, etc.).
440    Cylinder {
441        axis_start: Point3<f64>,
442        axis_end: Point3<f64>,
443        radius: f64,
444    },
445
446    /// A measurement plane (for circumferences, widths, etc.).
447    MeasurementPlane {
448        measurement_type: MeasurementType,
449        origin: Point3<f64>,
450        normal: Vector3<f64>,
451    },
452}
453
454/// Types of measurements that can be constrained.
455#[derive(Debug, Clone, Copy, PartialEq)]
456pub enum MeasurementType {
457    /// Circumference around a cross-section.
458    Circumference,
459    /// Width (extent in a direction).
460    Width,
461    /// Height (extent in vertical direction).
462    Height,
463    /// Depth (extent in a direction).
464    Depth,
465}
466
467/// A measurement value with optional tolerance.
468#[derive(Debug, Clone)]
469pub struct Measurement {
470    /// The target value.
471    pub value: f64,
472    /// Tolerance for the measurement (default: 1mm).
473    pub tolerance: f64,
474    /// Whether this is a minimum (true) or exact (false) constraint.
475    pub is_minimum: bool,
476}
477
478impl Measurement {
479    /// Create an exact measurement constraint.
480    pub fn exact(value: f64) -> Self {
481        Self {
482            value,
483            tolerance: 1.0,
484            is_minimum: false,
485        }
486    }
487
488    /// Create a measurement with tolerance.
489    pub fn with_tolerance(value: f64, tolerance: f64) -> Self {
490        Self {
491            value,
492            tolerance,
493            is_minimum: false,
494        }
495    }
496
497    /// Create a minimum measurement constraint.
498    pub fn minimum(value: f64) -> Self {
499        Self {
500            value,
501            tolerance: 1.0,
502            is_minimum: true,
503        }
504    }
505}
506
507/// Parameters for template fitting.
508#[derive(Debug, Clone, Default)]
509pub struct FitParams {
510    /// Target scan to fit to (optional).
511    pub target_scan: Option<Mesh>,
512
513    /// Target positions for landmark regions.
514    pub landmark_targets: HashMap<String, Point3<f64>>,
515
516    /// Target measurements for measurement regions.
517    pub measurement_targets: HashMap<String, Measurement>,
518
519    /// Smoothness parameter for morphing (higher = smoother).
520    pub smoothness: f64,
521
522    /// Maximum iterations for registration.
523    pub registration_iterations: usize,
524
525    /// Convergence threshold for registration.
526    pub convergence_threshold: f64,
527}
528
529impl FitParams {
530    /// Create default fitting parameters.
531    pub fn new() -> Self {
532        Self {
533            target_scan: None,
534            landmark_targets: HashMap::new(),
535            measurement_targets: HashMap::new(),
536            smoothness: 1.0,
537            registration_iterations: 100,
538            convergence_threshold: 1e-6,
539        }
540    }
541
542    /// Set the target scan.
543    pub fn with_target_scan(mut self, scan: Mesh) -> Self {
544        self.target_scan = Some(scan);
545        self
546    }
547
548    /// Add a landmark target.
549    pub fn with_landmark_target(mut self, name: impl Into<String>, target: Point3<f64>) -> Self {
550        self.landmark_targets.insert(name.into(), target);
551        self
552    }
553
554    /// Add multiple landmark targets.
555    pub fn with_landmark_targets(mut self, targets: HashMap<String, Point3<f64>>) -> Self {
556        self.landmark_targets.extend(targets);
557        self
558    }
559
560    /// Add a measurement target.
561    pub fn with_measurement(mut self, name: impl Into<String>, measurement: Measurement) -> Self {
562        self.measurement_targets.insert(name.into(), measurement);
563        self
564    }
565
566    /// Add multiple measurement targets.
567    pub fn with_measurements(mut self, measurements: HashMap<String, Measurement>) -> Self {
568        self.measurement_targets.extend(measurements);
569        self
570    }
571
572    /// Set the smoothness parameter.
573    pub fn with_smoothness(mut self, smoothness: f64) -> Self {
574        self.smoothness = smoothness;
575        self
576    }
577
578    /// Set the registration iterations.
579    pub fn with_registration_iterations(mut self, iterations: usize) -> Self {
580        self.registration_iterations = iterations;
581        self
582    }
583}
584
585/// Result of a template fitting operation.
586#[derive(Debug, Clone)]
587pub struct FitResult {
588    /// The fitted mesh.
589    pub mesh: Mesh,
590
591    /// Overall fit error (RMS distance at control points).
592    pub fit_error: f64,
593
594    /// Stages that were completed.
595    pub stages: Vec<FitStage>,
596
597    /// The rigid transformation applied.
598    pub transform: RigidTransform,
599}
600
601impl FitResult {
602    /// Check if the fit is acceptable.
603    pub fn is_acceptable(&self, max_error: f64) -> bool {
604        self.fit_error <= max_error
605    }
606}
607
608/// A stage in the fitting process.
609#[derive(Debug, Clone)]
610pub enum FitStage {
611    /// Rigid alignment stage.
612    RigidAlignment { rms_error: f64 },
613
614    /// Landmark-based deformation stage.
615    LandmarkDeformation {
616        constraints_applied: usize,
617        max_displacement: f64,
618    },
619
620    /// Measurement-based adjustment stage.
621    MeasurementAdjustment { measurements_applied: usize },
622}
623
624/// Apply a measurement constraint to a mesh.
625fn apply_measurement_constraint(
626    mesh: &Mesh,
627    region: &ControlRegion,
628    measurement: &Measurement,
629) -> MeshResult<Mesh> {
630    let RegionDefinition::MeasurementPlane {
631        measurement_type,
632        origin,
633        normal,
634    } = &region.definition
635    else {
636        return Ok(mesh.clone());
637    };
638
639    // Get vertices in the measurement region
640    let vertex_indices = region.get_vertex_indices(mesh);
641    if vertex_indices.is_empty() {
642        return Ok(mesh.clone());
643    }
644
645    // Calculate current measurement
646    let current_value = match measurement_type {
647        MeasurementType::Circumference => {
648            // Approximate circumference from vertices near the plane
649            // This is a simplified calculation
650            let region_vertices: Vec<Point3<f64>> = vertex_indices
651                .iter()
652                .filter_map(|&i| mesh.vertices.get(i as usize))
653                .map(|v| v.position)
654                .collect();
655
656            if region_vertices.len() < 3 {
657                return Ok(mesh.clone());
658            }
659
660            // Project vertices onto plane and compute perimeter
661            let projected: Vec<Point3<f64>> = region_vertices
662                .iter()
663                .map(|p| {
664                    let dist = (p - origin).dot(normal);
665                    Point3::from(p.coords - dist * normal)
666                })
667                .collect();
668
669            // Very rough circumference estimate using bounding box
670            let centroid: Vector3<f64> =
671                projected.iter().map(|p| p.coords).sum::<Vector3<f64>>() / projected.len() as f64;
672            let avg_radius = projected
673                .iter()
674                .map(|p| (p.coords - centroid).norm())
675                .sum::<f64>()
676                / projected.len() as f64;
677
678            2.0 * std::f64::consts::PI * avg_radius
679        }
680        MeasurementType::Width | MeasurementType::Depth => {
681            // Compute extent perpendicular to normal
682            let region_vertices: Vec<Point3<f64>> = vertex_indices
683                .iter()
684                .filter_map(|&i| mesh.vertices.get(i as usize))
685                .map(|v| v.position)
686                .collect();
687
688            if region_vertices.is_empty() {
689                return Ok(mesh.clone());
690            }
691
692            // Project onto plane
693            let projected: Vec<Point3<f64>> = region_vertices
694                .iter()
695                .map(|p| {
696                    let dist = (p - origin).dot(normal);
697                    Point3::from(p.coords - dist * normal)
698                })
699                .collect();
700
701            // Compute extent in an arbitrary direction perpendicular to normal
702            let perpendicular = if normal.x.abs() < 0.9 {
703                normal.cross(&Vector3::x()).normalize()
704            } else {
705                normal.cross(&Vector3::y()).normalize()
706            };
707
708            let projections: Vec<f64> = projected
709                .iter()
710                .map(|p| p.coords.dot(&perpendicular))
711                .collect();
712            let min = projections.iter().copied().fold(f64::INFINITY, f64::min);
713            let max = projections
714                .iter()
715                .copied()
716                .fold(f64::NEG_INFINITY, f64::max);
717
718            max - min
719        }
720        MeasurementType::Height => {
721            // Compute extent in the normal direction
722            let projections: Vec<f64> = vertex_indices
723                .iter()
724                .filter_map(|&i| mesh.vertices.get(i as usize))
725                .map(|v| (v.position - origin).dot(normal))
726                .collect();
727
728            if projections.is_empty() {
729                return Ok(mesh.clone());
730            }
731
732            let min = projections.iter().copied().fold(f64::INFINITY, f64::min);
733            let max = projections
734                .iter()
735                .copied()
736                .fold(f64::NEG_INFINITY, f64::max);
737
738            max - min
739        }
740    };
741
742    // Check if adjustment is needed
743    let target_value = measurement.value;
744    let diff = target_value - current_value;
745
746    if diff.abs() <= measurement.tolerance
747        || (measurement.is_minimum && current_value >= target_value)
748    {
749        return Ok(mesh.clone());
750    }
751
752    // Calculate scale factor
753    let scale_factor = if current_value > 1e-6 {
754        target_value / current_value
755    } else {
756        1.0
757    };
758
759    // Apply scaling to vertices in the region
760    let mut result = mesh.clone();
761    let centroid = {
762        let sum: Vector3<f64> = vertex_indices
763            .iter()
764            .filter_map(|&i| mesh.vertices.get(i as usize))
765            .map(|v| v.position.coords)
766            .sum();
767        Point3::from(sum / vertex_indices.len() as f64)
768    };
769
770    for &idx in &vertex_indices {
771        if let Some(vertex) = result.vertices.get_mut(idx as usize) {
772            let offset = vertex.position - centroid;
773            // Scale perpendicular to the measurement direction
774            let along_normal = offset.dot(normal) * normal;
775            let perpendicular = offset - along_normal;
776            let scaled_perpendicular = perpendicular * scale_factor;
777            vertex.position = centroid + along_normal + scaled_perpendicular;
778        }
779    }
780
781    Ok(result)
782}
783
784/// Calculate the overall fit error.
785fn calculate_fit_error(
786    mesh: &Mesh,
787    params: &FitParams,
788    regions: &HashMap<String, ControlRegion>,
789) -> f64 {
790    let mut total_error_sq = 0.0;
791    let mut count = 0;
792
793    // Error from landmark targets
794    for (name, target) in &params.landmark_targets {
795        if let Some(region) = regions.get(name) {
796            let current = match &region.definition {
797                RegionDefinition::Point(p) => *p,
798                RegionDefinition::Vertices(indices) if !indices.is_empty() => {
799                    let sum: Vector3<f64> = indices
800                        .iter()
801                        .filter_map(|&i| mesh.vertices.get(i as usize))
802                        .map(|v| v.position.coords)
803                        .sum();
804                    Point3::from(sum / indices.len() as f64)
805                }
806                _ => continue,
807            };
808
809            let error = (current - target).norm();
810            total_error_sq += error * error * region.weight;
811            count += 1;
812        }
813    }
814
815    if count > 0 {
816        (total_error_sq / count as f64).sqrt()
817    } else {
818        0.0
819    }
820}
821
822#[cfg(test)]
823mod tests {
824    use super::*;
825    use crate::Vertex;
826
827    fn create_test_mesh() -> Mesh {
828        let mut mesh = Mesh::new();
829        // Create a simple box-like mesh
830        mesh.vertices.push(Vertex::from_coords(0.0, 0.0, 0.0));
831        mesh.vertices.push(Vertex::from_coords(10.0, 0.0, 0.0));
832        mesh.vertices.push(Vertex::from_coords(10.0, 10.0, 0.0));
833        mesh.vertices.push(Vertex::from_coords(0.0, 10.0, 0.0));
834        mesh.vertices.push(Vertex::from_coords(5.0, 5.0, 10.0)); // Apex
835
836        mesh.faces.push([0, 1, 4]);
837        mesh.faces.push([1, 2, 4]);
838        mesh.faces.push([2, 3, 4]);
839        mesh.faces.push([3, 0, 4]);
840        mesh.faces.push([0, 3, 2]);
841        mesh.faces.push([0, 2, 1]);
842        mesh
843    }
844
845    #[test]
846    fn test_template_creation() {
847        let mesh = create_test_mesh();
848        let template = FitTemplate::new(mesh)
849            .with_control_region(ControlRegion::point("apex", Point3::new(5.0, 5.0, 10.0)))
850            .with_control_region(ControlRegion::vertices("base", vec![0, 1, 2, 3]));
851
852        assert_eq!(template.control_regions.len(), 2);
853        assert!(template.get_region("apex").is_some());
854        assert!(template.get_region("base").is_some());
855    }
856
857    #[test]
858    fn test_landmark_position() {
859        let mesh = create_test_mesh();
860        let template = FitTemplate::new(mesh)
861            .with_control_region(ControlRegion::point("apex", Point3::new(5.0, 5.0, 10.0)));
862
863        let pos = template.get_landmark_position("apex").unwrap();
864        assert!((pos.x - 5.0).abs() < 1e-10);
865        assert!((pos.y - 5.0).abs() < 1e-10);
866        assert!((pos.z - 10.0).abs() < 1e-10);
867    }
868
869    #[test]
870    fn test_landmark_fitting() {
871        let mesh = create_test_mesh();
872        let template = FitTemplate::new(mesh)
873            .with_control_region(ControlRegion::point("apex", Point3::new(5.0, 5.0, 10.0)))
874            .with_control_region(ControlRegion::point("base1", Point3::new(0.0, 0.0, 0.0)))
875            .with_control_region(ControlRegion::point("base2", Point3::new(10.0, 0.0, 0.0)))
876            .with_control_region(ControlRegion::point("base3", Point3::new(10.0, 10.0, 0.0)));
877
878        // Fit to move the apex higher - need multiple landmarks for numerical stability
879        let params = FitParams::default()
880            .with_landmark_target("apex", Point3::new(5.0, 5.0, 15.0))
881            .with_landmark_target("base1", Point3::new(0.0, 0.0, 0.0))
882            .with_landmark_target("base2", Point3::new(10.0, 0.0, 0.0))
883            .with_landmark_target("base3", Point3::new(10.0, 10.0, 0.0));
884
885        let result = template.fit(&params).unwrap();
886
887        // The apex vertex should have moved upward
888        let apex = result.mesh.vertices[4].position;
889        assert!(apex.z > 10.0, "Apex should have moved up: z={}", apex.z);
890    }
891
892    #[test]
893    fn test_region_vertex_indices_bounds() {
894        let mesh = create_test_mesh();
895
896        // Create a bounds region that includes the base vertices
897        let region = ControlRegion::bounds(
898            "lower_half",
899            Point3::new(-1.0, -1.0, -1.0),
900            Point3::new(11.0, 11.0, 5.0),
901        );
902
903        let indices = region.get_vertex_indices(&mesh);
904        assert_eq!(indices.len(), 4); // Should include vertices 0-3 (base)
905        assert!(indices.contains(&0));
906        assert!(indices.contains(&1));
907        assert!(indices.contains(&2));
908        assert!(indices.contains(&3));
909        assert!(!indices.contains(&4)); // Apex is at z=10, above the bounds
910    }
911
912    #[test]
913    fn test_region_vertex_indices_sphere() {
914        let mesh = create_test_mesh();
915
916        // Create a sphere region centered on the apex
917        let region = ControlRegion::sphere("near_apex", Point3::new(5.0, 5.0, 10.0), 3.0);
918
919        let indices = region.get_vertex_indices(&mesh);
920        assert!(indices.contains(&4)); // Apex should be included
921    }
922
923    #[test]
924    fn test_fit_to_scan() {
925        let template_mesh = create_test_mesh();
926        let template = FitTemplate::new(template_mesh.clone());
927
928        // Create a "scan" that's the same mesh translated
929        let mut scan = template_mesh.clone();
930        for vertex in &mut scan.vertices {
931            vertex.position.x += 5.0;
932        }
933
934        let result = template.fit_to_scan(&scan).unwrap();
935
936        // After fitting, the mesh should be closer to the scan
937        assert!(!result.stages.is_empty());
938    }
939
940    #[test]
941    fn test_fit_stages() {
942        let mesh = create_test_mesh();
943        let template = FitTemplate::new(mesh.clone())
944            .with_control_region(ControlRegion::point("apex", Point3::new(5.0, 5.0, 10.0)))
945            .with_control_region(ControlRegion::point("base1", Point3::new(0.0, 0.0, 0.0)))
946            .with_control_region(ControlRegion::point("base2", Point3::new(10.0, 0.0, 0.0)))
947            .with_control_region(ControlRegion::point("base3", Point3::new(10.0, 10.0, 0.0)));
948
949        let mut scan = mesh.clone();
950        for vertex in &mut scan.vertices {
951            vertex.position.x += 2.0;
952        }
953
954        // Need multiple well-separated landmarks for RBF stability
955        let params = FitParams::default()
956            .with_target_scan(scan)
957            .with_landmark_target("apex", Point3::new(7.0, 5.0, 12.0))
958            .with_landmark_target("base1", Point3::new(2.0, 0.0, 0.0))
959            .with_landmark_target("base2", Point3::new(12.0, 0.0, 0.0))
960            .with_landmark_target("base3", Point3::new(12.0, 10.0, 0.0));
961
962        let result = template.fit(&params).unwrap();
963
964        // Should have both rigid alignment and landmark deformation stages
965        assert!(result.stages.len() >= 2);
966    }
967
968    #[test]
969    fn test_empty_template_error() {
970        let mesh = Mesh::new();
971        let template = FitTemplate::new(mesh);
972
973        let params = FitParams::default();
974        assert!(matches!(
975            template.fit(&params),
976            Err(MeshError::EmptyMesh { .. })
977        ));
978    }
979
980    #[test]
981    fn test_measurement_exact() {
982        let m = Measurement::exact(100.0);
983        assert!((m.value - 100.0).abs() < 1e-10);
984        assert!(!m.is_minimum);
985    }
986
987    #[test]
988    fn test_measurement_minimum() {
989        let m = Measurement::minimum(50.0);
990        assert!(m.is_minimum);
991    }
992
993    #[test]
994    fn test_control_region_weights() {
995        let region = ControlRegion::point("test", Point3::new(0.0, 0.0, 0.0)).with_weight(2.5);
996        assert!((region.weight - 2.5).abs() < 1e-10);
997    }
998
999    #[test]
1000    fn test_preserved_region() {
1001        let region = ControlRegion::point("test", Point3::new(0.0, 0.0, 0.0)).preserved();
1002        assert!(region.preserve);
1003    }
1004
1005    #[test]
1006    fn test_cylinder_region() {
1007        let mesh = create_test_mesh();
1008
1009        // Create a cylinder that should include vertices near the center
1010        // The test mesh has: (0,0,0), (10,0,0), (10,10,0), (0,10,0), (5,5,10)
1011        // Center is approximately (5,5,5) - use a larger radius to catch vertices
1012        let region = ControlRegion::cylinder(
1013            "vertical",
1014            Point3::new(5.0, 5.0, 0.0),
1015            Point3::new(5.0, 5.0, 10.0),
1016            10.0, // Larger radius to include corner vertices
1017        );
1018
1019        let indices = region.get_vertex_indices(&mesh);
1020        // With radius 10, should include apex at (5,5,10) and potentially others
1021        assert!(
1022            !indices.is_empty(),
1023            "Should find at least some vertices in cylinder"
1024        );
1025    }
1026
1027    #[test]
1028    fn test_region_names() {
1029        let mesh = create_test_mesh();
1030        let template = FitTemplate::new(mesh)
1031            .with_control_region(ControlRegion::point("a", Point3::origin()))
1032            .with_control_region(ControlRegion::point("b", Point3::origin()))
1033            .with_control_region(ControlRegion::point("c", Point3::origin()));
1034
1035        let names = template.region_names();
1036        assert_eq!(names.len(), 3);
1037    }
1038
1039    #[test]
1040    fn test_fit_params_builder() {
1041        let params = FitParams::new()
1042            .with_landmark_target("heel", Point3::new(0.0, 0.0, 0.0))
1043            .with_landmark_target("toe", Point3::new(100.0, 0.0, 0.0))
1044            .with_smoothness(2.0)
1045            .with_registration_iterations(50);
1046
1047        assert_eq!(params.landmark_targets.len(), 2);
1048        assert!((params.smoothness - 2.0).abs() < 1e-10);
1049        assert_eq!(params.registration_iterations, 50);
1050    }
1051
1052    #[test]
1053    fn test_fit_result_acceptable() {
1054        let mesh = create_test_mesh();
1055        let template = FitTemplate::new(mesh);
1056
1057        let params = FitParams::default();
1058        let result = template.fit(&params).unwrap();
1059
1060        assert!(result.is_acceptable(1.0)); // With no constraints, error should be 0
1061    }
1062}