potato_workflow/workflow/
tasklist.rs

1use crate::workflow::error::WorkflowError;
2
3pub use potato_agent::agents::{
4    agent::Agent,
5    task::{PyTask, Task, TaskStatus},
6    types::ChatResponse,
7};
8use potato_agent::AgentResponse;
9use pyo3::prelude::*;
10use serde::Deserialize;
11use serde::Serialize;
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use std::sync::RwLock;
15use tracing::instrument;
16use tracing::{debug, warn};
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
19#[pyclass]
20pub struct TaskList {
21    pub tasks: HashMap<String, Arc<RwLock<Task>>>,
22    pub execution_order: Vec<String>,
23}
24
25impl PartialEq for TaskList {
26    fn eq(&self, other: &Self) -> bool {
27        // Compare tasks by their IDs and execution order
28        self.tasks.keys().eq(other.tasks.keys()) && self.execution_order == other.execution_order
29    }
30}
31
32#[pymethods]
33impl TaskList {
34    /// This is mainly a utility function to help with python interoperability.
35    pub fn tasks(&self) -> HashMap<String, Task> {
36        self.tasks
37            .iter()
38            .map(|(id, task)| {
39                let cloned_task = task.read().unwrap().clone();
40                (id.clone(), cloned_task)
41            })
42            .collect()
43    }
44
45    /// Helper for creating a new TaskList by cloning each task in the current TaskList out of the Arc<RwLock<Task>> wrapper.
46    pub fn deep_clone(&self) -> Result<Self, WorkflowError> {
47        let mut new_task_list = TaskList::new();
48
49        // Clone each task individually to create new Arc<RwLock<Task>> instances
50        for (task_id, task_arc) in &self.tasks {
51            let task = task_arc.read().unwrap();
52            let cloned_task = task.clone(); // This should clone the Task struct itself
53            new_task_list
54                .tasks
55                .insert(task_id.clone(), Arc::new(RwLock::new(cloned_task)));
56        }
57
58        // Copy execution order
59        new_task_list.execution_order = self.execution_order.clone();
60
61        Ok(new_task_list)
62    }
63}
64
65impl TaskList {
66    pub fn new() -> Self {
67        Self {
68            tasks: HashMap::new(),
69            execution_order: Vec::new(),
70        }
71    }
72
73    pub fn len(&self) -> usize {
74        self.tasks.len()
75    }
76
77    pub fn is_empty(&self) -> bool {
78        self.tasks.is_empty()
79    }
80
81    pub fn is_complete(&self) -> bool {
82        self.tasks.values().all(|task| {
83            task.read().unwrap().status == TaskStatus::Completed
84                || task.read().unwrap().status == TaskStatus::Failed
85        })
86    }
87
88    pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
89        // assert that task ID is unique
90        if self.tasks.contains_key(&task.id) {
91            return Err(WorkflowError::TaskAlreadyExists(task.id.clone()));
92        }
93
94        // if dependencies are not empty, check if they exist in the task list
95        for dep_id in &task.dependencies {
96            if !self.tasks.contains_key(dep_id) {
97                return Err(WorkflowError::DependencyNotFound(dep_id.clone()));
98            }
99
100            // also check that the dependency is not the task itself
101            if dep_id == &task.id {
102                return Err(WorkflowError::TaskDependsOnItself(task.id.clone()));
103            }
104        }
105
106        // if all checks pass, insert the task
107        self.tasks
108            .insert(task.id.clone(), Arc::new(RwLock::new(task)));
109        self.rebuild_execution_order();
110        Ok(())
111    }
112
113    pub fn get_task(&self, task_id: &str) -> Option<Arc<RwLock<Task>>> {
114        self.tasks.get(task_id).cloned()
115    }
116
117    pub fn remove_task(&mut self, task_id: &str) {
118        self.tasks.remove(task_id);
119    }
120
121    pub fn pending_count(&self) -> usize {
122        self.tasks
123            .values()
124            .filter(|task| task.read().unwrap().status == TaskStatus::Pending)
125            .count()
126    }
127
128    #[instrument(skip_all)]
129    pub fn update_task_status(
130        &mut self,
131        task_id: &str,
132        status: TaskStatus,
133        result: Option<&AgentResponse>,
134    ) {
135        debug!(status=?status, result=?result, "Updating task status");
136        if let Some(task) = self.tasks.get_mut(task_id) {
137            let mut task = task.write().unwrap();
138            task.status = status;
139            task.result = result.cloned();
140        }
141    }
142
143    fn topological_sort(
144        &self,
145        task_id: &str,
146        visited: &mut HashSet<String>,
147        temp_visited: &mut HashSet<String>,
148        order: &mut Vec<String>,
149    ) {
150        if temp_visited.contains(task_id) {
151            return; // Cycle detected, skip
152        }
153
154        if visited.contains(task_id) {
155            return;
156        }
157
158        temp_visited.insert(task_id.to_string());
159
160        if let Some(task) = self.tasks.get(task_id) {
161            for dep_id in &task.read().unwrap().dependencies {
162                self.topological_sort(dep_id, visited, temp_visited, order);
163            }
164        }
165
166        temp_visited.remove(task_id);
167        visited.insert(task_id.to_string());
168        order.push(task_id.to_string());
169    }
170
171    fn rebuild_execution_order(&mut self) {
172        let mut order = Vec::new();
173        let mut visited = HashSet::new();
174        let mut temp_visited = HashSet::new();
175
176        for task_id in self.tasks.keys() {
177            if !visited.contains(task_id) {
178                self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
179            }
180        }
181
182        self.execution_order = order;
183    }
184
185    /// Iterate through all tasks and return those that are ready to be executed
186    /// This also checks if all dependencies of the task are completed
187    ///
188    /// # Returns a vector of references to tasks that are ready to be executed
189    pub fn get_ready_tasks(&self) -> Vec<Arc<RwLock<Task>>> {
190        self.tasks
191            .values()
192            .filter(|task_arc| {
193                let task = task_arc.read().unwrap();
194                task.status == TaskStatus::Pending
195                    && task.dependencies.iter().all(|dep_id| {
196                        self.tasks
197                            .get(dep_id)
198                            .map(|dep| dep.read().unwrap().status == TaskStatus::Completed)
199                            .unwrap_or(false)
200                    })
201            })
202            .cloned()
203            .collect()
204    }
205
206    pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
207        for task in self.tasks.values_mut() {
208            let mut task = task.write().unwrap();
209            if task.status == TaskStatus::Failed {
210                task.status = TaskStatus::Pending;
211                task.increment_retry();
212                if task.retry_count > task.max_retries {
213                    return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
214                }
215            }
216        }
217        Ok(())
218    }
219}