ai_agents_reasoning/
planning.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct PlanningConfig {
5 #[serde(default)]
6 pub planner_llm: Option<String>,
7
8 #[serde(default = "default_max_steps")]
9 pub max_steps: u32,
10
11 #[serde(default)]
12 pub available: PlanAvailableActions,
13
14 #[serde(default)]
15 pub reflection: PlanReflectionConfig,
16}
17
18impl Default for PlanningConfig {
19 fn default() -> Self {
20 Self {
21 planner_llm: None,
22 max_steps: default_max_steps(),
23 available: PlanAvailableActions::default(),
24 reflection: PlanReflectionConfig::default(),
25 }
26 }
27}
28
29fn default_max_steps() -> u32 {
30 10
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
34pub struct PlanAvailableActions {
35 #[serde(default = "default_all_string")]
36 pub tools: StringOrList,
37
38 #[serde(default = "default_all_string")]
39 pub skills: StringOrList,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(untagged)]
44pub enum StringOrList {
45 All(String),
46 List(Vec<String>),
47}
48
49impl Default for StringOrList {
50 fn default() -> Self {
51 StringOrList::All("all".to_string())
52 }
53}
54
55fn default_all_string() -> StringOrList {
56 StringOrList::All("all".to_string())
57}
58
59impl StringOrList {
60 pub fn is_all(&self) -> bool {
61 matches!(self, StringOrList::All(s) if s == "all")
62 }
63
64 pub fn allows(&self, id: &str) -> bool {
65 match self {
66 StringOrList::All(s) if s == "all" => true,
67 StringOrList::All(_) => false,
68 StringOrList::List(list) => list.iter().any(|s| s == id),
69 }
70 }
71
72 pub fn as_list(&self) -> Option<&[String]> {
73 match self {
74 StringOrList::List(list) => Some(list),
75 StringOrList::All(_) => None,
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct PlanReflectionConfig {
82 #[serde(default = "default_true")]
83 pub enabled: bool,
84
85 #[serde(default)]
86 pub on_step_failure: StepFailureAction,
87
88 #[serde(default = "default_max_replans")]
89 pub max_replans: u32,
90}
91
92impl Default for PlanReflectionConfig {
93 fn default() -> Self {
94 Self {
95 enabled: true,
96 on_step_failure: StepFailureAction::default(),
97 max_replans: default_max_replans(),
98 }
99 }
100}
101
102fn default_true() -> bool {
103 true
104}
105
106fn default_max_replans() -> u32 {
107 2
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
111#[serde(rename_all = "snake_case")]
112pub enum StepFailureAction {
113 #[default]
114 Replan,
115 Skip,
116 Abort,
117 Continue,
118}
119
120impl StepFailureAction {
121 pub fn should_stop(&self) -> bool {
122 matches!(self, StepFailureAction::Abort)
123 }
124
125 pub fn should_replan(&self) -> bool {
126 matches!(self, StepFailureAction::Replan)
127 }
128
129 pub fn should_skip(&self) -> bool {
130 matches!(self, StepFailureAction::Skip)
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_planning_config_default() {
140 let config = PlanningConfig::default();
141 assert!(config.planner_llm.is_none());
142 assert_eq!(config.max_steps, 10);
143 assert!(config.available.tools.is_all());
144 assert!(config.available.skills.is_all());
145 assert!(config.reflection.enabled);
146 }
147
148 #[test]
149 fn test_planning_config_serde() {
150 let yaml = r#"
151planner_llm: router
152max_steps: 15
153available:
154 tools: all
155 skills:
156 - skill1
157 - skill2
158reflection:
159 enabled: true
160 on_step_failure: skip
161 max_replans: 3
162"#;
163 let config: PlanningConfig = serde_yaml::from_str(yaml).unwrap();
164 assert_eq!(config.planner_llm, Some("router".to_string()));
165 assert_eq!(config.max_steps, 15);
166 assert!(config.available.tools.is_all());
167 assert!(!config.available.skills.is_all());
168 assert!(config.available.skills.allows("skill1"));
169 assert!(!config.available.skills.allows("skill3"));
170 assert!(config.reflection.enabled);
171 assert_eq!(config.reflection.on_step_failure, StepFailureAction::Skip);
172 assert_eq!(config.reflection.max_replans, 3);
173 }
174
175 #[test]
176 fn test_string_or_list() {
177 let all = StringOrList::All("all".to_string());
178 assert!(all.is_all());
179 assert!(all.allows("anything"));
180 assert!(all.as_list().is_none());
181
182 let list = StringOrList::List(vec!["a".to_string(), "b".to_string()]);
183 assert!(!list.is_all());
184 assert!(list.allows("a"));
185 assert!(list.allows("b"));
186 assert!(!list.allows("c"));
187 assert_eq!(
188 list.as_list(),
189 Some(["a".to_string(), "b".to_string()].as_slice())
190 );
191 }
192
193 #[test]
194 fn test_step_failure_action() {
195 assert!(StepFailureAction::Abort.should_stop());
196 assert!(!StepFailureAction::Replan.should_stop());
197
198 assert!(StepFailureAction::Replan.should_replan());
199 assert!(!StepFailureAction::Skip.should_replan());
200
201 assert!(StepFailureAction::Skip.should_skip());
202 assert!(!StepFailureAction::Continue.should_skip());
203 }
204
205 #[test]
206 fn test_plan_reflection_config_default() {
207 let config = PlanReflectionConfig::default();
208 assert!(config.enabled);
209 assert_eq!(config.on_step_failure, StepFailureAction::Replan);
210 assert_eq!(config.max_replans, 2);
211 }
212}