Skip to main content

oxihuman_morph/
body_scan_fit_icp.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5#![allow(clippy::too_many_arguments)]
6
7//! ICP registration, SVD math, PLY/OBJ import, and multi-stage scan-fitting
8//! pipeline.
9//!
10//! Provides:
11//! - [`PointCloud`] — f64-precision point cloud with PLY/OBJ import and
12//!   voxel downsampling.
13//! - [`IcpAligner`] / [`IcpResult`] — point-to-point and point-to-plane ICP.
14//! - [`ScanFitter`] / [`ScanFitConfig`] / [`PhotoFitResult`] — multi-stage
15//!   coarse-to-fine alignment + morph parameter fitting.
16
17// ===========================================================================
18// Photogrammetry fitting — PLY/OBJ import, ICP alignment, multi-stage fit
19// ===========================================================================
20
21/// Point cloud from 3D scan (PLY/OBJ import), using f64 precision.
22#[derive(Debug, Clone)]
23pub struct PointCloud {
24    /// 3-D positions `[x, y, z]`.
25    pub points: Vec<[f64; 3]>,
26    /// Optional per-point normals.
27    pub normals: Option<Vec<[f64; 3]>>,
28    /// Optional per-point RGB colours in `[0, 1]`.
29    pub colors: Option<Vec<[f64; 3]>>,
30}
31
32impl PointCloud {
33    /// Parse PLY ASCII format.
34    pub fn from_ply_ascii(data: &str) -> anyhow::Result<Self> {
35        let mut lines = data.lines();
36        let first = lines.next().unwrap_or("");
37        if first.trim() != "ply" {
38            anyhow::bail!("not a PLY file: missing 'ply' magic");
39        }
40        let mut vertex_count: usize = 0;
41        let mut has_normals = false;
42        let mut has_colors = false;
43        let mut in_header = true;
44        let mut prop_order: Vec<String> = Vec::new();
45
46        while in_header {
47            let line = match lines.next() {
48                Some(l) => l.trim(),
49                None => anyhow::bail!("unexpected end of PLY header"),
50            };
51            if line == "end_header" {
52                in_header = false;
53            } else if line.starts_with("element vertex") {
54                let parts: Vec<&str> = line.split_whitespace().collect();
55                if parts.len() >= 3 {
56                    vertex_count = parts[2]
57                        .parse::<usize>()
58                        .map_err(|e| anyhow::anyhow!("bad vertex count: {}", e))?;
59                }
60            } else if line.starts_with("property") {
61                let parts: Vec<&str> = line.split_whitespace().collect();
62                if parts.len() >= 3 {
63                    let name = parts[2].to_lowercase();
64                    prop_order.push(name.clone());
65                    if name == "nx" || name == "ny" || name == "nz" {
66                        has_normals = true;
67                    }
68                    if name == "red" || name == "green" || name == "blue" {
69                        has_colors = true;
70                    }
71                }
72            }
73        }
74
75        let idx = |name: &str| -> Option<usize> { prop_order.iter().position(|s| s == name) };
76        let ix = idx("x");
77        let iy = idx("y");
78        let iz = idx("z");
79        let inx = idx("nx");
80        let iny = idx("ny");
81        let inz = idx("nz");
82        let ir = idx("red");
83        let ig = idx("green");
84        let ib = idx("blue");
85
86        let mut points = Vec::with_capacity(vertex_count);
87        let mut normals_vec: Vec<[f64; 3]> = if has_normals {
88            Vec::with_capacity(vertex_count)
89        } else {
90            Vec::new()
91        };
92        let mut colors_vec: Vec<[f64; 3]> = if has_colors {
93            Vec::with_capacity(vertex_count)
94        } else {
95            Vec::new()
96        };
97
98        for _ in 0..vertex_count {
99            let line = match lines.next() {
100                Some(l) => l.trim(),
101                None => break,
102            };
103            let vals: Vec<f64> = line
104                .split_whitespace()
105                .filter_map(|s| s.parse::<f64>().ok())
106                .collect();
107
108            let x = ix.and_then(|i| vals.get(i).copied()).unwrap_or(0.0);
109            let y = iy.and_then(|i| vals.get(i).copied()).unwrap_or(0.0);
110            let z = iz.and_then(|i| vals.get(i).copied()).unwrap_or(0.0);
111            points.push([x, y, z]);
112
113            if has_normals {
114                let nx = inx.and_then(|i| vals.get(i).copied()).unwrap_or(0.0);
115                let ny = iny.and_then(|i| vals.get(i).copied()).unwrap_or(0.0);
116                let nz = inz.and_then(|i| vals.get(i).copied()).unwrap_or(0.0);
117                normals_vec.push([nx, ny, nz]);
118            }
119            if has_colors {
120                let r = ir.and_then(|i| vals.get(i).copied()).unwrap_or(0.0);
121                let g = ig.and_then(|i| vals.get(i).copied()).unwrap_or(0.0);
122                let b = ib.and_then(|i| vals.get(i).copied()).unwrap_or(0.0);
123                let scale = if r > 1.0 || g > 1.0 || b > 1.0 {
124                    1.0 / 255.0
125                } else {
126                    1.0
127                };
128                colors_vec.push([r * scale, g * scale, b * scale]);
129            }
130        }
131
132        Ok(Self {
133            points,
134            normals: if has_normals { Some(normals_vec) } else { None },
135            colors: if has_colors { Some(colors_vec) } else { None },
136        })
137    }
138
139    /// Parse PLY binary little-endian format.
140    pub fn from_ply_binary_le(data: &[u8]) -> anyhow::Result<Self> {
141        let header_end = find_header_end(data)
142            .ok_or_else(|| anyhow::anyhow!("no end_header found in PLY binary"))?;
143        let header_str = std::str::from_utf8(&data[..header_end])
144            .map_err(|e| anyhow::anyhow!("invalid UTF-8 in PLY header: {}", e))?;
145
146        let mut vertex_count: usize = 0;
147        let mut props: Vec<(String, PlyPropType)> = Vec::new();
148
149        for line in header_str.lines() {
150            let line = line.trim();
151            if line.starts_with("element vertex") {
152                let parts: Vec<&str> = line.split_whitespace().collect();
153                if parts.len() >= 3 {
154                    vertex_count = parts[2]
155                        .parse::<usize>()
156                        .map_err(|e| anyhow::anyhow!("bad vertex count: {}", e))?;
157                }
158            } else if line.starts_with("property") {
159                let parts: Vec<&str> = line.split_whitespace().collect();
160                if parts.len() >= 3 {
161                    let ptype = match parts[1] {
162                        "float" | "float32" => PlyPropType::Float32,
163                        "double" | "float64" => PlyPropType::Float64,
164                        "uchar" | "uint8" => PlyPropType::Uint8,
165                        "int" | "int32" => PlyPropType::Int32,
166                        "short" | "int16" => PlyPropType::Int16,
167                        _ => PlyPropType::Float32,
168                    };
169                    props.push((parts[2].to_lowercase(), ptype));
170                }
171            }
172        }
173
174        let body_start = header_end + "end_header".len();
175        let body_start = data[body_start..]
176            .iter()
177            .position(|&b| b == b'\n')
178            .map(|p| body_start + p + 1)
179            .unwrap_or(body_start);
180
181        let stride: usize = props.iter().map(|(_, t)| t.byte_size()).sum();
182        let prop_idx = |name: &str| -> Option<(usize, PlyPropType)> {
183            let mut offset = 0usize;
184            for (n, t) in &props {
185                if n == name {
186                    return Some((offset, *t));
187                }
188                offset += t.byte_size();
189            }
190            None
191        };
192
193        let has_normals = prop_idx("nx").is_some();
194        let has_colors = prop_idx("red").is_some();
195
196        let mut points = Vec::with_capacity(vertex_count);
197        let mut normals_vec: Vec<[f64; 3]> = Vec::new();
198        let mut colors_vec: Vec<[f64; 3]> = Vec::new();
199        if has_normals {
200            normals_vec.reserve(vertex_count);
201        }
202        if has_colors {
203            colors_vec.reserve(vertex_count);
204        }
205
206        for i in 0..vertex_count {
207            let base = body_start + i * stride;
208            if base + stride > data.len() {
209                break;
210            }
211            let row = &data[base..base + stride];
212
213            let read_f64 = |name: &str| -> f64 {
214                if let Some((off, t)) = prop_idx(name) {
215                    if off + t.byte_size() <= row.len() {
216                        t.read_le_f64(&row[off..])
217                    } else {
218                        0.0
219                    }
220                } else {
221                    0.0
222                }
223            };
224
225            points.push([read_f64("x"), read_f64("y"), read_f64("z")]);
226
227            if has_normals {
228                normals_vec.push([read_f64("nx"), read_f64("ny"), read_f64("nz")]);
229            }
230            if has_colors {
231                let r = read_f64("red");
232                let g = read_f64("green");
233                let b = read_f64("blue");
234                let scale = if r > 1.0 || g > 1.0 || b > 1.0 {
235                    1.0 / 255.0
236                } else {
237                    1.0
238                };
239                colors_vec.push([r * scale, g * scale, b * scale]);
240            }
241        }
242
243        Ok(Self {
244            points,
245            normals: if has_normals { Some(normals_vec) } else { None },
246            colors: if has_colors { Some(colors_vec) } else { None },
247        })
248    }
249
250    /// Parse OBJ vertex data (vertices only, ignore faces).
251    pub fn from_obj_vertices(data: &str) -> anyhow::Result<Self> {
252        let mut points = Vec::new();
253        let mut normals_vec = Vec::new();
254
255        for line in data.lines() {
256            let line = line.trim();
257            if let Some(rest) = line.strip_prefix("vn ") {
258                let vals: Vec<f64> = rest
259                    .split_whitespace()
260                    .filter_map(|s| s.parse::<f64>().ok())
261                    .collect();
262                if vals.len() >= 3 {
263                    normals_vec.push([vals[0], vals[1], vals[2]]);
264                }
265            } else if let Some(rest) = line.strip_prefix("v ") {
266                let vals: Vec<f64> = rest
267                    .split_whitespace()
268                    .filter_map(|s| s.parse::<f64>().ok())
269                    .collect();
270                if vals.len() >= 3 {
271                    points.push([vals[0], vals[1], vals[2]]);
272                }
273            }
274        }
275
276        let normals = if normals_vec.len() == points.len() && !normals_vec.is_empty() {
277            Some(normals_vec)
278        } else {
279            None
280        };
281
282        Ok(Self {
283            points,
284            normals,
285            colors: None,
286        })
287    }
288
289    /// Downsample by voxel grid.
290    pub fn voxel_downsample(&self, voxel_size: f64) -> Self {
291        if self.points.is_empty() || voxel_size <= 0.0 {
292            return self.clone();
293        }
294        let inv = 1.0 / voxel_size;
295        let mut buckets: std::collections::HashMap<(i64, i64, i64), VoxelAccum> =
296            std::collections::HashMap::new();
297
298        let has_normals = self.normals.is_some();
299        let has_colors = self.colors.is_some();
300
301        for (idx, p) in self.points.iter().enumerate() {
302            let key = (
303                (p[0] * inv).floor() as i64,
304                (p[1] * inv).floor() as i64,
305                (p[2] * inv).floor() as i64,
306            );
307            let entry = buckets.entry(key).or_insert_with(|| VoxelAccum {
308                sum_pos: [0.0; 3],
309                sum_nrm: [0.0; 3],
310                sum_col: [0.0; 3],
311                count: 0,
312            });
313            entry.sum_pos[0] += p[0];
314            entry.sum_pos[1] += p[1];
315            entry.sum_pos[2] += p[2];
316            entry.count += 1;
317
318            if let Some(ref nrms) = self.normals {
319                if let Some(n) = nrms.get(idx) {
320                    entry.sum_nrm[0] += n[0];
321                    entry.sum_nrm[1] += n[1];
322                    entry.sum_nrm[2] += n[2];
323                }
324            }
325            if let Some(ref cols) = self.colors {
326                if let Some(c) = cols.get(idx) {
327                    entry.sum_col[0] += c[0];
328                    entry.sum_col[1] += c[1];
329                    entry.sum_col[2] += c[2];
330                }
331            }
332        }
333
334        let n_out = buckets.len();
335        let mut points = Vec::with_capacity(n_out);
336        let mut normals_out = if has_normals {
337            Vec::with_capacity(n_out)
338        } else {
339            Vec::new()
340        };
341        let mut colors_out = if has_colors {
342            Vec::with_capacity(n_out)
343        } else {
344            Vec::new()
345        };
346
347        for acc in buckets.values() {
348            let inv_n = 1.0 / (acc.count as f64);
349            points.push([
350                acc.sum_pos[0] * inv_n,
351                acc.sum_pos[1] * inv_n,
352                acc.sum_pos[2] * inv_n,
353            ]);
354            if has_normals {
355                let n = [
356                    acc.sum_nrm[0] * inv_n,
357                    acc.sum_nrm[1] * inv_n,
358                    acc.sum_nrm[2] * inv_n,
359                ];
360                let len = (n[0] * n[0] + n[1] * n[1] + n[2] * n[2]).sqrt().max(1e-12);
361                normals_out.push([n[0] / len, n[1] / len, n[2] / len]);
362            }
363            if has_colors {
364                colors_out.push([
365                    acc.sum_col[0] * inv_n,
366                    acc.sum_col[1] * inv_n,
367                    acc.sum_col[2] * inv_n,
368                ]);
369            }
370        }
371
372        Self {
373            points,
374            normals: if has_normals { Some(normals_out) } else { None },
375            colors: if has_colors { Some(colors_out) } else { None },
376        }
377    }
378
379    /// Remove statistical outliers.
380    pub fn remove_outliers(&self, k_neighbors: usize, std_ratio: f64) -> Self {
381        if self.points.len() <= k_neighbors + 1 {
382            return self.clone();
383        }
384        let n = self.points.len();
385        let k = k_neighbors.min(n - 1).max(1);
386
387        let mean_dists: Vec<f64> = self
388            .points
389            .iter()
390            .enumerate()
391            .map(|(i, p)| {
392                let mut dists: Vec<f64> = self
393                    .points
394                    .iter()
395                    .enumerate()
396                    .filter_map(|(j, q)| if j == i { None } else { Some(dist3(p, q)) })
397                    .collect();
398                dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
399                dists.iter().take(k).sum::<f64>() / k as f64
400            })
401            .collect();
402
403        let global_mean = mean_dists.iter().sum::<f64>() / n as f64;
404        let variance = mean_dists
405            .iter()
406            .map(|d| (d - global_mean).powi(2))
407            .sum::<f64>()
408            / n as f64;
409        let global_std = variance.sqrt();
410        let threshold = global_mean + std_ratio * global_std;
411
412        let keep: Vec<usize> = mean_dists
413            .iter()
414            .enumerate()
415            .filter(|(_, d)| **d <= threshold)
416            .map(|(i, _)| i)
417            .collect();
418
419        let points: Vec<[f64; 3]> = keep.iter().map(|&i| self.points[i]).collect();
420        let normals = self
421            .normals
422            .as_ref()
423            .map(|nv| keep.iter().map(|&i| nv[i]).collect());
424        let colors = self
425            .colors
426            .as_ref()
427            .map(|cv| keep.iter().map(|&i| cv[i]).collect());
428
429        Self {
430            points,
431            normals,
432            colors,
433        }
434    }
435
436    /// Compute the centroid (f64).
437    fn centroid_f64(&self) -> [f64; 3] {
438        if self.points.is_empty() {
439            return [0.0; 3];
440        }
441        let n = self.points.len() as f64;
442        let mut s = [0.0_f64; 3];
443        for p in &self.points {
444            s[0] += p[0];
445            s[1] += p[1];
446            s[2] += p[2];
447        }
448        [s[0] / n, s[1] / n, s[2] / n]
449    }
450}
451
452// ---------------------------------------------------------------------------
453// PLY binary helpers
454// ---------------------------------------------------------------------------
455
456fn find_header_end(data: &[u8]) -> Option<usize> {
457    let needle = b"end_header";
458    data.windows(needle.len()).position(|w| w == needle)
459}
460
461#[derive(Debug, Clone, Copy)]
462enum PlyPropType {
463    Float32,
464    Float64,
465    Uint8,
466    Int32,
467    Int16,
468}
469
470impl PlyPropType {
471    fn byte_size(self) -> usize {
472        match self {
473            Self::Float32 => 4,
474            Self::Float64 => 8,
475            Self::Uint8 => 1,
476            Self::Int32 => 4,
477            Self::Int16 => 2,
478        }
479    }
480
481    fn read_le_f64(self, buf: &[u8]) -> f64 {
482        match self {
483            Self::Float32 => {
484                if buf.len() >= 4 {
485                    f32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as f64
486                } else {
487                    0.0
488                }
489            }
490            Self::Float64 => {
491                if buf.len() >= 8 {
492                    f64::from_le_bytes([
493                        buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
494                    ])
495                } else {
496                    0.0
497                }
498            }
499            Self::Uint8 => {
500                if !buf.is_empty() {
501                    buf[0] as f64
502                } else {
503                    0.0
504                }
505            }
506            Self::Int32 => {
507                if buf.len() >= 4 {
508                    i32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as f64
509                } else {
510                    0.0
511                }
512            }
513            Self::Int16 => {
514                if buf.len() >= 2 {
515                    i16::from_le_bytes([buf[0], buf[1]]) as f64
516                } else {
517                    0.0
518                }
519            }
520        }
521    }
522}
523
524struct VoxelAccum {
525    sum_pos: [f64; 3],
526    sum_nrm: [f64; 3],
527    sum_col: [f64; 3],
528    count: usize,
529}
530
531// ---------------------------------------------------------------------------
532// 3-D math helpers (f64)
533// ---------------------------------------------------------------------------
534
535fn dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
536    let dx = a[0] - b[0];
537    let dy = a[1] - b[1];
538    let dz = a[2] - b[2];
539    (dx * dx + dy * dy + dz * dz).sqrt()
540}
541
542fn dist3_sq(a: &[f64; 3], b: &[f64; 3]) -> f64 {
543    let dx = a[0] - b[0];
544    let dy = a[1] - b[1];
545    let dz = a[2] - b[2];
546    dx * dx + dy * dy + dz * dz
547}
548
549fn vec3_sub(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] {
550    [a[0] - b[0], a[1] - b[1], a[2] - b[2]]
551}
552
553fn vec3_dot(a: &[f64; 3], b: &[f64; 3]) -> f64 {
554    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
555}
556
557fn mat3_identity() -> [[f64; 3]; 3] {
558    [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
559}
560
561fn mat3_mul(a: &[[f64; 3]; 3], b: &[[f64; 3]; 3]) -> [[f64; 3]; 3] {
562    let mut c = [[0.0_f64; 3]; 3];
563    for i in 0..3 {
564        for j in 0..3 {
565            c[i][j] = a[i][0] * b[0][j] + a[i][1] * b[1][j] + a[i][2] * b[2][j];
566        }
567    }
568    c
569}
570
571fn mat3_transpose(m: &[[f64; 3]; 3]) -> [[f64; 3]; 3] {
572    [
573        [m[0][0], m[1][0], m[2][0]],
574        [m[0][1], m[1][1], m[2][1]],
575        [m[0][2], m[1][2], m[2][2]],
576    ]
577}
578
579fn mat3_det(m: &[[f64; 3]; 3]) -> f64 {
580    m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
581        - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
582        + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0])
583}
584
585fn mat3_vec(m: &[[f64; 3]; 3], v: &[f64; 3]) -> [f64; 3] {
586    [
587        m[0][0] * v[0] + m[0][1] * v[1] + m[0][2] * v[2],
588        m[1][0] * v[0] + m[1][1] * v[1] + m[1][2] * v[2],
589        m[2][0] * v[0] + m[2][1] * v[1] + m[2][2] * v[2],
590    ]
591}
592
593fn centroid_of(pts: &[[f64; 3]]) -> [f64; 3] {
594    if pts.is_empty() {
595        return [0.0; 3];
596    }
597    let n = pts.len() as f64;
598    let mut s = [0.0; 3];
599    for p in pts {
600        s[0] += p[0];
601        s[1] += p[1];
602        s[2] += p[2];
603    }
604    [s[0] / n, s[1] / n, s[2] / n]
605}
606
607// ---------------------------------------------------------------------------
608// 3x3 SVD via Jacobi rotations (pure Rust)
609// ---------------------------------------------------------------------------
610
611struct Svd3 {
612    u: [[f64; 3]; 3],
613    s: [f64; 3],
614    vt: [[f64; 3]; 3],
615}
616
617fn jacobi_rotation_sym(a: &[[f64; 3]; 3], p: usize, q: usize) -> (f64, f64) {
618    let apq = a[p][q];
619    if apq.abs() < 1e-15 {
620        return (1.0, 0.0);
621    }
622    let tau = (a[q][q] - a[p][p]) / (2.0 * apq);
623    let t = if tau.abs() > 1e15 {
624        1.0 / (2.0 * tau)
625    } else {
626        let sign_tau = if tau >= 0.0 { 1.0 } else { -1.0 };
627        sign_tau / (tau.abs() + (1.0 + tau * tau).sqrt())
628    };
629    let c = 1.0 / (1.0 + t * t).sqrt();
630    let s = t * c;
631    (c, s)
632}
633
634fn apply_jacobi_sym(a: &mut [[f64; 3]; 3], p: usize, q: usize, c: f64, s: f64) {
635    let mut tmp = *a;
636    for k in 0..3 {
637        tmp[p][k] = c * a[p][k] - s * a[q][k];
638        tmp[q][k] = s * a[p][k] + c * a[q][k];
639    }
640    let a2 = tmp;
641    for k in 0..3 {
642        tmp[k][p] = c * a2[k][p] - s * a2[k][q];
643        tmp[k][q] = s * a2[k][p] + c * a2[k][q];
644    }
645    *a = tmp;
646}
647
648fn apply_jacobi_vec(v: &mut [[f64; 3]; 3], p: usize, q: usize, c: f64, s: f64) {
649    for row in v.iter_mut() {
650        let vp = row[p];
651        let vq = row[q];
652        row[p] = c * vp - s * vq;
653        row[q] = s * vp + c * vq;
654    }
655}
656
657fn sym_eigen3(m: &[[f64; 3]; 3]) -> ([f64; 3], [[f64; 3]; 3]) {
658    let mut a = *m;
659    let mut v = mat3_identity();
660    let max_iter = 100;
661
662    for _ in 0..max_iter {
663        let pairs: [(usize, usize); 3] = [(0, 1), (0, 2), (1, 2)];
664        let mut max_off = 0.0_f64;
665        for &(p, q) in &pairs {
666            let val = a[p][q].abs();
667            if val > max_off {
668                max_off = val;
669            }
670        }
671        if max_off < 1e-14 {
672            break;
673        }
674        for &(p, q) in &pairs {
675            if a[p][q].abs() < 1e-15 {
676                continue;
677            }
678            let (c, s) = jacobi_rotation_sym(&a, p, q);
679            apply_jacobi_sym(&mut a, p, q, c, s);
680            apply_jacobi_vec(&mut v, p, q, c, s);
681        }
682    }
683
684    ([a[0][0], a[1][1], a[2][2]], v)
685}
686
687fn svd3(m: &[[f64; 3]; 3]) -> Svd3 {
688    let mt = mat3_transpose(m);
689    let ata = mat3_mul(&mt, m);
690    let (eigenvalues, v_cols) = sym_eigen3(&ata);
691
692    let mut s = [0.0_f64; 3];
693    for i in 0..3 {
694        s[i] = eigenvalues[i].max(0.0).sqrt();
695    }
696
697    let mut order = [0usize, 1, 2];
698    if s[order[1]] > s[order[0]] {
699        order.swap(0, 1);
700    }
701    if s[order[2]] > s[order[0]] {
702        order.swap(0, 2);
703    }
704    if s[order[2]] > s[order[1]] {
705        order.swap(1, 2);
706    }
707
708    let s_sorted = [s[order[0]], s[order[1]], s[order[2]]];
709
710    let mut v_mat = [[0.0_f64; 3]; 3];
711    for i in 0..3 {
712        for j in 0..3 {
713            v_mat[i][j] = v_cols[i][order[j]];
714        }
715    }
716
717    let mv = mat3_mul(m, &v_mat);
718    let mut u_mat = [[0.0_f64; 3]; 3];
719    for j in 0..3 {
720        let inv_s = if s_sorted[j] > 1e-12 {
721            1.0 / s_sorted[j]
722        } else {
723            0.0
724        };
725        for i in 0..3 {
726            u_mat[i][j] = mv[i][j] * inv_s;
727        }
728    }
729
730    let det_u = mat3_det(&u_mat);
731    let det_v = mat3_det(&v_mat);
732    let mut s_final = s_sorted;
733
734    if det_u < 0.0 {
735        for row in u_mat.iter_mut() {
736            row[2] = -row[2];
737        }
738        s_final[2] = -s_final[2];
739    }
740    if det_v < 0.0 {
741        for row in v_mat.iter_mut() {
742            row[2] = -row[2];
743        }
744        s_final[2] = -s_final[2];
745    }
746
747    Svd3 {
748        u: u_mat,
749        s: s_final,
750        vt: mat3_transpose(&v_mat),
751    }
752}
753
754// ---------------------------------------------------------------------------
755// ICP (Iterative Closest Point)
756// ---------------------------------------------------------------------------
757
758/// ICP (Iterative Closest Point) alignment algorithm.
759#[derive(Debug, Clone)]
760pub struct IcpAligner {
761    /// Maximum number of ICP iterations.
762    pub max_iterations: usize,
763    /// Convergence threshold on RMSE change.
764    pub convergence_threshold: f64,
765    /// Maximum correspondence distance; pairs farther apart are rejected.
766    pub max_correspondence_distance: f64,
767}
768
769/// Result of ICP alignment.
770#[derive(Debug, Clone)]
771pub struct IcpResult {
772    /// 3x3 rotation matrix.
773    pub rotation: [[f64; 3]; 3],
774    /// Translation vector.
775    pub translation: [f64; 3],
776    /// Uniform scale factor.
777    pub scale: f64,
778    /// Fraction of source points with a valid correspondence.
779    pub fitness: f64,
780    /// Root mean square error of corresponding pairs.
781    pub rmse: f64,
782    /// Number of ICP iterations executed.
783    pub iterations: usize,
784}
785
786impl IcpAligner {
787    /// Create a new ICP aligner.
788    pub fn new(max_iterations: usize, convergence_threshold: f64) -> Self {
789        Self {
790            max_iterations,
791            convergence_threshold,
792            max_correspondence_distance: f64::MAX,
793        }
794    }
795
796    /// Set the maximum correspondence distance for rejecting outlier pairs.
797    pub fn with_max_correspondence_distance(mut self, d: f64) -> Self {
798        self.max_correspondence_distance = d;
799        self
800    }
801
802    /// Align source point cloud to target using point-to-point ICP.
803    pub fn align_point_to_point(
804        &self,
805        source: &[[f64; 3]],
806        target: &[[f64; 3]],
807    ) -> anyhow::Result<IcpResult> {
808        if source.is_empty() || target.is_empty() {
809            anyhow::bail!("ICP requires non-empty point sets");
810        }
811
812        let mut src: Vec<[f64; 3]> = source.to_vec();
813        let mut cumulative_rot = mat3_identity();
814        let mut cumulative_trans = [0.0_f64; 3];
815        let mut cumulative_scale = 1.0_f64;
816        let mut prev_rmse = f64::MAX;
817        let mut iters = 0usize;
818
819        for _ in 0..self.max_iterations {
820            iters += 1;
821
822            let (src_matched, tgt_matched) =
823                find_correspondences(&src, target, self.max_correspondence_distance);
824
825            if src_matched.len() < 3 {
826                break;
827            }
828
829            let (rot, trans, scale) = compute_rigid_transform(&src_matched, &tgt_matched);
830
831            for p in &mut src {
832                let rotated = mat3_vec(&rot, p);
833                p[0] = rotated[0] * scale + trans[0];
834                p[1] = rotated[1] * scale + trans[1];
835                p[2] = rotated[2] * scale + trans[2];
836            }
837
838            let new_rot = mat3_mul(&rot, &cumulative_rot);
839            let ct_rotated = mat3_vec(&rot, &cumulative_trans);
840            let new_trans = [
841                scale * ct_rotated[0] + trans[0],
842                scale * ct_rotated[1] + trans[1],
843                scale * ct_rotated[2] + trans[2],
844            ];
845            let new_scale = scale * cumulative_scale;
846
847            cumulative_rot = new_rot;
848            cumulative_trans = new_trans;
849            cumulative_scale = new_scale;
850
851            let rmse = compute_rmse(&src_matched, &tgt_matched);
852
853            if (prev_rmse - rmse).abs() < self.convergence_threshold {
854                break;
855            }
856            prev_rmse = rmse;
857        }
858
859        let (final_src, final_tgt) =
860            find_correspondences(&src, target, self.max_correspondence_distance);
861        let fitness = final_src.len() as f64 / source.len().max(1) as f64;
862        let rmse = if final_src.is_empty() {
863            f64::MAX
864        } else {
865            compute_rmse(&final_src, &final_tgt)
866        };
867
868        Ok(IcpResult {
869            rotation: cumulative_rot,
870            translation: cumulative_trans,
871            scale: cumulative_scale,
872            fitness,
873            rmse,
874            iterations: iters,
875        })
876    }
877
878    /// Align using point-to-plane ICP (requires normals on target).
879    pub fn align_point_to_plane(
880        &self,
881        source: &[[f64; 3]],
882        target: &[[f64; 3]],
883        target_normals: &[[f64; 3]],
884    ) -> anyhow::Result<IcpResult> {
885        if source.is_empty() || target.is_empty() {
886            anyhow::bail!("ICP requires non-empty point sets");
887        }
888        if target.len() != target_normals.len() {
889            anyhow::bail!("target and target_normals must have the same length");
890        }
891
892        let mut src: Vec<[f64; 3]> = source.to_vec();
893        let mut cumulative_rot = mat3_identity();
894        let mut cumulative_trans = [0.0_f64; 3];
895        let mut prev_rmse = f64::MAX;
896        let mut iters = 0usize;
897
898        for _ in 0..self.max_iterations {
899            iters += 1;
900
901            let (src_idx, tgt_idx) =
902                find_correspondence_indices(&src, target, self.max_correspondence_distance);
903
904            if src_idx.len() < 6 {
905                break;
906            }
907
908            let (delta_rot_vec, delta_trans) =
909                solve_point_to_plane_step(&src, &src_idx, target, target_normals, &tgt_idx);
910
911            let rot_inc = small_angle_rotation(&delta_rot_vec);
912
913            for p in &mut src {
914                let rotated = mat3_vec(&rot_inc, p);
915                p[0] = rotated[0] + delta_trans[0];
916                p[1] = rotated[1] + delta_trans[1];
917                p[2] = rotated[2] + delta_trans[2];
918            }
919
920            let new_rot = mat3_mul(&rot_inc, &cumulative_rot);
921            let ct_rotated = mat3_vec(&rot_inc, &cumulative_trans);
922            cumulative_rot = new_rot;
923            cumulative_trans = [
924                ct_rotated[0] + delta_trans[0],
925                ct_rotated[1] + delta_trans[1],
926                ct_rotated[2] + delta_trans[2],
927            ];
928
929            let matched_src: Vec<[f64; 3]> = src_idx.iter().map(|&i| src[i]).collect();
930            let matched_tgt: Vec<[f64; 3]> = tgt_idx.iter().map(|&i| target[i]).collect();
931            let rmse = compute_rmse(&matched_src, &matched_tgt);
932
933            if (prev_rmse - rmse).abs() < self.convergence_threshold {
934                break;
935            }
936            prev_rmse = rmse;
937        }
938
939        let (final_src_idx, final_tgt_idx) =
940            find_correspondence_indices(&src, target, self.max_correspondence_distance);
941        let fitness = final_src_idx.len() as f64 / source.len().max(1) as f64;
942        let rmse = if final_src_idx.is_empty() {
943            f64::MAX
944        } else {
945            let ms: Vec<[f64; 3]> = final_src_idx.iter().map(|&i| src[i]).collect();
946            let mt: Vec<[f64; 3]> = final_tgt_idx.iter().map(|&i| target[i]).collect();
947            compute_rmse(&ms, &mt)
948        };
949
950        Ok(IcpResult {
951            rotation: cumulative_rot,
952            translation: cumulative_trans,
953            scale: 1.0,
954            fitness,
955            rmse,
956            iterations: iters,
957        })
958    }
959
960    /// Apply a rigid transform (rotation, translation, scale) to points in-place.
961    pub fn transform_points(
962        points: &mut [[f64; 3]],
963        rotation: &[[f64; 3]; 3],
964        translation: &[f64; 3],
965        scale: f64,
966    ) {
967        for p in points.iter_mut() {
968            let r = mat3_vec(rotation, p);
969            p[0] = r[0] * scale + translation[0];
970            p[1] = r[1] * scale + translation[1];
971            p[2] = r[2] * scale + translation[2];
972        }
973    }
974}
975
976// ---------------------------------------------------------------------------
977// ICP helper functions
978// ---------------------------------------------------------------------------
979
980fn find_correspondences(
981    source: &[[f64; 3]],
982    target: &[[f64; 3]],
983    max_dist: f64,
984) -> (Vec<[f64; 3]>, Vec<[f64; 3]>) {
985    let max_dist_sq = max_dist * max_dist;
986    let mut src_out = Vec::new();
987    let mut tgt_out = Vec::new();
988
989    for sp in source {
990        let mut best_dist_sq = f64::MAX;
991        let mut best_pt = [0.0_f64; 3];
992        for tp in target {
993            let d2 = dist3_sq(sp, tp);
994            if d2 < best_dist_sq {
995                best_dist_sq = d2;
996                best_pt = *tp;
997            }
998        }
999        if best_dist_sq <= max_dist_sq {
1000            src_out.push(*sp);
1001            tgt_out.push(best_pt);
1002        }
1003    }
1004
1005    (src_out, tgt_out)
1006}
1007
1008fn find_correspondence_indices(
1009    source: &[[f64; 3]],
1010    target: &[[f64; 3]],
1011    max_dist: f64,
1012) -> (Vec<usize>, Vec<usize>) {
1013    let max_dist_sq = max_dist * max_dist;
1014    let mut src_idx = Vec::new();
1015    let mut tgt_idx = Vec::new();
1016
1017    for (si, sp) in source.iter().enumerate() {
1018        let mut best_dist_sq = f64::MAX;
1019        let mut best_idx = 0usize;
1020        for (ti, tp) in target.iter().enumerate() {
1021            let d2 = dist3_sq(sp, tp);
1022            if d2 < best_dist_sq {
1023                best_dist_sq = d2;
1024                best_idx = ti;
1025            }
1026        }
1027        if best_dist_sq <= max_dist_sq {
1028            src_idx.push(si);
1029            tgt_idx.push(best_idx);
1030        }
1031    }
1032
1033    (src_idx, tgt_idx)
1034}
1035
1036fn compute_rigid_transform(
1037    source: &[[f64; 3]],
1038    target: &[[f64; 3]],
1039) -> ([[f64; 3]; 3], [f64; 3], f64) {
1040    let c_src = centroid_of(source);
1041    let c_tgt = centroid_of(target);
1042
1043    let src_c: Vec<[f64; 3]> = source.iter().map(|p| vec3_sub(p, &c_src)).collect();
1044    let tgt_c: Vec<[f64; 3]> = target.iter().map(|p| vec3_sub(p, &c_tgt)).collect();
1045
1046    let mut h = [[0.0_f64; 3]; 3];
1047    for (s, t) in src_c.iter().zip(tgt_c.iter()) {
1048        for i in 0..3 {
1049            for j in 0..3 {
1050                h[i][j] += s[i] * t[j];
1051            }
1052        }
1053    }
1054
1055    let svd = svd3(&h);
1056    let ut = mat3_transpose(&svd.u);
1057    let vt_t = mat3_transpose(&svd.vt);
1058    let mut rot = mat3_mul(&vt_t, &ut);
1059
1060    if mat3_det(&rot) < 0.0 {
1061        let mut v_fixed = vt_t;
1062        for row in v_fixed.iter_mut() {
1063            row[2] = -row[2];
1064        }
1065        rot = mat3_mul(&v_fixed, &ut);
1066    }
1067
1068    let src_var: f64 = src_c.iter().map(|p| vec3_dot(p, p)).sum();
1069    let scale = if src_var > 1e-12 {
1070        let tgt_var: f64 = tgt_c.iter().map(|p| vec3_dot(p, p)).sum();
1071        (tgt_var / src_var).sqrt()
1072    } else {
1073        1.0
1074    };
1075
1076    let r_csrc = mat3_vec(&rot, &c_src);
1077    let trans = [
1078        c_tgt[0] - scale * r_csrc[0],
1079        c_tgt[1] - scale * r_csrc[1],
1080        c_tgt[2] - scale * r_csrc[2],
1081    ];
1082
1083    (rot, trans, scale)
1084}
1085
1086fn compute_rmse(a: &[[f64; 3]], b: &[[f64; 3]]) -> f64 {
1087    if a.is_empty() {
1088        return 0.0;
1089    }
1090    let sum: f64 = a.iter().zip(b.iter()).map(|(p, q)| dist3_sq(p, q)).sum();
1091    (sum / a.len() as f64).sqrt()
1092}
1093
1094fn small_angle_rotation(w: &[f64; 3]) -> [[f64; 3]; 3] {
1095    let (a, b, g) = (w[0], w[1], w[2]);
1096    let theta = (a * a + b * b + g * g).sqrt();
1097    if theta < 1e-12 {
1098        return mat3_identity();
1099    }
1100    let k = [a / theta, b / theta, g / theta];
1101    let ct = theta.cos();
1102    let st = theta.sin();
1103    let omc = 1.0 - ct;
1104
1105    [
1106        [
1107            ct + k[0] * k[0] * omc,
1108            k[0] * k[1] * omc - k[2] * st,
1109            k[0] * k[2] * omc + k[1] * st,
1110        ],
1111        [
1112            k[1] * k[0] * omc + k[2] * st,
1113            ct + k[1] * k[1] * omc,
1114            k[1] * k[2] * omc - k[0] * st,
1115        ],
1116        [
1117            k[2] * k[0] * omc - k[1] * st,
1118            k[2] * k[1] * omc + k[0] * st,
1119            ct + k[2] * k[2] * omc,
1120        ],
1121    ]
1122}
1123
1124fn solve_point_to_plane_step(
1125    source: &[[f64; 3]],
1126    src_idx: &[usize],
1127    target: &[[f64; 3]],
1128    target_normals: &[[f64; 3]],
1129    tgt_idx: &[usize],
1130) -> ([f64; 3], [f64; 3]) {
1131    let mut ata = [[0.0_f64; 6]; 6];
1132    let mut atb = [0.0_f64; 6];
1133
1134    for (&si, &ti) in src_idx.iter().zip(tgt_idx.iter()) {
1135        let s = &source[si];
1136        let t = &target[ti];
1137        let n = &target_normals[ti];
1138
1139        let d = vec3_sub(s, t);
1140        let r = vec3_dot(n, &d);
1141
1142        let cn = [
1143            s[1] * n[2] - s[2] * n[1],
1144            s[2] * n[0] - s[0] * n[2],
1145            s[0] * n[1] - s[1] * n[0],
1146        ];
1147        let row = [cn[0], cn[1], cn[2], n[0], n[1], n[2]];
1148
1149        for i in 0..6 {
1150            for j in 0..6 {
1151                ata[i][j] += row[i] * row[j];
1152            }
1153            atb[i] += row[i] * (-r);
1154        }
1155    }
1156
1157    let x = solve_6x6(&ata, &atb);
1158    ([x[0], x[1], x[2]], [x[3], x[4], x[5]])
1159}
1160
1161#[allow(clippy::needless_range_loop)]
1162fn solve_6x6(a: &[[f64; 6]; 6], b: &[f64; 6]) -> [f64; 6] {
1163    let mut aug = [[0.0_f64; 7]; 6];
1164    for i in 0..6 {
1165        for j in 0..6 {
1166            aug[i][j] = a[i][j];
1167        }
1168        aug[i][6] = b[i];
1169    }
1170
1171    for col in 0..6 {
1172        let mut max_row = col;
1173        let mut max_val = aug[col][col].abs();
1174        for row in (col + 1)..6 {
1175            if aug[row][col].abs() > max_val {
1176                max_val = aug[row][col].abs();
1177                max_row = row;
1178            }
1179        }
1180        if max_val < 1e-15 {
1181            continue;
1182        }
1183        if max_row != col {
1184            aug.swap(col, max_row);
1185        }
1186
1187        let pivot = aug[col][col];
1188        for row in (col + 1)..6 {
1189            let factor = aug[row][col] / pivot;
1190            for j in col..7 {
1191                aug[row][j] -= factor * aug[col][j];
1192            }
1193        }
1194    }
1195
1196    let mut x = [0.0_f64; 6];
1197    for col in (0..6).rev() {
1198        if aug[col][col].abs() < 1e-15 {
1199            x[col] = 0.0;
1200            continue;
1201        }
1202        let mut sum = aug[col][6];
1203        for j in (col + 1)..6 {
1204            sum -= aug[col][j] * x[j];
1205        }
1206        x[col] = sum / aug[col][col];
1207    }
1208    x
1209}
1210
1211// ===========================================================================
1212// Multi-stage body scan fitting pipeline
1213// ===========================================================================
1214
1215/// Configuration for the multi-stage scan fitting pipeline.
1216#[derive(Debug, Clone)]
1217pub struct ScanFitConfig {
1218    /// Number of ICP iterations per stage.
1219    pub icp_iterations: usize,
1220    /// Number of morph gradient-descent iterations.
1221    pub morph_iterations: usize,
1222    /// Voxel size for coarse downsampling (metres).
1223    pub coarse_voxel_size: f64,
1224    /// Voxel size for fine downsampling (metres).
1225    pub fine_voxel_size: f64,
1226    /// Regularisation weight on morph parameters (L2 penalty).
1227    pub regularization: f64,
1228}
1229
1230impl Default for ScanFitConfig {
1231    fn default() -> Self {
1232        Self {
1233            icp_iterations: 50,
1234            morph_iterations: 100,
1235            coarse_voxel_size: 0.02,
1236            fine_voxel_size: 0.005,
1237            regularization: 0.01,
1238        }
1239    }
1240}
1241
1242/// Result of the multi-stage scan fitting pipeline.
1243#[derive(Debug, Clone)]
1244pub struct PhotoFitResult {
1245    /// Fitted morph parameters `(name, weight)`.
1246    pub morph_parameters: Vec<(String, f64)>,
1247    /// ICP alignment result from the fine stage.
1248    pub alignment: IcpResult,
1249    /// Final mean surface distance error (metres).
1250    pub surface_error: f64,
1251    /// Number of fitting stages completed (0-3).
1252    pub stages_completed: usize,
1253}
1254
1255/// Multi-stage body scan fitting pipeline.
1256#[derive(Debug, Clone)]
1257pub struct ScanFitter {
1258    config: ScanFitConfig,
1259}
1260
1261impl ScanFitter {
1262    /// Create a new scan fitter with the given configuration.
1263    pub fn new(config: ScanFitConfig) -> Self {
1264        Self { config }
1265    }
1266
1267    /// Run the full pipeline: import -> downsample -> align -> fit morphs.
1268    pub fn fit(
1269        &self,
1270        scan_cloud: &PointCloud,
1271        template_vertices: &[[f64; 3]],
1272        template_triangles: &[[usize; 3]],
1273        morph_targets: &[(String, Vec<[f64; 3]>)],
1274    ) -> anyhow::Result<PhotoFitResult> {
1275        if scan_cloud.points.is_empty() {
1276            anyhow::bail!("scan point cloud is empty");
1277        }
1278        if template_vertices.is_empty() {
1279            anyhow::bail!("template mesh has no vertices");
1280        }
1281
1282        // Stage 1: Coarse alignment
1283        let coarse_scan = scan_cloud.voxel_downsample(self.config.coarse_voxel_size);
1284        let coarse_template =
1285            voxel_downsample_slice(template_vertices, self.config.coarse_voxel_size);
1286
1287        let coarse_icp = IcpAligner::new(self.config.icp_iterations, 1e-6);
1288        let coarse_result =
1289            coarse_icp.align_point_to_point(&coarse_scan.points, &coarse_template)?;
1290        // stages_completed: 1
1291
1292        let mut aligned_scan: Vec<[f64; 3]> = scan_cloud.points.clone();
1293        IcpAligner::transform_points(
1294            &mut aligned_scan,
1295            &coarse_result.rotation,
1296            &coarse_result.translation,
1297            coarse_result.scale,
1298        );
1299
1300        // Stage 2: Fine alignment
1301        let fine_scan = if self.config.fine_voxel_size > 0.0 {
1302            let pc = PointCloud {
1303                points: aligned_scan.clone(),
1304                normals: None,
1305                colors: None,
1306            };
1307            pc.voxel_downsample(self.config.fine_voxel_size).points
1308        } else {
1309            aligned_scan.clone()
1310        };
1311
1312        let fine_icp = IcpAligner::new(self.config.icp_iterations, 1e-7);
1313        let fine_result = fine_icp.align_point_to_point(&fine_scan, template_vertices)?;
1314        // stages_completed: 2
1315
1316        IcpAligner::transform_points(
1317            &mut aligned_scan,
1318            &fine_result.rotation,
1319            &fine_result.translation,
1320            fine_result.scale,
1321        );
1322
1323        let combined_rot = mat3_mul(&fine_result.rotation, &coarse_result.rotation);
1324        let cr_trans = mat3_vec(&fine_result.rotation, &coarse_result.translation);
1325        let combined_trans = [
1326            fine_result.scale * cr_trans[0] + fine_result.translation[0],
1327            fine_result.scale * cr_trans[1] + fine_result.translation[1],
1328            fine_result.scale * cr_trans[2] + fine_result.translation[2],
1329        ];
1330        let combined_scale = fine_result.scale * coarse_result.scale;
1331
1332        let combined_alignment = IcpResult {
1333            rotation: combined_rot,
1334            translation: combined_trans,
1335            scale: combined_scale,
1336            fitness: fine_result.fitness,
1337            rmse: fine_result.rmse,
1338            iterations: coarse_result.iterations + fine_result.iterations,
1339        };
1340
1341        // Stage 3: Morph fitting
1342        let morph_params = if morph_targets.is_empty() {
1343            Vec::new()
1344        } else {
1345            self.fit_morphs(
1346                &aligned_scan,
1347                template_vertices,
1348                template_triangles,
1349                morph_targets,
1350            )?
1351        };
1352
1353        // stages_completed: 3
1354        let deformed = apply_morph_deltas(template_vertices, morph_targets, &morph_params);
1355        let surface_error = mean_closest_distance(&aligned_scan, &deformed);
1356
1357        Ok(PhotoFitResult {
1358            morph_parameters: morph_params,
1359            alignment: combined_alignment,
1360            surface_error,
1361            stages_completed: 3,
1362        })
1363    }
1364
1365    /// Gradient descent to fit morph target weights.
1366    fn fit_morphs(
1367        &self,
1368        scan_points: &[[f64; 3]],
1369        template_vertices: &[[f64; 3]],
1370        _template_triangles: &[[usize; 3]],
1371        morph_targets: &[(String, Vec<[f64; 3]>)],
1372    ) -> anyhow::Result<Vec<(String, f64)>> {
1373        let n_morphs = morph_targets.len();
1374        let mut weights = vec![0.0_f64; n_morphs];
1375        let lr = 0.001_f64;
1376        let reg = self.config.regularization;
1377
1378        let scan_sub = if scan_points.len() > 2000 {
1379            let step = scan_points.len() / 2000;
1380            scan_points
1381                .iter()
1382                .step_by(step.max(1))
1383                .copied()
1384                .collect::<Vec<_>>()
1385        } else {
1386            scan_points.to_vec()
1387        };
1388
1389        for _iter in 0..self.config.morph_iterations {
1390            let deformed = apply_morph_deltas(
1391                template_vertices,
1392                morph_targets,
1393                &weight_pairs(morph_targets, &weights),
1394            );
1395
1396            let mut grad = vec![0.0_f64; n_morphs];
1397            let n_scan = scan_sub.len() as f64;
1398
1399            for sp in &scan_sub {
1400                let (closest_idx, _) = find_closest_vertex(sp, &deformed);
1401                let diff = vec3_sub(sp, &deformed[closest_idx]);
1402
1403                for (j, (_name, deltas)) in morph_targets.iter().enumerate() {
1404                    if closest_idx < deltas.len() {
1405                        let d = &deltas[closest_idx];
1406                        grad[j] += -2.0 * vec3_dot(&diff, d) / n_scan;
1407                    }
1408                }
1409            }
1410
1411            for j in 0..n_morphs {
1412                grad[j] += 2.0 * reg * weights[j];
1413            }
1414
1415            for j in 0..n_morphs {
1416                weights[j] -= lr * grad[j];
1417                weights[j] = weights[j].clamp(-2.0, 2.0);
1418            }
1419        }
1420
1421        Ok(weight_pairs(morph_targets, &weights))
1422    }
1423}
1424
1425// ---------------------------------------------------------------------------
1426// Multi-stage fitting helpers
1427// ---------------------------------------------------------------------------
1428
1429fn voxel_downsample_slice(pts: &[[f64; 3]], voxel_size: f64) -> Vec<[f64; 3]> {
1430    if pts.is_empty() || voxel_size <= 0.0 {
1431        return pts.to_vec();
1432    }
1433    let inv = 1.0 / voxel_size;
1434    let mut buckets: std::collections::HashMap<(i64, i64, i64), ([f64; 3], usize)> =
1435        std::collections::HashMap::new();
1436
1437    for p in pts {
1438        let key = (
1439            (p[0] * inv).floor() as i64,
1440            (p[1] * inv).floor() as i64,
1441            (p[2] * inv).floor() as i64,
1442        );
1443        let entry = buckets.entry(key).or_insert(([0.0; 3], 0));
1444        entry.0[0] += p[0];
1445        entry.0[1] += p[1];
1446        entry.0[2] += p[2];
1447        entry.1 += 1;
1448    }
1449
1450    buckets
1451        .values()
1452        .map(|(sum, count)| {
1453            let inv_n = 1.0 / (*count as f64);
1454            [sum[0] * inv_n, sum[1] * inv_n, sum[2] * inv_n]
1455        })
1456        .collect()
1457}
1458
1459fn apply_morph_deltas(
1460    template: &[[f64; 3]],
1461    morph_targets: &[(String, Vec<[f64; 3]>)],
1462    weights: &[(String, f64)],
1463) -> Vec<[f64; 3]> {
1464    let mut result: Vec<[f64; 3]> = template.to_vec();
1465
1466    for (name, w) in weights {
1467        if w.abs() < 1e-12 {
1468            continue;
1469        }
1470        if let Some((_n, deltas)) = morph_targets.iter().find(|(n, _)| n == name) {
1471            let len = result.len().min(deltas.len());
1472            for i in 0..len {
1473                result[i][0] += w * deltas[i][0];
1474                result[i][1] += w * deltas[i][1];
1475                result[i][2] += w * deltas[i][2];
1476            }
1477        }
1478    }
1479
1480    result
1481}
1482
1483fn weight_pairs(morph_targets: &[(String, Vec<[f64; 3]>)], weights: &[f64]) -> Vec<(String, f64)> {
1484    morph_targets
1485        .iter()
1486        .zip(weights.iter())
1487        .map(|((name, _), &w)| (name.clone(), w))
1488        .collect()
1489}
1490
1491fn find_closest_vertex(point: &[f64; 3], vertices: &[[f64; 3]]) -> (usize, f64) {
1492    let mut best_idx = 0usize;
1493    let mut best_d2 = f64::MAX;
1494    for (i, v) in vertices.iter().enumerate() {
1495        let d2 = dist3_sq(point, v);
1496        if d2 < best_d2 {
1497            best_d2 = d2;
1498            best_idx = i;
1499        }
1500    }
1501    (best_idx, best_d2)
1502}
1503
1504fn mean_closest_distance(source: &[[f64; 3]], target: &[[f64; 3]]) -> f64 {
1505    if source.is_empty() || target.is_empty() {
1506        return 0.0;
1507    }
1508    let total: f64 = source
1509        .iter()
1510        .map(|sp| {
1511            let (_, d2) = find_closest_vertex(sp, target);
1512            d2.sqrt()
1513        })
1514        .sum();
1515    total / source.len() as f64
1516}