Skip to main content

ai_agents_state/
config.rs

1use ai_agents_core::{AgentError, Result};
2use ai_agents_disambiguation::StateDisambiguationOverride;
3use ai_agents_process::ProcessConfig;
4use ai_agents_reasoning::{ReasoningConfig, ReflectionConfig};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
10#[serde(rename_all = "lowercase")]
11pub enum PromptMode {
12    #[default]
13    Append,
14    Replace,
15    Prepend,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StateConfig {
20    pub initial: String,
21    #[serde(default)]
22    pub states: HashMap<String, StateDefinition>,
23    #[serde(default)]
24    pub global_transitions: Vec<Transition>,
25    #[serde(default)]
26    pub fallback: Option<String>,
27    #[serde(default)]
28    pub max_no_transition: Option<u32>,
29
30    /// Whether to re-generate a response after state transitions (default: true).
31    #[serde(default = "default_true")]
32    pub regenerate_on_transition: bool,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, Default)]
36pub struct StateDefinition {
37    #[serde(default)]
38    pub prompt: Option<String>,
39
40    #[serde(default)]
41    pub prompt_mode: PromptMode,
42
43    #[serde(default)]
44    pub llm: Option<String>,
45
46    #[serde(default)]
47    pub skills: Vec<String>,
48
49    /// Tool availability for this state.
50    /// - `None` (omitted in YAML): inherit from parent or agent-level tools
51    /// - `Some([])` (`tools: []` in YAML): explicitly no tools available
52    /// - `Some([...])`: only these tools available
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub tools: Option<Vec<ToolRef>>,
55
56    #[serde(default)]
57    pub transitions: Vec<Transition>,
58
59    #[serde(default)]
60    pub max_turns: Option<u32>,
61
62    #[serde(default)]
63    pub timeout_to: Option<String>,
64
65    #[serde(default)]
66    pub initial: Option<String>,
67
68    #[serde(default)]
69    pub states: Option<HashMap<String, StateDefinition>>,
70
71    #[serde(default = "default_inherit_parent")]
72    pub inherit_parent: bool,
73
74    #[serde(default)]
75    pub on_enter: Vec<StateAction>,
76
77    /// Actions on re-entering a previously visited state. Falls back to on_enter if empty.
78    #[serde(default)]
79    pub on_reenter: Vec<StateAction>,
80
81    #[serde(default)]
82    pub on_exit: Vec<StateAction>,
83
84    /// Per-state override: skip re-generation on entering this state.
85    #[serde(default, skip_serializing_if = "Option::is_none")]
86    pub regenerate_on_enter: Option<bool>,
87
88    /// Context extractors: pull structured data from user input into context.
89    #[serde(default)]
90    pub extract: Vec<ContextExtractor>,
91
92    #[serde(default, skip_serializing_if = "Option::is_none")]
93    pub reasoning: Option<ReasoningConfig>,
94
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub reflection: Option<ReflectionConfig>,
97
98    #[serde(default, skip_serializing_if = "Option::is_none")]
99    pub disambiguation: Option<StateDisambiguationOverride>,
100
101    /// Per-state process pipeline override (replaces agent-level pipeline for this state).
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub process: Option<ProcessConfig>,
104
105    /// Delegate state messages to a registry agent by ID.
106    #[serde(default, skip_serializing_if = "Option::is_none")]
107    pub delegate: Option<String>,
108
109    /// Context mode for delegated states.
110    #[serde(default, skip_serializing_if = "Option::is_none")]
111    pub delegate_context: Option<DelegateContextMode>,
112
113    /// Run multiple registry agents concurrently in this state.
114    #[serde(default, skip_serializing_if = "Option::is_none")]
115    pub concurrent: Option<ConcurrentStateConfig>,
116
117    /// Run a multi-agent group chat in this state.
118    #[serde(default, skip_serializing_if = "Option::is_none")]
119    pub group_chat: Option<GroupChatStateConfig>,
120
121    /// Run a sequential agent pipeline in this state.
122    #[serde(default, skip_serializing_if = "Option::is_none")]
123    pub pipeline: Option<PipelineStateConfig>,
124
125    /// Run an LLM-directed handoff chain in this state.
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub handoff: Option<HandoffStateConfig>,
128}
129
130fn default_inherit_parent() -> bool {
131    true
132}
133
134fn default_true() -> bool {
135    true
136}
137
138fn default_extractor_llm() -> String {
139    "router".to_string()
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum ToolRef {
145    Simple(String),
146    Conditional {
147        id: String,
148        condition: ToolCondition,
149    },
150}
151
152impl ToolRef {
153    pub fn id(&self) -> &str {
154        match self {
155            ToolRef::Simple(id) => id,
156            ToolRef::Conditional { id, .. } => id,
157        }
158    }
159
160    pub fn condition(&self) -> Option<&ToolCondition> {
161        match self {
162            ToolRef::Simple(_) => None,
163            ToolRef::Conditional { condition, .. } => Some(condition),
164        }
165    }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum ToolCondition {
171    Context(HashMap<String, ContextMatcher>),
172    State(StateMatcher),
173    AfterTool(String),
174    ToolResult {
175        tool: String,
176        result: HashMap<String, Value>,
177    },
178    Semantic {
179        when: String,
180        #[serde(default = "default_semantic_llm")]
181        llm: String,
182        #[serde(default = "default_threshold")]
183        threshold: f32,
184    },
185    Time(TimeMatcher),
186    All(Vec<ToolCondition>),
187    Any(Vec<ToolCondition>),
188    Not(Box<ToolCondition>),
189}
190
191fn default_semantic_llm() -> String {
192    "router".to_string()
193}
194
195fn default_threshold() -> f32 {
196    0.7
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200#[serde(untagged)]
201pub enum ContextMatcher {
202    // Order matters for serde untagged: structured variants must come before
203    // Exact(Value) because Value matches any valid JSON — including objects
204    // like `{ "exists": true }` or `{ "eq": "admin" }` that should be parsed
205    // as Exists or Compare instead.
206    Exists { exists: bool },
207    Compare(CompareOp),
208    Exact(Value),
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212#[serde(rename_all = "snake_case")]
213pub enum CompareOp {
214    Eq(Value),
215    Neq(Value),
216    Gt(f64),
217    Gte(f64),
218    Lt(f64),
219    Lte(f64),
220    In(Vec<Value>),
221    Contains(String),
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize, Default)]
225pub struct StateMatcher {
226    #[serde(default)]
227    pub name: Option<String>,
228    #[serde(default)]
229    pub turn_count: Option<CompareOp>,
230    #[serde(default)]
231    pub previous: Option<String>,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, Default)]
235pub struct TimeMatcher {
236    #[serde(default)]
237    pub hours: Option<CompareOp>,
238    #[serde(default)]
239    pub day_of_week: Option<Vec<String>>,
240    #[serde(default)]
241    pub timezone: Option<String>,
242}
243
244#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
245#[serde(rename_all = "snake_case")]
246pub enum TransitionTiming {
247    /// Evaluate after the assistant response is available.
248    PostResponse,
249    /// Evaluate before main response generation when the route is response independent.
250    PreResponse,
251    /// Evaluate in parallel with a draft response when explicitly enabled.
252    Parallel,
253}
254
255impl Default for TransitionTiming {
256    fn default() -> Self {
257        Self::PostResponse
258    }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct Transition {
263    pub to: String,
264    #[serde(default)]
265    pub when: String,
266    #[serde(default)]
267    pub guard: Option<TransitionGuard>,
268    /// Intent label for deterministic routing after disambiguation.
269    #[serde(default, skip_serializing_if = "Option::is_none")]
270    pub intent: Option<String>,
271    #[serde(default = "default_auto")]
272    pub auto: bool,
273    #[serde(default)]
274    pub priority: u8,
275
276    /// Minimum turns before this transition can fire again after last use.
277    #[serde(default, skip_serializing_if = "Option::is_none")]
278    pub cooldown_turns: Option<u32>,
279
280    /// Controls whether the transition can be selected before a response exists.
281    #[serde(default)]
282    pub timing: TransitionTiming,
283
284    /// Marks transitions whose condition needs the assistant response text.
285    #[serde(default)]
286    pub requires_response: bool,
287
288    /// Allows this transition to run current-state extractors before pre-response selection.
289    #[serde(default)]
290    pub run_extractors: bool,
291}
292
293fn default_auto() -> bool {
294    true
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
298#[serde(untagged)]
299pub enum TransitionGuard {
300    Expression(String),
301    Conditions(GuardConditions),
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[serde(rename_all = "snake_case")]
306pub enum GuardConditions {
307    All(Vec<String>),
308    Any(Vec<String>),
309    Context(HashMap<String, ContextMatcher>),
310}
311
312#[derive(Debug, Clone, Serialize, Deserialize)]
313#[serde(untagged)]
314pub enum StateAction {
315    Tool {
316        tool: String,
317        #[serde(default)]
318        args: Option<Value>,
319    },
320    Skill {
321        skill: String,
322    },
323    Prompt {
324        prompt: String,
325        #[serde(default)]
326        llm: Option<String>,
327        #[serde(default)]
328        store_as: Option<String>,
329    },
330    SetContext {
331        set_context: HashMap<String, Value>,
332    },
333}
334
335/// Extract structured data from conversation into context via LLM.
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct ContextExtractor {
338    /// Context key to store the extracted value.
339    pub key: String,
340
341    /// Short description of what to extract (LLM-based).
342    #[serde(default)]
343    pub description: Option<String>,
344
345    /// Custom LLM extraction prompt (takes precedence over `description`).
346    #[serde(default)]
347    pub llm_extract: Option<String>,
348
349    /// LLM alias for extraction (default: "router").
350    #[serde(default = "default_extractor_llm")]
351    pub llm: String,
352
353    /// If true, extraction failure is logged as a warning.
354    #[serde(default)]
355    pub required: bool,
356}
357
358//
359// Multi-agent orchestration config types for state delegation, concurrent execution, and group chat.
360//
361
362/// Context mode for delegated states.
363#[derive(Debug, Clone, Default, Serialize, Deserialize)]
364#[serde(rename_all = "snake_case")]
365pub enum DelegateContextMode {
366    /// Delegated agent receives only the user's current message.
367    #[default]
368    InputOnly,
369    /// Parent summarizes recent conversation via router LLM.
370    Summary,
371    /// Parent passes full recent message history.
372    Full,
373}
374
375/// Config for running multiple registry agents concurrently.
376#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct ConcurrentStateConfig {
378    /// Agent IDs in the registry (simple list or weighted entries).
379    pub agents: Vec<ConcurrentAgentRef>,
380    /// Jinja2 template for input sent to each agent.
381    #[serde(default, skip_serializing_if = "Option::is_none")]
382    pub input: Option<String>,
383    /// How to aggregate results from all agents.
384    pub aggregation: AggregationConfig,
385    /// Minimum agents that must succeed.
386    #[serde(default, skip_serializing_if = "Option::is_none")]
387    pub min_required: Option<usize>,
388    /// What to do when some agents fail.
389    #[serde(default)]
390    pub on_partial_failure: PartialFailureAction,
391    /// Per-agent timeout in milliseconds.
392    #[serde(default, skip_serializing_if = "Option::is_none")]
393    pub timeout_ms: Option<u64>,
394    /// Parent conversation context forwarded to each agent.
395    #[serde(default, skip_serializing_if = "Option::is_none")]
396    pub context_mode: Option<DelegateContextMode>,
397}
398
399/// Either a plain agent ID string or a weighted entry.
400#[derive(Debug, Clone, Serialize, Deserialize)]
401#[serde(untagged)]
402pub enum ConcurrentAgentRef {
403    Id(String),
404    Weighted { id: String, weight: f64 },
405}
406
407impl ConcurrentAgentRef {
408    pub fn id(&self) -> &str {
409        match self {
410            Self::Id(id) => id,
411            Self::Weighted { id, .. } => id,
412        }
413    }
414
415    pub fn weight(&self) -> f64 {
416        match self {
417            Self::Id(_) => 1.0,
418            Self::Weighted { weight, .. } => *weight,
419        }
420    }
421}
422
423/// How to aggregate results from concurrent agents.
424#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct AggregationConfig {
426    /// Aggregation strategy.
427    pub strategy: AggregationStrategy,
428    /// LLM alias for synthesis or vote extraction.
429    #[serde(default, skip_serializing_if = "Option::is_none")]
430    pub synthesizer_llm: Option<String>,
431    /// Custom prompt for LLM synthesis.
432    #[serde(default, skip_serializing_if = "Option::is_none")]
433    pub synthesizer_prompt: Option<String>,
434    /// Voting sub-config.
435    #[serde(default, skip_serializing_if = "Option::is_none")]
436    pub vote: Option<VoteConfig>,
437}
438
439#[derive(Debug, Clone, Serialize, Deserialize)]
440#[serde(rename_all = "snake_case")]
441pub enum AggregationStrategy {
442    Voting,
443    LlmSynthesis,
444    FirstWins,
445    All,
446}
447
448/// Voting config for concurrent agent aggregation.
449#[derive(Debug, Clone, Serialize, Deserialize)]
450pub struct VoteConfig {
451    #[serde(default)]
452    pub method: VoteMethod,
453    #[serde(default)]
454    pub tiebreaker: TiebreakerStrategy,
455    /// Custom prompt for extracting a vote from each agent's response.
456    #[serde(default, skip_serializing_if = "Option::is_none")]
457    pub vote_prompt: Option<String>,
458}
459
460#[derive(Debug, Clone, Default, Serialize, Deserialize)]
461#[serde(rename_all = "snake_case")]
462pub enum VoteMethod {
463    #[default]
464    Majority,
465    Weighted,
466    Unanimous,
467}
468
469#[derive(Debug, Clone, Default, Serialize, Deserialize)]
470#[serde(rename_all = "snake_case")]
471pub enum TiebreakerStrategy {
472    #[default]
473    First,
474    Random,
475    RouterDecides,
476}
477
478#[derive(Debug, Clone, Default, Serialize, Deserialize)]
479#[serde(rename_all = "snake_case")]
480pub enum PartialFailureAction {
481    #[default]
482    ProceedWithAvailable,
483    Abort,
484}
485
486/// Group chat state config for multi-agent conversation.
487#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct GroupChatStateConfig {
489    /// Participant agent IDs with optional roles.
490    pub participants: Vec<ChatParticipant>,
491    /// Conversation style.
492    #[serde(default)]
493    pub style: ChatStyle,
494    /// Maximum conversation rounds.
495    #[serde(default = "default_max_rounds")]
496    pub max_rounds: u32,
497    /// Chat manager config.
498    #[serde(default, skip_serializing_if = "Option::is_none")]
499    pub manager: Option<ChatManagerConfig>,
500    /// When and how to terminate.
501    #[serde(default)]
502    pub termination: TerminationConfig,
503    /// Debate-specific config.
504    #[serde(default, skip_serializing_if = "Option::is_none")]
505    pub debate: Option<DebateStyleConfig>,
506    /// Maker-checker-specific config.
507    #[serde(default, skip_serializing_if = "Option::is_none")]
508    pub maker_checker: Option<MakerCheckerConfig>,
509    /// Total timeout for the group chat in milliseconds.
510    #[serde(default, skip_serializing_if = "Option::is_none")]
511    pub timeout_ms: Option<u64>,
512    /// Jinja2 template for the topic sent to participants.
513    /// {{ user_input }} is the user's message. {{ context.<key> }} accesses context values. When omitted, the raw user message is used as the topic.
514    #[serde(default, skip_serializing_if = "Option::is_none")]
515    pub input: Option<String>,
516    /// Parent conversation context included in the topic.
517    #[serde(default, skip_serializing_if = "Option::is_none")]
518    pub context_mode: Option<DelegateContextMode>,
519}
520
521/// A participant in a group chat.
522#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct ChatParticipant {
524    /// Agent ID in the registry.
525    pub id: String,
526    /// Role description visible to all participants.
527    #[serde(default, skip_serializing_if = "Option::is_none")]
528    pub role: Option<String>,
529}
530
531#[derive(Debug, Clone, Default, Serialize, Deserialize)]
532#[serde(rename_all = "snake_case")]
533pub enum ChatStyle {
534    #[default]
535    Brainstorm,
536    Debate,
537    MakerChecker,
538    Consensus,
539}
540
541/// Chat manager config for controlling turn order.
542#[derive(Debug, Clone, Serialize, Deserialize)]
543pub struct ChatManagerConfig {
544    /// Registry agent ID for chat management.
545    #[serde(default, skip_serializing_if = "Option::is_none")]
546    pub agent: Option<String>,
547    /// Built-in turn policy.
548    #[serde(default, skip_serializing_if = "Option::is_none")]
549    pub method: Option<TurnMethod>,
550}
551
552#[derive(Debug, Clone, Serialize, Deserialize)]
553#[serde(rename_all = "snake_case")]
554pub enum TurnMethod {
555    RoundRobin,
556    Random,
557    LlmDirected,
558}
559
560/// Termination config for group chat.
561#[derive(Debug, Clone, Serialize, Deserialize)]
562pub struct TerminationConfig {
563    #[serde(default)]
564    pub method: TerminationMethod,
565    #[serde(default = "default_stall_rounds")]
566    pub max_stall_rounds: u32,
567}
568
569impl Default for TerminationConfig {
570    fn default() -> Self {
571        Self {
572            method: TerminationMethod::default(),
573            max_stall_rounds: default_stall_rounds(),
574        }
575    }
576}
577
578#[derive(Debug, Clone, Default, Serialize, Deserialize)]
579#[serde(rename_all = "snake_case")]
580pub enum TerminationMethod {
581    #[default]
582    ManagerDecides,
583    MaxRounds,
584    ConsensusReached,
585}
586
587/// Debate-specific config for group chat.
588#[derive(Debug, Clone, Serialize, Deserialize)]
589pub struct DebateStyleConfig {
590    #[serde(default = "default_debate_rounds")]
591    pub rounds: u32,
592    /// Agent ID that synthesizes the final answer.
593    pub synthesizer: String,
594}
595
596/// Maker-checker-specific config for group chat.
597#[derive(Debug, Clone, Serialize, Deserialize)]
598pub struct MakerCheckerConfig {
599    #[serde(default = "default_maker_checker_iterations")]
600    pub max_iterations: u32,
601    /// LLM-evaluated acceptance criteria.
602    pub acceptance_criteria: String,
603    #[serde(default)]
604    pub on_max_iterations: MaxIterationsAction,
605}
606
607#[derive(Debug, Clone, Default, Serialize, Deserialize)]
608#[serde(rename_all = "snake_case")]
609pub enum MaxIterationsAction {
610    #[default]
611    AcceptLast,
612    Escalate,
613    Fail,
614}
615
616fn default_max_rounds() -> u32 {
617    5
618}
619fn default_stall_rounds() -> u32 {
620    2
621}
622fn default_debate_rounds() -> u32 {
623    3
624}
625fn default_maker_checker_iterations() -> u32 {
626    3
627}
628
629/// Config for a pipeline state type.
630#[derive(Debug, Clone, Serialize, Deserialize)]
631pub struct PipelineStateConfig {
632    pub stages: Vec<PipelineStageEntry>,
633
634    #[serde(default, skip_serializing_if = "Option::is_none")]
635    pub timeout_ms: Option<u64>,
636    /// Parent conversation context forwarded to the first stage.
637    #[serde(default, skip_serializing_if = "Option::is_none")]
638    pub context_mode: Option<DelegateContextMode>,
639}
640
641/// A single stage in a pipeline state.
642#[derive(Debug, Clone, Serialize, Deserialize)]
643#[serde(untagged)]
644pub enum PipelineStageEntry {
645    /// Simple agent ID string.
646    Id(String),
647    /// Agent with optional input template.
648    Config {
649        id: String,
650        #[serde(default, skip_serializing_if = "Option::is_none")]
651        input: Option<String>,
652    },
653}
654
655impl PipelineStageEntry {
656    pub fn id(&self) -> &str {
657        match self {
658            Self::Id(id) => id,
659            Self::Config { id, .. } => id,
660        }
661    }
662
663    pub fn input(&self) -> Option<&str> {
664        match self {
665            Self::Id(_) => None,
666            Self::Config { input, .. } => input.as_deref(),
667        }
668    }
669}
670
671/// Config for a handoff state type.
672#[derive(Debug, Clone, Serialize, Deserialize)]
673pub struct HandoffStateConfig {
674    pub initial_agent: String,
675    pub available_agents: Vec<String>,
676
677    #[serde(default = "default_max_handoffs")]
678    pub max_handoffs: u32,
679
680    /// Jinja2 template for the input sent to the initial agent.
681    /// {{ user_input }} is the user's message. {{ context.<key> }} accesses context values. When omitted, the raw user message is forwarded directly.
682    #[serde(default, skip_serializing_if = "Option::is_none")]
683    pub input: Option<String>,
684    /// Parent conversation context forwarded to the initial agent.
685    #[serde(default, skip_serializing_if = "Option::is_none")]
686    pub context_mode: Option<DelegateContextMode>,
687}
688
689fn default_max_handoffs() -> u32 {
690    5
691}
692
693fn validate_transition_timing(
694    transition: &Transition,
695    scope: &str,
696    state_path: Option<&str>,
697) -> Result<()> {
698    if transition.requires_response && !matches!(transition.timing, TransitionTiming::PostResponse)
699    {
700        let location = state_path
701            .map(|path| format!("State '{}'", path))
702            .unwrap_or_else(|| "Global transition".to_string());
703        return Err(AgentError::InvalidSpec(format!(
704            "{} has response-dependent transition '{}' with non-post-response timing",
705            location, transition.to
706        )));
707    }
708    if matches!(transition.timing, TransitionTiming::Parallel)
709        && transition.guard.is_none()
710        && transition.intent.is_none()
711        && transition.when.trim().is_empty()
712    {
713        return Err(AgentError::InvalidSpec(format!(
714            "{} transition '{}' uses parallel timing without a guard, intent, or when condition",
715            scope, transition.to
716        )));
717    }
718    if matches!(transition.timing, TransitionTiming::PreResponse) {
719        if transition.guard.is_none() && transition.intent.is_none() {
720            return Err(AgentError::InvalidSpec(format!(
721                "{} transition '{}' uses pre-response timing without a guard or intent",
722                scope, transition.to
723            )));
724        }
725        if !transition.when.trim().is_empty() {
726            return Err(AgentError::InvalidSpec(format!(
727                "{} transition '{}' uses pre-response timing with response-dependent when text",
728                scope, transition.to
729            )));
730        }
731    }
732    Ok(())
733}
734
735impl StateConfig {
736    pub fn validate(&self) -> Result<()> {
737        if self.initial.is_empty() {
738            return Err(AgentError::InvalidSpec(
739                "State machine initial state cannot be empty".into(),
740            ));
741        }
742        if !self.states.contains_key(&self.initial) {
743            return Err(AgentError::InvalidSpec(format!(
744                "Initial state '{}' not found in states",
745                self.initial
746            )));
747        }
748        for transition in &self.global_transitions {
749            validate_transition_timing(transition, "Global", None)?;
750            if !self.is_valid_transition_target(&transition.to, &[], &self.states) {
751                return Err(AgentError::InvalidSpec(format!(
752                    "Global transition targets unknown state '{}'",
753                    transition.to
754                )));
755            }
756        }
757
758        self.validate_states(&self.states, &[])?;
759
760        // Warn about unreachable states (non-fatal)
761        for warning in self.check_reachability() {
762            tracing::warn!("{}", warning);
763        }
764
765        Ok(())
766    }
767
768    fn validate_states(
769        &self,
770        states: &HashMap<String, StateDefinition>,
771        parent_path: &[String],
772    ) -> Result<()> {
773        for (name, def) in states {
774            let current_path: Vec<String> = parent_path
775                .iter()
776                .cloned()
777                .chain(std::iter::once(name.clone()))
778                .collect();
779
780            for transition in &def.transitions {
781                let path = current_path.join(".");
782                validate_transition_timing(transition, "State", Some(&path))?;
783
784                if !self.is_valid_transition_target(&transition.to, &current_path, states) {
785                    return Err(AgentError::InvalidSpec(format!(
786                        "State '{}' has transition to unknown state '{}'",
787                        current_path.join("."),
788                        transition.to
789                    )));
790                }
791            }
792
793            if let Some(ref timeout_state) = def.timeout_to {
794                if !self.is_valid_transition_target(timeout_state, &current_path, states) {
795                    return Err(AgentError::InvalidSpec(format!(
796                        "State '{}' has timeout_to unknown state '{}'",
797                        current_path.join("."),
798                        timeout_state
799                    )));
800                }
801            }
802
803            if let Some(ref sub_states) = def.states {
804                if let Some(ref initial) = def.initial {
805                    if !sub_states.contains_key(initial) {
806                        return Err(AgentError::InvalidSpec(format!(
807                            "State '{}' has initial sub-state '{}' that doesn't exist",
808                            current_path.join("."),
809                            initial
810                        )));
811                    }
812                }
813                self.validate_states(sub_states, &current_path)?;
814            }
815        }
816        Ok(())
817    }
818
819    fn is_valid_transition_target(
820        &self,
821        target: &str,
822        current_path: &[String],
823        states: &HashMap<String, StateDefinition>,
824    ) -> bool {
825        if target.starts_with('^') {
826            let target_name = &target[1..];
827            return self.states.contains_key(target_name);
828        }
829
830        if states.contains_key(target) {
831            return true;
832        }
833
834        if current_path.len() > 1 {
835            let parent_path = &current_path[..current_path.len() - 1];
836            if let Some(parent_states) = self.get_states_at_path(parent_path) {
837                if parent_states.contains_key(target) {
838                    return true;
839                }
840            }
841        }
842
843        self.states.contains_key(target)
844    }
845
846    fn get_states_at_path(&self, path: &[String]) -> Option<&HashMap<String, StateDefinition>> {
847        let mut current = &self.states;
848        for segment in path {
849            if let Some(def) = current.get(segment) {
850                if let Some(ref sub_states) = def.states {
851                    current = sub_states;
852                } else {
853                    return None;
854                }
855            } else {
856                return None;
857            }
858        }
859        Some(current)
860    }
861
862    pub fn get_state(&self, path: &str) -> Option<&StateDefinition> {
863        let parts: Vec<&str> = path.split('.').collect();
864        self.get_state_by_path(&parts)
865    }
866
867    fn get_state_by_path(&self, path: &[&str]) -> Option<&StateDefinition> {
868        if path.is_empty() {
869            return None;
870        }
871
872        let mut current = self.states.get(path[0])?;
873        for segment in &path[1..] {
874            if let Some(ref sub_states) = current.states {
875                current = sub_states.get(*segment)?;
876            } else {
877                return None;
878            }
879        }
880        Some(current)
881    }
882
883    /// Resolve a transition target to a full dotted state path.
884    /// Order: `^prefix` (parent-level) → top-level → sibling → child → fallback literal.
885    pub fn resolve_full_path(&self, current_path: &str, target: &str) -> String {
886        if target.starts_with('^') {
887            return target[1..].to_string();
888        }
889
890        if self.states.contains_key(target) {
891            return target.to_string();
892        }
893
894        if !current_path.is_empty() {
895            let parts: Vec<&str> = current_path.split('.').collect();
896            if parts.len() > 1 {
897                let parent_path = parts[..parts.len() - 1].join(".");
898                let potential = format!("{}.{}", parent_path, target);
899                if self.get_state(&potential).is_some() {
900                    return potential;
901                }
902            }
903
904            let potential = format!("{}.{}", current_path, target);
905            if self.get_state(&potential).is_some() {
906                return potential;
907            }
908        }
909
910        target.to_string()
911    }
912
913    /// Check for unreachable states. Returns warning messages.
914    pub fn check_reachability(&self) -> Vec<String> {
915        let mut reachable: HashSet<String> = HashSet::new();
916        reachable.insert(self.initial.clone());
917
918        if let Some(ref fb) = self.fallback {
919            reachable.insert(fb.clone());
920        }
921        for gt in &self.global_transitions {
922            reachable.insert(self.normalize_target(&gt.to));
923        }
924
925        let mut queue: Vec<String> = reachable.iter().cloned().collect();
926        while let Some(state_path) = queue.pop() {
927            if let Some(def) = self.get_state(&state_path) {
928                for t in &def.transitions {
929                    let target = self.resolve_full_path(&state_path, &t.to);
930                    if reachable.insert(target.clone()) {
931                        queue.push(target);
932                    }
933                }
934                if let Some(ref timeout) = def.timeout_to {
935                    let target = self.resolve_full_path(&state_path, timeout);
936                    if reachable.insert(target.clone()) {
937                        queue.push(target);
938                    }
939                }
940                if let (Some(initial), Some(_sub)) = (&def.initial, &def.states) {
941                    let sub_path = format!("{}.{}", state_path, initial);
942                    if reachable.insert(sub_path.clone()) {
943                        queue.push(sub_path);
944                    }
945                }
946            }
947        }
948
949        let all_states = self.collect_all_state_paths(&self.states, &[]);
950        let mut warnings = Vec::new();
951        for state_path in &all_states {
952            if !reachable.contains(state_path) {
953                warnings.push(format!(
954                    "State '{}' appears unreachable — no transitions lead to it",
955                    state_path
956                ));
957            }
958        }
959        warnings
960    }
961
962    fn normalize_target(&self, target: &str) -> String {
963        if target.starts_with('^') {
964            target[1..].to_string()
965        } else {
966            target.to_string()
967        }
968    }
969
970    fn collect_all_state_paths(
971        &self,
972        states: &HashMap<String, StateDefinition>,
973        parent: &[String],
974    ) -> Vec<String> {
975        let mut paths = Vec::new();
976        for (name, def) in states {
977            let mut current: Vec<String> = parent.to_vec();
978            current.push(name.clone());
979            paths.push(current.join("."));
980            if let Some(ref sub) = def.states {
981                paths.extend(self.collect_all_state_paths(sub, &current));
982            }
983        }
984        paths
985    }
986}
987
988impl StateDefinition {
989    pub fn has_sub_states(&self) -> bool {
990        self.states.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
991    }
992
993    pub fn get_effective_tools<'a>(
994        &'a self,
995        parent: Option<&'a StateDefinition>,
996    ) -> Option<Vec<&'a ToolRef>> {
997        match &self.tools {
998            // Explicitly set (including empty): use as-is, no inheritance
999            Some(tools) => Some(tools.iter().collect()),
1000            // Not set: inherit from parent if available
1001            None => {
1002                if !self.inherit_parent {
1003                    return None;
1004                }
1005                parent
1006                    .and_then(|p| p.tools.as_ref())
1007                    .map(|t| t.iter().collect())
1008            }
1009        }
1010    }
1011
1012    pub fn get_effective_skills<'a>(
1013        &'a self,
1014        parent: Option<&'a StateDefinition>,
1015    ) -> Vec<&'a String> {
1016        if !self.inherit_parent || parent.is_none() {
1017            return self.skills.iter().collect();
1018        }
1019
1020        let parent = parent.unwrap();
1021        let mut skills: Vec<&'a String> = parent.skills.iter().collect();
1022        skills.extend(self.skills.iter());
1023        skills
1024    }
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029    use super::*;
1030
1031    #[test]
1032    fn test_transition_timing_defaults_to_post_response() {
1033        let yaml = r#"
1034to: next
1035when: "ready"
1036"#;
1037        let transition: Transition = serde_yaml::from_str(yaml).unwrap();
1038        assert_eq!(transition.timing, TransitionTiming::PostResponse);
1039        assert!(!transition.requires_response);
1040        assert!(!transition.run_extractors);
1041    }
1042
1043    #[test]
1044    fn test_state_config_deserialize() {
1045        let yaml = r#"
1046initial: greeting
1047states:
1048  greeting:
1049    prompt: "Welcome!"
1050    transitions:
1051      - to: support
1052        when: "user needs help"
1053        auto: true
1054  support:
1055    prompt: "How can I help?"
1056    llm: fast
1057    tools:
1058      - search
1059"#;
1060        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1061        assert_eq!(config.initial, "greeting");
1062        assert_eq!(config.states.len(), 2);
1063        assert!(config.validate().is_ok());
1064    }
1065
1066    #[test]
1067    fn test_prompt_mode_default() {
1068        let def = StateDefinition::default();
1069        assert_eq!(def.prompt_mode, PromptMode::Append);
1070    }
1071
1072    #[test]
1073    fn test_response_dependent_pre_response_transition_is_invalid() {
1074        let yaml = r#"
1075initial: greeting
1076states:
1077  greeting:
1078    transitions:
1079      - to: done
1080        when: "after answer"
1081        timing: pre_response
1082        requires_response: true
1083  done:
1084    prompt: "Done"
1085"#;
1086        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1087        assert!(config.validate().is_err());
1088    }
1089
1090    #[test]
1091    fn test_parallel_transition_timing_accepts_response_independent_condition() {
1092        let yaml = r#"
1093initial: greeting
1094states:
1095  greeting:
1096    transitions:
1097      - to: done
1098        when: "ready"
1099        timing: parallel
1100  done:
1101    prompt: "Done"
1102"#;
1103        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1104        assert!(config.validate().is_ok());
1105    }
1106
1107    #[test]
1108    fn test_parallel_transition_without_condition_is_invalid() {
1109        let yaml = r#"
1110initial: greeting
1111states:
1112  greeting:
1113    transitions:
1114      - to: done
1115        timing: parallel
1116  done:
1117    prompt: "Done"
1118"#;
1119        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1120        let err = config.validate().unwrap_err();
1121        assert!(err.to_string().contains("parallel timing without"));
1122    }
1123
1124    #[test]
1125    fn test_pre_response_when_without_guard_or_intent_is_invalid() {
1126        let yaml = r#"
1127initial: greeting
1128states:
1129  greeting:
1130    transitions:
1131      - to: done
1132        when: "ready"
1133        timing: pre_response
1134  done:
1135    prompt: "Done"
1136"#;
1137        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1138        assert!(config.validate().is_err());
1139    }
1140
1141    #[test]
1142    fn test_pre_response_with_when_text_is_invalid() {
1143        let yaml = r#"
1144initial: greeting
1145states:
1146  greeting:
1147    transitions:
1148      - to: done
1149        when: "ready"
1150        guard:
1151          context:
1152            ready:
1153              eq: true
1154        timing: pre_response
1155  done:
1156    prompt: "Done"
1157"#;
1158        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1159        assert!(config.validate().is_err());
1160    }
1161
1162    #[test]
1163    fn test_invalid_initial_state() {
1164        let config = StateConfig {
1165            initial: "nonexistent".into(),
1166            states: HashMap::new(),
1167            global_transitions: vec![],
1168            fallback: None,
1169            max_no_transition: None,
1170            regenerate_on_transition: true,
1171        };
1172        assert!(config.validate().is_err());
1173    }
1174
1175    #[test]
1176    fn test_invalid_transition_target() {
1177        let mut states = HashMap::new();
1178        states.insert(
1179            "start".into(),
1180            StateDefinition {
1181                transitions: vec![Transition {
1182                    to: "nonexistent".into(),
1183                    when: "always".into(),
1184                    guard: None,
1185                    intent: None,
1186                    auto: true,
1187                    priority: 0,
1188                    cooldown_turns: None,
1189                    timing: TransitionTiming::PostResponse,
1190                    requires_response: false,
1191                    run_extractors: false,
1192                }],
1193                ..Default::default()
1194            },
1195        );
1196        let config = StateConfig {
1197            initial: "start".into(),
1198            states,
1199            global_transitions: vec![],
1200            fallback: None,
1201            max_no_transition: None,
1202            regenerate_on_transition: true,
1203        };
1204        assert!(config.validate().is_err());
1205    }
1206
1207    #[test]
1208    fn test_hierarchical_states() {
1209        let yaml = r#"
1210initial: problem_solving
1211states:
1212  problem_solving:
1213    initial: gathering_info
1214    prompt: "Solving customer problem"
1215    states:
1216      gathering_info:
1217        prompt: "Ask questions"
1218        transitions:
1219          - to: proposing_solution
1220            when: "understood"
1221      proposing_solution:
1222        prompt: "Offer solution"
1223        transitions:
1224          - to: ^closing
1225            when: "resolved"
1226  closing:
1227    prompt: "Thank you"
1228"#;
1229        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1230        assert!(config.validate().is_ok());
1231        assert!(
1232            config
1233                .states
1234                .get("problem_solving")
1235                .unwrap()
1236                .has_sub_states()
1237        );
1238    }
1239
1240    #[test]
1241    fn test_tool_ref_simple() {
1242        let yaml = r#"
1243tools:
1244  - calculator
1245  - search
1246"#;
1247        #[derive(Deserialize)]
1248        struct Test {
1249            tools: Vec<ToolRef>,
1250        }
1251        let t: Test = serde_yaml::from_str(yaml).unwrap();
1252        assert_eq!(t.tools.len(), 2);
1253        assert_eq!(t.tools[0].id(), "calculator");
1254    }
1255
1256    #[test]
1257    fn test_tool_ref_conditional() {
1258        let yaml = r#"
1259tools:
1260  - calculator
1261  - id: admin_tool
1262    condition:
1263      context:
1264        user.role: "admin"
1265"#;
1266        #[derive(Deserialize)]
1267        struct Test {
1268            tools: Vec<ToolRef>,
1269        }
1270        let t: Test = serde_yaml::from_str(yaml).unwrap();
1271        assert_eq!(t.tools.len(), 2);
1272        assert_eq!(t.tools[1].id(), "admin_tool");
1273        assert!(t.tools[1].condition().is_some());
1274    }
1275
1276    #[test]
1277    fn test_transition_with_guard() {
1278        let yaml = r#"
1279to: next_state
1280when: "user wants to proceed"
1281guard: "{{ context.has_data }}"
1282auto: true
1283priority: 10
1284"#;
1285        let t: Transition = serde_yaml::from_str(yaml).unwrap();
1286        assert!(t.guard.is_some());
1287        assert_eq!(t.priority, 10);
1288    }
1289
1290    #[test]
1291    fn test_state_action() {
1292        let yaml = r#"
1293- tool: log_event
1294  args:
1295    event: "entered"
1296- skill: greeting_skill
1297- set_context:
1298    entered: true
1299"#;
1300        let actions: Vec<StateAction> = serde_yaml::from_str(yaml).unwrap();
1301        assert_eq!(actions.len(), 3);
1302        match &actions[0] {
1303            StateAction::Tool { tool, .. } => assert_eq!(tool, "log_event"),
1304            _ => panic!("Expected Tool action"),
1305        }
1306        match &actions[1] {
1307            StateAction::Skill { skill } => assert_eq!(skill, "greeting_skill"),
1308            _ => panic!("Expected Skill action"),
1309        }
1310        match &actions[2] {
1311            StateAction::SetContext { set_context } => {
1312                assert!(set_context.contains_key("entered"));
1313            }
1314            _ => panic!("Expected SetContext action"),
1315        }
1316    }
1317
1318    #[test]
1319    fn test_complex_tool_condition() {
1320        let yaml = r#"
1321id: refund_tool
1322condition:
1323  all:
1324    - context:
1325        user.verified: true
1326    - semantic:
1327        when: "user wants refund"
1328        threshold: 0.85
1329"#;
1330        let tool: ToolRef = serde_yaml::from_str(yaml).unwrap();
1331        assert_eq!(tool.id(), "refund_tool");
1332        match tool.condition().unwrap() {
1333            ToolCondition::All(conditions) => assert_eq!(conditions.len(), 2),
1334            _ => panic!("Expected All condition"),
1335        }
1336    }
1337
1338    #[test]
1339    fn test_state_get_path() {
1340        let yaml = r#"
1341initial: problem_solving
1342states:
1343  problem_solving:
1344    initial: gathering_info
1345    states:
1346      gathering_info:
1347        prompt: "Ask"
1348      proposing:
1349        prompt: "Propose"
1350  closing:
1351    prompt: "Done"
1352"#;
1353        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1354        assert!(config.get_state("problem_solving").is_some());
1355        assert!(config.get_state("problem_solving.gathering_info").is_some());
1356        assert!(config.get_state("closing").is_some());
1357        assert!(config.get_state("nonexistent").is_none());
1358    }
1359
1360    #[test]
1361    fn test_resolve_full_path() {
1362        let yaml = r#"
1363initial: problem_solving
1364states:
1365  problem_solving:
1366    initial: gathering_info
1367    states:
1368      gathering_info:
1369        prompt: "Ask"
1370      proposing:
1371        prompt: "Propose"
1372  closing:
1373    prompt: "Done"
1374"#;
1375        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1376
1377        assert_eq!(
1378            config.resolve_full_path("problem_solving.gathering_info", "proposing"),
1379            "problem_solving.proposing"
1380        );
1381        assert_eq!(
1382            config.resolve_full_path("problem_solving.gathering_info", "^closing"),
1383            "closing"
1384        );
1385        assert_eq!(
1386            config.resolve_full_path("problem_solving", "closing"),
1387            "closing"
1388        );
1389    }
1390
1391    #[test]
1392    fn test_inherit_parent() {
1393        let parent = StateDefinition {
1394            tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1395            skills: vec!["parent_skill".into()],
1396            ..Default::default()
1397        };
1398
1399        let child = StateDefinition {
1400            tools: Some(vec![ToolRef::Simple("child_tool".into())]),
1401            skills: vec!["child_skill".into()],
1402            inherit_parent: true,
1403            ..Default::default()
1404        };
1405
1406        let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1407        assert_eq!(effective_tools.len(), 1); // explicit tools override, no merge
1408
1409        let effective_skills = child.get_effective_skills(Some(&parent));
1410        assert_eq!(effective_skills.len(), 2);
1411    }
1412
1413    #[test]
1414    fn test_no_inherit_parent() {
1415        let parent = StateDefinition {
1416            tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1417            ..Default::default()
1418        };
1419
1420        let child = StateDefinition {
1421            tools: Some(vec![ToolRef::Simple("child_tool".into())]),
1422            inherit_parent: false,
1423            ..Default::default()
1424        };
1425
1426        let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1427        assert_eq!(effective_tools.len(), 1);
1428        assert_eq!(effective_tools[0].id(), "child_tool");
1429    }
1430
1431    #[test]
1432    fn test_tools_none_inherits() {
1433        let parent = StateDefinition {
1434            tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1435            ..Default::default()
1436        };
1437
1438        let child = StateDefinition {
1439            tools: None, // not specified → inherit
1440            inherit_parent: true,
1441            ..Default::default()
1442        };
1443
1444        let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1445        assert_eq!(effective_tools.len(), 1);
1446        assert_eq!(effective_tools[0].id(), "parent_tool");
1447    }
1448
1449    #[test]
1450    fn test_tools_empty_means_no_tools() {
1451        let parent = StateDefinition {
1452            tools: Some(vec![ToolRef::Simple("parent_tool".into())]),
1453            ..Default::default()
1454        };
1455
1456        let child = StateDefinition {
1457            tools: Some(vec![]), // explicitly empty → no tools
1458            inherit_parent: true,
1459            ..Default::default()
1460        };
1461
1462        let effective_tools = child.get_effective_tools(Some(&parent)).unwrap();
1463        assert!(effective_tools.is_empty());
1464    }
1465
1466    #[test]
1467    fn test_state_with_disambiguation_override() {
1468        let yaml = r#"
1469initial: greeting
1470states:
1471  greeting:
1472    prompt: "Hello"
1473    transitions:
1474      - to: payment
1475        when: "User wants to pay"
1476  payment:
1477    prompt: "Processing payment"
1478    disambiguation:
1479      threshold: 0.95
1480      require_confirmation: true
1481      required_clarity:
1482        - recipient
1483        - amount
1484"#;
1485        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1486        let payment = config.get_state("payment").unwrap();
1487        let disambig = payment.disambiguation.as_ref().unwrap();
1488        assert_eq!(disambig.threshold, Some(0.95));
1489        assert!(disambig.require_confirmation);
1490        assert_eq!(disambig.required_clarity.len(), 2);
1491        assert!(disambig.required_clarity.contains(&"recipient".to_string()));
1492
1493        let greeting = config.get_state("greeting").unwrap();
1494        assert!(greeting.disambiguation.is_none());
1495    }
1496
1497    #[test]
1498    fn test_context_extractor_vec_deserialize() {
1499        let yaml = r#"
1500initial: a
1501states:
1502  a:
1503    extract:
1504      - key: user_email
1505        description: "The user's email address"
1506      - key: order_id
1507        llm_extract: "Extract the order ID"
1508        required: true
1509"#;
1510        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1511        let state = config.get_state("a").unwrap();
1512        assert_eq!(state.extract.len(), 2);
1513        assert_eq!(state.extract[0].key, "user_email");
1514        assert_eq!(
1515            state.extract[0].description.as_deref(),
1516            Some("The user's email address")
1517        );
1518        assert!(!state.extract[0].required);
1519        assert_eq!(state.extract[0].llm, "router");
1520        assert_eq!(state.extract[1].key, "order_id");
1521        assert!(state.extract[1].required);
1522        assert!(state.extract[1].llm_extract.is_some());
1523    }
1524
1525    #[test]
1526    fn test_context_extractor_default_empty() {
1527        let yaml = r#"
1528initial: a
1529states:
1530  a:
1531    prompt: "Hello"
1532"#;
1533        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1534        let state = config.get_state("a").unwrap();
1535        assert!(state.extract.is_empty());
1536    }
1537
1538    #[test]
1539    fn test_state_process_override_deserialize() {
1540        let yaml = r#"
1541initial: a
1542states:
1543  a:
1544    process:
1545      input:
1546        - type: normalize
1547          config:
1548            trim: true
1549"#;
1550        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1551        let state = config.get_state("a").unwrap();
1552        assert!(state.process.is_some());
1553        assert_eq!(state.process.as_ref().unwrap().input.len(), 1);
1554    }
1555
1556    #[test]
1557    fn test_state_process_default_none() {
1558        let yaml = r#"
1559initial: a
1560states:
1561  a:
1562    prompt: "Hello"
1563"#;
1564        let config: StateConfig = serde_yaml::from_str(yaml).unwrap();
1565        let state = config.get_state("a").unwrap();
1566        assert!(state.process.is_none());
1567    }
1568}