1use chrono::Utc;
2use parking_lot::RwLock;
3
4use ai_agents_core::{AgentError, Result, StateMachineSnapshot, StateTransitionEvent};
5
6use super::config::{StateConfig, StateDefinition, Transition};
7
8pub struct StateMachine {
9 config: StateConfig,
10 current: RwLock<String>,
11 previous: RwLock<Option<String>>,
12 turn_count: RwLock<u32>,
13 no_transition_count: RwLock<u32>,
14 history: RwLock<Vec<StateTransitionEvent>>,
15}
16
17impl StateMachine {
18 pub fn new(config: StateConfig) -> Result<Self> {
19 config.validate()?;
20 let initial = Self::resolve_initial_state(&config)?;
21 Ok(Self {
22 config,
23 current: RwLock::new(initial),
24 previous: RwLock::new(None),
25 turn_count: RwLock::new(0),
26 no_transition_count: RwLock::new(0),
27 history: RwLock::new(Vec::new()),
28 })
29 }
30
31 fn resolve_initial_state(config: &StateConfig) -> Result<String> {
32 let mut path = config.initial.clone();
33 let mut current_def = config.states.get(&config.initial);
34
35 while let Some(def) = current_def {
36 if let (Some(initial_sub), Some(sub_states)) = (&def.initial, &def.states) {
37 path = format!("{}.{}", path, initial_sub);
38 current_def = sub_states.get(initial_sub);
39 } else {
40 break;
41 }
42 }
43
44 Ok(path)
45 }
46
47 pub fn current(&self) -> String {
48 self.current.read().clone()
49 }
50
51 pub fn previous(&self) -> Option<String> {
52 self.previous.read().clone()
53 }
54
55 pub fn current_definition(&self) -> Option<StateDefinition> {
56 let current = self.current.read();
57 self.config.get_state(¤t).cloned()
58 }
59
60 pub fn get_definition(&self, state: &str) -> Option<&StateDefinition> {
61 self.config.get_state(state)
62 }
63
64 pub fn get_parent_definition(&self) -> Option<StateDefinition> {
65 let current = self.current.read();
66 let parts: Vec<&str> = current.split('.').collect();
67 if parts.len() <= 1 {
68 return None;
69 }
70 let parent_path = parts[..parts.len() - 1].join(".");
71 self.config.get_state(&parent_path).cloned()
72 }
73
74 pub fn transition_to(&self, state: &str, reason: &str) -> Result<()> {
75 let current_path = self.current.read().clone();
76 let resolved_path = self.config.resolve_full_path(¤t_path, state);
77
78 if self.config.get_state(&resolved_path).is_none() {
79 return Err(AgentError::InvalidSpec(format!(
80 "Unknown state: {} (resolved from {})",
81 resolved_path, state
82 )));
83 }
84
85 let final_path = self.resolve_to_leaf_state(&resolved_path)?;
86
87 let from = {
88 let mut current = self.current.write();
89 let mut previous = self.previous.write();
90 let from = current.clone();
91 *previous = Some(from.clone());
92 *current = final_path.clone();
93 from
94 };
95
96 *self.turn_count.write() = 0;
97 *self.no_transition_count.write() = 0;
98
99 let event = StateTransitionEvent {
100 from,
101 to: final_path,
102 reason: reason.to_string(),
103 timestamp: Utc::now(),
104 };
105 self.history.write().push(event);
106
107 Ok(())
108 }
109
110 fn resolve_to_leaf_state(&self, path: &str) -> Result<String> {
111 let mut current_path = path.to_string();
112
113 loop {
114 let def = self.config.get_state(¤t_path).ok_or_else(|| {
115 AgentError::InvalidSpec(format!("State not found: {}", current_path))
116 })?;
117
118 if let (Some(initial_sub), Some(sub_states)) = (&def.initial, &def.states) {
119 if sub_states.contains_key(initial_sub) {
120 current_path = format!("{}.{}", current_path, initial_sub);
121 continue;
122 }
123 }
124 break;
125 }
126
127 Ok(current_path)
128 }
129
130 pub fn available_transitions(&self) -> Vec<Transition> {
131 let mut transitions = Vec::new();
132
133 if let Some(def) = self.current_definition() {
134 transitions.extend(def.transitions.clone());
135 }
136
137 transitions.extend(self.config.global_transitions.clone());
138
139 transitions.sort_by(|a, b| b.priority.cmp(&a.priority));
140 transitions
141 }
142
143 pub fn auto_transitions(&self) -> Vec<Transition> {
144 self.available_transitions()
145 .into_iter()
146 .filter(|t| t.auto)
147 .collect()
148 }
149
150 pub fn history(&self) -> Vec<StateTransitionEvent> {
151 self.history.read().clone()
152 }
153
154 pub fn increment_turn(&self) {
155 *self.turn_count.write() += 1;
156 }
157
158 pub fn turn_count(&self) -> u32 {
159 *self.turn_count.read()
160 }
161
162 pub fn increment_no_transition(&self) {
163 *self.no_transition_count.write() += 1;
164 }
165
166 pub fn no_transition_count(&self) -> u32 {
167 *self.no_transition_count.read()
168 }
169
170 pub fn reset_no_transition(&self) {
171 *self.no_transition_count.write() = 0;
172 }
173
174 pub fn check_fallback(&self) -> Option<String> {
175 if let Some(max) = self.config.max_no_transition {
176 if self.no_transition_count() >= max {
177 return self.config.fallback.clone();
178 }
179 }
180 None
181 }
182
183 pub fn reset(&self) {
184 let initial =
185 Self::resolve_initial_state(&self.config).unwrap_or(self.config.initial.clone());
186 *self.current.write() = initial;
187 *self.previous.write() = None;
188 *self.turn_count.write() = 0;
189 *self.no_transition_count.write() = 0;
190 self.history.write().clear();
191 }
192
193 pub fn snapshot(&self) -> StateMachineSnapshot {
194 StateMachineSnapshot {
195 current_state: self.current.read().clone(),
196 previous_state: self.previous.read().clone(),
197 turn_count: *self.turn_count.read(),
198 no_transition_count: *self.no_transition_count.read(),
199 history: self.history.read().clone(),
200 }
201 }
202
203 pub fn restore(&self, snapshot: StateMachineSnapshot) -> Result<()> {
204 if self.config.get_state(&snapshot.current_state).is_none() {
205 return Err(AgentError::InvalidSpec(format!(
206 "Snapshot contains unknown state: {}",
207 snapshot.current_state
208 )));
209 }
210 *self.current.write() = snapshot.current_state;
211 *self.previous.write() = snapshot.previous_state;
212 *self.turn_count.write() = snapshot.turn_count;
213 *self.no_transition_count.write() = snapshot.no_transition_count;
214 *self.history.write() = snapshot.history;
215 Ok(())
216 }
217
218 pub fn config(&self) -> &StateConfig {
219 &self.config
220 }
221
222 pub fn check_timeout(&self) -> Option<String> {
223 let def = self.current_definition()?;
224 let max_turns = def.max_turns?;
225 let timeout_to = def.timeout_to.as_ref()?;
226 if self.turn_count() >= max_turns {
227 let current_path = self.current.read().clone();
228 Some(self.config.resolve_full_path(¤t_path, timeout_to))
229 } else {
230 None
231 }
232 }
233
234 pub fn current_depth(&self) -> usize {
235 self.current.read().split('.').count()
236 }
237
238 pub fn is_in_sub_state(&self) -> bool {
239 self.current_depth() > 1
240 }
241
242 pub fn parent_state(&self) -> Option<String> {
243 let current = self.current.read();
244 let parts: Vec<&str> = current.split('.').collect();
245 if parts.len() > 1 {
246 Some(parts[..parts.len() - 1].join("."))
247 } else {
248 None
249 }
250 }
251
252 pub fn root_state(&self) -> String {
253 let current = self.current.read();
254 current.split('.').next().unwrap_or(¤t).to_string()
255 }
256
257 pub fn is_on_cooldown(&self, target: &str, cooldown_turns: u32) -> bool {
259 let history = self.history.read();
260 let total_transitions = history.len();
261 if total_transitions == 0 || cooldown_turns == 0 {
262 return false;
263 }
264 let start = total_transitions.saturating_sub(cooldown_turns as usize);
266 history[start..].iter().any(|e| e.to == target)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use std::collections::HashMap;
274
275 fn create_test_config() -> StateConfig {
276 let mut states = HashMap::new();
277 states.insert(
278 "greeting".into(),
279 StateDefinition {
280 prompt: Some("Welcome!".into()),
281 transitions: vec![Transition {
282 to: "support".into(),
283 when: "needs help".into(),
284 guard: None,
285 intent: None,
286 auto: true,
287 priority: 10,
288 cooldown_turns: None,
289 }],
290 ..Default::default()
291 },
292 );
293 states.insert(
294 "support".into(),
295 StateDefinition {
296 prompt: Some("How can I help?".into()),
297 max_turns: Some(3),
298 timeout_to: Some("escalation".into()),
299 ..Default::default()
300 },
301 );
302 states.insert(
303 "escalation".into(),
304 StateDefinition {
305 prompt: Some("Escalating...".into()),
306 ..Default::default()
307 },
308 );
309 StateConfig {
310 initial: "greeting".into(),
311 states,
312 global_transitions: vec![],
313 fallback: None,
314 max_no_transition: None,
315 regenerate_on_transition: true,
316 }
317 }
318
319 fn create_hierarchical_config() -> StateConfig {
320 let mut sub_states = HashMap::new();
321 sub_states.insert(
322 "gathering_info".into(),
323 StateDefinition {
324 prompt: Some("Gathering info".into()),
325 transitions: vec![Transition {
326 to: "proposing".into(),
327 when: "understood".into(),
328 guard: None,
329 intent: None,
330 auto: true,
331 priority: 0,
332 cooldown_turns: None,
333 }],
334 ..Default::default()
335 },
336 );
337 sub_states.insert(
338 "proposing".into(),
339 StateDefinition {
340 prompt: Some("Proposing solution".into()),
341 transitions: vec![Transition {
342 to: "^closing".into(),
343 when: "resolved".into(),
344 guard: None,
345 intent: None,
346 auto: true,
347 priority: 0,
348 cooldown_turns: None,
349 }],
350 ..Default::default()
351 },
352 );
353
354 let mut states = HashMap::new();
355 states.insert(
356 "problem_solving".into(),
357 StateDefinition {
358 prompt: Some("Problem solving".into()),
359 initial: Some("gathering_info".into()),
360 states: Some(sub_states),
361 ..Default::default()
362 },
363 );
364 states.insert(
365 "closing".into(),
366 StateDefinition {
367 prompt: Some("Thank you".into()),
368 ..Default::default()
369 },
370 );
371
372 StateConfig {
373 initial: "problem_solving".into(),
374 states,
375 global_transitions: vec![],
376 fallback: None,
377 max_no_transition: None,
378 regenerate_on_transition: true,
379 }
380 }
381
382 #[test]
383 fn test_new_state_machine() {
384 let config = create_test_config();
385 let sm = StateMachine::new(config).unwrap();
386 assert_eq!(sm.current(), "greeting");
387 assert!(sm.previous().is_none());
388 assert_eq!(sm.turn_count(), 0);
389 }
390
391 #[test]
392 fn test_transition() {
393 let config = create_test_config();
394 let sm = StateMachine::new(config).unwrap();
395 sm.transition_to("support", "user asked for help").unwrap();
396 assert_eq!(sm.current(), "support");
397 assert_eq!(sm.previous(), Some("greeting".into()));
398 assert_eq!(sm.history().len(), 1);
399 }
400
401 #[test]
402 fn test_turn_counting() {
403 let config = create_test_config();
404 let sm = StateMachine::new(config).unwrap();
405 assert_eq!(sm.turn_count(), 0);
406 sm.increment_turn();
407 sm.increment_turn();
408 assert_eq!(sm.turn_count(), 2);
409 sm.transition_to("support", "reason").unwrap();
410 assert_eq!(sm.turn_count(), 0);
411 }
412
413 #[test]
414 fn test_timeout_check() {
415 let config = create_test_config();
416 let sm = StateMachine::new(config).unwrap();
417 sm.transition_to("support", "needs help").unwrap();
418 assert!(sm.check_timeout().is_none());
419 sm.increment_turn();
420 sm.increment_turn();
421 sm.increment_turn();
422 assert_eq!(sm.check_timeout(), Some("escalation".into()));
423 }
424
425 #[test]
426 fn test_snapshot_restore() {
427 let config = create_test_config();
428 let sm = StateMachine::new(config.clone()).unwrap();
429 sm.transition_to("support", "reason").unwrap();
430 sm.increment_turn();
431
432 let snapshot = sm.snapshot();
433 assert_eq!(snapshot.current_state, "support");
434 assert_eq!(snapshot.turn_count, 1);
435
436 let sm2 = StateMachine::new(config).unwrap();
437 sm2.restore(snapshot).unwrap();
438 assert_eq!(sm2.current(), "support");
439 assert_eq!(sm2.turn_count(), 1);
440 }
441
442 #[test]
443 fn test_reset() {
444 let config = create_test_config();
445 let sm = StateMachine::new(config).unwrap();
446 sm.transition_to("support", "reason").unwrap();
447 sm.increment_turn();
448 sm.reset();
449 assert_eq!(sm.current(), "greeting");
450 assert!(sm.previous().is_none());
451 assert_eq!(sm.turn_count(), 0);
452 assert!(sm.history().is_empty());
453 }
454
455 #[test]
456 fn test_hierarchical_initial_state() {
457 let config = create_hierarchical_config();
458 let sm = StateMachine::new(config).unwrap();
459 assert_eq!(sm.current(), "problem_solving.gathering_info");
460 }
461
462 #[test]
463 fn test_hierarchical_transition_sibling() {
464 let config = create_hierarchical_config();
465 let sm = StateMachine::new(config).unwrap();
466 assert_eq!(sm.current(), "problem_solving.gathering_info");
467
468 sm.transition_to("proposing", "understood").unwrap();
469 assert_eq!(sm.current(), "problem_solving.proposing");
470 }
471
472 #[test]
473 fn test_hierarchical_transition_parent() {
474 let config = create_hierarchical_config();
475 let sm = StateMachine::new(config).unwrap();
476 sm.transition_to("proposing", "understood").unwrap();
477 sm.transition_to("^closing", "resolved").unwrap();
478 assert_eq!(sm.current(), "closing");
479 }
480
481 #[test]
482 fn test_current_depth() {
483 let config = create_hierarchical_config();
484 let sm = StateMachine::new(config).unwrap();
485 assert_eq!(sm.current_depth(), 2);
486 assert!(sm.is_in_sub_state());
487
488 sm.transition_to("^closing", "done").unwrap();
489 assert_eq!(sm.current_depth(), 1);
490 assert!(!sm.is_in_sub_state());
491 }
492
493 #[test]
494 fn test_parent_state() {
495 let config = create_hierarchical_config();
496 let sm = StateMachine::new(config).unwrap();
497 assert_eq!(sm.parent_state(), Some("problem_solving".into()));
498
499 sm.transition_to("^closing", "done").unwrap();
500 assert!(sm.parent_state().is_none());
501 }
502
503 #[test]
504 fn test_root_state() {
505 let config = create_hierarchical_config();
506 let sm = StateMachine::new(config).unwrap();
507 assert_eq!(sm.root_state(), "problem_solving");
508
509 sm.transition_to("^closing", "done").unwrap();
510 assert_eq!(sm.root_state(), "closing");
511 }
512
513 #[test]
514 fn test_no_transition_count() {
515 let config = create_test_config();
516 let sm = StateMachine::new(config).unwrap();
517
518 assert_eq!(sm.no_transition_count(), 0);
519 sm.increment_no_transition();
520 sm.increment_no_transition();
521 assert_eq!(sm.no_transition_count(), 2);
522
523 sm.reset_no_transition();
524 assert_eq!(sm.no_transition_count(), 0);
525 }
526
527 #[test]
528 fn test_fallback() {
529 let mut config = create_test_config();
530 config.fallback = Some("escalation".into());
531 config.max_no_transition = Some(3);
532
533 let sm = StateMachine::new(config).unwrap();
534 assert!(sm.check_fallback().is_none());
535
536 sm.increment_no_transition();
537 sm.increment_no_transition();
538 sm.increment_no_transition();
539 assert_eq!(sm.check_fallback(), Some("escalation".into()));
540 }
541
542 #[test]
543 fn test_global_transitions() {
544 let mut config = create_test_config();
545 config.global_transitions = vec![Transition {
546 to: "escalation".into(),
547 when: "user is angry".into(),
548 guard: None,
549 intent: None,
550 auto: true,
551 priority: 100,
552 cooldown_turns: None,
553 }];
554
555 let sm = StateMachine::new(config).unwrap();
556 let transitions = sm.available_transitions();
557
558 assert!(
559 transitions
560 .iter()
561 .any(|t| t.to == "escalation" && t.priority == 100)
562 );
563 assert_eq!(transitions[0].to, "escalation");
564 }
565
566 #[test]
567 fn test_get_parent_definition() {
568 let config = create_hierarchical_config();
569 let sm = StateMachine::new(config).unwrap();
570
571 let parent = sm.get_parent_definition();
572 assert!(parent.is_some());
573 assert_eq!(parent.unwrap().prompt, Some("Problem solving".into()));
574
575 sm.transition_to("^closing", "done").unwrap();
576 assert!(sm.get_parent_definition().is_none());
577 }
578}