1use crate::llm::{LlmClient, Message};
8use crate::planning::{AgentGoal, Complexity, ExecutionPlan, Task};
9use anyhow::{Context, Result};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AchievementResult {
16 pub achieved: bool,
18 pub progress: f32,
20 pub remaining_criteria: Vec<String>,
22}
23
24#[derive(Debug, Clone)]
26pub struct PreAnalysis {
27 pub intent: crate::prompts::AgentStyle,
28 pub requires_planning: bool,
29 pub goal: AgentGoal,
30 pub execution_plan: ExecutionPlan,
31 pub optimized_input: String,
33}
34
35pub struct LlmPlanner;
37
38#[derive(Debug, Deserialize)]
43struct PlanResponse {
44 goal: String,
45 complexity: String,
46 steps: Vec<StepResponse>,
47 #[serde(default)]
48 required_tools: Vec<String>,
49}
50
51#[derive(Debug, Deserialize)]
52struct StepResponse {
53 id: String,
54 description: String,
55 #[serde(default)]
56 tool: Option<String>,
57 #[serde(default)]
58 dependencies: Vec<String>,
59 #[serde(default)]
60 success_criteria: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64struct GoalResponse {
65 description: String,
66 success_criteria: Vec<String>,
67}
68
69#[derive(Debug, Deserialize)]
70struct AchievementResponse {
71 achieved: bool,
72 progress: f32,
73 #[serde(default)]
74 remaining_criteria: Vec<String>,
75}
76
77#[derive(Debug, Deserialize)]
78struct PreAnalysisResponse {
79 intent: String,
80 requires_planning: bool,
81 goal: GoalResponse,
82 execution_plan: PreAnalysisPlan,
83 optimized_input: String,
84}
85
86#[derive(Debug, Deserialize)]
87struct PreAnalysisPlan {
88 complexity: String,
89 steps: Vec<StepResponse>,
90 #[serde(default)]
91 required_tools: Vec<String>,
92}
93
94impl LlmPlanner {
95 pub async fn create_plan(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
97 let system = crate::prompts::LLM_PLAN_SYSTEM;
98
99 let messages = vec![Message::user(prompt)];
100 let response = llm
101 .complete(&messages, Some(system), &[])
102 .await
103 .context("LLM call failed during plan creation")?;
104
105 let text = response.text();
106 Self::parse_plan_response(&text)
107 }
108
109 pub async fn extract_goal(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
111 let system = crate::prompts::LLM_GOAL_EXTRACT_SYSTEM;
112
113 let messages = vec![Message::user(prompt)];
114 let response = llm
115 .complete(&messages, Some(system), &[])
116 .await
117 .context("LLM call failed during goal extraction")?;
118
119 let text = response.text();
120 Self::parse_goal_response(&text)
121 }
122
123 pub async fn check_achievement(
125 llm: &Arc<dyn LlmClient>,
126 goal: &AgentGoal,
127 current_state: &str,
128 ) -> Result<AchievementResult> {
129 let system = crate::prompts::LLM_GOAL_CHECK_SYSTEM;
130
131 let user_message = format!(
132 "Goal: {}\nSuccess Criteria: {}\nCurrent State: {}",
133 goal.description,
134 goal.success_criteria.join("; "),
135 current_state,
136 );
137
138 let messages = vec![Message::user(&user_message)];
139 let response = llm
140 .complete(&messages, Some(system), &[])
141 .await
142 .context("LLM call failed during achievement check")?;
143
144 let text = response.text();
145 Self::parse_achievement_response(&text)
146 }
147
148 pub fn fallback_plan(prompt: &str) -> ExecutionPlan {
150 let complexity = if prompt.len() < 50 {
151 Complexity::Simple
152 } else if prompt.len() < 150 {
153 Complexity::Medium
154 } else if prompt.len() < 300 {
155 Complexity::Complex
156 } else {
157 Complexity::VeryComplex
158 };
159
160 let mut plan = ExecutionPlan::new(prompt, complexity);
161
162 let step_count = match complexity {
163 Complexity::Simple => 2,
164 Complexity::Medium => 4,
165 Complexity::Complex => 7,
166 Complexity::VeryComplex => 10,
167 };
168
169 for i in 0..step_count {
170 let step = Task::new(
171 format!("step-{}", i + 1),
172 crate::prompts::render(
173 crate::prompts::PLAN_FALLBACK_STEP,
174 &[("step_num", &(i + 1).to_string())],
175 ),
176 );
177 plan.add_step(step);
178 }
179
180 plan
181 }
182
183 pub fn fallback_goal(prompt: &str) -> AgentGoal {
185 AgentGoal::new(prompt).with_criteria(vec![
186 "Task is completed successfully".to_string(),
187 "All requirements are met".to_string(),
188 ])
189 }
190
191 pub fn fallback_check_achievement(goal: &AgentGoal, current_state: &str) -> AchievementResult {
193 let state_lower = current_state.to_lowercase();
194 let achieved = state_lower.contains("complete")
195 || state_lower.contains("done")
196 || state_lower.contains("finished");
197
198 let progress = if achieved { 1.0 } else { goal.progress };
199
200 let remaining_criteria = if achieved {
201 Vec::new()
202 } else {
203 goal.success_criteria.clone()
204 };
205
206 AchievementResult {
207 achieved,
208 progress,
209 remaining_criteria,
210 }
211 }
212
213 pub async fn pre_analyze(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<PreAnalysis> {
216 let system = crate::prompts::PRE_ANALYSIS_SYSTEM;
217
218 let messages = vec![Message::user(prompt)];
219 let response = llm
220 .complete(&messages, Some(system), &[])
221 .await
222 .context("LLM pre-analysis call failed")?;
223
224 let text = response.text();
225 Self::parse_pre_analysis_response(&text, prompt)
226 }
227
228 fn parse_pre_analysis_response(text: &str, original_prompt: &str) -> Result<PreAnalysis> {
229 let cleaned = Self::extract_json(text);
230 let parsed: PreAnalysisResponse = serde_json::from_str(cleaned)
231 .context("Failed to parse pre-analysis JSON from LLM response")?;
232
233 let intent = match parsed.intent.to_lowercase().as_str() {
234 "plan" => crate::prompts::AgentStyle::Plan,
235 "explore" => crate::prompts::AgentStyle::Explore,
236 "verification" => crate::prompts::AgentStyle::Verification,
237 "codereview" | "code review" => crate::prompts::AgentStyle::CodeReview,
238 _ => crate::prompts::AgentStyle::GeneralPurpose,
239 };
240
241 let goal_description = parsed.goal.description.clone();
242 let goal =
243 AgentGoal::new(goal_description.clone()).with_criteria(parsed.goal.success_criteria);
244
245 let complexity = match parsed.execution_plan.complexity.as_str() {
246 "Simple" => Complexity::Simple,
247 "Medium" => Complexity::Medium,
248 "Complex" => Complexity::Complex,
249 "VeryComplex" => Complexity::VeryComplex,
250 _ => Complexity::Medium,
251 };
252
253 let mut plan = ExecutionPlan::new(goal_description, complexity);
254 for step_resp in parsed.execution_plan.steps {
255 let mut task = Task::new(step_resp.id, step_resp.description);
256 if let Some(tool) = step_resp.tool {
257 task = task.with_tool(tool);
258 }
259 if !step_resp.dependencies.is_empty() {
260 task = task.with_dependencies(step_resp.dependencies);
261 }
262 if let Some(criteria) = step_resp.success_criteria {
263 task = task.with_success_criteria(criteria);
264 }
265 plan.add_step(task);
266 }
267 for tool in parsed.execution_plan.required_tools {
268 plan.add_required_tool(tool);
269 }
270
271 Ok(PreAnalysis {
272 intent,
273 requires_planning: parsed.requires_planning,
274 goal,
275 execution_plan: plan,
276 optimized_input: if parsed.optimized_input.is_empty() {
277 original_prompt.to_string()
278 } else {
279 parsed.optimized_input
280 },
281 })
282 }
283
284 fn parse_plan_response(text: &str) -> Result<ExecutionPlan> {
289 let cleaned = Self::extract_json(text);
290 let parsed: PlanResponse =
291 serde_json::from_str(cleaned).context("Failed to parse plan JSON from LLM response")?;
292
293 let complexity = match parsed.complexity.as_str() {
294 "Simple" => Complexity::Simple,
295 "Medium" => Complexity::Medium,
296 "Complex" => Complexity::Complex,
297 "VeryComplex" => Complexity::VeryComplex,
298 _ => Complexity::Medium,
299 };
300
301 let mut plan = ExecutionPlan::new(parsed.goal, complexity);
302
303 for step_resp in parsed.steps {
304 let mut task = Task::new(step_resp.id, step_resp.description);
305 if let Some(tool) = step_resp.tool {
306 task = task.with_tool(tool);
307 }
308 if !step_resp.dependencies.is_empty() {
309 task = task.with_dependencies(step_resp.dependencies);
310 }
311 if let Some(criteria) = step_resp.success_criteria {
312 task = task.with_success_criteria(criteria);
313 }
314 plan.add_step(task);
315 }
316
317 for tool in parsed.required_tools {
318 plan.add_required_tool(tool);
319 }
320
321 Ok(plan)
322 }
323
324 fn parse_goal_response(text: &str) -> Result<AgentGoal> {
325 let cleaned = Self::extract_json(text);
326 let parsed: GoalResponse =
327 serde_json::from_str(cleaned).context("Failed to parse goal JSON from LLM response")?;
328
329 Ok(AgentGoal::new(parsed.description).with_criteria(parsed.success_criteria))
330 }
331
332 fn parse_achievement_response(text: &str) -> Result<AchievementResult> {
333 let cleaned = Self::extract_json(text);
334 let parsed: AchievementResponse = serde_json::from_str(cleaned)
335 .context("Failed to parse achievement JSON from LLM response")?;
336
337 Ok(AchievementResult {
338 achieved: parsed.achieved,
339 progress: parsed.progress.clamp(0.0, 1.0),
340 remaining_criteria: parsed.remaining_criteria,
341 })
342 }
343
344 fn extract_json(text: &str) -> &str {
346 let trimmed = text.trim();
347
348 if let Some(start) = trimmed.find('{') {
350 if let Some(end) = trimmed.rfind('}') {
351 if start <= end {
352 return &trimmed[start..=end];
353 }
354 }
355 }
356
357 trimmed
358 }
359}
360
361#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_parse_plan_response() {
371 let json = r#"{
372 "goal": "Build a REST API",
373 "complexity": "Complex",
374 "steps": [
375 {
376 "id": "step-1",
377 "description": "Set up project structure",
378 "tool": "bash",
379 "dependencies": [],
380 "success_criteria": "Project directory created"
381 },
382 {
383 "id": "step-2",
384 "description": "Implement endpoints",
385 "tool": "write",
386 "dependencies": ["step-1"],
387 "success_criteria": "Endpoints respond correctly"
388 }
389 ],
390 "required_tools": ["bash", "write", "read"]
391 }"#;
392
393 let plan = LlmPlanner::parse_plan_response(json).unwrap();
394 assert_eq!(plan.goal, "Build a REST API");
395 assert_eq!(plan.complexity, Complexity::Complex);
396 assert_eq!(plan.steps.len(), 2);
397 assert_eq!(plan.steps[0].id, "step-1");
398 assert_eq!(plan.steps[0].tool, Some("bash".to_string()));
399 assert_eq!(plan.steps[1].dependencies, vec!["step-1".to_string()]);
400 assert_eq!(plan.required_tools, vec!["bash", "write", "read"]);
401 }
402
403 #[test]
404 fn test_parse_plan_response_with_markdown_fences() {
405 let json = "```json\n{\"goal\": \"Test\", \"complexity\": \"Simple\", \"steps\": [{\"id\": \"step-1\", \"description\": \"Do it\"}], \"required_tools\": []}\n```";
406
407 let plan = LlmPlanner::parse_plan_response(json).unwrap();
408 assert_eq!(plan.goal, "Test");
409 assert_eq!(plan.complexity, Complexity::Simple);
410 assert_eq!(plan.steps.len(), 1);
411 }
412
413 #[test]
414 fn test_parse_plan_response_invalid() {
415 let bad_json = "This is not JSON at all";
416 let result = LlmPlanner::parse_plan_response(bad_json);
417 assert!(result.is_err());
418 }
419
420 #[test]
421 fn test_parse_plan_response_unknown_complexity() {
422 let json =
423 r#"{"goal": "Test", "complexity": "Unknown", "steps": [], "required_tools": []}"#;
424 let plan = LlmPlanner::parse_plan_response(json).unwrap();
425 assert_eq!(plan.complexity, Complexity::Medium); }
427
428 #[test]
429 fn test_parse_goal_response() {
430 let json = r#"{
431 "description": "Deploy the application to production",
432 "success_criteria": [
433 "All tests pass",
434 "Application is accessible at production URL",
435 "Health check returns 200"
436 ]
437 }"#;
438
439 let goal = LlmPlanner::parse_goal_response(json).unwrap();
440 assert_eq!(goal.description, "Deploy the application to production");
441 assert_eq!(goal.success_criteria.len(), 3);
442 assert_eq!(goal.success_criteria[0], "All tests pass");
443 }
444
445 #[test]
446 fn test_parse_goal_response_invalid() {
447 let result = LlmPlanner::parse_goal_response("not json");
448 assert!(result.is_err());
449 }
450
451 #[test]
452 fn test_parse_achievement_response() {
453 let json = r#"{
454 "achieved": false,
455 "progress": 0.65,
456 "remaining_criteria": ["Health check not verified"]
457 }"#;
458
459 let result = LlmPlanner::parse_achievement_response(json).unwrap();
460 assert!(!result.achieved);
461 assert!((result.progress - 0.65).abs() < f32::EPSILON);
462 assert_eq!(result.remaining_criteria, vec!["Health check not verified"]);
463 }
464
465 #[test]
466 fn test_parse_achievement_response_achieved() {
467 let json = r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#;
468 let result = LlmPlanner::parse_achievement_response(json).unwrap();
469 assert!(result.achieved);
470 assert!((result.progress - 1.0).abs() < f32::EPSILON);
471 assert!(result.remaining_criteria.is_empty());
472 }
473
474 #[test]
475 fn test_parse_achievement_response_clamps_progress() {
476 let json = r#"{"achieved": false, "progress": 1.5, "remaining_criteria": []}"#;
477 let result = LlmPlanner::parse_achievement_response(json).unwrap();
478 assert!((result.progress - 1.0).abs() < f32::EPSILON);
479 }
480
481 #[test]
482 fn test_fallback_plan() {
483 let short_prompt = "Fix bug";
484 let plan = LlmPlanner::fallback_plan(short_prompt);
485 assert_eq!(plan.complexity, Complexity::Simple);
486 assert_eq!(plan.steps.len(), 2);
487 assert_eq!(plan.goal, short_prompt);
488
489 let long_prompt = "Implement a comprehensive authentication system with OAuth2 support, JWT tokens, refresh token rotation, multi-factor authentication, and role-based access control across all API endpoints with proper audit logging and session management capabilities for both web and mobile clients, including password reset flows, account lockout policies, and integration with external identity providers such as Google, GitHub, and SAML-based enterprise SSO systems";
490 let plan = LlmPlanner::fallback_plan(long_prompt);
491 assert_eq!(plan.complexity, Complexity::VeryComplex);
492 assert_eq!(plan.steps.len(), 10);
493 }
494
495 #[test]
496 fn test_fallback_goal() {
497 let goal = LlmPlanner::fallback_goal("Fix the login bug");
498 assert_eq!(goal.description, "Fix the login bug");
499 assert_eq!(goal.success_criteria.len(), 2);
500 assert_eq!(goal.success_criteria[0], "Task is completed successfully");
501 }
502
503 #[test]
504 fn test_fallback_check_achievement_done() {
505 let goal = AgentGoal::new("Test task").with_criteria(vec!["Criterion 1".to_string()]);
506
507 let result = LlmPlanner::fallback_check_achievement(&goal, "The task is done.");
508 assert!(result.achieved);
509 assert!((result.progress - 1.0).abs() < f32::EPSILON);
510 assert!(result.remaining_criteria.is_empty());
511 }
512
513 #[test]
514 fn test_fallback_check_achievement_not_done() {
515 let goal = AgentGoal::new("Test task")
516 .with_criteria(vec!["Criterion 1".to_string(), "Criterion 2".to_string()]);
517
518 let result = LlmPlanner::fallback_check_achievement(&goal, "Work in progress");
519 assert!(!result.achieved);
520 assert_eq!(result.remaining_criteria.len(), 2);
521 }
522
523 #[test]
524 fn test_extract_json_plain() {
525 assert_eq!(LlmPlanner::extract_json(" {\"a\": 1} "), "{\"a\": 1}");
526 }
527
528 #[test]
529 fn test_extract_json_with_fences() {
530 let text = "```json\n{\"a\": 1}\n```";
531 assert_eq!(LlmPlanner::extract_json(text), "{\"a\": 1}");
532 }
533
534 #[test]
535 fn test_extract_json_with_surrounding_text() {
536 let text = "Here is the plan:\n{\"goal\": \"test\"}\nDone.";
537 assert_eq!(LlmPlanner::extract_json(text), "{\"goal\": \"test\"}");
538 }
539}