Skip to main content

agentlib_reasoning/
cot.rs

1use crate::utils::{call_model, execute_tool_calls, extract_text};
2use agentlib_core::{ModelMessage, ReasoningContext, ReasoningEngine, ReasoningStep, Role};
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5
6pub struct ChainOfThoughtEngine {
7    use_thinking_tags: bool,
8    max_tool_steps: usize,
9    thinking_instruction: String,
10}
11
12impl ChainOfThoughtEngine {
13    pub fn new(
14        use_thinking_tags: bool,
15        max_tool_steps: usize,
16        thinking_instruction: Option<String>,
17    ) -> Self {
18        Self {
19            use_thinking_tags,
20            max_tool_steps,
21            thinking_instruction: thinking_instruction.unwrap_or_else(|| {
22                "Before answering, reason step by step inside <thinking> tags.\nWork through the problem carefully, considering all relevant information.\nThen provide your final answer outside the tags.".to_string()
23            }),
24        }
25    }
26
27    fn inject_instruction(&self, messages: Vec<ModelMessage>) -> Vec<ModelMessage> {
28        if !self.use_thinking_tags {
29            return messages;
30        }
31
32        let mut result = messages;
33        let system_idx = result.iter().position(|m| m.role == Role::System);
34
35        if let Some(idx) = system_idx {
36            let sys = &mut result[idx];
37            sys.content = format!("{}\n\n{}", sys.content, self.thinking_instruction);
38        } else {
39            result.insert(
40                0,
41                ModelMessage {
42                    role: Role::System,
43                    content: self.thinking_instruction.clone(),
44                    tool_call_id: None,
45                    tool_calls: None,
46                },
47            );
48        }
49
50        result
51    }
52
53    fn extract_thinking(&self, content: &str) -> Option<String> {
54        let re = regex::Regex::new(r"(?i)<thinking>([\s\S]*?)</thinking>").unwrap();
55        re.captures(content)
56            .map(|caps| caps.get(1).unwrap().as_str().trim().to_string())
57    }
58}
59
60impl Default for ChainOfThoughtEngine {
61    fn default() -> Self {
62        Self::new(true, 5, None)
63    }
64}
65
66#[async_trait]
67impl ReasoningEngine for ChainOfThoughtEngine {
68    fn name(&self) -> &str {
69        "cot"
70    }
71
72    async fn execute(&self, r_ctx: &mut ReasoningContext<'_>) -> Result<String> {
73        let messages = self.inject_instruction(r_ctx.ctx.messages.clone());
74
75        // Pass 1: Initial reasoning + possible tool call
76        let response = call_model(r_ctx, messages).await?;
77        r_ctx.ctx.messages.push(response.message.clone());
78
79        // Extract and emit the thinking portion
80        if self.use_thinking_tags {
81            if let Some(thinking) = self.extract_thinking(&response.message.content) {
82                r_ctx.push_step(ReasoningStep::Thought {
83                    content: thinking,
84                    engine: self.name().to_string(),
85                });
86            }
87        }
88
89        // No tool calls -> extract clean answer
90        if response
91            .message
92            .tool_calls
93            .as_ref()
94            .map_or(true, |tc| tc.is_empty())
95        {
96            let answer = extract_text(&response.message.content);
97            r_ctx.push_step(ReasoningStep::Response {
98                content: answer.clone(),
99                engine: self.name().to_string(),
100            });
101            return Ok(answer);
102        }
103
104        // Tool execution loop
105        execute_tool_calls(r_ctx, &response).await?;
106        let mut tool_steps = 1;
107
108        while tool_steps < self.max_tool_steps {
109            let next = call_model(r_ctx, r_ctx.ctx.messages.clone()).await?;
110            r_ctx.ctx.messages.push(next.message.clone());
111
112            if next
113                .message
114                .tool_calls
115                .as_ref()
116                .map_or(true, |tc| tc.is_empty())
117            {
118                let answer = extract_text(&next.message.content);
119                r_ctx.push_step(ReasoningStep::Response {
120                    content: answer.clone(),
121                    engine: self.name().to_string(),
122                });
123                return Ok(answer);
124            }
125
126            execute_tool_calls(r_ctx, &next).await?;
127            tool_steps += 1;
128        }
129
130        Err(anyhow!(
131            "[ChainOfThoughtEngine] Max tool steps ({}) reached.",
132            self.max_tool_steps
133        ))
134    }
135}