Skip to main content

ai_agents_reasoning/
planning.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct PlanningConfig {
5    #[serde(default)]
6    pub planner_llm: Option<String>,
7
8    #[serde(default = "default_max_steps")]
9    pub max_steps: u32,
10
11    #[serde(default)]
12    pub available: PlanAvailableActions,
13
14    #[serde(default)]
15    pub reflection: PlanReflectionConfig,
16}
17
18impl Default for PlanningConfig {
19    fn default() -> Self {
20        Self {
21            planner_llm: None,
22            max_steps: default_max_steps(),
23            available: PlanAvailableActions::default(),
24            reflection: PlanReflectionConfig::default(),
25        }
26    }
27}
28
29fn default_max_steps() -> u32 {
30    10
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
34pub struct PlanAvailableActions {
35    #[serde(default = "default_all_string")]
36    pub tools: StringOrList,
37
38    #[serde(default = "default_all_string")]
39    pub skills: StringOrList,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(untagged)]
44pub enum StringOrList {
45    All(String),
46    List(Vec<String>),
47}
48
49impl Default for StringOrList {
50    fn default() -> Self {
51        StringOrList::All("all".to_string())
52    }
53}
54
55fn default_all_string() -> StringOrList {
56    StringOrList::All("all".to_string())
57}
58
59impl StringOrList {
60    pub fn is_all(&self) -> bool {
61        matches!(self, StringOrList::All(s) if s == "all")
62    }
63
64    pub fn allows(&self, id: &str) -> bool {
65        match self {
66            StringOrList::All(s) if s == "all" => true,
67            StringOrList::All(_) => false,
68            StringOrList::List(list) => list.iter().any(|s| s == id),
69        }
70    }
71
72    pub fn as_list(&self) -> Option<&[String]> {
73        match self {
74            StringOrList::List(list) => Some(list),
75            StringOrList::All(_) => None,
76        }
77    }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct PlanReflectionConfig {
82    #[serde(default = "default_true")]
83    pub enabled: bool,
84
85    #[serde(default)]
86    pub on_step_failure: StepFailureAction,
87
88    #[serde(default = "default_max_replans")]
89    pub max_replans: u32,
90}
91
92impl Default for PlanReflectionConfig {
93    fn default() -> Self {
94        Self {
95            enabled: true,
96            on_step_failure: StepFailureAction::default(),
97            max_replans: default_max_replans(),
98        }
99    }
100}
101
102fn default_true() -> bool {
103    true
104}
105
106fn default_max_replans() -> u32 {
107    2
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
111#[serde(rename_all = "snake_case")]
112pub enum StepFailureAction {
113    #[default]
114    Replan,
115    Skip,
116    Abort,
117    Continue,
118}
119
120impl StepFailureAction {
121    pub fn should_stop(&self) -> bool {
122        matches!(self, StepFailureAction::Abort)
123    }
124
125    pub fn should_replan(&self) -> bool {
126        matches!(self, StepFailureAction::Replan)
127    }
128
129    pub fn should_skip(&self) -> bool {
130        matches!(self, StepFailureAction::Skip)
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_planning_config_default() {
140        let config = PlanningConfig::default();
141        assert!(config.planner_llm.is_none());
142        assert_eq!(config.max_steps, 10);
143        assert!(config.available.tools.is_all());
144        assert!(config.available.skills.is_all());
145        assert!(config.reflection.enabled);
146    }
147
148    #[test]
149    fn test_planning_config_serde() {
150        let yaml = r#"
151planner_llm: router
152max_steps: 15
153available:
154  tools: all
155  skills:
156    - skill1
157    - skill2
158reflection:
159  enabled: true
160  on_step_failure: skip
161  max_replans: 3
162"#;
163        let config: PlanningConfig = serde_yaml::from_str(yaml).unwrap();
164        assert_eq!(config.planner_llm, Some("router".to_string()));
165        assert_eq!(config.max_steps, 15);
166        assert!(config.available.tools.is_all());
167        assert!(!config.available.skills.is_all());
168        assert!(config.available.skills.allows("skill1"));
169        assert!(!config.available.skills.allows("skill3"));
170        assert!(config.reflection.enabled);
171        assert_eq!(config.reflection.on_step_failure, StepFailureAction::Skip);
172        assert_eq!(config.reflection.max_replans, 3);
173    }
174
175    #[test]
176    fn test_string_or_list() {
177        let all = StringOrList::All("all".to_string());
178        assert!(all.is_all());
179        assert!(all.allows("anything"));
180        assert!(all.as_list().is_none());
181
182        let list = StringOrList::List(vec!["a".to_string(), "b".to_string()]);
183        assert!(!list.is_all());
184        assert!(list.allows("a"));
185        assert!(list.allows("b"));
186        assert!(!list.allows("c"));
187        assert_eq!(
188            list.as_list(),
189            Some(["a".to_string(), "b".to_string()].as_slice())
190        );
191    }
192
193    #[test]
194    fn test_step_failure_action() {
195        assert!(StepFailureAction::Abort.should_stop());
196        assert!(!StepFailureAction::Replan.should_stop());
197
198        assert!(StepFailureAction::Replan.should_replan());
199        assert!(!StepFailureAction::Skip.should_replan());
200
201        assert!(StepFailureAction::Skip.should_skip());
202        assert!(!StepFailureAction::Continue.should_skip());
203    }
204
205    #[test]
206    fn test_plan_reflection_config_default() {
207        let config = PlanReflectionConfig::default();
208        assert!(config.enabled);
209        assert_eq!(config.on_step_failure, StepFailureAction::Replan);
210        assert_eq!(config.max_replans, 2);
211    }
212}