Skip to main content

ai_agents_state/
config.rs

1use ai_agents_core::{AgentError, Result};
2use ai_agents_disambiguation::StateDisambiguationOverride;
3use ai_agents_process::ProcessConfig;
4use ai_agents_reasoning::{ReasoningConfig, ReflectionConfig};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
10#[serde(rename_all = "lowercase")]
11pub enum PromptMode {
12    #[default]
13    Append,
14    Replace,
15    Prepend,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StateConfig {
20    pub initial: String,
21    #[serde(default)]
22    pub states: HashMap<String, StateDefinition>,
23    #[serde(default)]
24    pub global_transitions: Vec<Transition>,
25    #[serde(default)]
26    pub fallback: Option<String>,
27    #[serde(default)]
28    pub max_no_transition: Option<u32>,
29
30    /// Whether to re-generate a response after state transitions (default: true).
31    #[serde(default = "default_true")]
32    pub regenerate_on_transition: bool,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, Default)]
36pub struct StateDefinition {
37    #[serde(default)]
38    pub prompt: Option<String>,
39
40    #[serde(default)]
41    pub prompt_mode: PromptMode,
42
43    #[serde(default)]
44    pub llm: Option<String>,
45
46    #[serde(default)]
47    pub skills: Vec<String>,
48
49    /// Tool availability for this state.
50    /// - `None` (omitted in YAML): inherit from parent or agent-level tools
51    /// - `Some([])` (`tools: []` in YAML): explicitly no tools available
52    /// - `Some([...])`: only these tools available
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub tools: Option<Vec<ToolRef>>,
55
56    #[serde(default)]
57    pub transitions: Vec<Transition>,
58
59    #[serde(default)]
60    pub max_turns: Option<u32>,
61
62    #[serde(default)]
63    pub timeout_to: Option<String>,
64
65    #[serde(default)]
66    pub initial: Option<String>,
67
68    #[serde(default)]
69    pub states: Option<HashMap<String, StateDefinition>>,
70
71    #[serde(default = "default_inherit_parent")]
72    pub inherit_parent: bool,
73
74    #[serde(default)]
75    pub on_enter: Vec<StateAction>,
76
77    /// Actions on re-entering a previously visited state. Falls back to on_enter if empty.
78    #[serde(default)]
79    pub on_reenter: Vec<StateAction>,
80
81    #[serde(default)]
82    pub on_exit: Vec<StateAction>,
83
84    /// Per-state override: skip re-generation on entering this state.
85    #[serde(default, skip_serializing_if = "Option::is_none")]
86    pub regenerate_on_enter: Option<bool>,
87
88    /// Context extractors: pull structured data from user input into context.
89    #[serde(default)]
90    pub extract: Vec<ContextExtractor>,
91
92    #[serde(default, skip_serializing_if = "Option::is_none")]
93    pub reasoning: Option<ReasoningConfig>,
94
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub reflection: Option<ReflectionConfig>,
97
98    #[serde(default, skip_serializing_if = "Option::is_none")]
99    pub disambiguation: Option<StateDisambiguationOverride>,
100
101    /// Per-state process pipeline override (replaces agent-level pipeline for this state).
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub process: Option<ProcessConfig>,
104
105    /// Delegate state messages to a registry agent by ID.
106    #[serde(default, skip_serializing_if = "Option::is_none")]
107    pub delegate: Option<String>,
108
109    /// Context mode for delegated states.
110    #[serde(default, skip_serializing_if = "Option::is_none")]
111    pub delegate_context: Option<DelegateContextMode>,
112
113    /// Run multiple registry agents concurrently in this state.
114    #[serde(default, skip_serializing_if = "Option::is_none")]
115    pub concurrent: Option<ConcurrentStateConfig>,
116
117    /// Run a multi-agent group chat in this state.
118    #[serde(default, skip_serializing_if = "Option::is_none")]
119    pub group_chat: Option<GroupChatStateConfig>,
120
121    /// Run a sequential agent pipeline in this state.
122    #[serde(default, skip_serializing_if = "Option::is_none")]
123    pub pipeline: Option<PipelineStateConfig>,
124
125    /// Run an LLM-directed handoff chain in this state.
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub handoff: Option<HandoffStateConfig>,
128}
129
130fn default_inherit_parent() -> bool {
131    true
132}
133
134fn default_true() -> bool {
135    true
136}
137
138fn default_extractor_llm() -> String {
139    "router".to_string()
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum ToolRef {
145    Simple(String),
146    Conditional {
147        id: String,
148        condition: ToolCondition,
149    },
150}
151
152impl ToolRef {
153    pub fn id(&self) -> &str {
154        match self {
155            ToolRef::Simple(id) => id,
156            ToolRef::Conditional { id, .. } => id,
157        }
158    }
159
160    pub fn condition(&self) -> Option<&ToolCondition> {
161        match self {
162            ToolRef::Simple(_) => None,
163            ToolRef::Conditional { condition, .. } => Some(condition),
164        }
165    }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum ToolCondition {
171    Context(HashMap<String, ContextMatcher>),
172    State(StateMatcher),
173    AfterTool(String),
174    ToolResult {
175        tool: String,
176        result: HashMap<String, Value>,
177    },
178    Semantic {
179        when: String,
180        #[serde(default = "default_semantic_llm")]
181        llm: String,
182        #[serde(default = "default_threshold")]
183        threshold: f32,
184    },
185    Time(TimeMatcher),
186    All(Vec<ToolCondition>),
187    Any(Vec<ToolCondition>),
188    Not(Box<ToolCondition>),
189}
190
191fn default_semantic_llm() -> String {
192    "router".to_string()
193}
194
195fn default_threshold() -> f32 {
196    0.7
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200#[serde(untagged)]
201pub enum ContextMatcher {
202    // Order matters for serde untagged: structured variants must come before
203    // Exact(Value) because Value matches any valid JSON — including objects
204    // like `{ "exists": true }` or `{ "eq": "admin" }` that should be parsed
205    // as Exists or Compare instead.
206    Exists { exists: bool },
207    Compare(CompareOp),
208    Exact(Value),
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212#[serde(rename_all = "snake_case")]
213pub enum CompareOp {
214    Eq(Value),
215    Neq(Value),
216    Gt(f64),
217    Gte(f64),
218    Lt(f64),
219    Lte(f64),
220    In(Vec<Value>),
221    Contains(String),
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize, Default)]
225pub struct StateMatcher {
226    #[serde(default)]
227    pub name: Option<String>,
228    #[serde(default)]
229    pub turn_count: Option<CompareOp>,
230    #[serde(default)]
231    pub previous: Option<String>,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, Default)]
235pub struct TimeMatcher {
236    #[serde(default)]
237    pub hours: Option<CompareOp>,
238    #[serde(default)]
239    pub day_of_week: Option<Vec<String>>,
240    #[serde(default)]
241    pub timezone: Option<String>,
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct Transition {
246    pub to: String,
247    #[serde(default)]
248    pub when: String,
249    #[serde(default)]
250    pub guard: Option<TransitionGuard>,
251    /// Intent label for deterministic routing after disambiguation.
252    #[serde(default, skip_serializing_if = "Option::is_none")]
253    pub intent: Option<String>,
254    #[serde(default = "default_auto")]
255    pub auto: bool,
256    #[serde(default)]
257    pub priority: u8,
258
259    /// Minimum turns before this transition can fire again after last use.
260    #[serde(default, skip_serializing_if = "Option::is_none")]
261    pub cooldown_turns: Option<u32>,
262}
263
264fn default_auto() -> bool {
265    true
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
269#[serde(untagged)]
270pub enum TransitionGuard {
271    Expression(String),
272    Conditions(GuardConditions),
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276#[serde(rename_all = "snake_case")]
277pub enum GuardConditions {
278    All(Vec<String>),
279    Any(Vec<String>),
280    Context(HashMap<String, ContextMatcher>),
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
284#[serde(untagged)]
285pub enum StateAction {
286    Tool {
287        tool: String,
288        #[serde(default)]
289        args: Option<Value>,
290    },
291    Skill {
292        skill: String,
293    },
294    Prompt {
295        prompt: String,
296        #[serde(default)]
297        llm: Option<String>,
298        #[serde(default)]
299        store_as: Option<String>,
300    },
301    SetContext {
302        set_context: HashMap<String, Value>,
303    },
304}
305
306/// Extract structured data from conversation into context via LLM.
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct ContextExtractor {
309    /// Context key to store the extracted value.
310    pub key: String,
311
312    /// Short description of what to extract (LLM-based).
313    #[serde(default)]
314    pub description: Option<String>,
315
316    /// Custom LLM extraction prompt (takes precedence over `description`).
317    #[serde(default)]
318    pub llm_extract: Option<String>,
319
320    /// LLM alias for extraction (default: "router").
321    #[serde(default = "default_extractor_llm")]
322    pub llm: String,
323
324    /// If true, extraction failure is logged as a warning.
325    #[serde(default)]
326    pub required: bool,
327}
328
329//
330// Multi-agent orchestration config types for state delegation, concurrent execution, and group chat.
331//
332
333/// Context mode for delegated states.
334#[derive(Debug, Clone, Default, Serialize, Deserialize)]
335#[serde(rename_all = "snake_case")]
336pub enum DelegateContextMode {
337    /// Delegated agent receives only the user's current message.
338    #[default]
339    InputOnly,
340    /// Parent summarizes recent conversation via router LLM.
341    Summary,
342    /// Parent passes full recent message history.
343    Full,
344}
345
346/// Config for running multiple registry agents concurrently.
347#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct ConcurrentStateConfig {
349    /// Agent IDs in the registry (simple list or weighted entries).
350    pub agents: Vec<ConcurrentAgentRef>,
351    /// Jinja2 template for input sent to each agent.
352    #[serde(default, skip_serializing_if = "Option::is_none")]
353    pub input: Option<String>,
354    /// How to aggregate results from all agents.
355    pub aggregation: AggregationConfig,
356    /// Minimum agents that must succeed.
357    #[serde(default, skip_serializing_if = "Option::is_none")]
358    pub min_required: Option<usize>,
359    /// What to do when some agents fail.
360    #[serde(default)]
361    pub on_partial_failure: PartialFailureAction,
362    /// Per-agent timeout in milliseconds.
363    #[serde(default, skip_serializing_if = "Option::is_none")]
364    pub timeout_ms: Option<u64>,
365    /// Parent conversation context forwarded to each agent.
366    #[serde(default, skip_serializing_if = "Option::is_none")]
367    pub context_mode: Option<DelegateContextMode>,
368}
369
370/// Either a plain agent ID string or a weighted entry.
371#[derive(Debug, Clone, Serialize, Deserialize)]
372#[serde(untagged)]
373pub enum ConcurrentAgentRef {
374    Id(String),
375    Weighted { id: String, weight: f64 },
376}
377
378impl ConcurrentAgentRef {
379    pub fn id(&self) -> &str {
380        match self {
381            Self::Id(id) => id,
382            Self::Weighted { id, .. } => id,
383        }
384    }
385
386    pub fn weight(&self) -> f64 {
387        match self {
388            Self::Id(_) => 1.0,
389            Self::Weighted { weight, .. } => *weight,
390        }
391    }
392}
393
394/// How to aggregate results from concurrent agents.
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct AggregationConfig {
397    /// Aggregation strategy.
398    pub strategy: AggregationStrategy,
399    /// LLM alias for synthesis or vote extraction.
400    #[serde(default, skip_serializing_if = "Option::is_none")]
401    pub synthesizer_llm: Option<String>,
402    /// Custom prompt for LLM synthesis.
403    #[serde(default, skip_serializing_if = "Option::is_none")]
404    pub synthesizer_prompt: Option<String>,
405    /// Voting sub-config.
406    #[serde(default, skip_serializing_if = "Option::is_none")]
407    pub vote: Option<VoteConfig>,
408}
409
410#[derive(Debug, Clone, Serialize, Deserialize)]
411#[serde(rename_all = "snake_case")]
412pub enum AggregationStrategy {
413    Voting,
414    LlmSynthesis,
415    FirstWins,
416    All,
417}
418
419/// Voting config for concurrent agent aggregation.
420#[derive(Debug, Clone, Serialize, Deserialize)]
421pub struct VoteConfig {
422    #[serde(default)]
423    pub method: VoteMethod,
424    #[serde(default)]
425    pub tiebreaker: TiebreakerStrategy,
426    /// Custom prompt for extracting a vote from each agent's response.
427    #[serde(default, skip_serializing_if = "Option::is_none")]
428    pub vote_prompt: Option<String>,
429}
430
431#[derive(Debug, Clone, Default, Serialize, Deserialize)]
432#[serde(rename_all = "snake_case")]
433pub enum VoteMethod {
434    #[default]
435    Majority,
436    Weighted,
437    Unanimous,
438}
439
440#[derive(Debug, Clone, Default, Serialize, Deserialize)]
441#[serde(rename_all = "snake_case")]
442pub enum TiebreakerStrategy {
443    #[default]
444    First,
445    Random,
446    RouterDecides,
447}
448
449#[derive(Debug, Clone, Default, Serialize, Deserialize)]
450#[serde(rename_all = "snake_case")]
451pub enum PartialFailureAction {
452    #[default]
453    ProceedWithAvailable,
454    Abort,
455}
456
457/// Group chat state config for multi-agent conversation.
458#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct GroupChatStateConfig {
460    /// Participant agent IDs with optional roles.
461    pub participants: Vec<ChatParticipant>,
462    /// Conversation style.
463    #[serde(default)]
464    pub style: ChatStyle,
465    /// Maximum conversation rounds.
466    #[serde(default = "default_max_rounds")]
467    pub max_rounds: u32,
468    /// Chat manager config.
469    #[serde(default, skip_serializing_if = "Option::is_none")]
470    pub manager: Option<ChatManagerConfig>,
471    /// When and how to terminate.
472    #[serde(default)]
473    pub termination: TerminationConfig,
474    /// Debate-specific config.
475    #[serde(default, skip_serializing_if = "Option::is_none")]
476    pub debate: Option<DebateStyleConfig>,
477    /// Maker-checker-specific config.
478    #[serde(default, skip_serializing_if = "Option::is_none")]
479    pub maker_checker: Option<MakerCheckerConfig>,
480    /// Total timeout for the group chat in milliseconds.
481    #[serde(default, skip_serializing_if = "Option::is_none")]
482    pub timeout_ms: Option<u64>,
483    /// Jinja2 template for the topic sent to participants.
484    /// {{ user_input }} is the user's message. {{ context.<key> }} accesses context values. When omitted, the raw user message is used as the topic.
485    #[serde(default, skip_serializing_if = "Option::is_none")]
486    pub input: Option<String>,
487    /// Parent conversation context included in the topic.
488    #[serde(default, skip_serializing_if = "Option::is_none")]
489    pub context_mode: Option<DelegateContextMode>,
490}
491
492/// A participant in a group chat.
493#[derive(Debug, Clone, Serialize, Deserialize)]
494pub struct ChatParticipant {
495    /// Agent ID in the registry.
496    pub id: String,
497    /// Role description visible to all participants.
498    #[serde(default, skip_serializing_if = "Option::is_none")]
499    pub role: Option<String>,
500}
501
502#[derive(Debug, Clone, Default, Serialize, Deserialize)]
503#[serde(rename_all = "snake_case")]
504pub enum ChatStyle {
505    #[default]
506    Brainstorm,
507    Debate,
508    MakerChecker,
509    Consensus,
510}
511
512/// Chat manager config for controlling turn order.
513#[derive(Debug, Clone, Serialize, Deserialize)]
514pub struct ChatManagerConfig {
515    /// Registry agent ID for chat management.
516    #[serde(default, skip_serializing_if = "Option::is_none")]
517    pub agent: Option<String>,
518    /// Built-in turn policy.
519    #[serde(default, skip_serializing_if = "Option::is_none")]
520    pub method: Option<TurnMethod>,
521}
522
523#[derive(Debug, Clone, Serialize, Deserialize)]
524#[serde(rename_all = "snake_case")]
525pub enum TurnMethod {
526    RoundRobin,
527    Random,
528    LlmDirected,
529}
530
531/// Termination config for group chat.
532#[derive(Debug, Clone, Serialize, Deserialize)]
533pub struct TerminationConfig {
534    #[serde(default)]
535    pub method: TerminationMethod,
536    #[serde(default = "default_stall_rounds")]
537    pub max_stall_rounds: u32,
538}
539
540impl Default for TerminationConfig {
541    fn default() -> Self {
542        Self {
543            method: TerminationMethod::default(),
544            max_stall_rounds: default_stall_rounds(),
545        }
546    }
547}
548
549#[derive(Debug, Clone, Default, Serialize, Deserialize)]
550#[serde(rename_all = "snake_case")]
551pub enum TerminationMethod {
552    #[default]
553    ManagerDecides,
554    MaxRounds,
555    ConsensusReached,
556}
557
558/// Debate-specific config for group chat.
559#[derive(Debug, Clone, Serialize, Deserialize)]
560pub struct DebateStyleConfig {
561    #[serde(default = "default_debate_rounds")]
562    pub rounds: u32,
563    /// Agent ID that synthesizes the final answer.
564    pub synthesizer: String,
565}
566
567/// Maker-checker-specific config for group chat.
568#[derive(Debug, Clone, Serialize, Deserialize)]
569pub struct MakerCheckerConfig {
570    #[serde(default = "default_maker_checker_iterations")]
571    pub max_iterations: u32,
572    /// LLM-evaluated acceptance criteria.
573    pub acceptance_criteria: String,
574    #[serde(default)]
575    pub on_max_iterations: MaxIterationsAction,
576}
577
578#[derive(Debug, Clone, Default, Serialize, Deserialize)]
579#[serde(rename_all = "snake_case")]
580pub enum MaxIterationsAction {
581    #[default]
582    AcceptLast,
583    Escalate,
584    Fail,
585}
586
587fn default_max_rounds() -> u32 {
588    5
589}
590fn default_stall_rounds() -> u32 {
591    2
592}
593fn default_debate_rounds() -> u32 {
594    3
595}
596fn default_maker_checker_iterations() -> u32 {
597    3
598}
599
600/// Config for a pipeline state type.
601#[derive(Debug, Clone, Serialize, Deserialize)]
602pub struct PipelineStateConfig {
603    pub stages: Vec<PipelineStageEntry>,
604
605    #[serde(default, skip_serializing_if = "Option::is_none")]
606    pub timeout_ms: Option<u64>,
607    /// Parent conversation context forwarded to the first stage.
608    #[serde(default, skip_serializing_if = "Option::is_none")]
609    pub context_mode: Option<DelegateContextMode>,
610}
611
612/// A single stage in a pipeline state.
613#[derive(Debug, Clone, Serialize, Deserialize)]
614#[serde(untagged)]
615pub enum PipelineStageEntry {
616    /// Simple agent ID string.
617    Id(String),
618    /// Agent with optional input template.
619    Config {
620        id: String,
621        #[serde(default, skip_serializing_if = "Option::is_none")]
622        input: Option<String>,
623    },
624}
625
626impl PipelineStageEntry {
627    pub fn id(&self) -> &str {
628        match self {
629            Self::Id(id) => id,
630            Self::Config { id, .. } => id,
631        }
632    }
633
634    pub fn input(&self) -> Option<&str> {
635        match self {
636            Self::Id(_) => None,
637            Self::Config { input, .. } => input.as_deref(),
638        }
639    }
640}
641
642/// Config for a handoff state type.
643#[derive(Debug, Clone, Serialize, Deserialize)]
644pub struct HandoffStateConfig {
645    pub initial_agent: String,
646    pub available_agents: Vec<String>,
647
648    #[serde(default = "default_max_handoffs")]
649    pub max_handoffs: u32,
650
651    /// Jinja2 template for the input sent to the initial agent.
652    /// {{ user_input }} is the user's message. {{ context.<key> }} accesses context values. When omitted, the raw user message is forwarded directly.
653    #[serde(default, skip_serializing_if = "Option::is_none")]
654    pub input: Option<String>,
655    /// Parent conversation context forwarded to the initial agent.
656    #[serde(default, skip_serializing_if = "Option::is_none")]
657    pub context_mode: Option<DelegateContextMode>,
658}
659
660fn default_max_handoffs() -> u32 {
661    5
662}
663
664impl StateConfig {
665    pub fn validate(&self) -> Result<()> {
666        if self.initial.is_empty() {
667            return Err(AgentError::InvalidSpec(
668                "State machine initial state cannot be empty".into(),
669            ));
670        }
671        if !self.states.contains_key(&self.initial) {
672            return Err(AgentError::InvalidSpec(format!(
673                "Initial state '{}' not found in states",
674                self.initial
675            )));
676        }
677        self.validate_states(&self.states, &[])?;
678
679        // Warn about unreachable states (non-fatal)
680        for warning in self.check_reachability() {
681            tracing::warn!("{}", warning);
682        }
683
684        Ok(())
685    }
686
687    fn validate_states(
688        &self,
689        states: &HashMap<String, StateDefinition>,
690        parent_path: &[String],
691    ) -> Result<()> {
692        for (name, def) in states {
693            let current_path: Vec<String> = parent_path
694                .iter()
695                .cloned()
696                .chain(std::iter::once(name.clone()))
697                .collect();
698
699            for transition in &def.transitions {
700                if !self.is_valid_transition_target(&transition.to, &current_path, states) {
701                    return Err(AgentError::InvalidSpec(format!(
702                        "State '{}' has transition to unknown state '{}'",
703                        current_path.join("."),
704                        transition.to
705                    )));
706                }
707            }
708
709            if let Some(ref timeout_state) = def.timeout_to {
710                if !self.is_valid_transition_target(timeout_state, &current_path, states) {
711                    return Err(AgentError::InvalidSpec(format!(
712                        "State '{}' has timeout_to unknown state '{}'",
713                        current_path.join("."),
714                        timeout_state
715                    )));
716                }
717            }
718
719            if let Some(ref sub_states) = def.states {
720                if let Some(ref initial) = def.initial {
721                    if !sub_states.contains_key(initial) {
722                        return Err(AgentError::InvalidSpec(format!(
723                            "State '{}' has initial sub-state '{}' that doesn't exist",
724                            current_path.join("."),
725                            initial
726                        )));
727                    }
728                }
729                self.validate_states(sub_states, &current_path)?;
730            }
731        }
732        Ok(())
733    }
734
735    fn is_valid_transition_target(
736        &self,
737        target: &str,
738        current_path: &[String],
739        states: &HashMap<String, StateDefinition>,
740    ) -> bool {
741        if target.starts_with('^') {
742            let target_name = &target[1..];
743            return self.states.contains_key(target_name);
744        }
745
746        if states.contains_key(target) {
747            return true;
748        }
749
750        if current_path.len() > 1 {
751            let parent_path = &current_path[..current_path.len() - 1];
752            if let Some(parent_states) = self.get_states_at_path(parent_path) {
753                if parent_states.contains_key(target) {
754                    return true;
755                }
756            }
757        }
758
759        self.states.contains_key(target)
760    }
761
762    fn get_states_at_path(&self, path: &[String]) -> Option<&HashMap<String, StateDefinition>> {
763        let mut current = &self.states;
764        for segment in path {
765            if let Some(def) = current.get(segment) {
766                if let Some(ref sub_states) = def.states {
767                    current = sub_states;
768                } else {
769                    return None;
770                }
771            } else {
772                return None;
773            }
774        }
775        Some(current)
776    }
777
778    pub fn get_state(&self, path: &str) -> Option<&StateDefinition> {
779        let parts: Vec<&str> = path.split('.').collect();
780        self.get_state_by_path(&parts)
781    }
782
783    fn get_state_by_path(&self, path: &[&str]) -> Option<&StateDefinition> {
784        if path.is_empty() {
785            return None;
786        }
787
788        let mut current = self.states.get(path[0])?;
789        for segment in &path[1..] {
790            if let Some(ref sub_states) = current.states {
791                current = sub_states.get(*segment)?;
792            } else {
793                return None;
794            }
795        }
796        Some(current)
797    }
798
799    /// Resolve a transition target to a full dotted state path.
800    /// Order: `^prefix` (parent-level) → top-level → sibling → child → fallback literal.
801    pub fn resolve_full_path(&self, current_path: &str, target: &str) -> String {
802        if target.starts_with('^') {
803            return target[1..].to_string();
804        }
805
806        if self.states.contains_key(target) {
807            return target.to_string();
808        }
809
810        if !current_path.is_empty() {
811            let parts: Vec<&str> = current_path.split('.').collect();
812            if parts.len() > 1 {
813                let parent_path = parts[..parts.len() - 1].join(".");
814                let potential = format!("{}.{}", parent_path, target);
815                if self.get_state(&potential).is_some() {
816                    return potential;
817                }
818            }
819
820            let potential = format!("{}.{}", current_path, target);
821            if self.get_state(&potential).is_some() {
822                return potential;
823            }
824        }
825
826        target.to_string()
827    }
828
829    /// Check for unreachable states. Returns warning messages.
830    pub fn check_reachability(&self) -> Vec<String> {
831        let mut reachable: HashSet<String> = HashSet::new();
832        reachable.insert(self.initial.clone());
833
834        if let Some(ref fb) = self.fallback {
835            reachable.insert(fb.clone());
836        }
837        for gt in &self.global_transitions {
838            reachable.insert(self.normalize_target(&gt.to));
839        }
840
841        let mut queue: Vec<String> = reachable.iter().cloned().collect();
842        while let Some(state_path) = queue.pop() {
843            if let Some(def) = self.get_state(&state_path) {
844                for t in &def.transitions {
845                    let target = self.resolve_full_path(&state_path, &t.to);
846                    if reachable.insert(target.clone()) {
847                        queue.push(target);
848                    }
849                }
850                if let Some(ref timeout) = def.timeout_to {
851                    let target = self.resolve_full_path(&state_path, timeout);
852                    if reachable.insert(target.clone()) {
853                        queue.push(target);
854                    }
855                }
856                if let (Some(initial), Some(_sub)) = (&def.initial, &def.states) {
857                    let sub_path = format!("{}.{}", state_path, initial);
858                    if reachable.insert(sub_path.clone()) {
859                        queue.push(sub_path);
860                    }
861                }
862            }
863        }
864
865        let all_states = self.collect_all_state_paths(&self.states, &[]);
866        let mut warnings = Vec::new();
867        for state_path in &all_states {
868            if !reachable.contains(state_path) {
869                warnings.push(format!(
870                    "State '{}' appears unreachable — no transitions lead to it",
871                    state_path
872                ));
873            }
874        }
875        warnings
876    }
877
878    fn normalize_target(&self, target: &str) -> String {
879        if target.starts_with('^') {
880            target[1..].to_string()
881        } else {
882            target.to_string()
883        }
884    }
885
886    fn collect_all_state_paths(
887        &self,
888        states: &HashMap<String, StateDefinition>,
889        parent: &[String],
890    ) -> Vec<String> {
891        let mut paths = Vec::new();
892        for (name, def) in states {
893            let mut current: Vec<String> = parent.to_vec();
894            current.push(name.clone());
895            paths.push(current.join("."));
896            if let Some(ref sub) = def.states {
897                paths.extend(self.collect_all_state_paths(sub, &current));
898            }
899        }
900        paths
901    }
902}
903
904impl StateDefinition {
905    pub fn has_sub_states(&self) -> bool {
906        self.states.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
907    }
908
909    pub fn get_effective_tools<'a>(
910        &'a self,
911        parent: Option<&'a StateDefinition>,
912    ) -> Option<Vec<&'a ToolRef>> {
913        match &self.tools {
914            // Explicitly set (including empty): use as-is, no inheritance
915            Some(tools) => Some(tools.iter().collect()),
916            // Not set: inherit from parent if available
917            None => {
918                if !self.inherit_parent {
919                    return None;
920                }
921                parent
922                    .and_then(|p| p.tools.as_ref())
923                    .map(|t| t.iter().collect())
924            }
925        }
926    }
927
928    pub fn get_effective_skills<'a>(
929        &'a self,
930        parent: Option<&'a StateDefinition>,
931    ) -> Vec<&'a String> {
932        if !self.inherit_parent || parent.is_none() {
933            return self.skills.iter().collect();
934        }
935
936        let parent = parent.unwrap();
937        let mut skills: Vec<&'a String> = parent.skills.iter().collect();
938        skills.extend(self.skills.iter());
939        skills
940    }
941}
942
943#[cfg(test)]
944mod tests {
945    use super::*;
946
947    #[test]
948    fn test_state_config_deserialize() {
949        let yaml = r#"
950initial: greeting
951states:
952  greeting:
953    prompt: "Welcome!"
954    transitions:
955      - to: support
956        when: "user needs help"
957        auto: true
958  support:
959    prompt: "How can I help?"
960    llm: fast
961    tools:
962      - search
963"#;
964        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
965        assert_eq!(config.initial, "greeting");
966        assert_eq!(config.states.len(), 2);
967        assert!(config.validate().is_ok());
968    }
969
970    #[test]
971    fn test_prompt_mode_default() {
972        let def = StateDefinition::default();
973        assert_eq!(def.prompt_mode, PromptMode::Append);
974    }
975
976    #[test]
977    fn test_invalid_initial_state() {
978        let config = StateConfig {
979            initial: "nonexistent".into(),
980            states: HashMap::new(),
981            global_transitions: vec![],
982            fallback: None,
983            max_no_transition: None,
984            regenerate_on_transition: true,
985        };
986        assert!(config.validate().is_err());
987    }
988
989    #[test]
990    fn test_invalid_transition_target() {
991        let mut states = HashMap::new();
992        states.insert(
993            "start".into(),
994            StateDefinition {
995                transitions: vec![Transition {
996                    to: "nonexistent".into(),
997                    when: "always".into(),
998                    guard: None,
999                    intent: None,
1000                    auto: true,
1001                    priority: 0,
1002                    cooldown_turns: None,
1003                }],
1004                ..Default::default()
1005            },
1006        );
1007        let config = StateConfig {
1008            initial: "start".into(),
1009            states,
1010            global_transitions: vec![],
1011            fallback: None,
1012            max_no_transition: None,
1013            regenerate_on_transition: true,
1014        };
1015        assert!(config.validate().is_err());
1016    }
1017
1018    #[test]
1019    fn test_hierarchical_states() {
1020        let yaml = r#"
1021initial: problem_solving
1022states:
1023  problem_solving:
1024    initial: gathering_info
1025    prompt: "Solving customer problem"
1026    states:
1027      gathering_info:
1028        prompt: "Ask questions"
1029        transitions:
1030          - to: proposing_solution
1031            when: "understood"
1032      proposing_solution:
1033        prompt: "Offer solution"
1034        transitions:
1035          - to: ^closing
1036            when: "resolved"
1037  closing:
1038    prompt: "Thank you"
1039"#;
1040        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1041        assert!(config.validate().is_ok());
1042        assert!(
1043            config
1044                .states
1045                .get("problem_solving")
1046                .unwrap()
1047                .has_sub_states()
1048        );
1049    }
1050
1051    #[test]
1052    fn test_tool_ref_simple() {
1053        let yaml = r#"
1054tools:
1055  - calculator
1056  - search
1057"#;
1058        #[derive(Deserialize)]
1059        struct Test {
1060            tools: Vec<ToolRef>,
1061        }
1062        let t: Test = serde_yaml::from_str(yaml).unwrap();
1063        assert_eq!(t.tools.len(), 2);
1064        assert_eq!(t.tools[0].id(), "calculator");
1065    }
1066
1067    #[test]
1068    fn test_tool_ref_conditional() {
1069        let yaml = r#"
1070tools:
1071  - calculator
1072  - id: admin_tool
1073    condition:
1074      context:
1075        user.role: "admin"
1076"#;
1077        #[derive(Deserialize)]
1078        struct Test {
1079            tools: Vec<ToolRef>,
1080        }
1081        let t: Test = serde_yaml::from_str(yaml).unwrap();
1082        assert_eq!(t.tools.len(), 2);
1083        assert_eq!(t.tools[1].id(), "admin_tool");
1084        assert!(t.tools[1].condition().is_some());
1085    }
1086
1087    #[test]
1088    fn test_transition_with_guard() {
1089        let yaml = r#"
1090to: next_state
1091when: "user wants to proceed"
1092guard: "{{ context.has_data }}"
1093auto: true
1094priority: 10
1095"#;
1096        let t: Transition = serde_yaml::from_str(yaml).unwrap();
1097        assert!(t.guard.is_some());
1098        assert_eq!(t.priority, 10);
1099    }
1100
1101    #[test]
1102    fn test_state_action() {
1103        let yaml = r#"
1104- tool: log_event
1105  args:
1106    event: "entered"
1107- skill: greeting_skill
1108- set_context:
1109    entered: true
1110"#;
1111        let actions: Vec<StateAction> = serde_yaml::from_str(yaml).unwrap();
1112        assert_eq!(actions.len(), 3);
1113        match &actions[0] {
1114            StateAction::Tool { tool, .. } => assert_eq!(tool, "log_event"),
1115            _ => panic!("Expected Tool action"),
1116        }
1117        match &actions[1] {
1118            StateAction::Skill { skill } => assert_eq!(skill, "greeting_skill"),
1119            _ => panic!("Expected Skill action"),
1120        }
1121        match &actions[2] {
1122            StateAction::SetContext { set_context } => {
1123                assert!(set_context.contains_key("entered"));
1124            }
1125            _ => panic!("Expected SetContext action"),
1126        }
1127    }
1128
1129    #[test]
1130    fn test_complex_tool_condition() {
1131        let yaml = r#"
1132id: refund_tool
1133condition:
1134  all:
1135    - context:
1136        user.verified: true
1137    - semantic:
1138        when: "user wants refund"
1139        threshold: 0.85
1140"#;
1141        let tool: ToolRef = serde_yaml::from_str(yaml).unwrap();
1142        assert_eq!(tool.id(), "refund_tool");
1143        match tool.condition().unwrap() {
1144            ToolCondition::All(conditions) => assert_eq!(conditions.len(), 2),
1145            _ => panic!("Expected All condition"),
1146        }
1147    }
1148
1149    #[test]
1150    fn test_state_get_path() {
1151        let yaml = r#"
1152initial: problem_solving
1153states:
1154  problem_solving:
1155    initial: gathering_info
1156    states:
1157      gathering_info:
1158        prompt: "Ask"
1159      proposing:
1160        prompt: "Propose"
1161  closing:
1162    prompt: "Done"
1163"#;
1164        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1165        assert!(config.get_state("problem_solving").is_some());
1166        assert!(config.get_state("problem_solving.gathering_info").is_some());
1167        assert!(config.get_state("closing").is_some());
1168        assert!(config.get_state("nonexistent").is_none());
1169    }
1170
1171    #[test]
1172    fn test_resolve_full_path() {
1173        let yaml = r#"
1174initial: problem_solving
1175states:
1176  problem_solving:
1177    initial: gathering_info
1178    states:
1179      gathering_info:
1180        prompt: "Ask"
1181      proposing:
1182        prompt: "Propose"
1183  closing:
1184    prompt: "Done"
1185"#;
1186        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1187
1188        assert_eq!(
1189            config.resolve_full_path("problem_solving.gathering_info", "proposing"),
1190            "problem_solving.proposing"
1191        );
1192        assert_eq!(
1193            config.resolve_full_path("problem_solving.gathering_info", "^closing"),
1194            "closing"
1195        );
1196        assert_eq!(
1197            config.resolve_full_path("problem_solving", "closing"),
1198            "closing"
1199        );
1200    }
1201
1202    #[test]
1203    fn test_inherit_parent() {
1204        let parent = StateDefinition {
1205            tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1206            skills: vec!["parent_skill".into()],
1207            ..Default::default()
1208        };
1209
1210        let child = StateDefinition {
1211            tools: Some(vec![ToolRef::Simple("child_tool".into())]),
1212            skills: vec!["child_skill".into()],
1213            inherit_parent: true,
1214            ..Default::default()
1215        };
1216
1217        let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1218        assert_eq!(effective_tools.len(), 1); // explicit tools override, no merge
1219
1220        let effective_skills = child.get_effective_skills(Some(&parent));
1221        assert_eq!(effective_skills.len(), 2);
1222    }
1223
1224    #[test]
1225    fn test_no_inherit_parent() {
1226        let parent = StateDefinition {
1227            tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1228            ..Default::default()
1229        };
1230
1231        let child = StateDefinition {
1232            tools: Some(vec![ToolRef::Simple("child_tool".into())]),
1233            inherit_parent: false,
1234            ..Default::default()
1235        };
1236
1237        let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1238        assert_eq!(effective_tools.len(), 1);
1239        assert_eq!(effective_tools[0].id(), "child_tool");
1240    }
1241
1242    #[test]
1243    fn test_tools_none_inherits() {
1244        let parent = StateDefinition {
1245            tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1246            ..Default::default()
1247        };
1248
1249        let child = StateDefinition {
1250            tools: None, // not specified → inherit
1251            inherit_parent: true,
1252            ..Default::default()
1253        };
1254
1255        let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1256        assert_eq!(effective_tools.len(), 1);
1257        assert_eq!(effective_tools[0].id(), "parent_tool");
1258    }
1259
1260    #[test]
1261    fn test_tools_empty_means_no_tools() {
1262        let parent = StateDefinition {
1263            tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1264            ..Default::default()
1265        };
1266
1267        let child = StateDefinition {
1268            tools: Some(vec![]), // explicitly empty → no tools
1269            inherit_parent: true,
1270            ..Default::default()
1271        };
1272
1273        let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1274        assert!(effective_tools.is_empty());
1275    }
1276
1277    #[test]
1278    fn test_state_with_disambiguation_override() {
1279        let yaml = r#"
1280initial: greeting
1281states:
1282  greeting:
1283    prompt: "Hello"
1284    transitions:
1285      - to: payment
1286        when: "User wants to pay"
1287  payment:
1288    prompt: "Processing payment"
1289    disambiguation:
1290      threshold: 0.95
1291      require_confirmation: true
1292      required_clarity:
1293        - recipient
1294        - amount
1295"#;
1296        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1297        let payment = config.get_state("payment").unwrap();
1298        let disambig = payment.disambiguation.as_ref().unwrap();
1299        assert_eq!(disambig.threshold, Some(0.95));
1300        assert!(disambig.require_confirmation);
1301        assert_eq!(disambig.required_clarity.len(), 2);
1302        assert!(disambig.required_clarity.contains(&"recipient".to_string()));
1303
1304        let greeting = config.get_state("greeting").unwrap();
1305        assert!(greeting.disambiguation.is_none());
1306    }
1307
1308    #[test]
1309    fn test_context_extractor_vec_deserialize() {
1310        let yaml = r#"
1311initial: a
1312states:
1313  a:
1314    extract:
1315      - key: user_email
1316        description: "The user's email address"
1317      - key: order_id
1318        llm_extract: "Extract the order ID"
1319        required: true
1320"#;
1321        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1322        let state = config.get_state("a").unwrap();
1323        assert_eq!(state.extract.len(), 2);
1324        assert_eq!(state.extract[0].key, "user_email");
1325        assert_eq!(
1326            state.extract[0].description.as_deref(),
1327            Some("The user's email address")
1328        );
1329        assert!(!state.extract[0].required);
1330        assert_eq!(state.extract[0].llm, "router");
1331        assert_eq!(state.extract[1].key, "order_id");
1332        assert!(state.extract[1].required);
1333        assert!(state.extract[1].llm_extract.is_some());
1334    }
1335
1336    #[test]
1337    fn test_context_extractor_default_empty() {
1338        let yaml = r#"
1339initial: a
1340states:
1341  a:
1342    prompt: "Hello"
1343"#;
1344        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1345        let state = config.get_state("a").unwrap();
1346        assert!(state.extract.is_empty());
1347    }
1348
1349    #[test]
1350    fn test_state_process_override_deserialize() {
1351        let yaml = r#"
1352initial: a
1353states:
1354  a:
1355    process:
1356      input:
1357        - type: normalize
1358          config:
1359            trim: true
1360"#;
1361        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1362        let state = config.get_state("a").unwrap();
1363        assert!(state.process.is_some());
1364        assert_eq!(state.process.as_ref().unwrap().input.len(), 1);
1365    }
1366
1367    #[test]
1368    fn test_state_process_default_none() {
1369        let yaml = r#"
1370initial: a
1371states:
1372  a:
1373    prompt: "Hello"
1374"#;
1375        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1376        let state = config.get_state("a").unwrap();
1377        assert!(state.process.is_none());
1378    }
1379}