1use crate::llm::{LlmClient, Message};
8use crate::planning::{AgentGoal, Complexity, ExecutionPlan, Task};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AchievementResult {
17 pub achieved: bool,
19 pub progress: f32,
21 pub remaining_criteria: Vec<String>,
23}
24
25#[async_trait]
30pub trait Planner: Send + Sync {
31 async fn create_plan(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan>;
33
34 async fn extract_goal(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal>;
36
37 async fn check_achievement(
39 &self,
40 llm: &Arc<dyn LlmClient>,
41 goal: &AgentGoal,
42 current_state: &str,
43 ) -> Result<AchievementResult>;
44}
45
46pub struct LlmPlanner;
48
49#[derive(Debug, Deserialize)]
54struct PlanResponse {
55 goal: String,
56 complexity: String,
57 steps: Vec<StepResponse>,
58 #[serde(default)]
59 required_tools: Vec<String>,
60}
61
62#[derive(Debug, Deserialize)]
63struct StepResponse {
64 id: String,
65 description: String,
66 #[serde(default)]
67 tool: Option<String>,
68 #[serde(default)]
69 dependencies: Vec<String>,
70 #[serde(default)]
71 success_criteria: Option<String>,
72}
73
74#[derive(Debug, Deserialize)]
75struct GoalResponse {
76 description: String,
77 success_criteria: Vec<String>,
78}
79
80#[derive(Debug, Deserialize)]
81struct AchievementResponse {
82 achieved: bool,
83 progress: f32,
84 #[serde(default)]
85 remaining_criteria: Vec<String>,
86}
87
88impl LlmPlanner {
89 pub async fn create_plan(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
91 let system = crate::prompts::LLM_PLAN_SYSTEM;
92
93 let messages = vec![Message::user(prompt)];
94 let response = llm
95 .complete(&messages, Some(system), &[])
96 .await
97 .context("LLM call failed during plan creation")?;
98
99 let text = response.text();
100 Self::parse_plan_response(&text)
101 }
102
103 pub async fn extract_goal(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
105 let system = crate::prompts::LLM_GOAL_EXTRACT_SYSTEM;
106
107 let messages = vec![Message::user(prompt)];
108 let response = llm
109 .complete(&messages, Some(system), &[])
110 .await
111 .context("LLM call failed during goal extraction")?;
112
113 let text = response.text();
114 Self::parse_goal_response(&text)
115 }
116
117 pub async fn check_achievement(
119 llm: &Arc<dyn LlmClient>,
120 goal: &AgentGoal,
121 current_state: &str,
122 ) -> Result<AchievementResult> {
123 let system = crate::prompts::LLM_GOAL_CHECK_SYSTEM;
124
125 let user_message = format!(
126 "Goal: {}\nSuccess Criteria: {}\nCurrent State: {}",
127 goal.description,
128 goal.success_criteria.join("; "),
129 current_state,
130 );
131
132 let messages = vec![Message::user(&user_message)];
133 let response = llm
134 .complete(&messages, Some(system), &[])
135 .await
136 .context("LLM call failed during achievement check")?;
137
138 let text = response.text();
139 Self::parse_achievement_response(&text)
140 }
141
142 pub fn fallback_plan(prompt: &str) -> ExecutionPlan {
144 let complexity = if prompt.len() < 50 {
145 Complexity::Simple
146 } else if prompt.len() < 150 {
147 Complexity::Medium
148 } else if prompt.len() < 300 {
149 Complexity::Complex
150 } else {
151 Complexity::VeryComplex
152 };
153
154 let mut plan = ExecutionPlan::new(prompt, complexity);
155
156 let step_count = match complexity {
157 Complexity::Simple => 2,
158 Complexity::Medium => 4,
159 Complexity::Complex => 7,
160 Complexity::VeryComplex => 10,
161 };
162
163 for i in 0..step_count {
164 let step = Task::new(
165 format!("step-{}", i + 1),
166 crate::prompts::render(
167 crate::prompts::PLAN_FALLBACK_STEP,
168 &[("step_num", &(i + 1).to_string())],
169 ),
170 );
171 plan.add_step(step);
172 }
173
174 plan
175 }
176
177 pub fn fallback_goal(prompt: &str) -> AgentGoal {
179 AgentGoal::new(prompt).with_criteria(vec![
180 "Task is completed successfully".to_string(),
181 "All requirements are met".to_string(),
182 ])
183 }
184
185 pub fn fallback_check_achievement(goal: &AgentGoal, current_state: &str) -> AchievementResult {
187 let state_lower = current_state.to_lowercase();
188 let achieved = state_lower.contains("complete")
189 || state_lower.contains("done")
190 || state_lower.contains("finished");
191
192 let progress = if achieved { 1.0 } else { goal.progress };
193
194 let remaining_criteria = if achieved {
195 Vec::new()
196 } else {
197 goal.success_criteria.clone()
198 };
199
200 AchievementResult {
201 achieved,
202 progress,
203 remaining_criteria,
204 }
205 }
206
207 fn parse_plan_response(text: &str) -> Result<ExecutionPlan> {
212 let cleaned = Self::extract_json(text);
213 let parsed: PlanResponse =
214 serde_json::from_str(cleaned).context("Failed to parse plan JSON from LLM response")?;
215
216 let complexity = match parsed.complexity.as_str() {
217 "Simple" => Complexity::Simple,
218 "Medium" => Complexity::Medium,
219 "Complex" => Complexity::Complex,
220 "VeryComplex" => Complexity::VeryComplex,
221 _ => Complexity::Medium,
222 };
223
224 let mut plan = ExecutionPlan::new(parsed.goal, complexity);
225
226 for step_resp in parsed.steps {
227 let mut task = Task::new(step_resp.id, step_resp.description);
228 if let Some(tool) = step_resp.tool {
229 task = task.with_tool(tool);
230 }
231 if !step_resp.dependencies.is_empty() {
232 task = task.with_dependencies(step_resp.dependencies);
233 }
234 if let Some(criteria) = step_resp.success_criteria {
235 task = task.with_success_criteria(criteria);
236 }
237 plan.add_step(task);
238 }
239
240 for tool in parsed.required_tools {
241 plan.add_required_tool(tool);
242 }
243
244 Ok(plan)
245 }
246
247 fn parse_goal_response(text: &str) -> Result<AgentGoal> {
248 let cleaned = Self::extract_json(text);
249 let parsed: GoalResponse =
250 serde_json::from_str(cleaned).context("Failed to parse goal JSON from LLM response")?;
251
252 Ok(AgentGoal::new(parsed.description).with_criteria(parsed.success_criteria))
253 }
254
255 fn parse_achievement_response(text: &str) -> Result<AchievementResult> {
256 let cleaned = Self::extract_json(text);
257 let parsed: AchievementResponse = serde_json::from_str(cleaned)
258 .context("Failed to parse achievement JSON from LLM response")?;
259
260 Ok(AchievementResult {
261 achieved: parsed.achieved,
262 progress: parsed.progress.clamp(0.0, 1.0),
263 remaining_criteria: parsed.remaining_criteria,
264 })
265 }
266
267 fn extract_json(text: &str) -> &str {
269 let trimmed = text.trim();
270
271 if let Some(start) = trimmed.find('{') {
273 if let Some(end) = trimmed.rfind('}') {
274 if start <= end {
275 return &trimmed[start..=end];
276 }
277 }
278 }
279
280 trimmed
281 }
282}
283
284#[async_trait]
286impl Planner for LlmPlanner {
287 async fn create_plan(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
288 LlmPlanner::create_plan(llm, prompt).await
289 }
290
291 async fn extract_goal(&self, llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
292 LlmPlanner::extract_goal(llm, prompt).await
293 }
294
295 async fn check_achievement(
296 &self,
297 llm: &Arc<dyn LlmClient>,
298 goal: &AgentGoal,
299 current_state: &str,
300 ) -> Result<AchievementResult> {
301 LlmPlanner::check_achievement(llm, goal, current_state).await
302 }
303}
304
305#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_parse_plan_response() {
315 let json = r#"{
316 "goal": "Build a REST API",
317 "complexity": "Complex",
318 "steps": [
319 {
320 "id": "step-1",
321 "description": "Set up project structure",
322 "tool": "bash",
323 "dependencies": [],
324 "success_criteria": "Project directory created"
325 },
326 {
327 "id": "step-2",
328 "description": "Implement endpoints",
329 "tool": "write",
330 "dependencies": ["step-1"],
331 "success_criteria": "Endpoints respond correctly"
332 }
333 ],
334 "required_tools": ["bash", "write", "read"]
335 }"#;
336
337 let plan = LlmPlanner::parse_plan_response(json).unwrap();
338 assert_eq!(plan.goal, "Build a REST API");
339 assert_eq!(plan.complexity, Complexity::Complex);
340 assert_eq!(plan.steps.len(), 2);
341 assert_eq!(plan.steps[0].id, "step-1");
342 assert_eq!(plan.steps[0].tool, Some("bash".to_string()));
343 assert_eq!(plan.steps[1].dependencies, vec!["step-1".to_string()]);
344 assert_eq!(plan.required_tools, vec!["bash", "write", "read"]);
345 }
346
347 #[test]
348 fn test_parse_plan_response_with_markdown_fences() {
349 let json = "```json\n{\"goal\": \"Test\", \"complexity\": \"Simple\", \"steps\": [{\"id\": \"step-1\", \"description\": \"Do it\"}], \"required_tools\": []}\n```";
350
351 let plan = LlmPlanner::parse_plan_response(json).unwrap();
352 assert_eq!(plan.goal, "Test");
353 assert_eq!(plan.complexity, Complexity::Simple);
354 assert_eq!(plan.steps.len(), 1);
355 }
356
357 #[test]
358 fn test_parse_plan_response_invalid() {
359 let bad_json = "This is not JSON at all";
360 let result = LlmPlanner::parse_plan_response(bad_json);
361 assert!(result.is_err());
362 }
363
364 #[test]
365 fn test_parse_plan_response_unknown_complexity() {
366 let json =
367 r#"{"goal": "Test", "complexity": "Unknown", "steps": [], "required_tools": []}"#;
368 let plan = LlmPlanner::parse_plan_response(json).unwrap();
369 assert_eq!(plan.complexity, Complexity::Medium); }
371
372 #[test]
373 fn test_parse_goal_response() {
374 let json = r#"{
375 "description": "Deploy the application to production",
376 "success_criteria": [
377 "All tests pass",
378 "Application is accessible at production URL",
379 "Health check returns 200"
380 ]
381 }"#;
382
383 let goal = LlmPlanner::parse_goal_response(json).unwrap();
384 assert_eq!(goal.description, "Deploy the application to production");
385 assert_eq!(goal.success_criteria.len(), 3);
386 assert_eq!(goal.success_criteria[0], "All tests pass");
387 }
388
389 #[test]
390 fn test_parse_goal_response_invalid() {
391 let result = LlmPlanner::parse_goal_response("not json");
392 assert!(result.is_err());
393 }
394
395 #[test]
396 fn test_parse_achievement_response() {
397 let json = r#"{
398 "achieved": false,
399 "progress": 0.65,
400 "remaining_criteria": ["Health check not verified"]
401 }"#;
402
403 let result = LlmPlanner::parse_achievement_response(json).unwrap();
404 assert!(!result.achieved);
405 assert!((result.progress - 0.65).abs() < f32::EPSILON);
406 assert_eq!(result.remaining_criteria, vec!["Health check not verified"]);
407 }
408
409 #[test]
410 fn test_parse_achievement_response_achieved() {
411 let json = r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#;
412 let result = LlmPlanner::parse_achievement_response(json).unwrap();
413 assert!(result.achieved);
414 assert!((result.progress - 1.0).abs() < f32::EPSILON);
415 assert!(result.remaining_criteria.is_empty());
416 }
417
418 #[test]
419 fn test_parse_achievement_response_clamps_progress() {
420 let json = r#"{"achieved": false, "progress": 1.5, "remaining_criteria": []}"#;
421 let result = LlmPlanner::parse_achievement_response(json).unwrap();
422 assert!((result.progress - 1.0).abs() < f32::EPSILON);
423 }
424
425 #[test]
426 fn test_fallback_plan() {
427 let short_prompt = "Fix bug";
428 let plan = LlmPlanner::fallback_plan(short_prompt);
429 assert_eq!(plan.complexity, Complexity::Simple);
430 assert_eq!(plan.steps.len(), 2);
431 assert_eq!(plan.goal, short_prompt);
432
433 let long_prompt = "Implement a comprehensive authentication system with OAuth2 support, JWT tokens, refresh token rotation, multi-factor authentication, and role-based access control across all API endpoints with proper audit logging and session management capabilities for both web and mobile clients, including password reset flows, account lockout policies, and integration with external identity providers such as Google, GitHub, and SAML-based enterprise SSO systems";
434 let plan = LlmPlanner::fallback_plan(long_prompt);
435 assert_eq!(plan.complexity, Complexity::VeryComplex);
436 assert_eq!(plan.steps.len(), 10);
437 }
438
439 #[test]
440 fn test_fallback_goal() {
441 let goal = LlmPlanner::fallback_goal("Fix the login bug");
442 assert_eq!(goal.description, "Fix the login bug");
443 assert_eq!(goal.success_criteria.len(), 2);
444 assert_eq!(goal.success_criteria[0], "Task is completed successfully");
445 }
446
447 #[test]
448 fn test_fallback_check_achievement_done() {
449 let goal = AgentGoal::new("Test task").with_criteria(vec!["Criterion 1".to_string()]);
450
451 let result = LlmPlanner::fallback_check_achievement(&goal, "The task is done.");
452 assert!(result.achieved);
453 assert!((result.progress - 1.0).abs() < f32::EPSILON);
454 assert!(result.remaining_criteria.is_empty());
455 }
456
457 #[test]
458 fn test_fallback_check_achievement_not_done() {
459 let goal = AgentGoal::new("Test task")
460 .with_criteria(vec!["Criterion 1".to_string(), "Criterion 2".to_string()]);
461
462 let result = LlmPlanner::fallback_check_achievement(&goal, "Work in progress");
463 assert!(!result.achieved);
464 assert_eq!(result.remaining_criteria.len(), 2);
465 }
466
467 #[test]
468 fn test_extract_json_plain() {
469 assert_eq!(LlmPlanner::extract_json(" {\"a\": 1} "), "{\"a\": 1}");
470 }
471
472 #[test]
473 fn test_extract_json_with_fences() {
474 let text = "```json\n{\"a\": 1}\n```";
475 assert_eq!(LlmPlanner::extract_json(text), "{\"a\": 1}");
476 }
477
478 #[test]
479 fn test_extract_json_with_surrounding_text() {
480 let text = "Here is the plan:\n{\"goal\": \"test\"}\nDone.";
481 assert_eq!(LlmPlanner::extract_json(text), "{\"goal\": \"test\"}");
482 }
483}