Skip to main content

a3s_code_core/planning/
llm_planner.rs

1//! LLM-powered planning logic
2//!
3//! Provides intelligent plan generation, goal extraction, and achievement
4//! evaluation by sending structured prompts to an LLM and parsing JSON responses.
5//! Falls back to heuristic logic when no LLM client is available.
6
7use crate::llm::{LlmClient, Message};
8use crate::planning::{AgentGoal, Complexity, ExecutionPlan, Task};
9use anyhow::{Context, Result};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13/// Result of evaluating goal achievement
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AchievementResult {
16    /// Whether the goal has been achieved
17    pub achieved: bool,
18    /// Progress toward goal (0.0 - 1.0)
19    pub progress: f32,
20    /// Criteria that remain unmet
21    pub remaining_criteria: Vec<String>,
22}
23
24/// Pre-analysis result — intent, goal, plan, and optimized input in one LLM call.
25#[derive(Debug, Clone)]
26pub struct PreAnalysis {
27    pub intent: crate::prompts::AgentStyle,
28    pub requires_planning: bool,
29    pub goal: AgentGoal,
30    pub execution_plan: ExecutionPlan,
31    /// LLM-rewritten version of the user input with ambiguities resolved.
32    pub optimized_input: String,
33}
34
35/// LLM-powered planner that generates plans, extracts goals, and evaluates achievement
36pub struct LlmPlanner;
37
38// ============================================================================
39// JSON response schemas for LLM parsing
40// ============================================================================
41
42#[derive(Debug, Deserialize)]
43struct PlanResponse {
44    goal: String,
45    complexity: String,
46    steps: Vec<StepResponse>,
47    #[serde(default)]
48    required_tools: Vec<String>,
49}
50
51#[derive(Debug, Deserialize)]
52struct StepResponse {
53    id: String,
54    description: String,
55    #[serde(default)]
56    tool: Option<String>,
57    #[serde(default)]
58    dependencies: Vec<String>,
59    #[serde(default)]
60    success_criteria: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64struct GoalResponse {
65    description: String,
66    success_criteria: Vec<String>,
67}
68
69#[derive(Debug, Deserialize)]
70struct AchievementResponse {
71    achieved: bool,
72    progress: f32,
73    #[serde(default)]
74    remaining_criteria: Vec<String>,
75}
76
77#[derive(Debug, Deserialize)]
78struct PreAnalysisResponse {
79    intent: String,
80    requires_planning: bool,
81    goal: GoalResponse,
82    execution_plan: PreAnalysisPlan,
83    optimized_input: String,
84}
85
86#[derive(Debug, Deserialize)]
87struct PreAnalysisPlan {
88    complexity: String,
89    steps: Vec<StepResponse>,
90    #[serde(default)]
91    required_tools: Vec<String>,
92}
93
94impl LlmPlanner {
95    /// Generate an execution plan from a prompt using LLM
96    pub async fn create_plan(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
97        let system = crate::prompts::LLM_PLAN_SYSTEM;
98
99        let messages = vec![Message::user(prompt)];
100        let response = llm
101            .complete(&messages, Some(system), &[])
102            .await
103            .context("LLM call failed during plan creation")?;
104
105        let text = response.text();
106        Self::parse_plan_response(&text)
107    }
108
109    /// Extract a goal with success criteria from a prompt using LLM
110    pub async fn extract_goal(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
111        let system = crate::prompts::LLM_GOAL_EXTRACT_SYSTEM;
112
113        let messages = vec![Message::user(prompt)];
114        let response = llm
115            .complete(&messages, Some(system), &[])
116            .await
117            .context("LLM call failed during goal extraction")?;
118
119        let text = response.text();
120        Self::parse_goal_response(&text)
121    }
122
123    /// Evaluate whether a goal has been achieved given current state
124    pub async fn check_achievement(
125        llm: &Arc<dyn LlmClient>,
126        goal: &AgentGoal,
127        current_state: &str,
128    ) -> Result<AchievementResult> {
129        let system = crate::prompts::LLM_GOAL_CHECK_SYSTEM;
130
131        let user_message = format!(
132            "Goal: {}\nSuccess Criteria: {}\nCurrent State: {}",
133            goal.description,
134            goal.success_criteria.join("; "),
135            current_state,
136        );
137
138        let messages = vec![Message::user(&user_message)];
139        let response = llm
140            .complete(&messages, Some(system), &[])
141            .await
142            .context("LLM call failed during achievement check")?;
143
144        let text = response.text();
145        Self::parse_achievement_response(&text)
146    }
147
148    /// Create a fallback plan using heuristic logic (no LLM required)
149    pub fn fallback_plan(prompt: &str) -> ExecutionPlan {
150        let complexity = if prompt.len() < 50 {
151            Complexity::Simple
152        } else if prompt.len() < 150 {
153            Complexity::Medium
154        } else if prompt.len() < 300 {
155            Complexity::Complex
156        } else {
157            Complexity::VeryComplex
158        };
159
160        let mut plan = ExecutionPlan::new(prompt, complexity);
161
162        let step_count = match complexity {
163            Complexity::Simple => 2,
164            Complexity::Medium => 4,
165            Complexity::Complex => 7,
166            Complexity::VeryComplex => 10,
167        };
168
169        for i in 0..step_count {
170            let step = Task::new(
171                format!("step-{}", i + 1),
172                crate::prompts::render(
173                    crate::prompts::PLAN_FALLBACK_STEP,
174                    &[("step_num", &(i + 1).to_string())],
175                ),
176            );
177            plan.add_step(step);
178        }
179
180        plan
181    }
182
183    /// Create a fallback goal using heuristic logic (no LLM required)
184    pub fn fallback_goal(prompt: &str) -> AgentGoal {
185        AgentGoal::new(prompt).with_criteria(vec![
186            "Task is completed successfully".to_string(),
187            "All requirements are met".to_string(),
188        ])
189    }
190
191    /// Create a fallback achievement result using heuristic logic (no LLM required)
192    pub fn fallback_check_achievement(goal: &AgentGoal, current_state: &str) -> AchievementResult {
193        let state_lower = current_state.to_lowercase();
194        let achieved = state_lower.contains("complete")
195            || state_lower.contains("done")
196            || state_lower.contains("finished");
197
198        let progress = if achieved { 1.0 } else { goal.progress };
199
200        let remaining_criteria = if achieved {
201            Vec::new()
202        } else {
203            goal.success_criteria.clone()
204        };
205
206        AchievementResult {
207            achieved,
208            progress,
209            remaining_criteria,
210        }
211    }
212
213    /// Perform pre-analysis in a single LLM call: intent classification, goal extraction,
214    /// execution plan, and input optimization. Falls back to heuristics on failure.
215    pub async fn pre_analyze(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<PreAnalysis> {
216        let system = crate::prompts::PRE_ANALYSIS_SYSTEM;
217
218        // One initial attempt plus one repair round: if the model returns
219        // unparseable JSON, re-prompt it once to emit strictly valid JSON before
220        // giving up (callers fall back to heuristics on the returned error).
221        const MAX_ATTEMPTS: usize = 2;
222        let mut messages = vec![Message::user(prompt)];
223        let mut last_err: Option<anyhow::Error> = None;
224
225        for attempt in 0..MAX_ATTEMPTS {
226            let response = llm
227                .complete(&messages, Some(system), &[])
228                .await
229                .context("LLM pre-analysis call failed")?;
230
231            let text = response.text();
232            match Self::parse_pre_analysis_response(&text, prompt) {
233                Ok(analysis) => return Ok(analysis),
234                Err(e) => {
235                    last_err = Some(e);
236                    if attempt + 1 < MAX_ATTEMPTS {
237                        messages.push(response.message.clone());
238                        messages.push(Message::user(
239                            "Your previous response was not valid JSON matching the required \
240                             schema. Respond again with ONLY the JSON object — no markdown \
241                             fences, no prose, no explanation.",
242                        ));
243                    }
244                }
245            }
246        }
247
248        Err(last_err.unwrap_or_else(|| anyhow::anyhow!("pre-analysis produced no result")))
249    }
250
251    fn parse_pre_analysis_response(text: &str, original_prompt: &str) -> Result<PreAnalysis> {
252        let parsed: PreAnalysisResponse = Self::parse_json_lenient(text)
253            .context("Failed to parse pre-analysis JSON from LLM response")?;
254
255        let intent = match parsed.intent.to_lowercase().as_str() {
256            "plan" => crate::prompts::AgentStyle::Plan,
257            "explore" => crate::prompts::AgentStyle::Explore,
258            "verification" => crate::prompts::AgentStyle::Verification,
259            "codereview" | "code review" => crate::prompts::AgentStyle::CodeReview,
260            _ => crate::prompts::AgentStyle::GeneralPurpose,
261        };
262
263        let goal_description = parsed.goal.description.clone();
264        let goal =
265            AgentGoal::new(goal_description.clone()).with_criteria(parsed.goal.success_criteria);
266
267        let complexity = match parsed.execution_plan.complexity.as_str() {
268            "Simple" => Complexity::Simple,
269            "Medium" => Complexity::Medium,
270            "Complex" => Complexity::Complex,
271            "VeryComplex" => Complexity::VeryComplex,
272            _ => Complexity::Medium,
273        };
274
275        let mut plan = ExecutionPlan::new(goal_description, complexity);
276        for step_resp in parsed.execution_plan.steps {
277            let mut task = Task::new(step_resp.id, step_resp.description);
278            if let Some(tool) = step_resp.tool {
279                task = task.with_tool(tool);
280            }
281            if !step_resp.dependencies.is_empty() {
282                task = task.with_dependencies(step_resp.dependencies);
283            }
284            if let Some(criteria) = step_resp.success_criteria {
285                task = task.with_success_criteria(criteria);
286            }
287            plan.add_step(task);
288        }
289        for tool in parsed.execution_plan.required_tools {
290            plan.add_required_tool(tool);
291        }
292
293        Ok(PreAnalysis {
294            intent,
295            requires_planning: parsed.requires_planning,
296            goal,
297            execution_plan: plan,
298            optimized_input: if parsed.optimized_input.is_empty() {
299                original_prompt.to_string()
300            } else {
301                parsed.optimized_input
302            },
303        })
304    }
305
306    // ========================================================================
307    // JSON parsing helpers
308    // ========================================================================
309
310    fn parse_plan_response(text: &str) -> Result<ExecutionPlan> {
311        let parsed: PlanResponse = Self::parse_json_lenient(text)
312            .context("Failed to parse plan JSON from LLM response")?;
313
314        let complexity = match parsed.complexity.as_str() {
315            "Simple" => Complexity::Simple,
316            "Medium" => Complexity::Medium,
317            "Complex" => Complexity::Complex,
318            "VeryComplex" => Complexity::VeryComplex,
319            _ => Complexity::Medium,
320        };
321
322        let mut plan = ExecutionPlan::new(parsed.goal, complexity);
323
324        for step_resp in parsed.steps {
325            let mut task = Task::new(step_resp.id, step_resp.description);
326            if let Some(tool) = step_resp.tool {
327                task = task.with_tool(tool);
328            }
329            if !step_resp.dependencies.is_empty() {
330                task = task.with_dependencies(step_resp.dependencies);
331            }
332            if let Some(criteria) = step_resp.success_criteria {
333                task = task.with_success_criteria(criteria);
334            }
335            plan.add_step(task);
336        }
337
338        for tool in parsed.required_tools {
339            plan.add_required_tool(tool);
340        }
341
342        Ok(plan)
343    }
344
345    fn parse_goal_response(text: &str) -> Result<AgentGoal> {
346        let parsed: GoalResponse = Self::parse_json_lenient(text)
347            .context("Failed to parse goal JSON from LLM response")?;
348
349        Ok(AgentGoal::new(parsed.description).with_criteria(parsed.success_criteria))
350    }
351
352    fn parse_achievement_response(text: &str) -> Result<AchievementResult> {
353        let parsed: AchievementResponse = Self::parse_json_lenient(text)
354            .context("Failed to parse achievement JSON from LLM response")?;
355
356        Ok(AchievementResult {
357            achieved: parsed.achieved,
358            progress: parsed.progress.clamp(0.0, 1.0),
359            remaining_criteria: parsed.remaining_criteria,
360        })
361    }
362
363    /// Parse JSON from possibly-dirty LLM output into the target type.
364    ///
365    /// Uses the shared robust extractor ([`crate::llm::structured::extract_json_value`]),
366    /// which handles ```json fences, surrounding prose, and braces embedded in
367    /// strings — unlike the previous naive first-`{`/last-`}` slice, which broke
368    /// on fenced output or any `}` inside a string value.
369    fn parse_json_lenient<T: serde::de::DeserializeOwned>(text: &str) -> Result<T> {
370        let value = crate::llm::structured::extract_json_value(text)?;
371        Ok(serde_json::from_value(value)?)
372    }
373}
374
375// ============================================================================
376// Tests
377// ============================================================================
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn test_parse_plan_response() {
385        let json = r#"{
386            "goal": "Build a REST API",
387            "complexity": "Complex",
388            "steps": [
389                {
390                    "id": "step-1",
391                    "description": "Set up project structure",
392                    "tool": "bash",
393                    "dependencies": [],
394                    "success_criteria": "Project directory created"
395                },
396                {
397                    "id": "step-2",
398                    "description": "Implement endpoints",
399                    "tool": "write",
400                    "dependencies": ["step-1"],
401                    "success_criteria": "Endpoints respond correctly"
402                }
403            ],
404            "required_tools": ["bash", "write", "read"]
405        }"#;
406
407        let plan = LlmPlanner::parse_plan_response(json).unwrap();
408        assert_eq!(plan.goal, "Build a REST API");
409        assert_eq!(plan.complexity, Complexity::Complex);
410        assert_eq!(plan.steps.len(), 2);
411        assert_eq!(plan.steps[0].id, "step-1");
412        assert_eq!(plan.steps[0].tool, Some("bash".to_string()));
413        assert_eq!(plan.steps[1].dependencies, vec!["step-1".to_string()]);
414        assert_eq!(plan.required_tools, vec!["bash", "write", "read"]);
415    }
416
417    #[test]
418    fn test_parse_plan_response_with_markdown_fences() {
419        let json = "```json\n{\"goal\": \"Test\", \"complexity\": \"Simple\", \"steps\": [{\"id\": \"step-1\", \"description\": \"Do it\"}], \"required_tools\": []}\n```";
420
421        let plan = LlmPlanner::parse_plan_response(json).unwrap();
422        assert_eq!(plan.goal, "Test");
423        assert_eq!(plan.complexity, Complexity::Simple);
424        assert_eq!(plan.steps.len(), 1);
425    }
426
427    #[test]
428    fn test_parse_plan_response_invalid() {
429        let bad_json = "This is not JSON at all";
430        let result = LlmPlanner::parse_plan_response(bad_json);
431        assert!(result.is_err());
432    }
433
434    #[test]
435    fn test_parse_plan_response_unknown_complexity() {
436        let json =
437            r#"{"goal": "Test", "complexity": "Unknown", "steps": [], "required_tools": []}"#;
438        let plan = LlmPlanner::parse_plan_response(json).unwrap();
439        assert_eq!(plan.complexity, Complexity::Medium); // falls back to Medium
440    }
441
442    #[test]
443    fn test_parse_goal_response() {
444        let json = r#"{
445            "description": "Deploy the application to production",
446            "success_criteria": [
447                "All tests pass",
448                "Application is accessible at production URL",
449                "Health check returns 200"
450            ]
451        }"#;
452
453        let goal = LlmPlanner::parse_goal_response(json).unwrap();
454        assert_eq!(goal.description, "Deploy the application to production");
455        assert_eq!(goal.success_criteria.len(), 3);
456        assert_eq!(goal.success_criteria[0], "All tests pass");
457    }
458
459    #[test]
460    fn test_parse_goal_response_invalid() {
461        let result = LlmPlanner::parse_goal_response("not json");
462        assert!(result.is_err());
463    }
464
465    #[test]
466    fn test_parse_achievement_response() {
467        let json = r#"{
468            "achieved": false,
469            "progress": 0.65,
470            "remaining_criteria": ["Health check not verified"]
471        }"#;
472
473        let result = LlmPlanner::parse_achievement_response(json).unwrap();
474        assert!(!result.achieved);
475        assert!((result.progress - 0.65).abs() < f32::EPSILON);
476        assert_eq!(result.remaining_criteria, vec!["Health check not verified"]);
477    }
478
479    #[test]
480    fn test_parse_achievement_response_achieved() {
481        let json = r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#;
482        let result = LlmPlanner::parse_achievement_response(json).unwrap();
483        assert!(result.achieved);
484        assert!((result.progress - 1.0).abs() < f32::EPSILON);
485        assert!(result.remaining_criteria.is_empty());
486    }
487
488    #[test]
489    fn test_parse_achievement_response_clamps_progress() {
490        let json = r#"{"achieved": false, "progress": 1.5, "remaining_criteria": []}"#;
491        let result = LlmPlanner::parse_achievement_response(json).unwrap();
492        assert!((result.progress - 1.0).abs() < f32::EPSILON);
493    }
494
495    #[test]
496    fn test_fallback_plan() {
497        let short_prompt = "Fix bug";
498        let plan = LlmPlanner::fallback_plan(short_prompt);
499        assert_eq!(plan.complexity, Complexity::Simple);
500        assert_eq!(plan.steps.len(), 2);
501        assert_eq!(plan.goal, short_prompt);
502
503        let long_prompt = "Implement a comprehensive authentication system with OAuth2 support, JWT tokens, refresh token rotation, multi-factor authentication, and role-based access control across all API endpoints with proper audit logging and session management capabilities for both web and mobile clients, including password reset flows, account lockout policies, and integration with external identity providers such as Google, GitHub, and SAML-based enterprise SSO systems";
504        let plan = LlmPlanner::fallback_plan(long_prompt);
505        assert_eq!(plan.complexity, Complexity::VeryComplex);
506        assert_eq!(plan.steps.len(), 10);
507    }
508
509    #[test]
510    fn test_fallback_goal() {
511        let goal = LlmPlanner::fallback_goal("Fix the login bug");
512        assert_eq!(goal.description, "Fix the login bug");
513        assert_eq!(goal.success_criteria.len(), 2);
514        assert_eq!(goal.success_criteria[0], "Task is completed successfully");
515    }
516
517    #[test]
518    fn test_fallback_check_achievement_done() {
519        let goal = AgentGoal::new("Test task").with_criteria(vec!["Criterion 1".to_string()]);
520
521        let result = LlmPlanner::fallback_check_achievement(&goal, "The task is done.");
522        assert!(result.achieved);
523        assert!((result.progress - 1.0).abs() < f32::EPSILON);
524        assert!(result.remaining_criteria.is_empty());
525    }
526
527    #[test]
528    fn test_fallback_check_achievement_not_done() {
529        let goal = AgentGoal::new("Test task")
530            .with_criteria(vec!["Criterion 1".to_string(), "Criterion 2".to_string()]);
531
532        let result = LlmPlanner::fallback_check_achievement(&goal, "Work in progress");
533        assert!(!result.achieved);
534        assert_eq!(result.remaining_criteria.len(), 2);
535    }
536
537    #[test]
538    fn test_parse_json_lenient_plain() {
539        let v: serde_json::Value = LlmPlanner::parse_json_lenient("  {\"a\": 1}  ").unwrap();
540        assert_eq!(v["a"], 1);
541    }
542
543    #[test]
544    fn test_parse_json_lenient_with_fences() {
545        let text = "```json\n{\"a\": 1}\n```";
546        let v: serde_json::Value = LlmPlanner::parse_json_lenient(text).unwrap();
547        assert_eq!(v["a"], 1);
548    }
549
550    #[test]
551    fn test_parse_json_lenient_with_surrounding_prose() {
552        let text = "Here is the plan:\n{\"goal\": \"test\"}\nDone.";
553        let v: serde_json::Value = LlmPlanner::parse_json_lenient(text).unwrap();
554        assert_eq!(v["goal"], "test");
555    }
556
557    #[test]
558    fn test_parse_json_lenient_brace_inside_string_value() {
559        // The old naive first-`{`/last-`}` slice broke when a string value
560        // contained a `}` followed by trailing prose; the robust extractor
561        // balances braces while respecting string boundaries.
562        let text = "Result: {\"note\": \"use a closing brace } here\"} -- end.";
563        let v: serde_json::Value = LlmPlanner::parse_json_lenient(text).unwrap();
564        assert_eq!(v["note"], "use a closing brace } here");
565    }
566
567    #[test]
568    fn test_parse_json_lenient_fenced_with_trailing_prose() {
569        // ```json fence followed by an explanation. The naive parser's
570        // `rfind('}')` could grab a brace from the trailing prose.
571        let text = "```json\n{\"goal\": \"ship\"}\n```\nNote: revisit the `plan` later.";
572        let v: serde_json::Value = LlmPlanner::parse_json_lenient(text).unwrap();
573        assert_eq!(v["goal"], "ship");
574    }
575
576    #[test]
577    fn test_parse_json_lenient_rejects_non_json() {
578        let err = LlmPlanner::parse_json_lenient::<serde_json::Value>("no json here at all");
579        assert!(err.is_err());
580    }
581
582    /// Replays a fixed sequence of assistant text responses, one per call.
583    struct ReplayClient {
584        responses: std::sync::Mutex<Vec<String>>,
585    }
586
587    impl ReplayClient {
588        fn new(responses: Vec<String>) -> Self {
589            Self {
590                responses: std::sync::Mutex::new(responses),
591            }
592        }
593    }
594
595    #[async_trait::async_trait]
596    impl LlmClient for ReplayClient {
597        async fn complete(
598            &self,
599            _messages: &[Message],
600            _system: Option<&str>,
601            _tools: &[crate::llm::ToolDefinition],
602        ) -> anyhow::Result<crate::llm::LlmResponse> {
603            let text = {
604                let mut r = self.responses.lock().unwrap();
605                if r.is_empty() {
606                    String::new()
607                } else {
608                    r.remove(0)
609                }
610            };
611            Ok(crate::llm::LlmResponse {
612                message: Message {
613                    role: "assistant".to_string(),
614                    content: vec![crate::llm::ContentBlock::Text { text }],
615                    reasoning_content: None,
616                },
617                usage: crate::llm::TokenUsage::default(),
618                stop_reason: None,
619                meta: None,
620            })
621        }
622
623        async fn complete_streaming(
624            &self,
625            _messages: &[Message],
626            _system: Option<&str>,
627            _tools: &[crate::llm::ToolDefinition],
628            _cancel_token: tokio_util::sync::CancellationToken,
629        ) -> anyhow::Result<tokio::sync::mpsc::Receiver<crate::llm::StreamEvent>> {
630            anyhow::bail!("streaming not used in planner tests")
631        }
632    }
633
634    #[tokio::test]
635    async fn test_pre_analyze_repairs_invalid_json() {
636        // First response is unparseable; pre_analyze must re-prompt and succeed
637        // on the second (valid) response.
638        let good = r#"{"intent":"explore","requires_planning":false,"goal":{"description":"Do x","success_criteria":["done"]},"execution_plan":{"complexity":"Simple","steps":[],"required_tools":[]},"optimized_input":"Do x carefully"}"#;
639        let client: Arc<dyn LlmClient> = Arc::new(ReplayClient::new(vec![
640            "Sorry — here's the plan, but not as JSON.".to_string(),
641            good.to_string(),
642        ]));
643        let pa = LlmPlanner::pre_analyze(&client, "do x").await.unwrap();
644        assert_eq!(pa.optimized_input, "Do x carefully");
645    }
646
647    #[tokio::test]
648    async fn test_pre_analyze_first_try_with_fenced_json() {
649        // A single ```json-fenced response must parse on the first attempt
650        // (robust extractor), with no repair round needed.
651        let good = format!(
652            "```json\n{}\n```",
653            r#"{"intent":"plan","requires_planning":true,"goal":{"description":"g","success_criteria":[]},"execution_plan":{"complexity":"Medium","steps":[],"required_tools":[]},"optimized_input":"opt"}"#
654        );
655        let client: Arc<dyn LlmClient> = Arc::new(ReplayClient::new(vec![good]));
656        let pa = LlmPlanner::pre_analyze(&client, "do x").await.unwrap();
657        assert_eq!(pa.optimized_input, "opt");
658    }
659}