Skip to main content

a3s_code_core/planning/
llm_planner.rs

1//! LLM-powered planning logic
2//!
3//! Provides intelligent plan generation, goal extraction, and achievement
4//! evaluation by sending structured prompts to an LLM and parsing JSON responses.
5//! Falls back to heuristic logic when no LLM client is available.
6
7use crate::llm::{LlmClient, Message};
8use crate::planning::{AgentGoal, Complexity, ExecutionPlan, Task};
9use anyhow::{Context, Result};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13/// Result of evaluating goal achievement
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AchievementResult {
16    /// Whether the goal has been achieved
17    pub achieved: bool,
18    /// Progress toward goal (0.0 - 1.0)
19    pub progress: f32,
20    /// Criteria that remain unmet
21    pub remaining_criteria: Vec<String>,
22}
23
24/// LLM-powered planner that generates plans, extracts goals, and evaluates achievement
25pub struct LlmPlanner;
26
27// ============================================================================
28// JSON response schemas for LLM parsing
29// ============================================================================
30
31#[derive(Debug, Deserialize)]
32struct PlanResponse {
33    goal: String,
34    complexity: String,
35    steps: Vec<StepResponse>,
36    #[serde(default)]
37    required_tools: Vec<String>,
38}
39
40#[derive(Debug, Deserialize)]
41struct StepResponse {
42    id: String,
43    description: String,
44    #[serde(default)]
45    tool: Option<String>,
46    #[serde(default)]
47    dependencies: Vec<String>,
48    #[serde(default)]
49    success_criteria: Option<String>,
50}
51
52#[derive(Debug, Deserialize)]
53struct GoalResponse {
54    description: String,
55    success_criteria: Vec<String>,
56}
57
58#[derive(Debug, Deserialize)]
59struct AchievementResponse {
60    achieved: bool,
61    progress: f32,
62    #[serde(default)]
63    remaining_criteria: Vec<String>,
64}
65
66impl LlmPlanner {
67    /// Generate an execution plan from a prompt using LLM
68    pub async fn create_plan(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
69        let system = crate::prompts::LLM_PLAN_SYSTEM;
70
71        let messages = vec![Message::user(prompt)];
72        let response = llm
73            .complete(&messages, Some(system), &[])
74            .await
75            .context("LLM call failed during plan creation")?;
76
77        let text = response.text();
78        Self::parse_plan_response(&text)
79    }
80
81    /// Extract a goal with success criteria from a prompt using LLM
82    pub async fn extract_goal(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
83        let system = crate::prompts::LLM_GOAL_EXTRACT_SYSTEM;
84
85        let messages = vec![Message::user(prompt)];
86        let response = llm
87            .complete(&messages, Some(system), &[])
88            .await
89            .context("LLM call failed during goal extraction")?;
90
91        let text = response.text();
92        Self::parse_goal_response(&text)
93    }
94
95    /// Evaluate whether a goal has been achieved given current state
96    pub async fn check_achievement(
97        llm: &Arc<dyn LlmClient>,
98        goal: &AgentGoal,
99        current_state: &str,
100    ) -> Result<AchievementResult> {
101        let system = crate::prompts::LLM_GOAL_CHECK_SYSTEM;
102
103        let user_message = format!(
104            "Goal: {}\nSuccess Criteria: {}\nCurrent State: {}",
105            goal.description,
106            goal.success_criteria.join("; "),
107            current_state,
108        );
109
110        let messages = vec![Message::user(&user_message)];
111        let response = llm
112            .complete(&messages, Some(system), &[])
113            .await
114            .context("LLM call failed during achievement check")?;
115
116        let text = response.text();
117        Self::parse_achievement_response(&text)
118    }
119
120    /// Create a fallback plan using heuristic logic (no LLM required)
121    pub fn fallback_plan(prompt: &str) -> ExecutionPlan {
122        let complexity = if prompt.len() < 50 {
123            Complexity::Simple
124        } else if prompt.len() < 150 {
125            Complexity::Medium
126        } else if prompt.len() < 300 {
127            Complexity::Complex
128        } else {
129            Complexity::VeryComplex
130        };
131
132        let mut plan = ExecutionPlan::new(prompt, complexity);
133
134        let step_count = match complexity {
135            Complexity::Simple => 2,
136            Complexity::Medium => 4,
137            Complexity::Complex => 7,
138            Complexity::VeryComplex => 10,
139        };
140
141        for i in 0..step_count {
142            let step = Task::new(
143                format!("step-{}", i + 1),
144                crate::prompts::render(
145                    crate::prompts::PLAN_FALLBACK_STEP,
146                    &[("step_num", &(i + 1).to_string())],
147                ),
148            );
149            plan.add_step(step);
150        }
151
152        plan
153    }
154
155    /// Create a fallback goal using heuristic logic (no LLM required)
156    pub fn fallback_goal(prompt: &str) -> AgentGoal {
157        AgentGoal::new(prompt).with_criteria(vec![
158            "Task is completed successfully".to_string(),
159            "All requirements are met".to_string(),
160        ])
161    }
162
163    /// Create a fallback achievement result using heuristic logic (no LLM required)
164    pub fn fallback_check_achievement(goal: &AgentGoal, current_state: &str) -> AchievementResult {
165        let state_lower = current_state.to_lowercase();
166        let achieved = state_lower.contains("complete")
167            || state_lower.contains("done")
168            || state_lower.contains("finished");
169
170        let progress = if achieved { 1.0 } else { goal.progress };
171
172        let remaining_criteria = if achieved {
173            Vec::new()
174        } else {
175            goal.success_criteria.clone()
176        };
177
178        AchievementResult {
179            achieved,
180            progress,
181            remaining_criteria,
182        }
183    }
184
185    // ========================================================================
186    // JSON parsing helpers
187    // ========================================================================
188
189    fn parse_plan_response(text: &str) -> Result<ExecutionPlan> {
190        let cleaned = Self::extract_json(text);
191        let parsed: PlanResponse =
192            serde_json::from_str(cleaned).context("Failed to parse plan JSON from LLM response")?;
193
194        let complexity = match parsed.complexity.as_str() {
195            "Simple" => Complexity::Simple,
196            "Medium" => Complexity::Medium,
197            "Complex" => Complexity::Complex,
198            "VeryComplex" => Complexity::VeryComplex,
199            _ => Complexity::Medium,
200        };
201
202        let mut plan = ExecutionPlan::new(parsed.goal, complexity);
203
204        for step_resp in parsed.steps {
205            let mut task = Task::new(step_resp.id, step_resp.description);
206            if let Some(tool) = step_resp.tool {
207                task = task.with_tool(tool);
208            }
209            if !step_resp.dependencies.is_empty() {
210                task = task.with_dependencies(step_resp.dependencies);
211            }
212            if let Some(criteria) = step_resp.success_criteria {
213                task = task.with_success_criteria(criteria);
214            }
215            plan.add_step(task);
216        }
217
218        for tool in parsed.required_tools {
219            plan.add_required_tool(tool);
220        }
221
222        Ok(plan)
223    }
224
225    fn parse_goal_response(text: &str) -> Result<AgentGoal> {
226        let cleaned = Self::extract_json(text);
227        let parsed: GoalResponse =
228            serde_json::from_str(cleaned).context("Failed to parse goal JSON from LLM response")?;
229
230        Ok(AgentGoal::new(parsed.description).with_criteria(parsed.success_criteria))
231    }
232
233    fn parse_achievement_response(text: &str) -> Result<AchievementResult> {
234        let cleaned = Self::extract_json(text);
235        let parsed: AchievementResponse = serde_json::from_str(cleaned)
236            .context("Failed to parse achievement JSON from LLM response")?;
237
238        Ok(AchievementResult {
239            achieved: parsed.achieved,
240            progress: parsed.progress.clamp(0.0, 1.0),
241            remaining_criteria: parsed.remaining_criteria,
242        })
243    }
244
245    /// Extract JSON from LLM text that may contain markdown fences
246    fn extract_json(text: &str) -> &str {
247        let trimmed = text.trim();
248
249        // Strip markdown code fences if present
250        if let Some(start) = trimmed.find('{') {
251            if let Some(end) = trimmed.rfind('}') {
252                if start <= end {
253                    return &trimmed[start..=end];
254                }
255            }
256        }
257
258        trimmed
259    }
260}
261
262// ============================================================================
263// Tests
264// ============================================================================
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_parse_plan_response() {
272        let json = r#"{
273            "goal": "Build a REST API",
274            "complexity": "Complex",
275            "steps": [
276                {
277                    "id": "step-1",
278                    "description": "Set up project structure",
279                    "tool": "bash",
280                    "dependencies": [],
281                    "success_criteria": "Project directory created"
282                },
283                {
284                    "id": "step-2",
285                    "description": "Implement endpoints",
286                    "tool": "write",
287                    "dependencies": ["step-1"],
288                    "success_criteria": "Endpoints respond correctly"
289                }
290            ],
291            "required_tools": ["bash", "write", "read"]
292        }"#;
293
294        let plan = LlmPlanner::parse_plan_response(json).unwrap();
295        assert_eq!(plan.goal, "Build a REST API");
296        assert_eq!(plan.complexity, Complexity::Complex);
297        assert_eq!(plan.steps.len(), 2);
298        assert_eq!(plan.steps[0].id, "step-1");
299        assert_eq!(plan.steps[0].tool, Some("bash".to_string()));
300        assert_eq!(plan.steps[1].dependencies, vec!["step-1".to_string()]);
301        assert_eq!(plan.required_tools, vec!["bash", "write", "read"]);
302    }
303
304    #[test]
305    fn test_parse_plan_response_with_markdown_fences() {
306        let json = "```json\n{\"goal\": \"Test\", \"complexity\": \"Simple\", \"steps\": [{\"id\": \"step-1\", \"description\": \"Do it\"}], \"required_tools\": []}\n```";
307
308        let plan = LlmPlanner::parse_plan_response(json).unwrap();
309        assert_eq!(plan.goal, "Test");
310        assert_eq!(plan.complexity, Complexity::Simple);
311        assert_eq!(plan.steps.len(), 1);
312    }
313
314    #[test]
315    fn test_parse_plan_response_invalid() {
316        let bad_json = "This is not JSON at all";
317        let result = LlmPlanner::parse_plan_response(bad_json);
318        assert!(result.is_err());
319    }
320
321    #[test]
322    fn test_parse_plan_response_unknown_complexity() {
323        let json =
324            r#"{"goal": "Test", "complexity": "Unknown", "steps": [], "required_tools": []}"#;
325        let plan = LlmPlanner::parse_plan_response(json).unwrap();
326        assert_eq!(plan.complexity, Complexity::Medium); // falls back to Medium
327    }
328
329    #[test]
330    fn test_parse_goal_response() {
331        let json = r#"{
332            "description": "Deploy the application to production",
333            "success_criteria": [
334                "All tests pass",
335                "Application is accessible at production URL",
336                "Health check returns 200"
337            ]
338        }"#;
339
340        let goal = LlmPlanner::parse_goal_response(json).unwrap();
341        assert_eq!(goal.description, "Deploy the application to production");
342        assert_eq!(goal.success_criteria.len(), 3);
343        assert_eq!(goal.success_criteria[0], "All tests pass");
344    }
345
346    #[test]
347    fn test_parse_goal_response_invalid() {
348        let result = LlmPlanner::parse_goal_response("not json");
349        assert!(result.is_err());
350    }
351
352    #[test]
353    fn test_parse_achievement_response() {
354        let json = r#"{
355            "achieved": false,
356            "progress": 0.65,
357            "remaining_criteria": ["Health check not verified"]
358        }"#;
359
360        let result = LlmPlanner::parse_achievement_response(json).unwrap();
361        assert!(!result.achieved);
362        assert!((result.progress - 0.65).abs() < f32::EPSILON);
363        assert_eq!(result.remaining_criteria, vec!["Health check not verified"]);
364    }
365
366    #[test]
367    fn test_parse_achievement_response_achieved() {
368        let json = r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#;
369        let result = LlmPlanner::parse_achievement_response(json).unwrap();
370        assert!(result.achieved);
371        assert!((result.progress - 1.0).abs() < f32::EPSILON);
372        assert!(result.remaining_criteria.is_empty());
373    }
374
375    #[test]
376    fn test_parse_achievement_response_clamps_progress() {
377        let json = r#"{"achieved": false, "progress": 1.5, "remaining_criteria": []}"#;
378        let result = LlmPlanner::parse_achievement_response(json).unwrap();
379        assert!((result.progress - 1.0).abs() < f32::EPSILON);
380    }
381
382    #[test]
383    fn test_fallback_plan() {
384        let short_prompt = "Fix bug";
385        let plan = LlmPlanner::fallback_plan(short_prompt);
386        assert_eq!(plan.complexity, Complexity::Simple);
387        assert_eq!(plan.steps.len(), 2);
388        assert_eq!(plan.goal, short_prompt);
389
390        let long_prompt = "Implement a comprehensive authentication system with OAuth2 support, JWT tokens, refresh token rotation, multi-factor authentication, and role-based access control across all API endpoints with proper audit logging and session management capabilities for both web and mobile clients, including password reset flows, account lockout policies, and integration with external identity providers such as Google, GitHub, and SAML-based enterprise SSO systems";
391        let plan = LlmPlanner::fallback_plan(long_prompt);
392        assert_eq!(plan.complexity, Complexity::VeryComplex);
393        assert_eq!(plan.steps.len(), 10);
394    }
395
396    #[test]
397    fn test_fallback_goal() {
398        let goal = LlmPlanner::fallback_goal("Fix the login bug");
399        assert_eq!(goal.description, "Fix the login bug");
400        assert_eq!(goal.success_criteria.len(), 2);
401        assert_eq!(goal.success_criteria[0], "Task is completed successfully");
402    }
403
404    #[test]
405    fn test_fallback_check_achievement_done() {
406        let goal = AgentGoal::new("Test task").with_criteria(vec!["Criterion 1".to_string()]);
407
408        let result = LlmPlanner::fallback_check_achievement(&goal, "The task is done.");
409        assert!(result.achieved);
410        assert!((result.progress - 1.0).abs() < f32::EPSILON);
411        assert!(result.remaining_criteria.is_empty());
412    }
413
414    #[test]
415    fn test_fallback_check_achievement_not_done() {
416        let goal = AgentGoal::new("Test task")
417            .with_criteria(vec!["Criterion 1".to_string(), "Criterion 2".to_string()]);
418
419        let result = LlmPlanner::fallback_check_achievement(&goal, "Work in progress");
420        assert!(!result.achieved);
421        assert_eq!(result.remaining_criteria.len(), 2);
422    }
423
424    #[test]
425    fn test_extract_json_plain() {
426        assert_eq!(LlmPlanner::extract_json("  {\"a\": 1}  "), "{\"a\": 1}");
427    }
428
429    #[test]
430    fn test_extract_json_with_fences() {
431        let text = "```json\n{\"a\": 1}\n```";
432        assert_eq!(LlmPlanner::extract_json(text), "{\"a\": 1}");
433    }
434
435    #[test]
436    fn test_extract_json_with_surrounding_text() {
437        let text = "Here is the plan:\n{\"goal\": \"test\"}\nDone.";
438        assert_eq!(LlmPlanner::extract_json(text), "{\"goal\": \"test\"}");
439    }
440}