Skip to main content

potato_workflow/workflow/
tasklist.rs

1use crate::workflow::error::WorkflowError;
2
3pub use potato_agent::agents::{
4    agent::Agent,
5    task::{Task, TaskStatus, WorkflowTask},
6};
7use potato_agent::AgentResponse;
8use pyo3::prelude::*;
9use serde::Deserialize;
10use serde::Serialize;
11use serde_json::Value;
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use std::sync::RwLock;
15use tracing::instrument;
16use tracing::{debug, warn};
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
19#[pyclass]
20pub struct TaskList {
21    pub tasks: HashMap<String, Arc<RwLock<Task>>>,
22    pub execution_order: Vec<String>,
23}
24
25impl PartialEq for TaskList {
26    fn eq(&self, other: &Self) -> bool {
27        // Compare tasks by their IDs and execution order
28        self.tasks.keys().eq(other.tasks.keys()) && self.execution_order == other.execution_order
29    }
30}
31
32#[pymethods]
33impl TaskList {
34    #[getter]
35    pub fn items(&self) -> HashMap<String, Task> {
36        self.tasks()
37    }
38    /// This is mainly a utility function to help with python interoperability.
39    pub fn tasks(&self) -> HashMap<String, Task> {
40        self.tasks
41            .iter()
42            .map(|(id, task)| {
43                let cloned_task = task.read().unwrap().clone();
44                (id.clone(), cloned_task)
45            })
46            .collect()
47    }
48
49    /// Helper for creating a new TaskList by cloning each task in the current TaskList out of the Arc<RwLock<Task>> wrapper.
50    pub fn deep_clone(&self) -> Result<Self, WorkflowError> {
51        let mut new_task_list = TaskList::new();
52
53        // Clone each task individually to create new Arc<RwLock<Task>> instances
54        for (task_id, task_arc) in &self.tasks {
55            let task = task_arc.read().unwrap();
56            let cloned_task = task.clone(); // This should clone the Task struct itself
57            new_task_list
58                .tasks
59                .insert(task_id.clone(), Arc::new(RwLock::new(cloned_task)));
60        }
61
62        // Copy execution order
63        new_task_list.execution_order = self.execution_order.clone();
64
65        Ok(new_task_list)
66    }
67}
68
69impl TaskList {
70    pub fn new() -> Self {
71        Self {
72            tasks: HashMap::new(),
73            execution_order: Vec::new(),
74        }
75    }
76
77    pub fn len(&self) -> usize {
78        self.tasks.len()
79    }
80
81    pub fn is_empty(&self) -> bool {
82        self.tasks.is_empty()
83    }
84
85    pub fn rebuild_task_validators(&mut self) -> Result<(), WorkflowError> {
86        for task_arc in self.tasks.values_mut() {
87            let mut task = task_arc.write().unwrap();
88            task.rebuild_validator()?
89        }
90        Ok(())
91    }
92
93    pub fn is_complete(&self) -> bool {
94        self.tasks.values().all(|task| {
95            task.read().unwrap().status == TaskStatus::Completed
96                || task.read().unwrap().status == TaskStatus::Failed
97        })
98    }
99
100    pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
101        // assert that task ID is unique
102        if self.tasks.contains_key(&task.id) {
103            return Err(WorkflowError::TaskAlreadyExists(task.id.clone()));
104        }
105
106        // if dependencies are not empty, check if they exist in the task list
107        for dep_id in &task.dependencies {
108            if !self.tasks.contains_key(dep_id) {
109                return Err(WorkflowError::DependencyNotFound(dep_id.clone()));
110            }
111
112            // also check that the dependency is not the task itself
113            if dep_id == &task.id {
114                return Err(WorkflowError::TaskDependsOnItself(task.id.clone()));
115            }
116        }
117
118        // if all checks pass, insert the task
119        self.tasks
120            .insert(task.id.clone(), Arc::new(RwLock::new(task)));
121        self.rebuild_execution_order();
122        Ok(())
123    }
124
125    pub fn get_task(&self, task_id: &str) -> Option<Arc<RwLock<Task>>> {
126        self.tasks.get(task_id).cloned()
127    }
128
129    /// returns all task responses as a HashMap of task ID to serde_json::Value
130    pub fn get_task_responses(&self) -> Result<HashMap<String, Value>, WorkflowError> {
131        let mut responses = HashMap::new();
132
133        for (task_id, task_arc) in &self.tasks {
134            let task = task_arc
135                .read()
136                .map_err(|_| WorkflowError::ReadLockAcquireError)?;
137            if let Some(result) = &task.result {
138                if let Some(value) = result.response_value() {
139                    responses.insert(task_id.clone(), value);
140                } else {
141                    // insert null if there is no response value
142                    responses.insert(task_id.clone(), Value::Null);
143                }
144            }
145        }
146
147        Ok(responses)
148    }
149
150    pub fn remove_task(&mut self, task_id: &str) {
151        self.tasks.remove(task_id);
152    }
153
154    pub fn pending_count(&self) -> usize {
155        self.tasks
156            .values()
157            .filter(|task| task.read().unwrap().status == TaskStatus::Pending)
158            .count()
159    }
160
161    #[instrument(skip_all)]
162    pub fn update_task_status(
163        &mut self,
164        task_id: &str,
165        status: TaskStatus,
166        result: Option<&AgentResponse>,
167    ) {
168        debug!(status=?status, result=?result, "Updating task status");
169        if let Some(task) = self.tasks.get_mut(task_id) {
170            let mut task = task.write().unwrap();
171            task.status = status;
172            task.result = result.cloned();
173        }
174    }
175
176    fn topological_sort(
177        &self,
178        task_id: &str,
179        visited: &mut HashSet<String>,
180        temp_visited: &mut HashSet<String>,
181        order: &mut Vec<String>,
182    ) {
183        if temp_visited.contains(task_id) {
184            return; // Cycle detected, skip
185        }
186
187        if visited.contains(task_id) {
188            return;
189        }
190
191        temp_visited.insert(task_id.to_string());
192
193        if let Some(task) = self.tasks.get(task_id) {
194            for dep_id in &task.read().unwrap().dependencies {
195                self.topological_sort(dep_id, visited, temp_visited, order);
196            }
197        }
198
199        temp_visited.remove(task_id);
200        visited.insert(task_id.to_string());
201        order.push(task_id.to_string());
202    }
203
204    fn rebuild_execution_order(&mut self) {
205        let mut order = Vec::new();
206        let mut visited = HashSet::new();
207        let mut temp_visited = HashSet::new();
208
209        for task_id in self.tasks.keys() {
210            if !visited.contains(task_id) {
211                self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
212            }
213        }
214
215        self.execution_order = order;
216    }
217
218    /// Iterate through all tasks and return those that are ready to be executed
219    /// This also checks if all dependencies of the task are completed
220    ///
221    /// # Returns a vector of references to tasks that are ready to be executed
222    pub fn get_ready_tasks(&self) -> Vec<Arc<RwLock<Task>>> {
223        self.tasks
224            .values()
225            .filter(|task_arc| {
226                let task = task_arc.read().unwrap();
227                task.status == TaskStatus::Pending
228                    && task.dependencies.iter().all(|dep_id| {
229                        self.tasks
230                            .get(dep_id)
231                            .map(|dep| dep.read().unwrap().status == TaskStatus::Completed)
232                            .unwrap_or(false)
233                    })
234            })
235            .cloned()
236            .collect()
237    }
238
239    pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
240        for task in self.tasks.values_mut() {
241            let mut task = task.write().unwrap();
242            if task.status == TaskStatus::Failed {
243                task.status = TaskStatus::Pending;
244                task.increment_retry();
245                if task.retry_count > task.max_retries {
246                    return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
247                }
248            }
249        }
250        Ok(())
251    }
252
253    pub fn get_last_task_id(&self) -> Option<String> {
254        self.execution_order.last().cloned()
255    }
256}