1use ai_agents_core::{AgentError, Result};
2use ai_agents_disambiguation::StateDisambiguationOverride;
3use ai_agents_process::ProcessConfig;
4use ai_agents_reasoning::{ReasoningConfig, ReflectionConfig};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
10#[serde(rename_all = "lowercase")]
11pub enum PromptMode {
12 #[default]
13 Append,
14 Replace,
15 Prepend,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StateConfig {
20 pub initial: String,
21 #[serde(default)]
22 pub states: HashMap<String, StateDefinition>,
23 #[serde(default)]
24 pub global_transitions: Vec<Transition>,
25 #[serde(default)]
26 pub fallback: Option<String>,
27 #[serde(default)]
28 pub max_no_transition: Option<u32>,
29
30 #[serde(default = "default_true")]
32 pub regenerate_on_transition: bool,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, Default)]
36pub struct StateDefinition {
37 #[serde(default)]
38 pub prompt: Option<String>,
39
40 #[serde(default)]
41 pub prompt_mode: PromptMode,
42
43 #[serde(default)]
44 pub llm: Option<String>,
45
46 #[serde(default)]
47 pub skills: Vec<String>,
48
49 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub tools: Option<Vec<ToolRef>>,
55
56 #[serde(default)]
57 pub transitions: Vec<Transition>,
58
59 #[serde(default)]
60 pub max_turns: Option<u32>,
61
62 #[serde(default)]
63 pub timeout_to: Option<String>,
64
65 #[serde(default)]
66 pub initial: Option<String>,
67
68 #[serde(default)]
69 pub states: Option<HashMap<String, StateDefinition>>,
70
71 #[serde(default = "default_inherit_parent")]
72 pub inherit_parent: bool,
73
74 #[serde(default)]
75 pub on_enter: Vec<StateAction>,
76
77 #[serde(default)]
79 pub on_reenter: Vec<StateAction>,
80
81 #[serde(default)]
82 pub on_exit: Vec<StateAction>,
83
84 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub regenerate_on_enter: Option<bool>,
87
88 #[serde(default)]
90 pub extract: Vec<ContextExtractor>,
91
92 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub reasoning: Option<ReasoningConfig>,
94
95 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub reflection: Option<ReflectionConfig>,
97
98 #[serde(default, skip_serializing_if = "Option::is_none")]
99 pub disambiguation: Option<StateDisambiguationOverride>,
100
101 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub process: Option<ProcessConfig>,
104
105 #[serde(default, skip_serializing_if = "Option::is_none")]
107 pub delegate: Option<String>,
108
109 #[serde(default, skip_serializing_if = "Option::is_none")]
111 pub delegate_context: Option<DelegateContextMode>,
112
113 #[serde(default, skip_serializing_if = "Option::is_none")]
115 pub concurrent: Option<ConcurrentStateConfig>,
116
117 #[serde(default, skip_serializing_if = "Option::is_none")]
119 pub group_chat: Option<GroupChatStateConfig>,
120
121 #[serde(default, skip_serializing_if = "Option::is_none")]
123 pub pipeline: Option<PipelineStateConfig>,
124
125 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub handoff: Option<HandoffStateConfig>,
128}
129
130fn default_inherit_parent() -> bool {
131 true
132}
133
134fn default_true() -> bool {
135 true
136}
137
138fn default_extractor_llm() -> String {
139 "router".to_string()
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum ToolRef {
145 Simple(String),
146 Conditional {
147 id: String,
148 condition: ToolCondition,
149 },
150}
151
152impl ToolRef {
153 pub fn id(&self) -> &str {
154 match self {
155 ToolRef::Simple(id) => id,
156 ToolRef::Conditional { id, .. } => id,
157 }
158 }
159
160 pub fn condition(&self) -> Option<&ToolCondition> {
161 match self {
162 ToolRef::Simple(_) => None,
163 ToolRef::Conditional { condition, .. } => Some(condition),
164 }
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum ToolCondition {
171 Context(HashMap<String, ContextMatcher>),
172 State(StateMatcher),
173 AfterTool(String),
174 ToolResult {
175 tool: String,
176 result: HashMap<String, Value>,
177 },
178 Semantic {
179 when: String,
180 #[serde(default = "default_semantic_llm")]
181 llm: String,
182 #[serde(default = "default_threshold")]
183 threshold: f32,
184 },
185 Time(TimeMatcher),
186 All(Vec<ToolCondition>),
187 Any(Vec<ToolCondition>),
188 Not(Box<ToolCondition>),
189}
190
191fn default_semantic_llm() -> String {
192 "router".to_string()
193}
194
195fn default_threshold() -> f32 {
196 0.7
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200#[serde(untagged)]
201pub enum ContextMatcher {
202 Exists { exists: bool },
207 Compare(CompareOp),
208 Exact(Value),
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212#[serde(rename_all = "snake_case")]
213pub enum CompareOp {
214 Eq(Value),
215 Neq(Value),
216 Gt(f64),
217 Gte(f64),
218 Lt(f64),
219 Lte(f64),
220 In(Vec<Value>),
221 Contains(String),
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize, Default)]
225pub struct StateMatcher {
226 #[serde(default)]
227 pub name: Option<String>,
228 #[serde(default)]
229 pub turn_count: Option<CompareOp>,
230 #[serde(default)]
231 pub previous: Option<String>,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, Default)]
235pub struct TimeMatcher {
236 #[serde(default)]
237 pub hours: Option<CompareOp>,
238 #[serde(default)]
239 pub day_of_week: Option<Vec<String>>,
240 #[serde(default)]
241 pub timezone: Option<String>,
242}
243
244#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
245#[serde(rename_all = "snake_case")]
246pub enum TransitionTiming {
247 PostResponse,
249 PreResponse,
251 Parallel,
253}
254
255impl Default for TransitionTiming {
256 fn default() -> Self {
257 Self::PostResponse
258 }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct Transition {
263 pub to: String,
264 #[serde(default)]
265 pub when: String,
266 #[serde(default)]
267 pub guard: Option<TransitionGuard>,
268 #[serde(default, skip_serializing_if = "Option::is_none")]
270 pub intent: Option<String>,
271 #[serde(default = "default_auto")]
272 pub auto: bool,
273 #[serde(default)]
274 pub priority: u8,
275
276 #[serde(default, skip_serializing_if = "Option::is_none")]
278 pub cooldown_turns: Option<u32>,
279
280 #[serde(default)]
282 pub timing: TransitionTiming,
283
284 #[serde(default)]
286 pub requires_response: bool,
287
288 #[serde(default)]
290 pub run_extractors: bool,
291}
292
293fn default_auto() -> bool {
294 true
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
298#[serde(untagged)]
299pub enum TransitionGuard {
300 Expression(String),
301 Conditions(GuardConditions),
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[serde(rename_all = "snake_case")]
306pub enum GuardConditions {
307 All(Vec<String>),
308 Any(Vec<String>),
309 Context(HashMap<String, ContextMatcher>),
310}
311
312#[derive(Debug, Clone, Serialize, Deserialize)]
313#[serde(untagged)]
314pub enum StateAction {
315 Tool {
316 tool: String,
317 #[serde(default)]
318 args: Option<Value>,
319 },
320 Skill {
321 skill: String,
322 },
323 Prompt {
324 prompt: String,
325 #[serde(default)]
326 llm: Option<String>,
327 #[serde(default)]
328 store_as: Option<String>,
329 },
330 SetContext {
331 set_context: HashMap<String, Value>,
332 },
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct ContextExtractor {
338 pub key: String,
340
341 #[serde(default)]
343 pub description: Option<String>,
344
345 #[serde(default)]
347 pub llm_extract: Option<String>,
348
349 #[serde(default = "default_extractor_llm")]
351 pub llm: String,
352
353 #[serde(default)]
355 pub required: bool,
356}
357
358#[derive(Debug, Clone, Default, Serialize, Deserialize)]
364#[serde(rename_all = "snake_case")]
365pub enum DelegateContextMode {
366 #[default]
368 InputOnly,
369 Summary,
371 Full,
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct ConcurrentStateConfig {
378 pub agents: Vec<ConcurrentAgentRef>,
380 #[serde(default, skip_serializing_if = "Option::is_none")]
382 pub input: Option<String>,
383 pub aggregation: AggregationConfig,
385 #[serde(default, skip_serializing_if = "Option::is_none")]
387 pub min_required: Option<usize>,
388 #[serde(default)]
390 pub on_partial_failure: PartialFailureAction,
391 #[serde(default, skip_serializing_if = "Option::is_none")]
393 pub timeout_ms: Option<u64>,
394 #[serde(default, skip_serializing_if = "Option::is_none")]
396 pub context_mode: Option<DelegateContextMode>,
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
401#[serde(untagged)]
402pub enum ConcurrentAgentRef {
403 Id(String),
404 Weighted { id: String, weight: f64 },
405}
406
407impl ConcurrentAgentRef {
408 pub fn id(&self) -> &str {
409 match self {
410 Self::Id(id) => id,
411 Self::Weighted { id, .. } => id,
412 }
413 }
414
415 pub fn weight(&self) -> f64 {
416 match self {
417 Self::Id(_) => 1.0,
418 Self::Weighted { weight, .. } => *weight,
419 }
420 }
421}
422
423#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct AggregationConfig {
426 pub strategy: AggregationStrategy,
428 #[serde(default, skip_serializing_if = "Option::is_none")]
430 pub synthesizer_llm: Option<String>,
431 #[serde(default, skip_serializing_if = "Option::is_none")]
433 pub synthesizer_prompt: Option<String>,
434 #[serde(default, skip_serializing_if = "Option::is_none")]
436 pub vote: Option<VoteConfig>,
437}
438
439#[derive(Debug, Clone, Serialize, Deserialize)]
440#[serde(rename_all = "snake_case")]
441pub enum AggregationStrategy {
442 Voting,
443 LlmSynthesis,
444 FirstWins,
445 All,
446}
447
448#[derive(Debug, Clone, Serialize, Deserialize)]
450pub struct VoteConfig {
451 #[serde(default)]
452 pub method: VoteMethod,
453 #[serde(default)]
454 pub tiebreaker: TiebreakerStrategy,
455 #[serde(default, skip_serializing_if = "Option::is_none")]
457 pub vote_prompt: Option<String>,
458}
459
460#[derive(Debug, Clone, Default, Serialize, Deserialize)]
461#[serde(rename_all = "snake_case")]
462pub enum VoteMethod {
463 #[default]
464 Majority,
465 Weighted,
466 Unanimous,
467}
468
469#[derive(Debug, Clone, Default, Serialize, Deserialize)]
470#[serde(rename_all = "snake_case")]
471pub enum TiebreakerStrategy {
472 #[default]
473 First,
474 Random,
475 RouterDecides,
476}
477
478#[derive(Debug, Clone, Default, Serialize, Deserialize)]
479#[serde(rename_all = "snake_case")]
480pub enum PartialFailureAction {
481 #[default]
482 ProceedWithAvailable,
483 Abort,
484}
485
486#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct GroupChatStateConfig {
489 pub participants: Vec<ChatParticipant>,
491 #[serde(default)]
493 pub style: ChatStyle,
494 #[serde(default = "default_max_rounds")]
496 pub max_rounds: u32,
497 #[serde(default, skip_serializing_if = "Option::is_none")]
499 pub manager: Option<ChatManagerConfig>,
500 #[serde(default)]
502 pub termination: TerminationConfig,
503 #[serde(default, skip_serializing_if = "Option::is_none")]
505 pub debate: Option<DebateStyleConfig>,
506 #[serde(default, skip_serializing_if = "Option::is_none")]
508 pub maker_checker: Option<MakerCheckerConfig>,
509 #[serde(default, skip_serializing_if = "Option::is_none")]
511 pub timeout_ms: Option<u64>,
512 #[serde(default, skip_serializing_if = "Option::is_none")]
515 pub input: Option<String>,
516 #[serde(default, skip_serializing_if = "Option::is_none")]
518 pub context_mode: Option<DelegateContextMode>,
519}
520
521#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct ChatParticipant {
524 pub id: String,
526 #[serde(default, skip_serializing_if = "Option::is_none")]
528 pub role: Option<String>,
529}
530
531#[derive(Debug, Clone, Default, Serialize, Deserialize)]
532#[serde(rename_all = "snake_case")]
533pub enum ChatStyle {
534 #[default]
535 Brainstorm,
536 Debate,
537 MakerChecker,
538 Consensus,
539}
540
541#[derive(Debug, Clone, Serialize, Deserialize)]
543pub struct ChatManagerConfig {
544 #[serde(default, skip_serializing_if = "Option::is_none")]
546 pub agent: Option<String>,
547 #[serde(default, skip_serializing_if = "Option::is_none")]
549 pub method: Option<TurnMethod>,
550}
551
552#[derive(Debug, Clone, Serialize, Deserialize)]
553#[serde(rename_all = "snake_case")]
554pub enum TurnMethod {
555 RoundRobin,
556 Random,
557 LlmDirected,
558}
559
560#[derive(Debug, Clone, Serialize, Deserialize)]
562pub struct TerminationConfig {
563 #[serde(default)]
564 pub method: TerminationMethod,
565 #[serde(default = "default_stall_rounds")]
566 pub max_stall_rounds: u32,
567}
568
569impl Default for TerminationConfig {
570 fn default() -> Self {
571 Self {
572 method: TerminationMethod::default(),
573 max_stall_rounds: default_stall_rounds(),
574 }
575 }
576}
577
578#[derive(Debug, Clone, Default, Serialize, Deserialize)]
579#[serde(rename_all = "snake_case")]
580pub enum TerminationMethod {
581 #[default]
582 ManagerDecides,
583 MaxRounds,
584 ConsensusReached,
585}
586
587#[derive(Debug, Clone, Serialize, Deserialize)]
589pub struct DebateStyleConfig {
590 #[serde(default = "default_debate_rounds")]
591 pub rounds: u32,
592 pub synthesizer: String,
594}
595
596#[derive(Debug, Clone, Serialize, Deserialize)]
598pub struct MakerCheckerConfig {
599 #[serde(default = "default_maker_checker_iterations")]
600 pub max_iterations: u32,
601 pub acceptance_criteria: String,
603 #[serde(default)]
604 pub on_max_iterations: MaxIterationsAction,
605}
606
607#[derive(Debug, Clone, Default, Serialize, Deserialize)]
608#[serde(rename_all = "snake_case")]
609pub enum MaxIterationsAction {
610 #[default]
611 AcceptLast,
612 Escalate,
613 Fail,
614}
615
616fn default_max_rounds() -> u32 {
617 5
618}
619fn default_stall_rounds() -> u32 {
620 2
621}
622fn default_debate_rounds() -> u32 {
623 3
624}
625fn default_maker_checker_iterations() -> u32 {
626 3
627}
628
629#[derive(Debug, Clone, Serialize, Deserialize)]
631pub struct PipelineStateConfig {
632 pub stages: Vec<PipelineStageEntry>,
633
634 #[serde(default, skip_serializing_if = "Option::is_none")]
635 pub timeout_ms: Option<u64>,
636 #[serde(default, skip_serializing_if = "Option::is_none")]
638 pub context_mode: Option<DelegateContextMode>,
639}
640
641#[derive(Debug, Clone, Serialize, Deserialize)]
643#[serde(untagged)]
644pub enum PipelineStageEntry {
645 Id(String),
647 Config {
649 id: String,
650 #[serde(default, skip_serializing_if = "Option::is_none")]
651 input: Option<String>,
652 },
653}
654
655impl PipelineStageEntry {
656 pub fn id(&self) -> &str {
657 match self {
658 Self::Id(id) => id,
659 Self::Config { id, .. } => id,
660 }
661 }
662
663 pub fn input(&self) -> Option<&str> {
664 match self {
665 Self::Id(_) => None,
666 Self::Config { input, .. } => input.as_deref(),
667 }
668 }
669}
670
671#[derive(Debug, Clone, Serialize, Deserialize)]
673pub struct HandoffStateConfig {
674 pub initial_agent: String,
675 pub available_agents: Vec<String>,
676
677 #[serde(default = "default_max_handoffs")]
678 pub max_handoffs: u32,
679
680 #[serde(default, skip_serializing_if = "Option::is_none")]
683 pub input: Option<String>,
684 #[serde(default, skip_serializing_if = "Option::is_none")]
686 pub context_mode: Option<DelegateContextMode>,
687}
688
689fn default_max_handoffs() -> u32 {
690 5
691}
692
693fn validate_transition_timing(
694 transition: &Transition,
695 scope: &str,
696 state_path: Option<&str>,
697) -> Result<()> {
698 if transition.requires_response && !matches!(transition.timing, TransitionTiming::PostResponse)
699 {
700 let location = state_path
701 .map(|path| format!("State '{}'", path))
702 .unwrap_or_else(|| "Global transition".to_string());
703 return Err(AgentError::InvalidSpec(format!(
704 "{} has response-dependent transition '{}' with non-post-response timing",
705 location, transition.to
706 )));
707 }
708 if matches!(transition.timing, TransitionTiming::Parallel)
709 && transition.guard.is_none()
710 && transition.intent.is_none()
711 && transition.when.trim().is_empty()
712 {
713 return Err(AgentError::InvalidSpec(format!(
714 "{} transition '{}' uses parallel timing without a guard, intent, or when condition",
715 scope, transition.to
716 )));
717 }
718 if matches!(transition.timing, TransitionTiming::PreResponse) {
719 if transition.guard.is_none() && transition.intent.is_none() {
720 return Err(AgentError::InvalidSpec(format!(
721 "{} transition '{}' uses pre-response timing without a guard or intent",
722 scope, transition.to
723 )));
724 }
725 if !transition.when.trim().is_empty() {
726 return Err(AgentError::InvalidSpec(format!(
727 "{} transition '{}' uses pre-response timing with response-dependent when text",
728 scope, transition.to
729 )));
730 }
731 }
732 Ok(())
733}
734
735impl StateConfig {
736 pub fn validate(&self) -> Result<()> {
737 if self.initial.is_empty() {
738 return Err(AgentError::InvalidSpec(
739 "State machine initial state cannot be empty".into(),
740 ));
741 }
742 if !self.states.contains_key(&self.initial) {
743 return Err(AgentError::InvalidSpec(format!(
744 "Initial state '{}' not found in states",
745 self.initial
746 )));
747 }
748 for transition in &self.global_transitions {
749 validate_transition_timing(transition, "Global", None)?;
750 if !self.is_valid_transition_target(&transition.to, &[], &self.states) {
751 return Err(AgentError::InvalidSpec(format!(
752 "Global transition targets unknown state '{}'",
753 transition.to
754 )));
755 }
756 }
757
758 self.validate_states(&self.states, &[])?;
759
760 for warning in self.check_reachability() {
762 tracing::warn!("{}", warning);
763 }
764
765 Ok(())
766 }
767
768 fn validate_states(
769 &self,
770 states: &HashMap<String, StateDefinition>,
771 parent_path: &[String],
772 ) -> Result<()> {
773 for (name, def) in states {
774 let current_path: Vec<String> = parent_path
775 .iter()
776 .cloned()
777 .chain(std::iter::once(name.clone()))
778 .collect();
779
780 for transition in &def.transitions {
781 let path = current_path.join(".");
782 validate_transition_timing(transition, "State", Some(&path))?;
783
784 if !self.is_valid_transition_target(&transition.to, ¤t_path, states) {
785 return Err(AgentError::InvalidSpec(format!(
786 "State '{}' has transition to unknown state '{}'",
787 current_path.join("."),
788 transition.to
789 )));
790 }
791 }
792
793 if let Some(ref timeout_state) = def.timeout_to {
794 if !self.is_valid_transition_target(timeout_state, ¤t_path, states) {
795 return Err(AgentError::InvalidSpec(format!(
796 "State '{}' has timeout_to unknown state '{}'",
797 current_path.join("."),
798 timeout_state
799 )));
800 }
801 }
802
803 if let Some(ref sub_states) = def.states {
804 if let Some(ref initial) = def.initial {
805 if !sub_states.contains_key(initial) {
806 return Err(AgentError::InvalidSpec(format!(
807 "State '{}' has initial sub-state '{}' that doesn't exist",
808 current_path.join("."),
809 initial
810 )));
811 }
812 }
813 self.validate_states(sub_states, ¤t_path)?;
814 }
815 }
816 Ok(())
817 }
818
819 fn is_valid_transition_target(
820 &self,
821 target: &str,
822 current_path: &[String],
823 states: &HashMap<String, StateDefinition>,
824 ) -> bool {
825 if target.starts_with('^') {
826 let target_name = &target[1..];
827 return self.states.contains_key(target_name);
828 }
829
830 if states.contains_key(target) {
831 return true;
832 }
833
834 if current_path.len() > 1 {
835 let parent_path = ¤t_path[..current_path.len() - 1];
836 if let Some(parent_states) = self.get_states_at_path(parent_path) {
837 if parent_states.contains_key(target) {
838 return true;
839 }
840 }
841 }
842
843 self.states.contains_key(target)
844 }
845
846 fn get_states_at_path(&self, path: &[String]) -> Option<&HashMap<String, StateDefinition>> {
847 let mut current = &self.states;
848 for segment in path {
849 if let Some(def) = current.get(segment) {
850 if let Some(ref sub_states) = def.states {
851 current = sub_states;
852 } else {
853 return None;
854 }
855 } else {
856 return None;
857 }
858 }
859 Some(current)
860 }
861
862 pub fn get_state(&self, path: &str) -> Option<&StateDefinition> {
863 let parts: Vec<&str> = path.split('.').collect();
864 self.get_state_by_path(&parts)
865 }
866
867 fn get_state_by_path(&self, path: &[&str]) -> Option<&StateDefinition> {
868 if path.is_empty() {
869 return None;
870 }
871
872 let mut current = self.states.get(path[0])?;
873 for segment in &path[1..] {
874 if let Some(ref sub_states) = current.states {
875 current = sub_states.get(*segment)?;
876 } else {
877 return None;
878 }
879 }
880 Some(current)
881 }
882
883 pub fn resolve_full_path(&self, current_path: &str, target: &str) -> String {
886 if target.starts_with('^') {
887 return target[1..].to_string();
888 }
889
890 if self.states.contains_key(target) {
891 return target.to_string();
892 }
893
894 if !current_path.is_empty() {
895 let parts: Vec<&str> = current_path.split('.').collect();
896 if parts.len() > 1 {
897 let parent_path = parts[..parts.len() - 1].join(".");
898 let potential = format!("{}.{}", parent_path, target);
899 if self.get_state(&potential).is_some() {
900 return potential;
901 }
902 }
903
904 let potential = format!("{}.{}", current_path, target);
905 if self.get_state(&potential).is_some() {
906 return potential;
907 }
908 }
909
910 target.to_string()
911 }
912
913 pub fn check_reachability(&self) -> Vec<String> {
915 let mut reachable: HashSet<String> = HashSet::new();
916 reachable.insert(self.initial.clone());
917
918 if let Some(ref fb) = self.fallback {
919 reachable.insert(fb.clone());
920 }
921 for gt in &self.global_transitions {
922 reachable.insert(self.normalize_target(>.to));
923 }
924
925 let mut queue: Vec<String> = reachable.iter().cloned().collect();
926 while let Some(state_path) = queue.pop() {
927 if let Some(def) = self.get_state(&state_path) {
928 for t in &def.transitions {
929 let target = self.resolve_full_path(&state_path, &t.to);
930 if reachable.insert(target.clone()) {
931 queue.push(target);
932 }
933 }
934 if let Some(ref timeout) = def.timeout_to {
935 let target = self.resolve_full_path(&state_path, timeout);
936 if reachable.insert(target.clone()) {
937 queue.push(target);
938 }
939 }
940 if let (Some(initial), Some(_sub)) = (&def.initial, &def.states) {
941 let sub_path = format!("{}.{}", state_path, initial);
942 if reachable.insert(sub_path.clone()) {
943 queue.push(sub_path);
944 }
945 }
946 }
947 }
948
949 let all_states = self.collect_all_state_paths(&self.states, &[]);
950 let mut warnings = Vec::new();
951 for state_path in &all_states {
952 if !reachable.contains(state_path) {
953 warnings.push(format!(
954 "State '{}' appears unreachable — no transitions lead to it",
955 state_path
956 ));
957 }
958 }
959 warnings
960 }
961
962 fn normalize_target(&self, target: &str) -> String {
963 if target.starts_with('^') {
964 target[1..].to_string()
965 } else {
966 target.to_string()
967 }
968 }
969
970 fn collect_all_state_paths(
971 &self,
972 states: &HashMap<String, StateDefinition>,
973 parent: &[String],
974 ) -> Vec<String> {
975 let mut paths = Vec::new();
976 for (name, def) in states {
977 let mut current: Vec<String> = parent.to_vec();
978 current.push(name.clone());
979 paths.push(current.join("."));
980 if let Some(ref sub) = def.states {
981 paths.extend(self.collect_all_state_paths(sub, ¤t));
982 }
983 }
984 paths
985 }
986}
987
988impl StateDefinition {
989 pub fn has_sub_states(&self) -> bool {
990 self.states.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
991 }
992
993 pub fn get_effective_tools<'a>(
994 &'a self,
995 parent: Option<&'a StateDefinition>,
996 ) -> Option<Vec<&'a ToolRef>> {
997 match &self.tools {
998 Some(tools) => Some(tools.iter().collect()),
1000 None => {
1002 if !self.inherit_parent {
1003 return None;
1004 }
1005 parent
1006 .and_then(|p| p.tools.as_ref())
1007 .map(|t| t.iter().collect())
1008 }
1009 }
1010 }
1011
1012 pub fn get_effective_skills<'a>(
1013 &'a self,
1014 parent: Option<&'a StateDefinition>,
1015 ) -> Vec<&'a String> {
1016 if !self.inherit_parent || parent.is_none() {
1017 return self.skills.iter().collect();
1018 }
1019
1020 let parent = parent.unwrap();
1021 let mut skills: Vec<&'a String> = parent.skills.iter().collect();
1022 skills.extend(self.skills.iter());
1023 skills
1024 }
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029 use super::*;
1030
1031 #[test]
1032 fn test_transition_timing_defaults_to_post_response() {
1033 let yaml = r#"
1034to: next
1035when: "ready"
1036"#;
1037 let transition: Transition = serde_yaml::from_str(yaml).unwrap();
1038 assert_eq!(transition.timing, TransitionTiming::PostResponse);
1039 assert!(!transition.requires_response);
1040 assert!(!transition.run_extractors);
1041 }
1042
1043 #[test]
1044 fn test_state_config_deserialize() {
1045 let yaml = r#"
1046initial: greeting
1047states:
1048 greeting:
1049 prompt: "Welcome!"
1050 transitions:
1051 - to: support
1052 when: "user needs help"
1053 auto: true
1054 support:
1055 prompt: "How can I help?"
1056 llm: fast
1057 tools:
1058 - search
1059"#;
1060 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1061 assert_eq!(config.initial, "greeting");
1062 assert_eq!(config.states.len(), 2);
1063 assert!(config.validate().is_ok());
1064 }
1065
1066 #[test]
1067 fn test_prompt_mode_default() {
1068 let def = StateDefinition::default();
1069 assert_eq!(def.prompt_mode, PromptMode::Append);
1070 }
1071
1072 #[test]
1073 fn test_response_dependent_pre_response_transition_is_invalid() {
1074 let yaml = r#"
1075initial: greeting
1076states:
1077 greeting:
1078 transitions:
1079 - to: done
1080 when: "after answer"
1081 timing: pre_response
1082 requires_response: true
1083 done:
1084 prompt: "Done"
1085"#;
1086 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1087 assert!(config.validate().is_err());
1088 }
1089
1090 #[test]
1091 fn test_parallel_transition_timing_accepts_response_independent_condition() {
1092 let yaml = r#"
1093initial: greeting
1094states:
1095 greeting:
1096 transitions:
1097 - to: done
1098 when: "ready"
1099 timing: parallel
1100 done:
1101 prompt: "Done"
1102"#;
1103 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1104 assert!(config.validate().is_ok());
1105 }
1106
1107 #[test]
1108 fn test_parallel_transition_without_condition_is_invalid() {
1109 let yaml = r#"
1110initial: greeting
1111states:
1112 greeting:
1113 transitions:
1114 - to: done
1115 timing: parallel
1116 done:
1117 prompt: "Done"
1118"#;
1119 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1120 let err = config.validate().unwrap_err();
1121 assert!(err.to_string().contains("parallel timing without"));
1122 }
1123
1124 #[test]
1125 fn test_pre_response_when_without_guard_or_intent_is_invalid() {
1126 let yaml = r#"
1127initial: greeting
1128states:
1129 greeting:
1130 transitions:
1131 - to: done
1132 when: "ready"
1133 timing: pre_response
1134 done:
1135 prompt: "Done"
1136"#;
1137 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1138 assert!(config.validate().is_err());
1139 }
1140
1141 #[test]
1142 fn test_pre_response_with_when_text_is_invalid() {
1143 let yaml = r#"
1144initial: greeting
1145states:
1146 greeting:
1147 transitions:
1148 - to: done
1149 when: "ready"
1150 guard:
1151 context:
1152 ready:
1153 eq: true
1154 timing: pre_response
1155 done:
1156 prompt: "Done"
1157"#;
1158 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1159 assert!(config.validate().is_err());
1160 }
1161
1162 #[test]
1163 fn test_invalid_initial_state() {
1164 let config = StateConfig {
1165 initial: "nonexistent".into(),
1166 states: HashMap::new(),
1167 global_transitions: vec![],
1168 fallback: None,
1169 max_no_transition: None,
1170 regenerate_on_transition: true,
1171 };
1172 assert!(config.validate().is_err());
1173 }
1174
1175 #[test]
1176 fn test_invalid_transition_target() {
1177 let mut states = HashMap::new();
1178 states.insert(
1179 "start".into(),
1180 StateDefinition {
1181 transitions: vec![Transition {
1182 to: "nonexistent".into(),
1183 when: "always".into(),
1184 guard: None,
1185 intent: None,
1186 auto: true,
1187 priority: 0,
1188 cooldown_turns: None,
1189 timing: TransitionTiming::PostResponse,
1190 requires_response: false,
1191 run_extractors: false,
1192 }],
1193 ..Default::default()
1194 },
1195 );
1196 let config = StateConfig {
1197 initial: "start".into(),
1198 states,
1199 global_transitions: vec![],
1200 fallback: None,
1201 max_no_transition: None,
1202 regenerate_on_transition: true,
1203 };
1204 assert!(config.validate().is_err());
1205 }
1206
1207 #[test]
1208 fn test_hierarchical_states() {
1209 let yaml = r#"
1210initial: problem_solving
1211states:
1212 problem_solving:
1213 initial: gathering_info
1214 prompt: "Solving customer problem"
1215 states:
1216 gathering_info:
1217 prompt: "Ask questions"
1218 transitions:
1219 - to: proposing_solution
1220 when: "understood"
1221 proposing_solution:
1222 prompt: "Offer solution"
1223 transitions:
1224 - to: ^closing
1225 when: "resolved"
1226 closing:
1227 prompt: "Thank you"
1228"#;
1229 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1230 assert!(config.validate().is_ok());
1231 assert!(
1232 config
1233 .states
1234 .get("problem_solving")
1235 .unwrap()
1236 .has_sub_states()
1237 );
1238 }
1239
1240 #[test]
1241 fn test_tool_ref_simple() {
1242 let yaml = r#"
1243tools:
1244 - calculator
1245 - search
1246"#;
1247 #[derive(Deserialize)]
1248 struct Test {
1249 tools: Vec<ToolRef>,
1250 }
1251 let t: Test = serde_yaml::from_str(yaml).unwrap();
1252 assert_eq!(t.tools.len(), 2);
1253 assert_eq!(t.tools[0].id(), "calculator");
1254 }
1255
1256 #[test]
1257 fn test_tool_ref_conditional() {
1258 let yaml = r#"
1259tools:
1260 - calculator
1261 - id: admin_tool
1262 condition:
1263 context:
1264 user.role: "admin"
1265"#;
1266 #[derive(Deserialize)]
1267 struct Test {
1268 tools: Vec<ToolRef>,
1269 }
1270 let t: Test = serde_yaml::from_str(yaml).unwrap();
1271 assert_eq!(t.tools.len(), 2);
1272 assert_eq!(t.tools[1].id(), "admin_tool");
1273 assert!(t.tools[1].condition().is_some());
1274 }
1275
1276 #[test]
1277 fn test_transition_with_guard() {
1278 let yaml = r#"
1279to: next_state
1280when: "user wants to proceed"
1281guard: "{{ context.has_data }}"
1282auto: true
1283priority: 10
1284"#;
1285 let t: Transition = serde_yaml::from_str(yaml).unwrap();
1286 assert!(t.guard.is_some());
1287 assert_eq!(t.priority, 10);
1288 }
1289
1290 #[test]
1291 fn test_state_action() {
1292 let yaml = r#"
1293- tool: log_event
1294 args:
1295 event: "entered"
1296- skill: greeting_skill
1297- set_context:
1298 entered: true
1299"#;
1300 let actions: Vec<StateAction> = serde_yaml::from_str(yaml).unwrap();
1301 assert_eq!(actions.len(), 3);
1302 match &actions[0] {
1303 StateAction::Tool { tool, .. } => assert_eq!(tool, "log_event"),
1304 _ => panic!("Expected Tool action"),
1305 }
1306 match &actions[1] {
1307 StateAction::Skill { skill } => assert_eq!(skill, "greeting_skill"),
1308 _ => panic!("Expected Skill action"),
1309 }
1310 match &actions[2] {
1311 StateAction::SetContext { set_context } => {
1312 assert!(set_context.contains_key("entered"));
1313 }
1314 _ => panic!("Expected SetContext action"),
1315 }
1316 }
1317
1318 #[test]
1319 fn test_complex_tool_condition() {
1320 let yaml = r#"
1321id: refund_tool
1322condition:
1323 all:
1324 - context:
1325 user.verified: true
1326 - semantic:
1327 when: "user wants refund"
1328 threshold: 0.85
1329"#;
1330 let tool: ToolRef = serde_yaml::from_str(yaml).unwrap();
1331 assert_eq!(tool.id(), "refund_tool");
1332 match tool.condition().unwrap() {
1333 ToolCondition::All(conditions) => assert_eq!(conditions.len(), 2),
1334 _ => panic!("Expected All condition"),
1335 }
1336 }
1337
1338 #[test]
1339 fn test_state_get_path() {
1340 let yaml = r#"
1341initial: problem_solving
1342states:
1343 problem_solving:
1344 initial: gathering_info
1345 states:
1346 gathering_info:
1347 prompt: "Ask"
1348 proposing:
1349 prompt: "Propose"
1350 closing:
1351 prompt: "Done"
1352"#;
1353 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1354 assert!(config.get_state("problem_solving").is_some());
1355 assert!(config.get_state("problem_solving.gathering_info").is_some());
1356 assert!(config.get_state("closing").is_some());
1357 assert!(config.get_state("nonexistent").is_none());
1358 }
1359
1360 #[test]
1361 fn test_resolve_full_path() {
1362 let yaml = r#"
1363initial: problem_solving
1364states:
1365 problem_solving:
1366 initial: gathering_info
1367 states:
1368 gathering_info:
1369 prompt: "Ask"
1370 proposing:
1371 prompt: "Propose"
1372 closing:
1373 prompt: "Done"
1374"#;
1375 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1376
1377 assert_eq!(
1378 config.resolve_full_path("problem_solving.gathering_info", "proposing"),
1379 "problem_solving.proposing"
1380 );
1381 assert_eq!(
1382 config.resolve_full_path("problem_solving.gathering_info", "^closing"),
1383 "closing"
1384 );
1385 assert_eq!(
1386 config.resolve_full_path("problem_solving", "closing"),
1387 "closing"
1388 );
1389 }
1390
1391 #[test]
1392 fn test_inherit_parent() {
1393 let parent = StateDefinition {
1394 tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1395 skills: vec!["parent_skill".into()],
1396 ..Default::default()
1397 };
1398
1399 let child = StateDefinition {
1400 tools: Some(vec![ToolRef::Simple("child_tool".into())]),
1401 skills: vec!["child_skill".into()],
1402 inherit_parent: true,
1403 ..Default::default()
1404 };
1405
1406 let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1407 assert_eq!(effective_tools.len(), 1); let effective_skills = child.get_effective_skills(Some(&parent));
1410 assert_eq!(effective_skills.len(), 2);
1411 }
1412
1413 #[test]
1414 fn test_no_inherit_parent() {
1415 let parent = StateDefinition {
1416 tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1417 ..Default::default()
1418 };
1419
1420 let child = StateDefinition {
1421 tools: Some(vec![ToolRef::Simple("child_tool".into())]),
1422 inherit_parent: false,
1423 ..Default::default()
1424 };
1425
1426 let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1427 assert_eq!(effective_tools.len(), 1);
1428 assert_eq!(effective_tools[0].id(), "child_tool");
1429 }
1430
1431 #[test]
1432 fn test_tools_none_inherits() {
1433 let parent = StateDefinition {
1434 tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1435 ..Default::default()
1436 };
1437
1438 let child = StateDefinition {
1439 tools: None, inherit_parent: true,
1441 ..Default::default()
1442 };
1443
1444 let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1445 assert_eq!(effective_tools.len(), 1);
1446 assert_eq!(effective_tools[0].id(), "parent_tool");
1447 }
1448
1449 #[test]
1450 fn test_tools_empty_means_no_tools() {
1451 let parent = StateDefinition {
1452 tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1453 ..Default::default()
1454 };
1455
1456 let child = StateDefinition {
1457 tools: Some(vec![]), inherit_parent: true,
1459 ..Default::default()
1460 };
1461
1462 let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1463 assert!(effective_tools.is_empty());
1464 }
1465
1466 #[test]
1467 fn test_state_with_disambiguation_override() {
1468 let yaml = r#"
1469initial: greeting
1470states:
1471 greeting:
1472 prompt: "Hello"
1473 transitions:
1474 - to: payment
1475 when: "User wants to pay"
1476 payment:
1477 prompt: "Processing payment"
1478 disambiguation:
1479 threshold: 0.95
1480 require_confirmation: true
1481 required_clarity:
1482 - recipient
1483 - amount
1484"#;
1485 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1486 let payment = config.get_state("payment").unwrap();
1487 let disambig = payment.disambiguation.as_ref().unwrap();
1488 assert_eq!(disambig.threshold, Some(0.95));
1489 assert!(disambig.require_confirmation);
1490 assert_eq!(disambig.required_clarity.len(), 2);
1491 assert!(disambig.required_clarity.contains(&"recipient".to_string()));
1492
1493 let greeting = config.get_state("greeting").unwrap();
1494 assert!(greeting.disambiguation.is_none());
1495 }
1496
1497 #[test]
1498 fn test_context_extractor_vec_deserialize() {
1499 let yaml = r#"
1500initial: a
1501states:
1502 a:
1503 extract:
1504 - key: user_email
1505 description: "The user's email address"
1506 - key: order_id
1507 llm_extract: "Extract the order ID"
1508 required: true
1509"#;
1510 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1511 let state = config.get_state("a").unwrap();
1512 assert_eq!(state.extract.len(), 2);
1513 assert_eq!(state.extract[0].key, "user_email");
1514 assert_eq!(
1515 state.extract[0].description.as_deref(),
1516 Some("The user's email address")
1517 );
1518 assert!(!state.extract[0].required);
1519 assert_eq!(state.extract[0].llm, "router");
1520 assert_eq!(state.extract[1].key, "order_id");
1521 assert!(state.extract[1].required);
1522 assert!(state.extract[1].llm_extract.is_some());
1523 }
1524
1525 #[test]
1526 fn test_context_extractor_default_empty() {
1527 let yaml = r#"
1528initial: a
1529states:
1530 a:
1531 prompt: "Hello"
1532"#;
1533 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1534 let state = config.get_state("a").unwrap();
1535 assert!(state.extract.is_empty());
1536 }
1537
1538 #[test]
1539 fn test_state_process_override_deserialize() {
1540 let yaml = r#"
1541initial: a
1542states:
1543 a:
1544 process:
1545 input:
1546 - type: normalize
1547 config:
1548 trim: true
1549"#;
1550 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1551 let state = config.get_state("a").unwrap();
1552 assert!(state.process.is_some());
1553 assert_eq!(state.process.as_ref().unwrap().input.len(), 1);
1554 }
1555
1556 #[test]
1557 fn test_state_process_default_none() {
1558 let yaml = r#"
1559initial: a
1560states:
1561 a:
1562 prompt: "Hello"
1563"#;
1564 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1565 let state = config.get_state("a").unwrap();
1566 assert!(state.process.is_none());
1567 }
1568}