1use crate::tasklist::TaskList;
2use crate::types::Context;
3use crate::{
4 events::{EventTracker, TaskEvent},
5 workflow::error::WorkflowError,
6};
7pub use potato_agent::agents::{
8 agent::{Agent, PyAgent},
9 task::{PyTask, Task, TaskStatus},
10 types::ChatResponse,
11};
12use potato_agent::PyAgentResponse;
13use potato_prompt::parse_response_to_json;
14use potato_prompt::prompt::types::Role;
15use potato_prompt::Message;
16use potato_util::{create_uuid7, utils::update_serde_map_with, PyHelperFuncs};
17use potato_util::{json_to_pydict, pyobject_to_json};
18use pyo3::prelude::*;
19use pyo3::IntoPyObjectExt;
20use serde::{
21 de::{self, MapAccess, Visitor},
22 ser::SerializeStruct,
23 Deserialize, Deserializer, Serialize, Serializer,
24};
25use serde_json::Map;
26use serde_json::Value;
27use std::collections::{HashMap, HashSet};
28use std::sync::Arc;
29use std::sync::RwLock;
30use tracing::instrument;
31use tracing::{debug, error, info, warn};
32
33use pyo3::types::PyDict;
35
36#[derive(Debug)]
37#[pyclass]
38pub struct WorkflowResult {
39 #[pyo3(get)]
40 pub tasks: HashMap<String, Py<PyTask>>,
41
42 #[pyo3(get)]
43 pub events: Vec<TaskEvent>,
44}
45
46impl WorkflowResult {
47 pub fn new(
48 py: Python,
49 tasks: HashMap<String, Task>,
50 output_types: &HashMap<String, Arc<PyObject>>,
51 events: Vec<TaskEvent>,
52 ) -> Self {
53 let py_tasks = tasks
54 .into_iter()
55 .map(|(id, task)| {
56 let py_agent_response = if let Some(result) = task.result {
57 let output_type = output_types.get(&id).map(|arc| arc.as_ref().clone_ref(py));
58 Some(PyAgentResponse::new(result, output_type))
59 } else {
60 None
61 };
62 let py_task = PyTask {
63 id: task.id.clone(),
64 prompt: task.prompt,
65 dependencies: task.dependencies,
66 status: task.status,
67 agent_id: task.agent_id,
68 result: py_agent_response,
69 max_retries: task.max_retries,
70 retry_count: task.retry_count,
71 };
72 (id, Py::new(py, py_task).unwrap())
73 })
74 .collect::<HashMap<_, _>>();
75
76 Self {
77 tasks: py_tasks,
78 events,
79 }
80 }
81}
82
83#[pymethods]
84impl WorkflowResult {
85 pub fn __str__(&self) -> String {
86 let json = serde_json::json!({
88 "tasks": serde_json::to_value(&self.tasks).unwrap_or(Value::Null),
89 "events": serde_json::to_value(&self.events).unwrap_or(Value::Null)
90 });
91
92 PyHelperFuncs::__str__(&json)
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct Workflow {
99 pub id: String,
100 pub name: String,
101 pub task_list: TaskList,
102 pub agents: HashMap<String, Arc<Agent>>,
103 pub event_tracker: Arc<RwLock<EventTracker>>,
104 pub global_context: Option<Value>,
105}
106
107impl PartialEq for Workflow {
108 fn eq(&self, other: &Self) -> bool {
109 self.id == other.id && self.name == other.name
111 }
112}
113
114impl Workflow {
115 pub fn new(name: &str) -> Self {
116 debug!("Creating new workflow: {}", name);
117 let id = create_uuid7();
118 Self {
119 id: id.clone(),
120 name: name.to_string(),
121 task_list: TaskList::new(),
122 agents: HashMap::new(),
123 event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
124 global_context: None, }
126 }
127 pub fn events(&self) -> Vec<TaskEvent> {
128 let tracker = self.event_tracker.read().unwrap();
129 let events = tracker.events.read().unwrap().clone();
130 events
131 }
132
133 pub fn total_duration(&self) -> i32 {
134 let tracker = self.event_tracker.read().unwrap();
135
136 if tracker.is_empty() {
137 0
138 } else {
139 let mut total_duration = chrono::Duration::zero();
141 for event in tracker.events.read().unwrap().iter() {
142 total_duration += event.details.duration.unwrap_or(chrono::Duration::zero());
143 }
144 total_duration.subsec_millis()
145 }
146 }
147
148 pub fn get_new_workflow(&self, global_context: Option<Value>) -> Result<Self, WorkflowError> {
149 let id = create_uuid7();
151
152 let task_list = self.task_list.deep_clone()?;
154
155 Ok(Workflow {
156 id: id.clone(),
157 name: self.name.clone(),
158 task_list,
159 agents: self.agents.clone(), event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
161 global_context, })
163 }
164
165 pub async fn run(
166 &self,
167 global_context: Option<Value>,
168 ) -> Result<Arc<RwLock<Workflow>>, WorkflowError> {
169 debug!("Running workflow: {}", self.name);
170
171 let run_workflow = Arc::new(RwLock::new(self.get_new_workflow(global_context)?));
172
173 execute_workflow(&run_workflow).await?;
174
175 Ok(run_workflow)
176 }
177
178 pub fn is_complete(&self) -> bool {
179 self.task_list.is_complete()
180 }
181
182 pub fn pending_count(&self) -> usize {
183 self.task_list.pending_count()
184 }
185
186 pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
187 self.task_list.add_task(task)
188 }
189
190 pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
191 for task in tasks {
192 self.task_list.add_task(task)?;
193 }
194 Ok(())
195 }
196
197 pub fn add_agent(&mut self, agent: &Agent) {
198 self.agents
199 .insert(agent.id.clone(), Arc::new(agent.clone()));
200 }
201
202 pub fn execution_plan(&self) -> Result<HashMap<i32, HashSet<String>>, WorkflowError> {
203 let mut remaining: HashMap<String, HashSet<String>> = self
204 .task_list
205 .tasks
206 .iter()
207 .map(|(id, task)| {
208 (
209 id.clone(),
210 task.read().unwrap().dependencies.iter().cloned().collect(),
211 )
212 })
213 .collect();
214
215 let mut executed = HashSet::new();
216 let mut plan = HashMap::new();
217 let mut step = 1;
218
219 while !remaining.is_empty() {
220 let ready_keys: Vec<String> = remaining
222 .iter()
223 .filter(|(_, deps)| deps.is_subset(&executed))
224 .map(|(id, _)| id.to_string())
225 .collect();
226
227 if ready_keys.is_empty() {
228 break;
230 }
231
232 let mut ready_set = HashSet::with_capacity(ready_keys.len());
234
235 for key in ready_keys {
237 executed.insert(key.clone());
238 remaining.remove(&key);
239 ready_set.insert(key);
240 }
241
242 plan.insert(step, ready_set);
244
245 step += 1;
246 }
247
248 Ok(plan)
249 }
250
251 pub fn __str__(&self) -> String {
252 PyHelperFuncs::__str__(&self.task_list)
253 }
254
255 pub fn serialize(&self) -> Result<String, serde_json::Error> {
256 let json = serde_json::to_string(self).unwrap();
258 Ok(json)
260 }
261
262 pub fn from_json(json: &str) -> Result<Self, WorkflowError> {
263 Ok(serde_json::from_str(json)?)
265 }
266
267 pub fn task_names(&self) -> Vec<String> {
268 self.task_list
269 .tasks
270 .keys()
271 .cloned()
272 .collect::<Vec<String>>()
273 }
274}
275
276fn is_workflow_complete(workflow: &Arc<RwLock<Workflow>>) -> bool {
281 workflow.read().unwrap().is_complete()
282}
283
284fn reset_failed_workflow_tasks(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
289 match workflow.write().unwrap().task_list.reset_failed_tasks() {
290 Ok(_) => Ok(()),
291 Err(e) => {
292 warn!("Failed to reset failed tasks: {}", e);
293 Err(e)
294 }
295 }
296}
297
298fn get_ready_tasks(workflow: &Arc<RwLock<Workflow>>) -> Vec<Arc<RwLock<Task>>> {
303 workflow.read().unwrap().task_list.get_ready_tasks()
304}
305
306fn check_for_circular_dependencies(workflow: &Arc<RwLock<Workflow>>) -> bool {
311 let pending_count = workflow.read().unwrap().pending_count();
312
313 if pending_count > 0 {
314 warn!(
315 "No ready tasks found but {} pending tasks remain. Possible circular dependency.",
316 pending_count
317 );
318 return true;
319 }
320
321 false
322}
323
324fn mark_task_as_running(task: Arc<RwLock<Task>>, event_tracker: &Arc<RwLock<EventTracker>>) {
329 let mut task = task.write().unwrap();
330 task.set_status(TaskStatus::Running);
331 event_tracker.write().unwrap().record_task_started(&task.id);
332}
333
334fn get_agent_for_task(workflow: &Arc<RwLock<Workflow>>, agent_id: &str) -> Option<Arc<Agent>> {
339 let wf = workflow.read().unwrap();
340 wf.agents.get(agent_id).cloned()
341}
342
343#[instrument(skip_all)]
349fn build_task_context(
350 workflow: &Arc<RwLock<Workflow>>,
351 task_dependencies: &Vec<String>,
352) -> Result<Context, WorkflowError> {
353 let wf = workflow.read().unwrap();
354 let mut ctx = HashMap::new();
355 let mut param_ctx: Value = Value::Object(Map::new());
356
357 for dep_id in task_dependencies {
358 debug!("Building context for task dependency: {}", dep_id);
359 if let Some(dep) = wf.task_list.get_task(dep_id) {
360 if let Some(result) = &dep.read().unwrap().result {
361 let msg_to_insert = result.response.to_message(Role::Assistant);
362
363 match msg_to_insert {
364 Ok(message) => {
365 ctx.insert(dep_id.clone(), message);
366 }
367 Err(e) => {
368 warn!("Failed to convert response to message: {}", e);
369 }
370 }
371
372 if let Some(structure_output) = result.response.extract_structured_data() {
373 if structure_output.is_object() {
376 update_serde_map_with(&mut param_ctx, &structure_output)?;
378 }
379 }
380 }
381 }
382 }
383
384 debug!("Built context for task dependencies: {:?}", ctx);
385 let global_context = workflow.read().unwrap().global_context.clone();
386
387 Ok((ctx, param_ctx, global_context))
388}
389
390fn spawn_task_execution(
399 event_tracker: Arc<RwLock<EventTracker>>,
400 task: Arc<RwLock<Task>>,
401 task_id: String,
402 agent: Option<Arc<Agent>>,
403 context: HashMap<String, Vec<Message>>,
404 parameter_context: Value,
405 global_context: Option<Value>,
406) -> tokio::task::JoinHandle<()> {
407 tokio::spawn(async move {
408 if let Some(agent) = agent {
409 let result = agent
413 .execute_task_with_context(&task, context, parameter_context, global_context)
414 .await;
415 match result {
416 Ok(response) => {
417 let mut write_task = task.write().unwrap();
418 write_task.set_status(TaskStatus::Completed);
419 write_task.set_result(response.clone());
420 event_tracker.write().unwrap().record_task_completed(
421 &write_task.id,
422 &write_task.prompt,
423 response,
424 );
425 }
426 Err(e) => {
427 error!("Task {} failed: {}", task_id, e);
428 let mut write_task = task.write().unwrap();
429 write_task.set_status(TaskStatus::Failed);
430 event_tracker.write().unwrap().record_task_failed(
431 &write_task.id,
432 &e.to_string(),
433 &write_task.prompt,
434 );
435 }
436 }
437 } else {
438 error!("No agent found for task {}", task_id);
439 let mut write_task = task.write().unwrap();
440 write_task.set_status(TaskStatus::Failed);
441 }
442 })
443}
444
445fn get_parameters_from_context(task: Arc<RwLock<Task>>) -> (String, Vec<String>, String) {
446 let (task_id, dependencies, agent_id) = {
447 let task_guard = task.read().unwrap();
448 (
449 task_guard.id.clone(),
450 task_guard.dependencies.clone(),
451 task_guard.agent_id.clone(),
452 )
453 };
454
455 (task_id, dependencies, agent_id)
456}
457
458fn spawn_task_executions(
464 workflow: &Arc<RwLock<Workflow>>,
465 ready_tasks: Vec<Arc<RwLock<Task>>>,
466) -> Result<Vec<tokio::task::JoinHandle<()>>, WorkflowError> {
467 let mut handles = Vec::with_capacity(ready_tasks.len());
468
469 let event_tracker = workflow.read().unwrap().event_tracker.clone();
471
472 for task in ready_tasks {
473 let (task_id, dependencies, agent_id) = get_parameters_from_context(task.clone());
475
476 mark_task_as_running(task.clone(), &event_tracker);
479
480 let (context, parameter_context, global_context) =
485 build_task_context(workflow, &dependencies)?;
486
487 let agent = get_agent_for_task(workflow, &agent_id);
489
490 let handle = spawn_task_execution(
492 event_tracker.clone(),
493 task.clone(),
494 task_id,
495 agent,
496 context,
497 parameter_context,
498 global_context,
499 );
500 handles.push(handle);
501 }
502
503 Ok(handles)
504}
505
506async fn await_task_completions(handles: Vec<tokio::task::JoinHandle<()>>) {
511 for handle in handles {
512 if let Err(e) = handle.await {
513 warn!("Task execution failed: {}", e);
514 }
515 }
516}
517
518#[instrument(skip_all)]
531pub async fn execute_workflow(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
532 debug!("Starting workflow execution");
536
537 while !is_workflow_complete(workflow) {
539 reset_failed_workflow_tasks(workflow)?;
542
543 let ready_tasks = get_ready_tasks(workflow);
546 debug!("Found {} ready tasks for execution", ready_tasks.len());
547
548 if ready_tasks.is_empty() {
550 if check_for_circular_dependencies(workflow) {
551 break;
552 }
553 continue;
554 }
555
556 let handles = spawn_task_executions(workflow, ready_tasks)?;
558
559 await_task_completions(handles).await;
561 }
562
563 debug!("Workflow execution completed");
564 Ok(())
565}
566
567impl Serialize for Workflow {
568 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
569 where
570 S: Serializer,
571 {
572 let mut state = serializer.serialize_struct("Workflow", 4)?;
573
574 state.serialize_field("id", &self.id)?;
576 state.serialize_field("name", &self.name)?;
577 state.serialize_field("task_list", &self.task_list)?;
578 state.serialize_field("agents", &self.agents)?;
579
580 state.end()
581 }
582}
583
584impl<'de> Deserialize<'de> for Workflow {
585 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
586 where
587 D: Deserializer<'de>,
588 {
589 #[derive(Deserialize)]
590 #[serde(field_identifier, rename_all = "snake_case")]
591 enum Field {
592 Id,
593 Name,
594 TaskList,
595 Agents,
596 }
597
598 struct WorkflowVisitor;
599
600 impl<'de> Visitor<'de> for WorkflowVisitor {
601 type Value = Workflow;
602
603 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
604 formatter.write_str("struct Workflow")
605 }
606
607 fn visit_map<V>(self, mut map: V) -> Result<Workflow, V::Error>
608 where
609 V: MapAccess<'de>,
610 {
611 let mut id = None;
612 let mut name = None;
613 let mut task_list_data = None;
614 let mut agents: Option<HashMap<String, Agent>> = None;
615
616 while let Some(key) = map.next_key()? {
617 match key {
618 Field::Id => {
619 let value: String = map.next_value().map_err(|e| {
620 error!("Failed to deserialize field 'id': {e}");
621 de::Error::custom(format!("Failed to deserialize field 'id': {e}"))
622 })?;
623 id = Some(value);
624 }
625 Field::TaskList => {
626 let value: TaskList = map.next_value().map_err(|e| {
628 error!("Failed to deserialize field 'task_list': {e}");
629 de::Error::custom(format!(
630 "Failed to deserialize field 'task_list': {e}",
631 ))
632 })?;
633
634 task_list_data = Some(value);
635 }
636 Field::Name => {
637 let value: String = map.next_value().map_err(|e| {
638 error!("Failed to deserialize field 'name': {e}");
639 de::Error::custom(format!(
640 "Failed to deserialize field 'name': {e}",
641 ))
642 })?;
643 name = Some(value);
644 }
645 Field::Agents => {
646 let value: HashMap<String, Agent> = map.next_value().map_err(|e| {
647 error!("Failed to deserialize field 'agents': {e}");
648 de::Error::custom(format!(
649 "Failed to deserialize field 'agents': {e}"
650 ))
651 })?;
652 agents = Some(value);
653 }
654 }
655 }
656
657 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
658 let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
659 let task_list_data =
660 task_list_data.ok_or_else(|| de::Error::missing_field("task_list"))?;
661 let agents = agents.ok_or_else(|| de::Error::missing_field("agents"))?;
662
663 let event_tracker = Arc::new(RwLock::new(EventTracker::new(create_uuid7())));
664
665 let agents = agents
667 .into_iter()
668 .map(|(id, agent)| (id, Arc::new(agent)))
669 .collect();
670
671 Ok(Workflow {
672 id,
673 name,
674 task_list: task_list_data,
675 agents,
676 event_tracker,
677 global_context: None, })
679 }
680 }
681
682 const FIELDS: &[&str] = &["id", "name", "task_list", "agents"];
683 deserializer.deserialize_struct("Workflow", FIELDS, WorkflowVisitor)
684 }
685}
686
687#[pyclass(name = "Workflow")]
688#[derive(Debug, Clone)]
689pub struct PyWorkflow {
690 workflow: Workflow,
691
692 output_types: HashMap<String, Arc<PyObject>>,
696
697 runtime: Arc<tokio::runtime::Runtime>,
699}
700
701#[pymethods]
702impl PyWorkflow {
703 #[new]
704 #[pyo3(signature = (name))]
705 pub fn new(name: &str) -> Result<Self, WorkflowError> {
706 debug!("Creating new workflow: {}", name);
707 Ok(Self {
708 workflow: Workflow::new(name),
709 output_types: HashMap::new(),
710 runtime: Arc::new(
711 tokio::runtime::Runtime::new()
712 .map_err(|e| WorkflowError::RuntimeError(e.to_string()))?,
713 ),
714 })
715 }
716
717 #[getter]
718 pub fn name(&self) -> String {
719 self.workflow.name.clone()
720 }
721
722 #[getter]
723 pub fn task_list(&self) -> TaskList {
724 self.workflow.task_list.clone()
725 }
726
727 #[getter]
728 pub fn is_workflow(&self) -> bool {
729 true
730 }
731
732 #[getter]
733 pub fn __workflow__(&self) -> String {
734 self.model_dump_json()
735 }
736
737 #[getter]
738 pub fn agents(&self) -> Result<HashMap<String, PyAgent>, WorkflowError> {
739 self.workflow
740 .agents
741 .iter()
742 .map(|(id, agent)| {
743 Ok((
744 id.clone(),
745 PyAgent {
746 agent: agent.clone(),
747 runtime: self.runtime.clone(),
748 },
749 ))
750 })
751 .collect::<Result<HashMap<_, _>, _>>()
752 }
753
754 #[pyo3(signature = (task_output_types))]
755 pub fn add_task_output_types<'py>(
756 &mut self,
757 task_output_types: Bound<'py, PyDict>,
758 ) -> PyResult<()> {
759 let converted: HashMap<String, Arc<PyObject>> = task_output_types
760 .iter()
761 .map(|(k, v)| -> PyResult<(String, Arc<PyObject>)> {
762 let key = k.extract::<String>()?;
764 let value = v.clone().unbind();
765 Ok((key, Arc::new(value)))
766 })
767 .collect::<PyResult<_>>()?;
768 self.output_types.extend(converted);
769 Ok(())
770 }
771
772 #[pyo3(signature = (task, output_type = None))]
773 pub fn add_task(
774 &mut self,
775 py: Python<'_>,
776 mut task: Task,
777 output_type: Option<Bound<'_, PyAny>>,
778 ) -> Result<(), WorkflowError> {
779 if let Some(output_type) = output_type {
780 (task.prompt.response_type, task.prompt.response_json_schema) =
782 parse_response_to_json(py, &output_type)
783 .map_err(|e| WorkflowError::InvalidOutputType(e.to_string()))?;
784
785 self.output_types
787 .insert(task.id.clone(), Arc::new(output_type.unbind()));
788 }
789
790 self.workflow.task_list.add_task(task)?;
791 Ok(())
792 }
793
794 pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
795 for task in tasks {
796 self.workflow.task_list.add_task(task)?;
797 }
798 Ok(())
799 }
800
801 pub fn add_agent(&mut self, agent: &Bound<'_, PyAgent>) {
802 let agent = agent.extract::<PyAgent>().unwrap().agent.clone();
804 self.workflow.agents.insert(agent.id.clone(), agent);
805 }
806
807 pub fn is_complete(&self) -> bool {
808 self.workflow.task_list.is_complete()
809 }
810
811 pub fn pending_count(&self) -> usize {
812 self.workflow.task_list.pending_count()
813 }
814
815 pub fn execution_plan<'py>(
816 &self,
817 py: Python<'py>,
818 ) -> Result<Bound<'py, PyDict>, WorkflowError> {
819 let plan = self.workflow.execution_plan()?;
820 debug!("Execution plan: {:?}", plan);
821
822 let json = serde_json::to_value(plan).map_err(|e| {
824 error!("Failed to serialize execution plan to JSON: {}", e);
825 e
826 })?;
827
828 let pydict = PyDict::new(py);
829 json_to_pydict(py, &json, &pydict)?;
830
831 Ok(pydict)
832 }
833
834 #[pyo3(signature = (global_context=None))]
835 pub fn run(
836 &self,
837 py: Python,
838 global_context: Option<Bound<'_, PyDict>>,
839 ) -> Result<WorkflowResult, WorkflowError> {
840 debug!("Running workflow: {}", self.workflow.name);
841
842 let global_context = if let Some(context) = global_context {
844 let json_value = pyobject_to_json(&context.into_bound_py_any(py)?)?;
846 Some(json_value)
847 } else {
848 None
849 };
850
851 let workflow: Arc<RwLock<Workflow>> = self
852 .runtime
853 .block_on(async { self.workflow.run(global_context).await })?;
854
855 let workflow_result = match Arc::try_unwrap(workflow) {
857 Ok(rwlock) => {
859 let workflow = rwlock
861 .into_inner()
862 .map_err(|_| WorkflowError::LockAcquireError)?;
863
864 let events = workflow
866 .event_tracker
867 .read()
868 .unwrap()
869 .events
870 .read()
871 .unwrap()
872 .clone();
873
874 WorkflowResult::new(py, workflow.task_list.tasks(), &self.output_types, events)
876 }
877 Err(arc) => {
879 error!("Workflow still has other references, reading instead of consuming.");
881 let workflow = arc
882 .read()
883 .map_err(|_| WorkflowError::ReadLockAcquireError)?;
884
885 let events = workflow
887 .event_tracker
888 .read()
889 .unwrap()
890 .events
891 .read()
892 .unwrap()
893 .clone();
894
895 WorkflowResult::new(py, workflow.task_list.tasks(), &self.output_types, events)
896 }
897 };
898
899 info!("Workflow execution completed successfully.");
900 Ok(workflow_result)
901 }
902
903 pub fn model_dump_json(&self) -> String {
904 serde_json::to_string(&self.workflow).unwrap()
905 }
906
907 #[staticmethod]
908 #[pyo3(signature = (json_string, output_types=None))]
909 pub fn model_validate_json(
910 json_string: String,
911 output_types: Option<Bound<'_, PyDict>>,
912 ) -> Result<Self, WorkflowError> {
913 let workflow: Workflow = serde_json::from_str(&json_string)?;
914 let runtime = Arc::new(
915 tokio::runtime::Runtime::new()
916 .map_err(|e| WorkflowError::RuntimeError(e.to_string()))?,
917 );
918
919 let output_types = if let Some(output_types) = output_types {
920 output_types
921 .iter()
922 .map(|(k, v)| -> PyResult<(String, Arc<PyObject>)> {
923 let key = k.extract::<String>()?;
924 let value = v.clone().unbind();
925 Ok((key, Arc::new(value)))
926 })
927 .collect::<PyResult<HashMap<String, Arc<PyObject>>>>()?
928 } else {
929 HashMap::new()
930 };
931
932 let py_workflow = PyWorkflow {
933 workflow,
934 output_types,
935 runtime,
936 };
937
938 Ok(py_workflow)
939 }
940}
941
942#[cfg(test)]
943mod tests {
944 use super::*;
945 use potato_prompt::prompt::ResponseType;
946 use potato_prompt::{prompt::types::PromptContent, Message, Prompt};
947
948 #[test]
949 fn test_workflow_creation() {
950 let workflow = Workflow::new("Test Workflow");
951 assert_eq!(workflow.name, "Test Workflow");
952 assert_eq!(workflow.id.len(), 36); }
954
955 #[test]
956 fn test_task_list_add_and_get() {
957 let mut task_list = TaskList::new();
958 let prompt_content = PromptContent::Str("Test prompt".to_string());
959 let prompt = Prompt::new_rs(
960 vec![Message::new_rs(prompt_content)],
961 "gpt-4o",
962 potato_type::Provider::OpenAI,
963 vec![],
964 None,
965 None,
966 ResponseType::Null,
967 )
968 .unwrap();
969
970 let task = Task::new("task1", prompt, "task1", None, None);
971 task_list.add_task(task.clone()).unwrap();
972 assert_eq!(
973 task_list.get_task(&task.id).unwrap().read().unwrap().id,
974 task.id
975 );
976 task_list.reset_failed_tasks().unwrap();
977 }
978}