ai-agents-reasoning 1.0.0-rc.15

Reasoning and reflection capabilities for AI Agents framework
Documentation
use serde::{Deserialize, Serialize};

use crate::mode::{ReasoningMode, ReasoningOutput, ReflectionMode};
use crate::planning::PlanningConfig;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningConfig {
    #[serde(default)]
    pub mode: ReasoningMode,

    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub judge_llm: Option<String>,

    #[serde(default)]
    pub output: ReasoningOutput,

    #[serde(default = "default_max_iterations")]
    pub max_iterations: u32,

    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub planning: Option<PlanningConfig>,
}

impl Default for ReasoningConfig {
    fn default() -> Self {
        Self {
            mode: ReasoningMode::None,
            judge_llm: None,
            output: ReasoningOutput::Hidden,
            max_iterations: default_max_iterations(),
            planning: None,
        }
    }
}

fn default_max_iterations() -> u32 {
    5
}

impl ReasoningConfig {
    pub fn new(mode: ReasoningMode) -> Self {
        Self {
            mode,
            ..Default::default()
        }
    }

    pub fn with_judge_llm(mut self, llm: impl Into<String>) -> Self {
        self.judge_llm = Some(llm.into());
        self
    }

    pub fn with_output(mut self, output: ReasoningOutput) -> Self {
        self.output = output;
        self
    }

    pub fn with_max_iterations(mut self, max: u32) -> Self {
        self.max_iterations = max;
        self
    }

    pub fn with_planning(mut self, planning: PlanningConfig) -> Self {
        self.planning = Some(planning);
        self
    }

    pub fn is_enabled(&self) -> bool {
        !matches!(self.mode, ReasoningMode::None)
    }

    pub fn needs_planning(&self) -> bool {
        self.mode.uses_planning()
    }

    pub fn get_planning(&self) -> Option<&PlanningConfig> {
        self.planning.as_ref()
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReflectionConfig {
    #[serde(default)]
    pub enabled: ReflectionMode,

    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub evaluator_llm: Option<String>,

    #[serde(default = "default_max_retries")]
    pub max_retries: u32,

    #[serde(default = "default_criteria")]
    pub criteria: Vec<String>,

    #[serde(default = "default_pass_threshold")]
    pub pass_threshold: f32,
}

impl Default for ReflectionConfig {
    fn default() -> Self {
        Self {
            enabled: ReflectionMode::Disabled,
            evaluator_llm: None,
            max_retries: default_max_retries(),
            criteria: default_criteria(),
            pass_threshold: default_pass_threshold(),
        }
    }
}

fn default_max_retries() -> u32 {
    2
}

fn default_pass_threshold() -> f32 {
    0.7
}

fn default_criteria() -> Vec<String> {
    vec![
        "Response directly addresses the user's question".to_string(),
        "Response is complete and not cut off".to_string(),
        "Response is accurate and helpful".to_string(),
    ]
}

impl ReflectionConfig {
    pub fn new(enabled: ReflectionMode) -> Self {
        Self {
            enabled,
            ..Default::default()
        }
    }

    pub fn enabled() -> Self {
        Self::new(ReflectionMode::Enabled)
    }

    pub fn disabled() -> Self {
        Self::new(ReflectionMode::Disabled)
    }

    pub fn auto() -> Self {
        Self::new(ReflectionMode::Auto)
    }

    pub fn with_evaluator_llm(mut self, llm: impl Into<String>) -> Self {
        self.evaluator_llm = Some(llm.into());
        self
    }

    pub fn with_max_retries(mut self, max: u32) -> Self {
        self.max_retries = max;
        self
    }

    pub fn with_criteria(mut self, criteria: Vec<String>) -> Self {
        self.criteria = criteria;
        self
    }

    pub fn add_criterion(mut self, criterion: impl Into<String>) -> Self {
        self.criteria.push(criterion.into());
        self
    }

    pub fn with_pass_threshold(mut self, threshold: f32) -> Self {
        self.pass_threshold = threshold.clamp(0.0, 1.0);
        self
    }

    pub fn is_enabled(&self) -> bool {
        self.enabled.is_enabled()
    }

    pub fn is_auto(&self) -> bool {
        self.enabled.is_auto()
    }

    pub fn requires_evaluation(&self) -> bool {
        self.enabled.requires_evaluation()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_reasoning_config_default() {
        let config = ReasoningConfig::default();
        assert_eq!(config.mode, ReasoningMode::None);
        assert!(config.judge_llm.is_none());
        assert_eq!(config.output, ReasoningOutput::Hidden);
        assert_eq!(config.max_iterations, 5);
        assert!(config.planning.is_none());
        assert!(!config.is_enabled());
    }

    #[test]
    fn test_reasoning_config_new() {
        let config = ReasoningConfig::new(ReasoningMode::CoT);
        assert_eq!(config.mode, ReasoningMode::CoT);
        assert!(config.is_enabled());
    }

    #[test]
    fn test_reasoning_config_builder() {
        let config = ReasoningConfig::new(ReasoningMode::PlanAndExecute)
            .with_judge_llm("router")
            .with_output(ReasoningOutput::Tagged)
            .with_max_iterations(10)
            .with_planning(PlanningConfig::default());

        assert_eq!(config.judge_llm, Some("router".to_string()));
        assert_eq!(config.output, ReasoningOutput::Tagged);
        assert_eq!(config.max_iterations, 10);
        assert!(config.planning.is_some());
        assert!(config.needs_planning());
    }

    #[test]
    fn test_reasoning_config_serde() {
        let yaml = r#"
mode: cot
judge_llm: router
output: tagged
max_iterations: 8
"#;
        let config: ReasoningConfig = serde_yaml::from_str(yaml).unwrap();
        assert_eq!(config.mode, ReasoningMode::CoT);
        assert_eq!(config.judge_llm, Some("router".to_string()));
        assert_eq!(config.output, ReasoningOutput::Tagged);
        assert_eq!(config.max_iterations, 8);
    }

    #[test]
    fn test_reasoning_config_serde_with_planning() {
        let yaml = r#"
mode: plan_and_execute
planning:
  planner_llm: router
  max_steps: 15
  available:
    tools: all
    skills:
      - skill1
      - skill2
"#;
        let config: ReasoningConfig = serde_yaml::from_str(yaml).unwrap();
        assert_eq!(config.mode, ReasoningMode::PlanAndExecute);
        assert!(config.needs_planning());
        let planning = config.get_planning().unwrap();
        assert_eq!(planning.planner_llm, Some("router".to_string()));
        assert_eq!(planning.max_steps, 15);
    }

    #[test]
    fn test_reflection_config_default() {
        let config = ReflectionConfig::default();
        assert_eq!(config.enabled, ReflectionMode::Disabled);
        assert!(config.evaluator_llm.is_none());
        assert_eq!(config.max_retries, 2);
        assert_eq!(config.criteria.len(), 3);
        assert_eq!(config.pass_threshold, 0.7);
        assert!(!config.is_enabled());
        assert!(!config.requires_evaluation());
    }

    #[test]
    fn test_reflection_config_constructors() {
        let enabled = ReflectionConfig::enabled();
        assert!(enabled.is_enabled());

        let disabled = ReflectionConfig::disabled();
        assert!(!disabled.is_enabled());

        let auto = ReflectionConfig::auto();
        assert!(auto.is_auto());
        assert!(auto.requires_evaluation());
    }

    #[test]
    fn test_reflection_config_builder() {
        let config = ReflectionConfig::enabled()
            .with_evaluator_llm("evaluator")
            .with_max_retries(5)
            .with_criteria(vec!["Criterion 1".to_string()])
            .add_criterion("Criterion 2")
            .with_pass_threshold(0.85);

        assert_eq!(config.evaluator_llm, Some("evaluator".to_string()));
        assert_eq!(config.max_retries, 5);
        assert_eq!(config.criteria.len(), 2);
        assert_eq!(config.pass_threshold, 0.85);
    }

    #[test]
    fn test_reflection_config_threshold_clamping() {
        let config = ReflectionConfig::default().with_pass_threshold(1.5);
        assert_eq!(config.pass_threshold, 1.0);

        let config = ReflectionConfig::default().with_pass_threshold(-0.5);
        assert_eq!(config.pass_threshold, 0.0);
    }

    #[test]
    fn test_reflection_config_serde() {
        let yaml = r#"
enabled: auto
evaluator_llm: router
max_retries: 3
criteria:
  - "Response is clear"
  - "Response is accurate"
pass_threshold: 0.8
"#;
        let config: ReflectionConfig = serde_yaml::from_str(yaml).unwrap();
        assert!(config.is_auto());
        assert_eq!(config.evaluator_llm, Some("router".to_string()));
        assert_eq!(config.max_retries, 3);
        assert_eq!(config.criteria.len(), 2);
        assert_eq!(config.pass_threshold, 0.8);
    }

    #[test]
    fn test_reflection_config_serde_enabled_alias() {
        let yaml = r#"enabled: true"#;
        let config: ReflectionConfig = serde_yaml::from_str(yaml).unwrap();
        assert!(config.is_enabled());

        let yaml = r#"enabled: false"#;
        let config: ReflectionConfig = serde_yaml::from_str(yaml).unwrap();
        assert!(!config.is_enabled());
    }
}