Skip to main content

ai_agents_state/
machine.rs

1use chrono::Utc;
2use parking_lot::RwLock;
3
4use ai_agents_core::{AgentError, Result, StateMachineSnapshot, StateTransitionEvent};
5
6use super::config::{StateConfig, StateDefinition, Transition};
7
8pub struct StateMachine {
9    config: StateConfig,
10    current: RwLock<String>,
11    previous: RwLock<Option<String>>,
12    turn_count: RwLock<u32>,
13    no_transition_count: RwLock<u32>,
14    history: RwLock<Vec<StateTransitionEvent>>,
15}
16
17impl StateMachine {
18    pub fn new(config: StateConfig) -> Result<Self> {
19        config.validate()?;
20        let initial = Self::resolve_initial_state(&config)?;
21        Ok(Self {
22            config,
23            current: RwLock::new(initial),
24            previous: RwLock::new(None),
25            turn_count: RwLock::new(0),
26            no_transition_count: RwLock::new(0),
27            history: RwLock::new(Vec::new()),
28        })
29    }
30
31    fn resolve_initial_state(config: &StateConfig) -> Result<String> {
32        let mut path = config.initial.clone();
33        let mut current_def = config.states.get(&config.initial);
34
35        while let Some(def) = current_def {
36            if let (Some(initial_sub), Some(sub_states)) = (&def.initial, &def.states) {
37                path = format!("{}.{}", path, initial_sub);
38                current_def = sub_states.get(initial_sub);
39            } else {
40                break;
41            }
42        }
43
44        Ok(path)
45    }
46
47    pub fn current(&self) -> String {
48        self.current.read().clone()
49    }
50
51    pub fn previous(&self) -> Option<String> {
52        self.previous.read().clone()
53    }
54
55    pub fn current_definition(&self) -> Option<StateDefinition> {
56        let current = self.current.read();
57        self.config.get_state(&current).cloned()
58    }
59
60    pub fn get_definition(&self, state: &str) -> Option<&StateDefinition> {
61        self.config.get_state(state)
62    }
63
64    pub fn get_parent_definition(&self) -> Option<StateDefinition> {
65        let current = self.current.read();
66        let parts: Vec<&str> = current.split('.').collect();
67        if parts.len() <= 1 {
68            return None;
69        }
70        let parent_path = parts[..parts.len() - 1].join(".");
71        self.config.get_state(&parent_path).cloned()
72    }
73
74    pub fn transition_to(&self, state: &str, reason: &str) -> Result<()> {
75        let current_path = self.current.read().clone();
76        let resolved_path = self.config.resolve_full_path(&current_path, state);
77
78        if self.config.get_state(&resolved_path).is_none() {
79            return Err(AgentError::InvalidSpec(format!(
80                "Unknown state: {} (resolved from {})",
81                resolved_path, state
82            )));
83        }
84
85        let final_path = self.resolve_to_leaf_state(&resolved_path)?;
86
87        let from = {
88            let mut current = self.current.write();
89            let mut previous = self.previous.write();
90            let from = current.clone();
91            *previous = Some(from.clone());
92            *current = final_path.clone();
93            from
94        };
95
96        *self.turn_count.write() = 0;
97        *self.no_transition_count.write() = 0;
98
99        let event = StateTransitionEvent {
100            from,
101            to: final_path,
102            reason: reason.to_string(),
103            timestamp: Utc::now(),
104        };
105        self.history.write().push(event);
106
107        Ok(())
108    }
109
110    fn resolve_to_leaf_state(&self, path: &str) -> Result<String> {
111        let mut current_path = path.to_string();
112
113        loop {
114            let def = self.config.get_state(&current_path).ok_or_else(|| {
115                AgentError::InvalidSpec(format!("State not found: {}", current_path))
116            })?;
117
118            if let (Some(initial_sub), Some(sub_states)) = (&def.initial, &def.states) {
119                if sub_states.contains_key(initial_sub) {
120                    current_path = format!("{}.{}", current_path, initial_sub);
121                    continue;
122                }
123            }
124            break;
125        }
126
127        Ok(current_path)
128    }
129
130    pub fn available_transitions(&self) -> Vec<Transition> {
131        let mut transitions = Vec::new();
132
133        if let Some(def) = self.current_definition() {
134            transitions.extend(def.transitions.clone());
135        }
136
137        transitions.extend(self.config.global_transitions.clone());
138
139        transitions.sort_by(|a, b| b.priority.cmp(&a.priority));
140        transitions
141    }
142
143    pub fn auto_transitions(&self) -> Vec<Transition> {
144        self.available_transitions()
145            .into_iter()
146            .filter(|t| t.auto)
147            .collect()
148    }
149
150    pub fn history(&self) -> Vec<StateTransitionEvent> {
151        self.history.read().clone()
152    }
153
154    pub fn increment_turn(&self) {
155        *self.turn_count.write() += 1;
156    }
157
158    pub fn turn_count(&self) -> u32 {
159        *self.turn_count.read()
160    }
161
162    pub fn increment_no_transition(&self) {
163        *self.no_transition_count.write() += 1;
164    }
165
166    pub fn no_transition_count(&self) -> u32 {
167        *self.no_transition_count.read()
168    }
169
170    pub fn reset_no_transition(&self) {
171        *self.no_transition_count.write() = 0;
172    }
173
174    pub fn check_fallback(&self) -> Option<String> {
175        if let Some(max) = self.config.max_no_transition {
176            if self.no_transition_count() >= max {
177                return self.config.fallback.clone();
178            }
179        }
180        None
181    }
182
183    pub fn reset(&self) {
184        let initial =
185            Self::resolve_initial_state(&self.config).unwrap_or(self.config.initial.clone());
186        *self.current.write() = initial;
187        *self.previous.write() = None;
188        *self.turn_count.write() = 0;
189        *self.no_transition_count.write() = 0;
190        self.history.write().clear();
191    }
192
193    pub fn snapshot(&self) -> StateMachineSnapshot {
194        StateMachineSnapshot {
195            current_state: self.current.read().clone(),
196            previous_state: self.previous.read().clone(),
197            turn_count: *self.turn_count.read(),
198            no_transition_count: *self.no_transition_count.read(),
199            history: self.history.read().clone(),
200        }
201    }
202
203    pub fn restore(&self, snapshot: StateMachineSnapshot) -> Result<()> {
204        if self.config.get_state(&snapshot.current_state).is_none() {
205            return Err(AgentError::InvalidSpec(format!(
206                "Snapshot contains unknown state: {}",
207                snapshot.current_state
208            )));
209        }
210        *self.current.write() = snapshot.current_state;
211        *self.previous.write() = snapshot.previous_state;
212        *self.turn_count.write() = snapshot.turn_count;
213        *self.no_transition_count.write() = snapshot.no_transition_count;
214        *self.history.write() = snapshot.history;
215        Ok(())
216    }
217
218    pub fn config(&self) -> &StateConfig {
219        &self.config
220    }
221
222    pub fn check_timeout(&self) -> Option<String> {
223        let def = self.current_definition()?;
224        let max_turns = def.max_turns?;
225        let timeout_to = def.timeout_to.as_ref()?;
226        if self.turn_count() >= max_turns {
227            let current_path = self.current.read().clone();
228            Some(self.config.resolve_full_path(&current_path, timeout_to))
229        } else {
230            None
231        }
232    }
233
234    pub fn current_depth(&self) -> usize {
235        self.current.read().split('.').count()
236    }
237
238    pub fn is_in_sub_state(&self) -> bool {
239        self.current_depth() > 1
240    }
241
242    pub fn parent_state(&self) -> Option<String> {
243        let current = self.current.read();
244        let parts: Vec<&str> = current.split('.').collect();
245        if parts.len() > 1 {
246            Some(parts[..parts.len() - 1].join("."))
247        } else {
248            None
249        }
250    }
251
252    pub fn root_state(&self) -> String {
253        let current = self.current.read();
254        current.split('.').next().unwrap_or(&current).to_string()
255    }
256
257    /// Check if a transition target is on cooldown (was recently transitioned to).
258    pub fn is_on_cooldown(&self, target: &str, cooldown_turns: u32) -> bool {
259        let history = self.history.read();
260        let total_transitions = history.len();
261        if total_transitions == 0 || cooldown_turns == 0 {
262            return false;
263        }
264        // Look at the last `cooldown_turns` transitions
265        let start = total_transitions.saturating_sub(cooldown_turns as usize);
266        history[start..].iter().any(|e| e.to == target)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use std::collections::HashMap;
274
275    fn create_test_config() -> StateConfig {
276        let mut states = HashMap::new();
277        states.insert(
278            "greeting".into(),
279            StateDefinition {
280                prompt: Some("Welcome!".into()),
281                transitions: vec![Transition {
282                    to: "support".into(),
283                    when: "needs help".into(),
284                    guard: None,
285                    intent: None,
286                    auto: true,
287                    priority: 10,
288                    cooldown_turns: None,
289                }],
290                ..Default::default()
291            },
292        );
293        states.insert(
294            "support".into(),
295            StateDefinition {
296                prompt: Some("How can I help?".into()),
297                max_turns: Some(3),
298                timeout_to: Some("escalation".into()),
299                ..Default::default()
300            },
301        );
302        states.insert(
303            "escalation".into(),
304            StateDefinition {
305                prompt: Some("Escalating...".into()),
306                ..Default::default()
307            },
308        );
309        StateConfig {
310            initial: "greeting".into(),
311            states,
312            global_transitions: vec![],
313            fallback: None,
314            max_no_transition: None,
315            regenerate_on_transition: true,
316        }
317    }
318
319    fn create_hierarchical_config() -> StateConfig {
320        let mut sub_states = HashMap::new();
321        sub_states.insert(
322            "gathering_info".into(),
323            StateDefinition {
324                prompt: Some("Gathering info".into()),
325                transitions: vec![Transition {
326                    to: "proposing".into(),
327                    when: "understood".into(),
328                    guard: None,
329                    intent: None,
330                    auto: true,
331                    priority: 0,
332                    cooldown_turns: None,
333                }],
334                ..Default::default()
335            },
336        );
337        sub_states.insert(
338            "proposing".into(),
339            StateDefinition {
340                prompt: Some("Proposing solution".into()),
341                transitions: vec![Transition {
342                    to: "^closing".into(),
343                    when: "resolved".into(),
344                    guard: None,
345                    intent: None,
346                    auto: true,
347                    priority: 0,
348                    cooldown_turns: None,
349                }],
350                ..Default::default()
351            },
352        );
353
354        let mut states = HashMap::new();
355        states.insert(
356            "problem_solving".into(),
357            StateDefinition {
358                prompt: Some("Problem solving".into()),
359                initial: Some("gathering_info".into()),
360                states: Some(sub_states),
361                ..Default::default()
362            },
363        );
364        states.insert(
365            "closing".into(),
366            StateDefinition {
367                prompt: Some("Thank you".into()),
368                ..Default::default()
369            },
370        );
371
372        StateConfig {
373            initial: "problem_solving".into(),
374            states,
375            global_transitions: vec![],
376            fallback: None,
377            max_no_transition: None,
378            regenerate_on_transition: true,
379        }
380    }
381
382    #[test]
383    fn test_new_state_machine() {
384        let config = create_test_config();
385        let sm = StateMachine::new(config).unwrap();
386        assert_eq!(sm.current(), "greeting");
387        assert!(sm.previous().is_none());
388        assert_eq!(sm.turn_count(), 0);
389    }
390
391    #[test]
392    fn test_transition() {
393        let config = create_test_config();
394        let sm = StateMachine::new(config).unwrap();
395        sm.transition_to("support", "user asked for help").unwrap();
396        assert_eq!(sm.current(), "support");
397        assert_eq!(sm.previous(), Some("greeting".into()));
398        assert_eq!(sm.history().len(), 1);
399    }
400
401    #[test]
402    fn test_turn_counting() {
403        let config = create_test_config();
404        let sm = StateMachine::new(config).unwrap();
405        assert_eq!(sm.turn_count(), 0);
406        sm.increment_turn();
407        sm.increment_turn();
408        assert_eq!(sm.turn_count(), 2);
409        sm.transition_to("support", "reason").unwrap();
410        assert_eq!(sm.turn_count(), 0);
411    }
412
413    #[test]
414    fn test_timeout_check() {
415        let config = create_test_config();
416        let sm = StateMachine::new(config).unwrap();
417        sm.transition_to("support", "needs help").unwrap();
418        assert!(sm.check_timeout().is_none());
419        sm.increment_turn();
420        sm.increment_turn();
421        sm.increment_turn();
422        assert_eq!(sm.check_timeout(), Some("escalation".into()));
423    }
424
425    #[test]
426    fn test_snapshot_restore() {
427        let config = create_test_config();
428        let sm = StateMachine::new(config.clone()).unwrap();
429        sm.transition_to("support", "reason").unwrap();
430        sm.increment_turn();
431
432        let snapshot = sm.snapshot();
433        assert_eq!(snapshot.current_state, "support");
434        assert_eq!(snapshot.turn_count, 1);
435
436        let sm2 = StateMachine::new(config).unwrap();
437        sm2.restore(snapshot).unwrap();
438        assert_eq!(sm2.current(), "support");
439        assert_eq!(sm2.turn_count(), 1);
440    }
441
442    #[test]
443    fn test_reset() {
444        let config = create_test_config();
445        let sm = StateMachine::new(config).unwrap();
446        sm.transition_to("support", "reason").unwrap();
447        sm.increment_turn();
448        sm.reset();
449        assert_eq!(sm.current(), "greeting");
450        assert!(sm.previous().is_none());
451        assert_eq!(sm.turn_count(), 0);
452        assert!(sm.history().is_empty());
453    }
454
455    #[test]
456    fn test_hierarchical_initial_state() {
457        let config = create_hierarchical_config();
458        let sm = StateMachine::new(config).unwrap();
459        assert_eq!(sm.current(), "problem_solving.gathering_info");
460    }
461
462    #[test]
463    fn test_hierarchical_transition_sibling() {
464        let config = create_hierarchical_config();
465        let sm = StateMachine::new(config).unwrap();
466        assert_eq!(sm.current(), "problem_solving.gathering_info");
467
468        sm.transition_to("proposing", "understood").unwrap();
469        assert_eq!(sm.current(), "problem_solving.proposing");
470    }
471
472    #[test]
473    fn test_hierarchical_transition_parent() {
474        let config = create_hierarchical_config();
475        let sm = StateMachine::new(config).unwrap();
476        sm.transition_to("proposing", "understood").unwrap();
477        sm.transition_to("^closing", "resolved").unwrap();
478        assert_eq!(sm.current(), "closing");
479    }
480
481    #[test]
482    fn test_current_depth() {
483        let config = create_hierarchical_config();
484        let sm = StateMachine::new(config).unwrap();
485        assert_eq!(sm.current_depth(), 2);
486        assert!(sm.is_in_sub_state());
487
488        sm.transition_to("^closing", "done").unwrap();
489        assert_eq!(sm.current_depth(), 1);
490        assert!(!sm.is_in_sub_state());
491    }
492
493    #[test]
494    fn test_parent_state() {
495        let config = create_hierarchical_config();
496        let sm = StateMachine::new(config).unwrap();
497        assert_eq!(sm.parent_state(), Some("problem_solving".into()));
498
499        sm.transition_to("^closing", "done").unwrap();
500        assert!(sm.parent_state().is_none());
501    }
502
503    #[test]
504    fn test_root_state() {
505        let config = create_hierarchical_config();
506        let sm = StateMachine::new(config).unwrap();
507        assert_eq!(sm.root_state(), "problem_solving");
508
509        sm.transition_to("^closing", "done").unwrap();
510        assert_eq!(sm.root_state(), "closing");
511    }
512
513    #[test]
514    fn test_no_transition_count() {
515        let config = create_test_config();
516        let sm = StateMachine::new(config).unwrap();
517
518        assert_eq!(sm.no_transition_count(), 0);
519        sm.increment_no_transition();
520        sm.increment_no_transition();
521        assert_eq!(sm.no_transition_count(), 2);
522
523        sm.reset_no_transition();
524        assert_eq!(sm.no_transition_count(), 0);
525    }
526
527    #[test]
528    fn test_fallback() {
529        let mut config = create_test_config();
530        config.fallback = Some("escalation".into());
531        config.max_no_transition = Some(3);
532
533        let sm = StateMachine::new(config).unwrap();
534        assert!(sm.check_fallback().is_none());
535
536        sm.increment_no_transition();
537        sm.increment_no_transition();
538        sm.increment_no_transition();
539        assert_eq!(sm.check_fallback(), Some("escalation".into()));
540    }
541
542    #[test]
543    fn test_global_transitions() {
544        let mut config = create_test_config();
545        config.global_transitions = vec![Transition {
546            to: "escalation".into(),
547            when: "user is angry".into(),
548            guard: None,
549            intent: None,
550            auto: true,
551            priority: 100,
552            cooldown_turns: None,
553        }];
554
555        let sm = StateMachine::new(config).unwrap();
556        let transitions = sm.available_transitions();
557
558        assert!(
559            transitions
560                .iter()
561                .any(|t| t.to == "escalation" && t.priority == 100)
562        );
563        assert_eq!(transitions[0].to, "escalation");
564    }
565
566    #[test]
567    fn test_get_parent_definition() {
568        let config = create_hierarchical_config();
569        let sm = StateMachine::new(config).unwrap();
570
571        let parent = sm.get_parent_definition();
572        assert!(parent.is_some());
573        assert_eq!(parent.unwrap().prompt, Some("Problem solving".into()));
574
575        sm.transition_to("^closing", "done").unwrap();
576        assert!(sm.get_parent_definition().is_none());
577    }
578}