Skip to main content

oxihuman_morph/
diff.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::params::ParamState;
5
6/// The difference between two ParamState values.
7#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
8pub struct ParamDiff {
9    pub d_height: f32,
10    pub d_weight: f32,
11    pub d_muscle: f32,
12    pub d_age: f32,
13    /// Extra keys that changed: (key, old_value, new_value).
14    pub extra_changes: Vec<(String, f32, f32)>,
15    /// Extra keys added in `b` (not present in `a`).
16    pub extra_added: Vec<(String, f32)>,
17    /// Extra keys removed in `b` (present in `a` but not `b`).
18    pub extra_removed: Vec<(String, f32)>,
19}
20
21impl ParamDiff {
22    /// Compute the diff: b - a for each field.
23    pub fn compute(a: &ParamState, b: &ParamState) -> Self {
24        let d_height = b.height - a.height;
25        let d_weight = b.weight - a.weight;
26        let d_muscle = b.muscle - a.muscle;
27        let d_age = b.age - a.age;
28
29        let mut extra_changes = Vec::new();
30        let mut extra_added = Vec::new();
31        let mut extra_removed = Vec::new();
32
33        // Keys in b
34        for (key, &bval) in &b.extra {
35            if let Some(&aval) = a.extra.get(key) {
36                if (bval - aval).abs() > 0.0 {
37                    extra_changes.push((key.clone(), aval, bval));
38                }
39            } else {
40                extra_added.push((key.clone(), bval));
41            }
42        }
43
44        // Keys only in a (removed in b)
45        for (key, &aval) in &a.extra {
46            if !b.extra.contains_key(key) {
47                extra_removed.push((key.clone(), aval));
48            }
49        }
50
51        ParamDiff {
52            d_height,
53            d_weight,
54            d_muscle,
55            d_age,
56            extra_changes,
57            extra_added,
58            extra_removed,
59        }
60    }
61
62    /// True if all diffs are (near) zero.
63    pub fn is_zero(&self, tolerance: f32) -> bool {
64        self.d_height.abs() <= tolerance
65            && self.d_weight.abs() <= tolerance
66            && self.d_muscle.abs() <= tolerance
67            && self.d_age.abs() <= tolerance
68            && self
69                .extra_changes
70                .iter()
71                .all(|(_, a, b)| (b - a).abs() <= tolerance)
72            && self.extra_added.is_empty()
73            && self.extra_removed.is_empty()
74    }
75
76    /// L2 norm of the numeric diffs (height, weight, muscle, age).
77    pub fn magnitude(&self) -> f32 {
78        (self.d_height.powi(2) + self.d_weight.powi(2) + self.d_muscle.powi(2) + self.d_age.powi(2))
79            .sqrt()
80    }
81
82    /// Apply the diff to a ParamState: a + diff = b.
83    pub fn apply(&self, a: &ParamState) -> ParamState {
84        let mut result = a.clone();
85        result.height += self.d_height;
86        result.weight += self.d_weight;
87        result.muscle += self.d_muscle;
88        result.age += self.d_age;
89
90        for (key, _old, new) in &self.extra_changes {
91            result.extra.insert(key.clone(), *new);
92        }
93        for (key, val) in &self.extra_added {
94            result.extra.insert(key.clone(), *val);
95        }
96        for (key, _) in &self.extra_removed {
97            result.extra.remove(key);
98        }
99
100        result
101    }
102
103    /// Scale the diff by a factor.
104    pub fn scaled(&self, factor: f32) -> ParamDiff {
105        ParamDiff {
106            d_height: self.d_height * factor,
107            d_weight: self.d_weight * factor,
108            d_muscle: self.d_muscle * factor,
109            d_age: self.d_age * factor,
110            extra_changes: self
111                .extra_changes
112                .iter()
113                .map(|(k, a, b)| {
114                    let mid = a + (b - a) * factor;
115                    (k.clone(), *a, mid)
116                })
117                .collect(),
118            extra_added: self
119                .extra_added
120                .iter()
121                .map(|(k, v)| (k.clone(), v * factor))
122                .collect(),
123            extra_removed: self.extra_removed.clone(),
124        }
125    }
126
127    /// Human-readable description of what changed.
128    pub fn describe(&self, threshold: f32) -> String {
129        let mut changes = Vec::new();
130        if self.d_height.abs() > threshold {
131            changes.push(format!("height {:+.3}", self.d_height));
132        }
133        if self.d_weight.abs() > threshold {
134            changes.push(format!("weight {:+.3}", self.d_weight));
135        }
136        if self.d_muscle.abs() > threshold {
137            changes.push(format!("muscle {:+.3}", self.d_muscle));
138        }
139        if self.d_age.abs() > threshold {
140            changes.push(format!("age {:+.3}", self.d_age));
141        }
142        if changes.is_empty() {
143            "no significant changes".into()
144        } else {
145            changes.join(", ")
146        }
147    }
148}
149
150/// Statistics about vertex displacement between two position buffers.
151#[derive(Debug, Clone)]
152pub struct MeshDiffStats {
153    pub vertex_count: usize,
154    /// Number of vertices that moved more than `threshold`.
155    pub changed_count: usize,
156    pub max_displacement: f32,
157    pub avg_displacement: f32,
158    pub rms_displacement: f32,
159    /// Index of the vertex with maximum displacement.
160    pub max_vertex_idx: usize,
161}
162
163impl MeshDiffStats {
164    /// Compute stats comparing two position buffers (must be same length).
165    pub fn compute(a: &[[f32; 3]], b: &[[f32; 3]], threshold: f32) -> Self {
166        let n = a.len().min(b.len());
167        if n == 0 {
168            return MeshDiffStats {
169                vertex_count: 0,
170                changed_count: 0,
171                max_displacement: 0.0,
172                avg_displacement: 0.0,
173                rms_displacement: 0.0,
174                max_vertex_idx: 0,
175            };
176        }
177
178        let displacements: Vec<f32> = (0..n)
179            .map(|i| {
180                let dx = b[i][0] - a[i][0];
181                let dy = b[i][1] - a[i][1];
182                let dz = b[i][2] - a[i][2];
183                (dx * dx + dy * dy + dz * dz).sqrt()
184            })
185            .collect();
186
187        let changed_count = displacements.iter().filter(|&&d| d > threshold).count();
188        let max_displacement = displacements.iter().cloned().fold(0.0f32, f32::max);
189        let avg = displacements.iter().sum::<f32>() / n as f32;
190        let rms = (displacements.iter().map(|d| d * d).sum::<f32>() / n as f32).sqrt();
191        let max_idx = displacements
192            .iter()
193            .enumerate()
194            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
195            .map(|(i, _)| i)
196            .unwrap_or(0);
197
198        MeshDiffStats {
199            vertex_count: n,
200            changed_count,
201            max_displacement,
202            avg_displacement: avg,
203            rms_displacement: rms,
204            max_vertex_idx: max_idx,
205        }
206    }
207
208    /// Human-readable summary.
209    pub fn summary(&self) -> String {
210        format!(
211            "vertices: {}, changed: {}, max_disp: {:.4}, avg_disp: {:.4}, rms_disp: {:.4}, max_vertex_idx: {}",
212            self.vertex_count,
213            self.changed_count,
214            self.max_displacement,
215            self.avg_displacement,
216            self.rms_displacement,
217            self.max_vertex_idx,
218        )
219    }
220}
221
222/// Per-vertex displacement magnitudes between two position buffers.
223#[allow(dead_code)]
224pub fn vertex_displacements(a: &[[f32; 3]], b: &[[f32; 3]]) -> Vec<f32> {
225    let n = a.len().min(b.len());
226    (0..n)
227        .map(|i| {
228            let dx = b[i][0] - a[i][0];
229            let dy = b[i][1] - a[i][1];
230            let dz = b[i][2] - a[i][2];
231            (dx * dx + dy * dy + dz * dz).sqrt()
232        })
233        .collect()
234}
235
236/// Find the top-N most displaced vertices (by displacement magnitude).
237/// Returns (vertex_index, displacement) sorted descending.
238#[allow(dead_code)]
239pub fn top_displaced_vertices(a: &[[f32; 3]], b: &[[f32; 3]], n: usize) -> Vec<(usize, f32)> {
240    let mut displacements: Vec<(usize, f32)> =
241        vertex_displacements(a, b).into_iter().enumerate().collect();
242    displacements.sort_by(|x, y| y.1.partial_cmp(&x.1).unwrap_or(std::cmp::Ordering::Equal));
243    displacements.truncate(n);
244    displacements
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    fn p(h: f32, w: f32, m: f32, a: f32) -> ParamState {
252        ParamState {
253            height: h,
254            weight: w,
255            muscle: m,
256            age: a,
257            extra: Default::default(),
258        }
259    }
260
261    #[test]
262    fn diff_compute_height() {
263        let diff = ParamDiff::compute(&p(0.3, 0.5, 0.5, 0.5), &p(0.7, 0.5, 0.5, 0.5));
264        assert!((diff.d_height - 0.4).abs() < 1e-5);
265    }
266
267    #[test]
268    fn diff_is_zero_for_identical() {
269        let a = p(0.5, 0.5, 0.5, 0.5);
270        let diff = ParamDiff::compute(&a, &a);
271        assert!(diff.is_zero(1e-6));
272    }
273
274    #[test]
275    fn diff_magnitude_correct() {
276        let diff = ParamDiff {
277            d_height: 3.0,
278            d_weight: 4.0,
279            d_muscle: 0.0,
280            d_age: 0.0,
281            extra_changes: vec![],
282            extra_added: vec![],
283            extra_removed: vec![],
284        };
285        assert!((diff.magnitude() - 5.0).abs() < 1e-5);
286    }
287
288    #[test]
289    fn diff_apply_roundtrip() {
290        let a = p(0.3, 0.5, 0.5, 0.5);
291        let b = p(0.7, 0.5, 0.5, 0.5);
292        let diff = ParamDiff::compute(&a, &b);
293        let result = diff.apply(&a);
294        assert!((result.height - b.height).abs() < 1e-5);
295        assert!((result.weight - b.weight).abs() < 1e-5);
296        assert!((result.muscle - b.muscle).abs() < 1e-5);
297        assert!((result.age - b.age).abs() < 1e-5);
298    }
299
300    #[test]
301    fn diff_scaled_halves() {
302        let a = p(0.2, 0.5, 0.5, 0.5);
303        let b = p(0.8, 0.5, 0.5, 0.5);
304        let diff = ParamDiff::compute(&a, &b);
305        let scaled = diff.scaled(0.5);
306        assert!((scaled.d_height - diff.d_height * 0.5).abs() < 1e-5);
307    }
308
309    #[test]
310    fn diff_describe_nonempty() {
311        let a = p(0.2, 0.5, 0.5, 0.5);
312        let b = p(0.8, 0.5, 0.5, 0.5);
313        let diff = ParamDiff::compute(&a, &b);
314        let desc = diff.describe(0.01);
315        assert!(!desc.is_empty());
316        assert_ne!(desc, "no significant changes");
317    }
318
319    #[test]
320    fn mesh_diff_zero_for_same() {
321        let positions: Vec<[f32; 3]> = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
322        let stats = MeshDiffStats::compute(&positions, &positions, 1e-6);
323        assert_eq!(stats.max_displacement, 0.0);
324    }
325
326    #[test]
327    fn mesh_diff_detects_change() {
328        let a: Vec<[f32; 3]> = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
329        let mut b = a.clone();
330        b[1] = [2.0, 0.0, 0.0]; // move vertex 1 by 1 unit
331        let stats = MeshDiffStats::compute(&a, &b, 0.5);
332        assert_eq!(stats.changed_count, 1);
333    }
334
335    #[test]
336    fn vertex_displacements_length() {
337        let a: Vec<[f32; 3]> = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
338        let b: Vec<[f32; 3]> = vec![[0.1, 0.0, 0.0], [1.0, 0.1, 0.0], [0.0, 1.0, 0.1]];
339        let disps = vertex_displacements(&a, &b);
340        assert_eq!(disps.len(), a.len());
341    }
342
343    #[test]
344    fn top_displaced_sorted_desc() {
345        let a: Vec<[f32; 3]> = vec![
346            [0.0, 0.0, 0.0],
347            [0.0, 0.0, 0.0],
348            [0.0, 0.0, 0.0],
349            [0.0, 0.0, 0.0],
350        ];
351        let b: Vec<[f32; 3]> = vec![
352            [1.0, 0.0, 0.0], // disp = 1.0
353            [3.0, 0.0, 0.0], // disp = 3.0
354            [2.0, 0.0, 0.0], // disp = 2.0
355            [0.5, 0.0, 0.0], // disp = 0.5
356        ];
357        let top = top_displaced_vertices(&a, &b, 3);
358        assert_eq!(top.len(), 3);
359        // Should be sorted descending: 3.0, 2.0, 1.0
360        assert!(top[0].1 >= top[1].1);
361        assert!(top[1].1 >= top[2].1);
362        // Top should be vertex 1 (disp=3.0)
363        assert_eq!(top[0].0, 1);
364    }
365}