Skip to main content

a3s_code_core/planning/
llm_planner.rs

1//! LLM-powered planning logic
2//!
3//! Provides intelligent plan generation, goal extraction, and achievement
4//! evaluation by sending structured prompts to an LLM and parsing JSON responses.
5//! Falls back to heuristic logic when no LLM client is available.
6
7use crate::llm::{LlmClient, Message};
8use crate::planning::{AgentGoal, Complexity, ExecutionPlan, Task};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13
14/// Result of evaluating goal achievement
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AchievementResult {
17    /// Whether the goal has been achieved
18    pub achieved: bool,
19    /// Progress toward goal (0.0 - 1.0)
20    pub progress: f32,
21    /// Criteria that remain unmet
22    pub remaining_criteria: Vec<String>,
23}
24
25/// Pre-analysis result — intent, goal, plan, and optimized input in one LLM call.
26#[derive(Debug, Clone)]
27pub struct PreAnalysis {
28    pub intent: crate::prompts::AgentStyle,
29    pub requires_planning: bool,
30    pub goal: AgentGoal,
31    pub execution_plan: ExecutionPlan,
32    /// LLM-rewritten version of the user input with ambiguities resolved.
33    pub optimized_input: String,
34}
35
36/// Trait for planning providers
37///
38/// Abstracts plan generation, goal extraction, and achievement evaluation.
39/// Allows different implementations (LLM-based, heuristic, test mocks).
40#[async_trait]
41pub trait Planner: Send + Sync {
42    /// Generate an execution plan from a prompt
43    async fn create_plan(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan>;
44
45    /// Extract a goal with success criteria from a prompt
46    async fn extract_goal(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal>;
47
48    /// Evaluate whether a goal has been achieved given current state
49    async fn check_achievement(
50        &self,
51        llm: &Arc<dyn LlmClient>,
52        goal: &AgentGoal,
53        current_state: &str,
54    ) -> Result<AchievementResult>;
55}
56
57/// LLM-powered planner that generates plans, extracts goals, and evaluates achievement
58pub struct LlmPlanner;
59
60// ============================================================================
61// JSON response schemas for LLM parsing
62// ============================================================================
63
64#[derive(Debug, Deserialize)]
65struct PlanResponse {
66    goal: String,
67    complexity: String,
68    steps: Vec<StepResponse>,
69    #[serde(default)]
70    required_tools: Vec<String>,
71}
72
73#[derive(Debug, Deserialize)]
74struct StepResponse {
75    id: String,
76    description: String,
77    #[serde(default)]
78    tool: Option<String>,
79    #[serde(default)]
80    dependencies: Vec<String>,
81    #[serde(default)]
82    success_criteria: Option<String>,
83}
84
85#[derive(Debug, Deserialize)]
86struct GoalResponse {
87    description: String,
88    success_criteria: Vec<String>,
89}
90
91#[derive(Debug, Deserialize)]
92struct AchievementResponse {
93    achieved: bool,
94    progress: f32,
95    #[serde(default)]
96    remaining_criteria: Vec<String>,
97}
98
99#[derive(Debug, Deserialize)]
100struct PreAnalysisResponse {
101    intent: String,
102    requires_planning: bool,
103    goal: GoalResponse,
104    execution_plan: PreAnalysisPlan,
105    optimized_input: String,
106}
107
108#[derive(Debug, Deserialize)]
109struct PreAnalysisPlan {
110    complexity: String,
111    steps: Vec<StepResponse>,
112    #[serde(default)]
113    required_tools: Vec<String>,
114}
115
116impl LlmPlanner {
117    /// Generate an execution plan from a prompt using LLM
118    pub async fn create_plan(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
119        let system = crate::prompts::LLM_PLAN_SYSTEM;
120
121        let messages = vec![Message::user(prompt)];
122        let response = llm
123            .complete(&messages, Some(system), &[])
124            .await
125            .context("LLM call failed during plan creation")?;
126
127        let text = response.text();
128        Self::parse_plan_response(&text)
129    }
130
131    /// Extract a goal with success criteria from a prompt using LLM
132    pub async fn extract_goal(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
133        let system = crate::prompts::LLM_GOAL_EXTRACT_SYSTEM;
134
135        let messages = vec![Message::user(prompt)];
136        let response = llm
137            .complete(&messages, Some(system), &[])
138            .await
139            .context("LLM call failed during goal extraction")?;
140
141        let text = response.text();
142        Self::parse_goal_response(&text)
143    }
144
145    /// Evaluate whether a goal has been achieved given current state
146    pub async fn check_achievement(
147        llm: &Arc<dyn LlmClient>,
148        goal: &AgentGoal,
149        current_state: &str,
150    ) -> Result<AchievementResult> {
151        let system = crate::prompts::LLM_GOAL_CHECK_SYSTEM;
152
153        let user_message = format!(
154            "Goal: {}\nSuccess Criteria: {}\nCurrent State: {}",
155            goal.description,
156            goal.success_criteria.join("; "),
157            current_state,
158        );
159
160        let messages = vec![Message::user(&user_message)];
161        let response = llm
162            .complete(&messages, Some(system), &[])
163            .await
164            .context("LLM call failed during achievement check")?;
165
166        let text = response.text();
167        Self::parse_achievement_response(&text)
168    }
169
170    /// Create a fallback plan using heuristic logic (no LLM required)
171    pub fn fallback_plan(prompt: &str) -> ExecutionPlan {
172        let complexity = if prompt.len() < 50 {
173            Complexity::Simple
174        } else if prompt.len() < 150 {
175            Complexity::Medium
176        } else if prompt.len() < 300 {
177            Complexity::Complex
178        } else {
179            Complexity::VeryComplex
180        };
181
182        let mut plan = ExecutionPlan::new(prompt, complexity);
183
184        let step_count = match complexity {
185            Complexity::Simple => 2,
186            Complexity::Medium => 4,
187            Complexity::Complex => 7,
188            Complexity::VeryComplex => 10,
189        };
190
191        for i in 0..step_count {
192            let step = Task::new(
193                format!("step-{}", i + 1),
194                crate::prompts::render(
195                    crate::prompts::PLAN_FALLBACK_STEP,
196                    &[("step_num", &(i + 1).to_string())],
197                ),
198            );
199            plan.add_step(step);
200        }
201
202        plan
203    }
204
205    /// Create a fallback goal using heuristic logic (no LLM required)
206    pub fn fallback_goal(prompt: &str) -> AgentGoal {
207        AgentGoal::new(prompt).with_criteria(vec![
208            "Task is completed successfully".to_string(),
209            "All requirements are met".to_string(),
210        ])
211    }
212
213    /// Create a fallback achievement result using heuristic logic (no LLM required)
214    pub fn fallback_check_achievement(goal: &AgentGoal, current_state: &str) -> AchievementResult {
215        let state_lower = current_state.to_lowercase();
216        let achieved = state_lower.contains("complete")
217            || state_lower.contains("done")
218            || state_lower.contains("finished");
219
220        let progress = if achieved { 1.0 } else { goal.progress };
221
222        let remaining_criteria = if achieved {
223            Vec::new()
224        } else {
225            goal.success_criteria.clone()
226        };
227
228        AchievementResult {
229            achieved,
230            progress,
231            remaining_criteria,
232        }
233    }
234
235    /// Perform pre-analysis in a single LLM call: intent classification, goal extraction,
236    /// execution plan, and input optimization. Falls back to heuristics on failure.
237    pub async fn pre_analyze(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<PreAnalysis> {
238        let system = crate::prompts::PRE_ANALYSIS_SYSTEM;
239
240        let messages = vec![Message::user(prompt)];
241        let response = llm
242            .complete(&messages, Some(system), &[])
243            .await
244            .context("LLM pre-analysis call failed")?;
245
246        let text = response.text();
247        Self::parse_pre_analysis_response(&text, prompt)
248    }
249
250    fn parse_pre_analysis_response(text: &str, original_prompt: &str) -> Result<PreAnalysis> {
251        let cleaned = Self::extract_json(text);
252        let parsed: PreAnalysisResponse = serde_json::from_str(cleaned)
253            .context("Failed to parse pre-analysis JSON from LLM response")?;
254
255        let intent = match parsed.intent.to_lowercase().as_str() {
256            "plan" => crate::prompts::AgentStyle::Plan,
257            "explore" => crate::prompts::AgentStyle::Explore,
258            "verification" => crate::prompts::AgentStyle::Verification,
259            "codereview" | "code review" => crate::prompts::AgentStyle::CodeReview,
260            _ => crate::prompts::AgentStyle::GeneralPurpose,
261        };
262
263        let goal_description = parsed.goal.description.clone();
264        let goal =
265            AgentGoal::new(goal_description.clone()).with_criteria(parsed.goal.success_criteria);
266
267        let complexity = match parsed.execution_plan.complexity.as_str() {
268            "Simple" => Complexity::Simple,
269            "Medium" => Complexity::Medium,
270            "Complex" => Complexity::Complex,
271            "VeryComplex" => Complexity::VeryComplex,
272            _ => Complexity::Medium,
273        };
274
275        let mut plan = ExecutionPlan::new(goal_description, complexity);
276        for step_resp in parsed.execution_plan.steps {
277            let mut task = Task::new(step_resp.id, step_resp.description);
278            if let Some(tool) = step_resp.tool {
279                task = task.with_tool(tool);
280            }
281            if !step_resp.dependencies.is_empty() {
282                task = task.with_dependencies(step_resp.dependencies);
283            }
284            if let Some(criteria) = step_resp.success_criteria {
285                task = task.with_success_criteria(criteria);
286            }
287            plan.add_step(task);
288        }
289        for tool in parsed.execution_plan.required_tools {
290            plan.add_required_tool(tool);
291        }
292
293        Ok(PreAnalysis {
294            intent,
295            requires_planning: parsed.requires_planning,
296            goal,
297            execution_plan: plan,
298            optimized_input: if parsed.optimized_input.is_empty() {
299                original_prompt.to_string()
300            } else {
301                parsed.optimized_input
302            },
303        })
304    }
305
306    // ========================================================================
307    // JSON parsing helpers
308    // ========================================================================
309
310    fn parse_plan_response(text: &str) -> Result<ExecutionPlan> {
311        let cleaned = Self::extract_json(text);
312        let parsed: PlanResponse =
313            serde_json::from_str(cleaned).context("Failed to parse plan JSON from LLM response")?;
314
315        let complexity = match parsed.complexity.as_str() {
316            "Simple" => Complexity::Simple,
317            "Medium" => Complexity::Medium,
318            "Complex" => Complexity::Complex,
319            "VeryComplex" => Complexity::VeryComplex,
320            _ => Complexity::Medium,
321        };
322
323        let mut plan = ExecutionPlan::new(parsed.goal, complexity);
324
325        for step_resp in parsed.steps {
326            let mut task = Task::new(step_resp.id, step_resp.description);
327            if let Some(tool) = step_resp.tool {
328                task = task.with_tool(tool);
329            }
330            if !step_resp.dependencies.is_empty() {
331                task = task.with_dependencies(step_resp.dependencies);
332            }
333            if let Some(criteria) = step_resp.success_criteria {
334                task = task.with_success_criteria(criteria);
335            }
336            plan.add_step(task);
337        }
338
339        for tool in parsed.required_tools {
340            plan.add_required_tool(tool);
341        }
342
343        Ok(plan)
344    }
345
346    fn parse_goal_response(text: &str) -> Result<AgentGoal> {
347        let cleaned = Self::extract_json(text);
348        let parsed: GoalResponse =
349            serde_json::from_str(cleaned).context("Failed to parse goal JSON from LLM response")?;
350
351        Ok(AgentGoal::new(parsed.description).with_criteria(parsed.success_criteria))
352    }
353
354    fn parse_achievement_response(text: &str) -> Result<AchievementResult> {
355        let cleaned = Self::extract_json(text);
356        let parsed: AchievementResponse = serde_json::from_str(cleaned)
357            .context("Failed to parse achievement JSON from LLM response")?;
358
359        Ok(AchievementResult {
360            achieved: parsed.achieved,
361            progress: parsed.progress.clamp(0.0, 1.0),
362            remaining_criteria: parsed.remaining_criteria,
363        })
364    }
365
366    /// Extract JSON from LLM text that may contain markdown fences
367    fn extract_json(text: &str) -> &str {
368        let trimmed = text.trim();
369
370        // Strip markdown code fences if present
371        if let Some(start) = trimmed.find('{') {
372            if let Some(end) = trimmed.rfind('}') {
373                if start <= end {
374                    return &trimmed[start..=end];
375                }
376            }
377        }
378
379        trimmed
380    }
381}
382
383// Implement Planner trait for LlmPlanner
384#[async_trait]
385impl Planner for LlmPlanner {
386    async fn create_plan(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
387        LlmPlanner::create_plan(llm, prompt).await
388    }
389
390    async fn extract_goal(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
391        LlmPlanner::extract_goal(llm, prompt).await
392    }
393
394    async fn check_achievement(
395        &self,
396        llm: &Arc<dyn LlmClient>,
397        goal: &AgentGoal,
398        current_state: &str,
399    ) -> Result<AchievementResult> {
400        LlmPlanner::check_achievement(llm, goal, current_state).await
401    }
402}
403
404// ============================================================================
405// Tests
406// ============================================================================
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_parse_plan_response() {
414        let json = r#"{
415            "goal": "Build a REST API",
416            "complexity": "Complex",
417            "steps": [
418                {
419                    "id": "step-1",
420                    "description": "Set up project structure",
421                    "tool": "bash",
422                    "dependencies": [],
423                    "success_criteria": "Project directory created"
424                },
425                {
426                    "id": "step-2",
427                    "description": "Implement endpoints",
428                    "tool": "write",
429                    "dependencies": ["step-1"],
430                    "success_criteria": "Endpoints respond correctly"
431                }
432            ],
433            "required_tools": ["bash", "write", "read"]
434        }"#;
435
436        let plan = LlmPlanner::parse_plan_response(json).unwrap();
437        assert_eq!(plan.goal, "Build a REST API");
438        assert_eq!(plan.complexity, Complexity::Complex);
439        assert_eq!(plan.steps.len(), 2);
440        assert_eq!(plan.steps[0].id, "step-1");
441        assert_eq!(plan.steps[0].tool, Some("bash".to_string()));
442        assert_eq!(plan.steps[1].dependencies, vec!["step-1".to_string()]);
443        assert_eq!(plan.required_tools, vec!["bash", "write", "read"]);
444    }
445
446    #[test]
447    fn test_parse_plan_response_with_markdown_fences() {
448        let json = "```json\n{\"goal\": \"Test\", \"complexity\": \"Simple\", \"steps\": [{\"id\": \"step-1\", \"description\": \"Do it\"}], \"required_tools\": []}\n```";
449
450        let plan = LlmPlanner::parse_plan_response(json).unwrap();
451        assert_eq!(plan.goal, "Test");
452        assert_eq!(plan.complexity, Complexity::Simple);
453        assert_eq!(plan.steps.len(), 1);
454    }
455
456    #[test]
457    fn test_parse_plan_response_invalid() {
458        let bad_json = "This is not JSON at all";
459        let result = LlmPlanner::parse_plan_response(bad_json);
460        assert!(result.is_err());
461    }
462
463    #[test]
464    fn test_parse_plan_response_unknown_complexity() {
465        let json =
466            r#"{"goal": "Test", "complexity": "Unknown", "steps": [], "required_tools": []}"#;
467        let plan = LlmPlanner::parse_plan_response(json).unwrap();
468        assert_eq!(plan.complexity, Complexity::Medium); // falls back to Medium
469    }
470
471    #[test]
472    fn test_parse_goal_response() {
473        let json = r#"{
474            "description": "Deploy the application to production",
475            "success_criteria": [
476                "All tests pass",
477                "Application is accessible at production URL",
478                "Health check returns 200"
479            ]
480        }"#;
481
482        let goal = LlmPlanner::parse_goal_response(json).unwrap();
483        assert_eq!(goal.description, "Deploy the application to production");
484        assert_eq!(goal.success_criteria.len(), 3);
485        assert_eq!(goal.success_criteria[0], "All tests pass");
486    }
487
488    #[test]
489    fn test_parse_goal_response_invalid() {
490        let result = LlmPlanner::parse_goal_response("not json");
491        assert!(result.is_err());
492    }
493
494    #[test]
495    fn test_parse_achievement_response() {
496        let json = r#"{
497            "achieved": false,
498            "progress": 0.65,
499            "remaining_criteria": ["Health check not verified"]
500        }"#;
501
502        let result = LlmPlanner::parse_achievement_response(json).unwrap();
503        assert!(!result.achieved);
504        assert!((result.progress - 0.65).abs() < f32::EPSILON);
505        assert_eq!(result.remaining_criteria, vec!["Health check not verified"]);
506    }
507
508    #[test]
509    fn test_parse_achievement_response_achieved() {
510        let json = r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#;
511        let result = LlmPlanner::parse_achievement_response(json).unwrap();
512        assert!(result.achieved);
513        assert!((result.progress - 1.0).abs() < f32::EPSILON);
514        assert!(result.remaining_criteria.is_empty());
515    }
516
517    #[test]
518    fn test_parse_achievement_response_clamps_progress() {
519        let json = r#"{"achieved": false, "progress": 1.5, "remaining_criteria": []}"#;
520        let result = LlmPlanner::parse_achievement_response(json).unwrap();
521        assert!((result.progress - 1.0).abs() < f32::EPSILON);
522    }
523
524    #[test]
525    fn test_fallback_plan() {
526        let short_prompt = "Fix bug";
527        let plan = LlmPlanner::fallback_plan(short_prompt);
528        assert_eq!(plan.complexity, Complexity::Simple);
529        assert_eq!(plan.steps.len(), 2);
530        assert_eq!(plan.goal, short_prompt);
531
532        let long_prompt = "Implement a comprehensive authentication system with OAuth2 support, JWT tokens, refresh token rotation, multi-factor authentication, and role-based access control across all API endpoints with proper audit logging and session management capabilities for both web and mobile clients, including password reset flows, account lockout policies, and integration with external identity providers such as Google, GitHub, and SAML-based enterprise SSO systems";
533        let plan = LlmPlanner::fallback_plan(long_prompt);
534        assert_eq!(plan.complexity, Complexity::VeryComplex);
535        assert_eq!(plan.steps.len(), 10);
536    }
537
538    #[test]
539    fn test_fallback_goal() {
540        let goal = LlmPlanner::fallback_goal("Fix the login bug");
541        assert_eq!(goal.description, "Fix the login bug");
542        assert_eq!(goal.success_criteria.len(), 2);
543        assert_eq!(goal.success_criteria[0], "Task is completed successfully");
544    }
545
546    #[test]
547    fn test_fallback_check_achievement_done() {
548        let goal = AgentGoal::new("Test task").with_criteria(vec!["Criterion 1".to_string()]);
549
550        let result = LlmPlanner::fallback_check_achievement(&goal, "The task is done.");
551        assert!(result.achieved);
552        assert!((result.progress - 1.0).abs() < f32::EPSILON);
553        assert!(result.remaining_criteria.is_empty());
554    }
555
556    #[test]
557    fn test_fallback_check_achievement_not_done() {
558        let goal = AgentGoal::new("Test task")
559            .with_criteria(vec!["Criterion 1".to_string(), "Criterion 2".to_string()]);
560
561        let result = LlmPlanner::fallback_check_achievement(&goal, "Work in progress");
562        assert!(!result.achieved);
563        assert_eq!(result.remaining_criteria.len(), 2);
564    }
565
566    #[test]
567    fn test_extract_json_plain() {
568        assert_eq!(LlmPlanner::extract_json("  {\"a\": 1}  "), "{\"a\": 1}");
569    }
570
571    #[test]
572    fn test_extract_json_with_fences() {
573        let text = "```json\n{\"a\": 1}\n```";
574        assert_eq!(LlmPlanner::extract_json(text), "{\"a\": 1}");
575    }
576
577    #[test]
578    fn test_extract_json_with_surrounding_text() {
579        let text = "Here is the plan:\n{\"goal\": \"test\"}\nDone.";
580        assert_eq!(LlmPlanner::extract_json(text), "{\"goal\": \"test\"}");
581    }
582}