Skip to main content

a3s_code_core/planning/
llm_planner.rs

1//! LLM-powered planning logic
2//!
3//! Provides intelligent plan generation, goal extraction, and achievement
4//! evaluation by sending structured prompts to an LLM and parsing JSON responses.
5//! Falls back to heuristic logic when no LLM client is available.
6
7use crate::llm::{LlmClient, Message};
8use crate::planning::{AgentGoal, Complexity, ExecutionPlan, Task};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13
14/// Result of evaluating goal achievement
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AchievementResult {
17    /// Whether the goal has been achieved
18    pub achieved: bool,
19    /// Progress toward goal (0.0 - 1.0)
20    pub progress: f32,
21    /// Criteria that remain unmet
22    pub remaining_criteria: Vec<String>,
23}
24
25/// Trait for planning providers
26///
27/// Abstracts plan generation, goal extraction, and achievement evaluation.
28/// Allows different implementations (LLM-based, heuristic, test mocks).
29#[async_trait]
30pub trait Planner: Send + Sync {
31    /// Generate an execution plan from a prompt
32    async fn create_plan(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan>;
33
34    /// Extract a goal with success criteria from a prompt
35    async fn extract_goal(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal>;
36
37    /// Evaluate whether a goal has been achieved given current state
38    async fn check_achievement(
39        &self,
40        llm: &Arc<dyn LlmClient>,
41        goal: &AgentGoal,
42        current_state: &str,
43    ) -> Result<AchievementResult>;
44}
45
46/// LLM-powered planner that generates plans, extracts goals, and evaluates achievement
47pub struct LlmPlanner;
48
49// ============================================================================
50// JSON response schemas for LLM parsing
51// ============================================================================
52
53#[derive(Debug, Deserialize)]
54struct PlanResponse {
55    goal: String,
56    complexity: String,
57    steps: Vec<StepResponse>,
58    #[serde(default)]
59    required_tools: Vec<String>,
60}
61
62#[derive(Debug, Deserialize)]
63struct StepResponse {
64    id: String,
65    description: String,
66    #[serde(default)]
67    tool: Option<String>,
68    #[serde(default)]
69    dependencies: Vec<String>,
70    #[serde(default)]
71    success_criteria: Option<String>,
72}
73
74#[derive(Debug, Deserialize)]
75struct GoalResponse {
76    description: String,
77    success_criteria: Vec<String>,
78}
79
80#[derive(Debug, Deserialize)]
81struct AchievementResponse {
82    achieved: bool,
83    progress: f32,
84    #[serde(default)]
85    remaining_criteria: Vec<String>,
86}
87
88impl LlmPlanner {
89    /// Generate an execution plan from a prompt using LLM
90    pub async fn create_plan(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
91        let system = crate::prompts::LLM_PLAN_SYSTEM;
92
93        let messages = vec![Message::user(prompt)];
94        let response = llm
95            .complete(&messages, Some(system), &[])
96            .await
97            .context("LLM call failed during plan creation")?;
98
99        let text = response.text();
100        Self::parse_plan_response(&text)
101    }
102
103    /// Extract a goal with success criteria from a prompt using LLM
104    pub async fn extract_goal(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
105        let system = crate::prompts::LLM_GOAL_EXTRACT_SYSTEM;
106
107        let messages = vec![Message::user(prompt)];
108        let response = llm
109            .complete(&messages, Some(system), &[])
110            .await
111            .context("LLM call failed during goal extraction")?;
112
113        let text = response.text();
114        Self::parse_goal_response(&text)
115    }
116
117    /// Evaluate whether a goal has been achieved given current state
118    pub async fn check_achievement(
119        llm: &Arc<dyn LlmClient>,
120        goal: &AgentGoal,
121        current_state: &str,
122    ) -> Result<AchievementResult> {
123        let system = crate::prompts::LLM_GOAL_CHECK_SYSTEM;
124
125        let user_message = format!(
126            "Goal: {}\nSuccess Criteria: {}\nCurrent State: {}",
127            goal.description,
128            goal.success_criteria.join("; "),
129            current_state,
130        );
131
132        let messages = vec![Message::user(&user_message)];
133        let response = llm
134            .complete(&messages, Some(system), &[])
135            .await
136            .context("LLM call failed during achievement check")?;
137
138        let text = response.text();
139        Self::parse_achievement_response(&text)
140    }
141
142    /// Create a fallback plan using heuristic logic (no LLM required)
143    pub fn fallback_plan(prompt: &str) -> ExecutionPlan {
144        let complexity = if prompt.len() < 50 {
145            Complexity::Simple
146        } else if prompt.len() < 150 {
147            Complexity::Medium
148        } else if prompt.len() < 300 {
149            Complexity::Complex
150        } else {
151            Complexity::VeryComplex
152        };
153
154        let mut plan = ExecutionPlan::new(prompt, complexity);
155
156        let step_count = match complexity {
157            Complexity::Simple => 2,
158            Complexity::Medium => 4,
159            Complexity::Complex => 7,
160            Complexity::VeryComplex => 10,
161        };
162
163        for i in 0..step_count {
164            let step = Task::new(
165                format!("step-{}", i + 1),
166                crate::prompts::render(
167                    crate::prompts::PLAN_FALLBACK_STEP,
168                    &[("step_num", &(i + 1).to_string())],
169                ),
170            );
171            plan.add_step(step);
172        }
173
174        plan
175    }
176
177    /// Create a fallback goal using heuristic logic (no LLM required)
178    pub fn fallback_goal(prompt: &str) -> AgentGoal {
179        AgentGoal::new(prompt).with_criteria(vec![
180            "Task is completed successfully".to_string(),
181            "All requirements are met".to_string(),
182        ])
183    }
184
185    /// Create a fallback achievement result using heuristic logic (no LLM required)
186    pub fn fallback_check_achievement(goal: &AgentGoal, current_state: &str) -> AchievementResult {
187        let state_lower = current_state.to_lowercase();
188        let achieved = state_lower.contains("complete")
189            || state_lower.contains("done")
190            || state_lower.contains("finished");
191
192        let progress = if achieved { 1.0 } else { goal.progress };
193
194        let remaining_criteria = if achieved {
195            Vec::new()
196        } else {
197            goal.success_criteria.clone()
198        };
199
200        AchievementResult {
201            achieved,
202            progress,
203            remaining_criteria,
204        }
205    }
206
207    // ========================================================================
208    // JSON parsing helpers
209    // ========================================================================
210
211    fn parse_plan_response(text: &str) -> Result<ExecutionPlan> {
212        let cleaned = Self::extract_json(text);
213        let parsed: PlanResponse =
214            serde_json::from_str(cleaned).context("Failed to parse plan JSON from LLM response")?;
215
216        let complexity = match parsed.complexity.as_str() {
217            "Simple" => Complexity::Simple,
218            "Medium" => Complexity::Medium,
219            "Complex" => Complexity::Complex,
220            "VeryComplex" => Complexity::VeryComplex,
221            _ => Complexity::Medium,
222        };
223
224        let mut plan = ExecutionPlan::new(parsed.goal, complexity);
225
226        for step_resp in parsed.steps {
227            let mut task = Task::new(step_resp.id, step_resp.description);
228            if let Some(tool) = step_resp.tool {
229                task = task.with_tool(tool);
230            }
231            if !step_resp.dependencies.is_empty() {
232                task = task.with_dependencies(step_resp.dependencies);
233            }
234            if let Some(criteria) = step_resp.success_criteria {
235                task = task.with_success_criteria(criteria);
236            }
237            plan.add_step(task);
238        }
239
240        for tool in parsed.required_tools {
241            plan.add_required_tool(tool);
242        }
243
244        Ok(plan)
245    }
246
247    fn parse_goal_response(text: &str) -> Result<AgentGoal> {
248        let cleaned = Self::extract_json(text);
249        let parsed: GoalResponse =
250            serde_json::from_str(cleaned).context("Failed to parse goal JSON from LLM response")?;
251
252        Ok(AgentGoal::new(parsed.description).with_criteria(parsed.success_criteria))
253    }
254
255    fn parse_achievement_response(text: &str) -> Result<AchievementResult> {
256        let cleaned = Self::extract_json(text);
257        let parsed: AchievementResponse = serde_json::from_str(cleaned)
258            .context("Failed to parse achievement JSON from LLM response")?;
259
260        Ok(AchievementResult {
261            achieved: parsed.achieved,
262            progress: parsed.progress.clamp(0.0, 1.0),
263            remaining_criteria: parsed.remaining_criteria,
264        })
265    }
266
267    /// Extract JSON from LLM text that may contain markdown fences
268    fn extract_json(text: &str) -> &str {
269        let trimmed = text.trim();
270
271        // Strip markdown code fences if present
272        if let Some(start) = trimmed.find('{') {
273            if let Some(end) = trimmed.rfind('}') {
274                if start <= end {
275                    return &trimmed[start..=end];
276                }
277            }
278        }
279
280        trimmed
281    }
282}
283
284// Implement Planner trait for LlmPlanner
285#[async_trait]
286impl Planner for LlmPlanner {
287    async fn create_plan(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
288        LlmPlanner::create_plan(llm, prompt).await
289    }
290
291    async fn extract_goal(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
292        LlmPlanner::extract_goal(llm, prompt).await
293    }
294
295    async fn check_achievement(
296        &self,
297        llm: &Arc<dyn LlmClient>,
298        goal: &AgentGoal,
299        current_state: &str,
300    ) -> Result<AchievementResult> {
301        LlmPlanner::check_achievement(llm, goal, current_state).await
302    }
303}
304
305// ============================================================================
306// Tests
307// ============================================================================
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_parse_plan_response() {
315        let json = r#"{
316            "goal": "Build a REST API",
317            "complexity": "Complex",
318            "steps": [
319                {
320                    "id": "step-1",
321                    "description": "Set up project structure",
322                    "tool": "bash",
323                    "dependencies": [],
324                    "success_criteria": "Project directory created"
325                },
326                {
327                    "id": "step-2",
328                    "description": "Implement endpoints",
329                    "tool": "write",
330                    "dependencies": ["step-1"],
331                    "success_criteria": "Endpoints respond correctly"
332                }
333            ],
334            "required_tools": ["bash", "write", "read"]
335        }"#;
336
337        let plan = LlmPlanner::parse_plan_response(json).unwrap();
338        assert_eq!(plan.goal, "Build a REST API");
339        assert_eq!(plan.complexity, Complexity::Complex);
340        assert_eq!(plan.steps.len(), 2);
341        assert_eq!(plan.steps[0].id, "step-1");
342        assert_eq!(plan.steps[0].tool, Some("bash".to_string()));
343        assert_eq!(plan.steps[1].dependencies, vec!["step-1".to_string()]);
344        assert_eq!(plan.required_tools, vec!["bash", "write", "read"]);
345    }
346
347    #[test]
348    fn test_parse_plan_response_with_markdown_fences() {
349        let json = "```json\n{\"goal\": \"Test\", \"complexity\": \"Simple\", \"steps\": [{\"id\": \"step-1\", \"description\": \"Do it\"}], \"required_tools\": []}\n```";
350
351        let plan = LlmPlanner::parse_plan_response(json).unwrap();
352        assert_eq!(plan.goal, "Test");
353        assert_eq!(plan.complexity, Complexity::Simple);
354        assert_eq!(plan.steps.len(), 1);
355    }
356
357    #[test]
358    fn test_parse_plan_response_invalid() {
359        let bad_json = "This is not JSON at all";
360        let result = LlmPlanner::parse_plan_response(bad_json);
361        assert!(result.is_err());
362    }
363
364    #[test]
365    fn test_parse_plan_response_unknown_complexity() {
366        let json =
367            r#"{"goal": "Test", "complexity": "Unknown", "steps": [], "required_tools": []}"#;
368        let plan = LlmPlanner::parse_plan_response(json).unwrap();
369        assert_eq!(plan.complexity, Complexity::Medium); // falls back to Medium
370    }
371
372    #[test]
373    fn test_parse_goal_response() {
374        let json = r#"{
375            "description": "Deploy the application to production",
376            "success_criteria": [
377                "All tests pass",
378                "Application is accessible at production URL",
379                "Health check returns 200"
380            ]
381        }"#;
382
383        let goal = LlmPlanner::parse_goal_response(json).unwrap();
384        assert_eq!(goal.description, "Deploy the application to production");
385        assert_eq!(goal.success_criteria.len(), 3);
386        assert_eq!(goal.success_criteria[0], "All tests pass");
387    }
388
389    #[test]
390    fn test_parse_goal_response_invalid() {
391        let result = LlmPlanner::parse_goal_response("not json");
392        assert!(result.is_err());
393    }
394
395    #[test]
396    fn test_parse_achievement_response() {
397        let json = r#"{
398            "achieved": false,
399            "progress": 0.65,
400            "remaining_criteria": ["Health check not verified"]
401        }"#;
402
403        let result = LlmPlanner::parse_achievement_response(json).unwrap();
404        assert!(!result.achieved);
405        assert!((result.progress - 0.65).abs() < f32::EPSILON);
406        assert_eq!(result.remaining_criteria, vec!["Health check not verified"]);
407    }
408
409    #[test]
410    fn test_parse_achievement_response_achieved() {
411        let json = r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#;
412        let result = LlmPlanner::parse_achievement_response(json).unwrap();
413        assert!(result.achieved);
414        assert!((result.progress - 1.0).abs() < f32::EPSILON);
415        assert!(result.remaining_criteria.is_empty());
416    }
417
418    #[test]
419    fn test_parse_achievement_response_clamps_progress() {
420        let json = r#"{"achieved": false, "progress": 1.5, "remaining_criteria": []}"#;
421        let result = LlmPlanner::parse_achievement_response(json).unwrap();
422        assert!((result.progress - 1.0).abs() < f32::EPSILON);
423    }
424
425    #[test]
426    fn test_fallback_plan() {
427        let short_prompt = "Fix bug";
428        let plan = LlmPlanner::fallback_plan(short_prompt);
429        assert_eq!(plan.complexity, Complexity::Simple);
430        assert_eq!(plan.steps.len(), 2);
431        assert_eq!(plan.goal, short_prompt);
432
433        let long_prompt = "Implement a comprehensive authentication system with OAuth2 support, JWT tokens, refresh token rotation, multi-factor authentication, and role-based access control across all API endpoints with proper audit logging and session management capabilities for both web and mobile clients, including password reset flows, account lockout policies, and integration with external identity providers such as Google, GitHub, and SAML-based enterprise SSO systems";
434        let plan = LlmPlanner::fallback_plan(long_prompt);
435        assert_eq!(plan.complexity, Complexity::VeryComplex);
436        assert_eq!(plan.steps.len(), 10);
437    }
438
439    #[test]
440    fn test_fallback_goal() {
441        let goal = LlmPlanner::fallback_goal("Fix the login bug");
442        assert_eq!(goal.description, "Fix the login bug");
443        assert_eq!(goal.success_criteria.len(), 2);
444        assert_eq!(goal.success_criteria[0], "Task is completed successfully");
445    }
446
447    #[test]
448    fn test_fallback_check_achievement_done() {
449        let goal = AgentGoal::new("Test task").with_criteria(vec!["Criterion 1".to_string()]);
450
451        let result = LlmPlanner::fallback_check_achievement(&goal, "The task is done.");
452        assert!(result.achieved);
453        assert!((result.progress - 1.0).abs() < f32::EPSILON);
454        assert!(result.remaining_criteria.is_empty());
455    }
456
457    #[test]
458    fn test_fallback_check_achievement_not_done() {
459        let goal = AgentGoal::new("Test task")
460            .with_criteria(vec!["Criterion 1".to_string(), "Criterion 2".to_string()]);
461
462        let result = LlmPlanner::fallback_check_achievement(&goal, "Work in progress");
463        assert!(!result.achieved);
464        assert_eq!(result.remaining_criteria.len(), 2);
465    }
466
467    #[test]
468    fn test_extract_json_plain() {
469        assert_eq!(LlmPlanner::extract_json("  {\"a\": 1}  "), "{\"a\": 1}");
470    }
471
472    #[test]
473    fn test_extract_json_with_fences() {
474        let text = "```json\n{\"a\": 1}\n```";
475        assert_eq!(LlmPlanner::extract_json(text), "{\"a\": 1}");
476    }
477
478    #[test]
479    fn test_extract_json_with_surrounding_text() {
480        let text = "Here is the plan:\n{\"goal\": \"test\"}\nDone.";
481        assert_eq!(LlmPlanner::extract_json(text), "{\"goal\": \"test\"}");
482    }
483}