potato_workflow/workflow/
flow.rs

1use crate::tasklist::TaskList;
2use crate::types::Context;
3use crate::{
4    events::{EventTracker, TaskEvent},
5    workflow::error::WorkflowError,
6};
7pub use potato_agent::agents::{
8    agent::{Agent, PyAgent},
9    task::{PyTask, Task, TaskStatus},
10    types::ChatResponse,
11};
12use potato_agent::PyAgentResponse;
13use potato_prompt::parse_response_to_json;
14use potato_prompt::prompt::types::Role;
15use potato_prompt::Message;
16use potato_util::{create_uuid7, utils::update_serde_map_with, PyHelperFuncs};
17use potato_util::{json_to_pydict, pyobject_to_json};
18use pyo3::prelude::*;
19use pyo3::IntoPyObjectExt;
20use serde::{
21    de::{self, MapAccess, Visitor},
22    ser::SerializeStruct,
23    Deserialize, Deserializer, Serialize, Serializer,
24};
25use serde_json::Map;
26use serde_json::Value;
27use std::collections::{HashMap, HashSet};
28use std::sync::Arc;
29use std::sync::RwLock;
30use tracing::instrument;
31use tracing::{debug, error, info, warn};
32
33/// Python workflows are a work in progress
34use pyo3::types::PyDict;
35
36#[derive(Debug)]
37#[pyclass]
38pub struct WorkflowResult {
39    #[pyo3(get)]
40    pub tasks: HashMap<String, Py<PyTask>>,
41
42    #[pyo3(get)]
43    pub events: Vec<TaskEvent>,
44}
45
46impl WorkflowResult {
47    pub fn new(
48        py: Python,
49        tasks: HashMap<String, Task>,
50        output_types: &HashMap<String, Arc<PyObject>>,
51        events: Vec<TaskEvent>,
52    ) -> Self {
53        let py_tasks = tasks
54            .into_iter()
55            .map(|(id, task)| {
56                let py_agent_response = if let Some(result) = task.result {
57                    let output_type = output_types.get(&id).map(|arc| arc.as_ref().clone_ref(py));
58                    Some(PyAgentResponse::new(result, output_type))
59                } else {
60                    None
61                };
62                let py_task = PyTask {
63                    id: task.id.clone(),
64                    prompt: task.prompt,
65                    dependencies: task.dependencies,
66                    status: task.status,
67                    agent_id: task.agent_id,
68                    result: py_agent_response,
69                    max_retries: task.max_retries,
70                    retry_count: task.retry_count,
71                };
72                (id, Py::new(py, py_task).unwrap())
73            })
74            .collect::<HashMap<_, _>>();
75
76        Self {
77            tasks: py_tasks,
78            events,
79        }
80    }
81}
82
83#[pymethods]
84impl WorkflowResult {
85    pub fn __str__(&self) -> String {
86        // serialize tasks to json
87        let json = serde_json::json!({
88            "tasks": serde_json::to_value(&self.tasks).unwrap_or(Value::Null),
89            "events": serde_json::to_value(&self.events).unwrap_or(Value::Null)
90        });
91
92        PyHelperFuncs::__str__(&json)
93    }
94}
95
96/// Rust-specific implementation of a workflow
97#[derive(Debug, Clone)]
98pub struct Workflow {
99    pub id: String,
100    pub name: String,
101    pub task_list: TaskList,
102    pub agents: HashMap<String, Arc<Agent>>,
103    pub event_tracker: Arc<RwLock<EventTracker>>,
104    pub global_context: Option<Value>,
105}
106
107impl PartialEq for Workflow {
108    fn eq(&self, other: &Self) -> bool {
109        // Compare by ID and name
110        self.id == other.id && self.name == other.name
111    }
112}
113
114impl Workflow {
115    pub fn new(name: &str) -> Self {
116        debug!("Creating new workflow: {}", name);
117        let id = create_uuid7();
118        Self {
119            id: id.clone(),
120            name: name.to_string(),
121            task_list: TaskList::new(),
122            agents: HashMap::new(),
123            event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
124            global_context: None, // Initialize with no global context
125        }
126    }
127    pub fn events(&self) -> Vec<TaskEvent> {
128        let tracker = self.event_tracker.read().unwrap();
129        let events = tracker.events.read().unwrap().clone();
130        events
131    }
132
133    pub fn total_duration(&self) -> i32 {
134        let tracker = self.event_tracker.read().unwrap();
135
136        if tracker.is_empty() {
137            0
138        } else {
139            //iter over each tracker event and get the details for each task event and get duration
140            let mut total_duration = chrono::Duration::zero();
141            for event in tracker.events.read().unwrap().iter() {
142                total_duration += event.details.duration.unwrap_or(chrono::Duration::zero());
143            }
144            total_duration.subsec_millis()
145        }
146    }
147
148    pub fn get_new_workflow(&self, global_context: Option<Value>) -> Result<Self, WorkflowError> {
149        // set new id for the new workflow
150        let id = create_uuid7();
151
152        // create deep copy of the tasklist so we don't clone the arc
153        let task_list = self.task_list.deep_clone()?;
154
155        Ok(Workflow {
156            id: id.clone(),
157            name: self.name.clone(),
158            task_list,
159            agents: self.agents.clone(), // Agents can be shared since they're read-only during execution
160            event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
161            global_context, // Use the provided global context or None
162        })
163    }
164
165    pub async fn run(
166        &self,
167        global_context: Option<Value>,
168    ) -> Result<Arc<RwLock<Workflow>>, WorkflowError> {
169        debug!("Running workflow: {}", self.name);
170
171        let run_workflow = Arc::new(RwLock::new(self.get_new_workflow(global_context)?));
172
173        execute_workflow(&run_workflow).await?;
174
175        Ok(run_workflow)
176    }
177
178    pub fn is_complete(&self) -> bool {
179        self.task_list.is_complete()
180    }
181
182    pub fn pending_count(&self) -> usize {
183        self.task_list.pending_count()
184    }
185
186    pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
187        self.task_list.add_task(task)
188    }
189
190    pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
191        for task in tasks {
192            self.task_list.add_task(task)?;
193        }
194        Ok(())
195    }
196
197    pub fn add_agent(&mut self, agent: &Agent) {
198        self.agents
199            .insert(agent.id.clone(), Arc::new(agent.clone()));
200    }
201
202    pub fn execution_plan(&self) -> Result<HashMap<i32, HashSet<String>>, WorkflowError> {
203        let mut remaining: HashMap<String, HashSet<String>> = self
204            .task_list
205            .tasks
206            .iter()
207            .map(|(id, task)| {
208                (
209                    id.clone(),
210                    task.read().unwrap().dependencies.iter().cloned().collect(),
211                )
212            })
213            .collect();
214
215        let mut executed = HashSet::new();
216        let mut plan = HashMap::new();
217        let mut step = 1;
218
219        while !remaining.is_empty() {
220            // Find all tasks that can be executed in parallel - collect just the keys we need to remove
221            let ready_keys: Vec<String> = remaining
222                .iter()
223                .filter(|(_, deps)| deps.is_subset(&executed))
224                .map(|(id, _)| id.to_string())
225                .collect();
226
227            if ready_keys.is_empty() {
228                // Circular dependency detected
229                break;
230            }
231
232            // Create the set for the plan (reusing the already allocated Strings)
233            let mut ready_set = HashSet::with_capacity(ready_keys.len());
234
235            // Update tracking sets and build the ready set in one pass
236            for key in ready_keys {
237                executed.insert(key.clone());
238                remaining.remove(&key);
239                ready_set.insert(key);
240            }
241
242            // Add parallel tasks to the current step
243            plan.insert(step, ready_set);
244
245            step += 1;
246        }
247
248        Ok(plan)
249    }
250
251    pub fn __str__(&self) -> String {
252        PyHelperFuncs::__str__(&self.task_list)
253    }
254
255    pub fn serialize(&self) -> Result<String, serde_json::Error> {
256        // reset the workflow
257        let json = serde_json::to_string(self).unwrap();
258        // Add debug output to see what's being serialized
259        Ok(json)
260    }
261
262    pub fn from_json(json: &str) -> Result<Self, WorkflowError> {
263        // Deserialize the JSON string into a Workflow instance
264        Ok(serde_json::from_str(json)?)
265    }
266
267    pub fn task_names(&self) -> Vec<String> {
268        self.task_list
269            .tasks
270            .keys()
271            .cloned()
272            .collect::<Vec<String>>()
273    }
274}
275
276/// Check if the workflow is complete
277/// # Arguments
278/// * `workflow` - A reference to the workflow instance
279/// # Returns true if the workflow is complete, false otherwise
280fn is_workflow_complete(workflow: &Arc<RwLock<Workflow>>) -> bool {
281    workflow.read().unwrap().is_complete()
282}
283
284/// Reset failed tasks in the workflow
285/// # Arguments
286/// * `workflow` - A reference to the workflow instance
287/// # Returns Ok(()) if successful, or an error if the reset fails
288fn reset_failed_workflow_tasks(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
289    match workflow.write().unwrap().task_list.reset_failed_tasks() {
290        Ok(_) => Ok(()),
291        Err(e) => {
292            warn!("Failed to reset failed tasks: {}", e);
293            Err(e)
294        }
295    }
296}
297
298/// Get all ready tasks in the workflow
299/// # Arguments
300/// * `workflow` - A reference to the workflow instance
301/// # Returns a vector of tasks that are ready to be executed
302fn get_ready_tasks(workflow: &Arc<RwLock<Workflow>>) -> Vec<Arc<RwLock<Task>>> {
303    workflow.read().unwrap().task_list.get_ready_tasks()
304}
305
306/// Check for circular dependencies
307/// # Arguments
308/// * `workflow` - A reference to the workflow instance
309/// # Returns true if circular dependencies are detected, false otherwise
310fn check_for_circular_dependencies(workflow: &Arc<RwLock<Workflow>>) -> bool {
311    let pending_count = workflow.read().unwrap().pending_count();
312
313    if pending_count > 0 {
314        warn!(
315            "No ready tasks found but {} pending tasks remain. Possible circular dependency.",
316            pending_count
317        );
318        return true;
319    }
320
321    false
322}
323
324/// Mark a task as running
325/// # Arguments
326/// * `workflow` - A reference to the workflow instance
327/// # Returns nothing
328fn mark_task_as_running(task: Arc<RwLock<Task>>, event_tracker: &Arc<RwLock<EventTracker>>) {
329    let mut task = task.write().unwrap();
330    task.set_status(TaskStatus::Running);
331    event_tracker.write().unwrap().record_task_started(&task.id);
332}
333
334/// Get an agent for a task
335/// # Arguments
336/// * `workflow` - A reference to the workflow instance
337/// * `task` - A reference to the task for which the agent is needed
338fn get_agent_for_task(workflow: &Arc<RwLock<Workflow>>, agent_id: &str) -> Option<Arc<Agent>> {
339    let wf = workflow.read().unwrap();
340    wf.agents.get(agent_id).cloned()
341}
342
343/// Builds the context for a task from its dependencies
344/// # Arguments
345/// * `workflow` - A reference to the workflow instance
346/// * `task` - A reference to the task for which the context is being built
347/// # Returns a HashMap containing the context messages for the task
348#[instrument(skip_all)]
349fn build_task_context(
350    workflow: &Arc<RwLock<Workflow>>,
351    task_dependencies: &Vec<String>,
352) -> Result<Context, WorkflowError> {
353    let wf = workflow.read().unwrap();
354    let mut ctx = HashMap::new();
355    let mut param_ctx: Value = Value::Object(Map::new());
356
357    for dep_id in task_dependencies {
358        debug!("Building context for task dependency: {}", dep_id);
359        if let Some(dep) = wf.task_list.get_task(dep_id) {
360            if let Some(result) = &dep.read().unwrap().result {
361                let msg_to_insert = result.response.to_message(Role::Assistant);
362
363                match msg_to_insert {
364                    Ok(message) => {
365                        ctx.insert(dep_id.clone(), message);
366                    }
367                    Err(e) => {
368                        warn!("Failed to convert response to message: {}", e);
369                    }
370                }
371
372                if let Some(structure_output) = result.response.extract_structured_data() {
373                    // Value should be a serde_json::Value Object type
374                    // validate that it's an object
375                    if structure_output.is_object() {
376                        // extract the Map from the Value
377                        update_serde_map_with(&mut param_ctx, &structure_output)?;
378                    }
379                }
380            }
381        }
382    }
383
384    debug!("Built context for task dependencies: {:?}", ctx);
385    let global_context = workflow.read().unwrap().global_context.clone();
386
387    Ok((ctx, param_ctx, global_context))
388}
389
390/// Spawns an individual task execution
391/// # Arguments
392/// * `workflow` - A reference to the workflow instance
393/// * `task` - The task to be executed
394/// * `task_id` - The ID of the task
395/// * `agent` - An optional reference to the agent that will execute the task
396/// * `context` - A HashMap containing the context messages for the task
397/// # Returns a JoinHandle for the spawned task
398fn spawn_task_execution(
399    event_tracker: Arc<RwLock<EventTracker>>,
400    task: Arc<RwLock<Task>>,
401    task_id: String,
402    agent: Option<Arc<Agent>>,
403    context: HashMap<String, Vec<Message>>,
404    parameter_context: Value,
405    global_context: Option<Value>,
406) -> tokio::task::JoinHandle<()> {
407    tokio::spawn(async move {
408        if let Some(agent) = agent {
409            // (1) Insert any context messages and/or parameters into the task prompt
410            // (2) Execute the task with the agent
411            // (3) Return the AgentResponse
412            let result = agent
413                .execute_task_with_context(&task, context, parameter_context, global_context)
414                .await;
415            match result {
416                Ok(response) => {
417                    let mut write_task = task.write().unwrap();
418                    write_task.set_status(TaskStatus::Completed);
419                    write_task.set_result(response.clone());
420                    event_tracker.write().unwrap().record_task_completed(
421                        &write_task.id,
422                        &write_task.prompt,
423                        response,
424                    );
425                }
426                Err(e) => {
427                    error!("Task {} failed: {}", task_id, e);
428                    let mut write_task = task.write().unwrap();
429                    write_task.set_status(TaskStatus::Failed);
430                    event_tracker.write().unwrap().record_task_failed(
431                        &write_task.id,
432                        &e.to_string(),
433                        &write_task.prompt,
434                    );
435                }
436            }
437        } else {
438            error!("No agent found for task {}", task_id);
439            let mut write_task = task.write().unwrap();
440            write_task.set_status(TaskStatus::Failed);
441        }
442    })
443}
444
445fn get_parameters_from_context(task: Arc<RwLock<Task>>) -> (String, Vec<String>, String) {
446    let (task_id, dependencies, agent_id) = {
447        let task_guard = task.read().unwrap();
448        (
449            task_guard.id.clone(),
450            task_guard.dependencies.clone(),
451            task_guard.agent_id.clone(),
452        )
453    };
454
455    (task_id, dependencies, agent_id)
456}
457
458/// Helper for spawning a task execution
459/// # Arguments
460/// * `workflow` - A reference to the workflow instance
461/// * `tasks` - A vector of tasks to be executed
462/// # Returns a vector of JoinHandles for the spawned tasks
463fn spawn_task_executions(
464    workflow: &Arc<RwLock<Workflow>>,
465    ready_tasks: Vec<Arc<RwLock<Task>>>,
466) -> Result<Vec<tokio::task::JoinHandle<()>>, WorkflowError> {
467    let mut handles = Vec::with_capacity(ready_tasks.len());
468
469    // Get the event tracker from the workflow
470    let event_tracker = workflow.read().unwrap().event_tracker.clone();
471
472    for task in ready_tasks {
473        // Get task parameters
474        let (task_id, dependencies, agent_id) = get_parameters_from_context(task.clone());
475
476        // Mark task as running
477        // This will also record the task started event
478        mark_task_as_running(task.clone(), &event_tracker);
479
480        // Build the context
481        // Here we:
482        // 1. Get the task dependencies and their results (these will be injected as assistant messages)
483        // 2. Parse dependent tasks for any structured outputs and return as a serde_json::Value (this will be task-level context)
484        let (context, parameter_context, global_context) =
485            build_task_context(workflow, &dependencies)?;
486
487        // Get/clone agent ARC
488        let agent = get_agent_for_task(workflow, &agent_id);
489
490        // Spawn task execution and push handle to future vector
491        let handle = spawn_task_execution(
492            event_tracker.clone(),
493            task.clone(),
494            task_id,
495            agent,
496            context,
497            parameter_context,
498            global_context,
499        );
500        handles.push(handle);
501    }
502
503    Ok(handles)
504}
505
506/// Wait for all spawned tasks to complete
507/// # Arguments
508/// * `handles` - A vector of JoinHandles for the spawned tasks
509/// # Returns nothing
510async fn await_task_completions(handles: Vec<tokio::task::JoinHandle<()>>) {
511    for handle in handles {
512        if let Err(e) = handle.await {
513            warn!("Task execution failed: {}", e);
514        }
515    }
516}
517
518/// Execute the workflow asynchronously
519/// This function will be called to start the workflow execution process and does the following:
520/// 1. Iterates over workflow tasks while the shared workflow is not complete.
521/// 2. Resets any failed tasks to allow them to be retried. This needs to happen before getting ready tasks.
522/// 3. Gets all ready tasks
523/// 4. For each ready task:
524/// ///    - Marks the task as running
525/// ///    - Checks previous tasks for injected context
526/// ///    - Gets the agent for the task  
527/// ///    - Spawn a new tokio task and execute task with agent
528/// ///    - Push task to the handles vector
529/// 4. Waits for all spawned tasks to complete
530#[instrument(skip_all)]
531pub async fn execute_workflow(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
532    // Important to remember that the workflow is an Arc<RwLock<Workflow>> is a new clone of
533    // the loaded workflow. This allows us to mutate the workflow without affecting the original
534    // workflow instance.
535    debug!("Starting workflow execution");
536
537    // Run until workflow is complete
538    while !is_workflow_complete(workflow) {
539        // Reset any failed tasks
540        // This will return an error if any task exceeds its max retries (set at the task level)
541        reset_failed_workflow_tasks(workflow)?;
542
543        // Get tasks ready for execution
544        // This will return an Arc<RwLock<Task>>
545        let ready_tasks = get_ready_tasks(workflow);
546        debug!("Found {} ready tasks for execution", ready_tasks.len());
547
548        // Check for circular dependencies
549        if ready_tasks.is_empty() {
550            if check_for_circular_dependencies(workflow) {
551                break;
552            }
553            continue;
554        }
555
556        // Execute tasks asynchronously
557        let handles = spawn_task_executions(workflow, ready_tasks)?;
558
559        // Wait for all tasks to complete
560        await_task_completions(handles).await;
561    }
562
563    debug!("Workflow execution completed");
564    Ok(())
565}
566
567impl Serialize for Workflow {
568    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
569    where
570        S: Serializer,
571    {
572        let mut state = serializer.serialize_struct("Workflow", 4)?;
573
574        // set session to none
575        state.serialize_field("id", &self.id)?;
576        state.serialize_field("name", &self.name)?;
577        state.serialize_field("task_list", &self.task_list)?;
578        state.serialize_field("agents", &self.agents)?;
579
580        state.end()
581    }
582}
583
584impl<'de> Deserialize<'de> for Workflow {
585    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
586    where
587        D: Deserializer<'de>,
588    {
589        #[derive(Deserialize)]
590        #[serde(field_identifier, rename_all = "snake_case")]
591        enum Field {
592            Id,
593            Name,
594            TaskList,
595            Agents,
596        }
597
598        struct WorkflowVisitor;
599
600        impl<'de> Visitor<'de> for WorkflowVisitor {
601            type Value = Workflow;
602
603            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
604                formatter.write_str("struct Workflow")
605            }
606
607            fn visit_map<V>(self, mut map: V) -> Result<Workflow, V::Error>
608            where
609                V: MapAccess<'de>,
610            {
611                let mut id = None;
612                let mut name = None;
613                let mut task_list_data = None;
614                let mut agents: Option<HashMap<String, Agent>> = None;
615
616                while let Some(key) = map.next_key()? {
617                    match key {
618                        Field::Id => {
619                            let value: String = map.next_value().map_err(|e| {
620                                error!("Failed to deserialize field 'id': {e}");
621                                de::Error::custom(format!("Failed to deserialize field 'id': {e}"))
622                            })?;
623                            id = Some(value);
624                        }
625                        Field::TaskList => {
626                            // Deserialize as a generic Value first
627                            let value: TaskList = map.next_value().map_err(|e| {
628                                error!("Failed to deserialize field 'task_list': {e}");
629                                de::Error::custom(format!(
630                                    "Failed to deserialize field 'task_list': {e}",
631                                ))
632                            })?;
633
634                            task_list_data = Some(value);
635                        }
636                        Field::Name => {
637                            let value: String = map.next_value().map_err(|e| {
638                                error!("Failed to deserialize field 'name': {e}");
639                                de::Error::custom(format!(
640                                    "Failed to deserialize field 'name': {e}",
641                                ))
642                            })?;
643                            name = Some(value);
644                        }
645                        Field::Agents => {
646                            let value: HashMap<String, Agent> = map.next_value().map_err(|e| {
647                                error!("Failed to deserialize field 'agents': {e}");
648                                de::Error::custom(format!(
649                                    "Failed to deserialize field 'agents': {e}"
650                                ))
651                            })?;
652                            agents = Some(value);
653                        }
654                    }
655                }
656
657                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
658                let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
659                let task_list_data =
660                    task_list_data.ok_or_else(|| de::Error::missing_field("task_list"))?;
661                let agents = agents.ok_or_else(|| de::Error::missing_field("agents"))?;
662
663                let event_tracker = Arc::new(RwLock::new(EventTracker::new(create_uuid7())));
664
665                // convert agents to arc
666                let agents = agents
667                    .into_iter()
668                    .map(|(id, agent)| (id, Arc::new(agent)))
669                    .collect();
670
671                Ok(Workflow {
672                    id,
673                    name,
674                    task_list: task_list_data,
675                    agents,
676                    event_tracker,
677                    global_context: None, // Initialize with no global context
678                })
679            }
680        }
681
682        const FIELDS: &[&str] = &["id", "name", "task_list", "agents"];
683        deserializer.deserialize_struct("Workflow", FIELDS, WorkflowVisitor)
684    }
685}
686
687#[pyclass(name = "Workflow")]
688#[derive(Debug, Clone)]
689pub struct PyWorkflow {
690    workflow: Workflow,
691
692    // allow adding output types for python tasks (py only)
693    // these are provided at runtime by the user and must match the response
694    // format of the prompt the task is associated with
695    output_types: HashMap<String, Arc<PyObject>>,
696
697    // potatohead version holds a reference to the runtime
698    runtime: Arc<tokio::runtime::Runtime>,
699}
700
701#[pymethods]
702impl PyWorkflow {
703    #[new]
704    #[pyo3(signature = (name))]
705    pub fn new(name: &str) -> Result<Self, WorkflowError> {
706        debug!("Creating new workflow: {}", name);
707        Ok(Self {
708            workflow: Workflow::new(name),
709            output_types: HashMap::new(),
710            runtime: Arc::new(
711                tokio::runtime::Runtime::new()
712                    .map_err(|e| WorkflowError::RuntimeError(e.to_string()))?,
713            ),
714        })
715    }
716
717    #[getter]
718    pub fn name(&self) -> String {
719        self.workflow.name.clone()
720    }
721
722    #[getter]
723    pub fn task_list(&self) -> TaskList {
724        self.workflow.task_list.clone()
725    }
726
727    #[getter]
728    pub fn is_workflow(&self) -> bool {
729        true
730    }
731
732    #[getter]
733    pub fn __workflow__(&self) -> String {
734        self.model_dump_json()
735    }
736
737    #[getter]
738    pub fn agents(&self) -> Result<HashMap<String, PyAgent>, WorkflowError> {
739        self.workflow
740            .agents
741            .iter()
742            .map(|(id, agent)| {
743                Ok((
744                    id.clone(),
745                    PyAgent {
746                        agent: agent.clone(),
747                        runtime: self.runtime.clone(),
748                    },
749                ))
750            })
751            .collect::<Result<HashMap<_, _>, _>>()
752    }
753
754    #[pyo3(signature = (task_output_types))]
755    pub fn add_task_output_types<'py>(
756        &mut self,
757        task_output_types: Bound<'py, PyDict>,
758    ) -> PyResult<()> {
759        let converted: HashMap<String, Arc<PyObject>> = task_output_types
760            .iter()
761            .map(|(k, v)| -> PyResult<(String, Arc<PyObject>)> {
762                // Explicitly return a Result from the closure
763                let key = k.extract::<String>()?;
764                let value = v.clone().unbind();
765                Ok((key, Arc::new(value)))
766            })
767            .collect::<PyResult<_>>()?;
768        self.output_types.extend(converted);
769        Ok(())
770    }
771
772    #[pyo3(signature = (task, output_type = None))]
773    pub fn add_task(
774        &mut self,
775        py: Python<'_>,
776        mut task: Task,
777        output_type: Option<Bound<'_, PyAny>>,
778    ) -> Result<(), WorkflowError> {
779        if let Some(output_type) = output_type {
780            // Parse and set the response format
781            (task.prompt.response_type, task.prompt.response_json_schema) =
782                parse_response_to_json(py, &output_type)
783                    .map_err(|e| WorkflowError::InvalidOutputType(e.to_string()))?;
784
785            // Store the output type for later use
786            self.output_types
787                .insert(task.id.clone(), Arc::new(output_type.unbind()));
788        }
789
790        self.workflow.task_list.add_task(task)?;
791        Ok(())
792    }
793
794    pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
795        for task in tasks {
796            self.workflow.task_list.add_task(task)?;
797        }
798        Ok(())
799    }
800
801    pub fn add_agent(&mut self, agent: &Bound<'_, PyAgent>) {
802        // extract the arc rust agent from the python agent
803        let agent = agent.extract::<PyAgent>().unwrap().agent.clone();
804        self.workflow.agents.insert(agent.id.clone(), agent);
805    }
806
807    pub fn is_complete(&self) -> bool {
808        self.workflow.task_list.is_complete()
809    }
810
811    pub fn pending_count(&self) -> usize {
812        self.workflow.task_list.pending_count()
813    }
814
815    pub fn execution_plan<'py>(
816        &self,
817        py: Python<'py>,
818    ) -> Result<Bound<'py, PyDict>, WorkflowError> {
819        let plan = self.workflow.execution_plan()?;
820        debug!("Execution plan: {:?}", plan);
821
822        // turn hashmap into a to json
823        let json = serde_json::to_value(plan).map_err(|e| {
824            error!("Failed to serialize execution plan to JSON: {}", e);
825            e
826        })?;
827
828        let pydict = PyDict::new(py);
829        json_to_pydict(py, &json, &pydict)?;
830
831        Ok(pydict)
832    }
833
834    #[pyo3(signature = (global_context=None))]
835    pub fn run(
836        &self,
837        py: Python,
838        global_context: Option<Bound<'_, PyDict>>,
839    ) -> Result<WorkflowResult, WorkflowError> {
840        debug!("Running workflow: {}", self.workflow.name);
841
842        // Convert the global context from PyDict to serde_json::Value if provided
843        let global_context = if let Some(context) = global_context {
844            // Convert PyDict to serde_json::Value
845            let json_value = pyobject_to_json(&context.into_bound_py_any(py)?)?;
846            Some(json_value)
847        } else {
848            None
849        };
850
851        let workflow: Arc<RwLock<Workflow>> = self
852            .runtime
853            .block_on(async { self.workflow.run(global_context).await })?;
854
855        // Try to get exclusive ownership of the workflow by unwrapping the Arc if there's only one reference
856        let workflow_result = match Arc::try_unwrap(workflow) {
857            // If we have exclusive ownership, we can consume the RwLock
858            Ok(rwlock) => {
859                // Unwrap the RwLock to get the Workflow
860                let workflow = rwlock
861                    .into_inner()
862                    .map_err(|_| WorkflowError::LockAcquireError)?;
863
864                // Get the events before creating WorkflowResult
865                let events = workflow
866                    .event_tracker
867                    .read()
868                    .unwrap()
869                    .events
870                    .read()
871                    .unwrap()
872                    .clone();
873
874                // Move the tasks out of the workflow
875                WorkflowResult::new(py, workflow.task_list.tasks(), &self.output_types, events)
876            }
877            // If there are other references, we need to clone
878            Err(arc) => {
879                // Just read the workflow
880                error!("Workflow still has other references, reading instead of consuming.");
881                let workflow = arc
882                    .read()
883                    .map_err(|_| WorkflowError::ReadLockAcquireError)?;
884
885                // Get the events before creating WorkflowResult
886                let events = workflow
887                    .event_tracker
888                    .read()
889                    .unwrap()
890                    .events
891                    .read()
892                    .unwrap()
893                    .clone();
894
895                WorkflowResult::new(py, workflow.task_list.tasks(), &self.output_types, events)
896            }
897        };
898
899        info!("Workflow execution completed successfully.");
900        Ok(workflow_result)
901    }
902
903    pub fn model_dump_json(&self) -> String {
904        serde_json::to_string(&self.workflow).unwrap()
905    }
906
907    #[staticmethod]
908    #[pyo3(signature = (json_string, output_types=None))]
909    pub fn model_validate_json(
910        json_string: String,
911        output_types: Option<Bound<'_, PyDict>>,
912    ) -> Result<Self, WorkflowError> {
913        let workflow: Workflow = serde_json::from_str(&json_string)?;
914        let runtime = Arc::new(
915            tokio::runtime::Runtime::new()
916                .map_err(|e| WorkflowError::RuntimeError(e.to_string()))?,
917        );
918
919        let output_types = if let Some(output_types) = output_types {
920            output_types
921                .iter()
922                .map(|(k, v)| -> PyResult<(String, Arc<PyObject>)> {
923                    let key = k.extract::<String>()?;
924                    let value = v.clone().unbind();
925                    Ok((key, Arc::new(value)))
926                })
927                .collect::<PyResult<HashMap<String, Arc<PyObject>>>>()?
928        } else {
929            HashMap::new()
930        };
931
932        let py_workflow = PyWorkflow {
933            workflow,
934            output_types,
935            runtime,
936        };
937
938        Ok(py_workflow)
939    }
940}
941
942#[cfg(test)]
943mod tests {
944    use super::*;
945    use potato_prompt::prompt::ResponseType;
946    use potato_prompt::{prompt::types::PromptContent, Message, Prompt};
947
948    #[test]
949    fn test_workflow_creation() {
950        let workflow = Workflow::new("Test Workflow");
951        assert_eq!(workflow.name, "Test Workflow");
952        assert_eq!(workflow.id.len(), 36); // UUID7 length
953    }
954
955    #[test]
956    fn test_task_list_add_and_get() {
957        let mut task_list = TaskList::new();
958        let prompt_content = PromptContent::Str("Test prompt".to_string());
959        let prompt = Prompt::new_rs(
960            vec![Message::new_rs(prompt_content)],
961            "gpt-4o",
962            potato_type::Provider::OpenAI,
963            vec![],
964            None,
965            None,
966            ResponseType::Null,
967        )
968        .unwrap();
969
970        let task = Task::new("task1", prompt, "task1", None, None);
971        task_list.add_task(task.clone()).unwrap();
972        assert_eq!(
973            task_list.get_task(&task.id).unwrap().read().unwrap().id,
974            task.id
975        );
976        task_list.reset_failed_tasks().unwrap();
977    }
978}