Skip to main content

cortexai_agents/
planning.rs

1//! Planning module for agent task execution
2//!
3//! This module provides planning capabilities for agents, allowing them to:
4//! - Generate execution plans before acting
5//! - Execute plans step by step
6//! - Re-plan adaptively based on results
7//!
8//! # Planning Modes
9//!
10//! - `Disabled`: No planning, direct execution (default)
11//! - `BeforeTask`: Generate a plan before each task
12//! - `FullPlan`: Generate complete plan upfront, then execute all steps
13//! - `Adaptive`: Re-plan after each step based on results
14
15use cortexai_core::tool::ToolSchema;
16use cortexai_core::types::{ExecutionPlan, PlanStep};
17use cortexai_core::LLMMessage;
18use thiserror::Error;
19use tracing::{debug, info, warn};
20
21/// Prompt template for plan generation
22const PLANNING_PROMPT: &str = r#"You are a planning agent. Your task is to create a detailed execution plan for the following goal.
23
24GOAL: {goal}
25
26AVAILABLE TOOLS:
27{tools}
28
29Create a step-by-step plan to achieve this goal. For each step:
301. Describe what action to take
312. Specify which tool(s) to use (if any)
323. Describe the expected result
33
34Respond ONLY with a JSON object in this exact format:
35{
36  "reasoning": "Brief explanation of your approach",
37  "steps": [
38    {
39      "step_number": 1,
40      "description": "What to do in this step",
41      "expected_result": "What we expect to get from this step",
42      "tools": ["tool_name_1", "tool_name_2"]
43    }
44  ]
45}
46
47Keep the plan concise but complete. Aim for 3-7 steps maximum."#;
48
49/// Prompt template for adaptive re-planning
50const REPLAN_PROMPT: &str = r#"You are a planning agent. Review the current plan progress and decide if re-planning is needed.
51
52ORIGINAL GOAL: {goal}
53
54CURRENT PLAN:
55{current_plan}
56
57COMPLETED STEPS:
58{completed_steps}
59
60LAST RESULT: {last_result}
61
62REMAINING STEPS:
63{remaining_steps}
64
65Based on the results so far, should we:
661. Continue with the current plan
672. Modify the remaining steps
683. Add new steps
69
70Respond ONLY with a JSON object:
71{
72  "action": "continue" | "modify" | "add",
73  "reasoning": "Why this decision",
74  "modified_steps": [
75    // Only if action is "modify" or "add"
76    {
77      "step_number": N,
78      "description": "...",
79      "expected_result": "...",
80      "tools": []
81    }
82  ]
83}"#;
84
85/// Plan generator for creating execution plans from LLM responses
86pub struct PlanGenerator;
87
88impl PlanGenerator {
89    /// Generate a planning prompt for the given goal and tools
90    pub fn create_planning_prompt(goal: &str, tools: &[ToolSchema]) -> String {
91        let tools_desc = tools
92            .iter()
93            .map(|t| format!("- {}: {}", t.name, t.description))
94            .collect::<Vec<_>>()
95            .join("\n");
96
97        PLANNING_PROMPT
98            .replace("{goal}", goal)
99            .replace("{tools}", &tools_desc)
100    }
101
102    /// Create an LLM message for plan generation
103    pub fn create_planning_message(goal: &str, tools: &[ToolSchema]) -> LLMMessage {
104        LLMMessage::user(Self::create_planning_prompt(goal, tools))
105    }
106
107    /// Parse a plan from LLM JSON response
108    pub fn parse_plan(goal: &str, response: &str) -> Result<ExecutionPlan, PlanParseError> {
109        // Try to extract JSON from the response
110        let json_str = extract_json(response)?;
111
112        // Parse the JSON
113        let parsed: serde_json::Value = serde_json::from_str(&json_str)
114            .map_err(|e| PlanParseError::InvalidJson(e.to_string()))?;
115
116        // Extract reasoning
117        let reasoning = parsed
118            .get("reasoning")
119            .and_then(|v| v.as_str())
120            .unwrap_or("")
121            .to_string();
122
123        // Extract steps
124        let steps_value = parsed
125            .get("steps")
126            .ok_or_else(|| PlanParseError::MissingField("steps".to_string()))?;
127
128        let steps_array = steps_value
129            .as_array()
130            .ok_or_else(|| PlanParseError::InvalidField("steps must be an array".to_string()))?;
131
132        let mut steps = Vec::with_capacity(steps_array.len());
133
134        for (idx, step_value) in steps_array.iter().enumerate() {
135            let step_number = step_value
136                .get("step_number")
137                .and_then(|v| v.as_u64())
138                .unwrap_or((idx + 1) as u64) as usize;
139
140            let description = step_value
141                .get("description")
142                .and_then(|v| v.as_str())
143                .ok_or_else(|| {
144                    PlanParseError::InvalidField(format!("step {} missing description", idx + 1))
145                })?
146                .to_string();
147
148            let expected_result = step_value
149                .get("expected_result")
150                .and_then(|v| v.as_str())
151                .unwrap_or("")
152                .to_string();
153
154            let tools = step_value
155                .get("tools")
156                .and_then(|v| v.as_array())
157                .map(|arr| {
158                    arr.iter()
159                        .filter_map(|v| v.as_str().map(String::from))
160                        .collect()
161                })
162                .unwrap_or_default();
163
164            steps.push(
165                PlanStep::new(step_number, description)
166                    .with_expected_result(expected_result)
167                    .with_tools(tools),
168            );
169        }
170
171        if steps.is_empty() {
172            return Err(PlanParseError::EmptyPlan);
173        }
174
175        Ok(ExecutionPlan::new(goal)
176            .with_steps(steps)
177            .with_reasoning(reasoning))
178    }
179
180    /// Create a re-planning prompt
181    pub fn create_replan_prompt(plan: &ExecutionPlan, last_result: &str) -> String {
182        let current_plan = format!(
183            "Goal: {}\nTotal steps: {}\nProgress: {:.0}%",
184            plan.goal,
185            plan.steps.len(),
186            plan.progress() * 100.0
187        );
188
189        let completed_steps = plan
190            .steps
191            .iter()
192            .filter(|s| s.completed)
193            .map(|s| {
194                format!(
195                    "Step {}: {} -> {}",
196                    s.step_number,
197                    s.description,
198                    s.actual_result.as_deref().unwrap_or("(no result)")
199                )
200            })
201            .collect::<Vec<_>>()
202            .join("\n");
203
204        let remaining_steps = plan
205            .steps
206            .iter()
207            .filter(|s| !s.completed)
208            .map(|s| format!("Step {}: {}", s.step_number, s.description))
209            .collect::<Vec<_>>()
210            .join("\n");
211
212        REPLAN_PROMPT
213            .replace("{goal}", &plan.goal)
214            .replace("{current_plan}", &current_plan)
215            .replace("{completed_steps}", &completed_steps)
216            .replace("{last_result}", last_result)
217            .replace("{remaining_steps}", &remaining_steps)
218    }
219
220    /// Parse re-planning response and update plan if needed
221    pub fn apply_replan(plan: &mut ExecutionPlan, response: &str) -> Result<bool, PlanParseError> {
222        let json_str = extract_json(response)?;
223        let parsed: serde_json::Value = serde_json::from_str(&json_str)
224            .map_err(|e| PlanParseError::InvalidJson(e.to_string()))?;
225
226        let action = parsed
227            .get("action")
228            .and_then(|v| v.as_str())
229            .unwrap_or("continue");
230
231        match action {
232            "continue" => {
233                debug!("Plan continues unchanged");
234                Ok(false)
235            }
236            "modify" | "add" => {
237                if let Some(modified_steps) =
238                    parsed.get("modified_steps").and_then(|v| v.as_array())
239                {
240                    // Remove incomplete steps
241                    plan.steps.retain(|s| s.completed);
242
243                    // Add modified/new steps
244                    let base_number = plan.steps.len();
245                    for (idx, step_value) in modified_steps.iter().enumerate() {
246                        let description = step_value
247                            .get("description")
248                            .and_then(|v| v.as_str())
249                            .unwrap_or("Unknown step")
250                            .to_string();
251
252                        let expected_result = step_value
253                            .get("expected_result")
254                            .and_then(|v| v.as_str())
255                            .unwrap_or("")
256                            .to_string();
257
258                        let tools = step_value
259                            .get("tools")
260                            .and_then(|v| v.as_array())
261                            .map(|arr| {
262                                arr.iter()
263                                    .filter_map(|v| v.as_str().map(String::from))
264                                    .collect()
265                            })
266                            .unwrap_or_default();
267
268                        plan.steps.push(
269                            PlanStep::new(base_number + idx + 1, description)
270                                .with_expected_result(expected_result)
271                                .with_tools(tools),
272                        );
273                    }
274
275                    info!("Plan modified: now {} steps", plan.steps.len());
276                    Ok(true)
277                } else {
278                    warn!("Replan response missing modified_steps");
279                    Ok(false)
280                }
281            }
282            _ => {
283                warn!("Unknown replan action: {}", action);
284                Ok(false)
285            }
286        }
287    }
288}
289
290/// Errors that can occur during plan parsing
291#[derive(Debug, Error)]
292pub enum PlanParseError {
293    #[error("No JSON found in response")]
294    NoJson,
295    #[error("Invalid JSON: {0}")]
296    InvalidJson(String),
297    #[error("Missing required field: {0}")]
298    MissingField(String),
299    #[error("Invalid field: {0}")]
300    InvalidField(String),
301    #[error("Plan has no steps")]
302    EmptyPlan,
303}
304
305/// Extract JSON from a response that might contain markdown code blocks or other text
306fn extract_json(response: &str) -> Result<String, PlanParseError> {
307    // First, try to find JSON in code blocks
308    if let Some(start) = response.find("```json") {
309        let content_start = start + 7;
310        if let Some(end) = response[content_start..].find("```") {
311            return Ok(response[content_start..content_start + end]
312                .trim()
313                .to_string());
314        }
315    }
316
317    // Try generic code block
318    if let Some(start) = response.find("```") {
319        let content_start = start + 3;
320        // Skip language identifier if present
321        let content_start = response[content_start..]
322            .find('\n')
323            .map(|n| content_start + n + 1)
324            .unwrap_or(content_start);
325
326        if let Some(end) = response[content_start..].find("```") {
327            return Ok(response[content_start..content_start + end]
328                .trim()
329                .to_string());
330        }
331    }
332
333    // Try to find raw JSON object
334    if let Some(start) = response.find('{') {
335        // Find matching closing brace
336        let mut depth = 0;
337        let mut end = start;
338        for (i, c) in response[start..].char_indices() {
339            match c {
340                '{' => depth += 1,
341                '}' => {
342                    depth -= 1;
343                    if depth == 0 {
344                        end = start + i + 1;
345                        break;
346                    }
347                }
348                _ => {}
349            }
350        }
351        if depth == 0 && end > start {
352            return Ok(response[start..end].to_string());
353        }
354    }
355
356    Err(PlanParseError::NoJson)
357}
358
359/// Step execution context for tracking progress
360#[derive(Debug, Clone)]
361pub struct StepExecutionContext {
362    /// Current step being executed
363    pub step: PlanStep,
364    /// Prompt to send to the agent for this step
365    pub prompt: String,
366}
367
368impl StepExecutionContext {
369    /// Create execution context for a plan step
370    pub fn from_step(step: &PlanStep, plan: &ExecutionPlan) -> Self {
371        let mut prompt = format!(
372            "Execute step {} of the plan to: {}\n\n",
373            step.step_number, plan.goal
374        );
375
376        prompt.push_str(&format!("CURRENT STEP: {}\n", step.description));
377
378        if !step.expected_result.is_empty() {
379            prompt.push_str(&format!("EXPECTED RESULT: {}\n", step.expected_result));
380        }
381
382        if !step.tools.is_empty() {
383            prompt.push_str(&format!("SUGGESTED TOOLS: {}\n", step.tools.join(", ")));
384        }
385
386        // Add context from previous steps
387        let completed: Vec<_> = plan.steps.iter().filter(|s| s.completed).collect();
388        if !completed.is_empty() {
389            prompt.push_str("\nPREVIOUS RESULTS:\n");
390            for prev in completed {
391                if let Some(result) = &prev.actual_result {
392                    prompt.push_str(&format!("- Step {}: {}\n", prev.step_number, result));
393                }
394            }
395        }
396
397        prompt.push_str("\nExecute this step and provide the result.");
398
399        Self {
400            step: step.clone(),
401            prompt,
402        }
403    }
404}
405
406/// Check if any stop words are present in text
407pub fn check_stop_words(text: &str, stop_words: &[String]) -> Option<String> {
408    let text_lower = text.to_lowercase();
409    for word in stop_words {
410        if text_lower.contains(&word.to_lowercase()) {
411            return Some(word.clone());
412        }
413    }
414    None
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_extract_json_code_block() {
423        let response = r#"Here's the plan:
424```json
425{"steps": [{"step_number": 1, "description": "Test"}]}
426```
427"#;
428        let json = extract_json(response).unwrap();
429        assert!(json.contains("steps"));
430    }
431
432    #[test]
433    fn test_extract_json_raw() {
434        let response = r#"{"steps": [{"step_number": 1, "description": "Test"}]}"#;
435        let json = extract_json(response).unwrap();
436        assert!(json.contains("steps"));
437    }
438
439    #[test]
440    fn test_parse_plan() {
441        let response = r#"{"reasoning": "Simple test", "steps": [
442            {"step_number": 1, "description": "First step", "expected_result": "Done", "tools": ["tool1"]}
443        ]}"#;
444
445        let plan = PlanGenerator::parse_plan("Test goal", response).unwrap();
446        assert_eq!(plan.steps.len(), 1);
447        assert_eq!(plan.steps[0].description, "First step");
448        assert_eq!(plan.steps[0].tools, vec!["tool1"]);
449    }
450
451    #[test]
452    fn test_check_stop_words() {
453        let stop_words = vec!["DONE".to_string(), "FINISHED".to_string()];
454
455        assert!(check_stop_words("Task is DONE", &stop_words).is_some());
456        assert!(check_stop_words("Task finished successfully", &stop_words).is_some());
457        assert!(check_stop_words("Still working", &stop_words).is_none());
458    }
459}