1use std::collections::HashMap;
27use std::fmt;
28use std::sync::Arc;
29use std::time::Instant;
30
31use chrono::{DateTime, Utc};
32use rust_decimal::Decimal;
33use serde_json::Value;
34use tokio::task::{Id, JoinSet};
35use tracing::{error, info};
36use uuid::Uuid;
37
38use ironflow_core::error::{AgentError, OperationError};
39use ironflow_core::provider::AgentProvider;
40use ironflow_store::models::{
41 NewRun, NewStep, NewStepDependency, RunStatus, RunUpdate, Step, StepKind, StepStatus,
42 StepUpdate, TriggerKind,
43};
44use ironflow_store::store::Store;
45
46use crate::config::{
47 AgentStepConfig, ApprovalConfig, HttpConfig, ShellConfig, StepConfig, WorkflowStepConfig,
48};
49use crate::error::EngineError;
50use crate::executor::{ParallelStepResult, StepOutput, execute_step_config};
51use crate::handler::WorkflowHandler;
52use crate::log_sender::{LogSender, StepLogSender};
53use crate::operation::Operation;
54
55pub(crate) type HandlerResolver =
57 Arc<dyn Fn(&str) -> Option<Arc<dyn WorkflowHandler>> + Send + Sync>;
58
59pub struct WorkflowContext {
78 run_id: Uuid,
79 store: Arc<dyn Store>,
80 provider: Arc<dyn AgentProvider>,
81 handler_resolver: Option<HandlerResolver>,
82 position: u32,
83 last_step_ids: Vec<Uuid>,
85 total_cost_usd: Decimal,
87 total_duration_ms: u64,
89 replay_steps: HashMap<u32, Step>,
92 log_sender: Option<LogSender>,
94}
95
96impl WorkflowContext {
97 pub fn new(run_id: Uuid, store: Arc<dyn Store>, provider: Arc<dyn AgentProvider>) -> Self {
102 Self {
103 run_id,
104 store,
105 provider,
106 handler_resolver: None,
107 position: 0,
108 last_step_ids: Vec::new(),
109 total_cost_usd: Decimal::ZERO,
110 total_duration_ms: 0,
111 replay_steps: HashMap::new(),
112 log_sender: None,
113 }
114 }
115
116 pub(crate) fn with_handler_resolver(
121 run_id: Uuid,
122 store: Arc<dyn Store>,
123 provider: Arc<dyn AgentProvider>,
124 resolver: HandlerResolver,
125 ) -> Self {
126 Self {
127 run_id,
128 store,
129 provider,
130 handler_resolver: Some(resolver),
131 position: 0,
132 last_step_ids: Vec::new(),
133 total_cost_usd: Decimal::ZERO,
134 total_duration_ms: 0,
135 replay_steps: HashMap::new(),
136 log_sender: None,
137 }
138 }
139
140 pub fn set_log_sender(&mut self, sender: LogSender) {
142 self.log_sender = Some(sender);
143 }
144
145 pub(crate) async fn load_replay_steps(&mut self) -> Result<(), EngineError> {
151 let steps = self.store.list_steps(self.run_id).await?;
152 for step in steps {
153 let dominated = matches!(
154 step.status.state,
155 StepStatus::Completed | StepStatus::Running | StepStatus::AwaitingApproval
156 );
157 if dominated {
158 self.replay_steps.insert(step.position, step);
159 }
160 }
161 Ok(())
162 }
163
164 pub fn run_id(&self) -> Uuid {
166 self.run_id
167 }
168
169 pub fn total_cost_usd(&self) -> Decimal {
171 self.total_cost_usd
172 }
173
174 pub fn total_duration_ms(&self) -> u64 {
176 self.total_duration_ms
177 }
178
179 pub async fn parallel(
216 &mut self,
217 steps: Vec<(&str, StepConfig)>,
218 fail_fast: bool,
219 ) -> Result<Vec<ParallelStepResult>, EngineError> {
220 if steps.is_empty() {
221 return Ok(Vec::new());
222 }
223
224 let wave_position = self.position;
225 self.position += 1;
226
227 let now = Utc::now();
228 let mut step_records: Vec<(Uuid, String, StepConfig)> = Vec::with_capacity(steps.len());
229
230 for (name, config) in &steps {
231 let kind = config.kind();
232 let step = self
233 .store
234 .create_step(NewStep {
235 run_id: self.run_id,
236 name: name.to_string(),
237 kind,
238 position: wave_position,
239 input: Some(serde_json::to_value(config)?),
240 })
241 .await?;
242
243 self.start_step(step.id, now).await?;
244
245 step_records.push((step.id, name.to_string(), config.clone()));
246 }
247
248 let mut join_set = JoinSet::new();
249 let mut task_index: HashMap<Id, usize> = HashMap::new();
250 for (idx, (step_id, step_name, config)) in step_records.iter().enumerate() {
251 let provider = self.provider.clone();
252 let config = config.clone();
253 let step_log_sender = self
254 .log_sender
255 .as_ref()
256 .map(|s| StepLogSender::new(s.clone(), self.run_id, *step_id, step_name.clone()));
257 let handle = join_set.spawn(async move {
258 (
259 idx,
260 execute_step_config(&config, &provider, step_log_sender).await,
261 )
262 });
263 task_index.insert(handle.id(), idx);
264 }
265
266 let mut indexed_results: Vec<Option<Result<StepOutput, String>>> =
268 vec![None; step_records.len()];
269 let mut first_error: Option<EngineError> = None;
270
271 while let Some(join_result) = join_set.join_next().await {
272 let (idx, step_result) = match join_result {
273 Ok(r) => r,
274 Err(e) => {
275 let error_msg = format!("join error: {e}");
276 if let Some(&idx) = task_index.get(&e.id()) {
277 let (step_id, step_name, _) = &step_records[idx];
278 let completed_at = Utc::now();
279 error!(
280 run_id = %self.run_id,
281 step = %step_name,
282 error = %error_msg,
283 "parallel step panicked or was cancelled"
284 );
285 if let Err(store_err) = self
286 .store
287 .update_step(
288 *step_id,
289 StepUpdate {
290 status: Some(StepStatus::Failed),
291 error: Some(error_msg.clone()),
292 completed_at: Some(completed_at),
293 ..StepUpdate::default()
294 },
295 )
296 .await
297 {
298 error!(
299 run_id = %self.run_id,
300 step_id = %step_id,
301 error = %store_err,
302 "failed to persist JoinError for step"
303 );
304 }
305 indexed_results[idx] = Some(Err(error_msg.clone()));
306 }
307 if first_error.is_none() {
308 first_error = Some(EngineError::StepConfig(error_msg));
309 }
310 if fail_fast {
311 join_set.abort_all();
312 }
313 continue;
314 }
315 };
316
317 let (step_id, step_name, _) = &step_records[idx];
318 let completed_at = Utc::now();
319
320 match step_result {
321 Ok(output) => {
322 self.total_cost_usd += output.cost_usd;
323 self.total_duration_ms += output.duration_ms;
324
325 let debug_messages_json = output.debug_messages_json();
326
327 self.store
328 .update_step(
329 *step_id,
330 StepUpdate {
331 status: Some(StepStatus::Completed),
332 output: Some(output.output.clone()),
333 duration_ms: Some(output.duration_ms),
334 cost_usd: Some(output.cost_usd),
335 input_tokens: output.input_tokens,
336 output_tokens: output.output_tokens,
337 completed_at: Some(completed_at),
338 debug_messages: debug_messages_json,
339 ..StepUpdate::default()
340 },
341 )
342 .await?;
343
344 info!(
345 run_id = %self.run_id,
346 step = %step_name,
347 duration_ms = output.duration_ms,
348 "parallel step completed"
349 );
350
351 indexed_results[idx] = Some(Ok(output));
352 }
353 Err(err) => {
354 let err_msg = err.to_string();
355 let debug_messages_json = extract_debug_messages_from_error(&err);
356 let partial = extract_partial_usage_from_error(&err);
357 let raw_response_output = extract_raw_response_from_error(&err);
358
359 if let Some(ref usage) = partial {
360 if let Some(cost) = usage.cost_usd {
361 self.total_cost_usd += cost;
362 }
363 if let Some(dur) = usage.duration_ms {
364 self.total_duration_ms += dur;
365 }
366 }
367
368 if let Err(store_err) = self
369 .store
370 .update_step(
371 *step_id,
372 StepUpdate {
373 status: Some(StepStatus::Failed),
374 error: Some(err_msg.clone()),
375 output: raw_response_output,
376 completed_at: Some(completed_at),
377 debug_messages: debug_messages_json,
378 duration_ms: partial.as_ref().and_then(|p| p.duration_ms),
379 cost_usd: partial.as_ref().and_then(|p| p.cost_usd),
380 input_tokens: partial.as_ref().and_then(|p| p.input_tokens),
381 output_tokens: partial.as_ref().and_then(|p| p.output_tokens),
382 ..StepUpdate::default()
383 },
384 )
385 .await
386 {
387 tracing::error!(
388 step_id = %step_id,
389 error = %store_err,
390 "failed to persist parallel step failure"
391 );
392 }
393
394 indexed_results[idx] = Some(Err(err_msg.clone()));
395
396 if first_error.is_none() {
397 first_error = Some(err);
398 }
399
400 if fail_fast {
401 join_set.abort_all();
402 }
403 }
404 }
405 }
406
407 if let Some(err) = first_error {
408 return Err(err);
409 }
410
411 self.last_step_ids = step_records.iter().map(|(id, _, _)| *id).collect();
412
413 let results: Vec<ParallelStepResult> = step_records
415 .iter()
416 .enumerate()
417 .map(|(idx, (step_id, name, _))| {
418 let output = match indexed_results[idx].take() {
419 Some(Ok(o)) => o,
420 _ => unreachable!("all steps succeeded if no error returned"),
421 };
422 ParallelStepResult {
423 name: name.clone(),
424 output,
425 step_id: *step_id,
426 }
427 })
428 .collect();
429
430 Ok(results)
431 }
432
433 pub async fn shell(
456 &mut self,
457 name: &str,
458 config: ShellConfig,
459 ) -> Result<StepOutput, EngineError> {
460 self.execute_step(name, StepKind::Shell, StepConfig::Shell(config))
461 .await
462 }
463
464 pub async fn http(
484 &mut self,
485 name: &str,
486 config: HttpConfig,
487 ) -> Result<StepOutput, EngineError> {
488 self.execute_step(name, StepKind::Http, StepConfig::Http(config))
489 .await
490 }
491
492 pub async fn agent(
512 &mut self,
513 name: &str,
514 config: impl Into<AgentStepConfig>,
515 ) -> Result<StepOutput, EngineError> {
516 self.execute_step(name, StepKind::Agent, StepConfig::Agent(config.into()))
517 .await
518 }
519
520 pub async fn approval(
551 &mut self,
552 name: &str,
553 config: ApprovalConfig,
554 ) -> Result<(), EngineError> {
555 let position = self.position;
556 self.position += 1;
557
558 if let Some(existing) = self.replay_steps.get(&position)
561 && existing.kind == StepKind::Approval
562 {
563 if existing.status.state == StepStatus::AwaitingApproval {
564 self.store
565 .update_step(
566 existing.id,
567 StepUpdate {
568 status: Some(StepStatus::Completed),
569 completed_at: Some(Utc::now()),
570 ..StepUpdate::default()
571 },
572 )
573 .await?;
574 }
575
576 self.last_step_ids = vec![existing.id];
577 info!(
578 run_id = %self.run_id,
579 step = %name,
580 position,
581 "approval step replayed (approved)"
582 );
583 return Ok(());
584 }
585
586 let step = self
588 .store
589 .create_step(NewStep {
590 run_id: self.run_id,
591 name: name.to_string(),
592 kind: StepKind::Approval,
593 position,
594 input: Some(serde_json::to_value(&config)?),
595 })
596 .await?;
597
598 self.start_step(step.id, Utc::now()).await?;
599
600 self.store
603 .update_step(
604 step.id,
605 StepUpdate {
606 status: Some(StepStatus::AwaitingApproval),
607 ..StepUpdate::default()
608 },
609 )
610 .await?;
611
612 self.last_step_ids = vec![step.id];
613
614 Err(EngineError::ApprovalRequired {
615 run_id: self.run_id,
616 step_id: step.id,
617 message: config.message().to_string(),
618 })
619 }
620
621 pub async fn skip(&mut self, name: &str, reason: &str) -> Result<(), EngineError> {
650 let position = self.position;
651 self.position += 1;
652
653 let step = self
654 .store
655 .create_step(NewStep {
656 run_id: self.run_id,
657 name: name.to_string(),
658 kind: StepKind::Custom("skip".to_string()),
659 position,
660 input: None,
661 })
662 .await?;
663
664 if !self.last_step_ids.is_empty() {
665 let deps: Vec<NewStepDependency> = self
666 .last_step_ids
667 .iter()
668 .map(|&depends_on| NewStepDependency {
669 step_id: step.id,
670 depends_on,
671 })
672 .collect();
673 self.store.create_step_dependencies(deps).await?;
674 }
675
676 let now = Utc::now();
677 self.store
678 .update_step(
679 step.id,
680 StepUpdate {
681 status: Some(StepStatus::Skipped),
682 output: Some(serde_json::json!({"reason": reason})),
683 completed_at: Some(now),
684 ..StepUpdate::default()
685 },
686 )
687 .await?;
688
689 self.last_step_ids = vec![step.id];
690
691 info!(
692 run_id = %self.run_id,
693 step = %name,
694 reason,
695 "step skipped"
696 );
697
698 Ok(())
699 }
700
701 pub async fn operation(
739 &mut self,
740 name: &str,
741 op: &dyn Operation,
742 ) -> Result<StepOutput, EngineError> {
743 let kind = StepKind::Custom(op.kind().to_string());
744 let position = self.position;
745 self.position += 1;
746
747 let step = self
748 .store
749 .create_step(NewStep {
750 run_id: self.run_id,
751 name: name.to_string(),
752 kind,
753 position,
754 input: op.input(),
755 })
756 .await?;
757
758 self.start_step(step.id, Utc::now()).await?;
759
760 let start = Instant::now();
761
762 match op.execute().await {
763 Ok(output_value) => {
764 let duration_ms = start.elapsed().as_millis() as u64;
765 self.total_duration_ms += duration_ms;
766
767 let completed_at = Utc::now();
768 self.store
769 .update_step(
770 step.id,
771 StepUpdate {
772 status: Some(StepStatus::Completed),
773 output: Some(output_value.clone()),
774 duration_ms: Some(duration_ms),
775 cost_usd: Some(Decimal::ZERO),
776 completed_at: Some(completed_at),
777 ..StepUpdate::default()
778 },
779 )
780 .await?;
781
782 info!(
783 run_id = %self.run_id,
784 step = %name,
785 kind = op.kind(),
786 duration_ms,
787 "operation step completed"
788 );
789
790 self.last_step_ids = vec![step.id];
791
792 Ok(StepOutput {
793 output: output_value,
794 duration_ms,
795 cost_usd: Decimal::ZERO,
796 input_tokens: None,
797 output_tokens: None,
798 model: None,
799 debug_messages: None,
800 })
801 }
802 Err(err) => {
803 let completed_at = Utc::now();
804 if let Err(store_err) = self
805 .store
806 .update_step(
807 step.id,
808 StepUpdate {
809 status: Some(StepStatus::Failed),
810 error: Some(err.to_string()),
811 completed_at: Some(completed_at),
812 ..StepUpdate::default()
813 },
814 )
815 .await
816 {
817 error!(step_id = %step.id, error = %store_err, "failed to persist step failure");
818 }
819
820 Err(err)
821 }
822 }
823 }
824
825 pub async fn workflow(
852 &mut self,
853 handler: &dyn WorkflowHandler,
854 payload: Value,
855 ) -> Result<StepOutput, EngineError> {
856 let config = WorkflowStepConfig::new(handler.name(), payload);
857 let position = self.position;
858 self.position += 1;
859
860 let step = self
861 .store
862 .create_step(NewStep {
863 run_id: self.run_id,
864 name: config.workflow_name.clone(),
865 kind: StepKind::Workflow,
866 position,
867 input: Some(serde_json::to_value(&config)?),
868 })
869 .await?;
870
871 self.start_step(step.id, Utc::now()).await?;
872
873 match self.execute_child_workflow(&config).await {
874 Ok(output) => {
875 self.total_cost_usd += output.cost_usd;
876 self.total_duration_ms += output.duration_ms;
877
878 let completed_at = Utc::now();
879 self.store
880 .update_step(
881 step.id,
882 StepUpdate {
883 status: Some(StepStatus::Completed),
884 output: Some(output.output.clone()),
885 duration_ms: Some(output.duration_ms),
886 cost_usd: Some(output.cost_usd),
887 completed_at: Some(completed_at),
888 ..StepUpdate::default()
889 },
890 )
891 .await?;
892
893 info!(
894 run_id = %self.run_id,
895 child_workflow = %config.workflow_name,
896 duration_ms = output.duration_ms,
897 "workflow step completed"
898 );
899
900 self.last_step_ids = vec![step.id];
901
902 Ok(output)
903 }
904 Err(err) => {
905 let completed_at = Utc::now();
906 if let Err(store_err) = self
907 .store
908 .update_step(
909 step.id,
910 StepUpdate {
911 status: Some(StepStatus::Failed),
912 error: Some(err.to_string()),
913 completed_at: Some(completed_at),
914 ..StepUpdate::default()
915 },
916 )
917 .await
918 {
919 error!(step_id = %step.id, error = %store_err, "failed to persist step failure");
920 }
921
922 Err(err)
923 }
924 }
925 }
926
927 async fn execute_child_workflow(
929 &self,
930 config: &WorkflowStepConfig,
931 ) -> Result<StepOutput, EngineError> {
932 let resolver = self.handler_resolver.as_ref().ok_or_else(|| {
933 EngineError::InvalidWorkflow(
934 "sub-workflow requires a handler resolver (use Engine to execute)".to_string(),
935 )
936 })?;
937
938 let handler = resolver(&config.workflow_name).ok_or_else(|| {
939 EngineError::InvalidWorkflow(format!("no handler registered: {}", config.workflow_name))
940 })?;
941
942 let parent_labels = self
943 .store
944 .get_run(self.run_id)
945 .await?
946 .map(|r| r.labels)
947 .unwrap_or_default();
948
949 let child_run = self
950 .store
951 .create_run(NewRun {
952 workflow_name: config.workflow_name.clone(),
953 trigger: TriggerKind::Workflow,
954 payload: config.payload.clone(),
955 max_retries: 0,
956 handler_version: None,
957 labels: parent_labels,
958 scheduled_at: None,
959 })
960 .await?;
961
962 let child_run_id = child_run.id;
963 info!(
964 parent_run_id = %self.run_id,
965 child_run_id = %child_run_id,
966 workflow = %config.workflow_name,
967 "child run created"
968 );
969
970 self.store
971 .update_run_status(child_run_id, RunStatus::Running)
972 .await?;
973
974 let run_start = Instant::now();
975 let mut child_ctx = WorkflowContext {
976 run_id: child_run_id,
977 store: self.store.clone(),
978 provider: self.provider.clone(),
979 handler_resolver: self.handler_resolver.clone(),
980 position: 0,
981 last_step_ids: Vec::new(),
982 total_cost_usd: Decimal::ZERO,
983 total_duration_ms: 0,
984 replay_steps: HashMap::new(),
985 log_sender: self.log_sender.clone(),
986 };
987
988 let result = handler.execute(&mut child_ctx).await;
989 let total_duration = run_start.elapsed().as_millis() as u64;
990 let completed_at = Utc::now();
991
992 match result {
993 Ok(()) => {
994 self.store
995 .update_run(
996 child_run_id,
997 RunUpdate {
998 status: Some(RunStatus::Completed),
999 cost_usd: Some(child_ctx.total_cost_usd),
1000 duration_ms: Some(total_duration),
1001 completed_at: Some(completed_at),
1002 ..RunUpdate::default()
1003 },
1004 )
1005 .await?;
1006
1007 Ok(StepOutput {
1008 output: serde_json::json!({
1009 "run_id": child_run_id,
1010 "workflow_name": config.workflow_name,
1011 "status": RunStatus::Completed,
1012 "cost_usd": child_ctx.total_cost_usd,
1013 "duration_ms": total_duration,
1014 }),
1015 duration_ms: total_duration,
1016 cost_usd: child_ctx.total_cost_usd,
1017 input_tokens: None,
1018 output_tokens: None,
1019 model: None,
1020 debug_messages: None,
1021 })
1022 }
1023 Err(err) => {
1024 if let Err(store_err) = self
1025 .store
1026 .update_run(
1027 child_run_id,
1028 RunUpdate {
1029 status: Some(RunStatus::Failed),
1030 error: Some(err.to_string()),
1031 cost_usd: Some(child_ctx.total_cost_usd),
1032 duration_ms: Some(total_duration),
1033 completed_at: Some(completed_at),
1034 ..RunUpdate::default()
1035 },
1036 )
1037 .await
1038 {
1039 error!(
1040 child_run_id = %child_run_id,
1041 store_error = %store_err,
1042 "failed to persist child run failure"
1043 );
1044 }
1045
1046 Err(err)
1047 }
1048 }
1049 }
1050
1051 fn try_replay_step(&mut self, position: u32) -> Option<StepOutput> {
1056 let step = self.replay_steps.get(&position)?;
1057 if step.status.state != StepStatus::Completed {
1058 return None;
1059 }
1060 let output = StepOutput {
1061 output: step.output.clone().unwrap_or(Value::Null),
1062 duration_ms: step.duration_ms,
1063 cost_usd: step.cost_usd,
1064 input_tokens: step.input_tokens,
1065 output_tokens: step.output_tokens,
1066 model: None,
1067 debug_messages: None,
1068 };
1069 self.total_cost_usd += output.cost_usd;
1070 self.total_duration_ms += output.duration_ms;
1071 self.last_step_ids = vec![step.id];
1072 info!(
1073 run_id = %self.run_id,
1074 step = %step.name,
1075 position,
1076 "step replayed from previous execution"
1077 );
1078 Some(output)
1079 }
1080
1081 async fn execute_step(
1083 &mut self,
1084 name: &str,
1085 kind: StepKind,
1086 config: StepConfig,
1087 ) -> Result<StepOutput, EngineError> {
1088 let position = self.position;
1089 self.position += 1;
1090
1091 if let Some(output) = self.try_replay_step(position) {
1093 return Ok(output);
1094 }
1095
1096 let step = self
1098 .store
1099 .create_step(NewStep {
1100 run_id: self.run_id,
1101 name: name.to_string(),
1102 kind,
1103 position,
1104 input: Some(serde_json::to_value(&config)?),
1105 })
1106 .await?;
1107
1108 self.start_step(step.id, Utc::now()).await?;
1109
1110 let step_log_sender = self
1111 .log_sender
1112 .as_ref()
1113 .map(|s| StepLogSender::new(s.clone(), self.run_id, step.id, name.to_string()));
1114
1115 match execute_step_config(&config, &self.provider, step_log_sender).await {
1116 Ok(output) => {
1117 self.total_cost_usd += output.cost_usd;
1118 self.total_duration_ms += output.duration_ms;
1119
1120 let debug_messages_json = output.debug_messages_json();
1121
1122 let completed_at = Utc::now();
1123 self.store
1124 .update_step(
1125 step.id,
1126 StepUpdate {
1127 status: Some(StepStatus::Completed),
1128 output: Some(output.output.clone()),
1129 duration_ms: Some(output.duration_ms),
1130 cost_usd: Some(output.cost_usd),
1131 input_tokens: output.input_tokens,
1132 output_tokens: output.output_tokens,
1133 completed_at: Some(completed_at),
1134 debug_messages: debug_messages_json,
1135 ..StepUpdate::default()
1136 },
1137 )
1138 .await?;
1139
1140 info!(
1141 run_id = %self.run_id,
1142 step = %name,
1143 duration_ms = output.duration_ms,
1144 "step completed"
1145 );
1146
1147 self.last_step_ids = vec![step.id];
1148
1149 Ok(output)
1150 }
1151 Err(err) => {
1152 let completed_at = Utc::now();
1153 let debug_messages_json = extract_debug_messages_from_error(&err);
1154 let partial = extract_partial_usage_from_error(&err);
1155 let raw_response_output = extract_raw_response_from_error(&err);
1156
1157 if let Some(ref usage) = partial {
1158 if let Some(cost) = usage.cost_usd {
1159 self.total_cost_usd += cost;
1160 }
1161 if let Some(dur) = usage.duration_ms {
1162 self.total_duration_ms += dur;
1163 }
1164 }
1165
1166 if let Err(store_err) = self
1167 .store
1168 .update_step(
1169 step.id,
1170 StepUpdate {
1171 status: Some(StepStatus::Failed),
1172 error: Some(err.to_string()),
1173 output: raw_response_output,
1174 completed_at: Some(completed_at),
1175 debug_messages: debug_messages_json,
1176 duration_ms: partial.as_ref().and_then(|p| p.duration_ms),
1177 cost_usd: partial.as_ref().and_then(|p| p.cost_usd),
1178 input_tokens: partial.as_ref().and_then(|p| p.input_tokens),
1179 output_tokens: partial.as_ref().and_then(|p| p.output_tokens),
1180 ..StepUpdate::default()
1181 },
1182 )
1183 .await
1184 {
1185 tracing::error!(step_id = %step.id, error = %store_err, "failed to persist step failure");
1186 }
1187
1188 Err(err)
1189 }
1190 }
1191 }
1192
1193 async fn start_step(&self, step_id: Uuid, now: DateTime<Utc>) -> Result<(), EngineError> {
1198 if !self.last_step_ids.is_empty() {
1199 let deps: Vec<NewStepDependency> = self
1200 .last_step_ids
1201 .iter()
1202 .map(|&depends_on| NewStepDependency {
1203 step_id,
1204 depends_on,
1205 })
1206 .collect();
1207 self.store.create_step_dependencies(deps).await?;
1208 }
1209
1210 self.store
1211 .update_step(
1212 step_id,
1213 StepUpdate {
1214 status: Some(StepStatus::Running),
1215 started_at: Some(now),
1216 ..StepUpdate::default()
1217 },
1218 )
1219 .await?;
1220
1221 Ok(())
1222 }
1223
1224 pub fn store(&self) -> &Arc<dyn Store> {
1226 &self.store
1227 }
1228
1229 pub async fn payload(&self) -> Result<Value, EngineError> {
1237 let run = self
1238 .store
1239 .get_run(self.run_id)
1240 .await?
1241 .ok_or(EngineError::Store(
1242 ironflow_store::error::StoreError::RunNotFound(self.run_id),
1243 ))?;
1244 Ok(run.payload)
1245 }
1246}
1247
1248impl fmt::Debug for WorkflowContext {
1249 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1250 f.debug_struct("WorkflowContext")
1251 .field("run_id", &self.run_id)
1252 .field("position", &self.position)
1253 .field("total_cost_usd", &self.total_cost_usd)
1254 .finish_non_exhaustive()
1255 }
1256}
1257
1258fn extract_debug_messages_from_error(err: &EngineError) -> Option<Value> {
1261 if let EngineError::Operation(OperationError::Agent(AgentError::SchemaValidation {
1262 debug_messages,
1263 ..
1264 })) = err
1265 && !debug_messages.is_empty()
1266 {
1267 return serde_json::to_value(debug_messages).ok();
1268 }
1269 None
1270}
1271
1272struct StepPartialUsage {
1278 cost_usd: Option<Decimal>,
1279 duration_ms: Option<u64>,
1280 input_tokens: Option<u64>,
1281 output_tokens: Option<u64>,
1282}
1283
1284fn extract_raw_response_from_error(err: &EngineError) -> Option<Value> {
1290 if let EngineError::Operation(OperationError::Agent(AgentError::SchemaValidation {
1291 raw_response: Some(text),
1292 ..
1293 })) = err
1294 {
1295 return Some(Value::String(text.clone()));
1296 }
1297 None
1298}
1299
1300fn extract_partial_usage_from_error(err: &EngineError) -> Option<StepPartialUsage> {
1301 if let EngineError::Operation(OperationError::Agent(AgentError::SchemaValidation {
1302 partial_usage,
1303 ..
1304 })) = err
1305 && (partial_usage.cost_usd.is_some() || partial_usage.duration_ms.is_some())
1306 {
1307 return Some(StepPartialUsage {
1308 cost_usd: partial_usage
1309 .cost_usd
1310 .and_then(|c| Decimal::try_from(c).ok()),
1311 duration_ms: partial_usage.duration_ms,
1312 input_tokens: partial_usage.input_tokens,
1313 output_tokens: partial_usage.output_tokens,
1314 });
1315 }
1316 None
1317}
1318
1319#[cfg(test)]
1320mod tests {
1321 use super::*;
1322 use ironflow_core::providers::claude::ClaudeCodeProvider;
1323 use ironflow_core::providers::record_replay::RecordReplayProvider;
1324 use ironflow_store::memory::InMemoryStore;
1325 use ironflow_store::models::RunFilter;
1326 use ironflow_store::store::RunStore;
1327 use serde_json::json;
1328 use std::sync::Arc;
1329 use std::sync::atomic::{AtomicBool, Ordering};
1330 use uuid::Uuid;
1331
1332 fn create_test_provider() -> Arc<dyn ironflow_core::provider::AgentProvider> {
1334 let inner = ClaudeCodeProvider::new();
1335 Arc::new(RecordReplayProvider::replay(
1336 inner,
1337 "/tmp/ironflow-fixtures",
1338 ))
1339 }
1340
1341 fn create_test_context() -> WorkflowContext {
1343 let store = Arc::new(InMemoryStore::new());
1344 let provider = create_test_provider();
1345 let run_id = Uuid::now_v7();
1346 WorkflowContext::new(run_id, store, provider)
1347 }
1348
1349 #[test]
1350 fn context_new_initializes_correctly() {
1351 let ctx = create_test_context();
1352 assert_eq!(ctx.position, 0);
1353 assert_eq!(ctx.total_cost_usd, Decimal::ZERO);
1354 assert_eq!(ctx.total_duration_ms, 0);
1355 assert!(ctx.last_step_ids.is_empty());
1356 assert!(ctx.replay_steps.is_empty());
1357 assert!(ctx.log_sender.is_none());
1358 }
1359
1360 #[test]
1361 fn context_run_id_returns_correct_id() {
1362 let run_id = Uuid::now_v7();
1363 let store = Arc::new(InMemoryStore::new());
1364 let provider = create_test_provider();
1365 let ctx = WorkflowContext::new(run_id, store, provider);
1366 assert_eq!(ctx.run_id(), run_id);
1367 }
1368
1369 #[test]
1370 fn context_total_cost_usd_initially_zero() {
1371 let ctx = create_test_context();
1372 assert_eq!(ctx.total_cost_usd(), Decimal::ZERO);
1373 }
1374
1375 #[test]
1376 fn context_total_duration_ms_initially_zero() {
1377 let ctx = create_test_context();
1378 assert_eq!(ctx.total_duration_ms(), 0);
1379 }
1380
1381 #[test]
1382 fn context_with_handler_resolver_creates_context_with_resolver() {
1383 let store = Arc::new(InMemoryStore::new());
1384 let provider = create_test_provider();
1385 let run_id = Uuid::now_v7();
1386
1387 let called = Arc::new(AtomicBool::new(false));
1388 let called_clone = called.clone();
1389
1390 let resolver: HandlerResolver = Arc::new(move |_name: &str| {
1391 called_clone.store(true, Ordering::SeqCst);
1392 None
1393 });
1394
1395 let ctx = WorkflowContext::with_handler_resolver(run_id, store, provider, resolver);
1396
1397 assert_eq!(ctx.run_id(), run_id);
1398 assert!(ctx.handler_resolver.is_some());
1399 }
1400
1401 #[tokio::test]
1402 async fn context_set_log_sender_attaches_sender() {
1403 let mut ctx = create_test_context();
1404 let (sender, _receiver) = crate::log_sender::channel();
1405 ctx.set_log_sender(sender);
1406 assert!(ctx.log_sender.is_some());
1407 }
1408
1409 #[tokio::test]
1410 async fn context_skip_creates_skipped_step() {
1411 let store = Arc::new(InMemoryStore::new());
1412 let provider = create_test_provider();
1413
1414 store
1416 .create_run(NewRun {
1417 workflow_name: "test".to_string(),
1418 trigger: TriggerKind::Manual,
1419 payload: json!({}),
1420 max_retries: 0,
1421 handler_version: None,
1422 labels: Default::default(),
1423 scheduled_at: None,
1424 })
1425 .await
1426 .expect("failed to create run");
1427
1428 let runs = store
1430 .list_runs(RunFilter::default(), 1, 10)
1431 .await
1432 .expect("failed to list runs");
1433 let created_run_id = runs.items[0].id;
1434
1435 let mut ctx = WorkflowContext::new(created_run_id, store.clone(), provider);
1436 let initial_position = ctx.position;
1437
1438 ctx.skip("skip-step", "condition not met")
1439 .await
1440 .expect("skip failed");
1441
1442 assert_eq!(ctx.position, initial_position + 1);
1443 assert!(!ctx.last_step_ids.is_empty());
1444
1445 let steps = store
1447 .list_steps(created_run_id)
1448 .await
1449 .expect("failed to list steps");
1450 assert_eq!(steps.len(), 1);
1451 assert_eq!(steps[0].status.state, StepStatus::Skipped);
1452 }
1453
1454 #[tokio::test]
1455 async fn context_parallel_empty_steps_returns_empty_vec() {
1456 let mut ctx = create_test_context();
1457 let results = ctx
1458 .parallel(vec![], true)
1459 .await
1460 .expect("parallel should not fail on empty input");
1461 assert!(results.is_empty());
1462 }
1463
1464 #[tokio::test]
1465 async fn context_approval_first_execution_returns_error() {
1466 let store = Arc::new(InMemoryStore::new());
1467 let provider = create_test_provider();
1468
1469 store
1471 .create_run(NewRun {
1472 workflow_name: "test".to_string(),
1473 trigger: TriggerKind::Manual,
1474 payload: json!({}),
1475 max_retries: 0,
1476 handler_version: None,
1477 labels: Default::default(),
1478 scheduled_at: None,
1479 })
1480 .await
1481 .expect("failed to create run");
1482
1483 let runs = store
1485 .list_runs(RunFilter::default(), 1, 10)
1486 .await
1487 .expect("failed to list runs");
1488 let created_run_id = runs.items[0].id;
1489
1490 let mut ctx = WorkflowContext::new(created_run_id, store.clone(), provider);
1491
1492 let result = ctx
1493 .approval(
1494 "approve-step",
1495 crate::config::ApprovalConfig::new("Continue?"),
1496 )
1497 .await;
1498
1499 assert!(matches!(result, Err(EngineError::ApprovalRequired { .. })));
1501
1502 assert_eq!(ctx.position, 1);
1504
1505 let steps = store
1507 .list_steps(created_run_id)
1508 .await
1509 .expect("failed to list steps");
1510 assert_eq!(steps.len(), 1);
1511 assert_eq!(steps[0].status.state, StepStatus::AwaitingApproval);
1512 }
1513
1514 #[tokio::test]
1515 async fn context_approval_replay_returns_ok() {
1516 let store = Arc::new(InMemoryStore::new());
1517 let provider = create_test_provider();
1518
1519 store
1521 .create_run(NewRun {
1522 workflow_name: "test".to_string(),
1523 trigger: TriggerKind::Manual,
1524 payload: json!({}),
1525 max_retries: 0,
1526 handler_version: None,
1527 labels: Default::default(),
1528 scheduled_at: None,
1529 })
1530 .await
1531 .expect("failed to create run");
1532
1533 let runs = store
1535 .list_runs(RunFilter::default(), 1, 10)
1536 .await
1537 .expect("failed to list runs");
1538 let created_run_id = runs.items[0].id;
1539
1540 let step = store
1542 .create_step(NewStep {
1543 run_id: created_run_id,
1544 name: "approval".to_string(),
1545 kind: StepKind::Approval,
1546 position: 0,
1547 input: None,
1548 })
1549 .await
1550 .expect("failed to create step");
1551
1552 store
1554 .update_step(
1555 step.id,
1556 StepUpdate {
1557 status: Some(StepStatus::Running),
1558 started_at: Some(Utc::now()),
1559 ..StepUpdate::default()
1560 },
1561 )
1562 .await
1563 .expect("failed to update step to Running");
1564
1565 store
1566 .update_step(
1567 step.id,
1568 StepUpdate {
1569 status: Some(StepStatus::AwaitingApproval),
1570 ..StepUpdate::default()
1571 },
1572 )
1573 .await
1574 .expect("failed to update step to AwaitingApproval");
1575
1576 let mut ctx = WorkflowContext::new(created_run_id, store.clone(), provider);
1578 ctx.load_replay_steps()
1579 .await
1580 .expect("failed to load replay steps");
1581
1582 let result = ctx
1584 .approval("approval", crate::config::ApprovalConfig::new("Continue?"))
1585 .await;
1586
1587 assert!(result.is_ok());
1588
1589 let steps = store
1591 .list_steps(created_run_id)
1592 .await
1593 .expect("failed to list steps");
1594 assert_eq!(steps.len(), 1);
1595 assert_eq!(steps[0].status.state, StepStatus::Completed);
1596 }
1597
1598 #[tokio::test]
1599 async fn context_load_replay_steps_loads_completed_steps() {
1600 let store = Arc::new(InMemoryStore::new());
1601 let provider = create_test_provider();
1602
1603 store
1605 .create_run(NewRun {
1606 workflow_name: "test".to_string(),
1607 trigger: TriggerKind::Manual,
1608 payload: json!({}),
1609 max_retries: 0,
1610 handler_version: None,
1611 labels: Default::default(),
1612 scheduled_at: None,
1613 })
1614 .await
1615 .expect("failed to create run");
1616
1617 let runs = store
1619 .list_runs(RunFilter::default(), 1, 10)
1620 .await
1621 .expect("failed to list runs");
1622 let created_run_id = runs.items[0].id;
1623
1624 let completed_step = store
1626 .create_step(NewStep {
1627 run_id: created_run_id,
1628 name: "completed".to_string(),
1629 kind: StepKind::Shell,
1630 position: 0,
1631 input: None,
1632 })
1633 .await
1634 .expect("failed to create step");
1635
1636 store
1638 .update_step(
1639 completed_step.id,
1640 StepUpdate {
1641 status: Some(StepStatus::Running),
1642 started_at: Some(Utc::now()),
1643 ..StepUpdate::default()
1644 },
1645 )
1646 .await
1647 .expect("failed to update step to Running");
1648
1649 store
1650 .update_step(
1651 completed_step.id,
1652 StepUpdate {
1653 status: Some(StepStatus::Completed),
1654 completed_at: Some(Utc::now()),
1655 ..StepUpdate::default()
1656 },
1657 )
1658 .await
1659 .expect("failed to update step to Completed");
1660
1661 let _pending_step = store
1662 .create_step(NewStep {
1663 run_id: created_run_id,
1664 name: "pending".to_string(),
1665 kind: StepKind::Shell,
1666 position: 1,
1667 input: None,
1668 })
1669 .await
1670 .expect("failed to create step");
1671
1672 let mut ctx = WorkflowContext::new(created_run_id, store, provider);
1674 ctx.load_replay_steps()
1675 .await
1676 .expect("failed to load replay steps");
1677
1678 assert_eq!(ctx.replay_steps.len(), 1);
1680 assert!(ctx.replay_steps.contains_key(&0));
1681 assert!(!ctx.replay_steps.contains_key(&1));
1682 }
1683
1684 #[tokio::test]
1685 async fn context_payload_returns_run_payload() {
1686 let store = Arc::new(InMemoryStore::new());
1687 let provider = create_test_provider();
1688 let test_payload = json!({"key": "value", "number": 42});
1689
1690 store
1692 .create_run(NewRun {
1693 workflow_name: "test".to_string(),
1694 trigger: TriggerKind::Manual,
1695 payload: test_payload.clone(),
1696 max_retries: 0,
1697 handler_version: None,
1698 labels: Default::default(),
1699 scheduled_at: None,
1700 })
1701 .await
1702 .expect("failed to create run");
1703
1704 let runs = store
1706 .list_runs(RunFilter::default(), 1, 10)
1707 .await
1708 .expect("failed to list runs");
1709 let created_run_id = runs.items[0].id;
1710
1711 let ctx = WorkflowContext::new(created_run_id, store, provider);
1712 let payload = ctx.payload().await.expect("failed to get payload");
1713
1714 assert_eq!(payload, test_payload);
1715 }
1716
1717 #[tokio::test]
1718 async fn context_payload_returns_error_for_nonexistent_run() {
1719 let store = Arc::new(InMemoryStore::new());
1720 let provider = create_test_provider();
1721 let run_id = Uuid::now_v7();
1722
1723 let ctx = WorkflowContext::new(run_id, store, provider);
1724 let result = ctx.payload().await;
1725
1726 assert!(result.is_err());
1727 }
1728
1729 #[tokio::test]
1730 async fn context_store_returns_reference() {
1731 let ctx = create_test_context();
1732 let _store = ctx.store();
1733 }
1735
1736 #[test]
1737 fn context_debug_formatting() {
1738 let ctx = create_test_context();
1739 let debug_str = format!("{:?}", ctx);
1740 assert!(debug_str.contains("WorkflowContext"));
1741 assert!(debug_str.contains("run_id"));
1742 }
1743
1744 #[tokio::test]
1745 async fn context_last_step_ids_tracks_executed_steps() {
1746 let store = Arc::new(InMemoryStore::new());
1747 let provider = create_test_provider();
1748
1749 store
1751 .create_run(NewRun {
1752 workflow_name: "test".to_string(),
1753 trigger: TriggerKind::Manual,
1754 payload: json!({}),
1755 max_retries: 0,
1756 handler_version: None,
1757 labels: Default::default(),
1758 scheduled_at: None,
1759 })
1760 .await
1761 .expect("failed to create run");
1762
1763 let runs = store
1765 .list_runs(RunFilter::default(), 1, 10)
1766 .await
1767 .expect("failed to list runs");
1768 let created_run_id = runs.items[0].id;
1769
1770 let mut ctx = WorkflowContext::new(created_run_id, store, provider);
1771 assert!(ctx.last_step_ids.is_empty());
1772
1773 ctx.skip("step1", "reason").await.expect("skip failed");
1774
1775 assert_eq!(ctx.last_step_ids.len(), 1);
1776
1777 ctx.skip("step2", "reason").await.expect("skip failed");
1778
1779 assert_eq!(ctx.last_step_ids.len(), 1);
1781 }
1782}