potato_workflow/workflow/
tasklist.rs

1use crate::workflow::error::WorkflowError;
2
3pub use potato_agent::agents::{
4    agent::Agent,
5    task::{PyTask, Task, TaskStatus},
6    types::ChatResponse,
7};
8
9use pyo3::prelude::*;
10use serde::Deserialize;
11use serde::Serialize;
12use std::collections::{HashMap, HashSet};
13use tracing::instrument;
14use tracing::{debug, warn};
15
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17#[pyclass]
18pub struct TaskList {
19    #[pyo3(get)]
20    pub tasks: HashMap<String, Task>,
21    pub execution_order: Vec<String>,
22}
23
24impl TaskList {
25    pub fn new() -> Self {
26        Self {
27            tasks: HashMap::new(),
28            execution_order: Vec::new(),
29        }
30    }
31
32    pub fn is_complete(&self) -> bool {
33        self.tasks
34            .values()
35            .all(|task| task.status == TaskStatus::Completed || task.status == TaskStatus::Failed)
36    }
37
38    pub fn add_task(&mut self, task: Task) {
39        self.tasks.insert(task.id.clone(), task);
40        self.rebuild_execution_order();
41    }
42
43    pub fn get_task(&self, task_id: &str) -> Option<&Task> {
44        self.tasks.get(task_id)
45    }
46
47    pub fn remove_task(&mut self, task_id: &str) {
48        self.tasks.remove(task_id);
49    }
50
51    pub fn pending_count(&self) -> usize {
52        self.tasks
53            .values()
54            .filter(|task| task.status == TaskStatus::Pending)
55            .count()
56    }
57
58    #[instrument(skip_all)]
59    pub fn update_task_status(
60        &mut self,
61        task_id: &str,
62        status: TaskStatus,
63        result: Option<ChatResponse>,
64    ) {
65        debug!(status=?status, result=?result, "Updating task status");
66        if let Some(task) = self.tasks.get_mut(task_id) {
67            task.status = status;
68            task.result = result;
69        }
70    }
71
72    fn topological_sort(
73        &self,
74        task_id: &str,
75        visited: &mut HashSet<String>,
76        temp_visited: &mut HashSet<String>,
77        order: &mut Vec<String>,
78    ) {
79        if temp_visited.contains(task_id) {
80            return; // Cycle detected, skip
81        }
82
83        if visited.contains(task_id) {
84            return;
85        }
86
87        temp_visited.insert(task_id.to_string());
88
89        if let Some(task) = self.tasks.get(task_id) {
90            for dep_id in &task.dependencies {
91                self.topological_sort(dep_id, visited, temp_visited, order);
92            }
93        }
94
95        temp_visited.remove(task_id);
96        visited.insert(task_id.to_string());
97        order.push(task_id.to_string());
98    }
99
100    fn rebuild_execution_order(&mut self) {
101        let mut order = Vec::new();
102        let mut visited = HashSet::new();
103        let mut temp_visited = HashSet::new();
104
105        for task_id in self.tasks.keys() {
106            if !visited.contains(task_id) {
107                self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
108            }
109        }
110
111        self.execution_order = order;
112    }
113
114    /// Iterate through all tasks and return those that are ready to be executed
115    /// This also checks if all dependencies of the task are completed
116    ///
117    /// # Returns a vector of references to tasks that are ready to be executed
118    pub fn get_ready_tasks(&self) -> Vec<Task> {
119        self.tasks
120            .values()
121            .filter(|task| {
122                task.status == TaskStatus::Pending
123                    && task.dependencies.iter().all(|dep_id| {
124                        self.tasks
125                            .get(dep_id)
126                            .map(|dep| dep.status == TaskStatus::Completed)
127                            .unwrap_or(false)
128                    })
129            })
130            .cloned()
131            .collect()
132    }
133
134    pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
135        for task in self.tasks.values_mut() {
136            if task.status == TaskStatus::Failed {
137                task.status = TaskStatus::Pending;
138                task.increment_retry();
139                if task.retry_count > task.max_retries {
140                    return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
141                }
142            }
143        }
144        Ok(())
145    }
146}