1use ai_agents_core::{AgentError, Result};
2use ai_agents_disambiguation::StateDisambiguationOverride;
3use ai_agents_process::ProcessConfig;
4use ai_agents_reasoning::{ReasoningConfig, ReflectionConfig};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
10#[serde(rename_all = "lowercase")]
11pub enum PromptMode {
12 #[default]
13 Append,
14 Replace,
15 Prepend,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StateConfig {
20 pub initial: String,
21 #[serde(default)]
22 pub states: HashMap<String, StateDefinition>,
23 #[serde(default)]
24 pub global_transitions: Vec<Transition>,
25 #[serde(default)]
26 pub fallback: Option<String>,
27 #[serde(default)]
28 pub max_no_transition: Option<u32>,
29
30 #[serde(default = "default_true")]
32 pub regenerate_on_transition: bool,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, Default)]
36pub struct StateDefinition {
37 #[serde(default)]
38 pub prompt: Option<String>,
39
40 #[serde(default)]
41 pub prompt_mode: PromptMode,
42
43 #[serde(default)]
44 pub llm: Option<String>,
45
46 #[serde(default)]
47 pub skills: Vec<String>,
48
49 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub tools: Option<Vec<ToolRef>>,
55
56 #[serde(default)]
57 pub transitions: Vec<Transition>,
58
59 #[serde(default)]
60 pub max_turns: Option<u32>,
61
62 #[serde(default)]
63 pub timeout_to: Option<String>,
64
65 #[serde(default)]
66 pub initial: Option<String>,
67
68 #[serde(default)]
69 pub states: Option<HashMap<String, StateDefinition>>,
70
71 #[serde(default = "default_inherit_parent")]
72 pub inherit_parent: bool,
73
74 #[serde(default)]
75 pub on_enter: Vec<StateAction>,
76
77 #[serde(default)]
79 pub on_reenter: Vec<StateAction>,
80
81 #[serde(default)]
82 pub on_exit: Vec<StateAction>,
83
84 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub regenerate_on_enter: Option<bool>,
87
88 #[serde(default)]
90 pub extract: Vec<ContextExtractor>,
91
92 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub reasoning: Option<ReasoningConfig>,
94
95 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub reflection: Option<ReflectionConfig>,
97
98 #[serde(default, skip_serializing_if = "Option::is_none")]
99 pub disambiguation: Option<StateDisambiguationOverride>,
100
101 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub process: Option<ProcessConfig>,
104
105 #[serde(default, skip_serializing_if = "Option::is_none")]
107 pub delegate: Option<String>,
108
109 #[serde(default, skip_serializing_if = "Option::is_none")]
111 pub delegate_context: Option<DelegateContextMode>,
112
113 #[serde(default, skip_serializing_if = "Option::is_none")]
115 pub concurrent: Option<ConcurrentStateConfig>,
116
117 #[serde(default, skip_serializing_if = "Option::is_none")]
119 pub group_chat: Option<GroupChatStateConfig>,
120
121 #[serde(default, skip_serializing_if = "Option::is_none")]
123 pub pipeline: Option<PipelineStateConfig>,
124
125 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub handoff: Option<HandoffStateConfig>,
128}
129
130fn default_inherit_parent() -> bool {
131 true
132}
133
134fn default_true() -> bool {
135 true
136}
137
138fn default_extractor_llm() -> String {
139 "router".to_string()
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum ToolRef {
145 Simple(String),
146 Conditional {
147 id: String,
148 condition: ToolCondition,
149 },
150}
151
152impl ToolRef {
153 pub fn id(&self) -> &str {
154 match self {
155 ToolRef::Simple(id) => id,
156 ToolRef::Conditional { id, .. } => id,
157 }
158 }
159
160 pub fn condition(&self) -> Option<&ToolCondition> {
161 match self {
162 ToolRef::Simple(_) => None,
163 ToolRef::Conditional { condition, .. } => Some(condition),
164 }
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum ToolCondition {
171 Context(HashMap<String, ContextMatcher>),
172 State(StateMatcher),
173 AfterTool(String),
174 ToolResult {
175 tool: String,
176 result: HashMap<String, Value>,
177 },
178 Semantic {
179 when: String,
180 #[serde(default = "default_semantic_llm")]
181 llm: String,
182 #[serde(default = "default_threshold")]
183 threshold: f32,
184 },
185 Time(TimeMatcher),
186 All(Vec<ToolCondition>),
187 Any(Vec<ToolCondition>),
188 Not(Box<ToolCondition>),
189}
190
191fn default_semantic_llm() -> String {
192 "router".to_string()
193}
194
195fn default_threshold() -> f32 {
196 0.7
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200#[serde(untagged)]
201pub enum ContextMatcher {
202 Exists { exists: bool },
207 Compare(CompareOp),
208 Exact(Value),
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212#[serde(rename_all = "snake_case")]
213pub enum CompareOp {
214 Eq(Value),
215 Neq(Value),
216 Gt(f64),
217 Gte(f64),
218 Lt(f64),
219 Lte(f64),
220 In(Vec<Value>),
221 Contains(String),
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize, Default)]
225pub struct StateMatcher {
226 #[serde(default)]
227 pub name: Option<String>,
228 #[serde(default)]
229 pub turn_count: Option<CompareOp>,
230 #[serde(default)]
231 pub previous: Option<String>,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, Default)]
235pub struct TimeMatcher {
236 #[serde(default)]
237 pub hours: Option<CompareOp>,
238 #[serde(default)]
239 pub day_of_week: Option<Vec<String>>,
240 #[serde(default)]
241 pub timezone: Option<String>,
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct Transition {
246 pub to: String,
247 #[serde(default)]
248 pub when: String,
249 #[serde(default)]
250 pub guard: Option<TransitionGuard>,
251 #[serde(default, skip_serializing_if = "Option::is_none")]
253 pub intent: Option<String>,
254 #[serde(default = "default_auto")]
255 pub auto: bool,
256 #[serde(default)]
257 pub priority: u8,
258
259 #[serde(default, skip_serializing_if = "Option::is_none")]
261 pub cooldown_turns: Option<u32>,
262}
263
264fn default_auto() -> bool {
265 true
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
269#[serde(untagged)]
270pub enum TransitionGuard {
271 Expression(String),
272 Conditions(GuardConditions),
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276#[serde(rename_all = "snake_case")]
277pub enum GuardConditions {
278 All(Vec<String>),
279 Any(Vec<String>),
280 Context(HashMap<String, ContextMatcher>),
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
284#[serde(untagged)]
285pub enum StateAction {
286 Tool {
287 tool: String,
288 #[serde(default)]
289 args: Option<Value>,
290 },
291 Skill {
292 skill: String,
293 },
294 Prompt {
295 prompt: String,
296 #[serde(default)]
297 llm: Option<String>,
298 #[serde(default)]
299 store_as: Option<String>,
300 },
301 SetContext {
302 set_context: HashMap<String, Value>,
303 },
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct ContextExtractor {
309 pub key: String,
311
312 #[serde(default)]
314 pub description: Option<String>,
315
316 #[serde(default)]
318 pub llm_extract: Option<String>,
319
320 #[serde(default = "default_extractor_llm")]
322 pub llm: String,
323
324 #[serde(default)]
326 pub required: bool,
327}
328
329#[derive(Debug, Clone, Default, Serialize, Deserialize)]
335#[serde(rename_all = "snake_case")]
336pub enum DelegateContextMode {
337 #[default]
339 InputOnly,
340 Summary,
342 Full,
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct ConcurrentStateConfig {
349 pub agents: Vec<ConcurrentAgentRef>,
351 #[serde(default, skip_serializing_if = "Option::is_none")]
353 pub input: Option<String>,
354 pub aggregation: AggregationConfig,
356 #[serde(default, skip_serializing_if = "Option::is_none")]
358 pub min_required: Option<usize>,
359 #[serde(default)]
361 pub on_partial_failure: PartialFailureAction,
362 #[serde(default, skip_serializing_if = "Option::is_none")]
364 pub timeout_ms: Option<u64>,
365 #[serde(default, skip_serializing_if = "Option::is_none")]
367 pub context_mode: Option<DelegateContextMode>,
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize)]
372#[serde(untagged)]
373pub enum ConcurrentAgentRef {
374 Id(String),
375 Weighted { id: String, weight: f64 },
376}
377
378impl ConcurrentAgentRef {
379 pub fn id(&self) -> &str {
380 match self {
381 Self::Id(id) => id,
382 Self::Weighted { id, .. } => id,
383 }
384 }
385
386 pub fn weight(&self) -> f64 {
387 match self {
388 Self::Id(_) => 1.0,
389 Self::Weighted { weight, .. } => *weight,
390 }
391 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct AggregationConfig {
397 pub strategy: AggregationStrategy,
399 #[serde(default, skip_serializing_if = "Option::is_none")]
401 pub synthesizer_llm: Option<String>,
402 #[serde(default, skip_serializing_if = "Option::is_none")]
404 pub synthesizer_prompt: Option<String>,
405 #[serde(default, skip_serializing_if = "Option::is_none")]
407 pub vote: Option<VoteConfig>,
408}
409
410#[derive(Debug, Clone, Serialize, Deserialize)]
411#[serde(rename_all = "snake_case")]
412pub enum AggregationStrategy {
413 Voting,
414 LlmSynthesis,
415 FirstWins,
416 All,
417}
418
419#[derive(Debug, Clone, Serialize, Deserialize)]
421pub struct VoteConfig {
422 #[serde(default)]
423 pub method: VoteMethod,
424 #[serde(default)]
425 pub tiebreaker: TiebreakerStrategy,
426 #[serde(default, skip_serializing_if = "Option::is_none")]
428 pub vote_prompt: Option<String>,
429}
430
431#[derive(Debug, Clone, Default, Serialize, Deserialize)]
432#[serde(rename_all = "snake_case")]
433pub enum VoteMethod {
434 #[default]
435 Majority,
436 Weighted,
437 Unanimous,
438}
439
440#[derive(Debug, Clone, Default, Serialize, Deserialize)]
441#[serde(rename_all = "snake_case")]
442pub enum TiebreakerStrategy {
443 #[default]
444 First,
445 Random,
446 RouterDecides,
447}
448
449#[derive(Debug, Clone, Default, Serialize, Deserialize)]
450#[serde(rename_all = "snake_case")]
451pub enum PartialFailureAction {
452 #[default]
453 ProceedWithAvailable,
454 Abort,
455}
456
457#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct GroupChatStateConfig {
460 pub participants: Vec<ChatParticipant>,
462 #[serde(default)]
464 pub style: ChatStyle,
465 #[serde(default = "default_max_rounds")]
467 pub max_rounds: u32,
468 #[serde(default, skip_serializing_if = "Option::is_none")]
470 pub manager: Option<ChatManagerConfig>,
471 #[serde(default)]
473 pub termination: TerminationConfig,
474 #[serde(default, skip_serializing_if = "Option::is_none")]
476 pub debate: Option<DebateStyleConfig>,
477 #[serde(default, skip_serializing_if = "Option::is_none")]
479 pub maker_checker: Option<MakerCheckerConfig>,
480 #[serde(default, skip_serializing_if = "Option::is_none")]
482 pub timeout_ms: Option<u64>,
483 #[serde(default, skip_serializing_if = "Option::is_none")]
486 pub input: Option<String>,
487 #[serde(default, skip_serializing_if = "Option::is_none")]
489 pub context_mode: Option<DelegateContextMode>,
490}
491
492#[derive(Debug, Clone, Serialize, Deserialize)]
494pub struct ChatParticipant {
495 pub id: String,
497 #[serde(default, skip_serializing_if = "Option::is_none")]
499 pub role: Option<String>,
500}
501
502#[derive(Debug, Clone, Default, Serialize, Deserialize)]
503#[serde(rename_all = "snake_case")]
504pub enum ChatStyle {
505 #[default]
506 Brainstorm,
507 Debate,
508 MakerChecker,
509 Consensus,
510}
511
512#[derive(Debug, Clone, Serialize, Deserialize)]
514pub struct ChatManagerConfig {
515 #[serde(default, skip_serializing_if = "Option::is_none")]
517 pub agent: Option<String>,
518 #[serde(default, skip_serializing_if = "Option::is_none")]
520 pub method: Option<TurnMethod>,
521}
522
523#[derive(Debug, Clone, Serialize, Deserialize)]
524#[serde(rename_all = "snake_case")]
525pub enum TurnMethod {
526 RoundRobin,
527 Random,
528 LlmDirected,
529}
530
531#[derive(Debug, Clone, Serialize, Deserialize)]
533pub struct TerminationConfig {
534 #[serde(default)]
535 pub method: TerminationMethod,
536 #[serde(default = "default_stall_rounds")]
537 pub max_stall_rounds: u32,
538}
539
540impl Default for TerminationConfig {
541 fn default() -> Self {
542 Self {
543 method: TerminationMethod::default(),
544 max_stall_rounds: default_stall_rounds(),
545 }
546 }
547}
548
549#[derive(Debug, Clone, Default, Serialize, Deserialize)]
550#[serde(rename_all = "snake_case")]
551pub enum TerminationMethod {
552 #[default]
553 ManagerDecides,
554 MaxRounds,
555 ConsensusReached,
556}
557
558#[derive(Debug, Clone, Serialize, Deserialize)]
560pub struct DebateStyleConfig {
561 #[serde(default = "default_debate_rounds")]
562 pub rounds: u32,
563 pub synthesizer: String,
565}
566
567#[derive(Debug, Clone, Serialize, Deserialize)]
569pub struct MakerCheckerConfig {
570 #[serde(default = "default_maker_checker_iterations")]
571 pub max_iterations: u32,
572 pub acceptance_criteria: String,
574 #[serde(default)]
575 pub on_max_iterations: MaxIterationsAction,
576}
577
578#[derive(Debug, Clone, Default, Serialize, Deserialize)]
579#[serde(rename_all = "snake_case")]
580pub enum MaxIterationsAction {
581 #[default]
582 AcceptLast,
583 Escalate,
584 Fail,
585}
586
587fn default_max_rounds() -> u32 {
588 5
589}
590fn default_stall_rounds() -> u32 {
591 2
592}
593fn default_debate_rounds() -> u32 {
594 3
595}
596fn default_maker_checker_iterations() -> u32 {
597 3
598}
599
600#[derive(Debug, Clone, Serialize, Deserialize)]
602pub struct PipelineStateConfig {
603 pub stages: Vec<PipelineStageEntry>,
604
605 #[serde(default, skip_serializing_if = "Option::is_none")]
606 pub timeout_ms: Option<u64>,
607 #[serde(default, skip_serializing_if = "Option::is_none")]
609 pub context_mode: Option<DelegateContextMode>,
610}
611
612#[derive(Debug, Clone, Serialize, Deserialize)]
614#[serde(untagged)]
615pub enum PipelineStageEntry {
616 Id(String),
618 Config {
620 id: String,
621 #[serde(default, skip_serializing_if = "Option::is_none")]
622 input: Option<String>,
623 },
624}
625
626impl PipelineStageEntry {
627 pub fn id(&self) -> &str {
628 match self {
629 Self::Id(id) => id,
630 Self::Config { id, .. } => id,
631 }
632 }
633
634 pub fn input(&self) -> Option<&str> {
635 match self {
636 Self::Id(_) => None,
637 Self::Config { input, .. } => input.as_deref(),
638 }
639 }
640}
641
642#[derive(Debug, Clone, Serialize, Deserialize)]
644pub struct HandoffStateConfig {
645 pub initial_agent: String,
646 pub available_agents: Vec<String>,
647
648 #[serde(default = "default_max_handoffs")]
649 pub max_handoffs: u32,
650
651 #[serde(default, skip_serializing_if = "Option::is_none")]
654 pub input: Option<String>,
655 #[serde(default, skip_serializing_if = "Option::is_none")]
657 pub context_mode: Option<DelegateContextMode>,
658}
659
660fn default_max_handoffs() -> u32 {
661 5
662}
663
664impl StateConfig {
665 pub fn validate(&self) -> Result<()> {
666 if self.initial.is_empty() {
667 return Err(AgentError::InvalidSpec(
668 "State machine initial state cannot be empty".into(),
669 ));
670 }
671 if !self.states.contains_key(&self.initial) {
672 return Err(AgentError::InvalidSpec(format!(
673 "Initial state '{}' not found in states",
674 self.initial
675 )));
676 }
677 self.validate_states(&self.states, &[])?;
678
679 for warning in self.check_reachability() {
681 tracing::warn!("{}", warning);
682 }
683
684 Ok(())
685 }
686
687 fn validate_states(
688 &self,
689 states: &HashMap<String, StateDefinition>,
690 parent_path: &[String],
691 ) -> Result<()> {
692 for (name, def) in states {
693 let current_path: Vec<String> = parent_path
694 .iter()
695 .cloned()
696 .chain(std::iter::once(name.clone()))
697 .collect();
698
699 for transition in &def.transitions {
700 if !self.is_valid_transition_target(&transition.to, ¤t_path, states) {
701 return Err(AgentError::InvalidSpec(format!(
702 "State '{}' has transition to unknown state '{}'",
703 current_path.join("."),
704 transition.to
705 )));
706 }
707 }
708
709 if let Some(ref timeout_state) = def.timeout_to {
710 if !self.is_valid_transition_target(timeout_state, ¤t_path, states) {
711 return Err(AgentError::InvalidSpec(format!(
712 "State '{}' has timeout_to unknown state '{}'",
713 current_path.join("."),
714 timeout_state
715 )));
716 }
717 }
718
719 if let Some(ref sub_states) = def.states {
720 if let Some(ref initial) = def.initial {
721 if !sub_states.contains_key(initial) {
722 return Err(AgentError::InvalidSpec(format!(
723 "State '{}' has initial sub-state '{}' that doesn't exist",
724 current_path.join("."),
725 initial
726 )));
727 }
728 }
729 self.validate_states(sub_states, ¤t_path)?;
730 }
731 }
732 Ok(())
733 }
734
735 fn is_valid_transition_target(
736 &self,
737 target: &str,
738 current_path: &[String],
739 states: &HashMap<String, StateDefinition>,
740 ) -> bool {
741 if target.starts_with('^') {
742 let target_name = &target[1..];
743 return self.states.contains_key(target_name);
744 }
745
746 if states.contains_key(target) {
747 return true;
748 }
749
750 if current_path.len() > 1 {
751 let parent_path = ¤t_path[..current_path.len() - 1];
752 if let Some(parent_states) = self.get_states_at_path(parent_path) {
753 if parent_states.contains_key(target) {
754 return true;
755 }
756 }
757 }
758
759 self.states.contains_key(target)
760 }
761
762 fn get_states_at_path(&self, path: &[String]) -> Option<&HashMap<String, StateDefinition>> {
763 let mut current = &self.states;
764 for segment in path {
765 if let Some(def) = current.get(segment) {
766 if let Some(ref sub_states) = def.states {
767 current = sub_states;
768 } else {
769 return None;
770 }
771 } else {
772 return None;
773 }
774 }
775 Some(current)
776 }
777
778 pub fn get_state(&self, path: &str) -> Option<&StateDefinition> {
779 let parts: Vec<&str> = path.split('.').collect();
780 self.get_state_by_path(&parts)
781 }
782
783 fn get_state_by_path(&self, path: &[&str]) -> Option<&StateDefinition> {
784 if path.is_empty() {
785 return None;
786 }
787
788 let mut current = self.states.get(path[0])?;
789 for segment in &path[1..] {
790 if let Some(ref sub_states) = current.states {
791 current = sub_states.get(*segment)?;
792 } else {
793 return None;
794 }
795 }
796 Some(current)
797 }
798
799 pub fn resolve_full_path(&self, current_path: &str, target: &str) -> String {
802 if target.starts_with('^') {
803 return target[1..].to_string();
804 }
805
806 if self.states.contains_key(target) {
807 return target.to_string();
808 }
809
810 if !current_path.is_empty() {
811 let parts: Vec<&str> = current_path.split('.').collect();
812 if parts.len() > 1 {
813 let parent_path = parts[..parts.len() - 1].join(".");
814 let potential = format!("{}.{}", parent_path, target);
815 if self.get_state(&potential).is_some() {
816 return potential;
817 }
818 }
819
820 let potential = format!("{}.{}", current_path, target);
821 if self.get_state(&potential).is_some() {
822 return potential;
823 }
824 }
825
826 target.to_string()
827 }
828
829 pub fn check_reachability(&self) -> Vec<String> {
831 let mut reachable: HashSet<String> = HashSet::new();
832 reachable.insert(self.initial.clone());
833
834 if let Some(ref fb) = self.fallback {
835 reachable.insert(fb.clone());
836 }
837 for gt in &self.global_transitions {
838 reachable.insert(self.normalize_target(>.to));
839 }
840
841 let mut queue: Vec<String> = reachable.iter().cloned().collect();
842 while let Some(state_path) = queue.pop() {
843 if let Some(def) = self.get_state(&state_path) {
844 for t in &def.transitions {
845 let target = self.resolve_full_path(&state_path, &t.to);
846 if reachable.insert(target.clone()) {
847 queue.push(target);
848 }
849 }
850 if let Some(ref timeout) = def.timeout_to {
851 let target = self.resolve_full_path(&state_path, timeout);
852 if reachable.insert(target.clone()) {
853 queue.push(target);
854 }
855 }
856 if let (Some(initial), Some(_sub)) = (&def.initial, &def.states) {
857 let sub_path = format!("{}.{}", state_path, initial);
858 if reachable.insert(sub_path.clone()) {
859 queue.push(sub_path);
860 }
861 }
862 }
863 }
864
865 let all_states = self.collect_all_state_paths(&self.states, &[]);
866 let mut warnings = Vec::new();
867 for state_path in &all_states {
868 if !reachable.contains(state_path) {
869 warnings.push(format!(
870 "State '{}' appears unreachable — no transitions lead to it",
871 state_path
872 ));
873 }
874 }
875 warnings
876 }
877
878 fn normalize_target(&self, target: &str) -> String {
879 if target.starts_with('^') {
880 target[1..].to_string()
881 } else {
882 target.to_string()
883 }
884 }
885
886 fn collect_all_state_paths(
887 &self,
888 states: &HashMap<String, StateDefinition>,
889 parent: &[String],
890 ) -> Vec<String> {
891 let mut paths = Vec::new();
892 for (name, def) in states {
893 let mut current: Vec<String> = parent.to_vec();
894 current.push(name.clone());
895 paths.push(current.join("."));
896 if let Some(ref sub) = def.states {
897 paths.extend(self.collect_all_state_paths(sub, ¤t));
898 }
899 }
900 paths
901 }
902}
903
904impl StateDefinition {
905 pub fn has_sub_states(&self) -> bool {
906 self.states.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
907 }
908
909 pub fn get_effective_tools<'a>(
910 &'a self,
911 parent: Option<&'a StateDefinition>,
912 ) -> Option<Vec<&'a ToolRef>> {
913 match &self.tools {
914 Some(tools) => Some(tools.iter().collect()),
916 None => {
918 if !self.inherit_parent {
919 return None;
920 }
921 parent
922 .and_then(|p| p.tools.as_ref())
923 .map(|t| t.iter().collect())
924 }
925 }
926 }
927
928 pub fn get_effective_skills<'a>(
929 &'a self,
930 parent: Option<&'a StateDefinition>,
931 ) -> Vec<&'a String> {
932 if !self.inherit_parent || parent.is_none() {
933 return self.skills.iter().collect();
934 }
935
936 let parent = parent.unwrap();
937 let mut skills: Vec<&'a String> = parent.skills.iter().collect();
938 skills.extend(self.skills.iter());
939 skills
940 }
941}
942
943#[cfg(test)]
944mod tests {
945 use super::*;
946
947 #[test]
948 fn test_state_config_deserialize() {
949 let yaml = r#"
950initial: greeting
951states:
952 greeting:
953 prompt: "Welcome!"
954 transitions:
955 - to: support
956 when: "user needs help"
957 auto: true
958 support:
959 prompt: "How can I help?"
960 llm: fast
961 tools:
962 - search
963"#;
964 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
965 assert_eq!(config.initial, "greeting");
966 assert_eq!(config.states.len(), 2);
967 assert!(config.validate().is_ok());
968 }
969
970 #[test]
971 fn test_prompt_mode_default() {
972 let def = StateDefinition::default();
973 assert_eq!(def.prompt_mode, PromptMode::Append);
974 }
975
976 #[test]
977 fn test_invalid_initial_state() {
978 let config = StateConfig {
979 initial: "nonexistent".into(),
980 states: HashMap::new(),
981 global_transitions: vec![],
982 fallback: None,
983 max_no_transition: None,
984 regenerate_on_transition: true,
985 };
986 assert!(config.validate().is_err());
987 }
988
989 #[test]
990 fn test_invalid_transition_target() {
991 let mut states = HashMap::new();
992 states.insert(
993 "start".into(),
994 StateDefinition {
995 transitions: vec![Transition {
996 to: "nonexistent".into(),
997 when: "always".into(),
998 guard: None,
999 intent: None,
1000 auto: true,
1001 priority: 0,
1002 cooldown_turns: None,
1003 }],
1004 ..Default::default()
1005 },
1006 );
1007 let config = StateConfig {
1008 initial: "start".into(),
1009 states,
1010 global_transitions: vec![],
1011 fallback: None,
1012 max_no_transition: None,
1013 regenerate_on_transition: true,
1014 };
1015 assert!(config.validate().is_err());
1016 }
1017
1018 #[test]
1019 fn test_hierarchical_states() {
1020 let yaml = r#"
1021initial: problem_solving
1022states:
1023 problem_solving:
1024 initial: gathering_info
1025 prompt: "Solving customer problem"
1026 states:
1027 gathering_info:
1028 prompt: "Ask questions"
1029 transitions:
1030 - to: proposing_solution
1031 when: "understood"
1032 proposing_solution:
1033 prompt: "Offer solution"
1034 transitions:
1035 - to: ^closing
1036 when: "resolved"
1037 closing:
1038 prompt: "Thank you"
1039"#;
1040 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1041 assert!(config.validate().is_ok());
1042 assert!(
1043 config
1044 .states
1045 .get("problem_solving")
1046 .unwrap()
1047 .has_sub_states()
1048 );
1049 }
1050
1051 #[test]
1052 fn test_tool_ref_simple() {
1053 let yaml = r#"
1054tools:
1055 - calculator
1056 - search
1057"#;
1058 #[derive(Deserialize)]
1059 struct Test {
1060 tools: Vec<ToolRef>,
1061 }
1062 let t: Test = serde_yaml::from_str(yaml).unwrap();
1063 assert_eq!(t.tools.len(), 2);
1064 assert_eq!(t.tools[0].id(), "calculator");
1065 }
1066
1067 #[test]
1068 fn test_tool_ref_conditional() {
1069 let yaml = r#"
1070tools:
1071 - calculator
1072 - id: admin_tool
1073 condition:
1074 context:
1075 user.role: "admin"
1076"#;
1077 #[derive(Deserialize)]
1078 struct Test {
1079 tools: Vec<ToolRef>,
1080 }
1081 let t: Test = serde_yaml::from_str(yaml).unwrap();
1082 assert_eq!(t.tools.len(), 2);
1083 assert_eq!(t.tools[1].id(), "admin_tool");
1084 assert!(t.tools[1].condition().is_some());
1085 }
1086
1087 #[test]
1088 fn test_transition_with_guard() {
1089 let yaml = r#"
1090to: next_state
1091when: "user wants to proceed"
1092guard: "{{ context.has_data }}"
1093auto: true
1094priority: 10
1095"#;
1096 let t: Transition = serde_yaml::from_str(yaml).unwrap();
1097 assert!(t.guard.is_some());
1098 assert_eq!(t.priority, 10);
1099 }
1100
1101 #[test]
1102 fn test_state_action() {
1103 let yaml = r#"
1104- tool: log_event
1105 args:
1106 event: "entered"
1107- skill: greeting_skill
1108- set_context:
1109 entered: true
1110"#;
1111 let actions: Vec<StateAction> = serde_yaml::from_str(yaml).unwrap();
1112 assert_eq!(actions.len(), 3);
1113 match &actions[0] {
1114 StateAction::Tool { tool, .. } => assert_eq!(tool, "log_event"),
1115 _ => panic!("Expected Tool action"),
1116 }
1117 match &actions[1] {
1118 StateAction::Skill { skill } => assert_eq!(skill, "greeting_skill"),
1119 _ => panic!("Expected Skill action"),
1120 }
1121 match &actions[2] {
1122 StateAction::SetContext { set_context } => {
1123 assert!(set_context.contains_key("entered"));
1124 }
1125 _ => panic!("Expected SetContext action"),
1126 }
1127 }
1128
1129 #[test]
1130 fn test_complex_tool_condition() {
1131 let yaml = r#"
1132id: refund_tool
1133condition:
1134 all:
1135 - context:
1136 user.verified: true
1137 - semantic:
1138 when: "user wants refund"
1139 threshold: 0.85
1140"#;
1141 let tool: ToolRef = serde_yaml::from_str(yaml).unwrap();
1142 assert_eq!(tool.id(), "refund_tool");
1143 match tool.condition().unwrap() {
1144 ToolCondition::All(conditions) => assert_eq!(conditions.len(), 2),
1145 _ => panic!("Expected All condition"),
1146 }
1147 }
1148
1149 #[test]
1150 fn test_state_get_path() {
1151 let yaml = r#"
1152initial: problem_solving
1153states:
1154 problem_solving:
1155 initial: gathering_info
1156 states:
1157 gathering_info:
1158 prompt: "Ask"
1159 proposing:
1160 prompt: "Propose"
1161 closing:
1162 prompt: "Done"
1163"#;
1164 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1165 assert!(config.get_state("problem_solving").is_some());
1166 assert!(config.get_state("problem_solving.gathering_info").is_some());
1167 assert!(config.get_state("closing").is_some());
1168 assert!(config.get_state("nonexistent").is_none());
1169 }
1170
1171 #[test]
1172 fn test_resolve_full_path() {
1173 let yaml = r#"
1174initial: problem_solving
1175states:
1176 problem_solving:
1177 initial: gathering_info
1178 states:
1179 gathering_info:
1180 prompt: "Ask"
1181 proposing:
1182 prompt: "Propose"
1183 closing:
1184 prompt: "Done"
1185"#;
1186 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1187
1188 assert_eq!(
1189 config.resolve_full_path("problem_solving.gathering_info", "proposing"),
1190 "problem_solving.proposing"
1191 );
1192 assert_eq!(
1193 config.resolve_full_path("problem_solving.gathering_info", "^closing"),
1194 "closing"
1195 );
1196 assert_eq!(
1197 config.resolve_full_path("problem_solving", "closing"),
1198 "closing"
1199 );
1200 }
1201
1202 #[test]
1203 fn test_inherit_parent() {
1204 let parent = StateDefinition {
1205 tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1206 skills: vec!["parent_skill".into()],
1207 ..Default::default()
1208 };
1209
1210 let child = StateDefinition {
1211 tools: Some(vec![ToolRef::Simple("child_tool".into())]),
1212 skills: vec!["child_skill".into()],
1213 inherit_parent: true,
1214 ..Default::default()
1215 };
1216
1217 let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1218 assert_eq!(effective_tools.len(), 1); let effective_skills = child.get_effective_skills(Some(&parent));
1221 assert_eq!(effective_skills.len(), 2);
1222 }
1223
1224 #[test]
1225 fn test_no_inherit_parent() {
1226 let parent = StateDefinition {
1227 tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1228 ..Default::default()
1229 };
1230
1231 let child = StateDefinition {
1232 tools: Some(vec![ToolRef::Simple("child_tool".into())]),
1233 inherit_parent: false,
1234 ..Default::default()
1235 };
1236
1237 let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1238 assert_eq!(effective_tools.len(), 1);
1239 assert_eq!(effective_tools[0].id(), "child_tool");
1240 }
1241
1242 #[test]
1243 fn test_tools_none_inherits() {
1244 let parent = StateDefinition {
1245 tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1246 ..Default::default()
1247 };
1248
1249 let child = StateDefinition {
1250 tools: None, inherit_parent: true,
1252 ..Default::default()
1253 };
1254
1255 let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1256 assert_eq!(effective_tools.len(), 1);
1257 assert_eq!(effective_tools[0].id(), "parent_tool");
1258 }
1259
1260 #[test]
1261 fn test_tools_empty_means_no_tools() {
1262 let parent = StateDefinition {
1263 tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1264 ..Default::default()
1265 };
1266
1267 let child = StateDefinition {
1268 tools: Some(vec![]), inherit_parent: true,
1270 ..Default::default()
1271 };
1272
1273 let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1274 assert!(effective_tools.is_empty());
1275 }
1276
1277 #[test]
1278 fn test_state_with_disambiguation_override() {
1279 let yaml = r#"
1280initial: greeting
1281states:
1282 greeting:
1283 prompt: "Hello"
1284 transitions:
1285 - to: payment
1286 when: "User wants to pay"
1287 payment:
1288 prompt: "Processing payment"
1289 disambiguation:
1290 threshold: 0.95
1291 require_confirmation: true
1292 required_clarity:
1293 - recipient
1294 - amount
1295"#;
1296 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1297 let payment = config.get_state("payment").unwrap();
1298 let disambig = payment.disambiguation.as_ref().unwrap();
1299 assert_eq!(disambig.threshold, Some(0.95));
1300 assert!(disambig.require_confirmation);
1301 assert_eq!(disambig.required_clarity.len(), 2);
1302 assert!(disambig.required_clarity.contains(&"recipient".to_string()));
1303
1304 let greeting = config.get_state("greeting").unwrap();
1305 assert!(greeting.disambiguation.is_none());
1306 }
1307
1308 #[test]
1309 fn test_context_extractor_vec_deserialize() {
1310 let yaml = r#"
1311initial: a
1312states:
1313 a:
1314 extract:
1315 - key: user_email
1316 description: "The user's email address"
1317 - key: order_id
1318 llm_extract: "Extract the order ID"
1319 required: true
1320"#;
1321 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1322 let state = config.get_state("a").unwrap();
1323 assert_eq!(state.extract.len(), 2);
1324 assert_eq!(state.extract[0].key, "user_email");
1325 assert_eq!(
1326 state.extract[0].description.as_deref(),
1327 Some("The user's email address")
1328 );
1329 assert!(!state.extract[0].required);
1330 assert_eq!(state.extract[0].llm, "router");
1331 assert_eq!(state.extract[1].key, "order_id");
1332 assert!(state.extract[1].required);
1333 assert!(state.extract[1].llm_extract.is_some());
1334 }
1335
1336 #[test]
1337 fn test_context_extractor_default_empty() {
1338 let yaml = r#"
1339initial: a
1340states:
1341 a:
1342 prompt: "Hello"
1343"#;
1344 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1345 let state = config.get_state("a").unwrap();
1346 assert!(state.extract.is_empty());
1347 }
1348
1349 #[test]
1350 fn test_state_process_override_deserialize() {
1351 let yaml = r#"
1352initial: a
1353states:
1354 a:
1355 process:
1356 input:
1357 - type: normalize
1358 config:
1359 trim: true
1360"#;
1361 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1362 let state = config.get_state("a").unwrap();
1363 assert!(state.process.is_some());
1364 assert_eq!(state.process.as_ref().unwrap().input.len(), 1);
1365 }
1366
1367 #[test]
1368 fn test_state_process_default_none() {
1369 let yaml = r#"
1370initial: a
1371states:
1372 a:
1373 prompt: "Hello"
1374"#;
1375 let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1376 let state = config.get_state("a").unwrap();
1377 assert!(state.process.is_none());
1378 }
1379}