Skip to main content

a3s_code_core/planning/
llm_planner.rs

1//! LLM-powered planning logic
2//!
3//! Provides intelligent plan generation, goal extraction, and achievement
4//! evaluation by sending structured prompts to an LLM and parsing JSON responses.
5//! Falls back to heuristic logic when no LLM client is available.
6
7use crate::llm::{LlmClient, Message};
8use crate::planning::{AgentGoal, Complexity, ExecutionPlan, Task};
9use anyhow::{Context, Result};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13/// Result of evaluating goal achievement
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AchievementResult {
16    /// Whether the goal has been achieved
17    pub achieved: bool,
18    /// Progress toward goal (0.0 - 1.0)
19    pub progress: f32,
20    /// Criteria that remain unmet
21    pub remaining_criteria: Vec<String>,
22}
23
24/// Pre-analysis result — intent, goal, plan, and optimized input in one LLM call.
25#[derive(Debug, Clone)]
26pub struct PreAnalysis {
27    pub intent: crate::prompts::AgentStyle,
28    pub requires_planning: bool,
29    pub goal: AgentGoal,
30    pub execution_plan: ExecutionPlan,
31    /// LLM-rewritten version of the user input with ambiguities resolved.
32    pub optimized_input: String,
33}
34
35/// LLM-powered planner that generates plans, extracts goals, and evaluates achievement
36pub struct LlmPlanner;
37
38// ============================================================================
39// JSON response schemas for LLM parsing
40// ============================================================================
41
42#[derive(Debug, Deserialize)]
43struct PlanResponse {
44    goal: String,
45    complexity: String,
46    steps: Vec<StepResponse>,
47    #[serde(default)]
48    required_tools: Vec<String>,
49}
50
51#[derive(Debug, Deserialize)]
52struct StepResponse {
53    id: String,
54    description: String,
55    #[serde(default)]
56    tool: Option<String>,
57    #[serde(default)]
58    dependencies: Vec<String>,
59    #[serde(default)]
60    success_criteria: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64struct GoalResponse {
65    description: String,
66    success_criteria: Vec<String>,
67}
68
69#[derive(Debug, Deserialize)]
70struct AchievementResponse {
71    achieved: bool,
72    progress: f32,
73    #[serde(default)]
74    remaining_criteria: Vec<String>,
75}
76
77#[derive(Debug, Deserialize)]
78struct PreAnalysisResponse {
79    intent: String,
80    requires_planning: bool,
81    goal: GoalResponse,
82    execution_plan: PreAnalysisPlan,
83    optimized_input: String,
84}
85
86#[derive(Debug, Deserialize)]
87struct PreAnalysisPlan {
88    complexity: String,
89    steps: Vec<StepResponse>,
90    #[serde(default)]
91    required_tools: Vec<String>,
92}
93
94impl LlmPlanner {
95    /// Generate an execution plan from a prompt using LLM
96    pub async fn create_plan(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
97        let system = crate::prompts::LLM_PLAN_SYSTEM;
98
99        let messages = vec![Message::user(prompt)];
100        let response = llm
101            .complete(&messages, Some(system), &[])
102            .await
103            .context("LLM call failed during plan creation")?;
104
105        let text = response.text();
106        Self::parse_plan_response(&text)
107    }
108
109    /// Extract a goal with success criteria from a prompt using LLM
110    pub async fn extract_goal(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
111        let system = crate::prompts::LLM_GOAL_EXTRACT_SYSTEM;
112
113        let messages = vec![Message::user(prompt)];
114        let response = llm
115            .complete(&messages, Some(system), &[])
116            .await
117            .context("LLM call failed during goal extraction")?;
118
119        let text = response.text();
120        Self::parse_goal_response(&text)
121    }
122
123    /// Evaluate whether a goal has been achieved given current state
124    pub async fn check_achievement(
125        llm: &Arc<dyn LlmClient>,
126        goal: &AgentGoal,
127        current_state: &str,
128    ) -> Result<AchievementResult> {
129        let system = crate::prompts::LLM_GOAL_CHECK_SYSTEM;
130
131        let user_message = format!(
132            "Goal: {}\nSuccess Criteria: {}\nCurrent State: {}",
133            goal.description,
134            goal.success_criteria.join("; "),
135            current_state,
136        );
137
138        let messages = vec![Message::user(&user_message)];
139        let response = llm
140            .complete(&messages, Some(system), &[])
141            .await
142            .context("LLM call failed during achievement check")?;
143
144        let text = response.text();
145        Self::parse_achievement_response(&text)
146    }
147
148    /// Create a fallback plan using heuristic logic (no LLM required)
149    pub fn fallback_plan(prompt: &str) -> ExecutionPlan {
150        let complexity = if prompt.len() < 50 {
151            Complexity::Simple
152        } else if prompt.len() < 150 {
153            Complexity::Medium
154        } else if prompt.len() < 300 {
155            Complexity::Complex
156        } else {
157            Complexity::VeryComplex
158        };
159
160        let mut plan = ExecutionPlan::new(prompt, complexity);
161
162        let step_count = match complexity {
163            Complexity::Simple => 2,
164            Complexity::Medium => 4,
165            Complexity::Complex => 7,
166            Complexity::VeryComplex => 10,
167        };
168
169        for i in 0..step_count {
170            let step = Task::new(
171                format!("step-{}", i + 1),
172                crate::prompts::render(
173                    crate::prompts::PLAN_FALLBACK_STEP,
174                    &[("step_num", &(i + 1).to_string())],
175                ),
176            );
177            plan.add_step(step);
178        }
179
180        plan
181    }
182
183    /// Create a fallback goal using heuristic logic (no LLM required)
184    pub fn fallback_goal(prompt: &str) -> AgentGoal {
185        AgentGoal::new(prompt).with_criteria(vec![
186            "Task is completed successfully".to_string(),
187            "All requirements are met".to_string(),
188        ])
189    }
190
191    /// Create a fallback achievement result using heuristic logic (no LLM required)
192    pub fn fallback_check_achievement(goal: &AgentGoal, current_state: &str) -> AchievementResult {
193        let state_lower = current_state.to_lowercase();
194        let achieved = state_lower.contains("complete")
195            || state_lower.contains("done")
196            || state_lower.contains("finished");
197
198        let progress = if achieved { 1.0 } else { goal.progress };
199
200        let remaining_criteria = if achieved {
201            Vec::new()
202        } else {
203            goal.success_criteria.clone()
204        };
205
206        AchievementResult {
207            achieved,
208            progress,
209            remaining_criteria,
210        }
211    }
212
213    /// Perform pre-analysis in a single LLM call: intent classification, goal extraction,
214    /// execution plan, and input optimization. Falls back to heuristics on failure.
215    pub async fn pre_analyze(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<PreAnalysis> {
216        let system = crate::prompts::PRE_ANALYSIS_SYSTEM;
217
218        let messages = vec![Message::user(prompt)];
219        let response = llm
220            .complete(&messages, Some(system), &[])
221            .await
222            .context("LLM pre-analysis call failed")?;
223
224        let text = response.text();
225        Self::parse_pre_analysis_response(&text, prompt)
226    }
227
228    fn parse_pre_analysis_response(text: &str, original_prompt: &str) -> Result<PreAnalysis> {
229        let cleaned = Self::extract_json(text);
230        let parsed: PreAnalysisResponse = serde_json::from_str(cleaned)
231            .context("Failed to parse pre-analysis JSON from LLM response")?;
232
233        let intent = match parsed.intent.to_lowercase().as_str() {
234            "plan" => crate::prompts::AgentStyle::Plan,
235            "explore" => crate::prompts::AgentStyle::Explore,
236            "verification" => crate::prompts::AgentStyle::Verification,
237            "codereview" | "code review" => crate::prompts::AgentStyle::CodeReview,
238            _ => crate::prompts::AgentStyle::GeneralPurpose,
239        };
240
241        let goal_description = parsed.goal.description.clone();
242        let goal =
243            AgentGoal::new(goal_description.clone()).with_criteria(parsed.goal.success_criteria);
244
245        let complexity = match parsed.execution_plan.complexity.as_str() {
246            "Simple" => Complexity::Simple,
247            "Medium" => Complexity::Medium,
248            "Complex" => Complexity::Complex,
249            "VeryComplex" => Complexity::VeryComplex,
250            _ => Complexity::Medium,
251        };
252
253        let mut plan = ExecutionPlan::new(goal_description, complexity);
254        for step_resp in parsed.execution_plan.steps {
255            let mut task = Task::new(step_resp.id, step_resp.description);
256            if let Some(tool) = step_resp.tool {
257                task = task.with_tool(tool);
258            }
259            if !step_resp.dependencies.is_empty() {
260                task = task.with_dependencies(step_resp.dependencies);
261            }
262            if let Some(criteria) = step_resp.success_criteria {
263                task = task.with_success_criteria(criteria);
264            }
265            plan.add_step(task);
266        }
267        for tool in parsed.execution_plan.required_tools {
268            plan.add_required_tool(tool);
269        }
270
271        Ok(PreAnalysis {
272            intent,
273            requires_planning: parsed.requires_planning,
274            goal,
275            execution_plan: plan,
276            optimized_input: if parsed.optimized_input.is_empty() {
277                original_prompt.to_string()
278            } else {
279                parsed.optimized_input
280            },
281        })
282    }
283
284    // ========================================================================
285    // JSON parsing helpers
286    // ========================================================================
287
288    fn parse_plan_response(text: &str) -> Result<ExecutionPlan> {
289        let cleaned = Self::extract_json(text);
290        let parsed: PlanResponse =
291            serde_json::from_str(cleaned).context("Failed to parse plan JSON from LLM response")?;
292
293        let complexity = match parsed.complexity.as_str() {
294            "Simple" => Complexity::Simple,
295            "Medium" => Complexity::Medium,
296            "Complex" => Complexity::Complex,
297            "VeryComplex" => Complexity::VeryComplex,
298            _ => Complexity::Medium,
299        };
300
301        let mut plan = ExecutionPlan::new(parsed.goal, complexity);
302
303        for step_resp in parsed.steps {
304            let mut task = Task::new(step_resp.id, step_resp.description);
305            if let Some(tool) = step_resp.tool {
306                task = task.with_tool(tool);
307            }
308            if !step_resp.dependencies.is_empty() {
309                task = task.with_dependencies(step_resp.dependencies);
310            }
311            if let Some(criteria) = step_resp.success_criteria {
312                task = task.with_success_criteria(criteria);
313            }
314            plan.add_step(task);
315        }
316
317        for tool in parsed.required_tools {
318            plan.add_required_tool(tool);
319        }
320
321        Ok(plan)
322    }
323
324    fn parse_goal_response(text: &str) -> Result<AgentGoal> {
325        let cleaned = Self::extract_json(text);
326        let parsed: GoalResponse =
327            serde_json::from_str(cleaned).context("Failed to parse goal JSON from LLM response")?;
328
329        Ok(AgentGoal::new(parsed.description).with_criteria(parsed.success_criteria))
330    }
331
332    fn parse_achievement_response(text: &str) -> Result<AchievementResult> {
333        let cleaned = Self::extract_json(text);
334        let parsed: AchievementResponse = serde_json::from_str(cleaned)
335            .context("Failed to parse achievement JSON from LLM response")?;
336
337        Ok(AchievementResult {
338            achieved: parsed.achieved,
339            progress: parsed.progress.clamp(0.0, 1.0),
340            remaining_criteria: parsed.remaining_criteria,
341        })
342    }
343
344    /// Extract JSON from LLM text that may contain markdown fences
345    fn extract_json(text: &str) -> &str {
346        let trimmed = text.trim();
347
348        // Strip markdown code fences if present
349        if let Some(start) = trimmed.find('{') {
350            if let Some(end) = trimmed.rfind('}') {
351                if start <= end {
352                    return &trimmed[start..=end];
353                }
354            }
355        }
356
357        trimmed
358    }
359}
360
361// ============================================================================
362// Tests
363// ============================================================================
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_parse_plan_response() {
371        let json = r#"{
372            "goal": "Build a REST API",
373            "complexity": "Complex",
374            "steps": [
375                {
376                    "id": "step-1",
377                    "description": "Set up project structure",
378                    "tool": "bash",
379                    "dependencies": [],
380                    "success_criteria": "Project directory created"
381                },
382                {
383                    "id": "step-2",
384                    "description": "Implement endpoints",
385                    "tool": "write",
386                    "dependencies": ["step-1"],
387                    "success_criteria": "Endpoints respond correctly"
388                }
389            ],
390            "required_tools": ["bash", "write", "read"]
391        }"#;
392
393        let plan = LlmPlanner::parse_plan_response(json).unwrap();
394        assert_eq!(plan.goal, "Build a REST API");
395        assert_eq!(plan.complexity, Complexity::Complex);
396        assert_eq!(plan.steps.len(), 2);
397        assert_eq!(plan.steps[0].id, "step-1");
398        assert_eq!(plan.steps[0].tool, Some("bash".to_string()));
399        assert_eq!(plan.steps[1].dependencies, vec!["step-1".to_string()]);
400        assert_eq!(plan.required_tools, vec!["bash", "write", "read"]);
401    }
402
403    #[test]
404    fn test_parse_plan_response_with_markdown_fences() {
405        let json = "```json\n{\"goal\": \"Test\", \"complexity\": \"Simple\", \"steps\": [{\"id\": \"step-1\", \"description\": \"Do it\"}], \"required_tools\": []}\n```";
406
407        let plan = LlmPlanner::parse_plan_response(json).unwrap();
408        assert_eq!(plan.goal, "Test");
409        assert_eq!(plan.complexity, Complexity::Simple);
410        assert_eq!(plan.steps.len(), 1);
411    }
412
413    #[test]
414    fn test_parse_plan_response_invalid() {
415        let bad_json = "This is not JSON at all";
416        let result = LlmPlanner::parse_plan_response(bad_json);
417        assert!(result.is_err());
418    }
419
420    #[test]
421    fn test_parse_plan_response_unknown_complexity() {
422        let json =
423            r#"{"goal": "Test", "complexity": "Unknown", "steps": [], "required_tools": []}"#;
424        let plan = LlmPlanner::parse_plan_response(json).unwrap();
425        assert_eq!(plan.complexity, Complexity::Medium); // falls back to Medium
426    }
427
428    #[test]
429    fn test_parse_goal_response() {
430        let json = r#"{
431            "description": "Deploy the application to production",
432            "success_criteria": [
433                "All tests pass",
434                "Application is accessible at production URL",
435                "Health check returns 200"
436            ]
437        }"#;
438
439        let goal = LlmPlanner::parse_goal_response(json).unwrap();
440        assert_eq!(goal.description, "Deploy the application to production");
441        assert_eq!(goal.success_criteria.len(), 3);
442        assert_eq!(goal.success_criteria[0], "All tests pass");
443    }
444
445    #[test]
446    fn test_parse_goal_response_invalid() {
447        let result = LlmPlanner::parse_goal_response("not json");
448        assert!(result.is_err());
449    }
450
451    #[test]
452    fn test_parse_achievement_response() {
453        let json = r#"{
454            "achieved": false,
455            "progress": 0.65,
456            "remaining_criteria": ["Health check not verified"]
457        }"#;
458
459        let result = LlmPlanner::parse_achievement_response(json).unwrap();
460        assert!(!result.achieved);
461        assert!((result.progress - 0.65).abs() < f32::EPSILON);
462        assert_eq!(result.remaining_criteria, vec!["Health check not verified"]);
463    }
464
465    #[test]
466    fn test_parse_achievement_response_achieved() {
467        let json = r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#;
468        let result = LlmPlanner::parse_achievement_response(json).unwrap();
469        assert!(result.achieved);
470        assert!((result.progress - 1.0).abs() < f32::EPSILON);
471        assert!(result.remaining_criteria.is_empty());
472    }
473
474    #[test]
475    fn test_parse_achievement_response_clamps_progress() {
476        let json = r#"{"achieved": false, "progress": 1.5, "remaining_criteria": []}"#;
477        let result = LlmPlanner::parse_achievement_response(json).unwrap();
478        assert!((result.progress - 1.0).abs() < f32::EPSILON);
479    }
480
481    #[test]
482    fn test_fallback_plan() {
483        let short_prompt = "Fix bug";
484        let plan = LlmPlanner::fallback_plan(short_prompt);
485        assert_eq!(plan.complexity, Complexity::Simple);
486        assert_eq!(plan.steps.len(), 2);
487        assert_eq!(plan.goal, short_prompt);
488
489        let long_prompt = "Implement a comprehensive authentication system with OAuth2 support, JWT tokens, refresh token rotation, multi-factor authentication, and role-based access control across all API endpoints with proper audit logging and session management capabilities for both web and mobile clients, including password reset flows, account lockout policies, and integration with external identity providers such as Google, GitHub, and SAML-based enterprise SSO systems";
490        let plan = LlmPlanner::fallback_plan(long_prompt);
491        assert_eq!(plan.complexity, Complexity::VeryComplex);
492        assert_eq!(plan.steps.len(), 10);
493    }
494
495    #[test]
496    fn test_fallback_goal() {
497        let goal = LlmPlanner::fallback_goal("Fix the login bug");
498        assert_eq!(goal.description, "Fix the login bug");
499        assert_eq!(goal.success_criteria.len(), 2);
500        assert_eq!(goal.success_criteria[0], "Task is completed successfully");
501    }
502
503    #[test]
504    fn test_fallback_check_achievement_done() {
505        let goal = AgentGoal::new("Test task").with_criteria(vec!["Criterion 1".to_string()]);
506
507        let result = LlmPlanner::fallback_check_achievement(&goal, "The task is done.");
508        assert!(result.achieved);
509        assert!((result.progress - 1.0).abs() < f32::EPSILON);
510        assert!(result.remaining_criteria.is_empty());
511    }
512
513    #[test]
514    fn test_fallback_check_achievement_not_done() {
515        let goal = AgentGoal::new("Test task")
516            .with_criteria(vec!["Criterion 1".to_string(), "Criterion 2".to_string()]);
517
518        let result = LlmPlanner::fallback_check_achievement(&goal, "Work in progress");
519        assert!(!result.achieved);
520        assert_eq!(result.remaining_criteria.len(), 2);
521    }
522
523    #[test]
524    fn test_extract_json_plain() {
525        assert_eq!(LlmPlanner::extract_json("  {\"a\": 1}  "), "{\"a\": 1}");
526    }
527
528    #[test]
529    fn test_extract_json_with_fences() {
530        let text = "```json\n{\"a\": 1}\n```";
531        assert_eq!(LlmPlanner::extract_json(text), "{\"a\": 1}");
532    }
533
534    #[test]
535    fn test_extract_json_with_surrounding_text() {
536        let text = "Here is the plan:\n{\"goal\": \"test\"}\nDone.";
537        assert_eq!(LlmPlanner::extract_json(text), "{\"goal\": \"test\"}");
538    }
539}