1use chrono::{DateTime, Utc};
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::{HashMap, HashSet, VecDeque};
18use std::sync::Arc;
19use std::time::Duration;
20use thiserror::Error;
21use tokio::sync::Mutex;
22use tokio::time::timeout;
23
24pub type ExecutorResult<T> = Result<T, ExecutorError>;
26
27#[derive(Debug, Error, Clone)]
29pub enum ExecutorError {
30 #[error("Task not found: {0}")]
32 TaskNotFound(String),
33
34 #[error("Task timeout: {0}")]
36 TaskTimeout(String),
37
38 #[error("Task failed: {task_id}, error: {error}")]
40 TaskFailed { task_id: String, error: String },
41
42 #[error("Circular dependency detected: {0:?}")]
44 CircularDependency(Vec<String>),
45
46 #[error("Invalid dependency: task {task_id} depends on non-existent task {dependency}")]
48 InvalidDependency { task_id: String, dependency: String },
49
50 #[error("Execution cancelled")]
52 Cancelled,
53
54 #[error("All retries exhausted for task: {0}")]
56 RetriesExhausted(String),
57
58 #[error("Dependency failed: task {task_id} depends on failed task {dependency}")]
60 DependencyFailed { task_id: String, dependency: String },
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
65#[serde(rename_all = "camelCase")]
66pub enum TaskStatus {
67 Pending,
69 WaitingForDependencies,
71 Running,
73 Completed,
75 Failed,
77 Cancelled,
79 Skipped,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct ParallelAgentConfig {
87 pub max_concurrency: usize,
89 pub timeout: Duration,
91 pub retry_on_failure: bool,
93 pub stop_on_first_error: bool,
95 pub max_retries: usize,
97 pub retry_delay: Duration,
99}
100
101impl Default for ParallelAgentConfig {
102 fn default() -> Self {
103 Self {
104 max_concurrency: 4,
105 timeout: Duration::from_secs(300), retry_on_failure: true,
107 stop_on_first_error: false,
108 max_retries: 3,
109 retry_delay: Duration::from_secs(1),
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(rename_all = "camelCase")]
117pub struct AgentTask {
118 pub id: String,
120 pub task_type: String,
122 pub prompt: String,
124 pub description: Option<String>,
126 pub options: Option<HashMap<String, Value>>,
128 pub priority: Option<u8>,
130 pub dependencies: Option<Vec<String>>,
132 pub timeout: Option<Duration>,
134}
135
136impl AgentTask {
137 pub fn new(
139 id: impl Into<String>,
140 task_type: impl Into<String>,
141 prompt: impl Into<String>,
142 ) -> Self {
143 Self {
144 id: id.into(),
145 task_type: task_type.into(),
146 prompt: prompt.into(),
147 description: None,
148 options: None,
149 priority: None,
150 dependencies: None,
151 timeout: None,
152 }
153 }
154
155 pub fn with_description(mut self, description: impl Into<String>) -> Self {
157 self.description = Some(description.into());
158 self
159 }
160
161 pub fn with_options(mut self, options: HashMap<String, Value>) -> Self {
163 self.options = Some(options);
164 self
165 }
166
167 pub fn with_priority(mut self, priority: u8) -> Self {
169 self.priority = Some(priority);
170 self
171 }
172
173 pub fn with_dependencies(mut self, dependencies: Vec<String>) -> Self {
175 self.dependencies = Some(dependencies);
176 self
177 }
178
179 pub fn with_timeout(mut self, timeout: Duration) -> Self {
181 self.timeout = Some(timeout);
182 self
183 }
184
185 pub fn effective_priority(&self) -> u8 {
187 self.priority.unwrap_or(0)
188 }
189
190 pub fn has_dependencies(&self) -> bool {
192 self.dependencies
193 .as_ref()
194 .map(|d| !d.is_empty())
195 .unwrap_or(false)
196 }
197
198 pub fn get_dependencies(&self) -> Vec<String> {
200 self.dependencies.clone().unwrap_or_default()
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
206#[serde(rename_all = "camelCase")]
207pub struct AgentResult {
208 pub task_id: String,
210 pub success: bool,
212 pub result: Option<Value>,
214 pub error: Option<String>,
216 pub duration: Duration,
218 pub retries: usize,
220 pub started_at: DateTime<Utc>,
222 pub completed_at: DateTime<Utc>,
224}
225
226#[derive(Debug, Clone)]
228pub struct TaskExecutionInfo {
229 pub task: AgentTask,
231 pub status: TaskStatus,
233 pub retries: usize,
235 pub last_error: Option<String>,
237 pub started_at: Option<DateTime<Utc>>,
239 pub completed_at: Option<DateTime<Utc>>,
241 pub result: Option<Value>,
243}
244
245impl TaskExecutionInfo {
246 pub fn new(task: AgentTask) -> Self {
248 Self {
249 task,
250 status: TaskStatus::Pending,
251 retries: 0,
252 last_error: None,
253 started_at: None,
254 completed_at: None,
255 result: None,
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262#[serde(rename_all = "camelCase")]
263pub struct ExecutionProgress {
264 pub total: usize,
266 pub completed: usize,
268 pub failed: usize,
270 pub running: usize,
272 pub pending: usize,
274 pub skipped: usize,
276 pub cancelled: bool,
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282#[serde(rename_all = "camelCase")]
283pub struct ParallelExecutionResult {
284 pub success: bool,
286 pub results: Vec<AgentResult>,
288 pub total_duration: Duration,
290 pub successful_count: usize,
292 pub failed_count: usize,
294 pub skipped_count: usize,
296 pub merged_result: Option<MergedResult>,
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
302#[serde(rename_all = "camelCase")]
303pub struct MergedResult {
304 pub outputs: Vec<Value>,
306 pub summary: Option<String>,
308 pub metadata: HashMap<String, Value>,
310}
311
312#[derive(Debug, Clone)]
314pub struct DependencyGraph {
315 dependencies: HashMap<String, HashSet<String>>,
317 dependents: HashMap<String, HashSet<String>>,
319 task_ids: HashSet<String>,
321}
322
323impl DependencyGraph {
324 pub fn new() -> Self {
326 Self {
327 dependencies: HashMap::new(),
328 dependents: HashMap::new(),
329 task_ids: HashSet::new(),
330 }
331 }
332
333 pub fn add_task(&mut self, task_id: impl Into<String>) {
335 let task_id = task_id.into();
336 self.task_ids.insert(task_id.clone());
337 self.dependencies.entry(task_id).or_default();
338 }
339
340 pub fn add_dependency(&mut self, task_id: impl Into<String>, dependency_id: impl Into<String>) {
342 let task_id = task_id.into();
343 let dependency_id = dependency_id.into();
344
345 self.task_ids.insert(task_id.clone());
346 self.task_ids.insert(dependency_id.clone());
347
348 self.dependencies
349 .entry(task_id.clone())
350 .or_default()
351 .insert(dependency_id.clone());
352
353 self.dependents
354 .entry(dependency_id)
355 .or_default()
356 .insert(task_id);
357 }
358
359 pub fn get_dependencies(&self, task_id: &str) -> HashSet<String> {
361 self.dependencies.get(task_id).cloned().unwrap_or_default()
362 }
363
364 pub fn get_dependents(&self, task_id: &str) -> HashSet<String> {
366 self.dependents.get(task_id).cloned().unwrap_or_default()
367 }
368
369 pub fn has_unmet_dependencies(&self, task_id: &str, completed: &HashSet<String>) -> bool {
371 if let Some(deps) = self.dependencies.get(task_id) {
372 deps.iter().any(|d| !completed.contains(d))
373 } else {
374 false
375 }
376 }
377
378 pub fn get_ready_tasks(
380 &self,
381 completed: &HashSet<String>,
382 running: &HashSet<String>,
383 ) -> Vec<String> {
384 self.task_ids
385 .iter()
386 .filter(|id| {
387 !completed.contains(*id)
388 && !running.contains(*id)
389 && !self.has_unmet_dependencies(id, completed)
390 })
391 .cloned()
392 .collect()
393 }
394
395 pub fn all_completed(&self, completed: &HashSet<String>) -> bool {
397 self.task_ids.iter().all(|id| completed.contains(id))
398 }
399
400 pub fn get_all_tasks(&self) -> &HashSet<String> {
402 &self.task_ids
403 }
404
405 pub fn contains(&self, task_id: &str) -> bool {
407 self.task_ids.contains(task_id)
408 }
409}
410
411impl Default for DependencyGraph {
412 fn default() -> Self {
413 Self::new()
414 }
415}
416
417#[derive(Debug, Clone)]
419pub struct ValidationResult {
420 pub valid: bool,
422 pub errors: Vec<String>,
424 pub circular_dependencies: Option<Vec<String>>,
426 pub missing_dependencies: Vec<(String, String)>,
428}
429
430impl ValidationResult {
431 pub fn valid() -> Self {
433 Self {
434 valid: true,
435 errors: Vec::new(),
436 circular_dependencies: None,
437 missing_dependencies: Vec::new(),
438 }
439 }
440
441 pub fn invalid(errors: Vec<String>) -> Self {
443 Self {
444 valid: false,
445 errors,
446 circular_dependencies: None,
447 missing_dependencies: Vec::new(),
448 }
449 }
450}
451
452pub fn create_dependency_graph(tasks: &[AgentTask]) -> DependencyGraph {
454 let mut graph = DependencyGraph::new();
455
456 for task in tasks {
457 graph.add_task(&task.id);
458 if let Some(deps) = &task.dependencies {
459 for dep in deps {
460 graph.add_dependency(&task.id, dep);
461 }
462 }
463 }
464
465 graph
466}
467
468pub fn validate_task_dependencies(tasks: &[AgentTask]) -> ValidationResult {
470 let task_ids: HashSet<String> = tasks.iter().map(|t| t.id.clone()).collect();
471 let mut errors = Vec::new();
472 let mut missing_deps = Vec::new();
473
474 for task in tasks {
476 if let Some(deps) = &task.dependencies {
477 for dep in deps {
478 if !task_ids.contains(dep) {
479 errors.push(format!(
480 "Task '{}' depends on non-existent task '{}'",
481 task.id, dep
482 ));
483 missing_deps.push((task.id.clone(), dep.clone()));
484 }
485 }
486 }
487 }
488
489 let graph = create_dependency_graph(tasks);
491 if let Some(cycle) = detect_cycle(&graph) {
492 errors.push(format!("Circular dependency detected: {:?}", cycle));
493 return ValidationResult {
494 valid: false,
495 errors,
496 circular_dependencies: Some(cycle),
497 missing_dependencies: missing_deps,
498 };
499 }
500
501 if errors.is_empty() {
502 ValidationResult::valid()
503 } else {
504 ValidationResult {
505 valid: false,
506 errors,
507 circular_dependencies: None,
508 missing_dependencies: missing_deps,
509 }
510 }
511}
512
513fn detect_cycle(graph: &DependencyGraph) -> Option<Vec<String>> {
515 let mut visited = HashSet::new();
516 let mut rec_stack = HashSet::new();
517 let mut path = Vec::new();
518
519 for task_id in graph.get_all_tasks() {
520 if !visited.contains(task_id) {
521 if let Some(cycle) =
522 dfs_detect_cycle(graph, task_id, &mut visited, &mut rec_stack, &mut path)
523 {
524 return Some(cycle);
525 }
526 }
527 }
528
529 None
530}
531
532fn dfs_detect_cycle(
534 graph: &DependencyGraph,
535 task_id: &str,
536 visited: &mut HashSet<String>,
537 rec_stack: &mut HashSet<String>,
538 path: &mut Vec<String>,
539) -> Option<Vec<String>> {
540 visited.insert(task_id.to_string());
541 rec_stack.insert(task_id.to_string());
542 path.push(task_id.to_string());
543
544 for dep in graph.get_dependencies(task_id) {
545 if !visited.contains(&dep) {
546 if let Some(cycle) = dfs_detect_cycle(graph, &dep, visited, rec_stack, path) {
547 return Some(cycle);
548 }
549 } else if rec_stack.contains(&dep) {
550 let cycle_start = path.iter().position(|x| x == &dep).unwrap();
552 let mut cycle: Vec<String> = path[cycle_start..].to_vec();
553 cycle.push(dep);
554 return Some(cycle);
555 }
556 }
557
558 path.pop();
559 rec_stack.remove(task_id);
560 None
561}
562
563pub fn merge_agent_results(results: Vec<AgentResult>) -> MergedResult {
565 let outputs: Vec<Value> = results
566 .iter()
567 .filter(|r| r.success && r.result.is_some())
568 .map(|r| r.result.clone().unwrap())
569 .collect();
570
571 let successful = results.iter().filter(|r| r.success).count();
572 let failed = results.iter().filter(|r| !r.success).count();
573
574 let mut metadata = HashMap::new();
575 metadata.insert("total_tasks".to_string(), Value::from(results.len()));
576 metadata.insert("successful_tasks".to_string(), Value::from(successful));
577 metadata.insert("failed_tasks".to_string(), Value::from(failed));
578
579 let summary = if failed == 0 {
580 Some(format!("All {} tasks completed successfully", successful))
581 } else {
582 Some(format!(
583 "{} tasks succeeded, {} tasks failed",
584 successful, failed
585 ))
586 };
587
588 MergedResult {
589 outputs,
590 summary,
591 metadata,
592 }
593}
594
595pub struct ParallelAgentExecutor {
600 config: ParallelAgentConfig,
602 tasks: Arc<Mutex<HashMap<String, TaskExecutionInfo>>>,
604 running: Arc<Mutex<bool>>,
606 cancelled: Arc<Mutex<bool>>,
608}
609
610impl ParallelAgentExecutor {
611 pub fn new(config: Option<ParallelAgentConfig>) -> Self {
613 Self {
614 config: config.unwrap_or_default(),
615 tasks: Arc::new(Mutex::new(HashMap::new())),
616 running: Arc::new(Mutex::new(false)),
617 cancelled: Arc::new(Mutex::new(false)),
618 }
619 }
620
621 pub fn with_config(config: ParallelAgentConfig) -> Self {
623 Self::new(Some(config))
624 }
625
626 pub fn config(&self) -> &ParallelAgentConfig {
628 &self.config
629 }
630
631 pub async fn execute(
633 &mut self,
634 tasks: Vec<AgentTask>,
635 ) -> ExecutorResult<ParallelExecutionResult> {
636 let graph = create_dependency_graph(&tasks);
638 self.execute_with_graph(tasks, graph).await
639 }
640
641 pub async fn execute_with_dependencies(
643 &mut self,
644 tasks: Vec<AgentTask>,
645 ) -> ExecutorResult<ParallelExecutionResult> {
646 let validation = validate_task_dependencies(&tasks);
648 if !validation.valid {
649 if let Some(cycle) = validation.circular_dependencies {
650 return Err(ExecutorError::CircularDependency(cycle));
651 }
652 if let Some((task_id, dep)) = validation.missing_dependencies.first() {
653 return Err(ExecutorError::InvalidDependency {
654 task_id: task_id.clone(),
655 dependency: dep.clone(),
656 });
657 }
658 }
659
660 let graph = create_dependency_graph(&tasks);
661 self.execute_with_graph(tasks, graph).await
662 }
663
664 async fn execute_with_graph(
666 &mut self,
667 tasks: Vec<AgentTask>,
668 graph: DependencyGraph,
669 ) -> ExecutorResult<ParallelExecutionResult> {
670 let start_time = Utc::now();
671
672 {
674 let mut task_map = self.tasks.lock().await;
675 task_map.clear();
676 for task in &tasks {
677 task_map.insert(task.id.clone(), TaskExecutionInfo::new(task.clone()));
678 }
679 }
680
681 {
683 *self.running.lock().await = true;
684 *self.cancelled.lock().await = false;
685 }
686
687 let completed = Arc::new(Mutex::new(HashSet::<String>::new()));
689 let failed = Arc::new(Mutex::new(HashSet::<String>::new()));
690 let results = Arc::new(Mutex::new(Vec::<AgentResult>::new()));
691
692 let mut sorted_tasks = tasks.clone();
694 sorted_tasks.sort_by_key(|b| std::cmp::Reverse(b.effective_priority()));
695
696 let execution_result = self
698 .execute_tasks_with_deps(
699 sorted_tasks,
700 graph,
701 completed.clone(),
702 failed.clone(),
703 results.clone(),
704 )
705 .await;
706
707 *self.running.lock().await = false;
709
710 if let Err(_e) = execution_result {
712 let results_vec = results.lock().await.clone();
714 let end_time = Utc::now();
715 let duration = (end_time - start_time).to_std().unwrap_or(Duration::ZERO);
716
717 return Ok(ParallelExecutionResult {
718 success: false,
719 results: results_vec.clone(),
720 total_duration: duration,
721 successful_count: results_vec.iter().filter(|r| r.success).count(),
722 failed_count: results_vec.iter().filter(|r| !r.success).count(),
723 skipped_count: 0,
724 merged_result: Some(merge_agent_results(results_vec)),
725 });
726 }
727
728 let results_vec = results.lock().await.clone();
730 let end_time = Utc::now();
731 let duration = (end_time - start_time).to_std().unwrap_or(Duration::ZERO);
732
733 let successful_count = results_vec.iter().filter(|r| r.success).count();
734 let failed_count = results_vec.iter().filter(|r| !r.success).count();
735 let skipped_count = {
736 let task_map = self.tasks.lock().await;
737 task_map
738 .values()
739 .filter(|t| t.status == TaskStatus::Skipped)
740 .count()
741 };
742
743 Ok(ParallelExecutionResult {
744 success: failed_count == 0 && skipped_count == 0,
745 results: results_vec.clone(),
746 total_duration: duration,
747 successful_count,
748 failed_count,
749 skipped_count,
750 merged_result: Some(merge_agent_results(results_vec)),
751 })
752 }
753
754 async fn execute_tasks_with_deps(
756 &self,
757 tasks: Vec<AgentTask>,
758 graph: DependencyGraph,
759 completed: Arc<Mutex<HashSet<String>>>,
760 failed: Arc<Mutex<HashSet<String>>>,
761 results: Arc<Mutex<Vec<AgentResult>>>,
762 ) -> ExecutorResult<()> {
763 let task_map: HashMap<String, AgentTask> =
764 tasks.iter().map(|t| (t.id.clone(), t.clone())).collect();
765 let pending: Arc<Mutex<VecDeque<String>>> =
767 Arc::new(Mutex::new(tasks.iter().map(|t| t.id.clone()).collect()));
768 let running: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
769
770 loop {
771 if *self.cancelled.lock().await {
773 return Err(ExecutorError::Cancelled);
774 }
775
776 let ready_tasks: Vec<String> = {
778 let completed_guard = completed.lock().await;
779 let running_guard = running.lock().await;
780 let mut pending_guard = pending.lock().await;
781
782 let mut ready = Vec::new();
783 let mut still_pending = VecDeque::new();
784
785 while let Some(task_id) = pending_guard.pop_front() {
786 if !graph.has_unmet_dependencies(&task_id, &completed_guard)
787 && !running_guard.contains(&task_id)
788 {
789 let failed_guard = failed.lock().await;
791 let deps = graph.get_dependencies(&task_id);
792 let has_failed_dep = deps.iter().any(|d| failed_guard.contains(d));
793 drop(failed_guard);
794
795 if has_failed_dep && self.config.stop_on_first_error {
796 let mut task_info = self.tasks.lock().await;
798 if let Some(info) = task_info.get_mut(&task_id) {
799 info.status = TaskStatus::Skipped;
800 }
801 continue;
802 }
803
804 ready.push(task_id);
805 } else {
806 still_pending.push_back(task_id);
807 }
808 }
809
810 *pending_guard = still_pending;
811 ready
812 };
813
814 {
816 let _completed_guard = completed.lock().await;
817 let running_guard = running.lock().await;
818 let pending_guard = pending.lock().await;
819
820 if pending_guard.is_empty() && running_guard.is_empty() && ready_tasks.is_empty() {
821 break;
822 }
823
824 if ready_tasks.is_empty() && running_guard.is_empty() && !pending_guard.is_empty() {
826 break;
828 }
829 }
830
831 let mut tasks_to_spawn = Vec::new();
834 let mut tasks_to_defer = Vec::new();
835
836 for (i, task_id) in ready_tasks.into_iter().enumerate() {
837 if i < self.config.max_concurrency {
838 tasks_to_spawn.push(task_id);
839 } else {
840 tasks_to_defer.push(task_id);
841 }
842 }
843
844 {
846 let mut pending_guard = pending.lock().await;
847 for task_id in tasks_to_defer.into_iter().rev() {
848 pending_guard.push_front(task_id);
849 }
850 }
851
852 let mut handles = Vec::new();
853 for task_id in tasks_to_spawn {
854 let task = match task_map.get(&task_id) {
855 Some(t) => t.clone(),
856 None => continue,
857 };
858
859 {
861 running.lock().await.insert(task_id.clone());
862 let mut task_info = self.tasks.lock().await;
863 if let Some(info) = task_info.get_mut(&task_id) {
864 info.status = TaskStatus::Running;
865 info.started_at = Some(Utc::now());
866 }
867 }
868
869 let completed = completed.clone();
870 let failed = failed.clone();
871 let running = running.clone();
872 let results = results.clone();
873 let tasks_info = self.tasks.clone();
874 let config = self.config.clone();
875 let cancelled = self.cancelled.clone();
876
877 let handle = tokio::spawn(async move {
878 let result = execute_single_task(&task, &config, &cancelled).await;
880
881 let task_id = task.id.clone();
883 {
884 let mut task_info = tasks_info.lock().await;
885 if let Some(info) = task_info.get_mut(&task_id) {
886 info.completed_at = Some(Utc::now());
887 if result.success {
888 info.status = TaskStatus::Completed;
889 info.result = result.result.clone();
890 } else {
891 info.status = TaskStatus::Failed;
892 info.last_error = result.error.clone();
893 }
894 info.retries = result.retries;
895 }
896 }
897
898 if result.success {
900 completed.lock().await.insert(task_id.clone());
901 } else {
902 failed.lock().await.insert(task_id.clone());
903 }
904
905 running.lock().await.remove(&task_id);
907
908 results.lock().await.push(result);
910 });
911
912 handles.push(handle);
913 }
914
915 if !handles.is_empty() {
917 for handle in handles {
919 let _ = handle.await;
920 }
921 } else {
922 tokio::time::sleep(Duration::from_millis(10)).await;
924 }
925
926 if self.config.stop_on_first_error {
928 let failed_guard = failed.lock().await;
929 if !failed_guard.is_empty() {
930 *self.cancelled.lock().await = true;
932 break;
933 }
934 }
935 }
936
937 Ok(())
938 }
939
940 pub async fn cancel(&mut self, task_id: Option<&str>) {
942 if let Some(id) = task_id {
943 let mut task_info = self.tasks.lock().await;
945 if let Some(info) = task_info.get_mut(id) {
946 info.status = TaskStatus::Cancelled;
947 }
948 } else {
949 *self.cancelled.lock().await = true;
951 }
952 }
953
954 pub async fn get_progress(&self) -> ExecutionProgress {
956 let task_info = self.tasks.lock().await;
957 let cancelled = *self.cancelled.lock().await;
958
959 let mut completed = 0;
960 let mut failed = 0;
961 let mut running = 0;
962 let mut pending = 0;
963 let mut skipped = 0;
964
965 for info in task_info.values() {
966 match info.status {
967 TaskStatus::Completed => completed += 1,
968 TaskStatus::Failed => failed += 1,
969 TaskStatus::Running => running += 1,
970 TaskStatus::Pending | TaskStatus::WaitingForDependencies => pending += 1,
971 TaskStatus::Cancelled | TaskStatus::Skipped => skipped += 1,
972 }
973 }
974
975 ExecutionProgress {
976 total: task_info.len(),
977 completed,
978 failed,
979 running,
980 pending,
981 skipped,
982 cancelled,
983 }
984 }
985
986 pub async fn is_running(&self) -> bool {
988 *self.running.lock().await
989 }
990
991 pub async fn is_cancelled(&self) -> bool {
993 *self.cancelled.lock().await
994 }
995}
996
997async fn execute_single_task(
999 task: &AgentTask,
1000 config: &ParallelAgentConfig,
1001 cancelled: &Arc<Mutex<bool>>,
1002) -> AgentResult {
1003 let start_time = Utc::now();
1004 let task_timeout = task.timeout.unwrap_or(config.timeout);
1005 let max_retries = if config.retry_on_failure {
1006 config.max_retries
1007 } else {
1008 0
1009 };
1010
1011 let mut retries = 0;
1012 #[allow(unused_assignments)]
1013 let mut last_error = None;
1014
1015 loop {
1016 if *cancelled.lock().await {
1018 return AgentResult {
1019 task_id: task.id.clone(),
1020 success: false,
1021 result: None,
1022 error: Some("Cancelled".to_string()),
1023 duration: (Utc::now() - start_time).to_std().unwrap_or(Duration::ZERO),
1024 retries,
1025 started_at: start_time,
1026 completed_at: Utc::now(),
1027 };
1028 }
1029
1030 let execution = timeout(task_timeout, simulate_task_execution(task));
1032
1033 match execution.await {
1034 Ok(Ok(result)) => {
1035 return AgentResult {
1036 task_id: task.id.clone(),
1037 success: true,
1038 result: Some(result),
1039 error: None,
1040 duration: (Utc::now() - start_time).to_std().unwrap_or(Duration::ZERO),
1041 retries,
1042 started_at: start_time,
1043 completed_at: Utc::now(),
1044 };
1045 }
1046 Ok(Err(e)) => {
1047 last_error = Some(e.to_string());
1048 }
1049 Err(_) => {
1050 last_error = Some(format!("Task timed out after {:?}", task_timeout));
1051 }
1052 }
1053
1054 if retries >= max_retries {
1056 break;
1057 }
1058
1059 retries += 1;
1060 tokio::time::sleep(config.retry_delay).await;
1061 }
1062
1063 AgentResult {
1064 task_id: task.id.clone(),
1065 success: false,
1066 result: None,
1067 error: last_error,
1068 duration: (Utc::now() - start_time).to_std().unwrap_or(Duration::ZERO),
1069 retries,
1070 started_at: start_time,
1071 completed_at: Utc::now(),
1072 }
1073}
1074
1075async fn simulate_task_execution(task: &AgentTask) -> Result<Value, String> {
1077 tokio::time::sleep(Duration::from_millis(10)).await;
1082
1083 Ok(serde_json::json!({
1085 "task_id": task.id,
1086 "task_type": task.task_type,
1087 "status": "completed",
1088 "output": format!("Executed task: {}", task.prompt)
1089 }))
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094 use super::*;
1095 use serde_json::json;
1096
1097 #[test]
1098 fn test_agent_task_creation() {
1099 let task = AgentTask::new("task-1", "explore", "Find all Rust files");
1100
1101 assert_eq!(task.id, "task-1");
1102 assert_eq!(task.task_type, "explore");
1103 assert_eq!(task.prompt, "Find all Rust files");
1104 assert!(task.description.is_none());
1105 assert!(task.options.is_none());
1106 assert!(task.priority.is_none());
1107 assert!(task.dependencies.is_none());
1108 assert!(task.timeout.is_none());
1109 }
1110
1111 #[test]
1112 fn test_agent_task_builder() {
1113 let task = AgentTask::new("task-1", "plan", "Create implementation plan")
1114 .with_description("Detailed planning task")
1115 .with_priority(5)
1116 .with_dependencies(vec!["task-0".to_string()])
1117 .with_timeout(Duration::from_secs(60));
1118
1119 assert_eq!(task.description, Some("Detailed planning task".to_string()));
1120 assert_eq!(task.priority, Some(5));
1121 assert_eq!(task.dependencies, Some(vec!["task-0".to_string()]));
1122 assert_eq!(task.timeout, Some(Duration::from_secs(60)));
1123 }
1124
1125 #[test]
1126 fn test_task_effective_priority() {
1127 let task_no_priority = AgentTask::new("t1", "test", "test");
1128 assert_eq!(task_no_priority.effective_priority(), 0);
1129
1130 let task_with_priority = AgentTask::new("t2", "test", "test").with_priority(10);
1131 assert_eq!(task_with_priority.effective_priority(), 10);
1132 }
1133
1134 #[test]
1135 fn test_task_has_dependencies() {
1136 let task_no_deps = AgentTask::new("t1", "test", "test");
1137 assert!(!task_no_deps.has_dependencies());
1138
1139 let task_empty_deps = AgentTask::new("t2", "test", "test").with_dependencies(vec![]);
1140 assert!(!task_empty_deps.has_dependencies());
1141
1142 let task_with_deps =
1143 AgentTask::new("t3", "test", "test").with_dependencies(vec!["t1".to_string()]);
1144 assert!(task_with_deps.has_dependencies());
1145 }
1146
1147 #[test]
1148 fn test_dependency_graph_creation() {
1149 let mut graph = DependencyGraph::new();
1150 graph.add_task("task-1");
1151 graph.add_task("task-2");
1152 graph.add_dependency("task-2", "task-1");
1153
1154 assert!(graph.contains("task-1"));
1155 assert!(graph.contains("task-2"));
1156 assert!(!graph.contains("task-3"));
1157
1158 let deps = graph.get_dependencies("task-2");
1159 assert!(deps.contains("task-1"));
1160
1161 let dependents = graph.get_dependents("task-1");
1162 assert!(dependents.contains("task-2"));
1163 }
1164
1165 #[test]
1166 fn test_dependency_graph_ready_tasks() {
1167 let mut graph = DependencyGraph::new();
1168 graph.add_task("task-1");
1169 graph.add_task("task-2");
1170 graph.add_task("task-3");
1171 graph.add_dependency("task-2", "task-1");
1172 graph.add_dependency("task-3", "task-2");
1173
1174 let completed = HashSet::new();
1175 let running = HashSet::new();
1176
1177 let ready = graph.get_ready_tasks(&completed, &running);
1179 assert_eq!(ready.len(), 1);
1180 assert!(ready.contains(&"task-1".to_string()));
1181
1182 let mut completed = HashSet::new();
1184 completed.insert("task-1".to_string());
1185 let ready = graph.get_ready_tasks(&completed, &running);
1186 assert_eq!(ready.len(), 1);
1187 assert!(ready.contains(&"task-2".to_string()));
1188
1189 completed.insert("task-2".to_string());
1191 let ready = graph.get_ready_tasks(&completed, &running);
1192 assert_eq!(ready.len(), 1);
1193 assert!(ready.contains(&"task-3".to_string()));
1194 }
1195
1196 #[test]
1197 fn test_create_dependency_graph_from_tasks() {
1198 let tasks = vec![
1199 AgentTask::new("task-1", "test", "test"),
1200 AgentTask::new("task-2", "test", "test").with_dependencies(vec!["task-1".to_string()]),
1201 AgentTask::new("task-3", "test", "test")
1202 .with_dependencies(vec!["task-1".to_string(), "task-2".to_string()]),
1203 ];
1204
1205 let graph = create_dependency_graph(&tasks);
1206
1207 assert!(graph.contains("task-1"));
1208 assert!(graph.contains("task-2"));
1209 assert!(graph.contains("task-3"));
1210
1211 assert!(graph.get_dependencies("task-1").is_empty());
1212 assert_eq!(graph.get_dependencies("task-2").len(), 1);
1213 assert_eq!(graph.get_dependencies("task-3").len(), 2);
1214 }
1215
1216 #[test]
1217 fn test_validate_valid_dependencies() {
1218 let tasks = vec![
1219 AgentTask::new("task-1", "test", "test"),
1220 AgentTask::new("task-2", "test", "test").with_dependencies(vec!["task-1".to_string()]),
1221 ];
1222
1223 let result = validate_task_dependencies(&tasks);
1224 assert!(result.valid);
1225 assert!(result.errors.is_empty());
1226 }
1227
1228 #[test]
1229 fn test_validate_missing_dependency() {
1230 let tasks = vec![AgentTask::new("task-1", "test", "test")
1231 .with_dependencies(vec!["non-existent".to_string()])];
1232
1233 let result = validate_task_dependencies(&tasks);
1234 assert!(!result.valid);
1235 assert!(!result.errors.is_empty());
1236 assert_eq!(result.missing_dependencies.len(), 1);
1237 }
1238
1239 #[test]
1240 fn test_validate_circular_dependency() {
1241 let tasks = vec![
1242 AgentTask::new("task-1", "test", "test").with_dependencies(vec!["task-2".to_string()]),
1243 AgentTask::new("task-2", "test", "test").with_dependencies(vec!["task-1".to_string()]),
1244 ];
1245
1246 let result = validate_task_dependencies(&tasks);
1247 assert!(!result.valid);
1248 assert!(result.circular_dependencies.is_some());
1249 }
1250
1251 #[test]
1252 fn test_validate_self_dependency() {
1253 let tasks =
1254 vec![AgentTask::new("task-1", "test", "test")
1255 .with_dependencies(vec!["task-1".to_string()])];
1256
1257 let result = validate_task_dependencies(&tasks);
1258 assert!(!result.valid);
1259 assert!(result.circular_dependencies.is_some());
1260 }
1261
1262 #[test]
1263 fn test_merge_agent_results() {
1264 let results = vec![
1265 AgentResult {
1266 task_id: "task-1".to_string(),
1267 success: true,
1268 result: Some(json!({"output": "result1"})),
1269 error: None,
1270 duration: Duration::from_secs(1),
1271 retries: 0,
1272 started_at: Utc::now(),
1273 completed_at: Utc::now(),
1274 },
1275 AgentResult {
1276 task_id: "task-2".to_string(),
1277 success: true,
1278 result: Some(json!({"output": "result2"})),
1279 error: None,
1280 duration: Duration::from_secs(2),
1281 retries: 0,
1282 started_at: Utc::now(),
1283 completed_at: Utc::now(),
1284 },
1285 AgentResult {
1286 task_id: "task-3".to_string(),
1287 success: false,
1288 result: None,
1289 error: Some("Failed".to_string()),
1290 duration: Duration::from_secs(1),
1291 retries: 3,
1292 started_at: Utc::now(),
1293 completed_at: Utc::now(),
1294 },
1295 ];
1296
1297 let merged = merge_agent_results(results);
1298
1299 assert_eq!(merged.outputs.len(), 2); assert!(merged.summary.is_some());
1301 assert_eq!(merged.metadata.get("total_tasks"), Some(&json!(3)));
1302 assert_eq!(merged.metadata.get("successful_tasks"), Some(&json!(2)));
1303 assert_eq!(merged.metadata.get("failed_tasks"), Some(&json!(1)));
1304 }
1305
1306 #[test]
1307 fn test_parallel_config_default() {
1308 let config = ParallelAgentConfig::default();
1309
1310 assert_eq!(config.max_concurrency, 4);
1311 assert_eq!(config.timeout, Duration::from_secs(300));
1312 assert!(config.retry_on_failure);
1313 assert!(!config.stop_on_first_error);
1314 assert_eq!(config.max_retries, 3);
1315 assert_eq!(config.retry_delay, Duration::from_secs(1));
1316 }
1317
1318 #[tokio::test]
1319 async fn test_executor_creation() {
1320 let executor = ParallelAgentExecutor::new(None);
1321 assert!(!executor.is_running().await);
1322 assert!(!executor.is_cancelled().await);
1323 }
1324
1325 #[tokio::test]
1326 async fn test_executor_simple_execution() {
1327 let mut executor = ParallelAgentExecutor::new(Some(ParallelAgentConfig {
1328 max_concurrency: 2,
1329 timeout: Duration::from_secs(10),
1330 retry_on_failure: false,
1331 stop_on_first_error: false,
1332 max_retries: 0,
1333 retry_delay: Duration::from_millis(100),
1334 }));
1335
1336 let tasks = vec![
1337 AgentTask::new("task-1", "test", "Test task 1"),
1338 AgentTask::new("task-2", "test", "Test task 2"),
1339 ];
1340
1341 let result = executor.execute(tasks).await.unwrap();
1342
1343 assert!(result.success);
1344 assert_eq!(result.results.len(), 2);
1345 assert_eq!(result.successful_count, 2);
1346 assert_eq!(result.failed_count, 0);
1347 }
1348
1349 #[tokio::test]
1350 async fn test_executor_with_dependencies() {
1351 let mut executor = ParallelAgentExecutor::new(Some(ParallelAgentConfig {
1352 max_concurrency: 2,
1353 timeout: Duration::from_secs(10),
1354 retry_on_failure: false,
1355 stop_on_first_error: false,
1356 max_retries: 0,
1357 retry_delay: Duration::from_millis(100),
1358 }));
1359
1360 let tasks = vec![
1361 AgentTask::new("task-1", "test", "First task"),
1362 AgentTask::new("task-2", "test", "Second task")
1363 .with_dependencies(vec!["task-1".to_string()]),
1364 AgentTask::new("task-3", "test", "Third task")
1365 .with_dependencies(vec!["task-2".to_string()]),
1366 ];
1367
1368 let result = executor.execute_with_dependencies(tasks).await.unwrap();
1369
1370 assert!(result.success);
1371 assert_eq!(result.results.len(), 3);
1372 assert_eq!(result.successful_count, 3);
1373 }
1374
1375 #[tokio::test]
1376 async fn test_executor_circular_dependency_error() {
1377 let mut executor = ParallelAgentExecutor::new(None);
1378
1379 let tasks = vec![
1380 AgentTask::new("task-1", "test", "test").with_dependencies(vec!["task-2".to_string()]),
1381 AgentTask::new("task-2", "test", "test").with_dependencies(vec!["task-1".to_string()]),
1382 ];
1383
1384 let result = executor.execute_with_dependencies(tasks).await;
1385
1386 assert!(matches!(result, Err(ExecutorError::CircularDependency(_))));
1387 }
1388
1389 #[tokio::test]
1390 async fn test_executor_invalid_dependency_error() {
1391 let mut executor = ParallelAgentExecutor::new(None);
1392
1393 let tasks = vec![AgentTask::new("task-1", "test", "test")
1394 .with_dependencies(vec!["non-existent".to_string()])];
1395
1396 let result = executor.execute_with_dependencies(tasks).await;
1397
1398 assert!(matches!(
1399 result,
1400 Err(ExecutorError::InvalidDependency { .. })
1401 ));
1402 }
1403
1404 #[tokio::test]
1405 async fn test_executor_progress() {
1406 let executor = ParallelAgentExecutor::new(None);
1407
1408 let progress = executor.get_progress().await;
1409
1410 assert_eq!(progress.total, 0);
1411 assert_eq!(progress.completed, 0);
1412 assert_eq!(progress.failed, 0);
1413 assert_eq!(progress.running, 0);
1414 assert_eq!(progress.pending, 0);
1415 assert!(!progress.cancelled);
1416 }
1417
1418 #[tokio::test]
1419 async fn test_executor_concurrency_limit() {
1420 let mut executor = ParallelAgentExecutor::new(Some(ParallelAgentConfig {
1421 max_concurrency: 1, timeout: Duration::from_secs(10),
1423 retry_on_failure: false,
1424 stop_on_first_error: false,
1425 max_retries: 0,
1426 retry_delay: Duration::from_millis(100),
1427 }));
1428
1429 let tasks = vec![
1430 AgentTask::new("task-1", "test", "Test 1"),
1431 AgentTask::new("task-2", "test", "Test 2"),
1432 AgentTask::new("task-3", "test", "Test 3"),
1433 ];
1434
1435 let result = executor.execute(tasks).await.unwrap();
1436
1437 assert!(result.success);
1438 assert_eq!(result.results.len(), 3);
1439 }
1440
1441 #[tokio::test]
1442 async fn test_executor_priority_ordering() {
1443 let mut executor = ParallelAgentExecutor::new(Some(ParallelAgentConfig {
1444 max_concurrency: 1, timeout: Duration::from_secs(10),
1446 retry_on_failure: false,
1447 stop_on_first_error: false,
1448 max_retries: 0,
1449 retry_delay: Duration::from_millis(100),
1450 }));
1451
1452 let tasks = vec![
1453 AgentTask::new("low", "test", "Low priority").with_priority(1),
1454 AgentTask::new("high", "test", "High priority").with_priority(10),
1455 AgentTask::new("medium", "test", "Medium priority").with_priority(5),
1456 ];
1457
1458 let result = executor.execute(tasks).await.unwrap();
1459
1460 assert!(result.success);
1461 assert_eq!(result.results[0].task_id, "high");
1463 assert_eq!(result.results[1].task_id, "medium");
1464 assert_eq!(result.results[2].task_id, "low");
1465 }
1466}