ai_agents_reasoning/
config.rs1use serde::{Deserialize, Serialize};
2
3use crate::mode::{ReasoningMode, ReasoningOutput, ReflectionMode};
4use crate::planning::PlanningConfig;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ReasoningConfig {
8 #[serde(default)]
9 pub mode: ReasoningMode,
10
11 #[serde(default, skip_serializing_if = "Option::is_none")]
12 pub judge_llm: Option<String>,
13
14 #[serde(default)]
15 pub output: ReasoningOutput,
16
17 #[serde(default = "default_max_iterations")]
18 pub max_iterations: u32,
19
20 #[serde(default, skip_serializing_if = "Option::is_none")]
21 pub planning: Option<PlanningConfig>,
22}
23
24impl Default for ReasoningConfig {
25 fn default() -> Self {
26 Self {
27 mode: ReasoningMode::None,
28 judge_llm: None,
29 output: ReasoningOutput::Hidden,
30 max_iterations: default_max_iterations(),
31 planning: None,
32 }
33 }
34}
35
36fn default_max_iterations() -> u32 {
37 5
38}
39
40impl ReasoningConfig {
41 pub fn new(mode: ReasoningMode) -> Self {
42 Self {
43 mode,
44 ..Default::default()
45 }
46 }
47
48 pub fn with_judge_llm(mut self, llm: impl Into<String>) -> Self {
49 self.judge_llm = Some(llm.into());
50 self
51 }
52
53 pub fn with_output(mut self, output: ReasoningOutput) -> Self {
54 self.output = output;
55 self
56 }
57
58 pub fn with_max_iterations(mut self, max: u32) -> Self {
59 self.max_iterations = max;
60 self
61 }
62
63 pub fn with_planning(mut self, planning: PlanningConfig) -> Self {
64 self.planning = Some(planning);
65 self
66 }
67
68 pub fn is_enabled(&self) -> bool {
69 !matches!(self.mode, ReasoningMode::None)
70 }
71
72 pub fn needs_planning(&self) -> bool {
73 self.mode.uses_planning()
74 }
75
76 pub fn get_planning(&self) -> Option<&PlanningConfig> {
77 self.planning.as_ref()
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ReflectionConfig {
83 #[serde(default)]
84 pub enabled: ReflectionMode,
85
86 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub evaluator_llm: Option<String>,
88
89 #[serde(default = "default_max_retries")]
90 pub max_retries: u32,
91
92 #[serde(default = "default_criteria")]
93 pub criteria: Vec<String>,
94
95 #[serde(default = "default_pass_threshold")]
96 pub pass_threshold: f32,
97}
98
99impl Default for ReflectionConfig {
100 fn default() -> Self {
101 Self {
102 enabled: ReflectionMode::Disabled,
103 evaluator_llm: None,
104 max_retries: default_max_retries(),
105 criteria: default_criteria(),
106 pass_threshold: default_pass_threshold(),
107 }
108 }
109}
110
111fn default_max_retries() -> u32 {
112 2
113}
114
115fn default_pass_threshold() -> f32 {
116 0.7
117}
118
119fn default_criteria() -> Vec<String> {
120 vec![
121 "Response directly addresses the user's question".to_string(),
122 "Response is complete and not cut off".to_string(),
123 "Response is accurate and helpful".to_string(),
124 ]
125}
126
127impl ReflectionConfig {
128 pub fn new(enabled: ReflectionMode) -> Self {
129 Self {
130 enabled,
131 ..Default::default()
132 }
133 }
134
135 pub fn enabled() -> Self {
136 Self::new(ReflectionMode::Enabled)
137 }
138
139 pub fn disabled() -> Self {
140 Self::new(ReflectionMode::Disabled)
141 }
142
143 pub fn auto() -> Self {
144 Self::new(ReflectionMode::Auto)
145 }
146
147 pub fn with_evaluator_llm(mut self, llm: impl Into<String>) -> Self {
148 self.evaluator_llm = Some(llm.into());
149 self
150 }
151
152 pub fn with_max_retries(mut self, max: u32) -> Self {
153 self.max_retries = max;
154 self
155 }
156
157 pub fn with_criteria(mut self, criteria: Vec<String>) -> Self {
158 self.criteria = criteria;
159 self
160 }
161
162 pub fn add_criterion(mut self, criterion: impl Into<String>) -> Self {
163 self.criteria.push(criterion.into());
164 self
165 }
166
167 pub fn with_pass_threshold(mut self, threshold: f32) -> Self {
168 self.pass_threshold = threshold.clamp(0.0, 1.0);
169 self
170 }
171
172 pub fn is_enabled(&self) -> bool {
173 self.enabled.is_enabled()
174 }
175
176 pub fn is_auto(&self) -> bool {
177 self.enabled.is_auto()
178 }
179
180 pub fn requires_evaluation(&self) -> bool {
181 self.enabled.requires_evaluation()
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn test_reasoning_config_default() {
191 let config = ReasoningConfig::default();
192 assert_eq!(config.mode, ReasoningMode::None);
193 assert!(config.judge_llm.is_none());
194 assert_eq!(config.output, ReasoningOutput::Hidden);
195 assert_eq!(config.max_iterations, 5);
196 assert!(config.planning.is_none());
197 assert!(!config.is_enabled());
198 }
199
200 #[test]
201 fn test_reasoning_config_new() {
202 let config = ReasoningConfig::new(ReasoningMode::CoT);
203 assert_eq!(config.mode, ReasoningMode::CoT);
204 assert!(config.is_enabled());
205 }
206
207 #[test]
208 fn test_reasoning_config_builder() {
209 let config = ReasoningConfig::new(ReasoningMode::PlanAndExecute)
210 .with_judge_llm("router")
211 .with_output(ReasoningOutput::Tagged)
212 .with_max_iterations(10)
213 .with_planning(PlanningConfig::default());
214
215 assert_eq!(config.judge_llm, Some("router".to_string()));
216 assert_eq!(config.output, ReasoningOutput::Tagged);
217 assert_eq!(config.max_iterations, 10);
218 assert!(config.planning.is_some());
219 assert!(config.needs_planning());
220 }
221
222 #[test]
223 fn test_reasoning_config_serde() {
224 let yaml = r#"
225mode: cot
226judge_llm: router
227output: tagged
228max_iterations: 8
229"#;
230 let config: ReasoningConfig = serde_yaml::from_str(yaml).unwrap();
231 assert_eq!(config.mode, ReasoningMode::CoT);
232 assert_eq!(config.judge_llm, Some("router".to_string()));
233 assert_eq!(config.output, ReasoningOutput::Tagged);
234 assert_eq!(config.max_iterations, 8);
235 }
236
237 #[test]
238 fn test_reasoning_config_serde_with_planning() {
239 let yaml = r#"
240mode: plan_and_execute
241planning:
242 planner_llm: router
243 max_steps: 15
244 available:
245 tools: all
246 skills:
247 - skill1
248 - skill2
249"#;
250 let config: ReasoningConfig = serde_yaml::from_str(yaml).unwrap();
251 assert_eq!(config.mode, ReasoningMode::PlanAndExecute);
252 assert!(config.needs_planning());
253 let planning = config.get_planning().unwrap();
254 assert_eq!(planning.planner_llm, Some("router".to_string()));
255 assert_eq!(planning.max_steps, 15);
256 }
257
258 #[test]
259 fn test_reflection_config_default() {
260 let config = ReflectionConfig::default();
261 assert_eq!(config.enabled, ReflectionMode::Disabled);
262 assert!(config.evaluator_llm.is_none());
263 assert_eq!(config.max_retries, 2);
264 assert_eq!(config.criteria.len(), 3);
265 assert_eq!(config.pass_threshold, 0.7);
266 assert!(!config.is_enabled());
267 assert!(!config.requires_evaluation());
268 }
269
270 #[test]
271 fn test_reflection_config_constructors() {
272 let enabled = ReflectionConfig::enabled();
273 assert!(enabled.is_enabled());
274
275 let disabled = ReflectionConfig::disabled();
276 assert!(!disabled.is_enabled());
277
278 let auto = ReflectionConfig::auto();
279 assert!(auto.is_auto());
280 assert!(auto.requires_evaluation());
281 }
282
283 #[test]
284 fn test_reflection_config_builder() {
285 let config = ReflectionConfig::enabled()
286 .with_evaluator_llm("evaluator")
287 .with_max_retries(5)
288 .with_criteria(vec!["Criterion 1".to_string()])
289 .add_criterion("Criterion 2")
290 .with_pass_threshold(0.85);
291
292 assert_eq!(config.evaluator_llm, Some("evaluator".to_string()));
293 assert_eq!(config.max_retries, 5);
294 assert_eq!(config.criteria.len(), 2);
295 assert_eq!(config.pass_threshold, 0.85);
296 }
297
298 #[test]
299 fn test_reflection_config_threshold_clamping() {
300 let config = ReflectionConfig::default().with_pass_threshold(1.5);
301 assert_eq!(config.pass_threshold, 1.0);
302
303 let config = ReflectionConfig::default().with_pass_threshold(-0.5);
304 assert_eq!(config.pass_threshold, 0.0);
305 }
306
307 #[test]
308 fn test_reflection_config_serde() {
309 let yaml = r#"
310enabled: auto
311evaluator_llm: router
312max_retries: 3
313criteria:
314 - "Response is clear"
315 - "Response is accurate"
316pass_threshold: 0.8
317"#;
318 let config: ReflectionConfig = serde_yaml::from_str(yaml).unwrap();
319 assert!(config.is_auto());
320 assert_eq!(config.evaluator_llm, Some("router".to_string()));
321 assert_eq!(config.max_retries, 3);
322 assert_eq!(config.criteria.len(), 2);
323 assert_eq!(config.pass_threshold, 0.8);
324 }
325
326 #[test]
327 fn test_reflection_config_serde_enabled_alias() {
328 let yaml = r#"enabled: true"#;
329 let config: ReflectionConfig = serde_yaml::from_str(yaml).unwrap();
330 assert!(config.is_enabled());
331
332 let yaml = r#"enabled: false"#;
333 let config: ReflectionConfig = serde_yaml::from_str(yaml).unwrap();
334 assert!(!config.is_enabled());
335 }
336}