potato_workflow/workflow/
tasklist.rs1use crate::workflow::error::WorkflowError;
2
3pub use potato_agent::agents::{
4 agent::Agent,
5 task::{PyTask, Task, TaskStatus},
6 types::ChatResponse,
7};
8use potato_agent::AgentResponse;
9use pyo3::prelude::*;
10use serde::Deserialize;
11use serde::Serialize;
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use std::sync::RwLock;
15use tracing::instrument;
16use tracing::{debug, warn};
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
19#[pyclass]
20pub struct TaskList {
21 pub tasks: HashMap<String, Arc<RwLock<Task>>>,
22 pub execution_order: Vec<String>,
23}
24
25impl PartialEq for TaskList {
26 fn eq(&self, other: &Self) -> bool {
27 self.tasks.keys().eq(other.tasks.keys()) && self.execution_order == other.execution_order
29 }
30}
31
32#[pymethods]
33impl TaskList {
34 pub fn tasks(&self) -> HashMap<String, Task> {
36 self.tasks
37 .iter()
38 .map(|(id, task)| {
39 let cloned_task = task.read().unwrap().clone();
40 (id.clone(), cloned_task)
41 })
42 .collect()
43 }
44
45 pub fn deep_clone(&self) -> Result<Self, WorkflowError> {
47 let mut new_task_list = TaskList::new();
48
49 for (task_id, task_arc) in &self.tasks {
51 let task = task_arc.read().unwrap();
52 let cloned_task = task.clone(); new_task_list
54 .tasks
55 .insert(task_id.clone(), Arc::new(RwLock::new(cloned_task)));
56 }
57
58 new_task_list.execution_order = self.execution_order.clone();
60
61 Ok(new_task_list)
62 }
63}
64
65impl TaskList {
66 pub fn new() -> Self {
67 Self {
68 tasks: HashMap::new(),
69 execution_order: Vec::new(),
70 }
71 }
72
73 pub fn len(&self) -> usize {
74 self.tasks.len()
75 }
76
77 pub fn is_empty(&self) -> bool {
78 self.tasks.is_empty()
79 }
80
81 pub fn is_complete(&self) -> bool {
82 self.tasks.values().all(|task| {
83 task.read().unwrap().status == TaskStatus::Completed
84 || task.read().unwrap().status == TaskStatus::Failed
85 })
86 }
87
88 pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
89 if self.tasks.contains_key(&task.id) {
91 return Err(WorkflowError::TaskAlreadyExists(task.id.clone()));
92 }
93
94 for dep_id in &task.dependencies {
96 if !self.tasks.contains_key(dep_id) {
97 return Err(WorkflowError::DependencyNotFound(dep_id.clone()));
98 }
99
100 if dep_id == &task.id {
102 return Err(WorkflowError::TaskDependsOnItself(task.id.clone()));
103 }
104 }
105
106 self.tasks
108 .insert(task.id.clone(), Arc::new(RwLock::new(task)));
109 self.rebuild_execution_order();
110 Ok(())
111 }
112
113 pub fn get_task(&self, task_id: &str) -> Option<Arc<RwLock<Task>>> {
114 self.tasks.get(task_id).cloned()
115 }
116
117 pub fn remove_task(&mut self, task_id: &str) {
118 self.tasks.remove(task_id);
119 }
120
121 pub fn pending_count(&self) -> usize {
122 self.tasks
123 .values()
124 .filter(|task| task.read().unwrap().status == TaskStatus::Pending)
125 .count()
126 }
127
128 #[instrument(skip_all)]
129 pub fn update_task_status(
130 &mut self,
131 task_id: &str,
132 status: TaskStatus,
133 result: Option<&AgentResponse>,
134 ) {
135 debug!(status=?status, result=?result, "Updating task status");
136 if let Some(task) = self.tasks.get_mut(task_id) {
137 let mut task = task.write().unwrap();
138 task.status = status;
139 task.result = result.cloned();
140 }
141 }
142
143 fn topological_sort(
144 &self,
145 task_id: &str,
146 visited: &mut HashSet<String>,
147 temp_visited: &mut HashSet<String>,
148 order: &mut Vec<String>,
149 ) {
150 if temp_visited.contains(task_id) {
151 return; }
153
154 if visited.contains(task_id) {
155 return;
156 }
157
158 temp_visited.insert(task_id.to_string());
159
160 if let Some(task) = self.tasks.get(task_id) {
161 for dep_id in &task.read().unwrap().dependencies {
162 self.topological_sort(dep_id, visited, temp_visited, order);
163 }
164 }
165
166 temp_visited.remove(task_id);
167 visited.insert(task_id.to_string());
168 order.push(task_id.to_string());
169 }
170
171 fn rebuild_execution_order(&mut self) {
172 let mut order = Vec::new();
173 let mut visited = HashSet::new();
174 let mut temp_visited = HashSet::new();
175
176 for task_id in self.tasks.keys() {
177 if !visited.contains(task_id) {
178 self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
179 }
180 }
181
182 self.execution_order = order;
183 }
184
185 pub fn get_ready_tasks(&self) -> Vec<Arc<RwLock<Task>>> {
190 self.tasks
191 .values()
192 .filter(|task_arc| {
193 let task = task_arc.read().unwrap();
194 task.status == TaskStatus::Pending
195 && task.dependencies.iter().all(|dep_id| {
196 self.tasks
197 .get(dep_id)
198 .map(|dep| dep.read().unwrap().status == TaskStatus::Completed)
199 .unwrap_or(false)
200 })
201 })
202 .cloned()
203 .collect()
204 }
205
206 pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
207 for task in self.tasks.values_mut() {
208 let mut task = task.write().unwrap();
209 if task.status == TaskStatus::Failed {
210 task.status = TaskStatus::Pending;
211 task.increment_retry();
212 if task.retry_count > task.max_retries {
213 return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
214 }
215 }
216 }
217 Ok(())
218 }
219}