agentlib_reasoning/
cot.rs1use crate::utils::{call_model, execute_tool_calls, extract_text};
2use agentlib_core::{ModelMessage, ReasoningContext, ReasoningEngine, ReasoningStep, Role};
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5
6pub struct ChainOfThoughtEngine {
7 use_thinking_tags: bool,
8 max_tool_steps: usize,
9 thinking_instruction: String,
10}
11
12impl ChainOfThoughtEngine {
13 pub fn new(
14 use_thinking_tags: bool,
15 max_tool_steps: usize,
16 thinking_instruction: Option<String>,
17 ) -> Self {
18 Self {
19 use_thinking_tags,
20 max_tool_steps,
21 thinking_instruction: thinking_instruction.unwrap_or_else(|| {
22 "Before answering, reason step by step inside <thinking> tags.\nWork through the problem carefully, considering all relevant information.\nThen provide your final answer outside the tags.".to_string()
23 }),
24 }
25 }
26
27 fn inject_instruction(&self, messages: Vec<ModelMessage>) -> Vec<ModelMessage> {
28 if !self.use_thinking_tags {
29 return messages;
30 }
31
32 let mut result = messages;
33 let system_idx = result.iter().position(|m| m.role == Role::System);
34
35 if let Some(idx) = system_idx {
36 let sys = &mut result[idx];
37 sys.content = format!("{}\n\n{}", sys.content, self.thinking_instruction);
38 } else {
39 result.insert(
40 0,
41 ModelMessage {
42 role: Role::System,
43 content: self.thinking_instruction.clone(),
44 tool_call_id: None,
45 tool_calls: None,
46 },
47 );
48 }
49
50 result
51 }
52
53 fn extract_thinking(&self, content: &str) -> Option<String> {
54 let re = regex::Regex::new(r"(?i)<thinking>([\s\S]*?)</thinking>").unwrap();
55 re.captures(content)
56 .map(|caps| caps.get(1).unwrap().as_str().trim().to_string())
57 }
58}
59
60impl Default for ChainOfThoughtEngine {
61 fn default() -> Self {
62 Self::new(true, 5, None)
63 }
64}
65
66#[async_trait]
67impl ReasoningEngine for ChainOfThoughtEngine {
68 fn name(&self) -> &str {
69 "cot"
70 }
71
72 async fn execute(&self, r_ctx: &mut ReasoningContext<'_>) -> Result<String> {
73 let messages = self.inject_instruction(r_ctx.ctx.messages.clone());
74
75 let response = call_model(r_ctx, messages).await?;
77 r_ctx.ctx.messages.push(response.message.clone());
78
79 if self.use_thinking_tags {
81 if let Some(thinking) = self.extract_thinking(&response.message.content) {
82 r_ctx.push_step(ReasoningStep::Thought {
83 content: thinking,
84 engine: self.name().to_string(),
85 });
86 }
87 }
88
89 if response
91 .message
92 .tool_calls
93 .as_ref()
94 .map_or(true, |tc| tc.is_empty())
95 {
96 let answer = extract_text(&response.message.content);
97 r_ctx.push_step(ReasoningStep::Response {
98 content: answer.clone(),
99 engine: self.name().to_string(),
100 });
101 return Ok(answer);
102 }
103
104 execute_tool_calls(r_ctx, &response).await?;
106 let mut tool_steps = 1;
107
108 while tool_steps < self.max_tool_steps {
109 let next = call_model(r_ctx, r_ctx.ctx.messages.clone()).await?;
110 r_ctx.ctx.messages.push(next.message.clone());
111
112 if next
113 .message
114 .tool_calls
115 .as_ref()
116 .map_or(true, |tc| tc.is_empty())
117 {
118 let answer = extract_text(&next.message.content);
119 r_ctx.push_step(ReasoningStep::Response {
120 content: answer.clone(),
121 engine: self.name().to_string(),
122 });
123 return Ok(answer);
124 }
125
126 execute_tool_calls(r_ctx, &next).await?;
127 tool_steps += 1;
128 }
129
130 Err(anyhow!(
131 "[ChainOfThoughtEngine] Max tool steps ({}) reached.",
132 self.max_tool_steps
133 ))
134 }
135}