1#![allow(dead_code)]
5
6use std::collections::HashMap;
7
8pub type ParamMap = HashMap<String, f32>;
10
11#[derive(Clone, Debug)]
13pub enum BlendMode {
14 Lerp,
16 Additive,
18 Override,
20 Multiply,
22}
23
24#[derive(Clone, Debug)]
26pub enum BlendNode {
27 Params { name: String, params: ParamMap },
29 Blend {
31 mode: BlendMode,
32 weight: f32,
33 left: Box<BlendNode>,
34 right: Box<BlendNode>,
35 },
36 Scale { factor: f32, child: Box<BlendNode> },
38 Clamp {
40 min: f32,
41 max: f32,
42 child: Box<BlendNode>,
43 },
44 Select {
46 index: usize,
47 children: Vec<BlendNode>,
48 },
49}
50
51impl BlendNode {
52 pub fn evaluate(&self) -> ParamMap {
54 match self {
55 BlendNode::Params { params, .. } => params.clone(),
56
57 BlendNode::Blend {
58 mode,
59 weight,
60 left,
61 right,
62 } => {
63 let left_result = left.evaluate();
64 let right_result = right.evaluate();
65 blend_params(&left_result, &right_result, *weight, mode)
66 }
67
68 BlendNode::Scale { factor, child } => {
69 let result = child.evaluate();
70 scale_params(&result, *factor)
71 }
72
73 BlendNode::Clamp { min, max, child } => {
74 let result = child.evaluate();
75 clamp_params(&result, *min, *max)
76 }
77
78 BlendNode::Select { index, children } => {
79 if children.is_empty() {
80 ParamMap::new()
81 } else {
82 let i = index % children.len();
83 children[i].evaluate()
84 }
85 }
86 }
87 }
88
89 pub fn leaf(name: impl Into<String>, params: ParamMap) -> Self {
91 BlendNode::Params {
92 name: name.into(),
93 params,
94 }
95 }
96
97 pub fn lerp(weight: f32, left: BlendNode, right: BlendNode) -> Self {
99 BlendNode::Blend {
100 mode: BlendMode::Lerp,
101 weight,
102 left: Box::new(left),
103 right: Box::new(right),
104 }
105 }
106
107 pub fn additive(weight: f32, base: BlendNode, addon: BlendNode) -> Self {
109 BlendNode::Blend {
110 mode: BlendMode::Additive,
111 weight,
112 left: Box::new(base),
113 right: Box::new(addon),
114 }
115 }
116
117 pub fn scale(factor: f32, child: BlendNode) -> Self {
119 BlendNode::Scale {
120 factor,
121 child: Box::new(child),
122 }
123 }
124
125 pub fn clamp(min: f32, max: f32, child: BlendNode) -> Self {
127 BlendNode::Clamp {
128 min,
129 max,
130 child: Box::new(child),
131 }
132 }
133
134 pub fn select(index: usize, children: Vec<BlendNode>) -> Self {
136 BlendNode::Select { index, children }
137 }
138
139 pub fn depth(&self) -> usize {
141 match self {
142 BlendNode::Params { .. } => 1,
143 BlendNode::Blend { left, right, .. } => 1 + left.depth().max(right.depth()),
144 BlendNode::Scale { child, .. } => 1 + child.depth(),
145 BlendNode::Clamp { child, .. } => 1 + child.depth(),
146 BlendNode::Select { children, .. } => {
147 let max_child = children.iter().map(|c| c.depth()).max().unwrap_or(0);
148 1 + max_child
149 }
150 }
151 }
152
153 pub fn leaf_count(&self) -> usize {
155 match self {
156 BlendNode::Params { .. } => 1,
157 BlendNode::Blend { left, right, .. } => left.leaf_count() + right.leaf_count(),
158 BlendNode::Scale { child, .. } => child.leaf_count(),
159 BlendNode::Clamp { child, .. } => child.leaf_count(),
160 BlendNode::Select { children, .. } => children.iter().map(|c| c.leaf_count()).sum(),
161 }
162 }
163}
164
165pub fn blend_params(a: &ParamMap, b: &ParamMap, weight: f32, mode: &BlendMode) -> ParamMap {
167 match mode {
168 BlendMode::Lerp => {
169 let all_keys: std::collections::HashSet<&String> = a.keys().chain(b.keys()).collect();
170 all_keys
171 .into_iter()
172 .map(|k| {
173 let a_val = *a.get(k).unwrap_or(&0.0);
174 let b_val = *b.get(k).unwrap_or(&0.0);
175 let val = a_val * (1.0 - weight) + b_val * weight;
176 (k.clone(), val)
177 })
178 .collect()
179 }
180
181 BlendMode::Additive => {
182 let all_keys: std::collections::HashSet<&String> = a.keys().chain(b.keys()).collect();
183 all_keys
184 .into_iter()
185 .map(|k| {
186 let a_val = *a.get(k).unwrap_or(&0.0);
187 let b_val = *b.get(k).unwrap_or(&0.0);
188 let val = a_val + b_val * weight;
189 (k.clone(), val)
190 })
191 .collect()
192 }
193
194 BlendMode::Override => {
195 if weight < 0.5 {
196 a.clone()
197 } else {
198 b.clone()
199 }
200 }
201
202 BlendMode::Multiply => {
203 let all_keys: std::collections::HashSet<&String> = a.keys().chain(b.keys()).collect();
204 all_keys
205 .into_iter()
206 .map(|k| {
207 let a_val = *a.get(k).unwrap_or(&0.0);
208 let b_val = *b.get(k).unwrap_or(&0.0);
209 let val = a_val * b_val;
210 (k.clone(), val)
211 })
212 .collect()
213 }
214 }
215}
216
217pub fn merge_params(a: &ParamMap, b: &ParamMap) -> ParamMap {
220 let mut result = b.clone();
221 for (k, v) in a {
222 result.insert(k.clone(), *v);
223 }
224 result
225}
226
227pub fn scale_params(params: &ParamMap, factor: f32) -> ParamMap {
229 params
230 .iter()
231 .map(|(k, v)| (k.clone(), v * factor))
232 .collect()
233}
234
235pub fn clamp_params(params: &ParamMap, min: f32, max: f32) -> ParamMap {
237 params
238 .iter()
239 .map(|(k, v)| (k.clone(), v.clamp(min, max)))
240 .collect()
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 fn make_params(pairs: &[(&str, f32)]) -> ParamMap {
248 pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
249 }
250
251 #[test]
252 fn test_leaf_evaluate() {
253 let params = make_params(&[("height", 1.8), ("weight", 75.0)]);
254 let node = BlendNode::leaf("base", params.clone());
255 let result = node.evaluate();
256 assert_eq!(result.get("height"), Some(&1.8));
257 assert_eq!(result.get("weight"), Some(&75.0));
258 }
259
260 #[test]
261 fn test_lerp_blend_zero() {
262 let a = make_params(&[("x", 0.0)]);
263 let b = make_params(&[("x", 10.0)]);
264 let node = BlendNode::lerp(0.0, BlendNode::leaf("a", a), BlendNode::leaf("b", b));
265 let result = node.evaluate();
266 let x = result["x"];
267 assert!((x - 0.0).abs() < 1e-6, "Expected 0.0, got {x}");
268 }
269
270 #[test]
271 fn test_lerp_blend_one() {
272 let a = make_params(&[("x", 0.0)]);
273 let b = make_params(&[("x", 10.0)]);
274 let node = BlendNode::lerp(1.0, BlendNode::leaf("a", a), BlendNode::leaf("b", b));
275 let result = node.evaluate();
276 let x = result["x"];
277 assert!((x - 10.0).abs() < 1e-6, "Expected 10.0, got {x}");
278 }
279
280 #[test]
281 fn test_lerp_blend_half() {
282 let a = make_params(&[("x", 0.0)]);
283 let b = make_params(&[("x", 10.0)]);
284 let node = BlendNode::lerp(0.5, BlendNode::leaf("a", a), BlendNode::leaf("b", b));
285 let result = node.evaluate();
286 let x = result["x"];
287 assert!((x - 5.0).abs() < 1e-6, "Expected 5.0, got {x}");
288 }
289
290 #[test]
291 fn test_additive_blend() {
292 let base = make_params(&[("x", 3.0)]);
293 let addon = make_params(&[("x", 2.0)]);
294 let node = BlendNode::additive(
296 0.5,
297 BlendNode::leaf("base", base),
298 BlendNode::leaf("addon", addon),
299 );
300 let result = node.evaluate();
301 let x = result["x"];
302 assert!((x - 4.0).abs() < 1e-6, "Expected 4.0, got {x}");
303 }
304
305 #[test]
306 fn test_override_blend() {
307 let a = make_params(&[("x", 1.0)]);
308 let b = make_params(&[("x", 9.0)]);
309 let node_a = BlendNode::lerp(
311 0.3,
312 BlendNode::leaf("a", a.clone()),
313 BlendNode::leaf("b", b.clone()),
314 );
315 let node_over_a = BlendNode::Blend {
317 mode: BlendMode::Override,
318 weight: 0.3,
319 left: Box::new(BlendNode::leaf("a", a.clone())),
320 right: Box::new(BlendNode::leaf("b", b.clone())),
321 };
322 let node_over_b = BlendNode::Blend {
323 mode: BlendMode::Override,
324 weight: 0.7,
325 left: Box::new(BlendNode::leaf("a", a.clone())),
326 right: Box::new(BlendNode::leaf("b", b.clone())),
327 };
328 let _ = node_a.evaluate();
330 let result_a = node_over_a.evaluate();
331 let result_b = node_over_b.evaluate();
332 assert!((result_a["x"] - 1.0).abs() < 1e-6);
333 assert!((result_b["x"] - 9.0).abs() < 1e-6);
334 }
335
336 #[test]
337 fn test_multiply_blend() {
338 let a = make_params(&[("x", 3.0)]);
339 let b = make_params(&[("x", 4.0)]);
340 let node = BlendNode::Blend {
341 mode: BlendMode::Multiply,
342 weight: 0.5, left: Box::new(BlendNode::leaf("a", a)),
344 right: Box::new(BlendNode::leaf("b", b)),
345 };
346 let result = node.evaluate();
347 let x = result["x"];
348 assert!((x - 12.0).abs() < 1e-6, "Expected 12.0, got {x}");
349 }
350
351 #[test]
352 fn test_scale_node() {
353 let params = make_params(&[("x", 5.0), ("y", 2.0)]);
354 let node = BlendNode::scale(3.0, BlendNode::leaf("base", params));
355 let result = node.evaluate();
356 assert!((result["x"] - 15.0).abs() < 1e-6);
357 assert!((result["y"] - 6.0).abs() < 1e-6);
358 }
359
360 #[test]
361 fn test_clamp_node() {
362 let params = make_params(&[("x", -5.0), ("y", 15.0), ("z", 0.5)]);
363 let node = BlendNode::clamp(0.0, 1.0, BlendNode::leaf("base", params));
364 let result = node.evaluate();
365 assert!((result["x"] - 0.0).abs() < 1e-6);
366 assert!((result["y"] - 1.0).abs() < 1e-6);
367 assert!((result["z"] - 0.5).abs() < 1e-6);
368 }
369
370 #[test]
371 fn test_select_node() {
372 let c0 = BlendNode::leaf("c0", make_params(&[("v", 1.0)]));
373 let c1 = BlendNode::leaf("c1", make_params(&[("v", 2.0)]));
374 let c2 = BlendNode::leaf("c2", make_params(&[("v", 3.0)]));
375
376 let node = BlendNode::select(1, vec![c0, c1, c2]);
377 let result = node.evaluate();
378 assert!((result["v"] - 2.0).abs() < 1e-6);
379
380 let c0b = BlendNode::leaf("c0", make_params(&[("v", 1.0)]));
382 let c1b = BlendNode::leaf("c1", make_params(&[("v", 2.0)]));
383 let c2b = BlendNode::leaf("c2", make_params(&[("v", 3.0)]));
384 let node2 = BlendNode::select(4, vec![c0b, c1b, c2b]);
385 let result2 = node2.evaluate();
386 assert!((result2["v"] - 2.0).abs() < 1e-6);
387
388 let node3 = BlendNode::select(0, vec![]);
390 let result3 = node3.evaluate();
391 assert!(result3.is_empty());
392 }
393
394 #[test]
395 fn test_blend_params_missing_key() {
396 let a = make_params(&[("x", 4.0)]);
397 let b = make_params(&[("y", 6.0)]);
398 let result = blend_params(&a, &b, 0.5, &BlendMode::Lerp);
399 assert!((result["x"] - 2.0).abs() < 1e-6);
402 assert!((result["y"] - 3.0).abs() < 1e-6);
403 }
404
405 #[test]
406 fn test_merge_params() {
407 let a = make_params(&[("x", 1.0), ("shared", 10.0)]);
408 let b = make_params(&[("y", 2.0), ("shared", 99.0)]);
409 let result = merge_params(&a, &b);
410 assert!((result["x"] - 1.0).abs() < 1e-6);
412 assert!((result["y"] - 2.0).abs() < 1e-6);
413 assert!((result["shared"] - 10.0).abs() < 1e-6);
414 }
415
416 #[test]
417 fn test_depth() {
418 let leaf = BlendNode::leaf("l", make_params(&[("x", 1.0)]));
419 assert_eq!(leaf.depth(), 1);
420
421 let leaf2 = BlendNode::leaf("l2", make_params(&[("x", 2.0)]));
422 let blend = BlendNode::lerp(0.5, leaf, leaf2);
423 assert_eq!(blend.depth(), 2);
424
425 let leaf3 = BlendNode::leaf("l3", make_params(&[("x", 3.0)]));
426 let scaled = BlendNode::scale(1.0, leaf3);
427 assert_eq!(scaled.depth(), 2);
428
429 let la = BlendNode::leaf("a", make_params(&[("x", 0.0)]));
431 let lb = BlendNode::leaf("b", make_params(&[("x", 1.0)]));
432 let lc = BlendNode::leaf("c", make_params(&[("x", 2.0)]));
433 let inner = BlendNode::lerp(0.5, la, lb);
434 let outer = BlendNode::lerp(0.5, inner, lc);
435 assert_eq!(outer.depth(), 3);
436 }
437
438 #[test]
439 fn test_leaf_count() {
440 let leaf = BlendNode::leaf("l", make_params(&[("x", 1.0)]));
441 assert_eq!(leaf.leaf_count(), 1);
442
443 let la = BlendNode::leaf("a", make_params(&[("x", 0.0)]));
444 let lb = BlendNode::leaf("b", make_params(&[("x", 1.0)]));
445 let blend = BlendNode::lerp(0.5, la, lb);
446 assert_eq!(blend.leaf_count(), 2);
447
448 let c0 = BlendNode::leaf("c0", make_params(&[("v", 1.0)]));
449 let c1 = BlendNode::leaf("c1", make_params(&[("v", 2.0)]));
450 let c2 = BlendNode::leaf("c2", make_params(&[("v", 3.0)]));
451 let sel = BlendNode::select(0, vec![c0, c1, c2]);
452 assert_eq!(sel.leaf_count(), 3);
453 }
454
455 #[test]
456 fn test_nested_blend() {
457 let a = BlendNode::leaf("a", make_params(&[("x", 1.0)]));
460 let b = BlendNode::leaf("b", make_params(&[("x", 3.0)]));
461 let blended = BlendNode::lerp(0.5, a, b);
462 let scaled = BlendNode::scale(2.0, blended);
463 let clamped = BlendNode::clamp(0.0, 5.0, scaled);
464 let result = clamped.evaluate();
465 assert!(
466 (result["x"] - 4.0).abs() < 1e-6,
467 "Expected 4.0, got {}",
468 result["x"]
469 );
470
471 assert_eq!(clamped.depth(), 4);
473 assert_eq!(clamped.leaf_count(), 2);
474 }
475}