1use crate::llm::{LlmClient, Message};
8use crate::planning::{AgentGoal, Complexity, ExecutionPlan, Task};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AchievementResult {
17 pub achieved: bool,
19 pub progress: f32,
21 pub remaining_criteria: Vec<String>,
23}
24
25#[derive(Debug, Clone)]
27pub struct PreAnalysis {
28 pub intent: crate::prompts::AgentStyle,
29 pub requires_planning: bool,
30 pub goal: AgentGoal,
31 pub execution_plan: ExecutionPlan,
32 pub optimized_input: String,
34}
35
36#[async_trait]
41pub trait Planner: Send + Sync {
42 async fn create_plan(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan>;
44
45 async fn extract_goal(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal>;
47
48 async fn check_achievement(
50 &self,
51 llm: &Arc<dyn LlmClient>,
52 goal: &AgentGoal,
53 current_state: &str,
54 ) -> Result<AchievementResult>;
55}
56
57pub struct LlmPlanner;
59
60#[derive(Debug, Deserialize)]
65struct PlanResponse {
66 goal: String,
67 complexity: String,
68 steps: Vec<StepResponse>,
69 #[serde(default)]
70 required_tools: Vec<String>,
71}
72
73#[derive(Debug, Deserialize)]
74struct StepResponse {
75 id: String,
76 description: String,
77 #[serde(default)]
78 tool: Option<String>,
79 #[serde(default)]
80 dependencies: Vec<String>,
81 #[serde(default)]
82 success_criteria: Option<String>,
83}
84
85#[derive(Debug, Deserialize)]
86struct GoalResponse {
87 description: String,
88 success_criteria: Vec<String>,
89}
90
91#[derive(Debug, Deserialize)]
92struct AchievementResponse {
93 achieved: bool,
94 progress: f32,
95 #[serde(default)]
96 remaining_criteria: Vec<String>,
97}
98
99#[derive(Debug, Deserialize)]
100struct PreAnalysisResponse {
101 intent: String,
102 requires_planning: bool,
103 goal: GoalResponse,
104 execution_plan: PreAnalysisPlan,
105 optimized_input: String,
106}
107
108#[derive(Debug, Deserialize)]
109struct PreAnalysisPlan {
110 complexity: String,
111 steps: Vec<StepResponse>,
112 #[serde(default)]
113 required_tools: Vec<String>,
114}
115
116impl LlmPlanner {
117 pub async fn create_plan(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
119 let system = crate::prompts::LLM_PLAN_SYSTEM;
120
121 let messages = vec![Message::user(prompt)];
122 let response = llm
123 .complete(&messages, Some(system), &[])
124 .await
125 .context("LLM call failed during plan creation")?;
126
127 let text = response.text();
128 Self::parse_plan_response(&text)
129 }
130
131 pub async fn extract_goal(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
133 let system = crate::prompts::LLM_GOAL_EXTRACT_SYSTEM;
134
135 let messages = vec![Message::user(prompt)];
136 let response = llm
137 .complete(&messages, Some(system), &[])
138 .await
139 .context("LLM call failed during goal extraction")?;
140
141 let text = response.text();
142 Self::parse_goal_response(&text)
143 }
144
145 pub async fn check_achievement(
147 llm: &Arc<dyn LlmClient>,
148 goal: &AgentGoal,
149 current_state: &str,
150 ) -> Result<AchievementResult> {
151 let system = crate::prompts::LLM_GOAL_CHECK_SYSTEM;
152
153 let user_message = format!(
154 "Goal: {}\nSuccess Criteria: {}\nCurrent State: {}",
155 goal.description,
156 goal.success_criteria.join("; "),
157 current_state,
158 );
159
160 let messages = vec![Message::user(&user_message)];
161 let response = llm
162 .complete(&messages, Some(system), &[])
163 .await
164 .context("LLM call failed during achievement check")?;
165
166 let text = response.text();
167 Self::parse_achievement_response(&text)
168 }
169
170 pub fn fallback_plan(prompt: &str) -> ExecutionPlan {
172 let complexity = if prompt.len() < 50 {
173 Complexity::Simple
174 } else if prompt.len() < 150 {
175 Complexity::Medium
176 } else if prompt.len() < 300 {
177 Complexity::Complex
178 } else {
179 Complexity::VeryComplex
180 };
181
182 let mut plan = ExecutionPlan::new(prompt, complexity);
183
184 let step_count = match complexity {
185 Complexity::Simple => 2,
186 Complexity::Medium => 4,
187 Complexity::Complex => 7,
188 Complexity::VeryComplex => 10,
189 };
190
191 for i in 0..step_count {
192 let step = Task::new(
193 format!("step-{}", i + 1),
194 crate::prompts::render(
195 crate::prompts::PLAN_FALLBACK_STEP,
196 &[("step_num", &(i + 1).to_string())],
197 ),
198 );
199 plan.add_step(step);
200 }
201
202 plan
203 }
204
205 pub fn fallback_goal(prompt: &str) -> AgentGoal {
207 AgentGoal::new(prompt).with_criteria(vec![
208 "Task is completed successfully".to_string(),
209 "All requirements are met".to_string(),
210 ])
211 }
212
213 pub fn fallback_check_achievement(goal: &AgentGoal, current_state: &str) -> AchievementResult {
215 let state_lower = current_state.to_lowercase();
216 let achieved = state_lower.contains("complete")
217 || state_lower.contains("done")
218 || state_lower.contains("finished");
219
220 let progress = if achieved { 1.0 } else { goal.progress };
221
222 let remaining_criteria = if achieved {
223 Vec::new()
224 } else {
225 goal.success_criteria.clone()
226 };
227
228 AchievementResult {
229 achieved,
230 progress,
231 remaining_criteria,
232 }
233 }
234
235 pub async fn pre_analyze(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<PreAnalysis> {
238 let system = crate::prompts::PRE_ANALYSIS_SYSTEM;
239
240 let messages = vec![Message::user(prompt)];
241 let response = llm
242 .complete(&messages, Some(system), &[])
243 .await
244 .context("LLM pre-analysis call failed")?;
245
246 let text = response.text();
247 Self::parse_pre_analysis_response(&text, prompt)
248 }
249
250 fn parse_pre_analysis_response(text: &str, original_prompt: &str) -> Result<PreAnalysis> {
251 let cleaned = Self::extract_json(text);
252 let parsed: PreAnalysisResponse = serde_json::from_str(cleaned)
253 .context("Failed to parse pre-analysis JSON from LLM response")?;
254
255 let intent = match parsed.intent.to_lowercase().as_str() {
256 "plan" => crate::prompts::AgentStyle::Plan,
257 "explore" => crate::prompts::AgentStyle::Explore,
258 "verification" => crate::prompts::AgentStyle::Verification,
259 "codereview" | "code review" => crate::prompts::AgentStyle::CodeReview,
260 _ => crate::prompts::AgentStyle::GeneralPurpose,
261 };
262
263 let goal_description = parsed.goal.description.clone();
264 let goal =
265 AgentGoal::new(goal_description.clone()).with_criteria(parsed.goal.success_criteria);
266
267 let complexity = match parsed.execution_plan.complexity.as_str() {
268 "Simple" => Complexity::Simple,
269 "Medium" => Complexity::Medium,
270 "Complex" => Complexity::Complex,
271 "VeryComplex" => Complexity::VeryComplex,
272 _ => Complexity::Medium,
273 };
274
275 let mut plan = ExecutionPlan::new(goal_description, complexity);
276 for step_resp in parsed.execution_plan.steps {
277 let mut task = Task::new(step_resp.id, step_resp.description);
278 if let Some(tool) = step_resp.tool {
279 task = task.with_tool(tool);
280 }
281 if !step_resp.dependencies.is_empty() {
282 task = task.with_dependencies(step_resp.dependencies);
283 }
284 if let Some(criteria) = step_resp.success_criteria {
285 task = task.with_success_criteria(criteria);
286 }
287 plan.add_step(task);
288 }
289 for tool in parsed.execution_plan.required_tools {
290 plan.add_required_tool(tool);
291 }
292
293 Ok(PreAnalysis {
294 intent,
295 requires_planning: parsed.requires_planning,
296 goal,
297 execution_plan: plan,
298 optimized_input: if parsed.optimized_input.is_empty() {
299 original_prompt.to_string()
300 } else {
301 parsed.optimized_input
302 },
303 })
304 }
305
306 fn parse_plan_response(text: &str) -> Result<ExecutionPlan> {
311 let cleaned = Self::extract_json(text);
312 let parsed: PlanResponse =
313 serde_json::from_str(cleaned).context("Failed to parse plan JSON from LLM response")?;
314
315 let complexity = match parsed.complexity.as_str() {
316 "Simple" => Complexity::Simple,
317 "Medium" => Complexity::Medium,
318 "Complex" => Complexity::Complex,
319 "VeryComplex" => Complexity::VeryComplex,
320 _ => Complexity::Medium,
321 };
322
323 let mut plan = ExecutionPlan::new(parsed.goal, complexity);
324
325 for step_resp in parsed.steps {
326 let mut task = Task::new(step_resp.id, step_resp.description);
327 if let Some(tool) = step_resp.tool {
328 task = task.with_tool(tool);
329 }
330 if !step_resp.dependencies.is_empty() {
331 task = task.with_dependencies(step_resp.dependencies);
332 }
333 if let Some(criteria) = step_resp.success_criteria {
334 task = task.with_success_criteria(criteria);
335 }
336 plan.add_step(task);
337 }
338
339 for tool in parsed.required_tools {
340 plan.add_required_tool(tool);
341 }
342
343 Ok(plan)
344 }
345
346 fn parse_goal_response(text: &str) -> Result<AgentGoal> {
347 let cleaned = Self::extract_json(text);
348 let parsed: GoalResponse =
349 serde_json::from_str(cleaned).context("Failed to parse goal JSON from LLM response")?;
350
351 Ok(AgentGoal::new(parsed.description).with_criteria(parsed.success_criteria))
352 }
353
354 fn parse_achievement_response(text: &str) -> Result<AchievementResult> {
355 let cleaned = Self::extract_json(text);
356 let parsed: AchievementResponse = serde_json::from_str(cleaned)
357 .context("Failed to parse achievement JSON from LLM response")?;
358
359 Ok(AchievementResult {
360 achieved: parsed.achieved,
361 progress: parsed.progress.clamp(0.0, 1.0),
362 remaining_criteria: parsed.remaining_criteria,
363 })
364 }
365
366 fn extract_json(text: &str) -> &str {
368 let trimmed = text.trim();
369
370 if let Some(start) = trimmed.find('{') {
372 if let Some(end) = trimmed.rfind('}') {
373 if start <= end {
374 return &trimmed[start..=end];
375 }
376 }
377 }
378
379 trimmed
380 }
381}
382
383#[async_trait]
385impl Planner for LlmPlanner {
386 async fn create_plan(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
387 LlmPlanner::create_plan(llm, prompt).await
388 }
389
390 async fn extract_goal(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
391 LlmPlanner::extract_goal(llm, prompt).await
392 }
393
394 async fn check_achievement(
395 &self,
396 llm: &Arc<dyn LlmClient>,
397 goal: &AgentGoal,
398 current_state: &str,
399 ) -> Result<AchievementResult> {
400 LlmPlanner::check_achievement(llm, goal, current_state).await
401 }
402}
403
404#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_parse_plan_response() {
414 let json = r#"{
415 "goal": "Build a REST API",
416 "complexity": "Complex",
417 "steps": [
418 {
419 "id": "step-1",
420 "description": "Set up project structure",
421 "tool": "bash",
422 "dependencies": [],
423 "success_criteria": "Project directory created"
424 },
425 {
426 "id": "step-2",
427 "description": "Implement endpoints",
428 "tool": "write",
429 "dependencies": ["step-1"],
430 "success_criteria": "Endpoints respond correctly"
431 }
432 ],
433 "required_tools": ["bash", "write", "read"]
434 }"#;
435
436 let plan = LlmPlanner::parse_plan_response(json).unwrap();
437 assert_eq!(plan.goal, "Build a REST API");
438 assert_eq!(plan.complexity, Complexity::Complex);
439 assert_eq!(plan.steps.len(), 2);
440 assert_eq!(plan.steps[0].id, "step-1");
441 assert_eq!(plan.steps[0].tool, Some("bash".to_string()));
442 assert_eq!(plan.steps[1].dependencies, vec!["step-1".to_string()]);
443 assert_eq!(plan.required_tools, vec!["bash", "write", "read"]);
444 }
445
446 #[test]
447 fn test_parse_plan_response_with_markdown_fences() {
448 let json = "```json\n{\"goal\": \"Test\", \"complexity\": \"Simple\", \"steps\": [{\"id\": \"step-1\", \"description\": \"Do it\"}], \"required_tools\": []}\n```";
449
450 let plan = LlmPlanner::parse_plan_response(json).unwrap();
451 assert_eq!(plan.goal, "Test");
452 assert_eq!(plan.complexity, Complexity::Simple);
453 assert_eq!(plan.steps.len(), 1);
454 }
455
456 #[test]
457 fn test_parse_plan_response_invalid() {
458 let bad_json = "This is not JSON at all";
459 let result = LlmPlanner::parse_plan_response(bad_json);
460 assert!(result.is_err());
461 }
462
463 #[test]
464 fn test_parse_plan_response_unknown_complexity() {
465 let json =
466 r#"{"goal": "Test", "complexity": "Unknown", "steps": [], "required_tools": []}"#;
467 let plan = LlmPlanner::parse_plan_response(json).unwrap();
468 assert_eq!(plan.complexity, Complexity::Medium); }
470
471 #[test]
472 fn test_parse_goal_response() {
473 let json = r#"{
474 "description": "Deploy the application to production",
475 "success_criteria": [
476 "All tests pass",
477 "Application is accessible at production URL",
478 "Health check returns 200"
479 ]
480 }"#;
481
482 let goal = LlmPlanner::parse_goal_response(json).unwrap();
483 assert_eq!(goal.description, "Deploy the application to production");
484 assert_eq!(goal.success_criteria.len(), 3);
485 assert_eq!(goal.success_criteria[0], "All tests pass");
486 }
487
488 #[test]
489 fn test_parse_goal_response_invalid() {
490 let result = LlmPlanner::parse_goal_response("not json");
491 assert!(result.is_err());
492 }
493
494 #[test]
495 fn test_parse_achievement_response() {
496 let json = r#"{
497 "achieved": false,
498 "progress": 0.65,
499 "remaining_criteria": ["Health check not verified"]
500 }"#;
501
502 let result = LlmPlanner::parse_achievement_response(json).unwrap();
503 assert!(!result.achieved);
504 assert!((result.progress - 0.65).abs() < f32::EPSILON);
505 assert_eq!(result.remaining_criteria, vec!["Health check not verified"]);
506 }
507
508 #[test]
509 fn test_parse_achievement_response_achieved() {
510 let json = r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#;
511 let result = LlmPlanner::parse_achievement_response(json).unwrap();
512 assert!(result.achieved);
513 assert!((result.progress - 1.0).abs() < f32::EPSILON);
514 assert!(result.remaining_criteria.is_empty());
515 }
516
517 #[test]
518 fn test_parse_achievement_response_clamps_progress() {
519 let json = r#"{"achieved": false, "progress": 1.5, "remaining_criteria": []}"#;
520 let result = LlmPlanner::parse_achievement_response(json).unwrap();
521 assert!((result.progress - 1.0).abs() < f32::EPSILON);
522 }
523
524 #[test]
525 fn test_fallback_plan() {
526 let short_prompt = "Fix bug";
527 let plan = LlmPlanner::fallback_plan(short_prompt);
528 assert_eq!(plan.complexity, Complexity::Simple);
529 assert_eq!(plan.steps.len(), 2);
530 assert_eq!(plan.goal, short_prompt);
531
532 let long_prompt = "Implement a comprehensive authentication system with OAuth2 support, JWT tokens, refresh token rotation, multi-factor authentication, and role-based access control across all API endpoints with proper audit logging and session management capabilities for both web and mobile clients, including password reset flows, account lockout policies, and integration with external identity providers such as Google, GitHub, and SAML-based enterprise SSO systems";
533 let plan = LlmPlanner::fallback_plan(long_prompt);
534 assert_eq!(plan.complexity, Complexity::VeryComplex);
535 assert_eq!(plan.steps.len(), 10);
536 }
537
538 #[test]
539 fn test_fallback_goal() {
540 let goal = LlmPlanner::fallback_goal("Fix the login bug");
541 assert_eq!(goal.description, "Fix the login bug");
542 assert_eq!(goal.success_criteria.len(), 2);
543 assert_eq!(goal.success_criteria[0], "Task is completed successfully");
544 }
545
546 #[test]
547 fn test_fallback_check_achievement_done() {
548 let goal = AgentGoal::new("Test task").with_criteria(vec!["Criterion 1".to_string()]);
549
550 let result = LlmPlanner::fallback_check_achievement(&goal, "The task is done.");
551 assert!(result.achieved);
552 assert!((result.progress - 1.0).abs() < f32::EPSILON);
553 assert!(result.remaining_criteria.is_empty());
554 }
555
556 #[test]
557 fn test_fallback_check_achievement_not_done() {
558 let goal = AgentGoal::new("Test task")
559 .with_criteria(vec!["Criterion 1".to_string(), "Criterion 2".to_string()]);
560
561 let result = LlmPlanner::fallback_check_achievement(&goal, "Work in progress");
562 assert!(!result.achieved);
563 assert_eq!(result.remaining_criteria.len(), 2);
564 }
565
566 #[test]
567 fn test_extract_json_plain() {
568 assert_eq!(LlmPlanner::extract_json(" {\"a\": 1} "), "{\"a\": 1}");
569 }
570
571 #[test]
572 fn test_extract_json_with_fences() {
573 let text = "```json\n{\"a\": 1}\n```";
574 assert_eq!(LlmPlanner::extract_json(text), "{\"a\": 1}");
575 }
576
577 #[test]
578 fn test_extract_json_with_surrounding_text() {
579 let text = "Here is the plan:\n{\"goal\": \"test\"}\nDone.";
580 assert_eq!(LlmPlanner::extract_json(text), "{\"goal\": \"test\"}");
581 }
582}