Skip to main content

ai_agents_reasoning/
config.rs

1use serde::{Deserialize, Serialize};
2
3use crate::mode::{ReasoningMode, ReasoningOutput, ReflectionMode};
4use crate::planning::PlanningConfig;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ReasoningConfig {
8    #[serde(default)]
9    pub mode: ReasoningMode,
10
11    #[serde(default, skip_serializing_if = "Option::is_none")]
12    pub judge_llm: Option<String>,
13
14    #[serde(default)]
15    pub output: ReasoningOutput,
16
17    #[serde(default = "default_max_iterations")]
18    pub max_iterations: u32,
19
20    #[serde(default, skip_serializing_if = "Option::is_none")]
21    pub planning: Option<PlanningConfig>,
22}
23
24impl Default for ReasoningConfig {
25    fn default() -> Self {
26        Self {
27            mode: ReasoningMode::None,
28            judge_llm: None,
29            output: ReasoningOutput::Hidden,
30            max_iterations: default_max_iterations(),
31            planning: None,
32        }
33    }
34}
35
36fn default_max_iterations() -> u32 {
37    5
38}
39
40impl ReasoningConfig {
41    pub fn new(mode: ReasoningMode) -> Self {
42        Self {
43            mode,
44            ..Default::default()
45        }
46    }
47
48    pub fn with_judge_llm(mut self, llm: impl Into<String>) -> Self {
49        self.judge_llm = Some(llm.into());
50        self
51    }
52
53    pub fn with_output(mut self, output: ReasoningOutput) -> Self {
54        self.output = output;
55        self
56    }
57
58    pub fn with_max_iterations(mut self, max: u32) -> Self {
59        self.max_iterations = max;
60        self
61    }
62
63    pub fn with_planning(mut self, planning: PlanningConfig) -> Self {
64        self.planning = Some(planning);
65        self
66    }
67
68    pub fn is_enabled(&self) -> bool {
69        !matches!(self.mode, ReasoningMode::None)
70    }
71
72    pub fn needs_planning(&self) -> bool {
73        self.mode.uses_planning()
74    }
75
76    pub fn get_planning(&self) -> Option<&PlanningConfig> {
77        self.planning.as_ref()
78    }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ReflectionConfig {
83    #[serde(default)]
84    pub enabled: ReflectionMode,
85
86    #[serde(default, skip_serializing_if = "Option::is_none")]
87    pub evaluator_llm: Option<String>,
88
89    #[serde(default = "default_max_retries")]
90    pub max_retries: u32,
91
92    #[serde(default = "default_criteria")]
93    pub criteria: Vec<String>,
94
95    #[serde(default = "default_pass_threshold")]
96    pub pass_threshold: f32,
97}
98
99impl Default for ReflectionConfig {
100    fn default() -> Self {
101        Self {
102            enabled: ReflectionMode::Disabled,
103            evaluator_llm: None,
104            max_retries: default_max_retries(),
105            criteria: default_criteria(),
106            pass_threshold: default_pass_threshold(),
107        }
108    }
109}
110
111fn default_max_retries() -> u32 {
112    2
113}
114
115fn default_pass_threshold() -> f32 {
116    0.7
117}
118
119fn default_criteria() -> Vec<String> {
120    vec![
121        "Response directly addresses the user's question".to_string(),
122        "Response is complete and not cut off".to_string(),
123        "Response is accurate and helpful".to_string(),
124    ]
125}
126
127impl ReflectionConfig {
128    pub fn new(enabled: ReflectionMode) -> Self {
129        Self {
130            enabled,
131            ..Default::default()
132        }
133    }
134
135    pub fn enabled() -> Self {
136        Self::new(ReflectionMode::Enabled)
137    }
138
139    pub fn disabled() -> Self {
140        Self::new(ReflectionMode::Disabled)
141    }
142
143    pub fn auto() -> Self {
144        Self::new(ReflectionMode::Auto)
145    }
146
147    pub fn with_evaluator_llm(mut self, llm: impl Into<String>) -> Self {
148        self.evaluator_llm = Some(llm.into());
149        self
150    }
151
152    pub fn with_max_retries(mut self, max: u32) -> Self {
153        self.max_retries = max;
154        self
155    }
156
157    pub fn with_criteria(mut self, criteria: Vec<String>) -> Self {
158        self.criteria = criteria;
159        self
160    }
161
162    pub fn add_criterion(mut self, criterion: impl Into<String>) -> Self {
163        self.criteria.push(criterion.into());
164        self
165    }
166
167    pub fn with_pass_threshold(mut self, threshold: f32) -> Self {
168        self.pass_threshold = threshold.clamp(0.0, 1.0);
169        self
170    }
171
172    pub fn is_enabled(&self) -> bool {
173        self.enabled.is_enabled()
174    }
175
176    pub fn is_auto(&self) -> bool {
177        self.enabled.is_auto()
178    }
179
180    pub fn requires_evaluation(&self) -> bool {
181        self.enabled.requires_evaluation()
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn test_reasoning_config_default() {
191        let config = ReasoningConfig::default();
192        assert_eq!(config.mode, ReasoningMode::None);
193        assert!(config.judge_llm.is_none());
194        assert_eq!(config.output, ReasoningOutput::Hidden);
195        assert_eq!(config.max_iterations, 5);
196        assert!(config.planning.is_none());
197        assert!(!config.is_enabled());
198    }
199
200    #[test]
201    fn test_reasoning_config_new() {
202        let config = ReasoningConfig::new(ReasoningMode::CoT);
203        assert_eq!(config.mode, ReasoningMode::CoT);
204        assert!(config.is_enabled());
205    }
206
207    #[test]
208    fn test_reasoning_config_builder() {
209        let config = ReasoningConfig::new(ReasoningMode::PlanAndExecute)
210            .with_judge_llm("router")
211            .with_output(ReasoningOutput::Tagged)
212            .with_max_iterations(10)
213            .with_planning(PlanningConfig::default());
214
215        assert_eq!(config.judge_llm, Some("router".to_string()));
216        assert_eq!(config.output, ReasoningOutput::Tagged);
217        assert_eq!(config.max_iterations, 10);
218        assert!(config.planning.is_some());
219        assert!(config.needs_planning());
220    }
221
222    #[test]
223    fn test_reasoning_config_serde() {
224        let yaml = r#"
225mode: cot
226judge_llm: router
227output: tagged
228max_iterations: 8
229"#;
230        let config: ReasoningConfig = serde_yaml::from_str(yaml).unwrap();
231        assert_eq!(config.mode, ReasoningMode::CoT);
232        assert_eq!(config.judge_llm, Some("router".to_string()));
233        assert_eq!(config.output, ReasoningOutput::Tagged);
234        assert_eq!(config.max_iterations, 8);
235    }
236
237    #[test]
238    fn test_reasoning_config_serde_with_planning() {
239        let yaml = r#"
240mode: plan_and_execute
241planning:
242  planner_llm: router
243  max_steps: 15
244  available:
245    tools: all
246    skills:
247      - skill1
248      - skill2
249"#;
250        let config: ReasoningConfig = serde_yaml::from_str(yaml).unwrap();
251        assert_eq!(config.mode, ReasoningMode::PlanAndExecute);
252        assert!(config.needs_planning());
253        let planning = config.get_planning().unwrap();
254        assert_eq!(planning.planner_llm, Some("router".to_string()));
255        assert_eq!(planning.max_steps, 15);
256    }
257
258    #[test]
259    fn test_reflection_config_default() {
260        let config = ReflectionConfig::default();
261        assert_eq!(config.enabled, ReflectionMode::Disabled);
262        assert!(config.evaluator_llm.is_none());
263        assert_eq!(config.max_retries, 2);
264        assert_eq!(config.criteria.len(), 3);
265        assert_eq!(config.pass_threshold, 0.7);
266        assert!(!config.is_enabled());
267        assert!(!config.requires_evaluation());
268    }
269
270    #[test]
271    fn test_reflection_config_constructors() {
272        let enabled = ReflectionConfig::enabled();
273        assert!(enabled.is_enabled());
274
275        let disabled = ReflectionConfig::disabled();
276        assert!(!disabled.is_enabled());
277
278        let auto = ReflectionConfig::auto();
279        assert!(auto.is_auto());
280        assert!(auto.requires_evaluation());
281    }
282
283    #[test]
284    fn test_reflection_config_builder() {
285        let config = ReflectionConfig::enabled()
286            .with_evaluator_llm("evaluator")
287            .with_max_retries(5)
288            .with_criteria(vec!["Criterion 1".to_string()])
289            .add_criterion("Criterion 2")
290            .with_pass_threshold(0.85);
291
292        assert_eq!(config.evaluator_llm, Some("evaluator".to_string()));
293        assert_eq!(config.max_retries, 5);
294        assert_eq!(config.criteria.len(), 2);
295        assert_eq!(config.pass_threshold, 0.85);
296    }
297
298    #[test]
299    fn test_reflection_config_threshold_clamping() {
300        let config = ReflectionConfig::default().with_pass_threshold(1.5);
301        assert_eq!(config.pass_threshold, 1.0);
302
303        let config = ReflectionConfig::default().with_pass_threshold(-0.5);
304        assert_eq!(config.pass_threshold, 0.0);
305    }
306
307    #[test]
308    fn test_reflection_config_serde() {
309        let yaml = r#"
310enabled: auto
311evaluator_llm: router
312max_retries: 3
313criteria:
314  - "Response is clear"
315  - "Response is accurate"
316pass_threshold: 0.8
317"#;
318        let config: ReflectionConfig = serde_yaml::from_str(yaml).unwrap();
319        assert!(config.is_auto());
320        assert_eq!(config.evaluator_llm, Some("router".to_string()));
321        assert_eq!(config.max_retries, 3);
322        assert_eq!(config.criteria.len(), 2);
323        assert_eq!(config.pass_threshold, 0.8);
324    }
325
326    #[test]
327    fn test_reflection_config_serde_enabled_alias() {
328        let yaml = r#"enabled: true"#;
329        let config: ReflectionConfig = serde_yaml::from_str(yaml).unwrap();
330        assert!(config.is_enabled());
331
332        let yaml = r#"enabled: false"#;
333        let config: ReflectionConfig = serde_yaml::from_str(yaml).unwrap();
334        assert!(!config.is_enabled());
335    }
336}