Skip to main content

agentlib_reasoning/
planner.rs

1use crate::utils::{call_model, extract_text, parse_json};
2use agentlib_core::{
3    ModelMessage, ModelRequest, PlanTask, ReasoningContext, ReasoningEngine, ReasoningStep, Role,
4};
5use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use serde::Deserialize;
8use std::collections::HashMap;
9use uuid::Uuid;
10
11pub struct PlannerEngine {
12    max_execution_steps: usize,
13    allow_replan: bool,
14    planner_prompt: String,
15    executor_prompt: String,
16}
17
18impl PlannerEngine {
19    pub fn new(
20        max_execution_steps: usize,
21        allow_replan: bool,
22        planner_prompt: Option<String>,
23        executor_prompt: Option<String>,
24    ) -> Self {
25        Self {
26            max_execution_steps,
27            allow_replan,
28            planner_prompt: planner_prompt.unwrap_or_else(|| {
29                "You are a planning assistant. Break the user's request into a clear, ordered list of subtasks.\n\nRespond with ONLY a JSON array of tasks in this exact format (no markdown, no preamble):\n[\n  { \"id\": \"t1\", \"description\": \"...\", \"dependsOn\": [] },\n  { \"id\": \"t2\", \"description\": \"...\", \"dependsOn\": [\"t1\"] }\n]\n\nRules:\n- Each task must be atomic and independently executable\n- dependsOn lists task ids that must complete first\n- Order tasks so dependencies come first\n- Be specific — the executor will act on each description".to_string()
30            }),
31            executor_prompt: executor_prompt.unwrap_or_else(|| {
32                "You are an execution assistant. Complete the given subtask using available tools.\nFocus only on the current task. Be concise and direct.".to_string()
33            }),
34        }
35    }
36
37    async fn make_plan(&self, r_ctx: &mut ReasoningContext<'_>) -> Result<Vec<PlanTask>> {
38        let available_tools = r_ctx
39            .tools
40            .list()
41            .iter()
42            .map(|t| format!("- {}: {:?}", t.name, t.description))
43            .collect::<Vec<_>>()
44            .join("\n");
45        let tool_context = if available_tools.is_empty() {
46            "".to_string()
47        } else {
48            format!("\n\nAvailable Tools:\n{}", available_tools)
49        };
50
51        let plan_messages = vec![
52            ModelMessage {
53                role: Role::System,
54                content: format!("{}{}", self.planner_prompt, tool_context),
55                tool_call_id: None,
56                tool_calls: None,
57            },
58            ModelMessage {
59                role: Role::User,
60                content: r_ctx.ctx.input.clone(),
61                tool_call_id: None,
62                tool_calls: None,
63            },
64        ];
65
66        let request = ModelRequest {
67            messages: plan_messages,
68            tools: None,
69        };
70
71        let response = r_ctx.model.complete(request).await?;
72
73        // Accumulate usage
74        if let Some(usage) = &response.usage {
75            r_ctx.ctx.usage.prompt_tokens += usage.prompt_tokens;
76            r_ctx.ctx.usage.completion_tokens += usage.completion_tokens;
77            r_ctx.ctx.usage.total_tokens += usage.total_tokens;
78        }
79
80        #[derive(Deserialize)]
81        struct RawTask {
82            id: Option<String>,
83            description: String,
84            #[serde(rename = "dependsOn")]
85            depends_on: Option<Vec<String>>,
86        }
87
88        match parse_json::<Vec<RawTask>>(&response.message.content) {
89            Ok(raw_tasks) => Ok(raw_tasks
90                .into_iter()
91                .map(|t| PlanTask {
92                    id: t.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
93                    description: t.description,
94                    depends_on: t.depends_on,
95                    status: "pending".to_string(),
96                    result: None,
97                })
98                .collect()),
99            Err(_) => Ok(vec![PlanTask {
100                id: "t1".to_string(),
101                description: r_ctx.ctx.input.clone(),
102                depends_on: None,
103                status: "pending".to_string(),
104                result: None,
105            }]),
106        }
107    }
108
109    async fn execute_task(
110        &self,
111        r_ctx: &mut ReasoningContext<'_>,
112        task: &PlanTask,
113        previous_results: &HashMap<String, String>,
114    ) -> Result<String> {
115        let dep_context = if let Some(deps) = &task.depends_on {
116            if deps.is_empty() {
117                "".to_string()
118            } else {
119                format!(
120                    "\n\nContext from previous tasks:\n{}",
121                    deps.iter()
122                        .map(|id| format!(
123                            "[{}]: {}",
124                            id,
125                            previous_results.get(id).unwrap_or(&"N/A".to_string())
126                        ))
127                        .collect::<Vec<_>>()
128                        .join("\n")
129                )
130            }
131        } else {
132            "".to_string()
133        };
134
135        let mut task_messages = vec![
136            ModelMessage {
137                role: Role::System,
138                content: self.executor_prompt.clone(),
139                tool_call_id: None,
140                tool_calls: None,
141            },
142            ModelMessage {
143                role: Role::User,
144                content: format!(
145                    "Original goal: {}\n\nCurrent task: {}{}",
146                    r_ctx.ctx.input, task.description, dep_context
147                ),
148                tool_call_id: None,
149                tool_calls: None,
150            },
151        ];
152
153        let mut steps = 0;
154        let max_task_steps = 5;
155
156        while steps < max_task_steps {
157            let response = call_model(r_ctx, task_messages.clone()).await?;
158            task_messages.push(response.message.clone());
159
160            if let Some(tool_calls) = &response.message.tool_calls {
161                if tool_calls.is_empty() {
162                    return Ok(extract_text(&response.message.content));
163                }
164
165                for tc in tool_calls {
166                    let result = r_ctx
167                        .call_tool(&tc.name, tc.arguments.clone(), tc.id.clone())
168                        .await?;
169                    task_messages.push(ModelMessage {
170                        role: Role::Tool,
171                        content: result.to_string(),
172                        tool_call_id: Some(tc.id.clone()),
173                        tool_calls: None,
174                    });
175                }
176            } else {
177                return Ok(extract_text(&response.message.content));
178            }
179
180            steps += 1;
181        }
182
183        Err(anyhow!(
184            "[PlannerEngine] Task \"{}\" exceeded max steps.",
185            task.id
186        ))
187    }
188
189    async fn synthesize(
190        &self,
191        r_ctx: &mut ReasoningContext<'_>,
192        plan: &[PlanTask],
193        results: &HashMap<String, String>,
194    ) -> Result<String> {
195        let summary_context = plan
196            .iter()
197            .filter(|t| t.status == "done")
198            .map(|t| {
199                format!(
200                    "[{}] {}:\n{}",
201                    t.id,
202                    t.description,
203                    results.get(&t.id).unwrap_or(&"no result".to_string())
204                )
205            })
206            .collect::<Vec<_>>()
207            .join("\n\n");
208
209        let synth_messages = vec![
210            ModelMessage {
211                role: Role::System,
212                content: "Synthesize the results of the completed tasks into a clear, direct answer to the original user request.".to_string(),
213                tool_call_id: None,
214                tool_calls: None,
215            },
216            ModelMessage {
217                role: Role::User,
218                content: format!("Original request: {}\n\nTask results:\n{}", r_ctx.ctx.input, summary_context),
219                tool_call_id: None,
220                tool_calls: None,
221            }
222        ];
223
224        let request = ModelRequest {
225            messages: synth_messages,
226            tools: None,
227        };
228
229        let response = r_ctx.model.complete(request).await?;
230
231        // Accumulate usage
232        if let Some(usage) = &response.usage {
233            r_ctx.ctx.usage.prompt_tokens += usage.prompt_tokens;
234            r_ctx.ctx.usage.completion_tokens += usage.completion_tokens;
235            r_ctx.ctx.usage.total_tokens += usage.total_tokens;
236        }
237
238        Ok(response.message.content)
239    }
240}
241
242impl Default for PlannerEngine {
243    fn default() -> Self {
244        Self::new(20, false, None, None)
245    }
246}
247
248#[async_trait]
249impl ReasoningEngine for PlannerEngine {
250    fn name(&self) -> &str {
251        "planner"
252    }
253
254    async fn execute(&self, r_ctx: &mut ReasoningContext<'_>) -> Result<String> {
255        // Phase 1: Planning
256        let mut plan = self.make_plan(r_ctx).await?;
257
258        r_ctx.push_step(ReasoningStep::Plan {
259            tasks: plan.clone(),
260            engine: self.name().to_string(),
261        });
262
263        // Phase 2: Execution
264        let mut task_results = HashMap::new();
265        let mut execution_steps = 0;
266
267        for i in 0..plan.len() {
268            if execution_steps >= self.max_execution_steps {
269                return Err(anyhow!(
270                    "[PlannerEngine] Max execution steps ({}) reached.",
271                    self.max_execution_steps
272                ));
273            }
274
275            let task = &mut plan[i];
276
277            // Check dependencies
278            if let Some(deps) = &task.depends_on {
279                let unmet_deps: Vec<_> = deps
280                    .iter()
281                    .filter(|dep| !task_results.contains_key(*dep))
282                    .collect();
283                if !unmet_deps.is_empty() {
284                    continue;
285                }
286            }
287
288            task.status = "in_progress".to_string();
289            r_ctx.push_step(ReasoningStep::Thought {
290                content: format!("Executing task [{}]: {}", task.id, task.description),
291                engine: self.name().to_string(),
292            });
293
294            match self.execute_task(r_ctx, task, &task_results).await {
295                Ok(result) => {
296                    task.status = "done".to_string();
297                    task.result = Some(serde_json::Value::String(result.clone()));
298                    task_results.insert(task.id.clone(), result);
299                    execution_steps += 1;
300                }
301                Err(err) => {
302                    task.status = "failed".to_string();
303                    if !self.allow_replan {
304                        return Err(anyhow!(
305                            "[PlannerEngine] Task \"{}\" failed: {}",
306                            task.id,
307                            err
308                        ));
309                    }
310                }
311            }
312        }
313
314        // Phase 3: Synthesize
315        let summary = self.synthesize(r_ctx, &plan, &task_results).await?;
316        r_ctx.push_step(ReasoningStep::Response {
317            content: summary.clone(),
318            engine: self.name().to_string(),
319        });
320
321        Ok(summary)
322    }
323}