Skip to main content

ai_agents_state/
machine.rs

1use chrono::Utc;
2use parking_lot::RwLock;
3
4use ai_agents_core::{AgentError, Result, StateMachineSnapshot, StateTransitionEvent};
5
6use super::config::{StateConfig, StateDefinition, Transition};
7
8pub struct StateMachine {
9    config: StateConfig,
10    current: RwLock<String>,
11    previous: RwLock<Option<String>>,
12    turn_count: RwLock<u32>,
13    no_transition_count: RwLock<u32>,
14    history: RwLock<Vec<StateTransitionEvent>>,
15}
16
17impl StateMachine {
18    pub fn new(config: StateConfig) -> Result<Self> {
19        config.validate()?;
20        let initial = Self::resolve_initial_state(&config)?;
21        Ok(Self {
22            config,
23            current: RwLock::new(initial),
24            previous: RwLock::new(None),
25            turn_count: RwLock::new(0),
26            no_transition_count: RwLock::new(0),
27            history: RwLock::new(Vec::new()),
28        })
29    }
30
31    fn resolve_initial_state(config: &StateConfig) -> Result<String> {
32        let mut path = config.initial.clone();
33        let mut current_def = config.states.get(&config.initial);
34
35        while let Some(def) = current_def {
36            if let (Some(initial_sub), Some(sub_states)) = (&def.initial, &def.states) {
37                path = format!("{}.{}", path, initial_sub);
38                current_def = sub_states.get(initial_sub);
39            } else {
40                break;
41            }
42        }
43
44        Ok(path)
45    }
46
47    pub fn current(&self) -> String {
48        self.current.read().clone()
49    }
50
51    pub fn previous(&self) -> Option<String> {
52        self.previous.read().clone()
53    }
54
55    pub fn current_definition(&self) -> Option<StateDefinition> {
56        let current = self.current.read();
57        self.config.get_state(&current).cloned()
58    }
59
60    pub fn get_definition(&self, state: &str) -> Option<&StateDefinition> {
61        self.config.get_state(state)
62    }
63
64    pub fn get_parent_definition(&self) -> Option<StateDefinition> {
65        let current = self.current.read();
66        let parts: Vec<&str> = current.split('.').collect();
67        if parts.len() <= 1 {
68            return None;
69        }
70        let parent_path = parts[..parts.len() - 1].join(".");
71        self.config.get_state(&parent_path).cloned()
72    }
73
74    pub fn transition_to(&self, state: &str, reason: &str) -> Result<()> {
75        let current_path = self.current.read().clone();
76        let resolved_path = self.config.resolve_full_path(&current_path, state);
77
78        if self.config.get_state(&resolved_path).is_none() {
79            return Err(AgentError::InvalidSpec(format!(
80                "Unknown state: {} (resolved from {})",
81                resolved_path, state
82            )));
83        }
84
85        let final_path = self.resolve_to_leaf_state(&resolved_path)?;
86
87        let from = {
88            let mut current = self.current.write();
89            let mut previous = self.previous.write();
90            let from = current.clone();
91            *previous = Some(from.clone());
92            *current = final_path.clone();
93            from
94        };
95
96        *self.turn_count.write() = 0;
97        *self.no_transition_count.write() = 0;
98
99        let event = StateTransitionEvent {
100            from,
101            to: final_path,
102            reason: reason.to_string(),
103            timestamp: Utc::now(),
104        };
105        self.history.write().push(event);
106
107        Ok(())
108    }
109
110    fn resolve_to_leaf_state(&self, path: &str) -> Result<String> {
111        let mut current_path = path.to_string();
112
113        loop {
114            let def = self.config.get_state(&current_path).ok_or_else(|| {
115                AgentError::InvalidSpec(format!("State not found: {}", current_path))
116            })?;
117
118            if let (Some(initial_sub), Some(sub_states)) = (&def.initial, &def.states) {
119                if sub_states.contains_key(initial_sub) {
120                    current_path = format!("{}.{}", current_path, initial_sub);
121                    continue;
122                }
123            }
124            break;
125        }
126
127        Ok(current_path)
128    }
129
130    pub fn available_transitions(&self) -> Vec<Transition> {
131        let mut transitions = Vec::new();
132
133        if let Some(def) = self.current_definition() {
134            transitions.extend(def.transitions.clone());
135        }
136
137        transitions.extend(self.config.global_transitions.clone());
138
139        transitions.sort_by(|a, b| b.priority.cmp(&a.priority));
140        transitions
141    }
142
143    pub fn auto_transitions(&self) -> Vec<Transition> {
144        self.available_transitions()
145            .into_iter()
146            .filter(|t| t.auto)
147            .collect()
148    }
149
150    pub fn history(&self) -> Vec<StateTransitionEvent> {
151        self.history.read().clone()
152    }
153
154    pub fn increment_turn(&self) {
155        *self.turn_count.write() += 1;
156    }
157
158    pub fn turn_count(&self) -> u32 {
159        *self.turn_count.read()
160    }
161
162    pub fn increment_no_transition(&self) {
163        *self.no_transition_count.write() += 1;
164    }
165
166    pub fn no_transition_count(&self) -> u32 {
167        *self.no_transition_count.read()
168    }
169
170    pub fn reset_no_transition(&self) {
171        *self.no_transition_count.write() = 0;
172    }
173
174    pub fn check_fallback(&self) -> Option<String> {
175        if let Some(max) = self.config.max_no_transition {
176            if self.no_transition_count() >= max {
177                return self.config.fallback.clone();
178            }
179        }
180        None
181    }
182
183    pub fn reset(&self) {
184        let initial =
185            Self::resolve_initial_state(&self.config).unwrap_or(self.config.initial.clone());
186        *self.current.write() = initial;
187        *self.previous.write() = None;
188        *self.turn_count.write() = 0;
189        *self.no_transition_count.write() = 0;
190        self.history.write().clear();
191    }
192
193    pub fn snapshot(&self) -> StateMachineSnapshot {
194        StateMachineSnapshot {
195            current_state: self.current.read().clone(),
196            previous_state: self.previous.read().clone(),
197            turn_count: *self.turn_count.read(),
198            no_transition_count: *self.no_transition_count.read(),
199            history: self.history.read().clone(),
200        }
201    }
202
203    pub fn restore(&self, snapshot: StateMachineSnapshot) -> Result<()> {
204        if self.config.get_state(&snapshot.current_state).is_none() {
205            return Err(AgentError::InvalidSpec(format!(
206                "Snapshot contains unknown state: {}",
207                snapshot.current_state
208            )));
209        }
210        *self.current.write() = snapshot.current_state;
211        *self.previous.write() = snapshot.previous_state;
212        *self.turn_count.write() = snapshot.turn_count;
213        *self.no_transition_count.write() = snapshot.no_transition_count;
214        *self.history.write() = snapshot.history;
215        Ok(())
216    }
217
218    pub fn config(&self) -> &StateConfig {
219        &self.config
220    }
221
222    pub fn check_timeout(&self) -> Option<String> {
223        let def = self.current_definition()?;
224        let max_turns = def.max_turns?;
225        let timeout_to = def.timeout_to.as_ref()?;
226        if self.turn_count() >= max_turns {
227            let current_path = self.current.read().clone();
228            Some(self.config.resolve_full_path(&current_path, timeout_to))
229        } else {
230            None
231        }
232    }
233
234    pub fn current_depth(&self) -> usize {
235        self.current.read().split('.').count()
236    }
237
238    pub fn is_in_sub_state(&self) -> bool {
239        self.current_depth() > 1
240    }
241
242    pub fn parent_state(&self) -> Option<String> {
243        let current = self.current.read();
244        let parts: Vec<&str> = current.split('.').collect();
245        if parts.len() > 1 {
246            Some(parts[..parts.len() - 1].join("."))
247        } else {
248            None
249        }
250    }
251
252    pub fn root_state(&self) -> String {
253        let current = self.current.read();
254        current.split('.').next().unwrap_or(&current).to_string()
255    }
256
257    /// Check if a transition target is on cooldown (was recently transitioned to).
258    pub fn is_on_cooldown(&self, target: &str, cooldown_turns: u32) -> bool {
259        let history = self.history.read();
260        let total_transitions = history.len();
261        if total_transitions == 0 || cooldown_turns == 0 {
262            return false;
263        }
264        // Look at the last `cooldown_turns` transitions
265        let start = total_transitions.saturating_sub(cooldown_turns as usize);
266        history[start..].iter().any(|e| e.to == target)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::super::config::TransitionTiming;
273    use super::*;
274    use std::collections::HashMap;
275
276    fn create_test_config() -> StateConfig {
277        let mut states = HashMap::new();
278        states.insert(
279            "greeting".into(),
280            StateDefinition {
281                prompt: Some("Welcome!".into()),
282                transitions: vec![Transition {
283                    to: "support".into(),
284                    when: "needs help".into(),
285                    guard: None,
286                    intent: None,
287                    auto: true,
288                    priority: 10,
289                    cooldown_turns: None,
290                    timing: TransitionTiming::PostResponse,
291                    requires_response: false,
292                    run_extractors: false,
293                }],
294                ..Default::default()
295            },
296        );
297        states.insert(
298            "support".into(),
299            StateDefinition {
300                prompt: Some("How can I help?".into()),
301                max_turns: Some(3),
302                timeout_to: Some("escalation".into()),
303                ..Default::default()
304            },
305        );
306        states.insert(
307            "escalation".into(),
308            StateDefinition {
309                prompt: Some("Escalating...".into()),
310                ..Default::default()
311            },
312        );
313        StateConfig {
314            initial: "greeting".into(),
315            states,
316            global_transitions: vec![],
317            fallback: None,
318            max_no_transition: None,
319            regenerate_on_transition: true,
320        }
321    }
322
323    fn create_hierarchical_config() -> StateConfig {
324        let mut sub_states = HashMap::new();
325        sub_states.insert(
326            "gathering_info".into(),
327            StateDefinition {
328                prompt: Some("Gathering info".into()),
329                transitions: vec![Transition {
330                    to: "proposing".into(),
331                    when: "understood".into(),
332                    guard: None,
333                    intent: None,
334                    auto: true,
335                    priority: 0,
336                    cooldown_turns: None,
337                    timing: TransitionTiming::PostResponse,
338                    requires_response: false,
339                    run_extractors: false,
340                }],
341                ..Default::default()
342            },
343        );
344        sub_states.insert(
345            "proposing".into(),
346            StateDefinition {
347                prompt: Some("Proposing solution".into()),
348                transitions: vec![Transition {
349                    to: "^closing".into(),
350                    when: "resolved".into(),
351                    guard: None,
352                    intent: None,
353                    auto: true,
354                    priority: 0,
355                    cooldown_turns: None,
356                    timing: TransitionTiming::PostResponse,
357                    requires_response: false,
358                    run_extractors: false,
359                }],
360                ..Default::default()
361            },
362        );
363
364        let mut states = HashMap::new();
365        states.insert(
366            "problem_solving".into(),
367            StateDefinition {
368                prompt: Some("Problem solving".into()),
369                initial: Some("gathering_info".into()),
370                states: Some(sub_states),
371                ..Default::default()
372            },
373        );
374        states.insert(
375            "closing".into(),
376            StateDefinition {
377                prompt: Some("Thank you".into()),
378                ..Default::default()
379            },
380        );
381
382        StateConfig {
383            initial: "problem_solving".into(),
384            states,
385            global_transitions: vec![],
386            fallback: None,
387            max_no_transition: None,
388            regenerate_on_transition: true,
389        }
390    }
391
392    #[test]
393    fn test_new_state_machine() {
394        let config = create_test_config();
395        let sm = StateMachine::new(config).unwrap();
396        assert_eq!(sm.current(), "greeting");
397        assert!(sm.previous().is_none());
398        assert_eq!(sm.turn_count(), 0);
399    }
400
401    #[test]
402    fn test_transition() {
403        let config = create_test_config();
404        let sm = StateMachine::new(config).unwrap();
405        sm.transition_to("support", "user asked for help").unwrap();
406        assert_eq!(sm.current(), "support");
407        assert_eq!(sm.previous(), Some("greeting".into()));
408        assert_eq!(sm.history().len(), 1);
409    }
410
411    #[test]
412    fn test_turn_counting() {
413        let config = create_test_config();
414        let sm = StateMachine::new(config).unwrap();
415        assert_eq!(sm.turn_count(), 0);
416        sm.increment_turn();
417        sm.increment_turn();
418        assert_eq!(sm.turn_count(), 2);
419        sm.transition_to("support", "reason").unwrap();
420        assert_eq!(sm.turn_count(), 0);
421    }
422
423    #[test]
424    fn test_timeout_check() {
425        let config = create_test_config();
426        let sm = StateMachine::new(config).unwrap();
427        sm.transition_to("support", "needs help").unwrap();
428        assert!(sm.check_timeout().is_none());
429        sm.increment_turn();
430        sm.increment_turn();
431        sm.increment_turn();
432        assert_eq!(sm.check_timeout(), Some("escalation".into()));
433    }
434
435    #[test]
436    fn test_snapshot_restore() {
437        let config = create_test_config();
438        let sm = StateMachine::new(config.clone()).unwrap();
439        sm.transition_to("support", "reason").unwrap();
440        sm.increment_turn();
441
442        let snapshot = sm.snapshot();
443        assert_eq!(snapshot.current_state, "support");
444        assert_eq!(snapshot.turn_count, 1);
445
446        let sm2 = StateMachine::new(config).unwrap();
447        sm2.restore(snapshot).unwrap();
448        assert_eq!(sm2.current(), "support");
449        assert_eq!(sm2.turn_count(), 1);
450    }
451
452    #[test]
453    fn test_reset() {
454        let config = create_test_config();
455        let sm = StateMachine::new(config).unwrap();
456        sm.transition_to("support", "reason").unwrap();
457        sm.increment_turn();
458        sm.reset();
459        assert_eq!(sm.current(), "greeting");
460        assert!(sm.previous().is_none());
461        assert_eq!(sm.turn_count(), 0);
462        assert!(sm.history().is_empty());
463    }
464
465    #[test]
466    fn test_hierarchical_initial_state() {
467        let config = create_hierarchical_config();
468        let sm = StateMachine::new(config).unwrap();
469        assert_eq!(sm.current(), "problem_solving.gathering_info");
470    }
471
472    #[test]
473    fn test_hierarchical_transition_sibling() {
474        let config = create_hierarchical_config();
475        let sm = StateMachine::new(config).unwrap();
476        assert_eq!(sm.current(), "problem_solving.gathering_info");
477
478        sm.transition_to("proposing", "understood").unwrap();
479        assert_eq!(sm.current(), "problem_solving.proposing");
480    }
481
482    #[test]
483    fn test_hierarchical_transition_parent() {
484        let config = create_hierarchical_config();
485        let sm = StateMachine::new(config).unwrap();
486        sm.transition_to("proposing", "understood").unwrap();
487        sm.transition_to("^closing", "resolved").unwrap();
488        assert_eq!(sm.current(), "closing");
489    }
490
491    #[test]
492    fn test_current_depth() {
493        let config = create_hierarchical_config();
494        let sm = StateMachine::new(config).unwrap();
495        assert_eq!(sm.current_depth(), 2);
496        assert!(sm.is_in_sub_state());
497
498        sm.transition_to("^closing", "done").unwrap();
499        assert_eq!(sm.current_depth(), 1);
500        assert!(!sm.is_in_sub_state());
501    }
502
503    #[test]
504    fn test_parent_state() {
505        let config = create_hierarchical_config();
506        let sm = StateMachine::new(config).unwrap();
507        assert_eq!(sm.parent_state(), Some("problem_solving".into()));
508
509        sm.transition_to("^closing", "done").unwrap();
510        assert!(sm.parent_state().is_none());
511    }
512
513    #[test]
514    fn test_root_state() {
515        let config = create_hierarchical_config();
516        let sm = StateMachine::new(config).unwrap();
517        assert_eq!(sm.root_state(), "problem_solving");
518
519        sm.transition_to("^closing", "done").unwrap();
520        assert_eq!(sm.root_state(), "closing");
521    }
522
523    #[test]
524    fn test_no_transition_count() {
525        let config = create_test_config();
526        let sm = StateMachine::new(config).unwrap();
527
528        assert_eq!(sm.no_transition_count(), 0);
529        sm.increment_no_transition();
530        sm.increment_no_transition();
531        assert_eq!(sm.no_transition_count(), 2);
532
533        sm.reset_no_transition();
534        assert_eq!(sm.no_transition_count(), 0);
535    }
536
537    #[test]
538    fn test_fallback() {
539        let mut config = create_test_config();
540        config.fallback = Some("escalation".into());
541        config.max_no_transition = Some(3);
542
543        let sm = StateMachine::new(config).unwrap();
544        assert!(sm.check_fallback().is_none());
545
546        sm.increment_no_transition();
547        sm.increment_no_transition();
548        sm.increment_no_transition();
549        assert_eq!(sm.check_fallback(), Some("escalation".into()));
550    }
551
552    #[test]
553    fn test_global_transitions() {
554        let mut config = create_test_config();
555        config.global_transitions = vec![Transition {
556            to: "escalation".into(),
557            when: "user is angry".into(),
558            guard: None,
559            intent: None,
560            auto: true,
561            priority: 100,
562            cooldown_turns: None,
563            timing: TransitionTiming::PostResponse,
564            requires_response: false,
565            run_extractors: false,
566        }];
567
568        let sm = StateMachine::new(config).unwrap();
569        let transitions = sm.available_transitions();
570
571        assert!(
572            transitions
573                .iter()
574                .any(|t| t.to == "escalation" && t.priority == 100)
575        );
576        assert_eq!(transitions[0].to, "escalation");
577    }
578
579    #[test]
580    fn test_get_parent_definition() {
581        let config = create_hierarchical_config();
582        let sm = StateMachine::new(config).unwrap();
583
584        let parent = sm.get_parent_definition();
585        assert!(parent.is_some());
586        assert_eq!(parent.unwrap().prompt, Some("Problem solving".into()));
587
588        sm.transition_to("^closing", "done").unwrap();
589        assert!(sm.get_parent_definition().is_none());
590    }
591}