1#![allow(dead_code)]
5
6use std::collections::HashMap;
7
8pub type MorphWeightMap = HashMap<String, f32>;
9
10pub struct MixLayer {
12 pub name: String,
13 pub weights: MorphWeightMap,
14 pub blend: f32,
16 pub additive: bool,
18}
19
20pub struct ExpressionMixer {
22 layers: Vec<MixLayer>,
23}
24
25impl ExpressionMixer {
26 pub fn new() -> Self {
27 Self { layers: Vec::new() }
28 }
29
30 pub fn add_layer(&mut self, layer: MixLayer) {
31 self.layers.push(layer);
32 }
33
34 pub fn remove_layer(&mut self, name: &str) -> bool {
36 if let Some(pos) = self.layers.iter().position(|l| l.name == name) {
37 self.layers.remove(pos);
38 true
39 } else {
40 false
41 }
42 }
43
44 pub fn set_blend(&mut self, name: &str, blend: f32) -> bool {
46 if let Some(layer) = self.layers.iter_mut().find(|l| l.name == name) {
47 layer.blend = blend;
48 true
49 } else {
50 false
51 }
52 }
53
54 pub fn layer_count(&self) -> usize {
55 self.layers.len()
56 }
57
58 pub fn evaluate(&self) -> MorphWeightMap {
63 let mut result: MorphWeightMap = HashMap::new();
64
65 for layer in &self.layers {
66 if layer.additive {
67 for (key, &val) in &layer.weights {
68 let current = result.entry(key.clone()).or_insert(0.0);
69 *current += val * layer.blend;
70 }
71 } else {
72 let all_keys: Vec<String> = result
75 .keys()
76 .chain(layer.weights.keys())
77 .cloned()
78 .collect::<std::collections::HashSet<_>>()
79 .into_iter()
80 .collect();
81
82 for key in all_keys {
83 let current = result.get(&key).copied().unwrap_or(0.0);
84 let target = layer.weights.get(&key).copied().unwrap_or(0.0);
85 let blended = current + (target - current) * layer.blend;
86 result.insert(key, blended);
87 }
88 }
89 }
90
91 result
92 }
93
94 pub fn clear(&mut self) {
95 self.layers.clear();
96 }
97}
98
99impl Default for ExpressionMixer {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105pub fn merge_weight_maps(a: &MorphWeightMap, b: &MorphWeightMap, t: f32) -> MorphWeightMap {
111 let all_keys: std::collections::HashSet<&String> = a.keys().chain(b.keys()).collect();
112 let mut result = MorphWeightMap::new();
113 for key in all_keys {
114 let av = a.get(key).copied().unwrap_or(0.0);
115 let bv = b.get(key).copied().unwrap_or(0.0);
116 result.insert(key.clone(), av + (bv - av) * t);
117 }
118 result
119}
120
121pub fn add_weight_maps(
123 base: &MorphWeightMap,
124 additive: &MorphWeightMap,
125 scale: f32,
126) -> MorphWeightMap {
127 let all_keys: std::collections::HashSet<&String> = base.keys().chain(additive.keys()).collect();
128 let mut result = MorphWeightMap::new();
129 for key in all_keys {
130 let bv = base.get(key).copied().unwrap_or(0.0);
131 let av = additive.get(key).copied().unwrap_or(0.0);
132 result.insert(key.clone(), bv + scale * av);
133 }
134 result
135}
136
137pub fn clamp_weight_map(map: &MorphWeightMap, min: f32, max: f32) -> MorphWeightMap {
139 map.iter()
140 .map(|(k, &v)| (k.clone(), v.clamp(min, max)))
141 .collect()
142}
143
144pub fn scale_weight_map(map: &MorphWeightMap, scale: f32) -> MorphWeightMap {
146 map.iter().map(|(k, &v)| (k.clone(), v * scale)).collect()
147}
148
149pub fn weight_map_magnitude(map: &MorphWeightMap) -> f32 {
151 map.values().map(|&v| v * v).sum::<f32>().sqrt()
152}
153
154pub fn top_n_weights(map: &MorphWeightMap, n: usize) -> Vec<(String, f32)> {
156 let mut entries: Vec<(String, f32)> = map.iter().map(|(k, &v)| (k.clone(), v)).collect();
157 entries.sort_by(|a, b| {
158 b.1.abs()
159 .partial_cmp(&a.1.abs())
160 .unwrap_or(std::cmp::Ordering::Equal)
161 });
162 entries.truncate(n);
163 entries
164}
165
166pub fn threshold_weight_map(map: &MorphWeightMap, threshold: f32) -> MorphWeightMap {
168 map.iter()
169 .filter(|(_, &v)| v.abs() >= threshold)
170 .map(|(k, &v)| (k.clone(), v))
171 .collect()
172}
173
174pub fn lip_sync_layer(viseme_weights: MorphWeightMap, blend: f32) -> MixLayer {
180 MixLayer {
181 name: "lip_sync".to_string(),
182 weights: viseme_weights,
183 blend,
184 additive: false,
185 }
186}
187
188pub fn emotion_layer(emotion_weights: MorphWeightMap, blend: f32) -> MixLayer {
190 MixLayer {
191 name: "emotion".to_string(),
192 weights: emotion_weights,
193 blend,
194 additive: false,
195 }
196}
197
198pub fn micro_expression_layer(weights: MorphWeightMap, blend: f32) -> MixLayer {
200 MixLayer {
201 name: "micro_expression".to_string(),
202 weights,
203 blend,
204 additive: true,
205 }
206}
207
208pub fn corrective_layer(weights: MorphWeightMap, blend: f32) -> MixLayer {
210 MixLayer {
211 name: "corrective".to_string(),
212 weights,
213 blend,
214 additive: true,
215 }
216}
217
218#[cfg(test)]
223mod tests {
224 use super::*;
225
226 fn map(pairs: &[(&str, f32)]) -> MorphWeightMap {
227 pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
228 }
229
230 #[test]
233 fn test_empty_mixer_evaluates_to_empty_map() {
234 let mixer = ExpressionMixer::new();
235 let result = mixer.evaluate();
236 assert!(result.is_empty());
237 }
238
239 #[test]
240 fn test_add_layer_increases_count() {
241 let mut mixer = ExpressionMixer::new();
242 assert_eq!(mixer.layer_count(), 0);
243 mixer.add_layer(emotion_layer(map(&[("smile", 1.0)]), 1.0));
244 assert_eq!(mixer.layer_count(), 1);
245 }
246
247 #[test]
248 fn test_remove_layer_found() {
249 let mut mixer = ExpressionMixer::new();
250 mixer.add_layer(emotion_layer(map(&[("smile", 1.0)]), 1.0));
251 let removed = mixer.remove_layer("emotion");
252 assert!(removed);
253 assert_eq!(mixer.layer_count(), 0);
254 }
255
256 #[test]
257 fn test_remove_layer_not_found() {
258 let mut mixer = ExpressionMixer::new();
259 let removed = mixer.remove_layer("nonexistent");
260 assert!(!removed);
261 }
262
263 #[test]
264 fn test_set_blend_found() {
265 let mut mixer = ExpressionMixer::new();
266 mixer.add_layer(emotion_layer(map(&[("smile", 1.0)]), 0.5));
267 let ok = mixer.set_blend("emotion", 0.8);
268 assert!(ok);
269 let result = mixer.evaluate();
270 let val = result["smile"];
271 assert!((val - 0.8).abs() < 1e-5, "expected 0.8, got {val}");
272 }
273
274 #[test]
275 fn test_set_blend_not_found() {
276 let mut mixer = ExpressionMixer::new();
277 let ok = mixer.set_blend("absent", 0.5);
278 assert!(!ok);
279 }
280
281 #[test]
282 fn test_clear() {
283 let mut mixer = ExpressionMixer::new();
284 mixer.add_layer(emotion_layer(map(&[("smile", 1.0)]), 1.0));
285 mixer.clear();
286 assert_eq!(mixer.layer_count(), 0);
287 assert!(mixer.evaluate().is_empty());
288 }
289
290 #[test]
293 fn test_override_layer_full_blend() {
294 let mut mixer = ExpressionMixer::new();
295 mixer.add_layer(MixLayer {
296 name: "base".to_string(),
297 weights: map(&[("a", 0.0)]),
298 blend: 1.0,
299 additive: false,
300 });
301 mixer.add_layer(MixLayer {
302 name: "override".to_string(),
303 weights: map(&[("a", 1.0)]),
304 blend: 1.0,
305 additive: false,
306 });
307 let result = mixer.evaluate();
308 assert!((result["a"] - 1.0).abs() < 1e-5);
309 }
310
311 #[test]
312 fn test_override_layer_half_blend() {
313 let mut mixer = ExpressionMixer::new();
314 mixer.add_layer(MixLayer {
315 name: "base".to_string(),
316 weights: map(&[("a", 0.0)]),
317 blend: 1.0,
318 additive: false,
319 });
320 mixer.add_layer(MixLayer {
321 name: "override".to_string(),
322 weights: map(&[("a", 1.0)]),
323 blend: 0.5,
324 additive: false,
325 });
326 let result = mixer.evaluate();
327 assert!((result["a"] - 0.5).abs() < 1e-5);
328 }
329
330 #[test]
333 fn test_additive_layer() {
334 let mut mixer = ExpressionMixer::new();
335 mixer.add_layer(MixLayer {
336 name: "base".to_string(),
337 weights: map(&[("a", 0.3)]),
338 blend: 1.0,
339 additive: false,
340 });
341 mixer.add_layer(MixLayer {
342 name: "add".to_string(),
343 weights: map(&[("a", 0.5)]),
344 blend: 1.0,
345 additive: true,
346 });
347 let result = mixer.evaluate();
348 assert!((result["a"] - 0.8).abs() < 1e-5, "got {}", result["a"]);
350 }
351
352 #[test]
353 fn test_additive_layer_with_scale() {
354 let mut mixer = ExpressionMixer::new();
355 mixer.add_layer(micro_expression_layer(map(&[("twitch", 0.4)]), 0.5));
356 let result = mixer.evaluate();
357 assert!((result["twitch"] - 0.2).abs() < 1e-5);
359 }
360
361 #[test]
364 fn test_merge_weight_maps_midpoint() {
365 let a = map(&[("x", 0.0), ("y", 1.0)]);
366 let b = map(&[("x", 1.0), ("z", 1.0)]);
367 let m = merge_weight_maps(&a, &b, 0.5);
368 assert!((m["x"] - 0.5).abs() < 1e-5);
369 assert!((m["y"] - 0.5).abs() < 1e-5);
370 assert!((m["z"] - 0.5).abs() < 1e-5);
371 }
372
373 #[test]
374 fn test_merge_weight_maps_t0_equals_a() {
375 let a = map(&[("x", 0.3)]);
376 let b = map(&[("x", 0.9)]);
377 let m = merge_weight_maps(&a, &b, 0.0);
378 assert!((m["x"] - 0.3).abs() < 1e-5);
379 }
380
381 #[test]
382 fn test_add_weight_maps() {
383 let base = map(&[("a", 0.5)]);
384 let add = map(&[("a", 0.2), ("b", 0.4)]);
385 let result = add_weight_maps(&base, &add, 2.0);
386 assert!((result["a"] - 0.9).abs() < 1e-5); assert!((result["b"] - 0.8).abs() < 1e-5); }
389
390 #[test]
391 fn test_clamp_weight_map() {
392 let m = map(&[("a", -0.5), ("b", 1.5), ("c", 0.5)]);
393 let c = clamp_weight_map(&m, 0.0, 1.0);
394 assert!((c["a"] - 0.0).abs() < 1e-5);
395 assert!((c["b"] - 1.0).abs() < 1e-5);
396 assert!((c["c"] - 0.5).abs() < 1e-5);
397 }
398
399 #[test]
400 fn test_scale_weight_map() {
401 let m = map(&[("a", 0.4), ("b", 0.8)]);
402 let s = scale_weight_map(&m, 0.5);
403 assert!((s["a"] - 0.2).abs() < 1e-5);
404 assert!((s["b"] - 0.4).abs() < 1e-5);
405 }
406
407 #[test]
408 fn test_weight_map_magnitude() {
409 let m = map(&[("a", 3.0), ("b", 4.0)]);
410 let mag = weight_map_magnitude(&m);
411 assert!((mag - 5.0).abs() < 1e-4);
412 }
413
414 #[test]
415 fn test_top_n_weights() {
416 let m = map(&[("a", 0.1), ("b", 0.9), ("c", 0.5), ("d", -0.8)]);
417 let top = top_n_weights(&m, 2);
418 assert_eq!(top.len(), 2);
419 assert_eq!(top[0].0, "b");
420 assert_eq!(top[1].0, "d");
421 }
422
423 #[test]
424 fn test_top_n_weights_fewer_than_n() {
425 let m = map(&[("x", 0.3)]);
426 let top = top_n_weights(&m, 5);
427 assert_eq!(top.len(), 1);
428 }
429
430 #[test]
431 fn test_threshold_weight_map() {
432 let m = map(&[("a", 0.05), ("b", 0.5), ("c", -0.3)]);
433 let t = threshold_weight_map(&m, 0.1);
434 assert!(!t.contains_key("a"));
435 assert!(t.contains_key("b"));
436 assert!(t.contains_key("c"));
437 }
438
439 #[test]
442 fn test_lip_sync_layer_factory() {
443 let layer = lip_sync_layer(map(&[("vowel_a", 1.0)]), 0.7);
444 assert_eq!(layer.name, "lip_sync");
445 assert!(!layer.additive);
446 assert!((layer.blend - 0.7).abs() < 1e-5);
447 }
448
449 #[test]
450 fn test_emotion_layer_factory() {
451 let layer = emotion_layer(map(&[("smile", 0.8)]), 1.0);
452 assert_eq!(layer.name, "emotion");
453 assert!(!layer.additive);
454 }
455
456 #[test]
457 fn test_micro_expression_layer_factory() {
458 let layer = micro_expression_layer(map(&[("brow_raise", 0.3)]), 0.5);
459 assert_eq!(layer.name, "micro_expression");
460 assert!(layer.additive);
461 }
462
463 #[test]
464 fn test_corrective_layer_factory() {
465 let layer = corrective_layer(map(&[("jaw_fix", 0.1)]), 1.0);
466 assert_eq!(layer.name, "corrective");
467 assert!(layer.additive);
468 }
469}