1use cortexai_core::tool::ToolSchema;
16use cortexai_core::types::{ExecutionPlan, PlanStep};
17use cortexai_core::LLMMessage;
18use thiserror::Error;
19use tracing::{debug, info, warn};
20
21const PLANNING_PROMPT: &str = r#"You are a planning agent. Your task is to create a detailed execution plan for the following goal.
23
24GOAL: {goal}
25
26AVAILABLE TOOLS:
27{tools}
28
29Create a step-by-step plan to achieve this goal. For each step:
301. Describe what action to take
312. Specify which tool(s) to use (if any)
323. Describe the expected result
33
34Respond ONLY with a JSON object in this exact format:
35{
36 "reasoning": "Brief explanation of your approach",
37 "steps": [
38 {
39 "step_number": 1,
40 "description": "What to do in this step",
41 "expected_result": "What we expect to get from this step",
42 "tools": ["tool_name_1", "tool_name_2"]
43 }
44 ]
45}
46
47Keep the plan concise but complete. Aim for 3-7 steps maximum."#;
48
49const REPLAN_PROMPT: &str = r#"You are a planning agent. Review the current plan progress and decide if re-planning is needed.
51
52ORIGINAL GOAL: {goal}
53
54CURRENT PLAN:
55{current_plan}
56
57COMPLETED STEPS:
58{completed_steps}
59
60LAST RESULT: {last_result}
61
62REMAINING STEPS:
63{remaining_steps}
64
65Based on the results so far, should we:
661. Continue with the current plan
672. Modify the remaining steps
683. Add new steps
69
70Respond ONLY with a JSON object:
71{
72 "action": "continue" | "modify" | "add",
73 "reasoning": "Why this decision",
74 "modified_steps": [
75 // Only if action is "modify" or "add"
76 {
77 "step_number": N,
78 "description": "...",
79 "expected_result": "...",
80 "tools": []
81 }
82 ]
83}"#;
84
85pub struct PlanGenerator;
87
88impl PlanGenerator {
89 pub fn create_planning_prompt(goal: &str, tools: &[ToolSchema]) -> String {
91 let tools_desc = tools
92 .iter()
93 .map(|t| format!("- {}: {}", t.name, t.description))
94 .collect::<Vec<_>>()
95 .join("\n");
96
97 PLANNING_PROMPT
98 .replace("{goal}", goal)
99 .replace("{tools}", &tools_desc)
100 }
101
102 pub fn create_planning_message(goal: &str, tools: &[ToolSchema]) -> LLMMessage {
104 LLMMessage::user(Self::create_planning_prompt(goal, tools))
105 }
106
107 pub fn parse_plan(goal: &str, response: &str) -> Result<ExecutionPlan, PlanParseError> {
109 let json_str = extract_json(response)?;
111
112 let parsed: serde_json::Value = serde_json::from_str(&json_str)
114 .map_err(|e| PlanParseError::InvalidJson(e.to_string()))?;
115
116 let reasoning = parsed
118 .get("reasoning")
119 .and_then(|v| v.as_str())
120 .unwrap_or("")
121 .to_string();
122
123 let steps_value = parsed
125 .get("steps")
126 .ok_or_else(|| PlanParseError::MissingField("steps".to_string()))?;
127
128 let steps_array = steps_value
129 .as_array()
130 .ok_or_else(|| PlanParseError::InvalidField("steps must be an array".to_string()))?;
131
132 let mut steps = Vec::with_capacity(steps_array.len());
133
134 for (idx, step_value) in steps_array.iter().enumerate() {
135 let step_number = step_value
136 .get("step_number")
137 .and_then(|v| v.as_u64())
138 .unwrap_or((idx + 1) as u64) as usize;
139
140 let description = step_value
141 .get("description")
142 .and_then(|v| v.as_str())
143 .ok_or_else(|| {
144 PlanParseError::InvalidField(format!("step {} missing description", idx + 1))
145 })?
146 .to_string();
147
148 let expected_result = step_value
149 .get("expected_result")
150 .and_then(|v| v.as_str())
151 .unwrap_or("")
152 .to_string();
153
154 let tools = step_value
155 .get("tools")
156 .and_then(|v| v.as_array())
157 .map(|arr| {
158 arr.iter()
159 .filter_map(|v| v.as_str().map(String::from))
160 .collect()
161 })
162 .unwrap_or_default();
163
164 steps.push(
165 PlanStep::new(step_number, description)
166 .with_expected_result(expected_result)
167 .with_tools(tools),
168 );
169 }
170
171 if steps.is_empty() {
172 return Err(PlanParseError::EmptyPlan);
173 }
174
175 Ok(ExecutionPlan::new(goal)
176 .with_steps(steps)
177 .with_reasoning(reasoning))
178 }
179
180 pub fn create_replan_prompt(plan: &ExecutionPlan, last_result: &str) -> String {
182 let current_plan = format!(
183 "Goal: {}\nTotal steps: {}\nProgress: {:.0}%",
184 plan.goal,
185 plan.steps.len(),
186 plan.progress() * 100.0
187 );
188
189 let completed_steps = plan
190 .steps
191 .iter()
192 .filter(|s| s.completed)
193 .map(|s| {
194 format!(
195 "Step {}: {} -> {}",
196 s.step_number,
197 s.description,
198 s.actual_result.as_deref().unwrap_or("(no result)")
199 )
200 })
201 .collect::<Vec<_>>()
202 .join("\n");
203
204 let remaining_steps = plan
205 .steps
206 .iter()
207 .filter(|s| !s.completed)
208 .map(|s| format!("Step {}: {}", s.step_number, s.description))
209 .collect::<Vec<_>>()
210 .join("\n");
211
212 REPLAN_PROMPT
213 .replace("{goal}", &plan.goal)
214 .replace("{current_plan}", ¤t_plan)
215 .replace("{completed_steps}", &completed_steps)
216 .replace("{last_result}", last_result)
217 .replace("{remaining_steps}", &remaining_steps)
218 }
219
220 pub fn apply_replan(plan: &mut ExecutionPlan, response: &str) -> Result<bool, PlanParseError> {
222 let json_str = extract_json(response)?;
223 let parsed: serde_json::Value = serde_json::from_str(&json_str)
224 .map_err(|e| PlanParseError::InvalidJson(e.to_string()))?;
225
226 let action = parsed
227 .get("action")
228 .and_then(|v| v.as_str())
229 .unwrap_or("continue");
230
231 match action {
232 "continue" => {
233 debug!("Plan continues unchanged");
234 Ok(false)
235 }
236 "modify" | "add" => {
237 if let Some(modified_steps) =
238 parsed.get("modified_steps").and_then(|v| v.as_array())
239 {
240 plan.steps.retain(|s| s.completed);
242
243 let base_number = plan.steps.len();
245 for (idx, step_value) in modified_steps.iter().enumerate() {
246 let description = step_value
247 .get("description")
248 .and_then(|v| v.as_str())
249 .unwrap_or("Unknown step")
250 .to_string();
251
252 let expected_result = step_value
253 .get("expected_result")
254 .and_then(|v| v.as_str())
255 .unwrap_or("")
256 .to_string();
257
258 let tools = step_value
259 .get("tools")
260 .and_then(|v| v.as_array())
261 .map(|arr| {
262 arr.iter()
263 .filter_map(|v| v.as_str().map(String::from))
264 .collect()
265 })
266 .unwrap_or_default();
267
268 plan.steps.push(
269 PlanStep::new(base_number + idx + 1, description)
270 .with_expected_result(expected_result)
271 .with_tools(tools),
272 );
273 }
274
275 info!("Plan modified: now {} steps", plan.steps.len());
276 Ok(true)
277 } else {
278 warn!("Replan response missing modified_steps");
279 Ok(false)
280 }
281 }
282 _ => {
283 warn!("Unknown replan action: {}", action);
284 Ok(false)
285 }
286 }
287 }
288}
289
290#[derive(Debug, Error)]
292pub enum PlanParseError {
293 #[error("No JSON found in response")]
294 NoJson,
295 #[error("Invalid JSON: {0}")]
296 InvalidJson(String),
297 #[error("Missing required field: {0}")]
298 MissingField(String),
299 #[error("Invalid field: {0}")]
300 InvalidField(String),
301 #[error("Plan has no steps")]
302 EmptyPlan,
303}
304
305fn extract_json(response: &str) -> Result<String, PlanParseError> {
307 if let Some(start) = response.find("```json") {
309 let content_start = start + 7;
310 if let Some(end) = response[content_start..].find("```") {
311 return Ok(response[content_start..content_start + end]
312 .trim()
313 .to_string());
314 }
315 }
316
317 if let Some(start) = response.find("```") {
319 let content_start = start + 3;
320 let content_start = response[content_start..]
322 .find('\n')
323 .map(|n| content_start + n + 1)
324 .unwrap_or(content_start);
325
326 if let Some(end) = response[content_start..].find("```") {
327 return Ok(response[content_start..content_start + end]
328 .trim()
329 .to_string());
330 }
331 }
332
333 if let Some(start) = response.find('{') {
335 let mut depth = 0;
337 let mut end = start;
338 for (i, c) in response[start..].char_indices() {
339 match c {
340 '{' => depth += 1,
341 '}' => {
342 depth -= 1;
343 if depth == 0 {
344 end = start + i + 1;
345 break;
346 }
347 }
348 _ => {}
349 }
350 }
351 if depth == 0 && end > start {
352 return Ok(response[start..end].to_string());
353 }
354 }
355
356 Err(PlanParseError::NoJson)
357}
358
359#[derive(Debug, Clone)]
361pub struct StepExecutionContext {
362 pub step: PlanStep,
364 pub prompt: String,
366}
367
368impl StepExecutionContext {
369 pub fn from_step(step: &PlanStep, plan: &ExecutionPlan) -> Self {
371 let mut prompt = format!(
372 "Execute step {} of the plan to: {}\n\n",
373 step.step_number, plan.goal
374 );
375
376 prompt.push_str(&format!("CURRENT STEP: {}\n", step.description));
377
378 if !step.expected_result.is_empty() {
379 prompt.push_str(&format!("EXPECTED RESULT: {}\n", step.expected_result));
380 }
381
382 if !step.tools.is_empty() {
383 prompt.push_str(&format!("SUGGESTED TOOLS: {}\n", step.tools.join(", ")));
384 }
385
386 let completed: Vec<_> = plan.steps.iter().filter(|s| s.completed).collect();
388 if !completed.is_empty() {
389 prompt.push_str("\nPREVIOUS RESULTS:\n");
390 for prev in completed {
391 if let Some(result) = &prev.actual_result {
392 prompt.push_str(&format!("- Step {}: {}\n", prev.step_number, result));
393 }
394 }
395 }
396
397 prompt.push_str("\nExecute this step and provide the result.");
398
399 Self {
400 step: step.clone(),
401 prompt,
402 }
403 }
404}
405
406pub fn check_stop_words(text: &str, stop_words: &[String]) -> Option<String> {
408 let text_lower = text.to_lowercase();
409 for word in stop_words {
410 if text_lower.contains(&word.to_lowercase()) {
411 return Some(word.clone());
412 }
413 }
414 None
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_extract_json_code_block() {
423 let response = r#"Here's the plan:
424```json
425{"steps": [{"step_number": 1, "description": "Test"}]}
426```
427"#;
428 let json = extract_json(response).unwrap();
429 assert!(json.contains("steps"));
430 }
431
432 #[test]
433 fn test_extract_json_raw() {
434 let response = r#"{"steps": [{"step_number": 1, "description": "Test"}]}"#;
435 let json = extract_json(response).unwrap();
436 assert!(json.contains("steps"));
437 }
438
439 #[test]
440 fn test_parse_plan() {
441 let response = r#"{"reasoning": "Simple test", "steps": [
442 {"step_number": 1, "description": "First step", "expected_result": "Done", "tools": ["tool1"]}
443 ]}"#;
444
445 let plan = PlanGenerator::parse_plan("Test goal", response).unwrap();
446 assert_eq!(plan.steps.len(), 1);
447 assert_eq!(plan.steps[0].description, "First step");
448 assert_eq!(plan.steps[0].tools, vec!["tool1"]);
449 }
450
451 #[test]
452 fn test_check_stop_words() {
453 let stop_words = vec!["DONE".to_string(), "FINISHED".to_string()];
454
455 assert!(check_stop_words("Task is DONE", &stop_words).is_some());
456 assert!(check_stop_words("Task finished successfully", &stop_words).is_some());
457 assert!(check_stop_words("Still working", &stop_words).is_none());
458 }
459}