1use crate::llm::{LlmClient, Message};
8use crate::planning::{AgentGoal, Complexity, ExecutionPlan, Task};
9use anyhow::{Context, Result};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AchievementResult {
16 pub achieved: bool,
18 pub progress: f32,
20 pub remaining_criteria: Vec<String>,
22}
23
24pub struct LlmPlanner;
26
27#[derive(Debug, Deserialize)]
32struct PlanResponse {
33 goal: String,
34 complexity: String,
35 steps: Vec<StepResponse>,
36 #[serde(default)]
37 required_tools: Vec<String>,
38}
39
40#[derive(Debug, Deserialize)]
41struct StepResponse {
42 id: String,
43 description: String,
44 #[serde(default)]
45 tool: Option<String>,
46 #[serde(default)]
47 dependencies: Vec<String>,
48 #[serde(default)]
49 success_criteria: Option<String>,
50}
51
52#[derive(Debug, Deserialize)]
53struct GoalResponse {
54 description: String,
55 success_criteria: Vec<String>,
56}
57
58#[derive(Debug, Deserialize)]
59struct AchievementResponse {
60 achieved: bool,
61 progress: f32,
62 #[serde(default)]
63 remaining_criteria: Vec<String>,
64}
65
66impl LlmPlanner {
67 pub async fn create_plan(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<ExecutionPlan> {
69 let system = crate::prompts::LLM_PLAN_SYSTEM;
70
71 let messages = vec![Message::user(prompt)];
72 let response = llm
73 .complete(&messages, Some(system), &[])
74 .await
75 .context("LLM call failed during plan creation")?;
76
77 let text = response.text();
78 Self::parse_plan_response(&text)
79 }
80
81 pub async fn extract_goal(llm: &Arc<dyn LlmClient>, prompt: &str) -> Result<AgentGoal> {
83 let system = crate::prompts::LLM_GOAL_EXTRACT_SYSTEM;
84
85 let messages = vec![Message::user(prompt)];
86 let response = llm
87 .complete(&messages, Some(system), &[])
88 .await
89 .context("LLM call failed during goal extraction")?;
90
91 let text = response.text();
92 Self::parse_goal_response(&text)
93 }
94
95 pub async fn check_achievement(
97 llm: &Arc<dyn LlmClient>,
98 goal: &AgentGoal,
99 current_state: &str,
100 ) -> Result<AchievementResult> {
101 let system = crate::prompts::LLM_GOAL_CHECK_SYSTEM;
102
103 let user_message = format!(
104 "Goal: {}\nSuccess Criteria: {}\nCurrent State: {}",
105 goal.description,
106 goal.success_criteria.join("; "),
107 current_state,
108 );
109
110 let messages = vec![Message::user(&user_message)];
111 let response = llm
112 .complete(&messages, Some(system), &[])
113 .await
114 .context("LLM call failed during achievement check")?;
115
116 let text = response.text();
117 Self::parse_achievement_response(&text)
118 }
119
120 pub fn fallback_plan(prompt: &str) -> ExecutionPlan {
122 let complexity = if prompt.len() < 50 {
123 Complexity::Simple
124 } else if prompt.len() < 150 {
125 Complexity::Medium
126 } else if prompt.len() < 300 {
127 Complexity::Complex
128 } else {
129 Complexity::VeryComplex
130 };
131
132 let mut plan = ExecutionPlan::new(prompt, complexity);
133
134 let step_count = match complexity {
135 Complexity::Simple => 2,
136 Complexity::Medium => 4,
137 Complexity::Complex => 7,
138 Complexity::VeryComplex => 10,
139 };
140
141 for i in 0..step_count {
142 let step = Task::new(
143 format!("step-{}", i + 1),
144 crate::prompts::render(
145 crate::prompts::PLAN_FALLBACK_STEP,
146 &[("step_num", &(i + 1).to_string())],
147 ),
148 );
149 plan.add_step(step);
150 }
151
152 plan
153 }
154
155 pub fn fallback_goal(prompt: &str) -> AgentGoal {
157 AgentGoal::new(prompt).with_criteria(vec![
158 "Task is completed successfully".to_string(),
159 "All requirements are met".to_string(),
160 ])
161 }
162
163 pub fn fallback_check_achievement(goal: &AgentGoal, current_state: &str) -> AchievementResult {
165 let state_lower = current_state.to_lowercase();
166 let achieved = state_lower.contains("complete")
167 || state_lower.contains("done")
168 || state_lower.contains("finished");
169
170 let progress = if achieved { 1.0 } else { goal.progress };
171
172 let remaining_criteria = if achieved {
173 Vec::new()
174 } else {
175 goal.success_criteria.clone()
176 };
177
178 AchievementResult {
179 achieved,
180 progress,
181 remaining_criteria,
182 }
183 }
184
185 fn parse_plan_response(text: &str) -> Result<ExecutionPlan> {
190 let cleaned = Self::extract_json(text);
191 let parsed: PlanResponse =
192 serde_json::from_str(cleaned).context("Failed to parse plan JSON from LLM response")?;
193
194 let complexity = match parsed.complexity.as_str() {
195 "Simple" => Complexity::Simple,
196 "Medium" => Complexity::Medium,
197 "Complex" => Complexity::Complex,
198 "VeryComplex" => Complexity::VeryComplex,
199 _ => Complexity::Medium,
200 };
201
202 let mut plan = ExecutionPlan::new(parsed.goal, complexity);
203
204 for step_resp in parsed.steps {
205 let mut task = Task::new(step_resp.id, step_resp.description);
206 if let Some(tool) = step_resp.tool {
207 task = task.with_tool(tool);
208 }
209 if !step_resp.dependencies.is_empty() {
210 task = task.with_dependencies(step_resp.dependencies);
211 }
212 if let Some(criteria) = step_resp.success_criteria {
213 task = task.with_success_criteria(criteria);
214 }
215 plan.add_step(task);
216 }
217
218 for tool in parsed.required_tools {
219 plan.add_required_tool(tool);
220 }
221
222 Ok(plan)
223 }
224
225 fn parse_goal_response(text: &str) -> Result<AgentGoal> {
226 let cleaned = Self::extract_json(text);
227 let parsed: GoalResponse =
228 serde_json::from_str(cleaned).context("Failed to parse goal JSON from LLM response")?;
229
230 Ok(AgentGoal::new(parsed.description).with_criteria(parsed.success_criteria))
231 }
232
233 fn parse_achievement_response(text: &str) -> Result<AchievementResult> {
234 let cleaned = Self::extract_json(text);
235 let parsed: AchievementResponse = serde_json::from_str(cleaned)
236 .context("Failed to parse achievement JSON from LLM response")?;
237
238 Ok(AchievementResult {
239 achieved: parsed.achieved,
240 progress: parsed.progress.clamp(0.0, 1.0),
241 remaining_criteria: parsed.remaining_criteria,
242 })
243 }
244
245 fn extract_json(text: &str) -> &str {
247 let trimmed = text.trim();
248
249 if let Some(start) = trimmed.find('{') {
251 if let Some(end) = trimmed.rfind('}') {
252 if start <= end {
253 return &trimmed[start..=end];
254 }
255 }
256 }
257
258 trimmed
259 }
260}
261
262#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_parse_plan_response() {
272 let json = r#"{
273 "goal": "Build a REST API",
274 "complexity": "Complex",
275 "steps": [
276 {
277 "id": "step-1",
278 "description": "Set up project structure",
279 "tool": "bash",
280 "dependencies": [],
281 "success_criteria": "Project directory created"
282 },
283 {
284 "id": "step-2",
285 "description": "Implement endpoints",
286 "tool": "write",
287 "dependencies": ["step-1"],
288 "success_criteria": "Endpoints respond correctly"
289 }
290 ],
291 "required_tools": ["bash", "write", "read"]
292 }"#;
293
294 let plan = LlmPlanner::parse_plan_response(json).unwrap();
295 assert_eq!(plan.goal, "Build a REST API");
296 assert_eq!(plan.complexity, Complexity::Complex);
297 assert_eq!(plan.steps.len(), 2);
298 assert_eq!(plan.steps[0].id, "step-1");
299 assert_eq!(plan.steps[0].tool, Some("bash".to_string()));
300 assert_eq!(plan.steps[1].dependencies, vec!["step-1".to_string()]);
301 assert_eq!(plan.required_tools, vec!["bash", "write", "read"]);
302 }
303
304 #[test]
305 fn test_parse_plan_response_with_markdown_fences() {
306 let json = "```json\n{\"goal\": \"Test\", \"complexity\": \"Simple\", \"steps\": [{\"id\": \"step-1\", \"description\": \"Do it\"}], \"required_tools\": []}\n```";
307
308 let plan = LlmPlanner::parse_plan_response(json).unwrap();
309 assert_eq!(plan.goal, "Test");
310 assert_eq!(plan.complexity, Complexity::Simple);
311 assert_eq!(plan.steps.len(), 1);
312 }
313
314 #[test]
315 fn test_parse_plan_response_invalid() {
316 let bad_json = "This is not JSON at all";
317 let result = LlmPlanner::parse_plan_response(bad_json);
318 assert!(result.is_err());
319 }
320
321 #[test]
322 fn test_parse_plan_response_unknown_complexity() {
323 let json =
324 r#"{"goal": "Test", "complexity": "Unknown", "steps": [], "required_tools": []}"#;
325 let plan = LlmPlanner::parse_plan_response(json).unwrap();
326 assert_eq!(plan.complexity, Complexity::Medium); }
328
329 #[test]
330 fn test_parse_goal_response() {
331 let json = r#"{
332 "description": "Deploy the application to production",
333 "success_criteria": [
334 "All tests pass",
335 "Application is accessible at production URL",
336 "Health check returns 200"
337 ]
338 }"#;
339
340 let goal = LlmPlanner::parse_goal_response(json).unwrap();
341 assert_eq!(goal.description, "Deploy the application to production");
342 assert_eq!(goal.success_criteria.len(), 3);
343 assert_eq!(goal.success_criteria[0], "All tests pass");
344 }
345
346 #[test]
347 fn test_parse_goal_response_invalid() {
348 let result = LlmPlanner::parse_goal_response("not json");
349 assert!(result.is_err());
350 }
351
352 #[test]
353 fn test_parse_achievement_response() {
354 let json = r#"{
355 "achieved": false,
356 "progress": 0.65,
357 "remaining_criteria": ["Health check not verified"]
358 }"#;
359
360 let result = LlmPlanner::parse_achievement_response(json).unwrap();
361 assert!(!result.achieved);
362 assert!((result.progress - 0.65).abs() < f32::EPSILON);
363 assert_eq!(result.remaining_criteria, vec!["Health check not verified"]);
364 }
365
366 #[test]
367 fn test_parse_achievement_response_achieved() {
368 let json = r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#;
369 let result = LlmPlanner::parse_achievement_response(json).unwrap();
370 assert!(result.achieved);
371 assert!((result.progress - 1.0).abs() < f32::EPSILON);
372 assert!(result.remaining_criteria.is_empty());
373 }
374
375 #[test]
376 fn test_parse_achievement_response_clamps_progress() {
377 let json = r#"{"achieved": false, "progress": 1.5, "remaining_criteria": []}"#;
378 let result = LlmPlanner::parse_achievement_response(json).unwrap();
379 assert!((result.progress - 1.0).abs() < f32::EPSILON);
380 }
381
382 #[test]
383 fn test_fallback_plan() {
384 let short_prompt = "Fix bug";
385 let plan = LlmPlanner::fallback_plan(short_prompt);
386 assert_eq!(plan.complexity, Complexity::Simple);
387 assert_eq!(plan.steps.len(), 2);
388 assert_eq!(plan.goal, short_prompt);
389
390 let long_prompt = "Implement a comprehensive authentication system with OAuth2 support, JWT tokens, refresh token rotation, multi-factor authentication, and role-based access control across all API endpoints with proper audit logging and session management capabilities for both web and mobile clients, including password reset flows, account lockout policies, and integration with external identity providers such as Google, GitHub, and SAML-based enterprise SSO systems";
391 let plan = LlmPlanner::fallback_plan(long_prompt);
392 assert_eq!(plan.complexity, Complexity::VeryComplex);
393 assert_eq!(plan.steps.len(), 10);
394 }
395
396 #[test]
397 fn test_fallback_goal() {
398 let goal = LlmPlanner::fallback_goal("Fix the login bug");
399 assert_eq!(goal.description, "Fix the login bug");
400 assert_eq!(goal.success_criteria.len(), 2);
401 assert_eq!(goal.success_criteria[0], "Task is completed successfully");
402 }
403
404 #[test]
405 fn test_fallback_check_achievement_done() {
406 let goal = AgentGoal::new("Test task").with_criteria(vec!["Criterion 1".to_string()]);
407
408 let result = LlmPlanner::fallback_check_achievement(&goal, "The task is done.");
409 assert!(result.achieved);
410 assert!((result.progress - 1.0).abs() < f32::EPSILON);
411 assert!(result.remaining_criteria.is_empty());
412 }
413
414 #[test]
415 fn test_fallback_check_achievement_not_done() {
416 let goal = AgentGoal::new("Test task")
417 .with_criteria(vec!["Criterion 1".to_string(), "Criterion 2".to_string()]);
418
419 let result = LlmPlanner::fallback_check_achievement(&goal, "Work in progress");
420 assert!(!result.achieved);
421 assert_eq!(result.remaining_criteria.len(), 2);
422 }
423
424 #[test]
425 fn test_extract_json_plain() {
426 assert_eq!(LlmPlanner::extract_json(" {\"a\": 1} "), "{\"a\": 1}");
427 }
428
429 #[test]
430 fn test_extract_json_with_fences() {
431 let text = "```json\n{\"a\": 1}\n```";
432 assert_eq!(LlmPlanner::extract_json(text), "{\"a\": 1}");
433 }
434
435 #[test]
436 fn test_extract_json_with_surrounding_text() {
437 let text = "Here is the plan:\n{\"goal\": \"test\"}\nDone.";
438 assert_eq!(LlmPlanner::extract_json(text), "{\"goal\": \"test\"}");
439 }
440}