potato_workflow/workflow/
tasklist.rs1use crate::workflow::error::WorkflowError;
2
3pub use potato_agent::agents::{
4 agent::Agent,
5 task::{Task, TaskStatus, WorkflowTask},
6};
7use potato_agent::AgentResponse;
8use pyo3::prelude::*;
9use serde::Deserialize;
10use serde::Serialize;
11use serde_json::Value;
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use std::sync::RwLock;
15use tracing::instrument;
16use tracing::{debug, warn};
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
19#[pyclass]
20pub struct TaskList {
21 pub tasks: HashMap<String, Arc<RwLock<Task>>>,
22 pub execution_order: Vec<String>,
23}
24
25impl PartialEq for TaskList {
26 fn eq(&self, other: &Self) -> bool {
27 self.tasks.keys().eq(other.tasks.keys()) && self.execution_order == other.execution_order
29 }
30}
31
32#[pymethods]
33impl TaskList {
34 #[getter]
35 pub fn items(&self) -> HashMap<String, Task> {
36 self.tasks()
37 }
38 pub fn tasks(&self) -> HashMap<String, Task> {
40 self.tasks
41 .iter()
42 .map(|(id, task)| {
43 let cloned_task = task.read().unwrap().clone();
44 (id.clone(), cloned_task)
45 })
46 .collect()
47 }
48
49 pub fn deep_clone(&self) -> Result<Self, WorkflowError> {
51 let mut new_task_list = TaskList::new();
52
53 for (task_id, task_arc) in &self.tasks {
55 let task = task_arc.read().unwrap();
56 let cloned_task = task.clone(); new_task_list
58 .tasks
59 .insert(task_id.clone(), Arc::new(RwLock::new(cloned_task)));
60 }
61
62 new_task_list.execution_order = self.execution_order.clone();
64
65 Ok(new_task_list)
66 }
67}
68
69impl TaskList {
70 pub fn new() -> Self {
71 Self {
72 tasks: HashMap::new(),
73 execution_order: Vec::new(),
74 }
75 }
76
77 pub fn len(&self) -> usize {
78 self.tasks.len()
79 }
80
81 pub fn is_empty(&self) -> bool {
82 self.tasks.is_empty()
83 }
84
85 pub fn rebuild_task_validators(&mut self) -> Result<(), WorkflowError> {
86 for task_arc in self.tasks.values_mut() {
87 let mut task = task_arc.write().unwrap();
88 task.rebuild_validator()?
89 }
90 Ok(())
91 }
92
93 pub fn is_complete(&self) -> bool {
94 self.tasks.values().all(|task| {
95 task.read().unwrap().status == TaskStatus::Completed
96 || task.read().unwrap().status == TaskStatus::Failed
97 })
98 }
99
100 pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
101 if self.tasks.contains_key(&task.id) {
103 return Err(WorkflowError::TaskAlreadyExists(task.id.clone()));
104 }
105
106 for dep_id in &task.dependencies {
108 if !self.tasks.contains_key(dep_id) {
109 return Err(WorkflowError::DependencyNotFound(dep_id.clone()));
110 }
111
112 if dep_id == &task.id {
114 return Err(WorkflowError::TaskDependsOnItself(task.id.clone()));
115 }
116 }
117
118 self.tasks
120 .insert(task.id.clone(), Arc::new(RwLock::new(task)));
121 self.rebuild_execution_order();
122 Ok(())
123 }
124
125 pub fn get_task(&self, task_id: &str) -> Option<Arc<RwLock<Task>>> {
126 self.tasks.get(task_id).cloned()
127 }
128
129 pub fn get_task_responses(&self) -> Result<HashMap<String, Value>, WorkflowError> {
131 let mut responses = HashMap::new();
132
133 for (task_id, task_arc) in &self.tasks {
134 let task = task_arc
135 .read()
136 .map_err(|_| WorkflowError::ReadLockAcquireError)?;
137 if let Some(result) = &task.result {
138 if let Some(value) = result.response_value() {
139 responses.insert(task_id.clone(), value);
140 } else {
141 responses.insert(task_id.clone(), Value::Null);
143 }
144 }
145 }
146
147 Ok(responses)
148 }
149
150 pub fn remove_task(&mut self, task_id: &str) {
151 self.tasks.remove(task_id);
152 }
153
154 pub fn pending_count(&self) -> usize {
155 self.tasks
156 .values()
157 .filter(|task| task.read().unwrap().status == TaskStatus::Pending)
158 .count()
159 }
160
161 #[instrument(skip_all)]
162 pub fn update_task_status(
163 &mut self,
164 task_id: &str,
165 status: TaskStatus,
166 result: Option<&AgentResponse>,
167 ) {
168 debug!(status=?status, result=?result, "Updating task status");
169 if let Some(task) = self.tasks.get_mut(task_id) {
170 let mut task = task.write().unwrap();
171 task.status = status;
172 task.result = result.cloned();
173 }
174 }
175
176 fn topological_sort(
177 &self,
178 task_id: &str,
179 visited: &mut HashSet<String>,
180 temp_visited: &mut HashSet<String>,
181 order: &mut Vec<String>,
182 ) {
183 if temp_visited.contains(task_id) {
184 return; }
186
187 if visited.contains(task_id) {
188 return;
189 }
190
191 temp_visited.insert(task_id.to_string());
192
193 if let Some(task) = self.tasks.get(task_id) {
194 for dep_id in &task.read().unwrap().dependencies {
195 self.topological_sort(dep_id, visited, temp_visited, order);
196 }
197 }
198
199 temp_visited.remove(task_id);
200 visited.insert(task_id.to_string());
201 order.push(task_id.to_string());
202 }
203
204 fn rebuild_execution_order(&mut self) {
205 let mut order = Vec::new();
206 let mut visited = HashSet::new();
207 let mut temp_visited = HashSet::new();
208
209 for task_id in self.tasks.keys() {
210 if !visited.contains(task_id) {
211 self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
212 }
213 }
214
215 self.execution_order = order;
216 }
217
218 pub fn get_ready_tasks(&self) -> Vec<Arc<RwLock<Task>>> {
223 self.tasks
224 .values()
225 .filter(|task_arc| {
226 let task = task_arc.read().unwrap();
227 task.status == TaskStatus::Pending
228 && task.dependencies.iter().all(|dep_id| {
229 self.tasks
230 .get(dep_id)
231 .map(|dep| dep.read().unwrap().status == TaskStatus::Completed)
232 .unwrap_or(false)
233 })
234 })
235 .cloned()
236 .collect()
237 }
238
239 pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
240 for task in self.tasks.values_mut() {
241 let mut task = task.write().unwrap();
242 if task.status == TaskStatus::Failed {
243 task.status = TaskStatus::Pending;
244 task.increment_retry();
245 if task.retry_count > task.max_retries {
246 return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
247 }
248 }
249 }
250 Ok(())
251 }
252
253 pub fn get_last_task_id(&self) -> Option<String> {
254 self.execution_order.last().cloned()
255 }
256}