Skip to main content

potato_workflow/workflow/
flow.rs

1use crate::tasklist::TaskList;
2
3use crate::{
4    events::{EventTracker, TaskEvent},
5    workflow::error::WorkflowError,
6};
7pub use potato_agent::agents::{
8    agent::{Agent, PyAgent},
9    task::{Task, TaskStatus, WorkflowTask},
10};
11use potato_agent::{AgentError, PyAgentResponse};
12use potato_state::block_on;
13use potato_type::prompt::{parse_response_to_json, MessageNum};
14use potato_type::Provider;
15use potato_util::utils::depythonize_object_to_value;
16use potato_util::{create_uuid7, utils::update_serde_map_with, PyHelperFuncs};
17use pyo3::prelude::*;
18use pythonize::pythonize;
19use serde::{
20    de::{self, MapAccess, Visitor},
21    ser::SerializeStruct,
22    Deserialize, Deserializer, Serialize, Serializer,
23};
24use serde_json::Map;
25use serde_json::Value;
26use std::collections::{HashMap, HashSet};
27use std::sync::Arc;
28use std::sync::RwLock;
29use tracing::instrument;
30use tracing::{debug, error, info, warn};
31
32/// Python workflows are a work in progress
33use pyo3::types::PyDict;
34
35pub type Context = (HashMap<String, Vec<MessageNum>>, Value, Option<Arc<Value>>);
36
37#[derive(Debug)]
38#[pyclass]
39pub struct WorkflowResult {
40    #[pyo3(get)]
41    pub tasks: HashMap<String, Py<WorkflowTask>>,
42
43    #[pyo3(get)]
44    pub events: Vec<TaskEvent>,
45
46    last_task_id: Option<String>,
47}
48
49impl WorkflowResult {
50    pub fn new(
51        py: Python,
52        tasks: HashMap<String, Task>,
53        output_types: &HashMap<String, Arc<Py<PyAny>>>,
54        events: Vec<TaskEvent>,
55        last_task_id: Option<String>,
56    ) -> Self {
57        let py_tasks = tasks
58            .into_iter()
59            .map(|(id, task)| {
60                let py_agent_response = if let Some(result) = task.result {
61                    let output_type = output_types.get(&id).map(|arc| arc.as_ref().clone_ref(py));
62                    Some(PyAgentResponse::new(result, output_type))
63                } else {
64                    None
65                };
66                let py_task = WorkflowTask {
67                    id: task.id.clone(),
68                    prompt: task.prompt,
69                    dependencies: task.dependencies,
70                    status: task.status,
71                    agent_id: task.agent_id,
72                    result: py_agent_response,
73                    max_retries: task.max_retries,
74                    retry_count: task.retry_count,
75                };
76                (id, Py::new(py, py_task).unwrap())
77            })
78            .collect::<HashMap<_, _>>();
79
80        Self {
81            tasks: py_tasks,
82            events,
83            last_task_id,
84        }
85    }
86}
87
88#[pymethods]
89impl WorkflowResult {
90    pub fn __str__(&self) -> String {
91        // serialize tasks to json
92        let json = serde_json::json!({
93            "tasks": serde_json::to_value(&self.tasks).unwrap_or(Value::Null),
94            "events": serde_json::to_value(&self.events).unwrap_or(Value::Null)
95        });
96
97        PyHelperFuncs::__str__(&json)
98    }
99
100    /// Get last task result
101    #[getter]
102    pub fn result<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
103        if let Some(last_task_id) = &self.last_task_id {
104            if let Some(task) = self.tasks.get(last_task_id) {
105                let result = task.bind(py).getattr("result")?;
106                return Ok(result);
107            }
108        }
109        Ok(py.None().bind(py).clone())
110    }
111}
112
113/// Rust-specific implementation of a workflow
114#[derive(Debug, Clone)]
115pub struct Workflow {
116    pub id: String,
117    pub name: String,
118    pub task_list: TaskList,
119    pub agents: HashMap<String, Arc<Agent>>,
120    pub event_tracker: Arc<RwLock<EventTracker>>,
121    pub global_context: Option<Arc<Value>>,
122}
123
124impl PartialEq for Workflow {
125    fn eq(&self, other: &Self) -> bool {
126        // Compare by ID and name
127        self.id == other.id && self.name == other.name
128    }
129}
130
131impl Workflow {
132    /// Reload the agent clients by overwriting the existing clients
133    /// We are trying to solve the deserialization issue where GenAIClient requires
134    /// and async context to be created. During deserialization we don't have that context, so we default
135    /// to an Undefined client, but keep the other agent details. This means that if we try to run a workflow after deserialization
136    /// we will get an error when we try to execute a task with an Undefined client.
137    /// For this specific function, we can either make Arc<Agent> into RW compatible, or we can
138    /// rebuild the entire agents map with new Arcs. Given that the only mutation we need to do is to
139    /// rebuild the GenAIClient, we opt for the latter because we don't need to make everything else RW compatible.
140    /// This will incure a small startup cost, but it will be a one-time cost and 99% of the time unnoticed.
141    pub async fn reset_agents(&mut self) -> Result<(), WorkflowError> {
142        let mut agents_map = self.agents.clone();
143
144        for agent in self.agents.values_mut() {
145            agents_map.insert(agent.id.clone(), Arc::new(agent.rebuild_client().await?));
146        }
147        self.agents = agents_map;
148        Ok(())
149    }
150    pub fn new(name: &str) -> Self {
151        debug!("Creating new workflow: {}", name);
152        let id = create_uuid7();
153        Self {
154            id: id.clone(),
155            name: name.to_string(),
156            task_list: TaskList::new(),
157            agents: HashMap::new(),
158            event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
159            global_context: None, // Initialize with no global context
160        }
161    }
162    pub fn events(&self) -> Vec<TaskEvent> {
163        let tracker = self.event_tracker.read().unwrap();
164        let events = tracker.events.read().unwrap().clone();
165        events
166    }
167
168    pub fn total_duration(&self) -> i32 {
169        let tracker = self.event_tracker.read().unwrap();
170
171        if tracker.is_empty() {
172            0
173        } else {
174            //iter over each tracker event and get the details for each task event and get duration
175            let mut total_duration = chrono::Duration::zero();
176            for event in tracker.events.read().unwrap().iter() {
177                total_duration += event.details.duration.unwrap_or(chrono::Duration::zero());
178            }
179            total_duration.subsec_millis()
180        }
181    }
182
183    pub fn get_new_workflow(
184        &self,
185        global_context: Option<Arc<Value>>,
186    ) -> Result<Self, WorkflowError> {
187        // set new id for the new workflow
188        let id = create_uuid7();
189
190        // create deep copy of the tasklist so we don't clone the arc
191        let task_list = self.task_list.deep_clone()?;
192
193        Ok(Workflow {
194            id: id.clone(),
195            name: self.name.clone(),
196            task_list,
197            agents: self.agents.clone(), // Agents can be shared since they're read-only during execution
198            event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
199            global_context,
200        })
201    }
202
203    pub async fn run(
204        &self,
205        global_context: Option<Value>,
206    ) -> Result<Arc<RwLock<Workflow>>, WorkflowError> {
207        debug!("Running workflow: {}", self.name);
208
209        let global_context = global_context.map(Arc::new);
210        let run_workflow = Arc::new(RwLock::new(self.get_new_workflow(global_context)?));
211
212        execute_workflow(&run_workflow).await?;
213
214        Ok(run_workflow)
215    }
216
217    pub fn is_complete(&self) -> bool {
218        self.task_list.is_complete()
219    }
220
221    pub fn pending_count(&self) -> usize {
222        self.task_list.pending_count()
223    }
224
225    pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
226        self.task_list.add_task(task)
227    }
228
229    pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
230        for task in tasks {
231            self.task_list.add_task(task)?;
232        }
233        Ok(())
234    }
235
236    pub fn add_agent(&mut self, agent: &Agent) {
237        self.agents
238            .insert(agent.id.clone(), Arc::new(agent.clone()));
239    }
240
241    pub fn add_agents(&mut self, agents: &[&Agent]) {
242        for agent in agents {
243            self.add_agent(agent);
244        }
245    }
246
247    pub async fn execute_task(&self, task: &str, context: &Value) -> Result<Value, WorkflowError> {
248        let task = self
249            .task_list
250            .get_task(task)
251            .ok_or_else(|| WorkflowError::TaskNotFound(task.to_string()))?;
252
253        let agent = {
254            let task_guard = task.read().map_err(|_| WorkflowError::TaskLockError)?;
255            self.agents
256                .get(&task_guard.agent_id)
257                .ok_or_else(|| WorkflowError::AgentNotFound(task_guard.agent_id.clone()))?
258                .clone()
259        };
260
261        let max_retries = {
262            let task_guard = task.read().map_err(|_| WorkflowError::TaskLockError)?;
263            task_guard.retry_count
264        };
265
266        for attempt in 0..=max_retries {
267            match agent.execute_task_with_context(&task, context).await {
268                Ok(response) => {
269                    let is_valid = validate_response_schema(&task, &response);
270
271                    if !is_valid {
272                        if attempt == max_retries {
273                            let (task_id, expected_schema, received_response) = {
274                                let task_guard = task.read().unwrap();
275                                (
276                                    task_guard.id.clone(),
277                                    task_guard
278                                        .prompt
279                                        .response_json_schema()
280                                        .map(|s| s.to_string())
281                                        .unwrap_or_else(|| "No schema".to_string()),
282                                    response
283                                        .response_value()
284                                        .map(|v| v.to_string())
285                                        .unwrap_or_else(|| "No response".to_string()),
286                                )
287                            };
288
289                            error!(
290                                "Task {} response validation failed after {} attempts",
291                                task_id,
292                                max_retries + 1
293                            );
294
295                            return Err(WorkflowError::ResponseValidationFailed {
296                                task_id,
297                                expected_schema,
298                                received_response,
299                            });
300                        }
301                        warn!(
302                            "Task validation failed (attempt {}/{}), retrying...",
303                            attempt + 1,
304                            max_retries + 1
305                        );
306                        continue;
307                    }
308
309                    return Ok(response.response_value().unwrap_or(Value::Null));
310                }
311                Err(e) => {
312                    let task_id = { task.read().unwrap().id.clone() };
313                    warn!(
314                        "Task {} execution failed (attempt {}/{}): {}",
315                        task_id,
316                        attempt + 1,
317                        max_retries + 1,
318                        e
319                    );
320
321                    if attempt == max_retries {
322                        error!("Task {} exceeded max retries ({})", task_id, max_retries);
323                        return Err(WorkflowError::MaxRetriesExceeded(task_id));
324                    }
325                }
326            }
327        }
328
329        unreachable!("Loop should always return via Ok or error")
330    }
331
332    pub fn execution_plan(&self) -> Result<HashMap<i32, HashSet<String>>, WorkflowError> {
333        let mut remaining: HashMap<String, HashSet<String>> = self
334            .task_list
335            .tasks
336            .iter()
337            .map(|(id, task)| {
338                (
339                    id.clone(),
340                    task.read().unwrap().dependencies.iter().cloned().collect(),
341                )
342            })
343            .collect();
344
345        let mut executed = HashSet::new();
346        let mut plan = HashMap::new();
347        let mut step = 1;
348
349        while !remaining.is_empty() {
350            // Find all tasks that can be executed in parallel - collect just the keys we need to remove
351            let ready_keys: Vec<String> = remaining
352                .iter()
353                .filter(|(_, deps)| deps.is_subset(&executed))
354                .map(|(id, _)| id.to_string())
355                .collect();
356
357            if ready_keys.is_empty() {
358                // Circular dependency detected
359                break;
360            }
361
362            // Create the set for the plan (reusing the already allocated Strings)
363            let mut ready_set = HashSet::with_capacity(ready_keys.len());
364
365            // Update tracking sets and build the ready set in one pass
366            for key in ready_keys {
367                executed.insert(key.clone());
368                remaining.remove(&key);
369                ready_set.insert(key);
370            }
371
372            // Add parallel tasks to the current step
373            plan.insert(step, ready_set);
374
375            step += 1;
376        }
377
378        Ok(plan)
379    }
380
381    pub fn __str__(&self) -> String {
382        PyHelperFuncs::__str__(&self.task_list)
383    }
384
385    pub fn serialize(&self) -> Result<String, serde_json::Error> {
386        // reset the workflow
387        let json = serde_json::to_string(self).unwrap();
388        // Add debug output to see what's being serialized
389        Ok(json)
390    }
391
392    pub fn from_json(json: &str) -> Result<Self, WorkflowError> {
393        // Deserialize the JSON string into a Workflow instance
394        Ok(serde_json::from_str(json)?)
395    }
396
397    pub fn task_names(&self) -> Vec<String> {
398        self.task_list
399            .tasks
400            .keys()
401            .cloned()
402            .collect::<Vec<String>>()
403    }
404
405    pub fn last_task_id(&self) -> Option<String> {
406        self.task_list.get_last_task_id()
407    }
408}
409
410/// Check if the workflow is complete
411/// # Arguments
412/// * `workflow` - A reference to the workflow instance
413/// # Returns true if the workflow is complete, false otherwise
414fn is_workflow_complete(workflow: &Arc<RwLock<Workflow>>) -> bool {
415    workflow.read().unwrap().is_complete()
416}
417
418/// Reset failed tasks in the workflow
419/// # Arguments
420/// * `workflow` - A reference to the workflow instance
421/// # Returns Ok(()) if successful, or an error if the reset fails
422fn reset_failed_workflow_tasks(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
423    match workflow.write().unwrap().task_list.reset_failed_tasks() {
424        Ok(_) => Ok(()),
425        Err(e) => {
426            warn!("Failed to reset failed tasks: {}", e);
427            Err(e)
428        }
429    }
430}
431
432/// Get all ready tasks in the workflow
433/// # Arguments
434/// * `workflow` - A reference to the workflow instance
435/// # Returns a vector of tasks that are ready to be executed
436fn get_ready_tasks(workflow: &Arc<RwLock<Workflow>>) -> Vec<Arc<RwLock<Task>>> {
437    workflow.read().unwrap().task_list.get_ready_tasks()
438}
439
440/// Check for circular dependencies
441/// # Arguments
442/// * `workflow` - A reference to the workflow instance
443/// # Returns true if circular dependencies are detected, false otherwise
444fn check_for_circular_dependencies(workflow: &Arc<RwLock<Workflow>>) -> bool {
445    let pending_count = workflow.read().unwrap().pending_count();
446
447    if pending_count > 0 {
448        warn!(
449            "No ready tasks found but {} pending tasks remain. Possible circular dependency.",
450            pending_count
451        );
452        return true;
453    }
454
455    false
456}
457
458/// Mark a task as running
459/// # Arguments
460/// * `workflow` - A reference to the workflow instance
461/// # Returns nothing
462fn mark_task_as_running(task: Arc<RwLock<Task>>, event_tracker: &Arc<RwLock<EventTracker>>) {
463    let mut task = task.write().unwrap();
464    task.set_status(TaskStatus::Running);
465    event_tracker.write().unwrap().record_task_started(&task.id);
466}
467
468/// Get an agent for a task
469/// # Arguments
470/// * `workflow` - A reference to the workflow instance
471/// * `task` - A reference to the task for which the agent is needed
472fn get_agent_for_task(
473    workflow: &Arc<RwLock<Workflow>>,
474    agent_id: &str,
475) -> Result<Arc<Agent>, WorkflowError> {
476    let wf = workflow.read().unwrap();
477    match wf.agents.get(agent_id) {
478        Some(agent) => Ok(agent.clone()),
479        None => Err(WorkflowError::AgentNotFound(agent_id.to_string())),
480    }
481}
482
483/// Builds the context for a task from its dependencies
484/// For a given list of dependency task IDs (tasks current task depends on),
485/// for each dependency:
486/// 1. Retrieve the task from the workflow's task list
487/// 2. If the task has a result, convert the result to MessageNum format (convert from response struct to request struct)
488/// 2a. Inside of to_message_num, convert each request message to the target provider format if needed
489/// 3. Insert the converted messages into the context hashmap
490/// 4. if the result has a structured output, extract it and merge into parameter context
491/// This is helpful for when a user wants to pass structured output variables that can be bound in subsequent tasks
492/// 5. Return the dependency context, parameter context, and global context
493/// # Arguments
494/// * `workflow` - A reference to the workflow instance
495/// * `task` - A reference to the task for which the context is being built
496/// # Returns a HashMap containing the context messages for the task
497#[instrument(skip_all)]
498fn build_task_context(
499    workflow: &Arc<RwLock<Workflow>>,
500    task_dependencies: &Vec<String>,
501    provider: &Provider,
502) -> Result<Context, WorkflowError> {
503    let wf = workflow.read().unwrap();
504    let mut ctx = HashMap::new();
505    let mut param_ctx: Value = Value::Object(Map::new());
506
507    for dep_id in task_dependencies {
508        debug!("Building context for task dependency: {}", dep_id);
509        if let Some(dep) = wf.task_list.get_task(dep_id) {
510            if let Some(result) = &dep.read().unwrap().result {
511                let msg_to_insert = result.response.to_message_num(provider);
512
513                match msg_to_insert {
514                    Ok(message) => {
515                        ctx.insert(dep_id.clone(), message);
516                    }
517                    Err(e) => {
518                        warn!("Failed to convert response to message: {}", e);
519                    }
520                }
521
522                if let Some(structure_output) = result.response.extract_structured_data() {
523                    // Value should be a serde_json::Value Object type
524                    // validate that it's an object
525                    if structure_output.is_object() {
526                        // extract the Map from the Value
527                        update_serde_map_with(&mut param_ctx, &structure_output)?;
528                    }
529                }
530            }
531        }
532    }
533
534    debug!("Built context for task dependencies: {:?}", ctx);
535    let global_context = workflow
536        .read()
537        .unwrap()
538        .global_context
539        .as_ref()
540        .map(Arc::clone);
541
542    Ok((ctx, param_ctx, global_context))
543}
544
545/// Validate a task response against its JSON schema
546/// Only validates if a schema is defined for the task
547/// # Arguments
548/// * `task` - A reference to the task
549/// * `response` - A reference to the task response
550/// # Returns
551/// true if the response is invalid, false otherwise
552fn validate_response_schema(
553    task: &Arc<RwLock<Task>>,
554    response: &potato_agent::AgentResponse,
555) -> bool {
556    task.read()
557        .ok()
558        .and_then(|t| {
559            response
560                .response_value()
561                .map(|value| t.validate_output(&value).is_ok())
562        })
563        .unwrap_or(true)
564}
565
566/// Spawns an individual task execution
567/// # Arguments
568/// * `workflow` - A reference to the workflow instance
569/// * `task` - The task to be executed
570/// * `task_id` - The ID of the task
571/// * `agent` - An optional reference to the agent that will execute the task
572/// * `context` - A HashMap containing the context messages for the task
573/// # Returns a JoinHandle for the spawned task
574fn spawn_task_execution(
575    event_tracker: Arc<RwLock<EventTracker>>,
576    task: Arc<RwLock<Task>>,
577    task_id: String,
578    agent: Arc<Agent>,
579    context: HashMap<String, Vec<MessageNum>>,
580    parameter_context: Value,
581    global_context: Option<Arc<Value>>,
582) -> tokio::task::JoinHandle<()> {
583    tokio::spawn(async move {
584        let result = agent
585            .execute_task_with_context_message(&task, context, parameter_context, global_context)
586            .await;
587
588        match result {
589            Ok(response) => {
590                info!("Task {} completed successfully", task_id);
591
592                let is_valid = validate_response_schema(&task, &response); // No schema or no response_value = validation passes
593                if !is_valid {
594                    error!(
595                        "Task {} response validation against JSON schema failed",
596                        task_id
597                    );
598
599                    if let Ok(mut write_task) = task.write() {
600                        write_task.set_status(TaskStatus::Failed);
601
602                        if let Ok(tracker) = event_tracker.write() {
603                            tracker.record_task_failed(
604                                &write_task.id,
605                                "Response JSON schema validation failed",
606                                &write_task.prompt,
607                            );
608                        }
609                    }
610                    return;
611                }
612
613                if let Ok(mut write_task) = task.write() {
614                    write_task.set_status(TaskStatus::Completed);
615                    write_task.set_result(response.clone());
616
617                    if let Ok(tracker) = event_tracker.write() {
618                        tracker.record_task_completed(&write_task.id, &write_task.prompt, response);
619                    }
620                }
621            }
622            Err(e) => {
623                error!("Task {} failed: {}", task_id, e);
624
625                if let Ok(mut write_task) = task.write() {
626                    write_task.set_status(TaskStatus::Failed);
627
628                    if let Ok(tracker) = event_tracker.write() {
629                        tracker.record_task_failed(
630                            &write_task.id,
631                            &e.to_string(),
632                            &write_task.prompt,
633                        );
634                    }
635                }
636            }
637        }
638    })
639}
640
641fn get_parameters_from_context(task: Arc<RwLock<Task>>) -> (String, Vec<String>, String, Provider) {
642    let (task_id, dependencies, agent_id, provider) = {
643        let task_guard = task.read().unwrap();
644        (
645            task_guard.id.clone(),
646            task_guard.dependencies.clone(),
647            task_guard.agent_id.clone(),
648            task_guard.prompt.provider.clone(),
649        )
650    };
651
652    (task_id, dependencies, agent_id, provider)
653}
654
655/// Helper for spawning a task execution
656/// # Arguments
657/// * `workflow` - A reference to the workflow instance
658/// * `tasks` - A vector of tasks to be executed
659/// # Returns a vector of JoinHandles for the spawned tasks
660fn spawn_task_executions(
661    workflow: &Arc<RwLock<Workflow>>,
662    ready_tasks: Vec<Arc<RwLock<Task>>>,
663) -> Result<Vec<tokio::task::JoinHandle<()>>, WorkflowError> {
664    let mut handles = Vec::with_capacity(ready_tasks.len());
665
666    // Get the event tracker from the workflow
667    let event_tracker = workflow.read().unwrap().event_tracker.clone();
668
669    for task in ready_tasks {
670        // Get task parameters for active task
671        let (task_id, dependencies, agent_id, provider) = get_parameters_from_context(task.clone());
672
673        // Mark task as running
674        // This will also record the task started event
675        mark_task_as_running(task.clone(), &event_tracker);
676
677        // Build the context
678        // Here we:
679        // 1. Get the task dependencies and their results (these will be injected as assistant messages)
680        // We need to know the provider here so that we can convert  messages to a different provider type if needed
681        // 2. Parse dependent tasks for any structured outputs and return as a serde_json::Value (this will be task-level context)
682        let (context, parameter_context, global_context) =
683            build_task_context(workflow, &dependencies, &provider)?;
684
685        // Get/clone agent ARC
686        let agent = get_agent_for_task(workflow, &agent_id)?;
687
688        // Spawn task execution and push handle to future vector
689        let handle = spawn_task_execution(
690            event_tracker.clone(),
691            task.clone(),
692            task_id,
693            agent,
694            context,
695            parameter_context,
696            global_context,
697        );
698        handles.push(handle);
699    }
700
701    Ok(handles)
702}
703
704/// Wait for all spawned tasks to complete
705/// # Arguments
706/// * `handles` - A vector of JoinHandles for the spawned tasks
707/// # Returns nothing
708async fn await_task_completions(handles: Vec<tokio::task::JoinHandle<()>>) {
709    for handle in handles {
710        if let Err(e) = handle.await {
711            warn!("Task execution failed: {}", e);
712        }
713    }
714}
715
716/// Execute the workflow asynchronously
717/// This function will be called to start the workflow execution process and does the following:
718/// 1. Iterates over workflow tasks while the shared workflow is not complete.
719/// 2. Resets any failed tasks to allow them to be retried. This needs to happen before getting ready tasks.
720/// 3. Gets all ready tasks
721/// 4. For each ready task:
722/// ///    - Marks the task as running
723/// ///    - Checks previous tasks for injected context
724/// ///    - Gets the agent for the task  
725/// ///    - Spawn a new tokio task and execute task with agent
726/// ///    - Push task to the handles vector
727/// 4. Waits for all spawned tasks to complete
728#[instrument(skip_all)]
729pub async fn execute_workflow(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
730    // Important to remember that the workflow is an Arc<RwLock<Workflow>> is a new clone of
731    // the loaded workflow. This allows us to mutate the workflow without affecting the original
732    // workflow instance.
733    debug!("Starting workflow execution");
734
735    // Run until workflow is complete
736    while !is_workflow_complete(workflow) {
737        // Reset any failed tasks
738        // This will return an error if any task exceeds its max retries (set at the task level)
739        reset_failed_workflow_tasks(workflow)?;
740
741        // Get tasks ready for execution
742        // This will return an Arc<RwLock<Task>>
743        let ready_tasks = get_ready_tasks(workflow);
744        debug!("Found {} ready tasks for execution", ready_tasks.len());
745
746        // Check for circular dependencies
747        if ready_tasks.is_empty() {
748            if check_for_circular_dependencies(workflow) {
749                break;
750            }
751            continue;
752        }
753
754        // Execute tasks asynchronously
755        let handles = spawn_task_executions(workflow, ready_tasks)?;
756
757        // Wait for all tasks to complete
758        await_task_completions(handles).await;
759    }
760
761    debug!("Workflow execution completed");
762    Ok(())
763}
764
765impl Serialize for Workflow {
766    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
767    where
768        S: Serializer,
769    {
770        let mut state = serializer.serialize_struct("Workflow", 4)?;
771
772        // set session to none
773        state.serialize_field("id", &self.id)?;
774        state.serialize_field("name", &self.name)?;
775        state.serialize_field("task_list", &self.task_list)?;
776        state.serialize_field("agents", &self.agents)?;
777
778        state.end()
779    }
780}
781
782impl<'de> Deserialize<'de> for Workflow {
783    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
784    where
785        D: Deserializer<'de>,
786    {
787        #[derive(Deserialize)]
788        #[serde(field_identifier, rename_all = "snake_case")]
789        enum Field {
790            Id,
791            Name,
792            TaskList,
793            Agents,
794        }
795
796        struct WorkflowVisitor;
797
798        impl<'de> Visitor<'de> for WorkflowVisitor {
799            type Value = Workflow;
800
801            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
802                formatter.write_str("struct Workflow")
803            }
804
805            fn visit_map<V>(self, mut map: V) -> Result<Workflow, V::Error>
806            where
807                V: MapAccess<'de>,
808            {
809                let mut id = None;
810                let mut name = None;
811                let mut task_list_data = None;
812                let mut agents: Option<HashMap<String, Agent>> = None;
813
814                while let Some(key) = map.next_key()? {
815                    match key {
816                        Field::Id => {
817                            let value: String = map.next_value().map_err(|e| {
818                                error!("Failed to deserialize field 'id': {e}");
819                                de::Error::custom(format!("Failed to deserialize field 'id': {e}"))
820                            })?;
821                            id = Some(value);
822                        }
823                        Field::TaskList => {
824                            // Deserialize as a generic Value first
825                            let value: TaskList = map.next_value().map_err(|e| {
826                                error!("Failed to deserialize field 'task_list': {e}");
827                                de::Error::custom(format!(
828                                    "Failed to deserialize field 'task_list': {e}",
829                                ))
830                            })?;
831
832                            task_list_data = Some(value);
833                        }
834                        Field::Name => {
835                            let value: String = map.next_value().map_err(|e| {
836                                error!("Failed to deserialize field 'name': {e}");
837                                de::Error::custom(format!(
838                                    "Failed to deserialize field 'name': {e}",
839                                ))
840                            })?;
841                            name = Some(value);
842                        }
843                        Field::Agents => {
844                            let value: HashMap<String, Agent> = map.next_value().map_err(|e| {
845                                error!("Failed to deserialize field 'agents': {e}");
846                                de::Error::custom(format!(
847                                    "Failed to deserialize field 'agents': {e}"
848                                ))
849                            })?;
850                            agents = Some(value);
851                        }
852                    }
853                }
854
855                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
856                let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
857                let task_list_data =
858                    task_list_data.ok_or_else(|| de::Error::missing_field("task_list"))?;
859                let agents = agents.ok_or_else(|| de::Error::missing_field("agents"))?;
860
861                let event_tracker = Arc::new(RwLock::new(EventTracker::new(create_uuid7())));
862
863                // convert agents to arc
864                let agents = agents
865                    .into_iter()
866                    .map(|(id, agent)| (id, Arc::new(agent)))
867                    .collect();
868
869                Ok(Workflow {
870                    id,
871                    name,
872                    task_list: task_list_data,
873                    agents,
874                    event_tracker,
875                    global_context: None, // Initialize with no global context
876                })
877            }
878        }
879
880        const FIELDS: &[&str] = &["id", "name", "task_list", "agents"];
881        deserializer.deserialize_struct("Workflow", FIELDS, WorkflowVisitor)
882    }
883}
884
885#[pyclass(name = "Workflow")]
886#[derive(Debug, Clone)]
887pub struct PyWorkflow {
888    workflow: Workflow,
889
890    // allow adding output types for python tasks (py only)
891    // these are provided at runtime by the user and must match the response
892    // format of the prompt the task is associated with
893    output_types: HashMap<String, Arc<Py<PyAny>>>,
894}
895
896#[pymethods]
897impl PyWorkflow {
898    #[new]
899    #[pyo3(signature = (name))]
900    pub fn new(name: &str) -> Result<Self, WorkflowError> {
901        debug!("Creating new workflow: {}", name);
902        Ok(Self {
903            workflow: Workflow::new(name),
904            output_types: HashMap::new(),
905        })
906    }
907
908    #[getter]
909    pub fn name(&self) -> String {
910        self.workflow.name.clone()
911    }
912
913    #[getter]
914    pub fn task_list(&self) -> TaskList {
915        self.workflow.task_list.clone()
916    }
917
918    #[getter]
919    pub fn is_workflow(&self) -> bool {
920        true
921    }
922
923    #[getter]
924    pub fn __workflow__(&self) -> Result<String, WorkflowError> {
925        self.model_dump_json()
926    }
927
928    #[getter]
929    pub fn agents(&self) -> Result<HashMap<String, PyAgent>, WorkflowError> {
930        self.workflow
931            .agents
932            .iter()
933            .map(|(id, agent)| {
934                Ok((
935                    id.clone(),
936                    PyAgent {
937                        agent: agent.clone(),
938                    },
939                ))
940            })
941            .collect::<Result<HashMap<_, _>, _>>()
942    }
943
944    #[pyo3(signature = (task_output_types))]
945    pub fn add_task_output_types<'py>(
946        &mut self,
947        task_output_types: Bound<'py, PyDict>,
948    ) -> PyResult<()> {
949        let converted: HashMap<String, Arc<Py<PyAny>>> = task_output_types
950            .iter()
951            .map(|(k, v)| -> PyResult<(String, Arc<Py<PyAny>>)> {
952                // Explicitly return a Result from the closure
953                let key = k.extract::<String>()?;
954                let value = v.clone().unbind();
955                Ok((key, Arc::new(value)))
956            })
957            .collect::<PyResult<_>>()?;
958        self.output_types.extend(converted);
959        Ok(())
960    }
961
962    #[pyo3(signature = (task, output_type = None))]
963    pub fn add_task(
964        &mut self,
965        py: Python<'_>,
966        mut task: Task,
967        output_type: Option<Bound<'_, PyAny>>,
968    ) -> Result<(), WorkflowError> {
969        if let Some(output_type) = output_type {
970            // Parse and set the response format
971            let (response_type, response_json_schema) = parse_response_to_json(py, &output_type)
972                .map_err(|e| WorkflowError::InvalidOutputType(e.to_string()))?;
973
974            // update the task prompt with the response schema
975            task.prompt
976                .set_response_json_schema(response_json_schema, response_type);
977
978            // Store the output type for later use
979            self.output_types
980                .insert(task.id.clone(), Arc::new(output_type.unbind()));
981        }
982
983        self.workflow.task_list.add_task(task)?;
984        Ok(())
985    }
986
987    pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
988        for task in tasks {
989            self.workflow.task_list.add_task(task)?;
990        }
991        Ok(())
992    }
993
994    pub fn add_agent(&mut self, agent: &Bound<'_, PyAgent>) {
995        // extract the arc rust agent from the python agent
996        let agent = agent.extract::<PyAgent>().unwrap().agent.clone();
997        self.workflow.agents.insert(agent.id.clone(), agent);
998    }
999
1000    pub fn add_agents(&mut self, agents: Vec<Bound<'_, PyAgent>>) {
1001        for agent in agents {
1002            self.add_agent(&agent);
1003        }
1004    }
1005
1006    pub fn is_complete(&self) -> bool {
1007        self.workflow.task_list.is_complete()
1008    }
1009
1010    pub fn pending_count(&self) -> usize {
1011        self.workflow.task_list.pending_count()
1012    }
1013
1014    pub fn execution_plan<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, WorkflowError> {
1015        let plan = self.workflow.execution_plan()?;
1016        debug!("Execution plan: {:?}", plan);
1017
1018        // turn hashmap into a to json
1019        let json = serde_json::to_value(plan).map_err(|e| {
1020            error!("Failed to serialize execution plan to JSON: {}", e);
1021            e
1022        })?;
1023
1024        Ok(pythonize(py, &json)?)
1025    }
1026
1027    #[pyo3(signature = (global_context=None))]
1028    pub fn run(
1029        &self,
1030        py: Python,
1031        global_context: Option<Bound<'_, PyAny>>,
1032    ) -> Result<WorkflowResult, WorkflowError> {
1033        debug!("Running workflow: {}", self.workflow.name);
1034
1035        // Convert the global context from PyDict to serde_json::Value if provided
1036        let global_context = if let Some(context) = global_context {
1037            let json_value = depythonize_object_to_value(py, &context)?;
1038            Some(json_value)
1039        } else {
1040            None
1041        };
1042
1043        let workflow: Arc<RwLock<Workflow>> =
1044            block_on(async { self.workflow.run(global_context).await })?;
1045
1046        // Try to get exclusive ownership of the workflow by unwrapping the Arc if there's only one reference
1047        let workflow_result = match Arc::try_unwrap(workflow) {
1048            // If we have exclusive ownership, we can consume the RwLock
1049            Ok(rwlock) => {
1050                // Unwrap the RwLock to get the Workflow
1051                let workflow = rwlock
1052                    .into_inner()
1053                    .map_err(|_| WorkflowError::LockAcquireError)?;
1054
1055                // Get the events before creating WorkflowResult
1056                let events = workflow
1057                    .event_tracker
1058                    .read()
1059                    .unwrap()
1060                    .events
1061                    .read()
1062                    .unwrap()
1063                    .clone();
1064
1065                // Move the tasks out of the workflow
1066                WorkflowResult::new(
1067                    py,
1068                    workflow.task_list.tasks(),
1069                    &self.output_types,
1070                    events,
1071                    workflow.task_list.get_last_task_id(),
1072                )
1073            }
1074            // If there are other references, we need to clone
1075            Err(arc) => {
1076                // Just read the workflow
1077                error!("Workflow still has other references, reading instead of consuming.");
1078                let workflow = arc
1079                    .read()
1080                    .map_err(|_| WorkflowError::ReadLockAcquireError)?;
1081
1082                // Get the events before creating WorkflowResult
1083                let events = workflow
1084                    .event_tracker
1085                    .read()
1086                    .unwrap()
1087                    .events
1088                    .read()
1089                    .unwrap()
1090                    .clone();
1091
1092                WorkflowResult::new(
1093                    py,
1094                    workflow.task_list.tasks(),
1095                    &self.output_types,
1096                    events,
1097                    workflow.task_list.get_last_task_id(),
1098                )
1099            }
1100        };
1101
1102        info!("Workflow execution completed successfully.");
1103        Ok(workflow_result)
1104    }
1105
1106    #[pyo3(signature = (task_id, context=None))]
1107    pub fn execute_task<'py>(
1108        &self,
1109        py: Python<'py>,
1110        task_id: String,
1111        context: Option<Bound<'py, PyAny>>,
1112    ) -> Result<Bound<'py, PyAny>, WorkflowError> {
1113        let context_value = if let Some(ctx) = context {
1114            depythonize_object_to_value(py, &ctx)?
1115        } else {
1116            Value::Null
1117        };
1118
1119        let response_value =
1120            block_on(async { self.workflow.execute_task(&task_id, &context_value).await })?;
1121        let py_response = pythonize(py, &response_value)?;
1122
1123        Ok(py_response)
1124    }
1125
1126    pub fn model_dump_json(&self) -> Result<String, WorkflowError> {
1127        Ok(self.workflow.serialize()?)
1128    }
1129
1130    #[staticmethod]
1131    #[pyo3(signature = (json_string, output_types=None))]
1132    pub fn model_validate_json(
1133        json_string: String,
1134        output_types: Option<Bound<'_, PyDict>>,
1135    ) -> Result<Self, WorkflowError> {
1136        let mut workflow: Workflow = Workflow::from_json(&json_string)?;
1137
1138        // rebuild tasklist schema validators
1139        workflow.task_list.rebuild_task_validators()?;
1140
1141        // reload agents to ensure clients are rebuilt
1142        // This is necessary because during deserialization the GenAIClient
1143        block_on(async { workflow.reset_agents().await })?;
1144        let output_types = match output_types {
1145            Some(output_types) => output_types
1146                .iter()
1147                .map(|(k, v)| -> PyResult<(String, Arc<Py<PyAny>>)> {
1148                    let key = k.extract::<String>()?;
1149                    let value = v.clone().unbind();
1150                    Ok((key, Arc::new(value)))
1151                })
1152                .collect::<PyResult<HashMap<String, Arc<Py<PyAny>>>>>()?,
1153            None => HashMap::new(),
1154        };
1155
1156        let py_workflow = PyWorkflow {
1157            workflow,
1158            output_types,
1159        };
1160
1161        Ok(py_workflow)
1162    }
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167    use super::*;
1168    use potato_type::openai::v1::chat::request::{
1169        ChatMessage as OpenAIChatMessage, ContentPart, TextContentPart,
1170    };
1171    use potato_type::prompt::Prompt;
1172    use potato_type::prompt::ResponseType;
1173
1174    fn create_openai_chat_message_num() -> MessageNum {
1175        let text_part = TextContentPart::new("What company is this logo from?".to_string());
1176        let text_content_part = ContentPart::Text(text_part);
1177        let text_message = OpenAIChatMessage {
1178            role: "user".to_string(),
1179            content: vec![text_content_part],
1180            name: None,
1181        };
1182        MessageNum::OpenAIMessageV1(text_message)
1183    }
1184
1185    #[test]
1186    fn test_workflow_creation() {
1187        let workflow = Workflow::new("Test Workflow");
1188        assert_eq!(workflow.name, "Test Workflow");
1189        assert_eq!(workflow.id.len(), 36); // UUID7 length
1190    }
1191
1192    #[test]
1193    fn test_task_list_add_and_get() {
1194        let mut task_list = TaskList::new();
1195
1196        let prompt = Prompt::new_rs(
1197            vec![create_openai_chat_message_num()],
1198            "gpt-4o",
1199            potato_type::Provider::OpenAI,
1200            vec![],
1201            None,
1202            None,
1203            ResponseType::Null,
1204        )
1205        .unwrap();
1206
1207        let task = Task::new("task1", prompt, "task1", None, None).unwrap();
1208        task_list.add_task(task.clone()).unwrap();
1209        assert_eq!(
1210            task_list.get_task(&task.id).unwrap().read().unwrap().id,
1211            task.id
1212        );
1213        task_list.reset_failed_tasks().unwrap();
1214    }
1215}