Skip to main content

agentlib_reasoning/
reflect.rs

1use crate::utils::{call_model, execute_tool_calls, extract_text, parse_json};
2use agentlib_core::{
3    ModelMessage, ModelRequest, ReasoningContext, ReasoningEngine, ReasoningStep, Role,
4};
5use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10struct CritiqueResult {
11    score: u32,
12    issues: Vec<String>,
13    suggestion: String,
14    needs_revision: bool,
15}
16
17pub struct ReflectEngine {
18    max_reflections: usize,
19    acceptance_threshold: u32,
20    max_answer_steps: usize,
21    critique_prompt: String,
22}
23
24impl ReflectEngine {
25    pub fn new(
26        max_reflections: usize,
27        acceptance_threshold: u32,
28        max_answer_steps: usize,
29        critique_prompt: Option<String>,
30    ) -> Self {
31        Self {
32            max_reflections,
33            acceptance_threshold,
34            max_answer_steps,
35            critique_prompt: critique_prompt.unwrap_or_else(|| {
36                "You are a critical evaluator. Review the answer below and assess its quality.\n\nRespond in this exact JSON format (no markdown):\n{\n  \"score\": <0-10>,\n  \"issues\": [\"<issue 1>\", \"<issue 2>\"],\n  \"suggestion\": \"<one-sentence improvement suggestion>\",\n  \"needs_revision\": <true|false>\n}\n\nBe strict. Score 10 only for perfect answers. Score < 8 if the answer is incomplete, incorrect, or could be substantially improved.".to_string()
37            }),
38        }
39    }
40
41    async fn generate_answer(
42        &self,
43        r_ctx: &mut ReasoningContext<'_>,
44        messages: &mut Vec<ModelMessage>,
45    ) -> Result<String> {
46        let mut steps = 0;
47        while steps < self.max_answer_steps {
48            let response = call_model(r_ctx, messages.clone()).await?;
49            messages.push(response.message.clone());
50
51            let tool_calls = response.message.tool_calls.as_ref();
52            if tool_calls.is_none() || tool_calls.unwrap().is_empty() {
53                return Ok(extract_text(&response.message.content));
54            }
55
56            execute_tool_calls(r_ctx, &response).await?;
57            steps += 1;
58        }
59        Err(anyhow!(
60            "[ReflectEngine] Max answer steps ({}) reached.",
61            self.max_answer_steps
62        ))
63    }
64
65    async fn critique(
66        &self,
67        r_ctx: &mut ReasoningContext<'_>,
68        question: &str,
69        answer: &str,
70    ) -> Result<CritiqueResult> {
71        let critique_messages = vec![
72            ModelMessage {
73                role: Role::System,
74                content: self.critique_prompt.clone(),
75                tool_call_id: None,
76                tool_calls: None,
77            },
78            ModelMessage {
79                role: Role::User,
80                content: format!("Question:\n{}\n\nAnswer to evaluate:\n{}", question, answer),
81                tool_call_id: None,
82                tool_calls: None,
83            },
84        ];
85
86        let request = ModelRequest {
87            messages: critique_messages,
88            tools: None, // no tools for critique
89        };
90
91        let response = r_ctx.model.complete(request).await?;
92
93        // Accumulate usage
94        if let Some(usage) = &response.usage {
95            r_ctx.ctx.usage.prompt_tokens += usage.prompt_tokens;
96            r_ctx.ctx.usage.completion_tokens += usage.completion_tokens;
97            r_ctx.ctx.usage.total_tokens += usage.total_tokens;
98        }
99
100        match parse_json::<CritiqueResult>(&response.message.content) {
101            Ok(res) => Ok(res),
102            Err(_) => Ok(CritiqueResult {
103                score: 9,
104                issues: vec![],
105                suggestion: "".to_string(),
106                needs_revision: false,
107            }),
108        }
109    }
110
111    async fn revise(
112        &self,
113        r_ctx: &mut ReasoningContext<'_>,
114        question: &str,
115        current_answer: &str,
116        critique: &CritiqueResult,
117    ) -> Result<String> {
118        let revision_messages = vec![
119            ModelMessage {
120                role: Role::System,
121                content: "You are revising your previous answer based on critique. Produce an improved, complete answer.".to_string(),
122                tool_call_id: None,
123                tool_calls: None,
124            },
125            ModelMessage {
126                role: Role::User,
127                content: format!("Original question: {}", question),
128                tool_call_id: None,
129                tool_calls: None,
130            },
131            ModelMessage {
132                role: Role::Assistant,
133                content: current_answer.to_string(),
134                tool_call_id: None,
135                tool_calls: None,
136            },
137            ModelMessage {
138                role: Role::User,
139                content: format!(
140                    "Critique of your answer:\n- Score: {}/10\n- Issues: {}\n- Suggestion: {}\n\nPlease revise your answer to address these issues.",
141                    critique.score,
142                    critique.issues.join(", "),
143                    critique.suggestion
144                ),
145                tool_call_id: None,
146                tool_calls: None,
147            }
148        ];
149
150        let response = call_model(r_ctx, revision_messages).await?;
151
152        if let Some(tool_calls) = &response.message.tool_calls {
153            if !tool_calls.is_empty() {
154                r_ctx.ctx.messages.push(response.message.clone());
155                execute_tool_calls(r_ctx, &response).await?;
156                // One more pass after tools
157                let request = ModelRequest {
158                    messages: r_ctx.ctx.messages.clone(),
159                    tools: None,
160                };
161                let final_res = r_ctx.model.complete(request).await?;
162                return Ok(extract_text(&final_res.message.content));
163            }
164        }
165
166        Ok(extract_text(&response.message.content))
167    }
168}
169
170impl Default for ReflectEngine {
171    fn default() -> Self {
172        Self::new(2, 8, 5, None)
173    }
174}
175
176#[async_trait]
177impl ReasoningEngine for ReflectEngine {
178    fn name(&self) -> &str {
179        "reflect"
180    }
181
182    async fn execute(&self, r_ctx: &mut ReasoningContext<'_>) -> Result<String> {
183        let mut messages = r_ctx.ctx.messages.clone();
184        let mut answer = self.generate_answer(r_ctx, &mut messages).await?;
185        r_ctx.push_step(ReasoningStep::Thought {
186            content: "Initial answer generated.".to_string(),
187            engine: self.name().to_string(),
188        });
189
190        for i in 0..self.max_reflections {
191            let question = r_ctx.ctx.input.clone();
192            let critique = self.critique(r_ctx, &question, &answer).await?;
193
194            r_ctx.push_step(ReasoningStep::Reflection {
195                assessment: format!(
196                    "Score: {}/10. Issues: {}. {}",
197                    critique.score,
198                    critique.issues.join("; "),
199                    critique.suggestion
200                ),
201                needs_revision: critique.needs_revision,
202                engine: self.name().to_string(),
203            });
204
205            if !critique.needs_revision || critique.score >= self.acceptance_threshold {
206                break;
207            }
208
209            r_ctx.push_step(ReasoningStep::Thought {
210                content: format!(
211                    "Revising answer (attempt {}/{})...",
212                    i + 1,
213                    self.max_reflections
214                ),
215                engine: self.name().to_string(),
216            });
217
218            answer = self.revise(r_ctx, &question, &answer, &critique).await?;
219        }
220
221        r_ctx.push_step(ReasoningStep::Response {
222            content: answer.clone(),
223            engine: self.name().to_string(),
224        });
225
226        Ok(answer)
227    }
228}