potato_workflow/workflow/
tasklist.rs1use crate::workflow::error::WorkflowError;
2
3pub use potato_agents::agents::{
4 agent::Agent,
5 task::{PyTask, Task, TaskStatus},
6 types::ChatResponse,
7};
8
9use pyo3::prelude::*;
10use serde::Deserialize;
11use serde::Serialize;
12use std::collections::{HashMap, HashSet};
13use tracing::instrument;
14use tracing::{debug, warn};
15
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17#[pyclass]
18pub struct TaskList {
19 #[pyo3(get)]
20 pub tasks: HashMap<String, Task>,
21 pub execution_order: Vec<String>,
22}
23
24impl TaskList {
25 pub fn new() -> Self {
26 Self {
27 tasks: HashMap::new(),
28 execution_order: Vec::new(),
29 }
30 }
31
32 pub fn is_complete(&self) -> bool {
33 self.tasks
34 .values()
35 .all(|task| task.status == TaskStatus::Completed || task.status == TaskStatus::Failed)
36 }
37
38 pub fn add_task(&mut self, task: Task) {
39 self.tasks.insert(task.id.clone(), task);
40 self.rebuild_execution_order();
41 }
42
43 pub fn get_task(&self, task_id: &str) -> Option<&Task> {
44 self.tasks.get(task_id)
45 }
46
47 pub fn remove_task(&mut self, task_id: &str) {
48 self.tasks.remove(task_id);
49 }
50
51 pub fn pending_count(&self) -> usize {
52 self.tasks
53 .values()
54 .filter(|task| task.status == TaskStatus::Pending)
55 .count()
56 }
57
58 #[instrument(skip_all)]
59 pub fn update_task_status(
60 &mut self,
61 task_id: &str,
62 status: TaskStatus,
63 result: Option<ChatResponse>,
64 ) {
65 debug!(status=?status, result=?result, "Updating task status");
66 if let Some(task) = self.tasks.get_mut(task_id) {
67 task.status = status;
68 task.result = result;
69 }
70 }
71
72 fn topological_sort(
73 &self,
74 task_id: &str,
75 visited: &mut HashSet<String>,
76 temp_visited: &mut HashSet<String>,
77 order: &mut Vec<String>,
78 ) {
79 if temp_visited.contains(task_id) {
80 return; }
82
83 if visited.contains(task_id) {
84 return;
85 }
86
87 temp_visited.insert(task_id.to_string());
88
89 if let Some(task) = self.tasks.get(task_id) {
90 for dep_id in &task.dependencies {
91 self.topological_sort(dep_id, visited, temp_visited, order);
92 }
93 }
94
95 temp_visited.remove(task_id);
96 visited.insert(task_id.to_string());
97 order.push(task_id.to_string());
98 }
99
100 fn rebuild_execution_order(&mut self) {
101 let mut order = Vec::new();
102 let mut visited = HashSet::new();
103 let mut temp_visited = HashSet::new();
104
105 for task_id in self.tasks.keys() {
106 if !visited.contains(task_id) {
107 self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
108 }
109 }
110
111 self.execution_order = order;
112 }
113
114 pub fn get_ready_tasks(&self) -> Vec<Task> {
119 self.tasks
120 .values()
121 .filter(|task| {
122 task.status == TaskStatus::Pending
123 && task.dependencies.iter().all(|dep_id| {
124 self.tasks
125 .get(dep_id)
126 .map(|dep| dep.status == TaskStatus::Completed)
127 .unwrap_or(false)
128 })
129 })
130 .cloned()
131 .collect()
132 }
133
134 pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
135 for task in self.tasks.values_mut() {
136 if task.status == TaskStatus::Failed {
137 task.status = TaskStatus::Pending;
138 task.increment_retry();
139 if task.retry_count > task.max_retries {
140 return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
141 }
142 }
143 }
144 Ok(())
145 }
146}