1use crate::tasklist::TaskList;
2
3use crate::{
4 events::{EventTracker, TaskEvent},
5 workflow::error::WorkflowError,
6};
7pub use potato_agent::agents::{
8 agent::{Agent, PyAgent},
9 task::{Task, TaskStatus, WorkflowTask},
10};
11use potato_agent::{AgentError, PyAgentResponse};
12use potato_state::block_on;
13use potato_type::prompt::{parse_response_to_json, MessageNum};
14use potato_type::Provider;
15use potato_util::utils::depythonize_object_to_value;
16use potato_util::{create_uuid7, utils::update_serde_map_with, PyHelperFuncs};
17use pyo3::prelude::*;
18use pythonize::pythonize;
19use serde::{
20 de::{self, MapAccess, Visitor},
21 ser::SerializeStruct,
22 Deserialize, Deserializer, Serialize, Serializer,
23};
24use serde_json::Map;
25use serde_json::Value;
26use std::collections::{HashMap, HashSet};
27use std::sync::Arc;
28use std::sync::RwLock;
29use tracing::instrument;
30use tracing::{debug, error, info, warn};
31
32use pyo3::types::PyDict;
34
35pub type Context = (HashMap<String, Vec<MessageNum>>, Value, Option<Arc<Value>>);
36
37#[derive(Debug)]
38#[pyclass]
39pub struct WorkflowResult {
40 #[pyo3(get)]
41 pub tasks: HashMap<String, Py<WorkflowTask>>,
42
43 #[pyo3(get)]
44 pub events: Vec<TaskEvent>,
45
46 last_task_id: Option<String>,
47}
48
49impl WorkflowResult {
50 pub fn new(
51 py: Python,
52 tasks: HashMap<String, Task>,
53 output_types: &HashMap<String, Arc<Py<PyAny>>>,
54 events: Vec<TaskEvent>,
55 last_task_id: Option<String>,
56 ) -> Self {
57 let py_tasks = tasks
58 .into_iter()
59 .map(|(id, task)| {
60 let py_agent_response = if let Some(result) = task.result {
61 let output_type = output_types.get(&id).map(|arc| arc.as_ref().clone_ref(py));
62 Some(PyAgentResponse::new(result, output_type))
63 } else {
64 None
65 };
66 let py_task = WorkflowTask {
67 id: task.id.clone(),
68 prompt: task.prompt,
69 dependencies: task.dependencies,
70 status: task.status,
71 agent_id: task.agent_id,
72 result: py_agent_response,
73 max_retries: task.max_retries,
74 retry_count: task.retry_count,
75 };
76 (id, Py::new(py, py_task).unwrap())
77 })
78 .collect::<HashMap<_, _>>();
79
80 Self {
81 tasks: py_tasks,
82 events,
83 last_task_id,
84 }
85 }
86}
87
88#[pymethods]
89impl WorkflowResult {
90 pub fn __str__(&self) -> String {
91 let json = serde_json::json!({
93 "tasks": serde_json::to_value(&self.tasks).unwrap_or(Value::Null),
94 "events": serde_json::to_value(&self.events).unwrap_or(Value::Null)
95 });
96
97 PyHelperFuncs::__str__(&json)
98 }
99
100 #[getter]
102 pub fn result<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
103 if let Some(last_task_id) = &self.last_task_id {
104 if let Some(task) = self.tasks.get(last_task_id) {
105 let result = task.bind(py).getattr("result")?;
106 return Ok(result);
107 }
108 }
109 Ok(py.None().bind(py).clone())
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct Workflow {
116 pub id: String,
117 pub name: String,
118 pub task_list: TaskList,
119 pub agents: HashMap<String, Arc<Agent>>,
120 pub event_tracker: Arc<RwLock<EventTracker>>,
121 pub global_context: Option<Arc<Value>>,
122}
123
124impl PartialEq for Workflow {
125 fn eq(&self, other: &Self) -> bool {
126 self.id == other.id && self.name == other.name
128 }
129}
130
131impl Workflow {
132 pub async fn reset_agents(&mut self) -> Result<(), WorkflowError> {
142 let mut agents_map = self.agents.clone();
143
144 for agent in self.agents.values_mut() {
145 agents_map.insert(agent.id.clone(), Arc::new(agent.rebuild_client().await?));
146 }
147 self.agents = agents_map;
148 Ok(())
149 }
150 pub fn new(name: &str) -> Self {
151 debug!("Creating new workflow: {}", name);
152 let id = create_uuid7();
153 Self {
154 id: id.clone(),
155 name: name.to_string(),
156 task_list: TaskList::new(),
157 agents: HashMap::new(),
158 event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
159 global_context: None, }
161 }
162 pub fn events(&self) -> Vec<TaskEvent> {
163 let tracker = self.event_tracker.read().unwrap();
164 let events = tracker.events.read().unwrap().clone();
165 events
166 }
167
168 pub fn total_duration(&self) -> i32 {
169 let tracker = self.event_tracker.read().unwrap();
170
171 if tracker.is_empty() {
172 0
173 } else {
174 let mut total_duration = chrono::Duration::zero();
176 for event in tracker.events.read().unwrap().iter() {
177 total_duration += event.details.duration.unwrap_or(chrono::Duration::zero());
178 }
179 total_duration.subsec_millis()
180 }
181 }
182
183 pub fn get_new_workflow(
184 &self,
185 global_context: Option<Arc<Value>>,
186 ) -> Result<Self, WorkflowError> {
187 let id = create_uuid7();
189
190 let task_list = self.task_list.deep_clone()?;
192
193 Ok(Workflow {
194 id: id.clone(),
195 name: self.name.clone(),
196 task_list,
197 agents: self.agents.clone(), event_tracker: Arc::new(RwLock::new(EventTracker::new(id))),
199 global_context,
200 })
201 }
202
203 pub async fn run(
204 &self,
205 global_context: Option<Value>,
206 ) -> Result<Arc<RwLock<Workflow>>, WorkflowError> {
207 debug!("Running workflow: {}", self.name);
208
209 let global_context = global_context.map(Arc::new);
210 let run_workflow = Arc::new(RwLock::new(self.get_new_workflow(global_context)?));
211
212 execute_workflow(&run_workflow).await?;
213
214 Ok(run_workflow)
215 }
216
217 pub fn is_complete(&self) -> bool {
218 self.task_list.is_complete()
219 }
220
221 pub fn pending_count(&self) -> usize {
222 self.task_list.pending_count()
223 }
224
225 pub fn add_task(&mut self, task: Task) -> Result<(), WorkflowError> {
226 self.task_list.add_task(task)
227 }
228
229 pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
230 for task in tasks {
231 self.task_list.add_task(task)?;
232 }
233 Ok(())
234 }
235
236 pub fn add_agent(&mut self, agent: &Agent) {
237 self.agents
238 .insert(agent.id.clone(), Arc::new(agent.clone()));
239 }
240
241 pub fn add_agents(&mut self, agents: &[&Agent]) {
242 for agent in agents {
243 self.add_agent(agent);
244 }
245 }
246
247 pub async fn execute_task(&self, task: &str, context: &Value) -> Result<Value, WorkflowError> {
248 let task = self
249 .task_list
250 .get_task(task)
251 .ok_or_else(|| WorkflowError::TaskNotFound(task.to_string()))?;
252
253 let agent = {
254 let task_guard = task.read().map_err(|_| WorkflowError::TaskLockError)?;
255 self.agents
256 .get(&task_guard.agent_id)
257 .ok_or_else(|| WorkflowError::AgentNotFound(task_guard.agent_id.clone()))?
258 .clone()
259 };
260
261 let max_retries = {
262 let task_guard = task.read().map_err(|_| WorkflowError::TaskLockError)?;
263 task_guard.retry_count
264 };
265
266 for attempt in 0..=max_retries {
267 match agent.execute_task_with_context(&task, context).await {
268 Ok(response) => {
269 let is_valid = validate_response_schema(&task, &response);
270
271 if !is_valid {
272 if attempt == max_retries {
273 let (task_id, expected_schema, received_response) = {
274 let task_guard = task.read().unwrap();
275 (
276 task_guard.id.clone(),
277 task_guard
278 .prompt
279 .response_json_schema()
280 .map(|s| s.to_string())
281 .unwrap_or_else(|| "No schema".to_string()),
282 response
283 .response_value()
284 .map(|v| v.to_string())
285 .unwrap_or_else(|| "No response".to_string()),
286 )
287 };
288
289 error!(
290 "Task {} response validation failed after {} attempts",
291 task_id,
292 max_retries + 1
293 );
294
295 return Err(WorkflowError::ResponseValidationFailed {
296 task_id,
297 expected_schema,
298 received_response,
299 });
300 }
301 warn!(
302 "Task validation failed (attempt {}/{}), retrying...",
303 attempt + 1,
304 max_retries + 1
305 );
306 continue;
307 }
308
309 return Ok(response.response_value().unwrap_or(Value::Null));
310 }
311 Err(e) => {
312 let task_id = { task.read().unwrap().id.clone() };
313 warn!(
314 "Task {} execution failed (attempt {}/{}): {}",
315 task_id,
316 attempt + 1,
317 max_retries + 1,
318 e
319 );
320
321 if attempt == max_retries {
322 error!("Task {} exceeded max retries ({})", task_id, max_retries);
323 return Err(WorkflowError::MaxRetriesExceeded(task_id));
324 }
325 }
326 }
327 }
328
329 unreachable!("Loop should always return via Ok or error")
330 }
331
332 pub fn execution_plan(&self) -> Result<HashMap<i32, HashSet<String>>, WorkflowError> {
333 let mut remaining: HashMap<String, HashSet<String>> = self
334 .task_list
335 .tasks
336 .iter()
337 .map(|(id, task)| {
338 (
339 id.clone(),
340 task.read().unwrap().dependencies.iter().cloned().collect(),
341 )
342 })
343 .collect();
344
345 let mut executed = HashSet::new();
346 let mut plan = HashMap::new();
347 let mut step = 1;
348
349 while !remaining.is_empty() {
350 let ready_keys: Vec<String> = remaining
352 .iter()
353 .filter(|(_, deps)| deps.is_subset(&executed))
354 .map(|(id, _)| id.to_string())
355 .collect();
356
357 if ready_keys.is_empty() {
358 break;
360 }
361
362 let mut ready_set = HashSet::with_capacity(ready_keys.len());
364
365 for key in ready_keys {
367 executed.insert(key.clone());
368 remaining.remove(&key);
369 ready_set.insert(key);
370 }
371
372 plan.insert(step, ready_set);
374
375 step += 1;
376 }
377
378 Ok(plan)
379 }
380
381 pub fn __str__(&self) -> String {
382 PyHelperFuncs::__str__(&self.task_list)
383 }
384
385 pub fn serialize(&self) -> Result<String, serde_json::Error> {
386 let json = serde_json::to_string(self).unwrap();
388 Ok(json)
390 }
391
392 pub fn from_json(json: &str) -> Result<Self, WorkflowError> {
393 Ok(serde_json::from_str(json)?)
395 }
396
397 pub fn task_names(&self) -> Vec<String> {
398 self.task_list
399 .tasks
400 .keys()
401 .cloned()
402 .collect::<Vec<String>>()
403 }
404
405 pub fn last_task_id(&self) -> Option<String> {
406 self.task_list.get_last_task_id()
407 }
408}
409
410fn is_workflow_complete(workflow: &Arc<RwLock<Workflow>>) -> bool {
415 workflow.read().unwrap().is_complete()
416}
417
418fn reset_failed_workflow_tasks(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
423 match workflow.write().unwrap().task_list.reset_failed_tasks() {
424 Ok(_) => Ok(()),
425 Err(e) => {
426 warn!("Failed to reset failed tasks: {}", e);
427 Err(e)
428 }
429 }
430}
431
432fn get_ready_tasks(workflow: &Arc<RwLock<Workflow>>) -> Vec<Arc<RwLock<Task>>> {
437 workflow.read().unwrap().task_list.get_ready_tasks()
438}
439
440fn check_for_circular_dependencies(workflow: &Arc<RwLock<Workflow>>) -> bool {
445 let pending_count = workflow.read().unwrap().pending_count();
446
447 if pending_count > 0 {
448 warn!(
449 "No ready tasks found but {} pending tasks remain. Possible circular dependency.",
450 pending_count
451 );
452 return true;
453 }
454
455 false
456}
457
458fn mark_task_as_running(task: Arc<RwLock<Task>>, event_tracker: &Arc<RwLock<EventTracker>>) {
463 let mut task = task.write().unwrap();
464 task.set_status(TaskStatus::Running);
465 event_tracker.write().unwrap().record_task_started(&task.id);
466}
467
468fn get_agent_for_task(
473 workflow: &Arc<RwLock<Workflow>>,
474 agent_id: &str,
475) -> Result<Arc<Agent>, WorkflowError> {
476 let wf = workflow.read().unwrap();
477 match wf.agents.get(agent_id) {
478 Some(agent) => Ok(agent.clone()),
479 None => Err(WorkflowError::AgentNotFound(agent_id.to_string())),
480 }
481}
482
483#[instrument(skip_all)]
498fn build_task_context(
499 workflow: &Arc<RwLock<Workflow>>,
500 task_dependencies: &Vec<String>,
501 provider: &Provider,
502) -> Result<Context, WorkflowError> {
503 let wf = workflow.read().unwrap();
504 let mut ctx = HashMap::new();
505 let mut param_ctx: Value = Value::Object(Map::new());
506
507 for dep_id in task_dependencies {
508 debug!("Building context for task dependency: {}", dep_id);
509 if let Some(dep) = wf.task_list.get_task(dep_id) {
510 if let Some(result) = &dep.read().unwrap().result {
511 let msg_to_insert = result.response.to_message_num(provider);
512
513 match msg_to_insert {
514 Ok(message) => {
515 ctx.insert(dep_id.clone(), message);
516 }
517 Err(e) => {
518 warn!("Failed to convert response to message: {}", e);
519 }
520 }
521
522 if let Some(structure_output) = result.response.extract_structured_data() {
523 if structure_output.is_object() {
526 update_serde_map_with(&mut param_ctx, &structure_output)?;
528 }
529 }
530 }
531 }
532 }
533
534 debug!("Built context for task dependencies: {:?}", ctx);
535 let global_context = workflow
536 .read()
537 .unwrap()
538 .global_context
539 .as_ref()
540 .map(Arc::clone);
541
542 Ok((ctx, param_ctx, global_context))
543}
544
545fn validate_response_schema(
553 task: &Arc<RwLock<Task>>,
554 response: &potato_agent::AgentResponse,
555) -> bool {
556 task.read()
557 .ok()
558 .and_then(|t| {
559 response
560 .response_value()
561 .map(|value| t.validate_output(&value).is_ok())
562 })
563 .unwrap_or(true)
564}
565
566fn spawn_task_execution(
575 event_tracker: Arc<RwLock<EventTracker>>,
576 task: Arc<RwLock<Task>>,
577 task_id: String,
578 agent: Arc<Agent>,
579 context: HashMap<String, Vec<MessageNum>>,
580 parameter_context: Value,
581 global_context: Option<Arc<Value>>,
582) -> tokio::task::JoinHandle<()> {
583 tokio::spawn(async move {
584 let result = agent
585 .execute_task_with_context_message(&task, context, parameter_context, global_context)
586 .await;
587
588 match result {
589 Ok(response) => {
590 info!("Task {} completed successfully", task_id);
591
592 let is_valid = validate_response_schema(&task, &response); if !is_valid {
594 error!(
595 "Task {} response validation against JSON schema failed",
596 task_id
597 );
598
599 if let Ok(mut write_task) = task.write() {
600 write_task.set_status(TaskStatus::Failed);
601
602 if let Ok(tracker) = event_tracker.write() {
603 tracker.record_task_failed(
604 &write_task.id,
605 "Response JSON schema validation failed",
606 &write_task.prompt,
607 );
608 }
609 }
610 return;
611 }
612
613 if let Ok(mut write_task) = task.write() {
614 write_task.set_status(TaskStatus::Completed);
615 write_task.set_result(response.clone());
616
617 if let Ok(tracker) = event_tracker.write() {
618 tracker.record_task_completed(&write_task.id, &write_task.prompt, response);
619 }
620 }
621 }
622 Err(e) => {
623 error!("Task {} failed: {}", task_id, e);
624
625 if let Ok(mut write_task) = task.write() {
626 write_task.set_status(TaskStatus::Failed);
627
628 if let Ok(tracker) = event_tracker.write() {
629 tracker.record_task_failed(
630 &write_task.id,
631 &e.to_string(),
632 &write_task.prompt,
633 );
634 }
635 }
636 }
637 }
638 })
639}
640
641fn get_parameters_from_context(task: Arc<RwLock<Task>>) -> (String, Vec<String>, String, Provider) {
642 let (task_id, dependencies, agent_id, provider) = {
643 let task_guard = task.read().unwrap();
644 (
645 task_guard.id.clone(),
646 task_guard.dependencies.clone(),
647 task_guard.agent_id.clone(),
648 task_guard.prompt.provider.clone(),
649 )
650 };
651
652 (task_id, dependencies, agent_id, provider)
653}
654
655fn spawn_task_executions(
661 workflow: &Arc<RwLock<Workflow>>,
662 ready_tasks: Vec<Arc<RwLock<Task>>>,
663) -> Result<Vec<tokio::task::JoinHandle<()>>, WorkflowError> {
664 let mut handles = Vec::with_capacity(ready_tasks.len());
665
666 let event_tracker = workflow.read().unwrap().event_tracker.clone();
668
669 for task in ready_tasks {
670 let (task_id, dependencies, agent_id, provider) = get_parameters_from_context(task.clone());
672
673 mark_task_as_running(task.clone(), &event_tracker);
676
677 let (context, parameter_context, global_context) =
683 build_task_context(workflow, &dependencies, &provider)?;
684
685 let agent = get_agent_for_task(workflow, &agent_id)?;
687
688 let handle = spawn_task_execution(
690 event_tracker.clone(),
691 task.clone(),
692 task_id,
693 agent,
694 context,
695 parameter_context,
696 global_context,
697 );
698 handles.push(handle);
699 }
700
701 Ok(handles)
702}
703
704async fn await_task_completions(handles: Vec<tokio::task::JoinHandle<()>>) {
709 for handle in handles {
710 if let Err(e) = handle.await {
711 warn!("Task execution failed: {}", e);
712 }
713 }
714}
715
716#[instrument(skip_all)]
729pub async fn execute_workflow(workflow: &Arc<RwLock<Workflow>>) -> Result<(), WorkflowError> {
730 debug!("Starting workflow execution");
734
735 while !is_workflow_complete(workflow) {
737 reset_failed_workflow_tasks(workflow)?;
740
741 let ready_tasks = get_ready_tasks(workflow);
744 debug!("Found {} ready tasks for execution", ready_tasks.len());
745
746 if ready_tasks.is_empty() {
748 if check_for_circular_dependencies(workflow) {
749 break;
750 }
751 continue;
752 }
753
754 let handles = spawn_task_executions(workflow, ready_tasks)?;
756
757 await_task_completions(handles).await;
759 }
760
761 debug!("Workflow execution completed");
762 Ok(())
763}
764
765impl Serialize for Workflow {
766 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
767 where
768 S: Serializer,
769 {
770 let mut state = serializer.serialize_struct("Workflow", 4)?;
771
772 state.serialize_field("id", &self.id)?;
774 state.serialize_field("name", &self.name)?;
775 state.serialize_field("task_list", &self.task_list)?;
776 state.serialize_field("agents", &self.agents)?;
777
778 state.end()
779 }
780}
781
782impl<'de> Deserialize<'de> for Workflow {
783 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
784 where
785 D: Deserializer<'de>,
786 {
787 #[derive(Deserialize)]
788 #[serde(field_identifier, rename_all = "snake_case")]
789 enum Field {
790 Id,
791 Name,
792 TaskList,
793 Agents,
794 }
795
796 struct WorkflowVisitor;
797
798 impl<'de> Visitor<'de> for WorkflowVisitor {
799 type Value = Workflow;
800
801 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
802 formatter.write_str("struct Workflow")
803 }
804
805 fn visit_map<V>(self, mut map: V) -> Result<Workflow, V::Error>
806 where
807 V: MapAccess<'de>,
808 {
809 let mut id = None;
810 let mut name = None;
811 let mut task_list_data = None;
812 let mut agents: Option<HashMap<String, Agent>> = None;
813
814 while let Some(key) = map.next_key()? {
815 match key {
816 Field::Id => {
817 let value: String = map.next_value().map_err(|e| {
818 error!("Failed to deserialize field 'id': {e}");
819 de::Error::custom(format!("Failed to deserialize field 'id': {e}"))
820 })?;
821 id = Some(value);
822 }
823 Field::TaskList => {
824 let value: TaskList = map.next_value().map_err(|e| {
826 error!("Failed to deserialize field 'task_list': {e}");
827 de::Error::custom(format!(
828 "Failed to deserialize field 'task_list': {e}",
829 ))
830 })?;
831
832 task_list_data = Some(value);
833 }
834 Field::Name => {
835 let value: String = map.next_value().map_err(|e| {
836 error!("Failed to deserialize field 'name': {e}");
837 de::Error::custom(format!(
838 "Failed to deserialize field 'name': {e}",
839 ))
840 })?;
841 name = Some(value);
842 }
843 Field::Agents => {
844 let value: HashMap<String, Agent> = map.next_value().map_err(|e| {
845 error!("Failed to deserialize field 'agents': {e}");
846 de::Error::custom(format!(
847 "Failed to deserialize field 'agents': {e}"
848 ))
849 })?;
850 agents = Some(value);
851 }
852 }
853 }
854
855 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
856 let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
857 let task_list_data =
858 task_list_data.ok_or_else(|| de::Error::missing_field("task_list"))?;
859 let agents = agents.ok_or_else(|| de::Error::missing_field("agents"))?;
860
861 let event_tracker = Arc::new(RwLock::new(EventTracker::new(create_uuid7())));
862
863 let agents = agents
865 .into_iter()
866 .map(|(id, agent)| (id, Arc::new(agent)))
867 .collect();
868
869 Ok(Workflow {
870 id,
871 name,
872 task_list: task_list_data,
873 agents,
874 event_tracker,
875 global_context: None, })
877 }
878 }
879
880 const FIELDS: &[&str] = &["id", "name", "task_list", "agents"];
881 deserializer.deserialize_struct("Workflow", FIELDS, WorkflowVisitor)
882 }
883}
884
885#[pyclass(name = "Workflow")]
886#[derive(Debug, Clone)]
887pub struct PyWorkflow {
888 workflow: Workflow,
889
890 output_types: HashMap<String, Arc<Py<PyAny>>>,
894}
895
896#[pymethods]
897impl PyWorkflow {
898 #[new]
899 #[pyo3(signature = (name))]
900 pub fn new(name: &str) -> Result<Self, WorkflowError> {
901 debug!("Creating new workflow: {}", name);
902 Ok(Self {
903 workflow: Workflow::new(name),
904 output_types: HashMap::new(),
905 })
906 }
907
908 #[getter]
909 pub fn name(&self) -> String {
910 self.workflow.name.clone()
911 }
912
913 #[getter]
914 pub fn task_list(&self) -> TaskList {
915 self.workflow.task_list.clone()
916 }
917
918 #[getter]
919 pub fn is_workflow(&self) -> bool {
920 true
921 }
922
923 #[getter]
924 pub fn __workflow__(&self) -> Result<String, WorkflowError> {
925 self.model_dump_json()
926 }
927
928 #[getter]
929 pub fn agents(&self) -> Result<HashMap<String, PyAgent>, WorkflowError> {
930 self.workflow
931 .agents
932 .iter()
933 .map(|(id, agent)| {
934 Ok((
935 id.clone(),
936 PyAgent {
937 agent: agent.clone(),
938 },
939 ))
940 })
941 .collect::<Result<HashMap<_, _>, _>>()
942 }
943
944 #[pyo3(signature = (task_output_types))]
945 pub fn add_task_output_types<'py>(
946 &mut self,
947 task_output_types: Bound<'py, PyDict>,
948 ) -> PyResult<()> {
949 let converted: HashMap<String, Arc<Py<PyAny>>> = task_output_types
950 .iter()
951 .map(|(k, v)| -> PyResult<(String, Arc<Py<PyAny>>)> {
952 let key = k.extract::<String>()?;
954 let value = v.clone().unbind();
955 Ok((key, Arc::new(value)))
956 })
957 .collect::<PyResult<_>>()?;
958 self.output_types.extend(converted);
959 Ok(())
960 }
961
962 #[pyo3(signature = (task, output_type = None))]
963 pub fn add_task(
964 &mut self,
965 py: Python<'_>,
966 mut task: Task,
967 output_type: Option<Bound<'_, PyAny>>,
968 ) -> Result<(), WorkflowError> {
969 if let Some(output_type) = output_type {
970 let (response_type, response_json_schema) = parse_response_to_json(py, &output_type)
972 .map_err(|e| WorkflowError::InvalidOutputType(e.to_string()))?;
973
974 task.prompt
976 .set_response_json_schema(response_json_schema, response_type);
977
978 self.output_types
980 .insert(task.id.clone(), Arc::new(output_type.unbind()));
981 }
982
983 self.workflow.task_list.add_task(task)?;
984 Ok(())
985 }
986
987 pub fn add_tasks(&mut self, tasks: Vec<Task>) -> Result<(), WorkflowError> {
988 for task in tasks {
989 self.workflow.task_list.add_task(task)?;
990 }
991 Ok(())
992 }
993
994 pub fn add_agent(&mut self, agent: &Bound<'_, PyAgent>) {
995 let agent = agent.extract::<PyAgent>().unwrap().agent.clone();
997 self.workflow.agents.insert(agent.id.clone(), agent);
998 }
999
1000 pub fn add_agents(&mut self, agents: Vec<Bound<'_, PyAgent>>) {
1001 for agent in agents {
1002 self.add_agent(&agent);
1003 }
1004 }
1005
1006 pub fn is_complete(&self) -> bool {
1007 self.workflow.task_list.is_complete()
1008 }
1009
1010 pub fn pending_count(&self) -> usize {
1011 self.workflow.task_list.pending_count()
1012 }
1013
1014 pub fn execution_plan<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, WorkflowError> {
1015 let plan = self.workflow.execution_plan()?;
1016 debug!("Execution plan: {:?}", plan);
1017
1018 let json = serde_json::to_value(plan).map_err(|e| {
1020 error!("Failed to serialize execution plan to JSON: {}", e);
1021 e
1022 })?;
1023
1024 Ok(pythonize(py, &json)?)
1025 }
1026
1027 #[pyo3(signature = (global_context=None))]
1028 pub fn run(
1029 &self,
1030 py: Python,
1031 global_context: Option<Bound<'_, PyAny>>,
1032 ) -> Result<WorkflowResult, WorkflowError> {
1033 debug!("Running workflow: {}", self.workflow.name);
1034
1035 let global_context = if let Some(context) = global_context {
1037 let json_value = depythonize_object_to_value(py, &context)?;
1038 Some(json_value)
1039 } else {
1040 None
1041 };
1042
1043 let workflow: Arc<RwLock<Workflow>> =
1044 block_on(async { self.workflow.run(global_context).await })?;
1045
1046 let workflow_result = match Arc::try_unwrap(workflow) {
1048 Ok(rwlock) => {
1050 let workflow = rwlock
1052 .into_inner()
1053 .map_err(|_| WorkflowError::LockAcquireError)?;
1054
1055 let events = workflow
1057 .event_tracker
1058 .read()
1059 .unwrap()
1060 .events
1061 .read()
1062 .unwrap()
1063 .clone();
1064
1065 WorkflowResult::new(
1067 py,
1068 workflow.task_list.tasks(),
1069 &self.output_types,
1070 events,
1071 workflow.task_list.get_last_task_id(),
1072 )
1073 }
1074 Err(arc) => {
1076 error!("Workflow still has other references, reading instead of consuming.");
1078 let workflow = arc
1079 .read()
1080 .map_err(|_| WorkflowError::ReadLockAcquireError)?;
1081
1082 let events = workflow
1084 .event_tracker
1085 .read()
1086 .unwrap()
1087 .events
1088 .read()
1089 .unwrap()
1090 .clone();
1091
1092 WorkflowResult::new(
1093 py,
1094 workflow.task_list.tasks(),
1095 &self.output_types,
1096 events,
1097 workflow.task_list.get_last_task_id(),
1098 )
1099 }
1100 };
1101
1102 info!("Workflow execution completed successfully.");
1103 Ok(workflow_result)
1104 }
1105
1106 #[pyo3(signature = (task_id, context=None))]
1107 pub fn execute_task<'py>(
1108 &self,
1109 py: Python<'py>,
1110 task_id: String,
1111 context: Option<Bound<'py, PyAny>>,
1112 ) -> Result<Bound<'py, PyAny>, WorkflowError> {
1113 let context_value = if let Some(ctx) = context {
1114 depythonize_object_to_value(py, &ctx)?
1115 } else {
1116 Value::Null
1117 };
1118
1119 let response_value =
1120 block_on(async { self.workflow.execute_task(&task_id, &context_value).await })?;
1121 let py_response = pythonize(py, &response_value)?;
1122
1123 Ok(py_response)
1124 }
1125
1126 pub fn model_dump_json(&self) -> Result<String, WorkflowError> {
1127 Ok(self.workflow.serialize()?)
1128 }
1129
1130 #[staticmethod]
1131 #[pyo3(signature = (json_string, output_types=None))]
1132 pub fn model_validate_json(
1133 json_string: String,
1134 output_types: Option<Bound<'_, PyDict>>,
1135 ) -> Result<Self, WorkflowError> {
1136 let mut workflow: Workflow = Workflow::from_json(&json_string)?;
1137
1138 workflow.task_list.rebuild_task_validators()?;
1140
1141 block_on(async { workflow.reset_agents().await })?;
1144 let output_types = match output_types {
1145 Some(output_types) => output_types
1146 .iter()
1147 .map(|(k, v)| -> PyResult<(String, Arc<Py<PyAny>>)> {
1148 let key = k.extract::<String>()?;
1149 let value = v.clone().unbind();
1150 Ok((key, Arc::new(value)))
1151 })
1152 .collect::<PyResult<HashMap<String, Arc<Py<PyAny>>>>>()?,
1153 None => HashMap::new(),
1154 };
1155
1156 let py_workflow = PyWorkflow {
1157 workflow,
1158 output_types,
1159 };
1160
1161 Ok(py_workflow)
1162 }
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167 use super::*;
1168 use potato_type::openai::v1::chat::request::{
1169 ChatMessage as OpenAIChatMessage, ContentPart, TextContentPart,
1170 };
1171 use potato_type::prompt::Prompt;
1172 use potato_type::prompt::ResponseType;
1173
1174 fn create_openai_chat_message_num() -> MessageNum {
1175 let text_part = TextContentPart::new("What company is this logo from?".to_string());
1176 let text_content_part = ContentPart::Text(text_part);
1177 let text_message = OpenAIChatMessage {
1178 role: "user".to_string(),
1179 content: vec![text_content_part],
1180 name: None,
1181 };
1182 MessageNum::OpenAIMessageV1(text_message)
1183 }
1184
1185 #[test]
1186 fn test_workflow_creation() {
1187 let workflow = Workflow::new("Test Workflow");
1188 assert_eq!(workflow.name, "Test Workflow");
1189 assert_eq!(workflow.id.len(), 36); }
1191
1192 #[test]
1193 fn test_task_list_add_and_get() {
1194 let mut task_list = TaskList::new();
1195
1196 let prompt = Prompt::new_rs(
1197 vec![create_openai_chat_message_num()],
1198 "gpt-4o",
1199 potato_type::Provider::OpenAI,
1200 vec![],
1201 None,
1202 None,
1203 ResponseType::Null,
1204 )
1205 .unwrap();
1206
1207 let task = Task::new("task1", prompt, "task1", None, None).unwrap();
1208 task_list.add_task(task.clone()).unwrap();
1209 assert_eq!(
1210 task_list.get_task(&task.id).unwrap().read().unwrap().id,
1211 task.id
1212 );
1213 task_list.reset_failed_tasks().unwrap();
1214 }
1215}