1use crate::utils::{call_model, execute_tool_calls, extract_text, parse_json};
2use agentlib_core::{
3 ModelMessage, ModelRequest, ReasoningContext, ReasoningEngine, ReasoningStep, Role,
4};
5use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10struct CritiqueResult {
11 score: u32,
12 issues: Vec<String>,
13 suggestion: String,
14 needs_revision: bool,
15}
16
17pub struct ReflectEngine {
18 max_reflections: usize,
19 acceptance_threshold: u32,
20 max_answer_steps: usize,
21 critique_prompt: String,
22}
23
24impl ReflectEngine {
25 pub fn new(
26 max_reflections: usize,
27 acceptance_threshold: u32,
28 max_answer_steps: usize,
29 critique_prompt: Option<String>,
30 ) -> Self {
31 Self {
32 max_reflections,
33 acceptance_threshold,
34 max_answer_steps,
35 critique_prompt: critique_prompt.unwrap_or_else(|| {
36 "You are a critical evaluator. Review the answer below and assess its quality.\n\nRespond in this exact JSON format (no markdown):\n{\n \"score\": <0-10>,\n \"issues\": [\"<issue 1>\", \"<issue 2>\"],\n \"suggestion\": \"<one-sentence improvement suggestion>\",\n \"needs_revision\": <true|false>\n}\n\nBe strict. Score 10 only for perfect answers. Score < 8 if the answer is incomplete, incorrect, or could be substantially improved.".to_string()
37 }),
38 }
39 }
40
41 async fn generate_answer(
42 &self,
43 r_ctx: &mut ReasoningContext<'_>,
44 messages: &mut Vec<ModelMessage>,
45 ) -> Result<String> {
46 let mut steps = 0;
47 while steps < self.max_answer_steps {
48 let response = call_model(r_ctx, messages.clone()).await?;
49 messages.push(response.message.clone());
50
51 let tool_calls = response.message.tool_calls.as_ref();
52 if tool_calls.is_none() || tool_calls.unwrap().is_empty() {
53 return Ok(extract_text(&response.message.content));
54 }
55
56 execute_tool_calls(r_ctx, &response).await?;
57 steps += 1;
58 }
59 Err(anyhow!(
60 "[ReflectEngine] Max answer steps ({}) reached.",
61 self.max_answer_steps
62 ))
63 }
64
65 async fn critique(
66 &self,
67 r_ctx: &mut ReasoningContext<'_>,
68 question: &str,
69 answer: &str,
70 ) -> Result<CritiqueResult> {
71 let critique_messages = vec![
72 ModelMessage {
73 role: Role::System,
74 content: self.critique_prompt.clone(),
75 tool_call_id: None,
76 tool_calls: None,
77 },
78 ModelMessage {
79 role: Role::User,
80 content: format!("Question:\n{}\n\nAnswer to evaluate:\n{}", question, answer),
81 tool_call_id: None,
82 tool_calls: None,
83 },
84 ];
85
86 let request = ModelRequest {
87 messages: critique_messages,
88 tools: None, };
90
91 let response = r_ctx.model.complete(request).await?;
92
93 if let Some(usage) = &response.usage {
95 r_ctx.ctx.usage.prompt_tokens += usage.prompt_tokens;
96 r_ctx.ctx.usage.completion_tokens += usage.completion_tokens;
97 r_ctx.ctx.usage.total_tokens += usage.total_tokens;
98 }
99
100 match parse_json::<CritiqueResult>(&response.message.content) {
101 Ok(res) => Ok(res),
102 Err(_) => Ok(CritiqueResult {
103 score: 9,
104 issues: vec![],
105 suggestion: "".to_string(),
106 needs_revision: false,
107 }),
108 }
109 }
110
111 async fn revise(
112 &self,
113 r_ctx: &mut ReasoningContext<'_>,
114 question: &str,
115 current_answer: &str,
116 critique: &CritiqueResult,
117 ) -> Result<String> {
118 let revision_messages = vec![
119 ModelMessage {
120 role: Role::System,
121 content: "You are revising your previous answer based on critique. Produce an improved, complete answer.".to_string(),
122 tool_call_id: None,
123 tool_calls: None,
124 },
125 ModelMessage {
126 role: Role::User,
127 content: format!("Original question: {}", question),
128 tool_call_id: None,
129 tool_calls: None,
130 },
131 ModelMessage {
132 role: Role::Assistant,
133 content: current_answer.to_string(),
134 tool_call_id: None,
135 tool_calls: None,
136 },
137 ModelMessage {
138 role: Role::User,
139 content: format!(
140 "Critique of your answer:\n- Score: {}/10\n- Issues: {}\n- Suggestion: {}\n\nPlease revise your answer to address these issues.",
141 critique.score,
142 critique.issues.join(", "),
143 critique.suggestion
144 ),
145 tool_call_id: None,
146 tool_calls: None,
147 }
148 ];
149
150 let response = call_model(r_ctx, revision_messages).await?;
151
152 if let Some(tool_calls) = &response.message.tool_calls {
153 if !tool_calls.is_empty() {
154 r_ctx.ctx.messages.push(response.message.clone());
155 execute_tool_calls(r_ctx, &response).await?;
156 let request = ModelRequest {
158 messages: r_ctx.ctx.messages.clone(),
159 tools: None,
160 };
161 let final_res = r_ctx.model.complete(request).await?;
162 return Ok(extract_text(&final_res.message.content));
163 }
164 }
165
166 Ok(extract_text(&response.message.content))
167 }
168}
169
170impl Default for ReflectEngine {
171 fn default() -> Self {
172 Self::new(2, 8, 5, None)
173 }
174}
175
176#[async_trait]
177impl ReasoningEngine for ReflectEngine {
178 fn name(&self) -> &str {
179 "reflect"
180 }
181
182 async fn execute(&self, r_ctx: &mut ReasoningContext<'_>) -> Result<String> {
183 let mut messages = r_ctx.ctx.messages.clone();
184 let mut answer = self.generate_answer(r_ctx, &mut messages).await?;
185 r_ctx.push_step(ReasoningStep::Thought {
186 content: "Initial answer generated.".to_string(),
187 engine: self.name().to_string(),
188 });
189
190 for i in 0..self.max_reflections {
191 let question = r_ctx.ctx.input.clone();
192 let critique = self.critique(r_ctx, &question, &answer).await?;
193
194 r_ctx.push_step(ReasoningStep::Reflection {
195 assessment: format!(
196 "Score: {}/10. Issues: {}. {}",
197 critique.score,
198 critique.issues.join("; "),
199 critique.suggestion
200 ),
201 needs_revision: critique.needs_revision,
202 engine: self.name().to_string(),
203 });
204
205 if !critique.needs_revision || critique.score >= self.acceptance_threshold {
206 break;
207 }
208
209 r_ctx.push_step(ReasoningStep::Thought {
210 content: format!(
211 "Revising answer (attempt {}/{})...",
212 i + 1,
213 self.max_reflections
214 ),
215 engine: self.name().to_string(),
216 });
217
218 answer = self.revise(r_ctx, &question, &answer, &critique).await?;
219 }
220
221 r_ctx.push_step(ReasoningStep::Response {
222 content: answer.clone(),
223 engine: self.name().to_string(),
224 });
225
226 Ok(answer)
227 }
228}