1use crate::workflow::error::WorkflowError;
2pub use potato_agent::agents::{
3 agent::Agent,
4 task::{PyTask, Task, TaskStatus},
5 types::ChatResponse,
6};
7use potato_util::{create_uuid7, PyHelperFuncs};
8
9use potato_prompt::prompt::types::Role;
10use potato_prompt::Message;
11use pyo3::prelude::*;
12
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use std::sync::RwLock;
16use tracing::instrument;
17use tracing::{debug, error, info, warn};
18
19use serde::{
20 de::{self, MapAccess, Visitor},
21 ser::SerializeStruct,
22 Deserialize, Deserializer, Serialize, Serializer,
23};
24
25#[derive(Debug)]
26#[pyclass]
27pub struct WorkflowResult {
28 #[pyo3(get)]
29 pub tasks: HashMap<String, Py<PyTask>>,
30}
31
32impl WorkflowResult {
33 pub fn new(py: Python, tasks: HashMap<String, Task>) -> Self {
34 let py_tasks = tasks
35 .into_iter()
36 .map(|(id, task)| {
37 let py_task = PyTask {
38 id: task.id.clone(),
39 prompt: task.prompt,
40 dependencies: task.dependencies,
41 status: task.status,
42 agent_id: task.agent_id,
43 result: task.result,
44 max_retries: task.max_retries,
45 retry_count: task.retry_count,
46 response_type: None, };
48 (id, Py::new(py, py_task).unwrap())
49 })
50 .collect::<HashMap<_, _>>();
51
52 Self { tasks: py_tasks }
53 }
54}
55
56#[pymethods]
57impl WorkflowResult {
58 pub fn __str__(&self) -> String {
59 PyHelperFuncs::__str__(&self.tasks)
60 }
61}
62
63#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64#[pyclass]
65pub struct TaskList {
66 #[pyo3(get)]
67 pub tasks: HashMap<String, Task>,
68 pub execution_order: Vec<String>,
69}
70
71impl TaskList {
72 pub fn new() -> Self {
73 Self {
74 tasks: HashMap::new(),
75 execution_order: Vec::new(),
76 }
77 }
78
79 pub fn is_complete(&self) -> bool {
80 self.tasks
81 .values()
82 .all(|task| task.status == TaskStatus::Completed || task.status == TaskStatus::Failed)
83 }
84
85 pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
86 if self.tasks.contains_key(&task.id) {
88 return Err(WorkflowError::TaskAlreadyExists(task.id.clone()));
89 }
90
91 for dep_id in &task.dependencies {
93 if !self.tasks.contains_key(dep_id) {
94 return Err(WorkflowError::DependencyNotFound(dep_id.clone()));
95 }
96
97 if dep_id == &task.id {
99 return Err(WorkflowError::TaskDependsOnItself(task.id.clone()));
100 }
101 }
102
103 self.tasks.insert(task.id.clone(), task);
105 self.rebuild_execution_order();
106 Ok(())
107 }
108
109 pub fn get_task(&self, task_id: &str) -> Option<&Task> {
110 self.tasks.get(task_id)
111 }
112
113 pub fn remove_task(&mut self, task_id: &str) {
114 self.tasks.remove(task_id);
115 }
116
117 pub fn pending_count(&self) -> usize {
118 self.tasks
119 .values()
120 .filter(|task| task.status == TaskStatus::Pending)
121 .count()
122 }
123
124 #[instrument(skip_all)]
125 pub fn update_task_status(
126 &mut self,
127 task_id: &str,
128 status: TaskStatus,
129 result: Option<ChatResponse>,
130 ) {
131 debug!(status=?status, result=?result, "Updating task status");
132 if let Some(task) = self.tasks.get_mut(task_id) {
133 task.status = status;
134 task.result = result;
135 }
136 }
137
138 fn topological_sort(
139 &self,
140 task_id: &str,
141 visited: &mut HashSet<String>,
142 temp_visited: &mut HashSet<String>,
143 order: &mut Vec<String>,
144 ) {
145 if temp_visited.contains(task_id) {
146 return; }
148
149 if visited.contains(task_id) {
150 return;
151 }
152
153 temp_visited.insert(task_id.to_string());
154
155 if let Some(task) = self.tasks.get(task_id) {
156 for dep_id in &task.dependencies {
157 self.topological_sort(dep_id, visited, temp_visited, order);
158 }
159 }
160
161 temp_visited.remove(task_id);
162 visited.insert(task_id.to_string());
163 order.push(task_id.to_string());
164 }
165
166 fn rebuild_execution_order(&mut self) {
167 let mut order = Vec::new();
168 let mut visited = HashSet::new();
169 let mut temp_visited = HashSet::new();
170
171 for task_id in self.tasks.keys() {
172 if !visited.contains(task_id) {
173 self.topological_sort(task_id, &mut visited, &mut temp_visited, &mut order);
174 }
175 }
176
177 self.execution_order = order;
178 }
179
180 pub fn get_ready_tasks(&self) -> Vec<Task> {
185 self.tasks
186 .values()
187 .filter(|task| {
188 task.status == TaskStatus::Pending
189 && task.dependencies.iter().all(|dep_id| {
190 self.tasks
191 .get(dep_id)
192 .map(|dep| dep.status == TaskStatus::Completed)
193 .unwrap_or(false)
194 })
195 })
196 .cloned()
197 .collect()
198 }
199
200 pub fn reset_failed_tasks(&mut self) -> Result<(), WorkflowError> {
201 for task in self.tasks.values_mut() {
202 if task.status == TaskStatus::Failed {
203 task.status = TaskStatus::Pending;
204 task.increment_retry();
205 if task.retry_count > task.max_retries {
206 return Err(WorkflowError::MaxRetriesExceeded(task.id.clone()));
207 }
208 }
209 }
210 Ok(())
211 }
212}
213
214#[derive(Debug, Clone)]
216pub struct Workflow {
217 pub id: String,
218 pub name: String,
219 pub tasks: TaskList,
220 pub agents: HashMap<String, Arc<Agent>>,
221}
222
223impl Workflow {
224 pub fn new(name: &str) -> Self {
225 info!("Creating new workflow: {}", name);
226 Self {
227 id: create_uuid7(),
228 name: name.to_string(),
229 tasks: TaskList::new(),
230 agents: HashMap::new(),
231 }
232 }
233 pub async fn run(&self) -> Result<(), WorkflowError> {
234 info!("Running workflow: {}", self.name);
235 let workflow = self.clone();
236 let workflow = Arc::new(RwLock::new(workflow));
237 execute_workflow(workflow).await
238 }
239
240 pub fn is_complete(&self) -> bool {
241 self.tasks.is_complete()
242 }
243
244 pub fn pending_count(&self) -> usize {
245 self.tasks.pending_count()
246 }
247
248 pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
249 self.tasks.add_task(task)
250 }
251
252 pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
253 for task in tasks {
254 self.tasks.add_task(task)?;
255 }
256 Ok(())
257 }
258
259 pub fn add_agent(&mut self, agent: &Agent) {
260 self.agents
261 .insert(agent.id.clone(), Arc::new(agent.clone()));
262 }
263
264 pub fn execution_plan(&self) -> Result<HashMap<String, HashSet<String>>, WorkflowError> {
265 let mut remaining: HashMap<String, HashSet<String>> = self
266 .tasks
267 .tasks
268 .iter()
269 .map(|(id, task)| (id.clone(), task.dependencies.iter().cloned().collect()))
270 .collect();
271
272 let mut executed = HashSet::new();
273 let mut plan = HashMap::new();
274 let mut step = 1;
275
276 while !remaining.is_empty() {
277 let ready_keys: Vec<String> = remaining
279 .iter()
280 .filter(|(_, deps)| deps.is_subset(&executed))
281 .map(|(id, _)| id.to_string())
282 .collect();
283
284 if ready_keys.is_empty() {
285 break;
287 }
288
289 let mut ready_set = HashSet::with_capacity(ready_keys.len());
291
292 for key in ready_keys {
294 executed.insert(key.clone());
295 remaining.remove(&key);
296 ready_set.insert(key);
297 }
298
299 plan.insert(format!("step{step}"), ready_set);
301
302 step += 1;
303 }
304
305 Ok(plan)
306 }
307}
308
309fn is_workflow_complete(workflow: &Arc<RwLock<Workflow>>) -> bool {
314 workflow.read().unwrap().is_complete()
315}
316
317fn reset_failed_workflow_tasks(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
322 match workflow.write().unwrap().tasks.reset_failed_tasks() {
323 Ok(_) => Ok(()),
324 Err(e) => {
325 warn!("Failed to reset failed tasks: {}", e);
326 Err(e)
327 }
328 }
329}
330
331fn get_ready_tasks(workflow: &Arc<RwLock<Workflow>>) -> Vec<Task> {
336 workflow.read().unwrap().tasks.get_ready_tasks()
337}
338
339fn check_for_circular_dependencies(workflow: &Arc<RwLock<Workflow>>) -> bool {
344 let pending_count = workflow.read().unwrap().pending_count();
345
346 if pending_count > 0 {
347 warn!(
348 "No ready tasks found but {} pending tasks remain. Possible circular dependency.",
349 pending_count
350 );
351 return true;
352 }
353
354 false
355}
356
357fn mark_task_as_running(workflow: &Arc<RwLock<Workflow>>, task_id: &str) {
362 let mut wf = workflow.write().unwrap();
363 wf.tasks
364 .update_task_status(task_id, TaskStatus::Running, None);
365}
366
367fn get_agent_for_task(workflow: &Arc<RwLock<Workflow>>, task: &Task) -> Option<Arc<Agent>> {
372 let wf = workflow.read().unwrap();
373 wf.agents.get(&task.agent_id).cloned()
374}
375
376fn build_task_context(
382 workflow: &Arc<RwLock<Workflow>>,
383 task: &Task,
384) -> HashMap<String, Vec<Message>> {
385 let wf = workflow.read().unwrap();
386 let mut ctx = HashMap::new();
387
388 for dep_id in &task.dependencies {
389 if let Some(dep) = wf.tasks.get_task(dep_id) {
390 if let Some(result) = &dep.result {
391 if let Ok(message) = result.to_message(Role::Assistant) {
392 ctx.insert(dep_id.clone(), message);
393 }
394 }
395 }
396 }
397
398 ctx
399}
400
401fn spawn_task_execution(
410 workflow: Arc<RwLock<Workflow>>,
411 task: Task,
412 task_id: String,
413 agent: Option<Arc<Agent>>,
414 context: HashMap<String, Vec<Message>>,
415) -> tokio::task::JoinHandle<()> {
416 tokio::spawn(async move {
417 if let Some(agent) = agent {
418 match agent.execute_async_task_with_context(&task, context).await {
419 Ok(response) => {
420 let mut wf = workflow.write().unwrap();
421 wf.tasks.update_task_status(
422 &task_id,
423 TaskStatus::Completed,
424 Some(response.response),
425 );
426 }
427 Err(e) => {
428 error!("Task {} failed: {}", task_id, e);
429 let mut wf = workflow.write().unwrap();
430 wf.tasks
431 .update_task_status(&task_id, TaskStatus::Failed, None);
432 }
433 }
434 } else {
435 error!("No agent found for task {}", task_id);
436 let mut wf = workflow.write().unwrap();
437 wf.tasks
438 .update_task_status(&task_id, TaskStatus::Failed, None);
439 }
440 })
441}
442
443fn spawn_task_executions(
449 workflow: &Arc<RwLock<Workflow>>,
450 tasks: Vec<Task>,
451) -> Vec<tokio::task::JoinHandle<()>> {
452 let mut handles = Vec::with_capacity(tasks.len());
453
454 for task in tasks {
455 let task_id = task.id.clone();
456 mark_task_as_running(workflow, &task_id);
460
461 let context = build_task_context(workflow, &task);
463
464 let agent = get_agent_for_task(workflow, &task);
466
467 let handle = spawn_task_execution(workflow.clone(), task, task_id, agent, context);
469 handles.push(handle);
470 }
471
472 handles
473}
474
475async fn await_task_completions(handles: Vec<tokio::task::JoinHandle<()>>) {
480 for handle in handles {
481 if let Err(e) = handle.await {
482 warn!("Task execution failed: {}", e);
483 }
484 }
485}
486
487#[instrument(skip_all)]
500pub async fn execute_workflow(workflow: Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
501 info!("Starting workflow execution");
502
503 while !is_workflow_complete(&workflow) {
504 reset_failed_workflow_tasks(&workflow)?;
507
508 let ready_tasks = get_ready_tasks(&workflow);
510 info!("Found {} ready tasks for execution", ready_tasks.len());
511
512 if ready_tasks.is_empty() {
514 if check_for_circular_dependencies(&workflow) {
515 break;
516 }
517 continue;
518 }
519
520 let handles = spawn_task_executions(&workflow, ready_tasks);
522
523 await_task_completions(handles).await;
525 }
526
527 info!("Workflow execution completed");
528 Ok(())
529}
530
531impl Serialize for Workflow {
532 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
533 where
534 S: Serializer,
535 {
536 let mut state = serializer.serialize_struct("Workflow", 4)?;
537
538 state.serialize_field("id", &self.id)?;
540 state.serialize_field("name", &self.name)?;
541 state.serialize_field("tasks", &self.tasks)?;
542
543 let agents: HashMap<String, Agent> = self
545 .agents
546 .iter()
547 .map(|(id, agent)| (id.clone(), (*agent.as_ref()).clone()))
548 .collect();
549
550 state.serialize_field("agents", &agents)?;
551 state.end()
552 }
553}
554
555impl<'de> Deserialize<'de> for Workflow {
556 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
557 where
558 D: Deserializer<'de>,
559 {
560 #[derive(Deserialize)]
561 #[serde(field_identifier, rename_all = "snake_case")]
562 enum Field {
563 Id,
564 Name,
565 Tasks,
566 Agents,
567 }
568
569 struct WorkflowVisitor;
570
571 impl<'de> Visitor<'de> for WorkflowVisitor {
572 type Value = Workflow;
573
574 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
575 formatter.write_str("struct Workflow")
576 }
577
578 fn visit_map<V>(self, mut map: V) -> Result<Workflow, V::Error>
579 where
580 V: MapAccess<'de>,
581 {
582 let mut id = None;
583 let mut name = None;
584 let mut tasks = None;
585 let mut agents: Option<HashMap<String, Agent>> = None;
586
587 while let Some(key) = map.next_key()? {
588 match key {
589 Field::Id => {
590 id = Some(map.next_value()?);
591 }
592 Field::Tasks => {
593 tasks = Some(map.next_value()?);
594 }
595 Field::Agents => {
596 agents = Some(map.next_value()?);
597 }
598 Field::Name => {
599 name = Some(map.next_value()?);
600 }
601 }
602 }
603
604 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
605 let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
606 let tasks = tasks.ok_or_else(|| de::Error::missing_field("tasks"))?;
607 let agents = agents.ok_or_else(|| de::Error::missing_field("agents"))?;
608
609 let agents = agents
611 .into_iter()
612 .map(|(id, agent)| (id, Arc::new(agent)))
613 .collect();
614
615 Ok(Workflow {
616 id,
617 name,
618 tasks,
619 agents,
620 })
621 }
622 }
623
624 const FIELDS: &[&str] = &["id", "name", "tasks", "agents"];
625 deserializer.deserialize_struct("Workflow", FIELDS, WorkflowVisitor)
626 }
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632 use potato_prompt::{prompt::types::PromptContent, Message, Prompt};
633
634 #[test]
635 fn test_workflow_creation() {
636 let workflow = Workflow::new("Test Workflow");
637 assert_eq!(workflow.name, "Test Workflow");
638 assert_eq!(workflow.id.len(), 36); }
640
641 #[test]
642 fn test_task_list_add_and_get() {
643 let mut task_list = TaskList::new();
644 let prompt_content = PromptContent::Str("Test prompt".to_string());
645 let prompt = Prompt::new_rs(
646 vec![Message::new_rs(prompt_content)],
647 Some("gpt-4o"),
648 Some("openai"),
649 vec![],
650 None,
651 None,
652 )
653 .unwrap();
654
655 let task = Task::new("task1", prompt, "task1", None, None);
656 task_list.add_task(task.clone()).unwrap();
657 assert_eq!(task_list.get_task(&task.id).unwrap().id, task.id);
658 task_list.reset_failed_tasks().unwrap();
659 }
660}