1use crate::tasklist::TaskList;
2use crate::types::Context;
3use crate::{
4 events::{EventTracker, TaskEvent},
5 workflow::error::WorkflowError,
6};
7pub use potato_agent::agents::{
8 agent::{Agent, PyAgent},
9 task::{PyTask, Task, TaskStatus},
10};
11use potato_agent::PyAgentResponse;
12use potato_prompt::parse_response_to_json;
13use potato_prompt::prompt::types::Role;
14use potato_prompt::Message;
15use potato_util::{create_uuid7, utils::update_serde_map_with, PyHelperFuncs};
16use potato_util::{json_to_pydict, pyobject_to_json};
17use pyo3::prelude::*;
18use pyo3::IntoPyObjectExt;
19use serde::{
20 de::{self, MapAccess, Visitor},
21 ser::SerializeStruct,
22 Deserialize, Deserializer, Serialize, Serializer,
23};
24use serde_json::Map;
25use serde_json::Value;
26use std::collections::{HashMap, HashSet};
27use std::sync::Arc;
28use std::sync::RwLock;
29use tracing::instrument;
30use tracing::{debug, error, info, warn};
31
32use pyo3::types::PyDict;
34
35#[derive(Debug)]
36#[pyclass]
37pub struct WorkflowResult {
38 #[pyo3(get)]
39 pub tasks: HashMap<String, Py<PyTask>>,
40
41 #[pyo3(get)]
42 pub events: Vec<TaskEvent>,
43}
44
45impl WorkflowResult {
46 pub fn new(
47 py: Python,
48 tasks: HashMap<String, Task>,
49 output_types: &HashMap<String, Arc<Py<PyAny>>>,
50 events: Vec<TaskEvent>,
51 ) -> Self {
52 let py_tasks = tasks
53 .into_iter()
54 .map(|(id, task)| {
55 let py_agent_response = if let Some(result) = task.result {
56 let output_type = output_types.get(&id).map(|arc| arc.as_ref().clone_ref(py));
57 Some(PyAgentResponse::new(result, output_type))
58 } else {
59 None
60 };
61 let py_task = PyTask {
62 id: task.id.clone(),
63 prompt: task.prompt,
64 dependencies: task.dependencies,
65 status: task.status,
66 agent_id: task.agent_id,
67 result: py_agent_response,
68 max_retries: task.max_retries,
69 retry_count: task.retry_count,
70 };
71 (id, Py::new(py, py_task).unwrap())
72 })
73 .collect::<HashMap<_, _>>();
74
75 Self {
76 tasks: py_tasks,
77 events,
78 }
79 }
80}
81
82#[pymethods]
83impl WorkflowResult {
84 pub fn __str__(&self) -> String {
85 let json = serde_json::json!({
87 "tasks": serde_json::to_value(&self.tasks).unwrap_or(Value::Null),
88 "events": serde_json::to_value(&self.events).unwrap_or(Value::Null)
89 });
90
91 PyHelperFuncs::__str__(&json)
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct Workflow {
98 pub id: String,
99 pub name: String,
100 pub task_list: TaskList,
101 pub agents: HashMap<String, Arc<Agent>>,
102 pub event_tracker: Arc<RwLock<EventTracker>>,
103 pub global_context: Option<Value>,
104}
105
106impl PartialEq for Workflow {
107 fn eq(&self, other: &Self) -> bool {
108 self.id == other.id && self.name == other.name
110 }
111}
112
113impl Workflow {
114 pub async fn reset_agents(&mut self) -> Result<(), WorkflowError> {
124 let mut agents_map = self.agents.clone();
125
126 for agent in self.agents.values_mut() {
127 agents_map.insert(agent.id.clone(), Arc::new(agent.rebuild_client().await?));
128 }
129 self.agents = agents_map;
130 Ok(())
131 }
132 pub fn new(name: &str) -> Self {
133 debug!("Creating new workflow: {}", name);
134 let id = create_uuid7();
135 Self {
136 id: id.clone(),
137 name: name.to_string(),
138 task_list: TaskList::new(),
139 agents: HashMap::new(),
140 event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
141 global_context: None, }
143 }
144 pub fn events(&self) -> Vec<TaskEvent> {
145 let tracker = self.event_tracker.read().unwrap();
146 let events = tracker.events.read().unwrap().clone();
147 events
148 }
149
150 pub fn total_duration(&self) -> i32 {
151 let tracker = self.event_tracker.read().unwrap();
152
153 if tracker.is_empty() {
154 0
155 } else {
156 let mut total_duration = chrono::Duration::zero();
158 for event in tracker.events.read().unwrap().iter() {
159 total_duration += event.details.duration.unwrap_or(chrono::Duration::zero());
160 }
161 total_duration.subsec_millis()
162 }
163 }
164
165 pub fn get_new_workflow(&self, global_context: Option<Value>) -> Result<Self, WorkflowError> {
166 let id = create_uuid7();
168
169 let task_list = self.task_list.deep_clone()?;
171
172 Ok(Workflow {
173 id: id.clone(),
174 name: self.name.clone(),
175 task_list,
176 agents: self.agents.clone(), event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
178 global_context, })
180 }
181
182 pub async fn run(
183 &self,
184 global_context: Option<Value>,
185 ) -> Result<Arc<RwLock<Workflow>>, WorkflowError> {
186 debug!("Running workflow: {}", self.name);
187
188 let run_workflow = Arc::new(RwLock::new(self.get_new_workflow(global_context)?));
189
190 execute_workflow(&run_workflow).await?;
191
192 Ok(run_workflow)
193 }
194
195 pub fn is_complete(&self) -> bool {
196 self.task_list.is_complete()
197 }
198
199 pub fn pending_count(&self) -> usize {
200 self.task_list.pending_count()
201 }
202
203 pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
204 self.task_list.add_task(task)
205 }
206
207 pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
208 for task in tasks {
209 self.task_list.add_task(task)?;
210 }
211 Ok(())
212 }
213
214 pub fn add_agent(&mut self, agent: &Agent) {
215 self.agents
216 .insert(agent.id.clone(), Arc::new(agent.clone()));
217 }
218
219 pub fn execution_plan(&self) -> Result<HashMap<i32, HashSet<String>>, WorkflowError> {
220 let mut remaining: HashMap<String, HashSet<String>> = self
221 .task_list
222 .tasks
223 .iter()
224 .map(|(id, task)| {
225 (
226 id.clone(),
227 task.read().unwrap().dependencies.iter().cloned().collect(),
228 )
229 })
230 .collect();
231
232 let mut executed = HashSet::new();
233 let mut plan = HashMap::new();
234 let mut step = 1;
235
236 while !remaining.is_empty() {
237 let ready_keys: Vec<String> = remaining
239 .iter()
240 .filter(|(_, deps)| deps.is_subset(&executed))
241 .map(|(id, _)| id.to_string())
242 .collect();
243
244 if ready_keys.is_empty() {
245 break;
247 }
248
249 let mut ready_set = HashSet::with_capacity(ready_keys.len());
251
252 for key in ready_keys {
254 executed.insert(key.clone());
255 remaining.remove(&key);
256 ready_set.insert(key);
257 }
258
259 plan.insert(step, ready_set);
261
262 step += 1;
263 }
264
265 Ok(plan)
266 }
267
268 pub fn __str__(&self) -> String {
269 PyHelperFuncs::__str__(&self.task_list)
270 }
271
272 pub fn serialize(&self) -> Result<String, serde_json::Error> {
273 let json = serde_json::to_string(self).unwrap();
275 Ok(json)
277 }
278
279 pub fn from_json(json: &str) -> Result<Self, WorkflowError> {
280 Ok(serde_json::from_str(json)?)
282 }
283
284 pub fn task_names(&self) -> Vec<String> {
285 self.task_list
286 .tasks
287 .keys()
288 .cloned()
289 .collect::<Vec<String>>()
290 }
291}
292
293fn is_workflow_complete(workflow: &Arc<RwLock<Workflow>>) -> bool {
298 workflow.read().unwrap().is_complete()
299}
300
301fn reset_failed_workflow_tasks(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
306 match workflow.write().unwrap().task_list.reset_failed_tasks() {
307 Ok(_) => Ok(()),
308 Err(e) => {
309 warn!("Failed to reset failed tasks: {}", e);
310 Err(e)
311 }
312 }
313}
314
315fn get_ready_tasks(workflow: &Arc<RwLock<Workflow>>) -> Vec<Arc<RwLock<Task>>> {
320 workflow.read().unwrap().task_list.get_ready_tasks()
321}
322
323fn check_for_circular_dependencies(workflow: &Arc<RwLock<Workflow>>) -> bool {
328 let pending_count = workflow.read().unwrap().pending_count();
329
330 if pending_count > 0 {
331 warn!(
332 "No ready tasks found but {} pending tasks remain. Possible circular dependency.",
333 pending_count
334 );
335 return true;
336 }
337
338 false
339}
340
341fn mark_task_as_running(task: Arc<RwLock<Task>>, event_tracker: &Arc<RwLock<EventTracker>>) {
346 let mut task = task.write().unwrap();
347 task.set_status(TaskStatus::Running);
348 event_tracker.write().unwrap().record_task_started(&task.id);
349}
350
351fn get_agent_for_task(workflow: &Arc<RwLock<Workflow>>, agent_id: &str) -> Option<Arc<Agent>> {
356 let wf = workflow.read().unwrap();
357 wf.agents.get(agent_id).cloned()
358}
359
360#[instrument(skip_all)]
366fn build_task_context(
367 workflow: &Arc<RwLock<Workflow>>,
368 task_dependencies: &Vec<String>,
369) -> Result<Context, WorkflowError> {
370 let wf = workflow.read().unwrap();
371 let mut ctx = HashMap::new();
372 let mut param_ctx: Value = Value::Object(Map::new());
373
374 for dep_id in task_dependencies {
375 debug!("Building context for task dependency: {}", dep_id);
376 if let Some(dep) = wf.task_list.get_task(dep_id) {
377 if let Some(result) = &dep.read().unwrap().result {
378 let msg_to_insert = result.response.to_message(Role::Assistant);
379
380 match msg_to_insert {
381 Ok(message) => {
382 ctx.insert(dep_id.clone(), message);
383 }
384 Err(e) => {
385 warn!("Failed to convert response to message: {}", e);
386 }
387 }
388
389 if let Some(structure_output) = result.response.extract_structured_data() {
390 if structure_output.is_object() {
393 update_serde_map_with(&mut param_ctx, &structure_output)?;
395 }
396 }
397 }
398 }
399 }
400
401 debug!("Built context for task dependencies: {:?}", ctx);
402 let global_context = workflow.read().unwrap().global_context.clone();
403
404 Ok((ctx, param_ctx, global_context))
405}
406
407fn spawn_task_execution(
416 event_tracker: Arc<RwLock<EventTracker>>,
417 task: Arc<RwLock<Task>>,
418 task_id: String,
419 agent: Option<Arc<Agent>>,
420 context: HashMap<String, Vec<Message>>,
421 parameter_context: Value,
422 global_context: Option<Value>,
423) -> tokio::task::JoinHandle<()> {
424 tokio::spawn(async move {
425 if let Some(agent) = agent {
426 let result = agent
430 .execute_task_with_context(&task, context, parameter_context, global_context)
431 .await;
432 match result {
433 Ok(response) => {
434 let mut write_task = task.write().unwrap();
435 write_task.set_status(TaskStatus::Completed);
436 write_task.set_result(response.clone());
437 event_tracker.write().unwrap().record_task_completed(
438 &write_task.id,
439 &write_task.prompt,
440 response,
441 );
442 }
443 Err(e) => {
444 error!("Task {} failed: {}", task_id, e);
445 let mut write_task = task.write().unwrap();
446 write_task.set_status(TaskStatus::Failed);
447 event_tracker.write().unwrap().record_task_failed(
448 &write_task.id,
449 &e.to_string(),
450 &write_task.prompt,
451 );
452 }
453 }
454 } else {
455 error!("No agent found for task {}", task_id);
456 let mut write_task = task.write().unwrap();
457 write_task.set_status(TaskStatus::Failed);
458 }
459 })
460}
461
462fn get_parameters_from_context(task: Arc<RwLock<Task>>) -> (String, Vec<String>, String) {
463 let (task_id, dependencies, agent_id) = {
464 let task_guard = task.read().unwrap();
465 (
466 task_guard.id.clone(),
467 task_guard.dependencies.clone(),
468 task_guard.agent_id.clone(),
469 )
470 };
471
472 (task_id, dependencies, agent_id)
473}
474
475fn spawn_task_executions(
481 workflow: &Arc<RwLock<Workflow>>,
482 ready_tasks: Vec<Arc<RwLock<Task>>>,
483) -> Result<Vec<tokio::task::JoinHandle<()>>, WorkflowError> {
484 let mut handles = Vec::with_capacity(ready_tasks.len());
485
486 let event_tracker = workflow.read().unwrap().event_tracker.clone();
488
489 for task in ready_tasks {
490 let (task_id, dependencies, agent_id) = get_parameters_from_context(task.clone());
492
493 mark_task_as_running(task.clone(), &event_tracker);
496
497 let (context, parameter_context, global_context) =
502 build_task_context(workflow, &dependencies)?;
503
504 let agent = get_agent_for_task(workflow, &agent_id);
506
507 let handle = spawn_task_execution(
509 event_tracker.clone(),
510 task.clone(),
511 task_id,
512 agent,
513 context,
514 parameter_context,
515 global_context,
516 );
517 handles.push(handle);
518 }
519
520 Ok(handles)
521}
522
523async fn await_task_completions(handles: Vec<tokio::task::JoinHandle<()>>) {
528 for handle in handles {
529 if let Err(e) = handle.await {
530 warn!("Task execution failed: {}", e);
531 }
532 }
533}
534
535#[instrument(skip_all)]
548pub async fn execute_workflow(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
549 debug!("Starting workflow execution");
553
554 while !is_workflow_complete(workflow) {
556 reset_failed_workflow_tasks(workflow)?;
559
560 let ready_tasks = get_ready_tasks(workflow);
563 debug!("Found {} ready tasks for execution", ready_tasks.len());
564
565 if ready_tasks.is_empty() {
567 if check_for_circular_dependencies(workflow) {
568 break;
569 }
570 continue;
571 }
572
573 let handles = spawn_task_executions(workflow, ready_tasks)?;
575
576 await_task_completions(handles).await;
578 }
579
580 debug!("Workflow execution completed");
581 Ok(())
582}
583
584impl Serialize for Workflow {
585 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
586 where
587 S: Serializer,
588 {
589 let mut state = serializer.serialize_struct("Workflow", 4)?;
590
591 state.serialize_field("id", &self.id)?;
593 state.serialize_field("name", &self.name)?;
594 state.serialize_field("task_list", &self.task_list)?;
595 state.serialize_field("agents", &self.agents)?;
596
597 state.end()
598 }
599}
600
601impl<'de> Deserialize<'de> for Workflow {
602 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
603 where
604 D: Deserializer<'de>,
605 {
606 #[derive(Deserialize)]
607 #[serde(field_identifier, rename_all = "snake_case")]
608 enum Field {
609 Id,
610 Name,
611 TaskList,
612 Agents,
613 }
614
615 struct WorkflowVisitor;
616
617 impl<'de> Visitor<'de> for WorkflowVisitor {
618 type Value = Workflow;
619
620 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
621 formatter.write_str("struct Workflow")
622 }
623
624 fn visit_map<V>(self, mut map: V) -> Result<Workflow, V::Error>
625 where
626 V: MapAccess<'de>,
627 {
628 let mut id = None;
629 let mut name = None;
630 let mut task_list_data = None;
631 let mut agents: Option<HashMap<String, Agent>> = None;
632
633 while let Some(key) = map.next_key()? {
634 match key {
635 Field::Id => {
636 let value: String = map.next_value().map_err(|e| {
637 error!("Failed to deserialize field 'id': {e}");
638 de::Error::custom(format!("Failed to deserialize field 'id': {e}"))
639 })?;
640 id = Some(value);
641 }
642 Field::TaskList => {
643 let value: TaskList = map.next_value().map_err(|e| {
645 error!("Failed to deserialize field 'task_list': {e}");
646 de::Error::custom(format!(
647 "Failed to deserialize field 'task_list': {e}",
648 ))
649 })?;
650
651 task_list_data = Some(value);
652 }
653 Field::Name => {
654 let value: String = map.next_value().map_err(|e| {
655 error!("Failed to deserialize field 'name': {e}");
656 de::Error::custom(format!(
657 "Failed to deserialize field 'name': {e}",
658 ))
659 })?;
660 name = Some(value);
661 }
662 Field::Agents => {
663 let value: HashMap<String, Agent> = map.next_value().map_err(|e| {
664 error!("Failed to deserialize field 'agents': {e}");
665 de::Error::custom(format!(
666 "Failed to deserialize field 'agents': {e}"
667 ))
668 })?;
669 agents = Some(value);
670 }
671 }
672 }
673
674 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
675 let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
676 let task_list_data =
677 task_list_data.ok_or_else(|| de::Error::missing_field("task_list"))?;
678 let agents = agents.ok_or_else(|| de::Error::missing_field("agents"))?;
679
680 let event_tracker = Arc::new(RwLock::new(EventTracker::new(create_uuid7())));
681
682 let agents = agents
684 .into_iter()
685 .map(|(id, agent)| (id, Arc::new(agent)))
686 .collect();
687
688 Ok(Workflow {
689 id,
690 name,
691 task_list: task_list_data,
692 agents,
693 event_tracker,
694 global_context: None, })
696 }
697 }
698
699 const FIELDS: &[&str] = &["id", "name", "task_list", "agents"];
700 deserializer.deserialize_struct("Workflow", FIELDS, WorkflowVisitor)
701 }
702}
703
704#[pyclass(name = "Workflow")]
705#[derive(Debug, Clone)]
706pub struct PyWorkflow {
707 workflow: Workflow,
708
709 output_types: HashMap<String, Arc<Py<PyAny>>>,
713
714 runtime: Arc<tokio::runtime::Runtime>,
716}
717
718#[pymethods]
719impl PyWorkflow {
720 #[new]
721 #[pyo3(signature = (name))]
722 pub fn new(name: &str) -> Result<Self, WorkflowError> {
723 debug!("Creating new workflow: {}", name);
724 Ok(Self {
725 workflow: Workflow::new(name),
726 output_types: HashMap::new(),
727 runtime: Arc::new(
728 tokio::runtime::Runtime::new()
729 .map_err(|e| WorkflowError::RuntimeError(e.to_string()))?,
730 ),
731 })
732 }
733
734 #[getter]
735 pub fn name(&self) -> String {
736 self.workflow.name.clone()
737 }
738
739 #[getter]
740 pub fn task_list(&self) -> TaskList {
741 self.workflow.task_list.clone()
742 }
743
744 #[getter]
745 pub fn is_workflow(&self) -> bool {
746 true
747 }
748
749 #[getter]
750 pub fn __workflow__(&self) -> Result<String, WorkflowError> {
751 self.model_dump_json()
752 }
753
754 #[getter]
755 pub fn agents(&self) -> Result<HashMap<String, PyAgent>, WorkflowError> {
756 self.workflow
757 .agents
758 .iter()
759 .map(|(id, agent)| {
760 Ok((
761 id.clone(),
762 PyAgent {
763 agent: agent.clone(),
764 runtime: self.runtime.clone(),
765 },
766 ))
767 })
768 .collect::<Result<HashMap<_, _>, _>>()
769 }
770
771 #[pyo3(signature = (task_output_types))]
772 pub fn add_task_output_types<'py>(
773 &mut self,
774 task_output_types: Bound<'py, PyDict>,
775 ) -> PyResult<()> {
776 let converted: HashMap<String, Arc<Py<PyAny>>> = task_output_types
777 .iter()
778 .map(|(k, v)| -> PyResult<(String, Arc<Py<PyAny>>)> {
779 let key = k.extract::<String>()?;
781 let value = v.clone().unbind();
782 Ok((key, Arc::new(value)))
783 })
784 .collect::<PyResult<_>>()?;
785 self.output_types.extend(converted);
786 Ok(())
787 }
788
789 #[pyo3(signature = (task, output_type = None))]
790 pub fn add_task(
791 &mut self,
792 py: Python<'_>,
793 mut task: Task,
794 output_type: Option<Bound<'_, PyAny>>,
795 ) -> Result<(), WorkflowError> {
796 if let Some(output_type) = output_type {
797 (task.prompt.response_type, task.prompt.response_json_schema) =
799 parse_response_to_json(py, &output_type)
800 .map_err(|e| WorkflowError::InvalidOutputType(e.to_string()))?;
801
802 self.output_types
804 .insert(task.id.clone(), Arc::new(output_type.unbind()));
805 }
806
807 self.workflow.task_list.add_task(task)?;
808 Ok(())
809 }
810
811 pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
812 for task in tasks {
813 self.workflow.task_list.add_task(task)?;
814 }
815 Ok(())
816 }
817
818 pub fn add_agent(&mut self, agent: &Bound<'_, PyAgent>) {
819 let agent = agent.extract::<PyAgent>().unwrap().agent.clone();
821 self.workflow.agents.insert(agent.id.clone(), agent);
822 }
823
824 pub fn is_complete(&self) -> bool {
825 self.workflow.task_list.is_complete()
826 }
827
828 pub fn pending_count(&self) -> usize {
829 self.workflow.task_list.pending_count()
830 }
831
832 pub fn execution_plan<'py>(
833 &self,
834 py: Python<'py>,
835 ) -> Result<Bound<'py, PyDict>, WorkflowError> {
836 let plan = self.workflow.execution_plan()?;
837 debug!("Execution plan: {:?}", plan);
838
839 let json = serde_json::to_value(plan).map_err(|e| {
841 error!("Failed to serialize execution plan to JSON: {}", e);
842 e
843 })?;
844
845 let pydict = PyDict::new(py);
846 json_to_pydict(py, &json, &pydict)?;
847
848 Ok(pydict)
849 }
850
851 #[pyo3(signature = (global_context=None))]
852 pub fn run(
853 &self,
854 py: Python,
855 global_context: Option<Bound<'_, PyDict>>,
856 ) -> Result<WorkflowResult, WorkflowError> {
857 debug!("Running workflow: {}", self.workflow.name);
858
859 let global_context = if let Some(context) = global_context {
861 let json_value = pyobject_to_json(&context.into_bound_py_any(py)?)?;
863 Some(json_value)
864 } else {
865 None
866 };
867
868 let workflow: Arc<RwLock<Workflow>> = self
869 .runtime
870 .block_on(async { self.workflow.run(global_context).await })?;
871
872 let workflow_result = match Arc::try_unwrap(workflow) {
874 Ok(rwlock) => {
876 let workflow = rwlock
878 .into_inner()
879 .map_err(|_| WorkflowError::LockAcquireError)?;
880
881 let events = workflow
883 .event_tracker
884 .read()
885 .unwrap()
886 .events
887 .read()
888 .unwrap()
889 .clone();
890
891 WorkflowResult::new(py, workflow.task_list.tasks(), &self.output_types, events)
893 }
894 Err(arc) => {
896 error!("Workflow still has other references, reading instead of consuming.");
898 let workflow = arc
899 .read()
900 .map_err(|_| WorkflowError::ReadLockAcquireError)?;
901
902 let events = workflow
904 .event_tracker
905 .read()
906 .unwrap()
907 .events
908 .read()
909 .unwrap()
910 .clone();
911
912 WorkflowResult::new(py, workflow.task_list.tasks(), &self.output_types, events)
913 }
914 };
915
916 info!("Workflow execution completed successfully.");
917 Ok(workflow_result)
918 }
919
920 pub fn model_dump_json(&self) -> Result<String, WorkflowError> {
921 Ok(self.workflow.serialize()?)
922 }
923
924 #[staticmethod]
925 #[pyo3(signature = (json_string, output_types=None))]
926 pub fn model_validate_json(
927 json_string: String,
928 output_types: Option<Bound<'_, PyDict>>,
929 ) -> Result<Self, WorkflowError> {
930 let runtime = Arc::new(
931 tokio::runtime::Runtime::new()
932 .map_err(|e| WorkflowError::RuntimeError(e.to_string()))?,
933 );
934 let mut workflow: Workflow = Workflow::from_json(&json_string)?;
935
936 runtime.block_on(async { workflow.reset_agents().await })?;
939
940 let output_types = match output_types {
941 Some(output_types) => output_types
942 .iter()
943 .map(|(k, v)| -> PyResult<(String, Arc<Py<PyAny>>)> {
944 let key = k.extract::<String>()?;
945 let value = v.clone().unbind();
946 Ok((key, Arc::new(value)))
947 })
948 .collect::<PyResult<HashMap<String, Arc<Py<PyAny>>>>>()?,
949 None => HashMap::new(),
950 };
951
952 let py_workflow = PyWorkflow {
953 workflow,
954 output_types,
955 runtime,
956 };
957
958 Ok(py_workflow)
959 }
960}
961
962#[cfg(test)]
963mod tests {
964 use super::*;
965 use potato_prompt::prompt::ResponseType;
966 use potato_prompt::{prompt::types::PromptContent, Message, Prompt};
967
968 #[test]
969 fn test_workflow_creation() {
970 let workflow = Workflow::new("Test Workflow");
971 assert_eq!(workflow.name, "Test Workflow");
972 assert_eq!(workflow.id.len(), 36); }
974
975 #[test]
976 fn test_task_list_add_and_get() {
977 let mut task_list = TaskList::new();
978 let prompt_content = PromptContent::Str("Test prompt".to_string());
979 let prompt = Prompt::new_rs(
980 vec![Message::new_rs(prompt_content)],
981 "gpt-4o",
982 potato_type::Provider::OpenAI,
983 vec![],
984 None,
985 None,
986 ResponseType::Null,
987 )
988 .unwrap();
989
990 let task = Task::new("task1", prompt, "task1", None, None);
991 task_list.add_task(task.clone()).unwrap();
992 assert_eq!(
993 task_list.get_task(&task.id).unwrap().read().unwrap().id,
994 task.id
995 );
996 task_list.reset_failed_tasks().unwrap();
997 }
998}