Skip to main content

oxihuman_morph/
expression_calibration.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Expression calibration: fit FACS Action Units to facial landmarks.
5
6// ── Types ─────────────────────────────────────────────────────────────────────
7
8/// A single 3D facial landmark.
9#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub struct FacialLandmark {
12    pub id: usize,
13    pub name: String,
14    pub position: [f32; 3],
15}
16
17/// A set of facial landmarks (e.g. 68-point or sparse).
18#[allow(dead_code)]
19#[derive(Debug, Clone)]
20pub struct LandmarkSet {
21    pub landmarks: Vec<FacialLandmark>,
22}
23
24/// A FACS Action Unit activation in [0, 1].
25#[allow(dead_code)]
26#[derive(Debug, Clone)]
27pub struct AuActivation {
28    pub au_id: u8,
29    pub intensity: f32,
30}
31
32// ── Core functions ────────────────────────────────────────────────────────────
33
34/// Compute per-landmark displacement from neutral to posed.
35#[allow(dead_code)]
36pub fn landmark_delta(neutral: &LandmarkSet, posed: &LandmarkSet) -> Vec<[f32; 3]> {
37    neutral
38        .landmarks
39        .iter()
40        .zip(posed.landmarks.iter())
41        .map(|(n, p)| {
42            [
43                p.position[0] - n.position[0],
44                p.position[1] - n.position[1],
45                p.position[2] - n.position[2],
46            ]
47        })
48        .collect()
49}
50
51/// Project per-landmark deltas onto AU basis vectors via dot product.
52#[allow(dead_code)]
53pub fn project_deltas_to_aus(deltas: &[[f32; 3]], au_basis: &[[f32; 3]]) -> Vec<f32> {
54    au_basis
55        .iter()
56        .map(|basis| {
57            deltas.iter().enumerate().fold(0.0_f32, |acc, (i, d)| {
58                let b = au_basis.get(i % au_basis.len()).copied().unwrap_or(*basis);
59                acc + d[0] * b[0] + d[1] * b[1] + d[2] * b[2]
60            })
61        })
62        .collect()
63}
64
65/// Build a simple default AU basis for `n_landmarks` landmarks.
66/// Each AU basis vector is a unit vector in [Y direction] scaled per-AU.
67#[allow(dead_code)]
68pub fn build_default_au_basis(n_landmarks: usize) -> Vec<[f32; 3]> {
69    (0..n_landmarks)
70        .map(|i| {
71            let scale = 1.0 / (n_landmarks.max(1) as f32).sqrt();
72            let sign = if i % 2 == 0 { 1.0_f32 } else { -1.0_f32 };
73            [0.0, sign * scale, 0.0]
74        })
75        .collect()
76}
77
78/// Fit AU activations to the displacement between neutral and target landmarks.
79/// Uses a simple least-squares projection.
80#[allow(dead_code)]
81pub fn calibrate_expression_to_landmarks(
82    neutral: &LandmarkSet,
83    target: &LandmarkSet,
84    au_basis: &[[f32; 3]],
85) -> Vec<AuActivation> {
86    let deltas = landmark_delta(neutral, target);
87    let raw = project_deltas_to_aus(&deltas, au_basis);
88    raw.into_iter()
89        .enumerate()
90        .map(|(i, v)| AuActivation {
91            au_id: i as u8,
92            intensity: v.clamp(0.0, 1.0),
93        })
94        .collect()
95}
96
97/// Compute reconstruction error after applying AU activations.
98#[allow(dead_code)]
99pub fn landmark_reconstruction_error(
100    neutral: &LandmarkSet,
101    target: &LandmarkSet,
102    activations: &[AuActivation],
103    au_basis: &[[f32; 3]],
104) -> f32 {
105    let deltas = landmark_delta(neutral, target);
106    let n = deltas.len();
107    if n == 0 {
108        return 0.0;
109    }
110    // Reconstruct deltas from activations
111    let mut reconstructed = vec![[0.0_f32; 3]; n];
112    for act in activations {
113        let idx = (act.au_id as usize).min(au_basis.len().saturating_sub(1));
114        let basis = au_basis[idx];
115        for r in reconstructed.iter_mut() {
116            r[0] += act.intensity * basis[0];
117            r[1] += act.intensity * basis[1];
118            r[2] += act.intensity * basis[2];
119        }
120    }
121    // Mean squared error
122    let mse: f32 = deltas
123        .iter()
124        .zip(reconstructed.iter())
125        .map(|(d, r)| {
126            let e = [d[0] - r[0], d[1] - r[1], d[2] - r[2]];
127            e[0] * e[0] + e[1] * e[1] + e[2] * e[2]
128        })
129        .sum::<f32>()
130        / n as f32;
131    mse.sqrt()
132}
133
134/// Zero-mean, unit-scale normalisation of a landmark set.
135#[allow(dead_code)]
136pub fn normalize_landmark_set(landmarks: &mut LandmarkSet) {
137    let n = landmarks.landmarks.len();
138    if n == 0 {
139        return;
140    }
141    let mean: [f32; 3] = {
142        let sum = landmarks.landmarks.iter().fold([0.0_f32; 3], |acc, l| {
143            [
144                acc[0] + l.position[0],
145                acc[1] + l.position[1],
146                acc[2] + l.position[2],
147            ]
148        });
149        [sum[0] / n as f32, sum[1] / n as f32, sum[2] / n as f32]
150    };
151    for lm in landmarks.landmarks.iter_mut() {
152        lm.position[0] -= mean[0];
153        lm.position[1] -= mean[1];
154        lm.position[2] -= mean[2];
155    }
156    let scale: f32 = landmarks
157        .landmarks
158        .iter()
159        .map(|l| {
160            (l.position[0] * l.position[0]
161                + l.position[1] * l.position[1]
162                + l.position[2] * l.position[2])
163                .sqrt()
164        })
165        .fold(0.0_f32, f32::max);
166    if scale > 1e-8 {
167        for lm in landmarks.landmarks.iter_mut() {
168            lm.position[0] /= scale;
169            lm.position[1] /= scale;
170            lm.position[2] /= scale;
171        }
172    }
173}
174
175/// Build a canonical 68-landmark face set at approximate positions.
176#[allow(dead_code)]
177pub fn standard_68_landmarks() -> LandmarkSet {
178    let names = [
179        "jaw_0",
180        "jaw_1",
181        "jaw_2",
182        "jaw_3",
183        "jaw_4",
184        "jaw_5",
185        "jaw_6",
186        "jaw_7",
187        "jaw_8",
188        "jaw_9",
189        "jaw_10",
190        "jaw_11",
191        "jaw_12",
192        "jaw_13",
193        "jaw_14",
194        "jaw_15",
195        "jaw_16",
196        "brow_l_0",
197        "brow_l_1",
198        "brow_l_2",
199        "brow_l_3",
200        "brow_l_4",
201        "brow_r_0",
202        "brow_r_1",
203        "brow_r_2",
204        "brow_r_3",
205        "brow_r_4",
206        "nose_bridge_0",
207        "nose_bridge_1",
208        "nose_bridge_2",
209        "nose_bridge_3",
210        "nose_tip",
211        "nose_nostril_l",
212        "nose_under_l",
213        "nose_under_r",
214        "nose_nostril_r",
215        "eye_l_0",
216        "eye_l_1",
217        "eye_l_2",
218        "eye_l_3",
219        "eye_l_4",
220        "eye_l_5",
221        "eye_r_0",
222        "eye_r_1",
223        "eye_r_2",
224        "eye_r_3",
225        "eye_r_4",
226        "eye_r_5",
227        "mouth_0",
228        "mouth_1",
229        "mouth_2",
230        "mouth_3",
231        "mouth_4",
232        "mouth_5",
233        "mouth_6",
234        "mouth_7",
235        "mouth_8",
236        "mouth_9",
237        "mouth_10",
238        "mouth_11",
239        "mouth_inner_0",
240        "mouth_inner_1",
241        "mouth_inner_2",
242        "mouth_inner_3",
243        "mouth_inner_4",
244        "mouth_inner_5",
245        "mouth_inner_6",
246        "mouth_inner_7",
247    ];
248    let positions: Vec<[f32; 3]> = (0..68)
249        .map(|i| {
250            let angle = i as f32 * std::f32::consts::TAU / 68.0;
251            [0.5 * angle.cos(), 0.5 * angle.sin(), 0.0]
252        })
253        .collect();
254    LandmarkSet {
255        landmarks: (0..68)
256            .map(|i| FacialLandmark {
257                id: i,
258                name: names.get(i).copied().unwrap_or("lm").to_string(),
259                position: positions[i],
260            })
261            .collect(),
262    }
263}
264
265/// Euclidean distance between two landmarks.
266#[allow(dead_code)]
267pub fn landmark_distance(a: &FacialLandmark, b: &FacialLandmark) -> f32 {
268    let dx = a.position[0] - b.position[0];
269    let dy = a.position[1] - b.position[1];
270    let dz = a.position[2] - b.position[2];
271    (dx * dx + dy * dy + dz * dz).sqrt()
272}
273
274/// Maximum X span of a landmark set.
275#[allow(dead_code)]
276pub fn face_width(landmarks: &LandmarkSet) -> f32 {
277    if landmarks.landmarks.is_empty() {
278        return 0.0;
279    }
280    let min_x = landmarks
281        .landmarks
282        .iter()
283        .map(|l| l.position[0])
284        .fold(f32::INFINITY, f32::min);
285    let max_x = landmarks
286        .landmarks
287        .iter()
288        .map(|l| l.position[0])
289        .fold(f32::NEG_INFINITY, f32::max);
290    (max_x - min_x).max(0.0)
291}
292
293/// Maximum Y span of a landmark set.
294#[allow(dead_code)]
295pub fn face_height(landmarks: &LandmarkSet) -> f32 {
296    if landmarks.landmarks.is_empty() {
297        return 0.0;
298    }
299    let min_y = landmarks
300        .landmarks
301        .iter()
302        .map(|l| l.position[1])
303        .fold(f32::INFINITY, f32::min);
304    let max_y = landmarks
305        .landmarks
306        .iter()
307        .map(|l| l.position[1])
308        .fold(f32::NEG_INFINITY, f32::max);
309    (max_y - min_y).max(0.0)
310}
311
312// ── Tests ─────────────────────────────────────────────────────────────────────
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    fn make_lm(id: usize, pos: [f32; 3]) -> FacialLandmark {
319        FacialLandmark {
320            id,
321            name: format!("lm{id}"),
322            position: pos,
323        }
324    }
325
326    fn make_set(positions: &[[f32; 3]]) -> LandmarkSet {
327        LandmarkSet {
328            landmarks: positions
329                .iter()
330                .enumerate()
331                .map(|(i, &p)| make_lm(i, p))
332                .collect(),
333        }
334    }
335
336    #[test]
337    fn test_landmark_delta_identical_is_zero() {
338        let s = make_set(&[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]);
339        let deltas = landmark_delta(&s, &s);
340        for d in &deltas {
341            assert_eq!(*d, [0.0, 0.0, 0.0]);
342        }
343    }
344
345    #[test]
346    fn test_landmark_delta_correct() {
347        let n = make_set(&[[0.0, 0.0, 0.0]]);
348        let p = make_set(&[[1.0, 2.0, 3.0]]);
349        let d = landmark_delta(&n, &p);
350        assert_eq!(d[0], [1.0, 2.0, 3.0]);
351    }
352
353    #[test]
354    fn test_face_width_positive() {
355        let ls = standard_68_landmarks();
356        assert!(face_width(&ls) > 0.0);
357    }
358
359    #[test]
360    fn test_face_height_positive() {
361        let ls = standard_68_landmarks();
362        assert!(face_height(&ls) > 0.0);
363    }
364
365    #[test]
366    fn test_face_width_empty() {
367        let ls = LandmarkSet { landmarks: vec![] };
368        assert_eq!(face_width(&ls), 0.0);
369    }
370
371    #[test]
372    fn test_normalize_landmark_set_mean_near_zero() {
373        let mut ls = make_set(&[[1.0, 2.0, 0.0], [3.0, 4.0, 0.0], [-1.0, 0.0, 0.0]]);
374        normalize_landmark_set(&mut ls);
375        let n = ls.landmarks.len() as f32;
376        let mean_x: f32 = ls.landmarks.iter().map(|l| l.position[0]).sum::<f32>() / n;
377        let mean_y: f32 = ls.landmarks.iter().map(|l| l.position[1]).sum::<f32>() / n;
378        assert!(mean_x.abs() < 1e-5);
379        assert!(mean_y.abs() < 1e-5);
380    }
381
382    #[test]
383    fn test_normalize_empty_no_panic() {
384        let mut ls = LandmarkSet { landmarks: vec![] };
385        normalize_landmark_set(&mut ls);
386    }
387
388    #[test]
389    fn test_reconstruction_error_nonnegative() {
390        let n = standard_68_landmarks();
391        let p = standard_68_landmarks();
392        let basis = build_default_au_basis(68);
393        let acts = calibrate_expression_to_landmarks(&n, &p, &basis);
394        let err = landmark_reconstruction_error(&n, &p, &acts, &basis);
395        assert!(err >= 0.0);
396    }
397
398    #[test]
399    fn test_calibrate_no_nan() {
400        let n = standard_68_landmarks();
401        let p = standard_68_landmarks();
402        let basis = build_default_au_basis(68);
403        let acts = calibrate_expression_to_landmarks(&n, &p, &basis);
404        for a in &acts {
405            assert!(!a.intensity.is_nan());
406        }
407    }
408
409    #[test]
410    fn test_calibrate_intensity_clamped() {
411        let n = standard_68_landmarks();
412        let p = standard_68_landmarks();
413        let basis = build_default_au_basis(68);
414        let acts = calibrate_expression_to_landmarks(&n, &p, &basis);
415        for a in &acts {
416            assert!((0.0..=1.0).contains(&a.intensity));
417        }
418    }
419
420    #[test]
421    fn test_landmark_distance_zero_same_point() {
422        let a = make_lm(0, [1.0, 2.0, 3.0]);
423        let b = make_lm(1, [1.0, 2.0, 3.0]);
424        assert!((landmark_distance(&a, &b)).abs() < 1e-6);
425    }
426
427    #[test]
428    fn test_landmark_distance_known() {
429        let a = make_lm(0, [0.0, 0.0, 0.0]);
430        let b = make_lm(1, [3.0, 4.0, 0.0]);
431        assert!((landmark_distance(&a, &b) - 5.0).abs() < 1e-5);
432    }
433
434    #[test]
435    fn test_standard_68_landmarks_count() {
436        assert_eq!(standard_68_landmarks().landmarks.len(), 68);
437    }
438
439    #[test]
440    fn test_build_default_au_basis_length() {
441        assert_eq!(build_default_au_basis(10).len(), 10);
442    }
443
444    #[test]
445    fn test_project_deltas_no_nan() {
446        let deltas: Vec<[f32; 3]> = (0..5).map(|i| [i as f32, 0.0, 0.0]).collect();
447        let basis = build_default_au_basis(5);
448        let out = project_deltas_to_aus(&deltas, &basis);
449        for v in &out {
450            assert!(!v.is_nan());
451        }
452    }
453}