Skip to main content

oxihuman_morph/
blend_tree.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5
6use std::collections::HashMap;
7
8/// A parameter set (name → value mapping)
9pub type ParamMap = HashMap<String, f32>;
10
11/// Blend mode for combining two param sets
12#[derive(Clone, Debug)]
13pub enum BlendMode {
14    /// Weighted average: result = a * (1-t) + b * t
15    Lerp,
16    /// Additive: result = a + b * t
17    Additive,
18    /// Override: result = a for t<0.5, else b
19    Override,
20    /// Multiply: result = a * b (t ignored)
21    Multiply,
22}
23
24/// A node in the blend tree
25#[derive(Clone, Debug)]
26pub enum BlendNode {
27    /// Leaf node: a named parameter set
28    Params { name: String, params: ParamMap },
29    /// Binary blend: mix two children
30    Blend {
31        mode: BlendMode,
32        weight: f32,
33        left: Box<BlendNode>,
34        right: Box<BlendNode>,
35    },
36    /// Scale all values in child by a factor
37    Scale { factor: f32, child: Box<BlendNode> },
38    /// Clamp all values in child to [min, max]
39    Clamp {
40        min: f32,
41        max: f32,
42        child: Box<BlendNode>,
43    },
44    /// Select one of N children by index
45    Select {
46        index: usize,
47        children: Vec<BlendNode>,
48    },
49}
50
51impl BlendNode {
52    /// Evaluate the node, returning a parameter map
53    pub fn evaluate(&self) -> ParamMap {
54        match self {
55            BlendNode::Params { params, .. } => params.clone(),
56
57            BlendNode::Blend {
58                mode,
59                weight,
60                left,
61                right,
62            } => {
63                let left_result = left.evaluate();
64                let right_result = right.evaluate();
65                blend_params(&left_result, &right_result, *weight, mode)
66            }
67
68            BlendNode::Scale { factor, child } => {
69                let result = child.evaluate();
70                scale_params(&result, *factor)
71            }
72
73            BlendNode::Clamp { min, max, child } => {
74                let result = child.evaluate();
75                clamp_params(&result, *min, *max)
76            }
77
78            BlendNode::Select { index, children } => {
79                if children.is_empty() {
80                    ParamMap::new()
81                } else {
82                    let i = index % children.len();
83                    children[i].evaluate()
84                }
85            }
86        }
87    }
88
89    /// Leaf constructor
90    pub fn leaf(name: impl Into<String>, params: ParamMap) -> Self {
91        BlendNode::Params {
92            name: name.into(),
93            params,
94        }
95    }
96
97    /// Lerp blend constructor
98    pub fn lerp(weight: f32, left: BlendNode, right: BlendNode) -> Self {
99        BlendNode::Blend {
100            mode: BlendMode::Lerp,
101            weight,
102            left: Box::new(left),
103            right: Box::new(right),
104        }
105    }
106
107    /// Additive blend constructor
108    pub fn additive(weight: f32, base: BlendNode, addon: BlendNode) -> Self {
109        BlendNode::Blend {
110            mode: BlendMode::Additive,
111            weight,
112            left: Box::new(base),
113            right: Box::new(addon),
114        }
115    }
116
117    /// Scale constructor
118    pub fn scale(factor: f32, child: BlendNode) -> Self {
119        BlendNode::Scale {
120            factor,
121            child: Box::new(child),
122        }
123    }
124
125    /// Clamp constructor
126    pub fn clamp(min: f32, max: f32, child: BlendNode) -> Self {
127        BlendNode::Clamp {
128            min,
129            max,
130            child: Box::new(child),
131        }
132    }
133
134    /// Select constructor
135    pub fn select(index: usize, children: Vec<BlendNode>) -> Self {
136        BlendNode::Select { index, children }
137    }
138
139    /// Depth of this subtree
140    pub fn depth(&self) -> usize {
141        match self {
142            BlendNode::Params { .. } => 1,
143            BlendNode::Blend { left, right, .. } => 1 + left.depth().max(right.depth()),
144            BlendNode::Scale { child, .. } => 1 + child.depth(),
145            BlendNode::Clamp { child, .. } => 1 + child.depth(),
146            BlendNode::Select { children, .. } => {
147                let max_child = children.iter().map(|c| c.depth()).max().unwrap_or(0);
148                1 + max_child
149            }
150        }
151    }
152
153    /// Count of leaf nodes
154    pub fn leaf_count(&self) -> usize {
155        match self {
156            BlendNode::Params { .. } => 1,
157            BlendNode::Blend { left, right, .. } => left.leaf_count() + right.leaf_count(),
158            BlendNode::Scale { child, .. } => child.leaf_count(),
159            BlendNode::Clamp { child, .. } => child.leaf_count(),
160            BlendNode::Select { children, .. } => children.iter().map(|c| c.leaf_count()).sum(),
161        }
162    }
163}
164
165/// Blend two param maps
166pub fn blend_params(a: &ParamMap, b: &ParamMap, weight: f32, mode: &BlendMode) -> ParamMap {
167    match mode {
168        BlendMode::Lerp => {
169            let all_keys: std::collections::HashSet<&String> = a.keys().chain(b.keys()).collect();
170            all_keys
171                .into_iter()
172                .map(|k| {
173                    let a_val = *a.get(k).unwrap_or(&0.0);
174                    let b_val = *b.get(k).unwrap_or(&0.0);
175                    let val = a_val * (1.0 - weight) + b_val * weight;
176                    (k.clone(), val)
177                })
178                .collect()
179        }
180
181        BlendMode::Additive => {
182            let all_keys: std::collections::HashSet<&String> = a.keys().chain(b.keys()).collect();
183            all_keys
184                .into_iter()
185                .map(|k| {
186                    let a_val = *a.get(k).unwrap_or(&0.0);
187                    let b_val = *b.get(k).unwrap_or(&0.0);
188                    let val = a_val + b_val * weight;
189                    (k.clone(), val)
190                })
191                .collect()
192        }
193
194        BlendMode::Override => {
195            if weight < 0.5 {
196                a.clone()
197            } else {
198                b.clone()
199            }
200        }
201
202        BlendMode::Multiply => {
203            let all_keys: std::collections::HashSet<&String> = a.keys().chain(b.keys()).collect();
204            all_keys
205                .into_iter()
206                .map(|k| {
207                    let a_val = *a.get(k).unwrap_or(&0.0);
208                    let b_val = *b.get(k).unwrap_or(&0.0);
209                    let val = a_val * b_val;
210                    (k.clone(), val)
211                })
212                .collect()
213        }
214    }
215}
216
217/// Merge all keys from both maps (union), using value from a for keys only in a, b for b-only.
218/// When a key exists in both, a wins.
219pub fn merge_params(a: &ParamMap, b: &ParamMap) -> ParamMap {
220    let mut result = b.clone();
221    for (k, v) in a {
222        result.insert(k.clone(), *v);
223    }
224    result
225}
226
227/// Scale all values in a param map
228pub fn scale_params(params: &ParamMap, factor: f32) -> ParamMap {
229    params
230        .iter()
231        .map(|(k, v)| (k.clone(), v * factor))
232        .collect()
233}
234
235/// Clamp all values in a param map
236pub fn clamp_params(params: &ParamMap, min: f32, max: f32) -> ParamMap {
237    params
238        .iter()
239        .map(|(k, v)| (k.clone(), v.clamp(min, max)))
240        .collect()
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    fn make_params(pairs: &[(&str, f32)]) -> ParamMap {
248        pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
249    }
250
251    #[test]
252    fn test_leaf_evaluate() {
253        let params = make_params(&[("height", 1.8), ("weight", 75.0)]);
254        let node = BlendNode::leaf("base", params.clone());
255        let result = node.evaluate();
256        assert_eq!(result.get("height"), Some(&1.8));
257        assert_eq!(result.get("weight"), Some(&75.0));
258    }
259
260    #[test]
261    fn test_lerp_blend_zero() {
262        let a = make_params(&[("x", 0.0)]);
263        let b = make_params(&[("x", 10.0)]);
264        let node = BlendNode::lerp(0.0, BlendNode::leaf("a", a), BlendNode::leaf("b", b));
265        let result = node.evaluate();
266        let x = result["x"];
267        assert!((x - 0.0).abs() < 1e-6, "Expected 0.0, got {x}");
268    }
269
270    #[test]
271    fn test_lerp_blend_one() {
272        let a = make_params(&[("x", 0.0)]);
273        let b = make_params(&[("x", 10.0)]);
274        let node = BlendNode::lerp(1.0, BlendNode::leaf("a", a), BlendNode::leaf("b", b));
275        let result = node.evaluate();
276        let x = result["x"];
277        assert!((x - 10.0).abs() < 1e-6, "Expected 10.0, got {x}");
278    }
279
280    #[test]
281    fn test_lerp_blend_half() {
282        let a = make_params(&[("x", 0.0)]);
283        let b = make_params(&[("x", 10.0)]);
284        let node = BlendNode::lerp(0.5, BlendNode::leaf("a", a), BlendNode::leaf("b", b));
285        let result = node.evaluate();
286        let x = result["x"];
287        assert!((x - 5.0).abs() < 1e-6, "Expected 5.0, got {x}");
288    }
289
290    #[test]
291    fn test_additive_blend() {
292        let base = make_params(&[("x", 3.0)]);
293        let addon = make_params(&[("x", 2.0)]);
294        // weight=0.5: result = 3 + 2*0.5 = 4
295        let node = BlendNode::additive(
296            0.5,
297            BlendNode::leaf("base", base),
298            BlendNode::leaf("addon", addon),
299        );
300        let result = node.evaluate();
301        let x = result["x"];
302        assert!((x - 4.0).abs() < 1e-6, "Expected 4.0, got {x}");
303    }
304
305    #[test]
306    fn test_override_blend() {
307        let a = make_params(&[("x", 1.0)]);
308        let b = make_params(&[("x", 9.0)]);
309        // weight < 0.5 → return a
310        let node_a = BlendNode::lerp(
311            0.3,
312            BlendNode::leaf("a", a.clone()),
313            BlendNode::leaf("b", b.clone()),
314        );
315        // Use Override mode directly via Blend variant
316        let node_over_a = BlendNode::Blend {
317            mode: BlendMode::Override,
318            weight: 0.3,
319            left: Box::new(BlendNode::leaf("a", a.clone())),
320            right: Box::new(BlendNode::leaf("b", b.clone())),
321        };
322        let node_over_b = BlendNode::Blend {
323            mode: BlendMode::Override,
324            weight: 0.7,
325            left: Box::new(BlendNode::leaf("a", a.clone())),
326            right: Box::new(BlendNode::leaf("b", b.clone())),
327        };
328        // lerp node just to suppress warning
329        let _ = node_a.evaluate();
330        let result_a = node_over_a.evaluate();
331        let result_b = node_over_b.evaluate();
332        assert!((result_a["x"] - 1.0).abs() < 1e-6);
333        assert!((result_b["x"] - 9.0).abs() < 1e-6);
334    }
335
336    #[test]
337    fn test_multiply_blend() {
338        let a = make_params(&[("x", 3.0)]);
339        let b = make_params(&[("x", 4.0)]);
340        let node = BlendNode::Blend {
341            mode: BlendMode::Multiply,
342            weight: 0.5, // ignored
343            left: Box::new(BlendNode::leaf("a", a)),
344            right: Box::new(BlendNode::leaf("b", b)),
345        };
346        let result = node.evaluate();
347        let x = result["x"];
348        assert!((x - 12.0).abs() < 1e-6, "Expected 12.0, got {x}");
349    }
350
351    #[test]
352    fn test_scale_node() {
353        let params = make_params(&[("x", 5.0), ("y", 2.0)]);
354        let node = BlendNode::scale(3.0, BlendNode::leaf("base", params));
355        let result = node.evaluate();
356        assert!((result["x"] - 15.0).abs() < 1e-6);
357        assert!((result["y"] - 6.0).abs() < 1e-6);
358    }
359
360    #[test]
361    fn test_clamp_node() {
362        let params = make_params(&[("x", -5.0), ("y", 15.0), ("z", 0.5)]);
363        let node = BlendNode::clamp(0.0, 1.0, BlendNode::leaf("base", params));
364        let result = node.evaluate();
365        assert!((result["x"] - 0.0).abs() < 1e-6);
366        assert!((result["y"] - 1.0).abs() < 1e-6);
367        assert!((result["z"] - 0.5).abs() < 1e-6);
368    }
369
370    #[test]
371    fn test_select_node() {
372        let c0 = BlendNode::leaf("c0", make_params(&[("v", 1.0)]));
373        let c1 = BlendNode::leaf("c1", make_params(&[("v", 2.0)]));
374        let c2 = BlendNode::leaf("c2", make_params(&[("v", 3.0)]));
375
376        let node = BlendNode::select(1, vec![c0, c1, c2]);
377        let result = node.evaluate();
378        assert!((result["v"] - 2.0).abs() < 1e-6);
379
380        // Test wrapping: index 4 % 3 = 1
381        let c0b = BlendNode::leaf("c0", make_params(&[("v", 1.0)]));
382        let c1b = BlendNode::leaf("c1", make_params(&[("v", 2.0)]));
383        let c2b = BlendNode::leaf("c2", make_params(&[("v", 3.0)]));
384        let node2 = BlendNode::select(4, vec![c0b, c1b, c2b]);
385        let result2 = node2.evaluate();
386        assert!((result2["v"] - 2.0).abs() < 1e-6);
387
388        // Empty children → empty map
389        let node3 = BlendNode::select(0, vec![]);
390        let result3 = node3.evaluate();
391        assert!(result3.is_empty());
392    }
393
394    #[test]
395    fn test_blend_params_missing_key() {
396        let a = make_params(&[("x", 4.0)]);
397        let b = make_params(&[("y", 6.0)]);
398        let result = blend_params(&a, &b, 0.5, &BlendMode::Lerp);
399        // x: 4*0.5 + 0*0.5 = 2.0
400        // y: 0*0.5 + 6*0.5 = 3.0
401        assert!((result["x"] - 2.0).abs() < 1e-6);
402        assert!((result["y"] - 3.0).abs() < 1e-6);
403    }
404
405    #[test]
406    fn test_merge_params() {
407        let a = make_params(&[("x", 1.0), ("shared", 10.0)]);
408        let b = make_params(&[("y", 2.0), ("shared", 99.0)]);
409        let result = merge_params(&a, &b);
410        // a wins on shared key
411        assert!((result["x"] - 1.0).abs() < 1e-6);
412        assert!((result["y"] - 2.0).abs() < 1e-6);
413        assert!((result["shared"] - 10.0).abs() < 1e-6);
414    }
415
416    #[test]
417    fn test_depth() {
418        let leaf = BlendNode::leaf("l", make_params(&[("x", 1.0)]));
419        assert_eq!(leaf.depth(), 1);
420
421        let leaf2 = BlendNode::leaf("l2", make_params(&[("x", 2.0)]));
422        let blend = BlendNode::lerp(0.5, leaf, leaf2);
423        assert_eq!(blend.depth(), 2);
424
425        let leaf3 = BlendNode::leaf("l3", make_params(&[("x", 3.0)]));
426        let scaled = BlendNode::scale(1.0, leaf3);
427        assert_eq!(scaled.depth(), 2);
428
429        // deeper: blend of (blend of leaves) and leaf → depth 3
430        let la = BlendNode::leaf("a", make_params(&[("x", 0.0)]));
431        let lb = BlendNode::leaf("b", make_params(&[("x", 1.0)]));
432        let lc = BlendNode::leaf("c", make_params(&[("x", 2.0)]));
433        let inner = BlendNode::lerp(0.5, la, lb);
434        let outer = BlendNode::lerp(0.5, inner, lc);
435        assert_eq!(outer.depth(), 3);
436    }
437
438    #[test]
439    fn test_leaf_count() {
440        let leaf = BlendNode::leaf("l", make_params(&[("x", 1.0)]));
441        assert_eq!(leaf.leaf_count(), 1);
442
443        let la = BlendNode::leaf("a", make_params(&[("x", 0.0)]));
444        let lb = BlendNode::leaf("b", make_params(&[("x", 1.0)]));
445        let blend = BlendNode::lerp(0.5, la, lb);
446        assert_eq!(blend.leaf_count(), 2);
447
448        let c0 = BlendNode::leaf("c0", make_params(&[("v", 1.0)]));
449        let c1 = BlendNode::leaf("c1", make_params(&[("v", 2.0)]));
450        let c2 = BlendNode::leaf("c2", make_params(&[("v", 3.0)]));
451        let sel = BlendNode::select(0, vec![c0, c1, c2]);
452        assert_eq!(sel.leaf_count(), 3);
453    }
454
455    #[test]
456    fn test_nested_blend() {
457        // Build a tree: clamp(scale(lerp(leaf_a, leaf_b, 0.5), 2.0), 0.0, 5.0)
458        // leaf_a: x=1, leaf_b: x=3 → lerp(0.5) → x=2 → scale(2) → x=4 → clamp(0,5) → x=4
459        let a = BlendNode::leaf("a", make_params(&[("x", 1.0)]));
460        let b = BlendNode::leaf("b", make_params(&[("x", 3.0)]));
461        let blended = BlendNode::lerp(0.5, a, b);
462        let scaled = BlendNode::scale(2.0, blended);
463        let clamped = BlendNode::clamp(0.0, 5.0, scaled);
464        let result = clamped.evaluate();
465        assert!(
466            (result["x"] - 4.0).abs() < 1e-6,
467            "Expected 4.0, got {}",
468            result["x"]
469        );
470
471        // Also verify leaf count and depth
472        assert_eq!(clamped.depth(), 4);
473        assert_eq!(clamped.leaf_count(), 2);
474    }
475}