potato_workflow/workflow/
tasklist.rs

1use crate::workflow::error::WorkflowError;
2
3pub use potato_agent::agents::{
4    agent::Agent,
5    task::{PyTask, Task, TaskStatus},
6};
7use potato_agent::AgentResponse;
8use pyo3::prelude::*;
9use serde::Deserialize;
10use serde::Serialize;
11use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13use std::sync::RwLock;
14use tracing::instrument;
15use tracing::{debug, warn};
16
17#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18#[pyclass]
19pub struct TaskList {
20    pub tasks: HashMap<String, Arc<RwLock<Task>>>,
21    pub execution_order: Vec<String>,
22}
23
24impl PartialEq for TaskList {
25    fn eq(&self, other: &Self) -> bool {
26        // Compare tasks by their IDs and execution order
27        self.tasks.keys().eq(other.tasks.keys()) && self.execution_order == other.execution_order
28    }
29}
30
31#[pymethods]
32impl TaskList {
33    /// This is mainly a utility function to help with python interoperability.
34    pub fn tasks(&self) -> HashMap<String, Task> {
35        self.tasks
36            .iter()
37            .map(|(id, task)| {
38                let cloned_task = task.read().unwrap().clone();
39                (id.clone(), cloned_task)
40            })
41            .collect()
42    }
43
44    /// Helper for creating a new TaskList by cloning each task in the current TaskList out of the Arc<RwLock<Task>> wrapper.
45    pub fn deep_clone(&self) -> Result<Self, WorkflowError> {
46        let mut new_task_list = TaskList::new();
47
48        // Clone each task individually to create new Arc<RwLock<Task>> instances
49        for (task_id, task_arc) in &self.tasks {
50            let task = task_arc.read().unwrap();
51            let cloned_task = task.clone(); // This should clone the Task struct itself
52            new_task_list
53                .tasks
54                .insert(task_id.clone(), Arc::new(RwLock::new(cloned_task)));
55        }
56
57        // Copy execution order
58        new_task_list.execution_order = self.execution_order.clone();
59
60        Ok(new_task_list)
61    }
62}
63
64impl TaskList {
65    pub fn new() -> Self {
66        Self {
67            tasks: HashMap::new(),
68            execution_order: Vec::new(),
69        }
70    }
71
72    pub fn len(&self) -> usize {
73        self.tasks.len()
74    }
75
76    pub fn is_empty(&self) -> bool {
77        self.tasks.is_empty()
78    }
79
80    pub fn is_complete(&self) -> bool {
81        self.tasks.values().all(|task| {
82            task.read().unwrap().status == TaskStatus::Completed
83                || task.read().unwrap().status == TaskStatus::Failed
84        })
85    }
86
87    pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
88        // assert that task ID is unique
89        if self.tasks.contains_key(&task.id) {
90            return Err(WorkflowError::TaskAlreadyExists(task.id.clone()));
91        }
92
93        // if dependencies are not empty, check if they exist in the task list
94        for dep_id in &task.dependencies {
95            if !self.tasks.contains_key(dep_id) {
96                return Err(WorkflowError::DependencyNotFound(dep_id.clone()));
97            }
98
99            // also check that the dependency is not the task itself
100            if dep_id == &task.id {
101                return Err(WorkflowError::TaskDependsOnItself(task.id.clone()));
102            }
103        }
104
105        // if all checks pass, insert the task
106        self.tasks
107            .insert(task.id.clone(), Arc::new(RwLock::new(task)));
108        self.rebuild_execution_order();
109        Ok(())
110    }
111
112    pub fn get_task(&self, task_id: &str) -> Option<Arc<RwLock<Task>>> {
113        self.tasks.get(task_id).cloned()
114    }
115
116    pub fn remove_task(&mut self, task_id: &str) {
117        self.tasks.remove(task_id);
118    }
119
120    pub fn pending_count(&self) -> usize {
121        self.tasks
122            .values()
123            .filter(|task| task.read().unwrap().status == TaskStatus::Pending)
124            .count()
125    }
126
127    #[instrument(skip_all)]
128    pub fn update_task_status(
129        &mut self,
130        task_id: &str,
131        status: TaskStatus,
132        result: Option<&AgentResponse>,
133    ) {
134        debug!(status=?status, result=?result, "Updating task status");
135        if let Some(task) = self.tasks.get_mut(task_id) {
136            let mut task = task.write().unwrap();
137            task.status = status;
138            task.result = result.cloned();
139        }
140    }
141
142    fn topological_sort(
143        &self,
144        task_id: &str,
145        visited: &mut HashSet<String>,
146        temp_visited: &mut HashSet<String>,
147        order: &mut Vec<String>,
148    ) {
149        if temp_visited.contains(task_id) {
150            return; // Cycle detected, skip
151        }
152
153        if visited.contains(task_id) {
154            return;
155        }
156
157        temp_visited.insert(task_id.to_string());
158
159        if let Some(task) = self.tasks.get(task_id) {
160            for dep_id in &task.read().unwrap().dependencies {
161                self.topological_sort(dep_id, visited, temp_visited, order);
162            }
163        }
164
165        temp_visited.remove(task_id);
166        visited.insert(task_id.to_string());
167        order.push(task_id.to_string());
168    }
169
170    fn rebuild_execution_order(&mut self) {
171        let mut order = Vec::new();
172        let mut visited = HashSet::new();
173        let mut temp_visited = HashSet::new();
174
175        for task_id in self.tasks.keys() {
176            if !visited.contains(task_id) {
177                self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
178            }
179        }
180
181        self.execution_order = order;
182    }
183
184    /// Iterate through all tasks and return those that are ready to be executed
185    /// This also checks if all dependencies of the task are completed
186    ///
187    /// # Returns a vector of references to tasks that are ready to be executed
188    pub fn get_ready_tasks(&self) -> Vec<Arc<RwLock<Task>>> {
189        self.tasks
190            .values()
191            .filter(|task_arc| {
192                let task = task_arc.read().unwrap();
193                task.status == TaskStatus::Pending
194                    && task.dependencies.iter().all(|dep_id| {
195                        self.tasks
196                            .get(dep_id)
197                            .map(|dep| dep.read().unwrap().status == TaskStatus::Completed)
198                            .unwrap_or(false)
199                    })
200            })
201            .cloned()
202            .collect()
203    }
204
205    pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
206        for task in self.tasks.values_mut() {
207            let mut task = task.write().unwrap();
208            if task.status == TaskStatus::Failed {
209                task.status = TaskStatus::Pending;
210                task.increment_retry();
211                if task.retry_count > task.max_retries {
212                    return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
213                }
214            }
215        }
216        Ok(())
217    }
218}