Skip to main content

oxihuman_morph/
expression_mixer.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5
6use std::collections::HashMap;
7
8pub type MorphWeightMap = HashMap<String, f32>;
9
10/// A single mixer layer that contributes to the final output.
11pub struct MixLayer {
12    pub name: String,
13    pub weights: MorphWeightMap,
14    /// Overall blend factor for this layer [0, 1].
15    pub blend: f32,
16    /// Whether this layer is additive (true) or override (false).
17    pub additive: bool,
18}
19
20/// The expression mixer: stacks layers and produces a final morph weight map.
21pub struct ExpressionMixer {
22    layers: Vec<MixLayer>,
23}
24
25impl ExpressionMixer {
26    pub fn new() -> Self {
27        Self { layers: Vec::new() }
28    }
29
30    pub fn add_layer(&mut self, layer: MixLayer) {
31        self.layers.push(layer);
32    }
33
34    /// Remove layer by name; returns true if found and removed.
35    pub fn remove_layer(&mut self, name: &str) -> bool {
36        if let Some(pos) = self.layers.iter().position(|l| l.name == name) {
37            self.layers.remove(pos);
38            true
39        } else {
40            false
41        }
42    }
43
44    /// Set blend factor for the named layer; returns true if found.
45    pub fn set_blend(&mut self, name: &str, blend: f32) -> bool {
46        if let Some(layer) = self.layers.iter_mut().find(|l| l.name == name) {
47            layer.blend = blend;
48            true
49        } else {
50            false
51        }
52    }
53
54    pub fn layer_count(&self) -> usize {
55        self.layers.len()
56    }
57
58    /// Evaluate all layers in order to produce a final morph weight map.
59    ///
60    /// - Additive layers: add `weight * layer.blend` to each key.
61    /// - Override layers: lerp current value toward layer's weight by `layer.blend`.
62    pub fn evaluate(&self) -> MorphWeightMap {
63        let mut result: MorphWeightMap = HashMap::new();
64
65        for layer in &self.layers {
66            if layer.additive {
67                for (key, &val) in &layer.weights {
68                    let current = result.entry(key.clone()).or_insert(0.0);
69                    *current += val * layer.blend;
70                }
71            } else {
72                // Override: lerp current toward layer value by blend.
73                // First, collect all keys from result and layer.
74                let all_keys: Vec<String> = result
75                    .keys()
76                    .chain(layer.weights.keys())
77                    .cloned()
78                    .collect::<std::collections::HashSet<_>>()
79                    .into_iter()
80                    .collect();
81
82                for key in all_keys {
83                    let current = result.get(&key).copied().unwrap_or(0.0);
84                    let target = layer.weights.get(&key).copied().unwrap_or(0.0);
85                    let blended = current + (target - current) * layer.blend;
86                    result.insert(key, blended);
87                }
88            }
89        }
90
91        result
92    }
93
94    pub fn clear(&mut self) {
95        self.layers.clear();
96    }
97}
98
99impl Default for ExpressionMixer {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105// ---------------------------------------------------------------------------
106// Standalone utility functions
107// ---------------------------------------------------------------------------
108
109/// Lerp between `a` (t=0) and `b` (t=1) over the union of all keys.
110pub fn merge_weight_maps(a: &MorphWeightMap, b: &MorphWeightMap, t: f32) -> MorphWeightMap {
111    let all_keys: std::collections::HashSet<&String> = a.keys().chain(b.keys()).collect();
112    let mut result = MorphWeightMap::new();
113    for key in all_keys {
114        let av = a.get(key).copied().unwrap_or(0.0);
115        let bv = b.get(key).copied().unwrap_or(0.0);
116        result.insert(key.clone(), av + (bv - av) * t);
117    }
118    result
119}
120
121/// Add `scale * additive[k]` to `base[k]` over the union of all keys.
122pub fn add_weight_maps(
123    base: &MorphWeightMap,
124    additive: &MorphWeightMap,
125    scale: f32,
126) -> MorphWeightMap {
127    let all_keys: std::collections::HashSet<&String> = base.keys().chain(additive.keys()).collect();
128    let mut result = MorphWeightMap::new();
129    for key in all_keys {
130        let bv = base.get(key).copied().unwrap_or(0.0);
131        let av = additive.get(key).copied().unwrap_or(0.0);
132        result.insert(key.clone(), bv + scale * av);
133    }
134    result
135}
136
137/// Clamp every value in the map to [min, max].
138pub fn clamp_weight_map(map: &MorphWeightMap, min: f32, max: f32) -> MorphWeightMap {
139    map.iter()
140        .map(|(k, &v)| (k.clone(), v.clamp(min, max)))
141        .collect()
142}
143
144/// Multiply every value in the map by `scale`.
145pub fn scale_weight_map(map: &MorphWeightMap, scale: f32) -> MorphWeightMap {
146    map.iter().map(|(k, &v)| (k.clone(), v * scale)).collect()
147}
148
149/// Compute the L2 magnitude (sqrt of sum of squares) of the weight map.
150pub fn weight_map_magnitude(map: &MorphWeightMap) -> f32 {
151    map.values().map(|&v| v * v).sum::<f32>().sqrt()
152}
153
154/// Return the top `n` entries by absolute value, sorted descending.
155pub fn top_n_weights(map: &MorphWeightMap, n: usize) -> Vec<(String, f32)> {
156    let mut entries: Vec<(String, f32)> = map.iter().map(|(k, &v)| (k.clone(), v)).collect();
157    entries.sort_by(|a, b| {
158        b.1.abs()
159            .partial_cmp(&a.1.abs())
160            .unwrap_or(std::cmp::Ordering::Equal)
161    });
162    entries.truncate(n);
163    entries
164}
165
166/// Keep only entries whose absolute value is >= threshold.
167pub fn threshold_weight_map(map: &MorphWeightMap, threshold: f32) -> MorphWeightMap {
168    map.iter()
169        .filter(|(_, &v)| v.abs() >= threshold)
170        .map(|(k, &v)| (k.clone(), v))
171        .collect()
172}
173
174// ---------------------------------------------------------------------------
175// Factory functions
176// ---------------------------------------------------------------------------
177
178/// Build a lip-sync override layer.
179pub fn lip_sync_layer(viseme_weights: MorphWeightMap, blend: f32) -> MixLayer {
180    MixLayer {
181        name: "lip_sync".to_string(),
182        weights: viseme_weights,
183        blend,
184        additive: false,
185    }
186}
187
188/// Build an emotion override layer.
189pub fn emotion_layer(emotion_weights: MorphWeightMap, blend: f32) -> MixLayer {
190    MixLayer {
191        name: "emotion".to_string(),
192        weights: emotion_weights,
193        blend,
194        additive: false,
195    }
196}
197
198/// Build a micro-expression additive layer.
199pub fn micro_expression_layer(weights: MorphWeightMap, blend: f32) -> MixLayer {
200    MixLayer {
201        name: "micro_expression".to_string(),
202        weights,
203        blend,
204        additive: true,
205    }
206}
207
208/// Build a corrective additive layer.
209pub fn corrective_layer(weights: MorphWeightMap, blend: f32) -> MixLayer {
210    MixLayer {
211        name: "corrective".to_string(),
212        weights,
213        blend,
214        additive: true,
215    }
216}
217
218// ---------------------------------------------------------------------------
219// Tests
220// ---------------------------------------------------------------------------
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    fn map(pairs: &[(&str, f32)]) -> MorphWeightMap {
227        pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
228    }
229
230    // --- ExpressionMixer basics ---
231
232    #[test]
233    fn test_empty_mixer_evaluates_to_empty_map() {
234        let mixer = ExpressionMixer::new();
235        let result = mixer.evaluate();
236        assert!(result.is_empty());
237    }
238
239    #[test]
240    fn test_add_layer_increases_count() {
241        let mut mixer = ExpressionMixer::new();
242        assert_eq!(mixer.layer_count(), 0);
243        mixer.add_layer(emotion_layer(map(&[("smile", 1.0)]), 1.0));
244        assert_eq!(mixer.layer_count(), 1);
245    }
246
247    #[test]
248    fn test_remove_layer_found() {
249        let mut mixer = ExpressionMixer::new();
250        mixer.add_layer(emotion_layer(map(&[("smile", 1.0)]), 1.0));
251        let removed = mixer.remove_layer("emotion");
252        assert!(removed);
253        assert_eq!(mixer.layer_count(), 0);
254    }
255
256    #[test]
257    fn test_remove_layer_not_found() {
258        let mut mixer = ExpressionMixer::new();
259        let removed = mixer.remove_layer("nonexistent");
260        assert!(!removed);
261    }
262
263    #[test]
264    fn test_set_blend_found() {
265        let mut mixer = ExpressionMixer::new();
266        mixer.add_layer(emotion_layer(map(&[("smile", 1.0)]), 0.5));
267        let ok = mixer.set_blend("emotion", 0.8);
268        assert!(ok);
269        let result = mixer.evaluate();
270        let val = result["smile"];
271        assert!((val - 0.8).abs() < 1e-5, "expected 0.8, got {val}");
272    }
273
274    #[test]
275    fn test_set_blend_not_found() {
276        let mut mixer = ExpressionMixer::new();
277        let ok = mixer.set_blend("absent", 0.5);
278        assert!(!ok);
279    }
280
281    #[test]
282    fn test_clear() {
283        let mut mixer = ExpressionMixer::new();
284        mixer.add_layer(emotion_layer(map(&[("smile", 1.0)]), 1.0));
285        mixer.clear();
286        assert_eq!(mixer.layer_count(), 0);
287        assert!(mixer.evaluate().is_empty());
288    }
289
290    // --- Override layer behaviour ---
291
292    #[test]
293    fn test_override_layer_full_blend() {
294        let mut mixer = ExpressionMixer::new();
295        mixer.add_layer(MixLayer {
296            name: "base".to_string(),
297            weights: map(&[("a", 0.0)]),
298            blend: 1.0,
299            additive: false,
300        });
301        mixer.add_layer(MixLayer {
302            name: "override".to_string(),
303            weights: map(&[("a", 1.0)]),
304            blend: 1.0,
305            additive: false,
306        });
307        let result = mixer.evaluate();
308        assert!((result["a"] - 1.0).abs() < 1e-5);
309    }
310
311    #[test]
312    fn test_override_layer_half_blend() {
313        let mut mixer = ExpressionMixer::new();
314        mixer.add_layer(MixLayer {
315            name: "base".to_string(),
316            weights: map(&[("a", 0.0)]),
317            blend: 1.0,
318            additive: false,
319        });
320        mixer.add_layer(MixLayer {
321            name: "override".to_string(),
322            weights: map(&[("a", 1.0)]),
323            blend: 0.5,
324            additive: false,
325        });
326        let result = mixer.evaluate();
327        assert!((result["a"] - 0.5).abs() < 1e-5);
328    }
329
330    // --- Additive layer behaviour ---
331
332    #[test]
333    fn test_additive_layer() {
334        let mut mixer = ExpressionMixer::new();
335        mixer.add_layer(MixLayer {
336            name: "base".to_string(),
337            weights: map(&[("a", 0.3)]),
338            blend: 1.0,
339            additive: false,
340        });
341        mixer.add_layer(MixLayer {
342            name: "add".to_string(),
343            weights: map(&[("a", 0.5)]),
344            blend: 1.0,
345            additive: true,
346        });
347        let result = mixer.evaluate();
348        // override sets a=0.3, then additive adds 0.5*1.0=0.5 → 0.8
349        assert!((result["a"] - 0.8).abs() < 1e-5, "got {}", result["a"]);
350    }
351
352    #[test]
353    fn test_additive_layer_with_scale() {
354        let mut mixer = ExpressionMixer::new();
355        mixer.add_layer(micro_expression_layer(map(&[("twitch", 0.4)]), 0.5));
356        let result = mixer.evaluate();
357        // additive: 0.4 * 0.5 = 0.2
358        assert!((result["twitch"] - 0.2).abs() < 1e-5);
359    }
360
361    // --- Standalone utilities ---
362
363    #[test]
364    fn test_merge_weight_maps_midpoint() {
365        let a = map(&[("x", 0.0), ("y", 1.0)]);
366        let b = map(&[("x", 1.0), ("z", 1.0)]);
367        let m = merge_weight_maps(&a, &b, 0.5);
368        assert!((m["x"] - 0.5).abs() < 1e-5);
369        assert!((m["y"] - 0.5).abs() < 1e-5);
370        assert!((m["z"] - 0.5).abs() < 1e-5);
371    }
372
373    #[test]
374    fn test_merge_weight_maps_t0_equals_a() {
375        let a = map(&[("x", 0.3)]);
376        let b = map(&[("x", 0.9)]);
377        let m = merge_weight_maps(&a, &b, 0.0);
378        assert!((m["x"] - 0.3).abs() < 1e-5);
379    }
380
381    #[test]
382    fn test_add_weight_maps() {
383        let base = map(&[("a", 0.5)]);
384        let add = map(&[("a", 0.2), ("b", 0.4)]);
385        let result = add_weight_maps(&base, &add, 2.0);
386        assert!((result["a"] - 0.9).abs() < 1e-5); // 0.5 + 2*0.2
387        assert!((result["b"] - 0.8).abs() < 1e-5); // 0.0 + 2*0.4
388    }
389
390    #[test]
391    fn test_clamp_weight_map() {
392        let m = map(&[("a", -0.5), ("b", 1.5), ("c", 0.5)]);
393        let c = clamp_weight_map(&m, 0.0, 1.0);
394        assert!((c["a"] - 0.0).abs() < 1e-5);
395        assert!((c["b"] - 1.0).abs() < 1e-5);
396        assert!((c["c"] - 0.5).abs() < 1e-5);
397    }
398
399    #[test]
400    fn test_scale_weight_map() {
401        let m = map(&[("a", 0.4), ("b", 0.8)]);
402        let s = scale_weight_map(&m, 0.5);
403        assert!((s["a"] - 0.2).abs() < 1e-5);
404        assert!((s["b"] - 0.4).abs() < 1e-5);
405    }
406
407    #[test]
408    fn test_weight_map_magnitude() {
409        let m = map(&[("a", 3.0), ("b", 4.0)]);
410        let mag = weight_map_magnitude(&m);
411        assert!((mag - 5.0).abs() < 1e-4);
412    }
413
414    #[test]
415    fn test_top_n_weights() {
416        let m = map(&[("a", 0.1), ("b", 0.9), ("c", 0.5), ("d", -0.8)]);
417        let top = top_n_weights(&m, 2);
418        assert_eq!(top.len(), 2);
419        assert_eq!(top[0].0, "b");
420        assert_eq!(top[1].0, "d");
421    }
422
423    #[test]
424    fn test_top_n_weights_fewer_than_n() {
425        let m = map(&[("x", 0.3)]);
426        let top = top_n_weights(&m, 5);
427        assert_eq!(top.len(), 1);
428    }
429
430    #[test]
431    fn test_threshold_weight_map() {
432        let m = map(&[("a", 0.05), ("b", 0.5), ("c", -0.3)]);
433        let t = threshold_weight_map(&m, 0.1);
434        assert!(!t.contains_key("a"));
435        assert!(t.contains_key("b"));
436        assert!(t.contains_key("c"));
437    }
438
439    // --- Factory functions ---
440
441    #[test]
442    fn test_lip_sync_layer_factory() {
443        let layer = lip_sync_layer(map(&[("vowel_a", 1.0)]), 0.7);
444        assert_eq!(layer.name, "lip_sync");
445        assert!(!layer.additive);
446        assert!((layer.blend - 0.7).abs() < 1e-5);
447    }
448
449    #[test]
450    fn test_emotion_layer_factory() {
451        let layer = emotion_layer(map(&[("smile", 0.8)]), 1.0);
452        assert_eq!(layer.name, "emotion");
453        assert!(!layer.additive);
454    }
455
456    #[test]
457    fn test_micro_expression_layer_factory() {
458        let layer = micro_expression_layer(map(&[("brow_raise", 0.3)]), 0.5);
459        assert_eq!(layer.name, "micro_expression");
460        assert!(layer.additive);
461    }
462
463    #[test]
464    fn test_corrective_layer_factory() {
465        let layer = corrective_layer(map(&[("jaw_fix", 0.1)]), 1.0);
466        assert_eq!(layer.name, "corrective");
467        assert!(layer.additive);
468    }
469}