potato_workflow/workflow/
flow.rs

1use crate::tasklist::TaskList;
2use crate::types::Context;
3use crate::{
4    events::{EventTracker, TaskEvent},
5    workflow::error::WorkflowError,
6};
7pub use potato_agent::agents::{
8    agent::{Agent, PyAgent},
9    task::{PyTask, Task, TaskStatus},
10};
11use potato_agent::PyAgentResponse;
12use potato_prompt::parse_response_to_json;
13use potato_prompt::prompt::types::Role;
14use potato_prompt::Message;
15use potato_util::{create_uuid7, utils::update_serde_map_with, PyHelperFuncs};
16use potato_util::{json_to_pydict, pyobject_to_json};
17use pyo3::prelude::*;
18use pyo3::IntoPyObjectExt;
19use serde::{
20    de::{self, MapAccess, Visitor},
21    ser::SerializeStruct,
22    Deserialize, Deserializer, Serialize, Serializer,
23};
24use serde_json::Map;
25use serde_json::Value;
26use std::collections::{HashMap, HashSet};
27use std::sync::Arc;
28use std::sync::RwLock;
29use tracing::instrument;
30use tracing::{debug, error, info, warn};
31
32/// Python workflows are a work in progress
33use pyo3::types::PyDict;
34
35#[derive(Debug)]
36#[pyclass]
37pub struct WorkflowResult {
38    #[pyo3(get)]
39    pub tasks: HashMap<String, Py<PyTask>>,
40
41    #[pyo3(get)]
42    pub events: Vec<TaskEvent>,
43}
44
45impl WorkflowResult {
46    pub fn new(
47        py: Python,
48        tasks: HashMap<String, Task>,
49        output_types: &HashMap<String, Arc<Py<PyAny>>>,
50        events: Vec<TaskEvent>,
51    ) -> Self {
52        let py_tasks = tasks
53            .into_iter()
54            .map(|(id, task)| {
55                let py_agent_response = if let Some(result) = task.result {
56                    let output_type = output_types.get(&id).map(|arc| arc.as_ref().clone_ref(py));
57                    Some(PyAgentResponse::new(result, output_type))
58                } else {
59                    None
60                };
61                let py_task = PyTask {
62                    id: task.id.clone(),
63                    prompt: task.prompt,
64                    dependencies: task.dependencies,
65                    status: task.status,
66                    agent_id: task.agent_id,
67                    result: py_agent_response,
68                    max_retries: task.max_retries,
69                    retry_count: task.retry_count,
70                };
71                (id, Py::new(py, py_task).unwrap())
72            })
73            .collect::<HashMap<_, _>>();
74
75        Self {
76            tasks: py_tasks,
77            events,
78        }
79    }
80}
81
82#[pymethods]
83impl WorkflowResult {
84    pub fn __str__(&self) -> String {
85        // serialize tasks to json
86        let json = serde_json::json!({
87            "tasks": serde_json::to_value(&self.tasks).unwrap_or(Value::Null),
88            "events": serde_json::to_value(&self.events).unwrap_or(Value::Null)
89        });
90
91        PyHelperFuncs::__str__(&json)
92    }
93}
94
95/// Rust-specific implementation of a workflow
96#[derive(Debug, Clone)]
97pub struct Workflow {
98    pub id: String,
99    pub name: String,
100    pub task_list: TaskList,
101    pub agents: HashMap<String, Arc<Agent>>,
102    pub event_tracker: Arc<RwLock<EventTracker>>,
103    pub global_context: Option<Value>,
104}
105
106impl PartialEq for Workflow {
107    fn eq(&self, other: &Self) -> bool {
108        // Compare by ID and name
109        self.id == other.id && self.name == other.name
110    }
111}
112
113impl Workflow {
114    /// Reload the agent clients by overwriting the existing clients
115    /// We are trying to solve the deserialization issue where GenAIClient requires
116    /// and async context to be created. During deserialization we don't have that context, so we default
117    /// to an Undefined client, but keep the other agent details. This means that if we try to run a workflow after deserialization
118    /// we will get an error when we try to execute a task with an Undefined client.
119    /// For this specific function, we can either make Arc<Agent> into RW compatible, or we can
120    /// rebuild the entire agents map with new Arcs. Given that the only mutation we need to do is to
121    /// rebuild the GenAIClient, we opt for the latter because we don't need to make everything else RW compatible.
122    /// This will incure a small startup cost, but it will be a one-time cost and 99% of the time unnoticed.
123    pub async fn reset_agents(&mut self) -> Result<(), WorkflowError> {
124        let mut agents_map = self.agents.clone();
125
126        for agent in self.agents.values_mut() {
127            agents_map.insert(agent.id.clone(), Arc::new(agent.rebuild_client().await?));
128        }
129        self.agents = agents_map;
130        Ok(())
131    }
132    pub fn new(name: &str) -> Self {
133        debug!("Creating new workflow: {}", name);
134        let id = create_uuid7();
135        Self {
136            id: id.clone(),
137            name: name.to_string(),
138            task_list: TaskList::new(),
139            agents: HashMap::new(),
140            event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
141            global_context: None, // Initialize with no global context
142        }
143    }
144    pub fn events(&self) -> Vec<TaskEvent> {
145        let tracker = self.event_tracker.read().unwrap();
146        let events = tracker.events.read().unwrap().clone();
147        events
148    }
149
150    pub fn total_duration(&self) -> i32 {
151        let tracker = self.event_tracker.read().unwrap();
152
153        if tracker.is_empty() {
154            0
155        } else {
156            //iter over each tracker event and get the details for each task event and get duration
157            let mut total_duration = chrono::Duration::zero();
158            for event in tracker.events.read().unwrap().iter() {
159                total_duration += event.details.duration.unwrap_or(chrono::Duration::zero());
160            }
161            total_duration.subsec_millis()
162        }
163    }
164
165    pub fn get_new_workflow(&self, global_context: Option<Value>) -> Result<Self, WorkflowError> {
166        // set new id for the new workflow
167        let id = create_uuid7();
168
169        // create deep copy of the tasklist so we don't clone the arc
170        let task_list = self.task_list.deep_clone()?;
171
172        Ok(Workflow {
173            id: id.clone(),
174            name: self.name.clone(),
175            task_list,
176            agents: self.agents.clone(), // Agents can be shared since they're read-only during execution
177            event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
178            global_context, // Use the provided global context or None
179        })
180    }
181
182    pub async fn run(
183        &self,
184        global_context: Option<Value>,
185    ) -> Result<Arc<RwLock<Workflow>>, WorkflowError> {
186        debug!("Running workflow: {}", self.name);
187
188        let run_workflow = Arc::new(RwLock::new(self.get_new_workflow(global_context)?));
189
190        execute_workflow(&run_workflow).await?;
191
192        Ok(run_workflow)
193    }
194
195    pub fn is_complete(&self) -> bool {
196        self.task_list.is_complete()
197    }
198
199    pub fn pending_count(&self) -> usize {
200        self.task_list.pending_count()
201    }
202
203    pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
204        self.task_list.add_task(task)
205    }
206
207    pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
208        for task in tasks {
209            self.task_list.add_task(task)?;
210        }
211        Ok(())
212    }
213
214    pub fn add_agent(&mut self, agent: &Agent) {
215        self.agents
216            .insert(agent.id.clone(), Arc::new(agent.clone()));
217    }
218
219    pub fn execution_plan(&self) -> Result<HashMap<i32, HashSet<String>>, WorkflowError> {
220        let mut remaining: HashMap<String, HashSet<String>> = self
221            .task_list
222            .tasks
223            .iter()
224            .map(|(id, task)| {
225                (
226                    id.clone(),
227                    task.read().unwrap().dependencies.iter().cloned().collect(),
228                )
229            })
230            .collect();
231
232        let mut executed = HashSet::new();
233        let mut plan = HashMap::new();
234        let mut step = 1;
235
236        while !remaining.is_empty() {
237            // Find all tasks that can be executed in parallel - collect just the keys we need to remove
238            let ready_keys: Vec<String> = remaining
239                .iter()
240                .filter(|(_, deps)| deps.is_subset(&executed))
241                .map(|(id, _)| id.to_string())
242                .collect();
243
244            if ready_keys.is_empty() {
245                // Circular dependency detected
246                break;
247            }
248
249            // Create the set for the plan (reusing the already allocated Strings)
250            let mut ready_set = HashSet::with_capacity(ready_keys.len());
251
252            // Update tracking sets and build the ready set in one pass
253            for key in ready_keys {
254                executed.insert(key.clone());
255                remaining.remove(&key);
256                ready_set.insert(key);
257            }
258
259            // Add parallel tasks to the current step
260            plan.insert(step, ready_set);
261
262            step += 1;
263        }
264
265        Ok(plan)
266    }
267
268    pub fn __str__(&self) -> String {
269        PyHelperFuncs::__str__(&self.task_list)
270    }
271
272    pub fn serialize(&self) -> Result<String, serde_json::Error> {
273        // reset the workflow
274        let json = serde_json::to_string(self).unwrap();
275        // Add debug output to see what's being serialized
276        Ok(json)
277    }
278
279    pub fn from_json(json: &str) -> Result<Self, WorkflowError> {
280        // Deserialize the JSON string into a Workflow instance
281        Ok(serde_json::from_str(json)?)
282    }
283
284    pub fn task_names(&self) -> Vec<String> {
285        self.task_list
286            .tasks
287            .keys()
288            .cloned()
289            .collect::<Vec<String>>()
290    }
291}
292
293/// Check if the workflow is complete
294/// # Arguments
295/// * `workflow` - A reference to the workflow instance
296/// # Returns true if the workflow is complete, false otherwise
297fn is_workflow_complete(workflow: &Arc<RwLock<Workflow>>) -> bool {
298    workflow.read().unwrap().is_complete()
299}
300
301/// Reset failed tasks in the workflow
302/// # Arguments
303/// * `workflow` - A reference to the workflow instance
304/// # Returns Ok(()) if successful, or an error if the reset fails
305fn reset_failed_workflow_tasks(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
306    match workflow.write().unwrap().task_list.reset_failed_tasks() {
307        Ok(_) => Ok(()),
308        Err(e) => {
309            warn!("Failed to reset failed tasks: {}", e);
310            Err(e)
311        }
312    }
313}
314
315/// Get all ready tasks in the workflow
316/// # Arguments
317/// * `workflow` - A reference to the workflow instance
318/// # Returns a vector of tasks that are ready to be executed
319fn get_ready_tasks(workflow: &Arc<RwLock<Workflow>>) -> Vec<Arc<RwLock<Task>>> {
320    workflow.read().unwrap().task_list.get_ready_tasks()
321}
322
323/// Check for circular dependencies
324/// # Arguments
325/// * `workflow` - A reference to the workflow instance
326/// # Returns true if circular dependencies are detected, false otherwise
327fn check_for_circular_dependencies(workflow: &Arc<RwLock<Workflow>>) -> bool {
328    let pending_count = workflow.read().unwrap().pending_count();
329
330    if pending_count > 0 {
331        warn!(
332            "No ready tasks found but {} pending tasks remain. Possible circular dependency.",
333            pending_count
334        );
335        return true;
336    }
337
338    false
339}
340
341/// Mark a task as running
342/// # Arguments
343/// * `workflow` - A reference to the workflow instance
344/// # Returns nothing
345fn mark_task_as_running(task: Arc<RwLock<Task>>, event_tracker: &Arc<RwLock<EventTracker>>) {
346    let mut task = task.write().unwrap();
347    task.set_status(TaskStatus::Running);
348    event_tracker.write().unwrap().record_task_started(&task.id);
349}
350
351/// Get an agent for a task
352/// # Arguments
353/// * `workflow` - A reference to the workflow instance
354/// * `task` - A reference to the task for which the agent is needed
355fn get_agent_for_task(workflow: &Arc<RwLock<Workflow>>, agent_id: &str) -> Option<Arc<Agent>> {
356    let wf = workflow.read().unwrap();
357    wf.agents.get(agent_id).cloned()
358}
359
360/// Builds the context for a task from its dependencies
361/// # Arguments
362/// * `workflow` - A reference to the workflow instance
363/// * `task` - A reference to the task for which the context is being built
364/// # Returns a HashMap containing the context messages for the task
365#[instrument(skip_all)]
366fn build_task_context(
367    workflow: &Arc<RwLock<Workflow>>,
368    task_dependencies: &Vec<String>,
369) -> Result<Context, WorkflowError> {
370    let wf = workflow.read().unwrap();
371    let mut ctx = HashMap::new();
372    let mut param_ctx: Value = Value::Object(Map::new());
373
374    for dep_id in task_dependencies {
375        debug!("Building context for task dependency: {}", dep_id);
376        if let Some(dep) = wf.task_list.get_task(dep_id) {
377            if let Some(result) = &dep.read().unwrap().result {
378                let msg_to_insert = result.response.to_message(Role::Assistant);
379
380                match msg_to_insert {
381                    Ok(message) => {
382                        ctx.insert(dep_id.clone(), message);
383                    }
384                    Err(e) => {
385                        warn!("Failed to convert response to message: {}", e);
386                    }
387                }
388
389                if let Some(structure_output) = result.response.extract_structured_data() {
390                    // Value should be a serde_json::Value Object type
391                    // validate that it's an object
392                    if structure_output.is_object() {
393                        // extract the Map from the Value
394                        update_serde_map_with(&mut param_ctx, &structure_output)?;
395                    }
396                }
397            }
398        }
399    }
400
401    debug!("Built context for task dependencies: {:?}", ctx);
402    let global_context = workflow.read().unwrap().global_context.clone();
403
404    Ok((ctx, param_ctx, global_context))
405}
406
407/// Spawns an individual task execution
408/// # Arguments
409/// * `workflow` - A reference to the workflow instance
410/// * `task` - The task to be executed
411/// * `task_id` - The ID of the task
412/// * `agent` - An optional reference to the agent that will execute the task
413/// * `context` - A HashMap containing the context messages for the task
414/// # Returns a JoinHandle for the spawned task
415fn spawn_task_execution(
416    event_tracker: Arc<RwLock<EventTracker>>,
417    task: Arc<RwLock<Task>>,
418    task_id: String,
419    agent: Option<Arc<Agent>>,
420    context: HashMap<String, Vec<Message>>,
421    parameter_context: Value,
422    global_context: Option<Value>,
423) -> tokio::task::JoinHandle<()> {
424    tokio::spawn(async move {
425        if let Some(agent) = agent {
426            // (1) Insert any context messages and/or parameters into the task prompt
427            // (2) Execute the task with the agent
428            // (3) Return the AgentResponse
429            let result = agent
430                .execute_task_with_context(&task, context, parameter_context, global_context)
431                .await;
432            match result {
433                Ok(response) => {
434                    let mut write_task = task.write().unwrap();
435                    write_task.set_status(TaskStatus::Completed);
436                    write_task.set_result(response.clone());
437                    event_tracker.write().unwrap().record_task_completed(
438                        &write_task.id,
439                        &write_task.prompt,
440                        response,
441                    );
442                }
443                Err(e) => {
444                    error!("Task {} failed: {}", task_id, e);
445                    let mut write_task = task.write().unwrap();
446                    write_task.set_status(TaskStatus::Failed);
447                    event_tracker.write().unwrap().record_task_failed(
448                        &write_task.id,
449                        &e.to_string(),
450                        &write_task.prompt,
451                    );
452                }
453            }
454        } else {
455            error!("No agent found for task {}", task_id);
456            let mut write_task = task.write().unwrap();
457            write_task.set_status(TaskStatus::Failed);
458        }
459    })
460}
461
462fn get_parameters_from_context(task: Arc<RwLock<Task>>) -> (String, Vec<String>, String) {
463    let (task_id, dependencies, agent_id) = {
464        let task_guard = task.read().unwrap();
465        (
466            task_guard.id.clone(),
467            task_guard.dependencies.clone(),
468            task_guard.agent_id.clone(),
469        )
470    };
471
472    (task_id, dependencies, agent_id)
473}
474
475/// Helper for spawning a task execution
476/// # Arguments
477/// * `workflow` - A reference to the workflow instance
478/// * `tasks` - A vector of tasks to be executed
479/// # Returns a vector of JoinHandles for the spawned tasks
480fn spawn_task_executions(
481    workflow: &Arc<RwLock<Workflow>>,
482    ready_tasks: Vec<Arc<RwLock<Task>>>,
483) -> Result<Vec<tokio::task::JoinHandle<()>>, WorkflowError> {
484    let mut handles = Vec::with_capacity(ready_tasks.len());
485
486    // Get the event tracker from the workflow
487    let event_tracker = workflow.read().unwrap().event_tracker.clone();
488
489    for task in ready_tasks {
490        // Get task parameters
491        let (task_id, dependencies, agent_id) = get_parameters_from_context(task.clone());
492
493        // Mark task as running
494        // This will also record the task started event
495        mark_task_as_running(task.clone(), &event_tracker);
496
497        // Build the context
498        // Here we:
499        // 1. Get the task dependencies and their results (these will be injected as assistant messages)
500        // 2. Parse dependent tasks for any structured outputs and return as a serde_json::Value (this will be task-level context)
501        let (context, parameter_context, global_context) =
502            build_task_context(workflow, &dependencies)?;
503
504        // Get/clone agent ARC
505        let agent = get_agent_for_task(workflow, &agent_id);
506
507        // Spawn task execution and push handle to future vector
508        let handle = spawn_task_execution(
509            event_tracker.clone(),
510            task.clone(),
511            task_id,
512            agent,
513            context,
514            parameter_context,
515            global_context,
516        );
517        handles.push(handle);
518    }
519
520    Ok(handles)
521}
522
523/// Wait for all spawned tasks to complete
524/// # Arguments
525/// * `handles` - A vector of JoinHandles for the spawned tasks
526/// # Returns nothing
527async fn await_task_completions(handles: Vec<tokio::task::JoinHandle<()>>) {
528    for handle in handles {
529        if let Err(e) = handle.await {
530            warn!("Task execution failed: {}", e);
531        }
532    }
533}
534
535/// Execute the workflow asynchronously
536/// This function will be called to start the workflow execution process and does the following:
537/// 1. Iterates over workflow tasks while the shared workflow is not complete.
538/// 2. Resets any failed tasks to allow them to be retried. This needs to happen before getting ready tasks.
539/// 3. Gets all ready tasks
540/// 4. For each ready task:
541/// ///    - Marks the task as running
542/// ///    - Checks previous tasks for injected context
543/// ///    - Gets the agent for the task  
544/// ///    - Spawn a new tokio task and execute task with agent
545/// ///    - Push task to the handles vector
546/// 4. Waits for all spawned tasks to complete
547#[instrument(skip_all)]
548pub async fn execute_workflow(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
549    // Important to remember that the workflow is an Arc<RwLock<Workflow>> is a new clone of
550    // the loaded workflow. This allows us to mutate the workflow without affecting the original
551    // workflow instance.
552    debug!("Starting workflow execution");
553
554    // Run until workflow is complete
555    while !is_workflow_complete(workflow) {
556        // Reset any failed tasks
557        // This will return an error if any task exceeds its max retries (set at the task level)
558        reset_failed_workflow_tasks(workflow)?;
559
560        // Get tasks ready for execution
561        // This will return an Arc<RwLock<Task>>
562        let ready_tasks = get_ready_tasks(workflow);
563        debug!("Found {} ready tasks for execution", ready_tasks.len());
564
565        // Check for circular dependencies
566        if ready_tasks.is_empty() {
567            if check_for_circular_dependencies(workflow) {
568                break;
569            }
570            continue;
571        }
572
573        // Execute tasks asynchronously
574        let handles = spawn_task_executions(workflow, ready_tasks)?;
575
576        // Wait for all tasks to complete
577        await_task_completions(handles).await;
578    }
579
580    debug!("Workflow execution completed");
581    Ok(())
582}
583
584impl Serialize for Workflow {
585    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
586    where
587        S: Serializer,
588    {
589        let mut state = serializer.serialize_struct("Workflow", 4)?;
590
591        // set session to none
592        state.serialize_field("id", &self.id)?;
593        state.serialize_field("name", &self.name)?;
594        state.serialize_field("task_list", &self.task_list)?;
595        state.serialize_field("agents", &self.agents)?;
596
597        state.end()
598    }
599}
600
601impl<'de> Deserialize<'de> for Workflow {
602    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
603    where
604        D: Deserializer<'de>,
605    {
606        #[derive(Deserialize)]
607        #[serde(field_identifier, rename_all = "snake_case")]
608        enum Field {
609            Id,
610            Name,
611            TaskList,
612            Agents,
613        }
614
615        struct WorkflowVisitor;
616
617        impl<'de> Visitor<'de> for WorkflowVisitor {
618            type Value = Workflow;
619
620            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
621                formatter.write_str("struct Workflow")
622            }
623
624            fn visit_map<V>(self, mut map: V) -> Result<Workflow, V::Error>
625            where
626                V: MapAccess<'de>,
627            {
628                let mut id = None;
629                let mut name = None;
630                let mut task_list_data = None;
631                let mut agents: Option<HashMap<String, Agent>> = None;
632
633                while let Some(key) = map.next_key()? {
634                    match key {
635                        Field::Id => {
636                            let value: String = map.next_value().map_err(|e| {
637                                error!("Failed to deserialize field 'id': {e}");
638                                de::Error::custom(format!("Failed to deserialize field 'id': {e}"))
639                            })?;
640                            id = Some(value);
641                        }
642                        Field::TaskList => {
643                            // Deserialize as a generic Value first
644                            let value: TaskList = map.next_value().map_err(|e| {
645                                error!("Failed to deserialize field 'task_list': {e}");
646                                de::Error::custom(format!(
647                                    "Failed to deserialize field 'task_list': {e}",
648                                ))
649                            })?;
650
651                            task_list_data = Some(value);
652                        }
653                        Field::Name => {
654                            let value: String = map.next_value().map_err(|e| {
655                                error!("Failed to deserialize field 'name': {e}");
656                                de::Error::custom(format!(
657                                    "Failed to deserialize field 'name': {e}",
658                                ))
659                            })?;
660                            name = Some(value);
661                        }
662                        Field::Agents => {
663                            let value: HashMap<String, Agent> = map.next_value().map_err(|e| {
664                                error!("Failed to deserialize field 'agents': {e}");
665                                de::Error::custom(format!(
666                                    "Failed to deserialize field 'agents': {e}"
667                                ))
668                            })?;
669                            agents = Some(value);
670                        }
671                    }
672                }
673
674                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
675                let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
676                let task_list_data =
677                    task_list_data.ok_or_else(|| de::Error::missing_field("task_list"))?;
678                let agents = agents.ok_or_else(|| de::Error::missing_field("agents"))?;
679
680                let event_tracker = Arc::new(RwLock::new(EventTracker::new(create_uuid7())));
681
682                // convert agents to arc
683                let agents = agents
684                    .into_iter()
685                    .map(|(id, agent)| (id, Arc::new(agent)))
686                    .collect();
687
688                Ok(Workflow {
689                    id,
690                    name,
691                    task_list: task_list_data,
692                    agents,
693                    event_tracker,
694                    global_context: None, // Initialize with no global context
695                })
696            }
697        }
698
699        const FIELDS: &[&str] = &["id", "name", "task_list", "agents"];
700        deserializer.deserialize_struct("Workflow", FIELDS, WorkflowVisitor)
701    }
702}
703
704#[pyclass(name = "Workflow")]
705#[derive(Debug, Clone)]
706pub struct PyWorkflow {
707    workflow: Workflow,
708
709    // allow adding output types for python tasks (py only)
710    // these are provided at runtime by the user and must match the response
711    // format of the prompt the task is associated with
712    output_types: HashMap<String, Arc<Py<PyAny>>>,
713
714    // potatohead version holds a reference to the runtime
715    runtime: Arc<tokio::runtime::Runtime>,
716}
717
718#[pymethods]
719impl PyWorkflow {
720    #[new]
721    #[pyo3(signature = (name))]
722    pub fn new(name: &str) -> Result<Self, WorkflowError> {
723        debug!("Creating new workflow: {}", name);
724        Ok(Self {
725            workflow: Workflow::new(name),
726            output_types: HashMap::new(),
727            runtime: Arc::new(
728                tokio::runtime::Runtime::new()
729                    .map_err(|e| WorkflowError::RuntimeError(e.to_string()))?,
730            ),
731        })
732    }
733
734    #[getter]
735    pub fn name(&self) -> String {
736        self.workflow.name.clone()
737    }
738
739    #[getter]
740    pub fn task_list(&self) -> TaskList {
741        self.workflow.task_list.clone()
742    }
743
744    #[getter]
745    pub fn is_workflow(&self) -> bool {
746        true
747    }
748
749    #[getter]
750    pub fn __workflow__(&self) -> Result<String, WorkflowError> {
751        self.model_dump_json()
752    }
753
754    #[getter]
755    pub fn agents(&self) -> Result<HashMap<String, PyAgent>, WorkflowError> {
756        self.workflow
757            .agents
758            .iter()
759            .map(|(id, agent)| {
760                Ok((
761                    id.clone(),
762                    PyAgent {
763                        agent: agent.clone(),
764                        runtime: self.runtime.clone(),
765                    },
766                ))
767            })
768            .collect::<Result<HashMap<_, _>, _>>()
769    }
770
771    #[pyo3(signature = (task_output_types))]
772    pub fn add_task_output_types<'py>(
773        &mut self,
774        task_output_types: Bound<'py, PyDict>,
775    ) -> PyResult<()> {
776        let converted: HashMap<String, Arc<Py<PyAny>>> = task_output_types
777            .iter()
778            .map(|(k, v)| -> PyResult<(String, Arc<Py<PyAny>>)> {
779                // Explicitly return a Result from the closure
780                let key = k.extract::<String>()?;
781                let value = v.clone().unbind();
782                Ok((key, Arc::new(value)))
783            })
784            .collect::<PyResult<_>>()?;
785        self.output_types.extend(converted);
786        Ok(())
787    }
788
789    #[pyo3(signature = (task, output_type = None))]
790    pub fn add_task(
791        &mut self,
792        py: Python<'_>,
793        mut task: Task,
794        output_type: Option<Bound<'_, PyAny>>,
795    ) -> Result<(), WorkflowError> {
796        if let Some(output_type) = output_type {
797            // Parse and set the response format
798            (task.prompt.response_type, task.prompt.response_json_schema) =
799                parse_response_to_json(py, &output_type)
800                    .map_err(|e| WorkflowError::InvalidOutputType(e.to_string()))?;
801
802            // Store the output type for later use
803            self.output_types
804                .insert(task.id.clone(), Arc::new(output_type.unbind()));
805        }
806
807        self.workflow.task_list.add_task(task)?;
808        Ok(())
809    }
810
811    pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
812        for task in tasks {
813            self.workflow.task_list.add_task(task)?;
814        }
815        Ok(())
816    }
817
818    pub fn add_agent(&mut self, agent: &Bound<'_, PyAgent>) {
819        // extract the arc rust agent from the python agent
820        let agent = agent.extract::<PyAgent>().unwrap().agent.clone();
821        self.workflow.agents.insert(agent.id.clone(), agent);
822    }
823
824    pub fn is_complete(&self) -> bool {
825        self.workflow.task_list.is_complete()
826    }
827
828    pub fn pending_count(&self) -> usize {
829        self.workflow.task_list.pending_count()
830    }
831
832    pub fn execution_plan<'py>(
833        &self,
834        py: Python<'py>,
835    ) -> Result<Bound<'py, PyDict>, WorkflowError> {
836        let plan = self.workflow.execution_plan()?;
837        debug!("Execution plan: {:?}", plan);
838
839        // turn hashmap into a to json
840        let json = serde_json::to_value(plan).map_err(|e| {
841            error!("Failed to serialize execution plan to JSON: {}", e);
842            e
843        })?;
844
845        let pydict = PyDict::new(py);
846        json_to_pydict(py, &json, &pydict)?;
847
848        Ok(pydict)
849    }
850
851    #[pyo3(signature = (global_context=None))]
852    pub fn run(
853        &self,
854        py: Python,
855        global_context: Option<Bound<'_, PyDict>>,
856    ) -> Result<WorkflowResult, WorkflowError> {
857        debug!("Running workflow: {}", self.workflow.name);
858
859        // Convert the global context from PyDict to serde_json::Value if provided
860        let global_context = if let Some(context) = global_context {
861            // Convert PyDict to serde_json::Value
862            let json_value = pyobject_to_json(&context.into_bound_py_any(py)?)?;
863            Some(json_value)
864        } else {
865            None
866        };
867
868        let workflow: Arc<RwLock<Workflow>> = self
869            .runtime
870            .block_on(async { self.workflow.run(global_context).await })?;
871
872        // Try to get exclusive ownership of the workflow by unwrapping the Arc if there's only one reference
873        let workflow_result = match Arc::try_unwrap(workflow) {
874            // If we have exclusive ownership, we can consume the RwLock
875            Ok(rwlock) => {
876                // Unwrap the RwLock to get the Workflow
877                let workflow = rwlock
878                    .into_inner()
879                    .map_err(|_| WorkflowError::LockAcquireError)?;
880
881                // Get the events before creating WorkflowResult
882                let events = workflow
883                    .event_tracker
884                    .read()
885                    .unwrap()
886                    .events
887                    .read()
888                    .unwrap()
889                    .clone();
890
891                // Move the tasks out of the workflow
892                WorkflowResult::new(py, workflow.task_list.tasks(), &self.output_types, events)
893            }
894            // If there are other references, we need to clone
895            Err(arc) => {
896                // Just read the workflow
897                error!("Workflow still has other references, reading instead of consuming.");
898                let workflow = arc
899                    .read()
900                    .map_err(|_| WorkflowError::ReadLockAcquireError)?;
901
902                // Get the events before creating WorkflowResult
903                let events = workflow
904                    .event_tracker
905                    .read()
906                    .unwrap()
907                    .events
908                    .read()
909                    .unwrap()
910                    .clone();
911
912                WorkflowResult::new(py, workflow.task_list.tasks(), &self.output_types, events)
913            }
914        };
915
916        info!("Workflow execution completed successfully.");
917        Ok(workflow_result)
918    }
919
920    pub fn model_dump_json(&self) -> Result<String, WorkflowError> {
921        Ok(self.workflow.serialize()?)
922    }
923
924    #[staticmethod]
925    #[pyo3(signature = (json_string, output_types=None))]
926    pub fn model_validate_json(
927        json_string: String,
928        output_types: Option<Bound<'_, PyDict>>,
929    ) -> Result<Self, WorkflowError> {
930        let runtime = Arc::new(
931            tokio::runtime::Runtime::new()
932                .map_err(|e| WorkflowError::RuntimeError(e.to_string()))?,
933        );
934        let mut workflow: Workflow = Workflow::from_json(&json_string)?;
935
936        // reload agents to ensure clients are rebuilt
937        // This is necessary because during deserialization the GenAIClient
938        runtime.block_on(async { workflow.reset_agents().await })?;
939
940        let output_types = match output_types {
941            Some(output_types) => output_types
942                .iter()
943                .map(|(k, v)| -> PyResult<(String, Arc<Py<PyAny>>)> {
944                    let key = k.extract::<String>()?;
945                    let value = v.clone().unbind();
946                    Ok((key, Arc::new(value)))
947                })
948                .collect::<PyResult<HashMap<String, Arc<Py<PyAny>>>>>()?,
949            None => HashMap::new(),
950        };
951
952        let py_workflow = PyWorkflow {
953            workflow,
954            output_types,
955            runtime,
956        };
957
958        Ok(py_workflow)
959    }
960}
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965    use potato_prompt::prompt::ResponseType;
966    use potato_prompt::{prompt::types::PromptContent, Message, Prompt};
967
968    #[test]
969    fn test_workflow_creation() {
970        let workflow = Workflow::new("Test Workflow");
971        assert_eq!(workflow.name, "Test Workflow");
972        assert_eq!(workflow.id.len(), 36); // UUID7 length
973    }
974
975    #[test]
976    fn test_task_list_add_and_get() {
977        let mut task_list = TaskList::new();
978        let prompt_content = PromptContent::Str("Test prompt".to_string());
979        let prompt = Prompt::new_rs(
980            vec![Message::new_rs(prompt_content)],
981            "gpt-4o",
982            potato_type::Provider::OpenAI,
983            vec![],
984            None,
985            None,
986            ResponseType::Null,
987        )
988        .unwrap();
989
990        let task = Task::new("task1", prompt, "task1", None, None);
991        task_list.add_task(task.clone()).unwrap();
992        assert_eq!(
993            task_list.get_task(&task.id).unwrap().read().unwrap().id,
994            task.id
995        );
996        task_list.reset_failed_tasks().unwrap();
997    }
998}