Skip to main content

oxihuman_morph/
pose_graph.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5
6use std::collections::{HashMap, VecDeque};
7
8pub type PoseParams = HashMap<String, f32>;
9
10/// Easing function for transitions
11#[derive(Clone, Debug)]
12pub enum Easing {
13    Linear,
14    EaseIn,    // quadratic ease-in
15    EaseOut,   // quadratic ease-out
16    EaseInOut, // smooth cubic
17    Spring,    // spring-like overshoot
18}
19
20/// A transition between two poses
21#[derive(Clone, Debug)]
22pub struct PoseTransition {
23    pub from: String,
24    pub to: String,
25    pub duration: f32,
26    pub easing: Easing,
27    /// Condition that triggers this transition (parameter name + threshold)
28    pub trigger: Option<(String, f32)>,
29}
30
31impl PoseTransition {
32    pub fn new(from: impl Into<String>, to: impl Into<String>, duration: f32) -> Self {
33        Self {
34            from: from.into(),
35            to: to.into(),
36            duration,
37            easing: Easing::Linear,
38            trigger: None,
39        }
40    }
41
42    pub fn with_easing(mut self, easing: Easing) -> Self {
43        self.easing = easing;
44        self
45    }
46
47    pub fn with_trigger(mut self, param: impl Into<String>, threshold: f32) -> Self {
48        self.trigger = Some((param.into(), threshold));
49        self
50    }
51
52    /// Evaluate easing at t in `[0,1]`
53    pub fn ease(easing: &Easing, t: f32) -> f32 {
54        apply_easing(easing, t)
55    }
56}
57
58/// A node in the pose graph
59#[derive(Clone, Debug)]
60pub struct PoseNode {
61    pub name: String,
62    pub params: PoseParams,
63    pub loop_animation: bool,
64}
65
66impl PoseNode {
67    pub fn new(name: impl Into<String>, params: PoseParams) -> Self {
68        Self {
69            name: name.into(),
70            params,
71            loop_animation: false,
72        }
73    }
74}
75
76/// Pose graph state machine
77pub struct PoseGraph {
78    nodes: HashMap<String, PoseNode>,
79    transitions: Vec<PoseTransition>,
80    current_state: String,
81    target_state: Option<String>,
82    transition_progress: f32,
83    transition_duration: f32,
84    active_transition: Option<usize>,
85}
86
87impl PoseGraph {
88    pub fn new(initial_state: &str, initial_params: PoseParams) -> Self {
89        let node = PoseNode::new(initial_state, initial_params);
90        let mut nodes = HashMap::new();
91        nodes.insert(initial_state.to_string(), node);
92        Self {
93            nodes,
94            transitions: Vec::new(),
95            current_state: initial_state.to_string(),
96            target_state: None,
97            transition_progress: 0.0,
98            transition_duration: 0.0,
99            active_transition: None,
100        }
101    }
102
103    pub fn add_node(&mut self, node: PoseNode) {
104        self.nodes.insert(node.name.clone(), node);
105    }
106
107    pub fn add_transition(&mut self, transition: PoseTransition) {
108        self.transitions.push(transition);
109    }
110
111    pub fn node_count(&self) -> usize {
112        self.nodes.len()
113    }
114
115    pub fn transition_count(&self) -> usize {
116        self.transitions.len()
117    }
118
119    pub fn current_state(&self) -> &str {
120        &self.current_state
121    }
122
123    pub fn is_transitioning(&self) -> bool {
124        self.target_state.is_some()
125    }
126
127    pub fn transition_progress(&self) -> f32 {
128        self.transition_progress
129    }
130
131    /// Trigger a transition to the named state.
132    /// Returns false if no transition is defined or already in that state.
133    pub fn go_to(&mut self, state: &str) -> bool {
134        if self.current_state == state {
135            return false;
136        }
137        // Find a transition from current_state to state
138        let found = self
139            .transitions
140            .iter()
141            .enumerate()
142            .find(|(_, t)| t.from == self.current_state && t.to == state);
143
144        if let Some((idx, t)) = found {
145            let duration = t.duration;
146            self.target_state = Some(state.to_string());
147            self.transition_duration = duration;
148            self.transition_progress = 0.0;
149            self.active_transition = Some(idx);
150            true
151        } else {
152            false
153        }
154    }
155
156    /// Update the state machine by dt seconds
157    pub fn update(&mut self, dt: f32) {
158        if self.target_state.is_none() {
159            return;
160        }
161        let dur = if self.transition_duration > 0.0 {
162            self.transition_duration
163        } else {
164            1.0
165        };
166        self.transition_progress += dt / dur;
167        if self.transition_progress >= 1.0 {
168            // Complete the transition
169            if let Some(target) = self.target_state.take() {
170                self.current_state = target;
171            }
172            self.transition_progress = 0.0;
173            self.transition_duration = 0.0;
174            self.active_transition = None;
175        }
176    }
177
178    /// Check and auto-trigger transitions based on param values
179    pub fn check_triggers(&mut self, params: &PoseParams) {
180        if self.is_transitioning() {
181            return;
182        }
183        // Collect candidates first to avoid borrow conflict
184        let candidates: Vec<String> = self
185            .transitions
186            .iter()
187            .filter(|t| t.from == self.current_state)
188            .filter_map(|t| {
189                if let Some((ref param_name, threshold)) = t.trigger {
190                    if let Some(&val) = params.get(param_name) {
191                        if val >= threshold {
192                            return Some(t.to.clone());
193                        }
194                    }
195                }
196                None
197            })
198            .collect();
199
200        if let Some(target) = candidates.into_iter().next() {
201            self.go_to(&target);
202        }
203    }
204
205    /// Get current interpolated parameter values
206    pub fn evaluate(&self) -> PoseParams {
207        let current_node = match self.nodes.get(&self.current_state) {
208            Some(n) => n,
209            None => return PoseParams::new(),
210        };
211
212        if !self.is_transitioning() {
213            return current_node.params.clone();
214        }
215
216        let target_name = match &self.target_state {
217            Some(s) => s,
218            None => return current_node.params.clone(),
219        };
220
221        let target_node = match self.nodes.get(target_name) {
222            Some(n) => n,
223            None => return current_node.params.clone(),
224        };
225
226        // Get easing factor
227        let eased_t = if let Some(idx) = self.active_transition {
228            if let Some(trans) = self.transitions.get(idx) {
229                apply_easing(&trans.easing, self.transition_progress.clamp(0.0, 1.0))
230            } else {
231                self.transition_progress.clamp(0.0, 1.0)
232            }
233        } else {
234            self.transition_progress.clamp(0.0, 1.0)
235        };
236
237        // Lerp between current and target params
238        let mut result = current_node.params.clone();
239        // Add keys from target that might not be in current
240        for (k, &tv) in &target_node.params {
241            let cv = current_node.params.get(k).copied().unwrap_or(0.0);
242            result.insert(k.clone(), cv + eased_t * (tv - cv));
243        }
244        // Keys only in current stay as-is (already in result, lerp toward 0 is not desired — keep them)
245        result
246    }
247
248    /// Get all reachable states from current via BFS
249    pub fn reachable_states(&self) -> Vec<&str> {
250        let mut visited: Vec<&str> = Vec::new();
251        let mut queue: VecDeque<&str> = VecDeque::new();
252        queue.push_back(&self.current_state);
253
254        while let Some(state) = queue.pop_front() {
255            if visited.contains(&state) {
256                continue;
257            }
258            visited.push(state);
259            for t in &self.transitions {
260                if t.from == state && !visited.contains(&t.to.as_str()) {
261                    queue.push_back(&t.to);
262                }
263            }
264        }
265
266        // Remove the current state itself from results (reachable = others)
267        visited
268            .into_iter()
269            .filter(|&s| s != self.current_state)
270            .collect()
271    }
272}
273
274/// Apply easing function to t
275pub fn apply_easing(easing: &Easing, t: f32) -> f32 {
276    match easing {
277        Easing::Linear => t,
278        Easing::EaseIn => t * t,
279        Easing::EaseOut => 1.0 - (1.0 - t) * (1.0 - t),
280        Easing::EaseInOut => t * t * (3.0 - 2.0 * t),
281        Easing::Spring => 1.0 - (1.0 - t).powi(3) * (1.0 + 3.0 * t),
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    fn approx_eq(a: f32, b: f32) -> bool {
290        (a - b).abs() < 1e-5
291    }
292
293    fn make_params(pairs: &[(&str, f32)]) -> PoseParams {
294        pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
295    }
296
297    #[test]
298    fn test_ease_linear() {
299        assert!(approx_eq(apply_easing(&Easing::Linear, 0.0), 0.0));
300        assert!(approx_eq(apply_easing(&Easing::Linear, 0.5), 0.5));
301        assert!(approx_eq(apply_easing(&Easing::Linear, 1.0), 1.0));
302    }
303
304    #[test]
305    fn test_ease_ease_in() {
306        assert!(approx_eq(apply_easing(&Easing::EaseIn, 0.0), 0.0));
307        assert!(approx_eq(apply_easing(&Easing::EaseIn, 0.5), 0.25));
308        assert!(approx_eq(apply_easing(&Easing::EaseIn, 1.0), 1.0));
309    }
310
311    #[test]
312    fn test_ease_ease_out() {
313        assert!(approx_eq(apply_easing(&Easing::EaseOut, 0.0), 0.0));
314        assert!(approx_eq(apply_easing(&Easing::EaseOut, 0.5), 0.75));
315        assert!(approx_eq(apply_easing(&Easing::EaseOut, 1.0), 1.0));
316    }
317
318    #[test]
319    fn test_ease_ease_in_out() {
320        assert!(approx_eq(apply_easing(&Easing::EaseInOut, 0.0), 0.0));
321        // smoothstep at 0.5: 0.5*0.5*(3-1) = 0.5
322        assert!(approx_eq(apply_easing(&Easing::EaseInOut, 0.5), 0.5));
323        assert!(approx_eq(apply_easing(&Easing::EaseInOut, 1.0), 1.0));
324    }
325
326    #[test]
327    fn test_ease_spring() {
328        // At t=0: 1 - 1^3 * (1+0) = 0
329        assert!(approx_eq(apply_easing(&Easing::Spring, 0.0), 0.0));
330        // At t=1: 1 - 0^3 * (1+3) = 1
331        assert!(approx_eq(apply_easing(&Easing::Spring, 1.0), 1.0));
332        // At t=0.5: 1 - 0.5^3 * (1 + 1.5) = 1 - 0.125*2.5 = 1 - 0.3125 = 0.6875
333        assert!(approx_eq(apply_easing(&Easing::Spring, 0.5), 0.6875));
334    }
335
336    #[test]
337    fn test_pose_graph_new() {
338        let params = make_params(&[("weight", 0.0)]);
339        let graph = PoseGraph::new("idle", params);
340        assert_eq!(graph.current_state(), "idle");
341        assert_eq!(graph.node_count(), 1);
342        assert_eq!(graph.transition_count(), 0);
343        assert!(!graph.is_transitioning());
344        assert!(approx_eq(graph.transition_progress(), 0.0));
345    }
346
347    #[test]
348    fn test_add_node() {
349        let params = make_params(&[("weight", 0.0)]);
350        let mut graph = PoseGraph::new("idle", params);
351        let walk_params = make_params(&[("weight", 1.0)]);
352        graph.add_node(PoseNode::new("walk", walk_params));
353        assert_eq!(graph.node_count(), 2);
354    }
355
356    #[test]
357    fn test_add_transition() {
358        let params = make_params(&[("weight", 0.0)]);
359        let mut graph = PoseGraph::new("idle", params);
360        let walk_params = make_params(&[("weight", 1.0)]);
361        graph.add_node(PoseNode::new("walk", walk_params));
362        graph.add_transition(PoseTransition::new("idle", "walk", 0.5));
363        assert_eq!(graph.transition_count(), 1);
364    }
365
366    #[test]
367    fn test_go_to_valid() {
368        let params = make_params(&[("weight", 0.0)]);
369        let mut graph = PoseGraph::new("idle", params);
370        let walk_params = make_params(&[("weight", 1.0)]);
371        graph.add_node(PoseNode::new("walk", walk_params));
372        graph.add_transition(PoseTransition::new("idle", "walk", 0.5));
373
374        let result = graph.go_to("walk");
375        assert!(result);
376        assert!(graph.is_transitioning());
377        assert_eq!(graph.current_state(), "idle");
378    }
379
380    #[test]
381    fn test_go_to_invalid() {
382        let params = make_params(&[("weight", 0.0)]);
383        let mut graph = PoseGraph::new("idle", params);
384        // No transition defined
385        let result = graph.go_to("run");
386        assert!(!result);
387        assert!(!graph.is_transitioning());
388    }
389
390    #[test]
391    fn test_update_completes_transition() {
392        let params = make_params(&[("weight", 0.0)]);
393        let mut graph = PoseGraph::new("idle", params);
394        let walk_params = make_params(&[("weight", 1.0)]);
395        graph.add_node(PoseNode::new("walk", walk_params));
396        graph.add_transition(PoseTransition::new("idle", "walk", 1.0));
397
398        graph.go_to("walk");
399        assert!(graph.is_transitioning());
400
401        // Advance past the end
402        graph.update(1.5);
403        assert!(!graph.is_transitioning());
404        assert_eq!(graph.current_state(), "walk");
405    }
406
407    #[test]
408    fn test_evaluate_no_transition() {
409        let params = make_params(&[("weight", 0.5), ("height", 1.8)]);
410        let graph = PoseGraph::new("idle", params.clone());
411        let evaluated = graph.evaluate();
412        assert!(approx_eq(
413            *evaluated.get("weight").expect("should succeed"),
414            0.5
415        ));
416        assert!(approx_eq(
417            *evaluated.get("height").expect("should succeed"),
418            1.8
419        ));
420    }
421
422    #[test]
423    fn test_evaluate_mid_transition() {
424        let idle_params = make_params(&[("weight", 0.0)]);
425        let mut graph = PoseGraph::new("idle", idle_params);
426        let walk_params = make_params(&[("weight", 1.0)]);
427        graph.add_node(PoseNode::new("walk", walk_params));
428        graph.add_transition(PoseTransition::new("idle", "walk", 1.0).with_easing(Easing::Linear));
429
430        graph.go_to("walk");
431        // Advance halfway
432        graph.update(0.5);
433        assert!(graph.is_transitioning());
434
435        let evaluated = graph.evaluate();
436        let w = *evaluated.get("weight").expect("should succeed");
437        // At t=0.5 linear: weight should be ~0.5
438        assert!(approx_eq(w, 0.5));
439    }
440
441    #[test]
442    fn test_check_triggers() {
443        let idle_params = make_params(&[("speed", 0.0)]);
444        let mut graph = PoseGraph::new("idle", idle_params);
445        let walk_params = make_params(&[("speed", 1.0)]);
446        graph.add_node(PoseNode::new("walk", walk_params));
447        graph.add_transition(PoseTransition::new("idle", "walk", 0.5).with_trigger("speed", 0.5));
448
449        // Trigger with speed below threshold — should NOT transition
450        let low_params = make_params(&[("speed", 0.3)]);
451        graph.check_triggers(&low_params);
452        assert!(!graph.is_transitioning());
453
454        // Trigger with speed above threshold — should transition
455        let high_params = make_params(&[("speed", 0.8)]);
456        graph.check_triggers(&high_params);
457        assert!(graph.is_transitioning());
458    }
459
460    #[test]
461    fn test_reachable_states() {
462        let idle_params = make_params(&[("w", 0.0)]);
463        let mut graph = PoseGraph::new("idle", idle_params);
464        graph.add_node(PoseNode::new("walk", make_params(&[("w", 1.0)])));
465        graph.add_node(PoseNode::new("run", make_params(&[("w", 2.0)])));
466        graph.add_transition(PoseTransition::new("idle", "walk", 0.3));
467        graph.add_transition(PoseTransition::new("walk", "run", 0.3));
468
469        let reachable = graph.reachable_states();
470        assert!(reachable.contains(&"walk"));
471        assert!(reachable.contains(&"run"));
472        assert!(!reachable.contains(&"idle"));
473    }
474}