1use std::collections::{HashMap, HashSet, VecDeque};
7use std::fmt;
8
9use serde::{Deserialize, Serialize};
10
11use crate::error::Error;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
15pub enum TaskPriority {
16 Low = 1,
17 Normal = 2,
18 High = 3,
19 Critical = 4,
20 Urgent = 5,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub enum TaskStatus {
26 Pending,
27 Running,
28 Completed,
29 Failed,
30 Skipped,
31 Cancelled,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub enum TaskType {
37 General,
38 ProtocolGeneration,
39 CodeAnalysis,
40 WebAutomation,
41 MemoryManagement,
42 EnterpriseWorkflow,
43 MultiAgentCoordination,
44}
45
46#[derive(Debug, Clone)]
48pub struct TaskNode {
49 id: String,
50 name: String,
51 task_type: TaskType,
52 priority: TaskPriority,
53 status: TaskStatus,
54 dependencies: HashSet<String>,
55 description: String,
56 config: TaskConfig,
57 metadata: serde_json::Value,
58 estimated_duration_ms: u64,
59 required_components: Vec<String>,
60 required_thinktools: Vec<String>,
61 requires_m2_capability: bool,
62 created_at: u64,
63 started_at: Option<u64>,
64 completed_at: Option<u64>,
65 retry_count: u32,
66 max_retries: u32,
67}
68
69impl TaskNode {
70 pub fn new(
71 id: String,
72 name: String,
73 task_type: TaskType,
74 priority: TaskPriority,
75 description: String,
76 ) -> Self {
77 let now = chrono::Utc::now().timestamp() as u64;
78
79 Self {
80 id,
81 name,
82 task_type,
83 priority,
84 status: TaskStatus::Pending,
85 dependencies: HashSet::new(),
86 description,
87 config: TaskConfig::default(),
88 metadata: serde_json::json!({}),
89 estimated_duration_ms: 0,
90 required_components: Vec::new(),
91 required_thinktools: Vec::new(),
92 requires_m2_capability: false,
93 created_at: now,
94 started_at: None,
95 completed_at: None,
96 retry_count: 0,
97 max_retries: 3,
98 }
99 }
100
101 pub fn with_dependency(mut self, dependency_id: &str) -> Self {
102 self.dependencies.insert(dependency_id.to_string());
103 self
104 }
105
106 pub fn with_config(mut self, config: TaskConfig) -> Self {
107 self.config = config;
108 self
109 }
110
111 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
112 self.metadata = metadata;
113 self
114 }
115
116 pub fn with_duration(mut self, duration_ms: u64) -> Self {
117 self.estimated_duration_ms = duration_ms;
118 self
119 }
120
121 pub fn with_components(mut self, components: Vec<String>) -> Self {
122 self.required_components = components;
123 self
124 }
125
126 pub fn with_thinktools(mut self, thinktools: Vec<String>) -> Self {
127 self.required_thinktools = thinktools;
128 self
129 }
130
131 pub fn requires_m2(mut self, requires: bool) -> Self {
132 self.requires_m2_capability = requires;
133 self
134 }
135
136 pub fn with_max_retries(mut self, retries: u32) -> Self {
137 self.max_retries = retries;
138 self
139 }
140
141 pub fn id(&self) -> &str {
143 &self.id
144 }
145 pub fn name(&self) -> &str {
146 &self.name
147 }
148 pub fn task_type(&self) -> TaskType {
149 self.task_type
150 }
151 pub fn priority(&self) -> TaskPriority {
152 self.priority
153 }
154 pub fn status(&self) -> TaskStatus {
155 self.status
156 }
157 pub fn dependencies(&self) -> &HashSet<String> {
158 &self.dependencies
159 }
160 pub fn description(&self) -> &str {
161 &self.description
162 }
163 pub fn config(&self) -> &TaskConfig {
164 &self.config
165 }
166 pub fn metadata(&self) -> &serde_json::Value {
167 &self.metadata
168 }
169 pub fn estimated_duration_ms(&self) -> u64 {
170 self.estimated_duration_ms
171 }
172 pub fn required_components(&self) -> &[String] {
173 &self.required_components
174 }
175 pub fn required_thinktools(&self) -> &[String] {
176 &self.required_thinktools
177 }
178 pub fn requires_m2_capability(&self) -> bool {
179 self.requires_m2_capability
180 }
181 pub fn created_at(&self) -> u64 {
182 self.created_at
183 }
184 pub fn started_at(&self) -> Option<u64> {
185 self.started_at
186 }
187 pub fn completed_at(&self) -> Option<u64> {
188 self.completed_at
189 }
190 pub fn retry_count(&self) -> u32 {
191 self.retry_count
192 }
193 pub fn max_retries(&self) -> u32 {
194 self.max_retries
195 }
196 pub fn requires_thinktool(&self) -> bool {
197 !self.required_thinktools.is_empty()
198 }
199
200 pub fn mark_running(&mut self) {
202 self.status = TaskStatus::Running;
203 self.started_at = Some(chrono::Utc::now().timestamp() as u64);
204 }
205
206 pub fn mark_completed(&mut self) {
207 self.status = TaskStatus::Completed;
208 self.completed_at = Some(chrono::Utc::now().timestamp() as u64);
209 }
210
211 pub fn mark_failed(&mut self) {
212 self.status = TaskStatus::Failed;
213 self.completed_at = Some(chrono::Utc::now().timestamp() as u64);
214 }
215
216 pub fn increment_retry(&mut self) {
217 self.retry_count += 1;
218 }
219
220 pub fn can_retry(&self) -> bool {
221 self.retry_count < self.max_retries && self.status == TaskStatus::Failed
222 }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct TaskConfig {
228 pub timeout_ms: u64,
229 pub memory_limit_mb: u64,
230 pub parallel_execution: bool,
231 pub resource_requirements: ResourceRequirements,
232 pub custom_parameters: HashMap<String, serde_json::Value>,
233}
234
235impl Default for TaskConfig {
236 fn default() -> Self {
237 Self {
238 timeout_ms: 300_000, memory_limit_mb: 512,
240 parallel_execution: false,
241 resource_requirements: ResourceRequirements::default(),
242 custom_parameters: HashMap::new(),
243 }
244 }
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ResourceRequirements {
250 pub cpu_cores: f64,
251 pub memory_mb: u64,
252 pub network_bandwidth_mbps: f64,
253 pub disk_io_mb: u64,
254}
255
256impl Default for ResourceRequirements {
257 fn default() -> Self {
258 Self {
259 cpu_cores: 1.0,
260 memory_mb: 512,
261 network_bandwidth_mbps: 10.0,
262 disk_io_mb: 100,
263 }
264 }
265}
266
267pub type DependencyGraph = TaskGraph;
269
270#[derive(Debug)]
271pub struct TaskGraph {
272 nodes: HashMap<String, TaskNode>,
273 edges: HashMap<String, HashSet<String>>, reverse_edges: HashMap<String, HashSet<String>>, }
276
277impl Default for TaskGraph {
278 fn default() -> Self {
279 Self::new()
280 }
281}
282
283impl TaskGraph {
284 pub fn new() -> Self {
285 Self {
286 nodes: HashMap::new(),
287 edges: HashMap::new(),
288 reverse_edges: HashMap::new(),
289 }
290 }
291
292 pub fn add_node(&mut self, node: TaskNode) -> Result<(), Error> {
294 let node_id = node.id().to_string();
295
296 if self.nodes.contains_key(&node_id) {
297 return Err(Error::Validation(format!(
298 "Task node '{}' already exists",
299 node_id
300 )));
301 }
302
303 self.nodes.insert(node_id.clone(), node);
304 self.edges.insert(node_id.clone(), HashSet::new());
305 self.reverse_edges.insert(node_id.clone(), HashSet::new());
306
307 Ok(())
308 }
309
310 pub fn add_dependency(&mut self, from: &str, to: &str) -> Result<(), Error> {
312 if !self.nodes.contains_key(from) {
314 return Err(Error::Validation(format!(
315 "Source task '{}' does not exist",
316 from
317 )));
318 }
319
320 if !self.nodes.contains_key(to) {
321 return Err(Error::Validation(format!(
322 "Target task '{}' does not exist",
323 to
324 )));
325 }
326
327 self.edges
329 .entry(from.to_string())
330 .or_default()
331 .insert(to.to_string());
332
333 self.reverse_edges
334 .entry(to.to_string())
335 .or_default()
336 .insert(from.to_string());
337
338 Ok(())
339 }
340
341 pub fn remove_node(&mut self, node_id: &str) -> Result<(), Error> {
343 if !self.nodes.contains_key(node_id) {
344 return Err(Error::Validation(format!(
345 "Task node '{}' does not exist",
346 node_id
347 )));
348 }
349
350 let dependents = self.edges.remove(node_id).unwrap_or_default();
352 for dependent in dependents {
353 self.reverse_edges
354 .entry(dependent.clone())
355 .or_default()
356 .remove(node_id);
357 }
358
359 let dependencies = self.reverse_edges.remove(node_id).unwrap_or_default();
360 for dependency in dependencies {
361 self.edges
362 .entry(dependency.clone())
363 .or_default()
364 .remove(node_id);
365 }
366
367 self.reverse_edges.remove(node_id);
369
370 self.nodes.remove(node_id);
372
373 Ok(())
374 }
375
376 pub fn get_node(&self, node_id: &str) -> Option<&TaskNode> {
378 self.nodes.get(node_id)
379 }
380
381 pub fn get_node_mut(&mut self, node_id: &str) -> Option<&mut TaskNode> {
383 self.nodes.get_mut(node_id)
384 }
385
386 pub fn nodes(&self) -> Vec<&TaskNode> {
388 self.nodes.values().collect()
389 }
390
391 pub fn node_ids(&self) -> Vec<&String> {
393 self.nodes.keys().collect()
394 }
395
396 pub fn get_dependents(&self, node_id: &str) -> Option<&HashSet<String>> {
398 self.edges.get(node_id)
399 }
400
401 pub fn get_dependencies(&self, node_id: &str) -> Option<&HashSet<String>> {
403 self.reverse_edges.get(node_id)
404 }
405
406 pub fn has_cycles(&self) -> bool {
408 let mut visited = HashSet::new();
409 let mut recursion_stack = HashSet::new();
410
411 for node_id in self.nodes.keys() {
412 if !visited.contains(node_id)
413 && self.has_cycle_dfs(node_id, &mut visited, &mut recursion_stack)
414 {
415 return true;
416 }
417 }
418
419 false
420 }
421
422 fn has_cycle_dfs(
424 &self,
425 node_id: &str,
426 visited: &mut HashSet<String>,
427 recursion_stack: &mut HashSet<String>,
428 ) -> bool {
429 visited.insert(node_id.to_string());
430 recursion_stack.insert(node_id.to_string());
431
432 if let Some(dependents) = self.edges.get(node_id) {
433 for dependent in dependents {
434 if !visited.contains(dependent) {
435 if self.has_cycle_dfs(dependent, visited, recursion_stack) {
436 return true;
437 }
438 } else if recursion_stack.contains(dependent) {
439 return true;
440 }
441 }
442 }
443
444 recursion_stack.remove(node_id);
445 false
446 }
447
448 pub fn validate(&self) -> Result<(), Error> {
450 if self.has_cycles() {
452 return Err(Error::Validation(
453 "Task dependency graph contains cycles".to_string(),
454 ));
455 }
456
457 for (node_id, dependencies) in &self.reverse_edges {
459 for dependency in dependencies {
460 if !self.nodes.contains_key(dependency) {
461 return Err(Error::Validation(format!(
462 "Task '{}' depends on non-existent task '{}'",
463 node_id, dependency
464 )));
465 }
466 }
467 }
468
469 Ok(())
470 }
471
472 pub fn topological_sort(&self) -> Result<Vec<String>, Error> {
474 self.validate()?;
476
477 let mut in_degree: HashMap<String, usize> = HashMap::new();
478 let mut queue: VecDeque<String> = VecDeque::new();
479 let mut result: Vec<String> = Vec::new();
480
481 for node_id in self.nodes.keys() {
483 let in_degree_count = self
484 .reverse_edges
485 .get(node_id)
486 .map(|deps| deps.len())
487 .unwrap_or(0);
488 in_degree.insert(node_id.clone(), in_degree_count);
489 }
490
491 for (node_id, degree) in &in_degree {
493 if *degree == 0 {
494 queue.push_back(node_id.clone());
495 }
496 }
497
498 while let Some(node_id) = queue.pop_front() {
500 result.push(node_id.clone());
501
502 if let Some(dependents) = self.edges.get(&node_id) {
504 for dependent in dependents {
505 if let Some(degree) = in_degree.get_mut(dependent) {
506 *degree -= 1;
507 if *degree == 0 {
508 queue.push_back(dependent.clone());
509 }
510 }
511 }
512 }
513 }
514
515 if result.len() != self.nodes.len() {
517 return Err(Error::Validation(
518 "Unable to topologically sort - graph may have cycles".to_string(),
519 ));
520 }
521
522 Ok(result)
523 }
524
525 pub fn get_ready_tasks(&self) -> Vec<String> {
527 let mut ready_tasks = Vec::new();
528
529 for (node_id, node) in &self.nodes {
530 if node.status() == TaskStatus::Pending {
531 let mut all_deps_completed = true;
533 if let Some(dependencies) = self.reverse_edges.get(node_id) {
534 for dep_id in dependencies {
535 if let Some(dep_node) = self.nodes.get(dep_id) {
536 if dep_node.status() != TaskStatus::Completed {
537 all_deps_completed = false;
538 break;
539 }
540 }
541 }
542 }
543
544 if all_deps_completed {
545 ready_tasks.push(node_id.clone());
546 }
547 }
548 }
549
550 ready_tasks.sort_by(|a, b| {
552 let node_a = self.nodes.get(a).unwrap();
553 let node_b = self.nodes.get(b).unwrap();
554 node_b.priority().cmp(&node_a.priority())
555 });
556
557 ready_tasks
558 }
559
560 pub fn get_statistics(&self) -> TaskGraphStatistics {
562 let total_nodes = self.nodes.len();
563 let mut completed_nodes = 0;
564 let mut running_nodes = 0;
565 let mut pending_nodes = 0;
566 let mut failed_nodes = 0;
567
568 let mut total_estimated_duration = 0u64;
569 let mut total_critical_path_duration = 0u64;
570
571 for node in self.nodes.values() {
572 match node.status() {
573 TaskStatus::Completed => completed_nodes += 1,
574 TaskStatus::Running => running_nodes += 1,
575 TaskStatus::Pending => pending_nodes += 1,
576 TaskStatus::Failed => failed_nodes += 1,
577 _ => {}
578 }
579
580 total_estimated_duration += node.estimated_duration_ms();
581 }
582
583 if let Ok(topological_order) = self.topological_sort() {
585 let mut node_completion_times: HashMap<String, u64> = HashMap::new();
586
587 for node_id in topological_order {
588 let mut max_dep_time = 0u64;
589 if let Some(dependencies) = self.reverse_edges.get(&node_id) {
590 for dep_id in dependencies {
591 if let Some(dep_time) = node_completion_times.get(dep_id) {
592 max_dep_time = max_dep_time.max(*dep_time);
593 }
594 }
595 }
596
597 let node = self.nodes.get(&node_id).unwrap();
598 let completion_time = max_dep_time + node.estimated_duration_ms();
599 node_completion_times.insert(node_id.clone(), completion_time);
600 total_critical_path_duration = total_critical_path_duration.max(completion_time);
601 }
602 }
603
604 TaskGraphStatistics {
605 total_nodes,
606 completed_nodes,
607 running_nodes,
608 pending_nodes,
609 failed_nodes,
610 total_estimated_duration,
611 critical_path_duration: total_critical_path_duration,
612 completion_percentage: if total_nodes > 0 {
613 (completed_nodes as f64 / total_nodes as f64) * 100.0
614 } else {
615 0.0
616 },
617 }
618 }
619}
620
621#[derive(Debug)]
622pub struct TaskGraphStatistics {
623 pub total_nodes: usize,
624 pub completed_nodes: usize,
625 pub running_nodes: usize,
626 pub pending_nodes: usize,
627 pub failed_nodes: usize,
628 pub total_estimated_duration: u64,
629 pub critical_path_duration: u64,
630 pub completion_percentage: f64,
631}
632
633impl fmt::Display for TaskGraphStatistics {
634 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
635 write!(
636 f,
637 "TaskGraph Statistics:\n\
638 - Total Tasks: {}\n\
639 - Completed: {} ({:.1}%)\n\
640 - Running: {}\n\
641 - Pending: {}\n\
642 - Failed: {}\n\
643 - Estimated Duration: {:.1} seconds\n\
644 - Critical Path: {:.1} seconds",
645 self.total_nodes,
646 self.completed_nodes,
647 self.completion_percentage,
648 self.running_nodes,
649 self.pending_nodes,
650 self.failed_nodes,
651 self.total_estimated_duration as f64 / 1000.0,
652 self.critical_path_duration as f64 / 1000.0
653 )
654 }
655}
656
657impl fmt::Display for TaskStatus {
658 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
659 match self {
660 TaskStatus::Pending => write!(f, "Pending"),
661 TaskStatus::Running => write!(f, "Running"),
662 TaskStatus::Completed => write!(f, "Completed"),
663 TaskStatus::Failed => write!(f, "Failed"),
664 TaskStatus::Skipped => write!(f, "Skipped"),
665 TaskStatus::Cancelled => write!(f, "Cancelled"),
666 }
667 }
668}
669
670impl fmt::Display for TaskPriority {
671 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
672 match self {
673 TaskPriority::Low => write!(f, "Low"),
674 TaskPriority::Normal => write!(f, "Normal"),
675 TaskPriority::High => write!(f, "High"),
676 TaskPriority::Critical => write!(f, "Critical"),
677 TaskPriority::Urgent => write!(f, "Urgent"),
678 }
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685
686 #[test]
687 fn test_task_node_creation() {
688 let task = TaskNode::new(
689 "test-1".to_string(),
690 "Test Task".to_string(),
691 TaskType::ProtocolGeneration,
692 TaskPriority::High,
693 "A test task".to_string(),
694 );
695
696 assert_eq!(task.id(), "test-1");
697 assert_eq!(task.name(), "Test Task");
698 assert_eq!(task.task_type(), TaskType::ProtocolGeneration);
699 assert_eq!(task.priority(), TaskPriority::High);
700 assert_eq!(task.status(), TaskStatus::Pending);
701 assert_eq!(task.dependencies().len(), 0);
702 }
703
704 #[test]
705 fn test_task_node_with_dependencies() {
706 let task = TaskNode::new(
707 "test-1".to_string(),
708 "Test Task".to_string(),
709 TaskType::CodeAnalysis,
710 TaskPriority::Normal,
711 "A test task".to_string(),
712 )
713 .with_dependency("dep-1")
714 .with_dependency("dep-2");
715
716 assert_eq!(task.dependencies().len(), 2);
717 assert!(task.dependencies().contains("dep-1"));
718 assert!(task.dependencies().contains("dep-2"));
719 }
720
721 #[test]
722 fn test_task_graph_creation() {
723 let mut graph = TaskGraph::new();
724
725 let task1 = TaskNode::new(
726 "task-1".to_string(),
727 "Task 1".to_string(),
728 TaskType::ProtocolGeneration,
729 TaskPriority::Normal,
730 "First task".to_string(),
731 );
732
733 let task2 = TaskNode::new(
734 "task-2".to_string(),
735 "Task 2".to_string(),
736 TaskType::CodeAnalysis,
737 TaskPriority::High,
738 "Second task".to_string(),
739 );
740
741 assert!(graph.add_node(task1).is_ok());
742 assert!(graph.add_node(task2).is_ok());
743 assert_eq!(graph.nodes().len(), 2);
744 }
745
746 #[test]
747 fn test_task_graph_dependencies() {
748 let mut graph = TaskGraph::new();
749
750 let task1 = TaskNode::new(
751 "task-1".to_string(),
752 "Task 1".to_string(),
753 TaskType::ProtocolGeneration,
754 TaskPriority::Normal,
755 "First task".to_string(),
756 );
757
758 let task2 = TaskNode::new(
759 "task-2".to_string(),
760 "Task 2".to_string(),
761 TaskType::CodeAnalysis,
762 TaskPriority::High,
763 "Second task".to_string(),
764 );
765
766 graph.add_node(task1).unwrap();
767 graph.add_node(task2).unwrap();
768 assert!(graph.add_dependency("task-1", "task-2").is_ok());
769 }
770
771 #[test]
772 fn test_cycle_detection() {
773 let mut graph = TaskGraph::new();
774
775 let task1 = TaskNode::new(
776 "task-1".to_string(),
777 "Task 1".to_string(),
778 TaskType::ProtocolGeneration,
779 TaskPriority::Normal,
780 "First task".to_string(),
781 );
782
783 let task2 = TaskNode::new(
784 "task-2".to_string(),
785 "Task 2".to_string(),
786 TaskType::CodeAnalysis,
787 TaskPriority::High,
788 "Second task".to_string(),
789 );
790
791 graph.add_node(task1).unwrap();
792 graph.add_node(task2).unwrap();
793 graph.add_dependency("task-1", "task-2").unwrap();
794 graph.add_dependency("task-2", "task-1").unwrap();
795
796 assert!(graph.has_cycles());
797 }
798
799 #[test]
800 fn test_topological_sort() {
801 let mut graph = TaskGraph::new();
802
803 let task1 = TaskNode::new(
804 "task-1".to_string(),
805 "Task 1".to_string(),
806 TaskType::ProtocolGeneration,
807 TaskPriority::Normal,
808 "First task".to_string(),
809 );
810
811 let task2 = TaskNode::new(
812 "task-2".to_string(),
813 "Task 2".to_string(),
814 TaskType::CodeAnalysis,
815 TaskPriority::High,
816 "Second task".to_string(),
817 );
818
819 let task3 = TaskNode::new(
820 "task-3".to_string(),
821 "Task 3".to_string(),
822 TaskType::WebAutomation,
823 TaskPriority::Normal,
824 "Third task".to_string(),
825 );
826
827 graph.add_node(task1).unwrap();
828 graph.add_node(task2).unwrap();
829 graph.add_node(task3).unwrap();
830 graph.add_dependency("task-1", "task-2").unwrap();
831 graph.add_dependency("task-2", "task-3").unwrap();
832
833 let order = graph.topological_sort().unwrap();
834 assert_eq!(order.len(), 3);
835
836 let task1_pos = order.iter().position(|id| id == "task-1").unwrap();
838 let task2_pos = order.iter().position(|id| id == "task-2").unwrap();
839 let task3_pos = order.iter().position(|id| id == "task-3").unwrap();
840
841 assert!(task1_pos < task2_pos);
842 assert!(task2_pos < task3_pos);
843 }
844}