potato_workflow/workflow/
flow.rs

1use crate::workflow::error::WorkflowError;
2pub use potato_agents::agents::{
3    agent::Agent,
4    task::{PyTask, Task, TaskStatus},
5    types::ChatResponse,
6};
7use potato_utils::{create_uuid7, PyHelperFuncs};
8
9use potato_prompts::prompt::types::Role;
10use potato_prompts::Message;
11use pyo3::prelude::*;
12
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use std::sync::RwLock;
16use tracing::instrument;
17use tracing::{debug, error, info, warn};
18
19use serde::{
20    de::{self, MapAccess, Visitor},
21    ser::SerializeStruct,
22    Deserialize, Deserializer, Serialize, Serializer,
23};
24
25#[derive(Debug)]
26#[pyclass]
27pub struct WorkflowResult {
28    #[pyo3(get)]
29    pub tasks: HashMap<String, Py<PyTask>>,
30}
31
32impl WorkflowResult {
33    pub fn new(py: Python, tasks: HashMap<String, Task>) -> Self {
34        let py_tasks = tasks
35            .into_iter()
36            .map(|(id, task)| {
37                let py_task = PyTask {
38                    id: task.id.clone(),
39                    prompt: task.prompt,
40                    dependencies: task.dependencies,
41                    status: task.status,
42                    agent_id: task.agent_id,
43                    result: task.result,
44                    max_retries: task.max_retries,
45                    retry_count: task.retry_count,
46                    response_type: None, // Response type is not serialized
47                };
48                (id, Py::new(py, py_task).unwrap())
49            })
50            .collect::<HashMap<_, _>>();
51
52        Self { tasks: py_tasks }
53    }
54}
55
56#[pymethods]
57impl WorkflowResult {
58    pub fn __str__(&self) -> String {
59        PyHelperFuncs::__str__(&self.tasks)
60    }
61}
62
63#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64#[pyclass]
65pub struct TaskList {
66    #[pyo3(get)]
67    pub tasks: HashMap<String, Task>,
68    pub execution_order: Vec<String>,
69}
70
71impl TaskList {
72    pub fn new() -> Self {
73        Self {
74            tasks: HashMap::new(),
75            execution_order: Vec::new(),
76        }
77    }
78
79    pub fn is_complete(&self) -> bool {
80        self.tasks
81            .values()
82            .all(|task| task.status == TaskStatus::Completed || task.status == TaskStatus::Failed)
83    }
84
85    pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
86        // assert that task ID is unique
87        if self.tasks.contains_key(&task.id) {
88            return Err(WorkflowError::TaskAlreadyExists(task.id.clone()));
89        }
90
91        // if dependencies are not empty, check if they exist in the task list
92        for dep_id in &task.dependencies {
93            if !self.tasks.contains_key(dep_id) {
94                return Err(WorkflowError::DependencyNotFound(dep_id.clone()));
95            }
96
97            // also check that the dependency is not the task itself
98            if dep_id == &task.id {
99                return Err(WorkflowError::TaskDependsOnItself(task.id.clone()));
100            }
101        }
102
103        // if all checks pass, insert the task
104        self.tasks.insert(task.id.clone(), task);
105        self.rebuild_execution_order();
106        Ok(())
107    }
108
109    pub fn get_task(&self, task_id: &str) -> Option<&Task> {
110        self.tasks.get(task_id)
111    }
112
113    pub fn remove_task(&mut self, task_id: &str) {
114        self.tasks.remove(task_id);
115    }
116
117    pub fn pending_count(&self) -> usize {
118        self.tasks
119            .values()
120            .filter(|task| task.status == TaskStatus::Pending)
121            .count()
122    }
123
124    #[instrument(skip_all)]
125    pub fn update_task_status(
126        &mut self,
127        task_id: &str,
128        status: TaskStatus,
129        result: Option<ChatResponse>,
130    ) {
131        debug!(status=?status, result=?result, "Updating task status");
132        if let Some(task) = self.tasks.get_mut(task_id) {
133            task.status = status;
134            task.result = result;
135        }
136    }
137
138    fn topological_sort(
139        &self,
140        task_id: &str,
141        visited: &mut HashSet<String>,
142        temp_visited: &mut HashSet<String>,
143        order: &mut Vec<String>,
144    ) {
145        if temp_visited.contains(task_id) {
146            return; // Cycle detected, skip
147        }
148
149        if visited.contains(task_id) {
150            return;
151        }
152
153        temp_visited.insert(task_id.to_string());
154
155        if let Some(task) = self.tasks.get(task_id) {
156            for dep_id in &task.dependencies {
157                self.topological_sort(dep_id, visited, temp_visited, order);
158            }
159        }
160
161        temp_visited.remove(task_id);
162        visited.insert(task_id.to_string());
163        order.push(task_id.to_string());
164    }
165
166    fn rebuild_execution_order(&mut self) {
167        let mut order = Vec::new();
168        let mut visited = HashSet::new();
169        let mut temp_visited = HashSet::new();
170
171        for task_id in self.tasks.keys() {
172            if !visited.contains(task_id) {
173                self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
174            }
175        }
176
177        self.execution_order = order;
178    }
179
180    /// Iterate through all tasks and return those that are ready to be executed
181    /// This also checks if all dependencies of the task are completed
182    ///
183    /// # Returns a vector of references to tasks that are ready to be executed
184    pub fn get_ready_tasks(&self) -> Vec<Task> {
185        self.tasks
186            .values()
187            .filter(|task| {
188                task.status == TaskStatus::Pending
189                    && task.dependencies.iter().all(|dep_id| {
190                        self.tasks
191                            .get(dep_id)
192                            .map(|dep| dep.status == TaskStatus::Completed)
193                            .unwrap_or(false)
194                    })
195            })
196            .cloned()
197            .collect()
198    }
199
200    pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
201        for task in self.tasks.values_mut() {
202            if task.status == TaskStatus::Failed {
203                task.status = TaskStatus::Pending;
204                task.increment_retry();
205                if task.retry_count > task.max_retries {
206                    return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
207                }
208            }
209        }
210        Ok(())
211    }
212}
213
214/// Rust-specific implementation of a workflow
215#[derive(Debug, Clone)]
216pub struct Workflow {
217    pub id: String,
218    pub name: String,
219    pub tasks: TaskList,
220    pub agents: HashMap<String, Arc<Agent>>,
221}
222
223impl Workflow {
224    pub fn new(name: &str) -> Self {
225        info!("Creating new workflow: {}", name);
226        Self {
227            id: create_uuid7(),
228            name: name.to_string(),
229            tasks: TaskList::new(),
230            agents: HashMap::new(),
231        }
232    }
233    pub async fn run(&self) -> Result<(), WorkflowError> {
234        info!("Running workflow: {}", self.name);
235        let workflow = self.clone();
236        let workflow = Arc::new(RwLock::new(workflow));
237        execute_workflow(workflow).await
238    }
239
240    pub fn is_complete(&self) -> bool {
241        self.tasks.is_complete()
242    }
243
244    pub fn pending_count(&self) -> usize {
245        self.tasks.pending_count()
246    }
247
248    pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
249        self.tasks.add_task(task)
250    }
251
252    pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
253        for task in tasks {
254            self.tasks.add_task(task)?;
255        }
256        Ok(())
257    }
258
259    pub fn add_agent(&mut self, agent: &Agent) {
260        self.agents
261            .insert(agent.id.clone(), Arc::new(agent.clone()));
262    }
263
264    pub fn execution_plan(&self) -> Result<HashMap<String, HashSet<String>>, WorkflowError> {
265        let mut remaining: HashMap<String, HashSet<String>> = self
266            .tasks
267            .tasks
268            .iter()
269            .map(|(id, task)| (id.clone(), task.dependencies.iter().cloned().collect()))
270            .collect();
271
272        let mut executed = HashSet::new();
273        let mut plan = HashMap::new();
274        let mut step = 1;
275
276        while !remaining.is_empty() {
277            // Find all tasks that can be executed in parallel - collect just the keys we need to remove
278            let ready_keys: Vec<String> = remaining
279                .iter()
280                .filter(|(_, deps)| deps.is_subset(&executed))
281                .map(|(id, _)| id.to_string())
282                .collect();
283
284            if ready_keys.is_empty() {
285                // Circular dependency detected
286                break;
287            }
288
289            // Create the set for the plan (reusing the already allocated Strings)
290            let mut ready_set = HashSet::with_capacity(ready_keys.len());
291
292            // Update tracking sets and build the ready set in one pass
293            for key in ready_keys {
294                executed.insert(key.clone());
295                remaining.remove(&key);
296                ready_set.insert(key);
297            }
298
299            // Add parallel tasks to the current step
300            plan.insert(format!("step{step}"), ready_set);
301
302            step += 1;
303        }
304
305        Ok(plan)
306    }
307}
308
309/// Check if the workflow is complete
310/// # Arguments
311/// * `workflow` - A reference to the workflow instance
312/// # Returns true if the workflow is complete, false otherwise
313fn is_workflow_complete(workflow: &Arc<RwLock<Workflow>>) -> bool {
314    workflow.read().unwrap().is_complete()
315}
316
317/// Reset failed tasks in the workflow
318/// # Arguments
319/// * `workflow` - A reference to the workflow instance
320/// # Returns Ok(()) if successful, or an error if the reset fails
321fn reset_failed_workflow_tasks(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
322    match workflow.write().unwrap().tasks.reset_failed_tasks() {
323        Ok(_) => Ok(()),
324        Err(e) => {
325            warn!("Failed to reset failed tasks: {}", e);
326            Err(e)
327        }
328    }
329}
330
331/// Get all ready tasks in the workflow
332/// # Arguments
333/// * `workflow` - A reference to the workflow instance
334/// # Returns a vector of tasks that are ready to be executed
335fn get_ready_tasks(workflow: &Arc<RwLock<Workflow>>) -> Vec<Task> {
336    workflow.read().unwrap().tasks.get_ready_tasks()
337}
338
339/// Check for circular dependencies
340/// # Arguments
341/// * `workflow` - A reference to the workflow instance
342/// # Returns true if circular dependencies are detected, false otherwise
343fn check_for_circular_dependencies(workflow: &Arc<RwLock<Workflow>>) -> bool {
344    let pending_count = workflow.read().unwrap().pending_count();
345
346    if pending_count > 0 {
347        warn!(
348            "No ready tasks found but {} pending tasks remain. Possible circular dependency.",
349            pending_count
350        );
351        return true;
352    }
353
354    false
355}
356
357/// Mark a task as running
358/// # Arguments
359/// * `workflow` - A reference to the workflow instance
360/// # Returns nothing
361fn mark_task_as_running(workflow: &Arc<RwLock<Workflow>>, task_id: &str) {
362    let mut wf = workflow.write().unwrap();
363    wf.tasks
364        .update_task_status(task_id, TaskStatus::Running, None);
365}
366
367/// Get an agent for a task
368/// # Arguments
369/// * `workflow` - A reference to the workflow instance
370/// * `task` - A reference to the task for which the agent is needed
371fn get_agent_for_task(workflow: &Arc<RwLock<Workflow>>, task: &Task) -> Option<Arc<Agent>> {
372    let wf = workflow.read().unwrap();
373    wf.agents.get(&task.agent_id).cloned()
374}
375
376/// Builds the context for a task from its dependencies
377/// # Arguments
378/// * `workflow` - A reference to the workflow instance
379/// * `task` - A reference to the task for which the context is being built
380/// # Returns a HashMap containing the context messages for the task
381fn build_task_context(
382    workflow: &Arc<RwLock<Workflow>>,
383    task: &Task,
384) -> HashMap<String, Vec<Message>> {
385    let wf = workflow.read().unwrap();
386    let mut ctx = HashMap::new();
387
388    for dep_id in &task.dependencies {
389        if let Some(dep) = wf.tasks.get_task(dep_id) {
390            if let Some(result) = &dep.result {
391                if let Ok(message) = result.to_message(Role::Assistant) {
392                    ctx.insert(dep_id.clone(), message);
393                }
394            }
395        }
396    }
397
398    ctx
399}
400
401/// Spawns an individual task execution
402/// # Arguments
403/// * `workflow` - A reference to the workflow instance
404/// * `task` - The task to be executed
405/// * `task_id` - The ID of the task
406/// * `agent` - An optional reference to the agent that will execute the task
407/// * `context` - A HashMap containing the context messages for the task
408/// # Returns a JoinHandle for the spawned task
409fn spawn_task_execution(
410    workflow: Arc<RwLock<Workflow>>,
411    task: Task,
412    task_id: String,
413    agent: Option<Arc<Agent>>,
414    context: HashMap<String, Vec<Message>>,
415) -> tokio::task::JoinHandle<()> {
416    tokio::spawn(async move {
417        if let Some(agent) = agent {
418            match agent.execute_async_task_with_context(&task, context).await {
419                Ok(response) => {
420                    let mut wf = workflow.write().unwrap();
421                    wf.tasks.update_task_status(
422                        &task_id,
423                        TaskStatus::Completed,
424                        Some(response.response),
425                    );
426                }
427                Err(e) => {
428                    error!("Task {} failed: {}", task_id, e);
429                    let mut wf = workflow.write().unwrap();
430                    wf.tasks
431                        .update_task_status(&task_id, TaskStatus::Failed, None);
432                }
433            }
434        } else {
435            error!("No agent found for task {}", task_id);
436            let mut wf = workflow.write().unwrap();
437            wf.tasks
438                .update_task_status(&task_id, TaskStatus::Failed, None);
439        }
440    })
441}
442
443/// Helper for spawning a task execution
444/// # Arguments
445/// * `workflow` - A reference to the workflow instance
446/// * `tasks` - A vector of tasks to be executed
447/// # Returns a vector of JoinHandles for the spawned tasks
448fn spawn_task_executions(
449    workflow: &Arc<RwLock<Workflow>>,
450    tasks: Vec<Task>,
451) -> Vec<tokio::task::JoinHandle<()>> {
452    let mut handles = Vec::with_capacity(tasks.len());
453
454    for task in tasks {
455        let task_id = task.id.clone();
456        //let workflow_clone = workflow.clone();
457
458        // Mark task as running
459        mark_task_as_running(workflow, &task_id);
460
461        // Build the context
462        let context = build_task_context(workflow, &task);
463
464        // Get/clone agent ARC
465        let agent = get_agent_for_task(workflow, &task);
466
467        // Spawn task execution and push handle to the vector
468        let handle = spawn_task_execution(workflow.clone(), task, task_id, agent, context);
469        handles.push(handle);
470    }
471
472    handles
473}
474
475/// Wait for all spawned tasks to complete
476/// # Arguments
477/// * `handles` - A vector of JoinHandles for the spawned tasks
478/// # Returns nothing
479async fn await_task_completions(handles: Vec<tokio::task::JoinHandle<()>>) {
480    for handle in handles {
481        if let Err(e) = handle.await {
482            warn!("Task execution failed: {}", e);
483        }
484    }
485}
486
487/// Execute the workflow asynchronously
488/// This function will be called to start the workflow execution process and does the following:
489/// 1. Iterates over workflow tasks while the shared workflow is not complete.
490/// 2. Resets any failed tasks to allow them to be retried. This needs to happen before getting ready tasks.
491/// 3. Gets all ready tasks
492/// 4. For each ready task:
493/// ///    - Marks the task as running
494/// ///    - Checks previous tasks for injected context
495/// ///    - Gets the agent for the task  
496/// ///    - Spawn a new tokio task and execute task with agent
497/// ///    - Push task to the handles vector
498/// 4. Waits for all spawned tasks to complete
499#[instrument(skip_all)]
500pub async fn execute_workflow(workflow: Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
501    info!("Starting workflow execution");
502
503    while !is_workflow_complete(&workflow) {
504        // Reset any failed tasks
505        // This will return and error if any task exceeds its max retries
506        reset_failed_workflow_tasks(&workflow)?;
507
508        // Get tasks ready for execution
509        let ready_tasks = get_ready_tasks(&workflow);
510        info!("Found {} ready tasks for execution", ready_tasks.len());
511
512        // Check for circular dependencies
513        if ready_tasks.is_empty() {
514            if check_for_circular_dependencies(&workflow) {
515                break;
516            }
517            continue;
518        }
519
520        // Execute tasks asynchronously
521        let handles = spawn_task_executions(&workflow, ready_tasks);
522
523        // Wait for all tasks to complete
524        await_task_completions(handles).await;
525    }
526
527    info!("Workflow execution completed");
528    Ok(())
529}
530
531impl Serialize for Workflow {
532    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
533    where
534        S: Serializer,
535    {
536        let mut state = serializer.serialize_struct("Workflow", 4)?;
537
538        // set session to none
539        state.serialize_field("id", &self.id)?;
540        state.serialize_field("name", &self.name)?;
541        state.serialize_field("tasks", &self.tasks)?;
542
543        // serialize agents by unwrapping the Arc
544        let agents: HashMap<String, Agent> = self
545            .agents
546            .iter()
547            .map(|(id, agent)| (id.clone(), (*agent.as_ref()).clone()))
548            .collect();
549
550        state.serialize_field("agents", &agents)?;
551        state.end()
552    }
553}
554
555impl<'de> Deserialize<'de> for Workflow {
556    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
557    where
558        D: Deserializer<'de>,
559    {
560        #[derive(Deserialize)]
561        #[serde(field_identifier, rename_all = "snake_case")]
562        enum Field {
563            Id,
564            Name,
565            Tasks,
566            Agents,
567        }
568
569        struct WorkflowVisitor;
570
571        impl<'de> Visitor<'de> for WorkflowVisitor {
572            type Value = Workflow;
573
574            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
575                formatter.write_str("struct Workflow")
576            }
577
578            fn visit_map<V>(self, mut map: V) -> Result<Workflow, V::Error>
579            where
580                V: MapAccess<'de>,
581            {
582                let mut id = None;
583                let mut name = None;
584                let mut tasks = None;
585                let mut agents: Option<HashMap<String, Agent>> = None;
586
587                while let Some(key) = map.next_key()? {
588                    match key {
589                        Field::Id => {
590                            id = Some(map.next_value()?);
591                        }
592                        Field::Tasks => {
593                            tasks = Some(map.next_value()?);
594                        }
595                        Field::Agents => {
596                            agents = Some(map.next_value()?);
597                        }
598                        Field::Name => {
599                            name = Some(map.next_value()?);
600                        }
601                    }
602                }
603
604                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
605                let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
606                let tasks = tasks.ok_or_else(|| de::Error::missing_field("tasks"))?;
607                let agents = agents.ok_or_else(|| de::Error::missing_field("agents"))?;
608
609                // convert agents to arc
610                let agents = agents
611                    .into_iter()
612                    .map(|(id, agent)| (id, Arc::new(agent)))
613                    .collect();
614
615                Ok(Workflow {
616                    id,
617                    name,
618                    tasks,
619                    agents,
620                })
621            }
622        }
623
624        const FIELDS: &[&str] = &["id", "name", "tasks", "agents"];
625        deserializer.deserialize_struct("Workflow", FIELDS, WorkflowVisitor)
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    use potato_prompts::{prompt::types::PromptContent, Message, Prompt};
633
634    #[test]
635    fn test_workflow_creation() {
636        let workflow = Workflow::new("Test Workflow");
637        assert_eq!(workflow.name, "Test Workflow");
638        assert_eq!(workflow.id.len(), 36); // UUID7 length
639    }
640
641    #[test]
642    fn test_task_list_add_and_get() {
643        let mut task_list = TaskList::new();
644        let prompt_content = PromptContent::Str("Test prompt".to_string());
645        let prompt = Prompt::new_rs(
646            vec![Message::new_rs(prompt_content)],
647            Some("gpt-4o"),
648            Some("openai"),
649            vec![],
650            None,
651            None,
652        )
653        .unwrap();
654
655        let task = Task::new("task1", prompt, "task1", None, None);
656        task_list.add_task(task.clone()).unwrap();
657        assert_eq!(task_list.get_task(&task.id).unwrap().id, task.id);
658        task_list.reset_failed_tasks().unwrap();
659    }
660}