1use chrono::Utc;
2use parking_lot::RwLock;
3
4use ai_agents_core::{AgentError, Result, StateMachineSnapshot, StateTransitionEvent};
5
6use super::config::{StateConfig, StateDefinition, Transition};
7
8pub struct StateMachine {
9 config: StateConfig,
10 current: RwLock<String>,
11 previous: RwLock<Option<String>>,
12 turn_count: RwLock<u32>,
13 no_transition_count: RwLock<u32>,
14 history: RwLock<Vec<StateTransitionEvent>>,
15}
16
17impl StateMachine {
18 pub fn new(config: StateConfig) -> Result<Self> {
19 config.validate()?;
20 let initial = Self::resolve_initial_state(&config)?;
21 Ok(Self {
22 config,
23 current: RwLock::new(initial),
24 previous: RwLock::new(None),
25 turn_count: RwLock::new(0),
26 no_transition_count: RwLock::new(0),
27 history: RwLock::new(Vec::new()),
28 })
29 }
30
31 fn resolve_initial_state(config: &StateConfig) -> Result<String> {
32 let mut path = config.initial.clone();
33 let mut current_def = config.states.get(&config.initial);
34
35 while let Some(def) = current_def {
36 if let (Some(initial_sub), Some(sub_states)) = (&def.initial, &def.states) {
37 path = format!("{}.{}", path, initial_sub);
38 current_def = sub_states.get(initial_sub);
39 } else {
40 break;
41 }
42 }
43
44 Ok(path)
45 }
46
47 pub fn current(&self) -> String {
48 self.current.read().clone()
49 }
50
51 pub fn previous(&self) -> Option<String> {
52 self.previous.read().clone()
53 }
54
55 pub fn current_definition(&self) -> Option<StateDefinition> {
56 let current = self.current.read();
57 self.config.get_state(¤t).cloned()
58 }
59
60 pub fn get_definition(&self, state: &str) -> Option<&StateDefinition> {
61 self.config.get_state(state)
62 }
63
64 pub fn get_parent_definition(&self) -> Option<StateDefinition> {
65 let current = self.current.read();
66 let parts: Vec<&str> = current.split('.').collect();
67 if parts.len() <= 1 {
68 return None;
69 }
70 let parent_path = parts[..parts.len() - 1].join(".");
71 self.config.get_state(&parent_path).cloned()
72 }
73
74 pub fn transition_to(&self, state: &str, reason: &str) -> Result<()> {
75 let current_path = self.current.read().clone();
76 let resolved_path = self.config.resolve_full_path(¤t_path, state);
77
78 if self.config.get_state(&resolved_path).is_none() {
79 return Err(AgentError::InvalidSpec(format!(
80 "Unknown state: {} (resolved from {})",
81 resolved_path, state
82 )));
83 }
84
85 let final_path = self.resolve_to_leaf_state(&resolved_path)?;
86
87 let from = {
88 let mut current = self.current.write();
89 let mut previous = self.previous.write();
90 let from = current.clone();
91 *previous = Some(from.clone());
92 *current = final_path.clone();
93 from
94 };
95
96 *self.turn_count.write() = 0;
97 *self.no_transition_count.write() = 0;
98
99 let event = StateTransitionEvent {
100 from,
101 to: final_path,
102 reason: reason.to_string(),
103 timestamp: Utc::now(),
104 };
105 self.history.write().push(event);
106
107 Ok(())
108 }
109
110 fn resolve_to_leaf_state(&self, path: &str) -> Result<String> {
111 let mut current_path = path.to_string();
112
113 loop {
114 let def = self.config.get_state(¤t_path).ok_or_else(|| {
115 AgentError::InvalidSpec(format!("State not found: {}", current_path))
116 })?;
117
118 if let (Some(initial_sub), Some(sub_states)) = (&def.initial, &def.states) {
119 if sub_states.contains_key(initial_sub) {
120 current_path = format!("{}.{}", current_path, initial_sub);
121 continue;
122 }
123 }
124 break;
125 }
126
127 Ok(current_path)
128 }
129
130 pub fn available_transitions(&self) -> Vec<Transition> {
131 let mut transitions = Vec::new();
132
133 if let Some(def) = self.current_definition() {
134 transitions.extend(def.transitions.clone());
135 }
136
137 transitions.extend(self.config.global_transitions.clone());
138
139 transitions.sort_by(|a, b| b.priority.cmp(&a.priority));
140 transitions
141 }
142
143 pub fn auto_transitions(&self) -> Vec<Transition> {
144 self.available_transitions()
145 .into_iter()
146 .filter(|t| t.auto)
147 .collect()
148 }
149
150 pub fn history(&self) -> Vec<StateTransitionEvent> {
151 self.history.read().clone()
152 }
153
154 pub fn increment_turn(&self) {
155 *self.turn_count.write() += 1;
156 }
157
158 pub fn turn_count(&self) -> u32 {
159 *self.turn_count.read()
160 }
161
162 pub fn increment_no_transition(&self) {
163 *self.no_transition_count.write() += 1;
164 }
165
166 pub fn no_transition_count(&self) -> u32 {
167 *self.no_transition_count.read()
168 }
169
170 pub fn reset_no_transition(&self) {
171 *self.no_transition_count.write() = 0;
172 }
173
174 pub fn check_fallback(&self) -> Option<String> {
175 if let Some(max) = self.config.max_no_transition {
176 if self.no_transition_count() >= max {
177 return self.config.fallback.clone();
178 }
179 }
180 None
181 }
182
183 pub fn reset(&self) {
184 let initial =
185 Self::resolve_initial_state(&self.config).unwrap_or(self.config.initial.clone());
186 *self.current.write() = initial;
187 *self.previous.write() = None;
188 *self.turn_count.write() = 0;
189 *self.no_transition_count.write() = 0;
190 self.history.write().clear();
191 }
192
193 pub fn snapshot(&self) -> StateMachineSnapshot {
194 StateMachineSnapshot {
195 current_state: self.current.read().clone(),
196 previous_state: self.previous.read().clone(),
197 turn_count: *self.turn_count.read(),
198 no_transition_count: *self.no_transition_count.read(),
199 history: self.history.read().clone(),
200 }
201 }
202
203 pub fn restore(&self, snapshot: StateMachineSnapshot) -> Result<()> {
204 if self.config.get_state(&snapshot.current_state).is_none() {
205 return Err(AgentError::InvalidSpec(format!(
206 "Snapshot contains unknown state: {}",
207 snapshot.current_state
208 )));
209 }
210 *self.current.write() = snapshot.current_state;
211 *self.previous.write() = snapshot.previous_state;
212 *self.turn_count.write() = snapshot.turn_count;
213 *self.no_transition_count.write() = snapshot.no_transition_count;
214 *self.history.write() = snapshot.history;
215 Ok(())
216 }
217
218 pub fn config(&self) -> &StateConfig {
219 &self.config
220 }
221
222 pub fn check_timeout(&self) -> Option<String> {
223 let def = self.current_definition()?;
224 let max_turns = def.max_turns?;
225 let timeout_to = def.timeout_to.as_ref()?;
226 if self.turn_count() >= max_turns {
227 let current_path = self.current.read().clone();
228 Some(self.config.resolve_full_path(¤t_path, timeout_to))
229 } else {
230 None
231 }
232 }
233
234 pub fn current_depth(&self) -> usize {
235 self.current.read().split('.').count()
236 }
237
238 pub fn is_in_sub_state(&self) -> bool {
239 self.current_depth() > 1
240 }
241
242 pub fn parent_state(&self) -> Option<String> {
243 let current = self.current.read();
244 let parts: Vec<&str> = current.split('.').collect();
245 if parts.len() > 1 {
246 Some(parts[..parts.len() - 1].join("."))
247 } else {
248 None
249 }
250 }
251
252 pub fn root_state(&self) -> String {
253 let current = self.current.read();
254 current.split('.').next().unwrap_or(¤t).to_string()
255 }
256
257 pub fn is_on_cooldown(&self, target: &str, cooldown_turns: u32) -> bool {
259 let history = self.history.read();
260 let total_transitions = history.len();
261 if total_transitions == 0 || cooldown_turns == 0 {
262 return false;
263 }
264 let start = total_transitions.saturating_sub(cooldown_turns as usize);
266 history[start..].iter().any(|e| e.to == target)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::super::config::TransitionTiming;
273 use super::*;
274 use std::collections::HashMap;
275
276 fn create_test_config() -> StateConfig {
277 let mut states = HashMap::new();
278 states.insert(
279 "greeting".into(),
280 StateDefinition {
281 prompt: Some("Welcome!".into()),
282 transitions: vec![Transition {
283 to: "support".into(),
284 when: "needs help".into(),
285 guard: None,
286 intent: None,
287 auto: true,
288 priority: 10,
289 cooldown_turns: None,
290 timing: TransitionTiming::PostResponse,
291 requires_response: false,
292 run_extractors: false,
293 }],
294 ..Default::default()
295 },
296 );
297 states.insert(
298 "support".into(),
299 StateDefinition {
300 prompt: Some("How can I help?".into()),
301 max_turns: Some(3),
302 timeout_to: Some("escalation".into()),
303 ..Default::default()
304 },
305 );
306 states.insert(
307 "escalation".into(),
308 StateDefinition {
309 prompt: Some("Escalating...".into()),
310 ..Default::default()
311 },
312 );
313 StateConfig {
314 initial: "greeting".into(),
315 states,
316 global_transitions: vec![],
317 fallback: None,
318 max_no_transition: None,
319 regenerate_on_transition: true,
320 }
321 }
322
323 fn create_hierarchical_config() -> StateConfig {
324 let mut sub_states = HashMap::new();
325 sub_states.insert(
326 "gathering_info".into(),
327 StateDefinition {
328 prompt: Some("Gathering info".into()),
329 transitions: vec![Transition {
330 to: "proposing".into(),
331 when: "understood".into(),
332 guard: None,
333 intent: None,
334 auto: true,
335 priority: 0,
336 cooldown_turns: None,
337 timing: TransitionTiming::PostResponse,
338 requires_response: false,
339 run_extractors: false,
340 }],
341 ..Default::default()
342 },
343 );
344 sub_states.insert(
345 "proposing".into(),
346 StateDefinition {
347 prompt: Some("Proposing solution".into()),
348 transitions: vec![Transition {
349 to: "^closing".into(),
350 when: "resolved".into(),
351 guard: None,
352 intent: None,
353 auto: true,
354 priority: 0,
355 cooldown_turns: None,
356 timing: TransitionTiming::PostResponse,
357 requires_response: false,
358 run_extractors: false,
359 }],
360 ..Default::default()
361 },
362 );
363
364 let mut states = HashMap::new();
365 states.insert(
366 "problem_solving".into(),
367 StateDefinition {
368 prompt: Some("Problem solving".into()),
369 initial: Some("gathering_info".into()),
370 states: Some(sub_states),
371 ..Default::default()
372 },
373 );
374 states.insert(
375 "closing".into(),
376 StateDefinition {
377 prompt: Some("Thank you".into()),
378 ..Default::default()
379 },
380 );
381
382 StateConfig {
383 initial: "problem_solving".into(),
384 states,
385 global_transitions: vec![],
386 fallback: None,
387 max_no_transition: None,
388 regenerate_on_transition: true,
389 }
390 }
391
392 #[test]
393 fn test_new_state_machine() {
394 let config = create_test_config();
395 let sm = StateMachine::new(config).unwrap();
396 assert_eq!(sm.current(), "greeting");
397 assert!(sm.previous().is_none());
398 assert_eq!(sm.turn_count(), 0);
399 }
400
401 #[test]
402 fn test_transition() {
403 let config = create_test_config();
404 let sm = StateMachine::new(config).unwrap();
405 sm.transition_to("support", "user asked for help").unwrap();
406 assert_eq!(sm.current(), "support");
407 assert_eq!(sm.previous(), Some("greeting".into()));
408 assert_eq!(sm.history().len(), 1);
409 }
410
411 #[test]
412 fn test_turn_counting() {
413 let config = create_test_config();
414 let sm = StateMachine::new(config).unwrap();
415 assert_eq!(sm.turn_count(), 0);
416 sm.increment_turn();
417 sm.increment_turn();
418 assert_eq!(sm.turn_count(), 2);
419 sm.transition_to("support", "reason").unwrap();
420 assert_eq!(sm.turn_count(), 0);
421 }
422
423 #[test]
424 fn test_timeout_check() {
425 let config = create_test_config();
426 let sm = StateMachine::new(config).unwrap();
427 sm.transition_to("support", "needs help").unwrap();
428 assert!(sm.check_timeout().is_none());
429 sm.increment_turn();
430 sm.increment_turn();
431 sm.increment_turn();
432 assert_eq!(sm.check_timeout(), Some("escalation".into()));
433 }
434
435 #[test]
436 fn test_snapshot_restore() {
437 let config = create_test_config();
438 let sm = StateMachine::new(config.clone()).unwrap();
439 sm.transition_to("support", "reason").unwrap();
440 sm.increment_turn();
441
442 let snapshot = sm.snapshot();
443 assert_eq!(snapshot.current_state, "support");
444 assert_eq!(snapshot.turn_count, 1);
445
446 let sm2 = StateMachine::new(config).unwrap();
447 sm2.restore(snapshot).unwrap();
448 assert_eq!(sm2.current(), "support");
449 assert_eq!(sm2.turn_count(), 1);
450 }
451
452 #[test]
453 fn test_reset() {
454 let config = create_test_config();
455 let sm = StateMachine::new(config).unwrap();
456 sm.transition_to("support", "reason").unwrap();
457 sm.increment_turn();
458 sm.reset();
459 assert_eq!(sm.current(), "greeting");
460 assert!(sm.previous().is_none());
461 assert_eq!(sm.turn_count(), 0);
462 assert!(sm.history().is_empty());
463 }
464
465 #[test]
466 fn test_hierarchical_initial_state() {
467 let config = create_hierarchical_config();
468 let sm = StateMachine::new(config).unwrap();
469 assert_eq!(sm.current(), "problem_solving.gathering_info");
470 }
471
472 #[test]
473 fn test_hierarchical_transition_sibling() {
474 let config = create_hierarchical_config();
475 let sm = StateMachine::new(config).unwrap();
476 assert_eq!(sm.current(), "problem_solving.gathering_info");
477
478 sm.transition_to("proposing", "understood").unwrap();
479 assert_eq!(sm.current(), "problem_solving.proposing");
480 }
481
482 #[test]
483 fn test_hierarchical_transition_parent() {
484 let config = create_hierarchical_config();
485 let sm = StateMachine::new(config).unwrap();
486 sm.transition_to("proposing", "understood").unwrap();
487 sm.transition_to("^closing", "resolved").unwrap();
488 assert_eq!(sm.current(), "closing");
489 }
490
491 #[test]
492 fn test_current_depth() {
493 let config = create_hierarchical_config();
494 let sm = StateMachine::new(config).unwrap();
495 assert_eq!(sm.current_depth(), 2);
496 assert!(sm.is_in_sub_state());
497
498 sm.transition_to("^closing", "done").unwrap();
499 assert_eq!(sm.current_depth(), 1);
500 assert!(!sm.is_in_sub_state());
501 }
502
503 #[test]
504 fn test_parent_state() {
505 let config = create_hierarchical_config();
506 let sm = StateMachine::new(config).unwrap();
507 assert_eq!(sm.parent_state(), Some("problem_solving".into()));
508
509 sm.transition_to("^closing", "done").unwrap();
510 assert!(sm.parent_state().is_none());
511 }
512
513 #[test]
514 fn test_root_state() {
515 let config = create_hierarchical_config();
516 let sm = StateMachine::new(config).unwrap();
517 assert_eq!(sm.root_state(), "problem_solving");
518
519 sm.transition_to("^closing", "done").unwrap();
520 assert_eq!(sm.root_state(), "closing");
521 }
522
523 #[test]
524 fn test_no_transition_count() {
525 let config = create_test_config();
526 let sm = StateMachine::new(config).unwrap();
527
528 assert_eq!(sm.no_transition_count(), 0);
529 sm.increment_no_transition();
530 sm.increment_no_transition();
531 assert_eq!(sm.no_transition_count(), 2);
532
533 sm.reset_no_transition();
534 assert_eq!(sm.no_transition_count(), 0);
535 }
536
537 #[test]
538 fn test_fallback() {
539 let mut config = create_test_config();
540 config.fallback = Some("escalation".into());
541 config.max_no_transition = Some(3);
542
543 let sm = StateMachine::new(config).unwrap();
544 assert!(sm.check_fallback().is_none());
545
546 sm.increment_no_transition();
547 sm.increment_no_transition();
548 sm.increment_no_transition();
549 assert_eq!(sm.check_fallback(), Some("escalation".into()));
550 }
551
552 #[test]
553 fn test_global_transitions() {
554 let mut config = create_test_config();
555 config.global_transitions = vec![Transition {
556 to: "escalation".into(),
557 when: "user is angry".into(),
558 guard: None,
559 intent: None,
560 auto: true,
561 priority: 100,
562 cooldown_turns: None,
563 timing: TransitionTiming::PostResponse,
564 requires_response: false,
565 run_extractors: false,
566 }];
567
568 let sm = StateMachine::new(config).unwrap();
569 let transitions = sm.available_transitions();
570
571 assert!(
572 transitions
573 .iter()
574 .any(|t| t.to == "escalation" && t.priority == 100)
575 );
576 assert_eq!(transitions[0].to, "escalation");
577 }
578
579 #[test]
580 fn test_get_parent_definition() {
581 let config = create_hierarchical_config();
582 let sm = StateMachine::new(config).unwrap();
583
584 let parent = sm.get_parent_definition();
585 assert!(parent.is_some());
586 assert_eq!(parent.unwrap().prompt, Some("Problem solving".into()));
587
588 sm.transition_to("^closing", "done").unwrap();
589 assert!(sm.get_parent_definition().is_none());
590 }
591}