1use crate::utils::{call_model, extract_text, parse_json};
2use agentlib_core::{
3 ModelMessage, ModelRequest, PlanTask, ReasoningContext, ReasoningEngine, ReasoningStep, Role,
4};
5use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use serde::Deserialize;
8use std::collections::HashMap;
9use uuid::Uuid;
10
11pub struct PlannerEngine {
12 max_execution_steps: usize,
13 allow_replan: bool,
14 planner_prompt: String,
15 executor_prompt: String,
16}
17
18impl PlannerEngine {
19 pub fn new(
20 max_execution_steps: usize,
21 allow_replan: bool,
22 planner_prompt: Option<String>,
23 executor_prompt: Option<String>,
24 ) -> Self {
25 Self {
26 max_execution_steps,
27 allow_replan,
28 planner_prompt: planner_prompt.unwrap_or_else(|| {
29 "You are a planning assistant. Break the user's request into a clear, ordered list of subtasks.\n\nRespond with ONLY a JSON array of tasks in this exact format (no markdown, no preamble):\n[\n { \"id\": \"t1\", \"description\": \"...\", \"dependsOn\": [] },\n { \"id\": \"t2\", \"description\": \"...\", \"dependsOn\": [\"t1\"] }\n]\n\nRules:\n- Each task must be atomic and independently executable\n- dependsOn lists task ids that must complete first\n- Order tasks so dependencies come first\n- Be specific — the executor will act on each description".to_string()
30 }),
31 executor_prompt: executor_prompt.unwrap_or_else(|| {
32 "You are an execution assistant. Complete the given subtask using available tools.\nFocus only on the current task. Be concise and direct.".to_string()
33 }),
34 }
35 }
36
37 async fn make_plan(&self, r_ctx: &mut ReasoningContext<'_>) -> Result<Vec<PlanTask>> {
38 let available_tools = r_ctx
39 .tools
40 .list()
41 .iter()
42 .map(|t| format!("- {}: {:?}", t.name, t.description))
43 .collect::<Vec<_>>()
44 .join("\n");
45 let tool_context = if available_tools.is_empty() {
46 "".to_string()
47 } else {
48 format!("\n\nAvailable Tools:\n{}", available_tools)
49 };
50
51 let plan_messages = vec![
52 ModelMessage {
53 role: Role::System,
54 content: format!("{}{}", self.planner_prompt, tool_context),
55 tool_call_id: None,
56 tool_calls: None,
57 },
58 ModelMessage {
59 role: Role::User,
60 content: r_ctx.ctx.input.clone(),
61 tool_call_id: None,
62 tool_calls: None,
63 },
64 ];
65
66 let request = ModelRequest {
67 messages: plan_messages,
68 tools: None,
69 };
70
71 let response = r_ctx.model.complete(request).await?;
72
73 if let Some(usage) = &response.usage {
75 r_ctx.ctx.usage.prompt_tokens += usage.prompt_tokens;
76 r_ctx.ctx.usage.completion_tokens += usage.completion_tokens;
77 r_ctx.ctx.usage.total_tokens += usage.total_tokens;
78 }
79
80 #[derive(Deserialize)]
81 struct RawTask {
82 id: Option<String>,
83 description: String,
84 #[serde(rename = "dependsOn")]
85 depends_on: Option<Vec<String>>,
86 }
87
88 match parse_json::<Vec<RawTask>>(&response.message.content) {
89 Ok(raw_tasks) => Ok(raw_tasks
90 .into_iter()
91 .map(|t| PlanTask {
92 id: t.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
93 description: t.description,
94 depends_on: t.depends_on,
95 status: "pending".to_string(),
96 result: None,
97 })
98 .collect()),
99 Err(_) => Ok(vec![PlanTask {
100 id: "t1".to_string(),
101 description: r_ctx.ctx.input.clone(),
102 depends_on: None,
103 status: "pending".to_string(),
104 result: None,
105 }]),
106 }
107 }
108
109 async fn execute_task(
110 &self,
111 r_ctx: &mut ReasoningContext<'_>,
112 task: &PlanTask,
113 previous_results: &HashMap<String, String>,
114 ) -> Result<String> {
115 let dep_context = if let Some(deps) = &task.depends_on {
116 if deps.is_empty() {
117 "".to_string()
118 } else {
119 format!(
120 "\n\nContext from previous tasks:\n{}",
121 deps.iter()
122 .map(|id| format!(
123 "[{}]: {}",
124 id,
125 previous_results.get(id).unwrap_or(&"N/A".to_string())
126 ))
127 .collect::<Vec<_>>()
128 .join("\n")
129 )
130 }
131 } else {
132 "".to_string()
133 };
134
135 let mut task_messages = vec![
136 ModelMessage {
137 role: Role::System,
138 content: self.executor_prompt.clone(),
139 tool_call_id: None,
140 tool_calls: None,
141 },
142 ModelMessage {
143 role: Role::User,
144 content: format!(
145 "Original goal: {}\n\nCurrent task: {}{}",
146 r_ctx.ctx.input, task.description, dep_context
147 ),
148 tool_call_id: None,
149 tool_calls: None,
150 },
151 ];
152
153 let mut steps = 0;
154 let max_task_steps = 5;
155
156 while steps < max_task_steps {
157 let response = call_model(r_ctx, task_messages.clone()).await?;
158 task_messages.push(response.message.clone());
159
160 if let Some(tool_calls) = &response.message.tool_calls {
161 if tool_calls.is_empty() {
162 return Ok(extract_text(&response.message.content));
163 }
164
165 for tc in tool_calls {
166 let result = r_ctx
167 .call_tool(&tc.name, tc.arguments.clone(), tc.id.clone())
168 .await?;
169 task_messages.push(ModelMessage {
170 role: Role::Tool,
171 content: result.to_string(),
172 tool_call_id: Some(tc.id.clone()),
173 tool_calls: None,
174 });
175 }
176 } else {
177 return Ok(extract_text(&response.message.content));
178 }
179
180 steps += 1;
181 }
182
183 Err(anyhow!(
184 "[PlannerEngine] Task \"{}\" exceeded max steps.",
185 task.id
186 ))
187 }
188
189 async fn synthesize(
190 &self,
191 r_ctx: &mut ReasoningContext<'_>,
192 plan: &[PlanTask],
193 results: &HashMap<String, String>,
194 ) -> Result<String> {
195 let summary_context = plan
196 .iter()
197 .filter(|t| t.status == "done")
198 .map(|t| {
199 format!(
200 "[{}] {}:\n{}",
201 t.id,
202 t.description,
203 results.get(&t.id).unwrap_or(&"no result".to_string())
204 )
205 })
206 .collect::<Vec<_>>()
207 .join("\n\n");
208
209 let synth_messages = vec![
210 ModelMessage {
211 role: Role::System,
212 content: "Synthesize the results of the completed tasks into a clear, direct answer to the original user request.".to_string(),
213 tool_call_id: None,
214 tool_calls: None,
215 },
216 ModelMessage {
217 role: Role::User,
218 content: format!("Original request: {}\n\nTask results:\n{}", r_ctx.ctx.input, summary_context),
219 tool_call_id: None,
220 tool_calls: None,
221 }
222 ];
223
224 let request = ModelRequest {
225 messages: synth_messages,
226 tools: None,
227 };
228
229 let response = r_ctx.model.complete(request).await?;
230
231 if let Some(usage) = &response.usage {
233 r_ctx.ctx.usage.prompt_tokens += usage.prompt_tokens;
234 r_ctx.ctx.usage.completion_tokens += usage.completion_tokens;
235 r_ctx.ctx.usage.total_tokens += usage.total_tokens;
236 }
237
238 Ok(response.message.content)
239 }
240}
241
242impl Default for PlannerEngine {
243 fn default() -> Self {
244 Self::new(20, false, None, None)
245 }
246}
247
248#[async_trait]
249impl ReasoningEngine for PlannerEngine {
250 fn name(&self) -> &str {
251 "planner"
252 }
253
254 async fn execute(&self, r_ctx: &mut ReasoningContext<'_>) -> Result<String> {
255 let mut plan = self.make_plan(r_ctx).await?;
257
258 r_ctx.push_step(ReasoningStep::Plan {
259 tasks: plan.clone(),
260 engine: self.name().to_string(),
261 });
262
263 let mut task_results = HashMap::new();
265 let mut execution_steps = 0;
266
267 for i in 0..plan.len() {
268 if execution_steps >= self.max_execution_steps {
269 return Err(anyhow!(
270 "[PlannerEngine] Max execution steps ({}) reached.",
271 self.max_execution_steps
272 ));
273 }
274
275 let task = &mut plan[i];
276
277 if let Some(deps) = &task.depends_on {
279 let unmet_deps: Vec<_> = deps
280 .iter()
281 .filter(|dep| !task_results.contains_key(*dep))
282 .collect();
283 if !unmet_deps.is_empty() {
284 continue;
285 }
286 }
287
288 task.status = "in_progress".to_string();
289 r_ctx.push_step(ReasoningStep::Thought {
290 content: format!("Executing task [{}]: {}", task.id, task.description),
291 engine: self.name().to_string(),
292 });
293
294 match self.execute_task(r_ctx, task, &task_results).await {
295 Ok(result) => {
296 task.status = "done".to_string();
297 task.result = Some(serde_json::Value::String(result.clone()));
298 task_results.insert(task.id.clone(), result);
299 execution_steps += 1;
300 }
301 Err(err) => {
302 task.status = "failed".to_string();
303 if !self.allow_replan {
304 return Err(anyhow!(
305 "[PlannerEngine] Task \"{}\" failed: {}",
306 task.id,
307 err
308 ));
309 }
310 }
311 }
312 }
313
314 let summary = self.synthesize(r_ctx, &plan, &task_results).await?;
316 r_ctx.push_step(ReasoningStep::Response {
317 content: summary.clone(),
318 engine: self.name().to_string(),
319 });
320
321 Ok(summary)
322 }
323}