potato_workflow/workflow/
tasklist.rs1use crate::workflow::error::WorkflowError;
2
3pub use potato_agent::agents::{
4 agent::Agent,
5 task::{PyTask, Task, TaskStatus},
6};
7use potato_agent::AgentResponse;
8use pyo3::prelude::*;
9use serde::Deserialize;
10use serde::Serialize;
11use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13use std::sync::RwLock;
14use tracing::instrument;
15use tracing::{debug, warn};
16
17#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18#[pyclass]
19pub struct TaskList {
20 pub tasks: HashMap<String, Arc<RwLock<Task>>>,
21 pub execution_order: Vec<String>,
22}
23
24impl PartialEq for TaskList {
25 fn eq(&self, other: &Self) -> bool {
26 self.tasks.keys().eq(other.tasks.keys()) && self.execution_order == other.execution_order
28 }
29}
30
31#[pymethods]
32impl TaskList {
33 pub fn tasks(&self) -> HashMap<String, Task> {
35 self.tasks
36 .iter()
37 .map(|(id, task)| {
38 let cloned_task = task.read().unwrap().clone();
39 (id.clone(), cloned_task)
40 })
41 .collect()
42 }
43
44 pub fn deep_clone(&self) -> Result<Self, WorkflowError> {
46 let mut new_task_list = TaskList::new();
47
48 for (task_id, task_arc) in &self.tasks {
50 let task = task_arc.read().unwrap();
51 let cloned_task = task.clone(); new_task_list
53 .tasks
54 .insert(task_id.clone(), Arc::new(RwLock::new(cloned_task)));
55 }
56
57 new_task_list.execution_order = self.execution_order.clone();
59
60 Ok(new_task_list)
61 }
62}
63
64impl TaskList {
65 pub fn new() -> Self {
66 Self {
67 tasks: HashMap::new(),
68 execution_order: Vec::new(),
69 }
70 }
71
72 pub fn len(&self) -> usize {
73 self.tasks.len()
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.tasks.is_empty()
78 }
79
80 pub fn is_complete(&self) -> bool {
81 self.tasks.values().all(|task| {
82 task.read().unwrap().status == TaskStatus::Completed
83 || task.read().unwrap().status == TaskStatus::Failed
84 })
85 }
86
87 pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
88 if self.tasks.contains_key(&task.id) {
90 return Err(WorkflowError::TaskAlreadyExists(task.id.clone()));
91 }
92
93 for dep_id in &task.dependencies {
95 if !self.tasks.contains_key(dep_id) {
96 return Err(WorkflowError::DependencyNotFound(dep_id.clone()));
97 }
98
99 if dep_id == &task.id {
101 return Err(WorkflowError::TaskDependsOnItself(task.id.clone()));
102 }
103 }
104
105 self.tasks
107 .insert(task.id.clone(), Arc::new(RwLock::new(task)));
108 self.rebuild_execution_order();
109 Ok(())
110 }
111
112 pub fn get_task(&self, task_id: &str) -> Option<Arc<RwLock<Task>>> {
113 self.tasks.get(task_id).cloned()
114 }
115
116 pub fn remove_task(&mut self, task_id: &str) {
117 self.tasks.remove(task_id);
118 }
119
120 pub fn pending_count(&self) -> usize {
121 self.tasks
122 .values()
123 .filter(|task| task.read().unwrap().status == TaskStatus::Pending)
124 .count()
125 }
126
127 #[instrument(skip_all)]
128 pub fn update_task_status(
129 &mut self,
130 task_id: &str,
131 status: TaskStatus,
132 result: Option<&AgentResponse>,
133 ) {
134 debug!(status=?status, result=?result, "Updating task status");
135 if let Some(task) = self.tasks.get_mut(task_id) {
136 let mut task = task.write().unwrap();
137 task.status = status;
138 task.result = result.cloned();
139 }
140 }
141
142 fn topological_sort(
143 &self,
144 task_id: &str,
145 visited: &mut HashSet<String>,
146 temp_visited: &mut HashSet<String>,
147 order: &mut Vec<String>,
148 ) {
149 if temp_visited.contains(task_id) {
150 return; }
152
153 if visited.contains(task_id) {
154 return;
155 }
156
157 temp_visited.insert(task_id.to_string());
158
159 if let Some(task) = self.tasks.get(task_id) {
160 for dep_id in &task.read().unwrap().dependencies {
161 self.topological_sort(dep_id, visited, temp_visited, order);
162 }
163 }
164
165 temp_visited.remove(task_id);
166 visited.insert(task_id.to_string());
167 order.push(task_id.to_string());
168 }
169
170 fn rebuild_execution_order(&mut self) {
171 let mut order = Vec::new();
172 let mut visited = HashSet::new();
173 let mut temp_visited = HashSet::new();
174
175 for task_id in self.tasks.keys() {
176 if !visited.contains(task_id) {
177 self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
178 }
179 }
180
181 self.execution_order = order;
182 }
183
184 pub fn get_ready_tasks(&self) -> Vec<Arc<RwLock<Task>>> {
189 self.tasks
190 .values()
191 .filter(|task_arc| {
192 let task = task_arc.read().unwrap();
193 task.status == TaskStatus::Pending
194 && task.dependencies.iter().all(|dep_id| {
195 self.tasks
196 .get(dep_id)
197 .map(|dep| dep.read().unwrap().status == TaskStatus::Completed)
198 .unwrap_or(false)
199 })
200 })
201 .cloned()
202 .collect()
203 }
204
205 pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
206 for task in self.tasks.values_mut() {
207 let mut task = task.write().unwrap();
208 if task.status == TaskStatus::Failed {
209 task.status = TaskStatus::Pending;
210 task.increment_retry();
211 if task.retry_count > task.max_retries {
212 return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
213 }
214 }
215 }
216 Ok(())
217 }
218}