1use serde::Deserialize;
13use serde::Serialize;
14use std::collections::HashMap;
15use thiserror::Error;
16use uuid::Uuid;
17
18#[derive(Error, Debug)]
20pub enum PlanError {
21 #[error("circular dependency detected in task chain: {chain:?}")]
22 CircularDependency { chain: Vec<String> },
23
24 #[error("invalid goal format: {0}")]
25 InvalidGoal(String),
26
27 #[error("task decomposition failed: {0}")]
28 DecompositionFailed(String),
29
30 #[error("dependency resolution failed: {0}")]
31 DependencyResolutionFailed(String),
32
33 #[error("parallelization analysis failed: {0}")]
34 ParallelizationFailed(String),
35}
36
37pub type PlanResult<T> = std::result::Result<T, PlanError>;
39
40pub type TaskId = Uuid;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45pub enum Complexity {
46 Simple,
47 Medium,
48 Complex,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct Task {
54 pub id: TaskId,
56
57 pub description: String,
59
60 pub depends_on: Vec<TaskId>,
62
63 pub can_parallelize: bool,
65}
66
67impl Task {
68 pub fn new(description: String) -> Self {
70 Self {
71 id: Uuid::new_v4(),
72 description,
73 depends_on: Vec::new(),
74 can_parallelize: true,
75 }
76 }
77
78 pub fn depends_on(mut self, dependency: TaskId) -> Self {
80 self.depends_on.push(dependency);
81 self
82 }
83
84 pub const fn sequential(mut self) -> Self {
86 self.can_parallelize = false;
87 self
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct Plan {
94 pub tasks: Vec<Task>,
96
97 pub dependency_graph: HashMap<TaskId, Vec<TaskId>>,
99
100 pub parallel_groups: Vec<Vec<TaskId>>,
102
103 pub estimated_complexity: Complexity,
105}
106
107impl Plan {
108 pub fn new(tasks: Vec<Task>, complexity: Complexity) -> Self {
110 let dependency_graph = Self::build_dependency_graph(&tasks);
111 let parallel_groups = Self::identify_parallel_groups(&tasks, &dependency_graph);
112
113 Self {
114 tasks,
115 dependency_graph,
116 parallel_groups,
117 estimated_complexity: complexity,
118 }
119 }
120
121 fn build_dependency_graph(tasks: &[Task]) -> HashMap<TaskId, Vec<TaskId>> {
123 tasks
124 .iter()
125 .map(|task| (task.id, task.depends_on.clone()))
126 .collect()
127 }
128
129 fn identify_parallel_groups(
131 tasks: &[Task],
132 _dependency_graph: &HashMap<TaskId, Vec<TaskId>>,
133 ) -> Vec<Vec<TaskId>> {
134 let mut groups = Vec::new();
135 let mut processed = std::collections::HashSet::new();
136
137 let independent_tasks: Vec<TaskId> = tasks
139 .iter()
140 .filter(|task| task.depends_on.is_empty() && task.can_parallelize)
141 .map(|task| task.id)
142 .collect();
143
144 if !independent_tasks.is_empty() {
145 groups.push(independent_tasks.clone());
146 processed.extend(independent_tasks);
147 }
148
149 while processed.len() < tasks.len() {
151 let current_level: Vec<TaskId> = tasks
152 .iter()
153 .filter(|task| {
154 !processed.contains(&task.id)
155 && task.depends_on.iter().all(|dep| processed.contains(dep))
156 && task.can_parallelize
157 })
158 .map(|task| task.id)
159 .collect();
160
161 if current_level.is_empty() {
162 let remaining: Vec<TaskId> = tasks
164 .iter()
165 .filter(|task| !processed.contains(&task.id))
166 .map(|task| task.id)
167 .collect();
168
169 for task_id in remaining {
170 groups.push(vec![task_id]);
171 processed.insert(task_id);
172 }
173 break;
174 }
175
176 if current_level.len() > 1 {
177 groups.push(current_level.clone());
178 } else {
179 groups.extend(current_level.iter().map(|&id| vec![id]));
181 }
182
183 processed.extend(current_level);
184 }
185
186 groups
187 }
188}
189
190pub struct PlanTool;
192
193impl PlanTool {
194 pub const fn new() -> Self {
196 Self
197 }
198
199 pub fn plan(&self, goal: &str) -> PlanResult<Plan> {
201 if goal.trim().is_empty() {
202 return Err(PlanError::InvalidGoal("Goal cannot be empty".to_string()));
203 }
204
205 let tasks = self.decompose_goal(goal)?;
206 let complexity = self.estimate_complexity(goal, &tasks);
207
208 Ok(Plan::new(tasks, complexity))
209 }
210
211 fn decompose_goal(&self, goal: &str) -> PlanResult<Vec<Task>> {
213 let goal_lower = goal.to_lowercase();
214 let mut tasks;
215
216 if goal_lower.contains("add")
218 && (goal_lower.contains("feature") || goal_lower.contains("component"))
219 {
220 tasks = self.decompose_add_feature(goal);
221 } else if goal_lower.contains("refactor") {
222 tasks = self.decompose_refactor(goal);
223 } else if goal_lower.contains("fix") || goal_lower.contains("bug") {
224 tasks = self.decompose_bug_fix(goal);
225 } else if goal_lower.contains("test") {
226 tasks = self.decompose_testing(goal);
227 } else if goal_lower.contains("optimize") || goal_lower.contains("performance") {
228 tasks = self.decompose_optimization(goal);
229 } else {
230 tasks = self.decompose_generic(goal);
231 }
232
233 self.setup_dependencies(&mut tasks)?;
234 Ok(tasks)
235 }
236
237 fn decompose_add_feature(&self, goal: &str) -> Vec<Task> {
239 let analyze = Task::new(format!("Analyze requirements for: {}", goal));
240 let implement =
241 Task::new(format!("Implement core functionality for: {}", goal)).depends_on(analyze.id);
242 let test = Task::new(format!("Write tests for: {}", goal)).depends_on(implement.id);
243 let document =
244 Task::new(format!("Document implementation for: {}", goal)).depends_on(implement.id);
245 let review = Task::new(format!("Review implementation for: {}", goal))
246 .depends_on(implement.id)
247 .depends_on(test.id)
248 .sequential();
249
250 vec![analyze, implement, test, document, review]
251 }
252
253 fn decompose_refactor(&self, goal: &str) -> Vec<Task> {
255 let analyze = Task::new(format!("Analyze current code for: {}", goal));
256 let refactor = Task::new(format!("Execute refactoring for: {}", goal))
257 .depends_on(analyze.id)
258 .sequential();
259 let validate = Task::new(format!("Validate refactoring with tests for: {}", goal))
260 .depends_on(refactor.id)
261 .sequential();
262
263 vec![analyze, refactor, validate]
264 }
265
266 fn decompose_bug_fix(&self, goal: &str) -> Vec<Task> {
268 let investigate = Task::new(format!("Investigate and diagnose: {}", goal));
269 let fix = Task::new(format!("Implement fix for: {}", goal))
270 .depends_on(investigate.id)
271 .sequential();
272 let test = Task::new(format!("Add regression tests for: {}", goal))
273 .depends_on(fix.id)
274 .sequential();
275
276 vec![investigate, fix, test]
277 }
278
279 fn decompose_testing(&self, goal: &str) -> Vec<Task> {
281 let analyze = Task::new(format!("Analyze testing requirements for: {}", goal));
282 let unit_tests =
283 Task::new(format!("Write unit tests for: {}", goal)).depends_on(analyze.id);
284 let integration_tests =
285 Task::new(format!("Write integration tests for: {}", goal)).depends_on(analyze.id);
286
287 vec![analyze, unit_tests, integration_tests]
288 }
289
290 fn decompose_optimization(&self, goal: &str) -> Vec<Task> {
292 let analyze = Task::new(format!("Analyze performance bottlenecks for: {}", goal));
293 let optimize = Task::new(format!("Implement optimizations for: {}", goal))
294 .depends_on(analyze.id)
295 .sequential();
296 let validate = Task::new(format!("Validate performance improvements for: {}", goal))
297 .depends_on(optimize.id)
298 .sequential();
299
300 vec![analyze, optimize, validate]
301 }
302
303 fn decompose_generic(&self, goal: &str) -> Vec<Task> {
305 let analyze = Task::new(format!("Analyze requirements for: {}", goal));
306 let implement =
307 Task::new(format!("Implement solution for: {}", goal)).depends_on(analyze.id);
308 let validate =
309 Task::new(format!("Validate and test solution for: {}", goal)).depends_on(implement.id);
310
311 vec![analyze, implement, validate]
312 }
313
314 fn setup_dependencies(&self, tasks: &mut [Task]) -> PlanResult<()> {
316 self.check_circular_dependencies(tasks)?;
318 Ok(())
319 }
320
321 fn check_circular_dependencies(&self, tasks: &[Task]) -> PlanResult<()> {
323 for task in tasks {
324 if self.has_circular_dependency(task, tasks, &mut std::collections::HashSet::new())? {
325 return Err(PlanError::CircularDependency {
326 chain: vec![format!("Task: {}", task.description)],
327 });
328 }
329 }
330 Ok(())
331 }
332
333 fn has_circular_dependency(
335 &self,
336 task: &Task,
337 all_tasks: &[Task],
338 visited: &mut std::collections::HashSet<TaskId>,
339 ) -> PlanResult<bool> {
340 if visited.contains(&task.id) {
341 return Ok(true);
342 }
343
344 visited.insert(task.id);
345
346 for &dep_id in &task.depends_on {
347 if let Some(dep_task) = all_tasks.iter().find(|t| t.id == dep_id)
348 && self.has_circular_dependency(dep_task, all_tasks, visited)?
349 {
350 return Ok(true);
351 }
352 }
353
354 visited.remove(&task.id);
355 Ok(false)
356 }
357
358 fn estimate_complexity(&self, goal: &str, tasks: &[Task]) -> Complexity {
360 let goal_lower = goal.to_lowercase();
361 let task_count = tasks.len();
362
363 if goal_lower.contains("refactor")
365 || goal_lower.contains("architecture")
366 || goal_lower.contains("performance")
367 || goal_lower.contains("security")
368 || task_count > 6
369 {
370 return Complexity::Complex;
371 }
372
373 if goal_lower.contains("feature")
375 || goal_lower.contains("component")
376 || goal_lower.contains("api")
377 || task_count > 3
378 {
379 return Complexity::Medium;
380 }
381
382 Complexity::Simple
384 }
385}
386
387impl Default for PlanTool {
388 fn default() -> Self {
389 Self::new()
390 }
391}
392
393pub type AgentType = String;
395pub type DependencyGraph = HashMap<TaskId, Vec<TaskId>>;
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct MetaTask {
399 pub name: String,
400 pub description: String,
401}
402
403pub struct MetaTaskPlanner;
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct PlanContext {
407 pub goal: String,
408 pub constraints: Vec<String>,
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct PlanExecutionPlan {
413 pub steps: Vec<PlanExecutionStep>,
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
417pub struct PlanExecutionStep {
418 pub description: String,
419 pub task_ids: Vec<TaskId>,
420}
421
422#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
423pub enum PlanIntelligenceLevel {
424 Basic,
425 Advanced,
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct SubTask {
430 pub id: TaskId,
431 pub description: String,
432 pub parent_id: Option<TaskId>,
433}
434
435pub struct SubTaskPlanner;
436
437#[derive(Debug, Clone, Serialize, Deserialize)]
438pub struct TaskGroup {
439 pub name: String,
440 pub tasks: Vec<TaskId>,
441}
442
443#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
444pub enum TaskPriority {
445 Low,
446 Medium,
447 High,
448 Critical,
449}
450
451#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
452pub enum TaskStatus {
453 Pending,
454 InProgress,
455 Complete,
456 Failed,
457}
458
459#[derive(Debug, Clone, Serialize, Deserialize)]
461pub struct ToolOutput<T> {
462 pub result: T,
463 pub summary: String,
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_plan_creation() {
472 let tool = PlanTool::new();
473 let result = tool.plan("Add dark mode toggle to the UI").unwrap();
474
475 assert!(!result.tasks.is_empty());
476 assert!(!result.parallel_groups.is_empty());
477 assert_eq!(result.estimated_complexity, Complexity::Simple);
479 }
480
481 #[test]
482 fn test_add_feature_decomposition() {
483 let tool = PlanTool::new();
484 let result = tool.plan("Add user authentication feature").unwrap();
485
486 assert!(result.tasks.len() >= 3); assert!(
488 result
489 .tasks
490 .iter()
491 .any(|t| t.description.contains("Analyze"))
492 );
493 assert!(
494 result
495 .tasks
496 .iter()
497 .any(|t| t.description.contains("Implement"))
498 );
499 }
500
501 #[test]
502 fn test_refactor_decomposition() {
503 let tool = PlanTool::new();
504 let result = tool.plan("Refactor authentication system").unwrap();
505
506 assert_eq!(result.estimated_complexity, Complexity::Complex);
507 assert!(
508 result
509 .tasks
510 .iter()
511 .any(|t| t.description.contains("Analyze"))
512 );
513 assert!(
514 result
515 .tasks
516 .iter()
517 .any(|t| t.description.contains("refactoring"))
518 );
519 }
520
521 #[test]
522 fn test_parallel_group_identification() {
523 let tool = PlanTool::new();
524 let result = tool.plan("Add comprehensive test suite").unwrap();
525
526 let has_parallel_group = result.parallel_groups.iter().any(|group| group.len() > 1);
528 assert!(has_parallel_group);
529 }
530
531 #[test]
532 fn test_circular_dependency_detection() {
533 let tool = PlanTool::new();
534
535 let task1 = Task::new("Task 1".to_string());
537 let task2 = Task::new("Task 2".to_string()).depends_on(task1.id);
538 let task3 = Task::new("Task 3".to_string())
539 .depends_on(task2.id)
540 .depends_on(task1.id); let tasks = vec![task1, task2, task3];
543 let result = tool.check_circular_dependencies(&tasks);
544 assert!(result.is_ok());
545 }
546
547 #[test]
548 fn test_complexity_estimation() {
549 let tool = PlanTool::new();
550
551 let simple_result = tool.plan("Fix typo in readme").unwrap();
552 assert_eq!(simple_result.estimated_complexity, Complexity::Simple);
553
554 let complex_result = tool
555 .plan("Refactor entire authentication architecture")
556 .unwrap();
557 assert_eq!(complex_result.estimated_complexity, Complexity::Complex);
558 }
559}