1use std::collections::HashMap;
27use std::fmt;
28use std::sync::Arc;
29use std::time::Instant;
30
31use chrono::{DateTime, Utc};
32use rust_decimal::Decimal;
33use serde_json::Value;
34use tokio::task::JoinSet;
35use tracing::{error, info};
36use uuid::Uuid;
37
38use ironflow_core::error::{AgentError, OperationError};
39use ironflow_core::provider::AgentProvider;
40use ironflow_store::models::{
41 NewRun, NewStep, NewStepDependency, RunStatus, RunUpdate, Step, StepKind, StepStatus,
42 StepUpdate, TriggerKind,
43};
44use ironflow_store::store::Store;
45
46use crate::config::{
47 AgentStepConfig, ApprovalConfig, HttpConfig, ShellConfig, StepConfig, WorkflowStepConfig,
48};
49use crate::error::EngineError;
50use crate::executor::{ParallelStepResult, StepOutput, execute_step_config};
51use crate::handler::WorkflowHandler;
52use crate::log_sender::{LogSender, StepLogSender};
53use crate::operation::Operation;
54
55pub(crate) type HandlerResolver =
57 Arc<dyn Fn(&str) -> Option<Arc<dyn WorkflowHandler>> + Send + Sync>;
58
59pub struct WorkflowContext {
78 run_id: Uuid,
79 store: Arc<dyn Store>,
80 provider: Arc<dyn AgentProvider>,
81 handler_resolver: Option<HandlerResolver>,
82 position: u32,
83 last_step_ids: Vec<Uuid>,
85 total_cost_usd: Decimal,
87 total_duration_ms: u64,
89 replay_steps: HashMap<u32, Step>,
92 log_sender: Option<LogSender>,
94}
95
96impl WorkflowContext {
97 pub fn new(run_id: Uuid, store: Arc<dyn Store>, provider: Arc<dyn AgentProvider>) -> Self {
102 Self {
103 run_id,
104 store,
105 provider,
106 handler_resolver: None,
107 position: 0,
108 last_step_ids: Vec::new(),
109 total_cost_usd: Decimal::ZERO,
110 total_duration_ms: 0,
111 replay_steps: HashMap::new(),
112 log_sender: None,
113 }
114 }
115
116 pub(crate) fn with_handler_resolver(
121 run_id: Uuid,
122 store: Arc<dyn Store>,
123 provider: Arc<dyn AgentProvider>,
124 resolver: HandlerResolver,
125 ) -> Self {
126 Self {
127 run_id,
128 store,
129 provider,
130 handler_resolver: Some(resolver),
131 position: 0,
132 last_step_ids: Vec::new(),
133 total_cost_usd: Decimal::ZERO,
134 total_duration_ms: 0,
135 replay_steps: HashMap::new(),
136 log_sender: None,
137 }
138 }
139
140 pub fn set_log_sender(&mut self, sender: LogSender) {
142 self.log_sender = Some(sender);
143 }
144
145 pub(crate) async fn load_replay_steps(&mut self) -> Result<(), EngineError> {
151 let steps = self.store.list_steps(self.run_id).await?;
152 for step in steps {
153 let dominated = matches!(
154 step.status.state,
155 StepStatus::Completed | StepStatus::Running | StepStatus::AwaitingApproval
156 );
157 if dominated {
158 self.replay_steps.insert(step.position, step);
159 }
160 }
161 Ok(())
162 }
163
164 pub fn run_id(&self) -> Uuid {
166 self.run_id
167 }
168
169 pub fn total_cost_usd(&self) -> Decimal {
171 self.total_cost_usd
172 }
173
174 pub fn total_duration_ms(&self) -> u64 {
176 self.total_duration_ms
177 }
178
179 pub async fn parallel(
216 &mut self,
217 steps: Vec<(&str, StepConfig)>,
218 fail_fast: bool,
219 ) -> Result<Vec<ParallelStepResult>, EngineError> {
220 if steps.is_empty() {
221 return Ok(Vec::new());
222 }
223
224 let wave_position = self.position;
225 self.position += 1;
226
227 let now = Utc::now();
228 let mut step_records: Vec<(Uuid, String, StepConfig)> = Vec::with_capacity(steps.len());
229
230 for (name, config) in &steps {
231 let kind = config.kind();
232 let step = self
233 .store
234 .create_step(NewStep {
235 run_id: self.run_id,
236 name: name.to_string(),
237 kind,
238 position: wave_position,
239 input: Some(serde_json::to_value(config)?),
240 })
241 .await?;
242
243 self.start_step(step.id, now).await?;
244
245 step_records.push((step.id, name.to_string(), config.clone()));
246 }
247
248 let mut join_set = JoinSet::new();
249 for (idx, (step_id, step_name, config)) in step_records.iter().enumerate() {
250 let provider = self.provider.clone();
251 let config = config.clone();
252 let step_log_sender = self
253 .log_sender
254 .as_ref()
255 .map(|s| StepLogSender::new(s.clone(), self.run_id, *step_id, step_name.clone()));
256 join_set.spawn(async move {
257 (
258 idx,
259 execute_step_config(&config, &provider, step_log_sender).await,
260 )
261 });
262 }
263
264 let mut indexed_results: Vec<Option<Result<StepOutput, String>>> =
266 vec![None; step_records.len()];
267 let mut first_error: Option<EngineError> = None;
268
269 while let Some(join_result) = join_set.join_next().await {
270 let (idx, step_result) = match join_result {
271 Ok(r) => r,
272 Err(e) => {
273 if first_error.is_none() {
274 first_error = Some(EngineError::StepConfig(format!("join error: {e}")));
275 }
276 if fail_fast {
277 join_set.abort_all();
278 }
279 continue;
280 }
281 };
282
283 let (step_id, step_name, _) = &step_records[idx];
284 let completed_at = Utc::now();
285
286 match step_result {
287 Ok(output) => {
288 self.total_cost_usd += output.cost_usd;
289 self.total_duration_ms += output.duration_ms;
290
291 let debug_messages_json = output.debug_messages_json();
292
293 self.store
294 .update_step(
295 *step_id,
296 StepUpdate {
297 status: Some(StepStatus::Completed),
298 output: Some(output.output.clone()),
299 duration_ms: Some(output.duration_ms),
300 cost_usd: Some(output.cost_usd),
301 input_tokens: output.input_tokens,
302 output_tokens: output.output_tokens,
303 completed_at: Some(completed_at),
304 debug_messages: debug_messages_json,
305 ..StepUpdate::default()
306 },
307 )
308 .await?;
309
310 info!(
311 run_id = %self.run_id,
312 step = %step_name,
313 duration_ms = output.duration_ms,
314 "parallel step completed"
315 );
316
317 indexed_results[idx] = Some(Ok(output));
318 }
319 Err(err) => {
320 let err_msg = err.to_string();
321 let debug_messages_json = extract_debug_messages_from_error(&err);
322 let partial = extract_partial_usage_from_error(&err);
323 let raw_response_output = extract_raw_response_from_error(&err);
324
325 if let Some(ref usage) = partial {
326 if let Some(cost) = usage.cost_usd {
327 self.total_cost_usd += cost;
328 }
329 if let Some(dur) = usage.duration_ms {
330 self.total_duration_ms += dur;
331 }
332 }
333
334 if let Err(store_err) = self
335 .store
336 .update_step(
337 *step_id,
338 StepUpdate {
339 status: Some(StepStatus::Failed),
340 error: Some(err_msg.clone()),
341 output: raw_response_output,
342 completed_at: Some(completed_at),
343 debug_messages: debug_messages_json,
344 duration_ms: partial.as_ref().and_then(|p| p.duration_ms),
345 cost_usd: partial.as_ref().and_then(|p| p.cost_usd),
346 input_tokens: partial.as_ref().and_then(|p| p.input_tokens),
347 output_tokens: partial.as_ref().and_then(|p| p.output_tokens),
348 ..StepUpdate::default()
349 },
350 )
351 .await
352 {
353 tracing::error!(
354 step_id = %step_id,
355 error = %store_err,
356 "failed to persist parallel step failure"
357 );
358 }
359
360 indexed_results[idx] = Some(Err(err_msg.clone()));
361
362 if first_error.is_none() {
363 first_error = Some(err);
364 }
365
366 if fail_fast {
367 join_set.abort_all();
368 }
369 }
370 }
371 }
372
373 if let Some(err) = first_error {
374 return Err(err);
375 }
376
377 self.last_step_ids = step_records.iter().map(|(id, _, _)| *id).collect();
378
379 let results: Vec<ParallelStepResult> = step_records
381 .iter()
382 .enumerate()
383 .map(|(idx, (step_id, name, _))| {
384 let output = match indexed_results[idx].take() {
385 Some(Ok(o)) => o,
386 _ => unreachable!("all steps succeeded if no error returned"),
387 };
388 ParallelStepResult {
389 name: name.clone(),
390 output,
391 step_id: *step_id,
392 }
393 })
394 .collect();
395
396 Ok(results)
397 }
398
399 pub async fn shell(
422 &mut self,
423 name: &str,
424 config: ShellConfig,
425 ) -> Result<StepOutput, EngineError> {
426 self.execute_step(name, StepKind::Shell, StepConfig::Shell(config))
427 .await
428 }
429
430 pub async fn http(
450 &mut self,
451 name: &str,
452 config: HttpConfig,
453 ) -> Result<StepOutput, EngineError> {
454 self.execute_step(name, StepKind::Http, StepConfig::Http(config))
455 .await
456 }
457
458 pub async fn agent(
478 &mut self,
479 name: &str,
480 config: impl Into<AgentStepConfig>,
481 ) -> Result<StepOutput, EngineError> {
482 self.execute_step(name, StepKind::Agent, StepConfig::Agent(config.into()))
483 .await
484 }
485
486 pub async fn approval(
517 &mut self,
518 name: &str,
519 config: ApprovalConfig,
520 ) -> Result<(), EngineError> {
521 let position = self.position;
522 self.position += 1;
523
524 if let Some(existing) = self.replay_steps.get(&position)
527 && existing.kind == StepKind::Approval
528 {
529 if existing.status.state == StepStatus::AwaitingApproval {
530 self.store
531 .update_step(
532 existing.id,
533 StepUpdate {
534 status: Some(StepStatus::Completed),
535 completed_at: Some(Utc::now()),
536 ..StepUpdate::default()
537 },
538 )
539 .await?;
540 }
541
542 self.last_step_ids = vec![existing.id];
543 info!(
544 run_id = %self.run_id,
545 step = %name,
546 position,
547 "approval step replayed (approved)"
548 );
549 return Ok(());
550 }
551
552 let step = self
554 .store
555 .create_step(NewStep {
556 run_id: self.run_id,
557 name: name.to_string(),
558 kind: StepKind::Approval,
559 position,
560 input: Some(serde_json::to_value(&config)?),
561 })
562 .await?;
563
564 self.start_step(step.id, Utc::now()).await?;
565
566 self.store
569 .update_step(
570 step.id,
571 StepUpdate {
572 status: Some(StepStatus::AwaitingApproval),
573 ..StepUpdate::default()
574 },
575 )
576 .await?;
577
578 self.last_step_ids = vec![step.id];
579
580 Err(EngineError::ApprovalRequired {
581 run_id: self.run_id,
582 step_id: step.id,
583 message: config.message().to_string(),
584 })
585 }
586
587 pub async fn skip(&mut self, name: &str, reason: &str) -> Result<(), EngineError> {
616 let position = self.position;
617 self.position += 1;
618
619 let step = self
620 .store
621 .create_step(NewStep {
622 run_id: self.run_id,
623 name: name.to_string(),
624 kind: StepKind::Custom("skip".to_string()),
625 position,
626 input: None,
627 })
628 .await?;
629
630 if !self.last_step_ids.is_empty() {
631 let deps: Vec<NewStepDependency> = self
632 .last_step_ids
633 .iter()
634 .map(|&depends_on| NewStepDependency {
635 step_id: step.id,
636 depends_on,
637 })
638 .collect();
639 self.store.create_step_dependencies(deps).await?;
640 }
641
642 let now = Utc::now();
643 self.store
644 .update_step(
645 step.id,
646 StepUpdate {
647 status: Some(StepStatus::Skipped),
648 output: Some(serde_json::json!({"reason": reason})),
649 completed_at: Some(now),
650 ..StepUpdate::default()
651 },
652 )
653 .await?;
654
655 self.last_step_ids = vec![step.id];
656
657 info!(
658 run_id = %self.run_id,
659 step = %name,
660 reason,
661 "step skipped"
662 );
663
664 Ok(())
665 }
666
667 pub async fn operation(
705 &mut self,
706 name: &str,
707 op: &dyn Operation,
708 ) -> Result<StepOutput, EngineError> {
709 let kind = StepKind::Custom(op.kind().to_string());
710 let position = self.position;
711 self.position += 1;
712
713 let step = self
714 .store
715 .create_step(NewStep {
716 run_id: self.run_id,
717 name: name.to_string(),
718 kind,
719 position,
720 input: op.input(),
721 })
722 .await?;
723
724 self.start_step(step.id, Utc::now()).await?;
725
726 let start = Instant::now();
727
728 match op.execute().await {
729 Ok(output_value) => {
730 let duration_ms = start.elapsed().as_millis() as u64;
731 self.total_duration_ms += duration_ms;
732
733 let completed_at = Utc::now();
734 self.store
735 .update_step(
736 step.id,
737 StepUpdate {
738 status: Some(StepStatus::Completed),
739 output: Some(output_value.clone()),
740 duration_ms: Some(duration_ms),
741 cost_usd: Some(Decimal::ZERO),
742 completed_at: Some(completed_at),
743 ..StepUpdate::default()
744 },
745 )
746 .await?;
747
748 info!(
749 run_id = %self.run_id,
750 step = %name,
751 kind = op.kind(),
752 duration_ms,
753 "operation step completed"
754 );
755
756 self.last_step_ids = vec![step.id];
757
758 Ok(StepOutput {
759 output: output_value,
760 duration_ms,
761 cost_usd: Decimal::ZERO,
762 input_tokens: None,
763 output_tokens: None,
764 model: None,
765 debug_messages: None,
766 })
767 }
768 Err(err) => {
769 let completed_at = Utc::now();
770 if let Err(store_err) = self
771 .store
772 .update_step(
773 step.id,
774 StepUpdate {
775 status: Some(StepStatus::Failed),
776 error: Some(err.to_string()),
777 completed_at: Some(completed_at),
778 ..StepUpdate::default()
779 },
780 )
781 .await
782 {
783 error!(step_id = %step.id, error = %store_err, "failed to persist step failure");
784 }
785
786 Err(err)
787 }
788 }
789 }
790
791 pub async fn workflow(
818 &mut self,
819 handler: &dyn WorkflowHandler,
820 payload: Value,
821 ) -> Result<StepOutput, EngineError> {
822 let config = WorkflowStepConfig::new(handler.name(), payload);
823 let position = self.position;
824 self.position += 1;
825
826 let step = self
827 .store
828 .create_step(NewStep {
829 run_id: self.run_id,
830 name: config.workflow_name.clone(),
831 kind: StepKind::Workflow,
832 position,
833 input: Some(serde_json::to_value(&config)?),
834 })
835 .await?;
836
837 self.start_step(step.id, Utc::now()).await?;
838
839 match self.execute_child_workflow(&config).await {
840 Ok(output) => {
841 self.total_cost_usd += output.cost_usd;
842 self.total_duration_ms += output.duration_ms;
843
844 let completed_at = Utc::now();
845 self.store
846 .update_step(
847 step.id,
848 StepUpdate {
849 status: Some(StepStatus::Completed),
850 output: Some(output.output.clone()),
851 duration_ms: Some(output.duration_ms),
852 cost_usd: Some(output.cost_usd),
853 completed_at: Some(completed_at),
854 ..StepUpdate::default()
855 },
856 )
857 .await?;
858
859 info!(
860 run_id = %self.run_id,
861 child_workflow = %config.workflow_name,
862 duration_ms = output.duration_ms,
863 "workflow step completed"
864 );
865
866 self.last_step_ids = vec![step.id];
867
868 Ok(output)
869 }
870 Err(err) => {
871 let completed_at = Utc::now();
872 if let Err(store_err) = self
873 .store
874 .update_step(
875 step.id,
876 StepUpdate {
877 status: Some(StepStatus::Failed),
878 error: Some(err.to_string()),
879 completed_at: Some(completed_at),
880 ..StepUpdate::default()
881 },
882 )
883 .await
884 {
885 error!(step_id = %step.id, error = %store_err, "failed to persist step failure");
886 }
887
888 Err(err)
889 }
890 }
891 }
892
893 async fn execute_child_workflow(
895 &self,
896 config: &WorkflowStepConfig,
897 ) -> Result<StepOutput, EngineError> {
898 let resolver = self.handler_resolver.as_ref().ok_or_else(|| {
899 EngineError::InvalidWorkflow(
900 "sub-workflow requires a handler resolver (use Engine to execute)".to_string(),
901 )
902 })?;
903
904 let handler = resolver(&config.workflow_name).ok_or_else(|| {
905 EngineError::InvalidWorkflow(format!("no handler registered: {}", config.workflow_name))
906 })?;
907
908 let parent_labels = self
909 .store
910 .get_run(self.run_id)
911 .await?
912 .map(|r| r.labels)
913 .unwrap_or_default();
914
915 let child_run = self
916 .store
917 .create_run(NewRun {
918 workflow_name: config.workflow_name.clone(),
919 trigger: TriggerKind::Workflow,
920 payload: config.payload.clone(),
921 max_retries: 0,
922 handler_version: None,
923 labels: parent_labels,
924 scheduled_at: None,
925 })
926 .await?;
927
928 let child_run_id = child_run.id;
929 info!(
930 parent_run_id = %self.run_id,
931 child_run_id = %child_run_id,
932 workflow = %config.workflow_name,
933 "child run created"
934 );
935
936 self.store
937 .update_run_status(child_run_id, RunStatus::Running)
938 .await?;
939
940 let run_start = Instant::now();
941 let mut child_ctx = WorkflowContext {
942 run_id: child_run_id,
943 store: self.store.clone(),
944 provider: self.provider.clone(),
945 handler_resolver: self.handler_resolver.clone(),
946 position: 0,
947 last_step_ids: Vec::new(),
948 total_cost_usd: Decimal::ZERO,
949 total_duration_ms: 0,
950 replay_steps: HashMap::new(),
951 log_sender: self.log_sender.clone(),
952 };
953
954 let result = handler.execute(&mut child_ctx).await;
955 let total_duration = run_start.elapsed().as_millis() as u64;
956 let completed_at = Utc::now();
957
958 match result {
959 Ok(()) => {
960 self.store
961 .update_run(
962 child_run_id,
963 RunUpdate {
964 status: Some(RunStatus::Completed),
965 cost_usd: Some(child_ctx.total_cost_usd),
966 duration_ms: Some(total_duration),
967 completed_at: Some(completed_at),
968 ..RunUpdate::default()
969 },
970 )
971 .await?;
972
973 Ok(StepOutput {
974 output: serde_json::json!({
975 "run_id": child_run_id,
976 "workflow_name": config.workflow_name,
977 "status": RunStatus::Completed,
978 "cost_usd": child_ctx.total_cost_usd,
979 "duration_ms": total_duration,
980 }),
981 duration_ms: total_duration,
982 cost_usd: child_ctx.total_cost_usd,
983 input_tokens: None,
984 output_tokens: None,
985 model: None,
986 debug_messages: None,
987 })
988 }
989 Err(err) => {
990 if let Err(store_err) = self
991 .store
992 .update_run(
993 child_run_id,
994 RunUpdate {
995 status: Some(RunStatus::Failed),
996 error: Some(err.to_string()),
997 cost_usd: Some(child_ctx.total_cost_usd),
998 duration_ms: Some(total_duration),
999 completed_at: Some(completed_at),
1000 ..RunUpdate::default()
1001 },
1002 )
1003 .await
1004 {
1005 error!(
1006 child_run_id = %child_run_id,
1007 store_error = %store_err,
1008 "failed to persist child run failure"
1009 );
1010 }
1011
1012 Err(err)
1013 }
1014 }
1015 }
1016
1017 fn try_replay_step(&mut self, position: u32) -> Option<StepOutput> {
1022 let step = self.replay_steps.get(&position)?;
1023 if step.status.state != StepStatus::Completed {
1024 return None;
1025 }
1026 let output = StepOutput {
1027 output: step.output.clone().unwrap_or(Value::Null),
1028 duration_ms: step.duration_ms,
1029 cost_usd: step.cost_usd,
1030 input_tokens: step.input_tokens,
1031 output_tokens: step.output_tokens,
1032 model: None,
1033 debug_messages: None,
1034 };
1035 self.total_cost_usd += output.cost_usd;
1036 self.total_duration_ms += output.duration_ms;
1037 self.last_step_ids = vec![step.id];
1038 info!(
1039 run_id = %self.run_id,
1040 step = %step.name,
1041 position,
1042 "step replayed from previous execution"
1043 );
1044 Some(output)
1045 }
1046
1047 async fn execute_step(
1049 &mut self,
1050 name: &str,
1051 kind: StepKind,
1052 config: StepConfig,
1053 ) -> Result<StepOutput, EngineError> {
1054 let position = self.position;
1055 self.position += 1;
1056
1057 if let Some(output) = self.try_replay_step(position) {
1059 return Ok(output);
1060 }
1061
1062 let step = self
1064 .store
1065 .create_step(NewStep {
1066 run_id: self.run_id,
1067 name: name.to_string(),
1068 kind,
1069 position,
1070 input: Some(serde_json::to_value(&config)?),
1071 })
1072 .await?;
1073
1074 self.start_step(step.id, Utc::now()).await?;
1075
1076 let step_log_sender = self
1077 .log_sender
1078 .as_ref()
1079 .map(|s| StepLogSender::new(s.clone(), self.run_id, step.id, name.to_string()));
1080
1081 match execute_step_config(&config, &self.provider, step_log_sender).await {
1082 Ok(output) => {
1083 self.total_cost_usd += output.cost_usd;
1084 self.total_duration_ms += output.duration_ms;
1085
1086 let debug_messages_json = output.debug_messages_json();
1087
1088 let completed_at = Utc::now();
1089 self.store
1090 .update_step(
1091 step.id,
1092 StepUpdate {
1093 status: Some(StepStatus::Completed),
1094 output: Some(output.output.clone()),
1095 duration_ms: Some(output.duration_ms),
1096 cost_usd: Some(output.cost_usd),
1097 input_tokens: output.input_tokens,
1098 output_tokens: output.output_tokens,
1099 completed_at: Some(completed_at),
1100 debug_messages: debug_messages_json,
1101 ..StepUpdate::default()
1102 },
1103 )
1104 .await?;
1105
1106 info!(
1107 run_id = %self.run_id,
1108 step = %name,
1109 duration_ms = output.duration_ms,
1110 "step completed"
1111 );
1112
1113 self.last_step_ids = vec![step.id];
1114
1115 Ok(output)
1116 }
1117 Err(err) => {
1118 let completed_at = Utc::now();
1119 let debug_messages_json = extract_debug_messages_from_error(&err);
1120 let partial = extract_partial_usage_from_error(&err);
1121 let raw_response_output = extract_raw_response_from_error(&err);
1122
1123 if let Some(ref usage) = partial {
1124 if let Some(cost) = usage.cost_usd {
1125 self.total_cost_usd += cost;
1126 }
1127 if let Some(dur) = usage.duration_ms {
1128 self.total_duration_ms += dur;
1129 }
1130 }
1131
1132 if let Err(store_err) = self
1133 .store
1134 .update_step(
1135 step.id,
1136 StepUpdate {
1137 status: Some(StepStatus::Failed),
1138 error: Some(err.to_string()),
1139 output: raw_response_output,
1140 completed_at: Some(completed_at),
1141 debug_messages: debug_messages_json,
1142 duration_ms: partial.as_ref().and_then(|p| p.duration_ms),
1143 cost_usd: partial.as_ref().and_then(|p| p.cost_usd),
1144 input_tokens: partial.as_ref().and_then(|p| p.input_tokens),
1145 output_tokens: partial.as_ref().and_then(|p| p.output_tokens),
1146 ..StepUpdate::default()
1147 },
1148 )
1149 .await
1150 {
1151 tracing::error!(step_id = %step.id, error = %store_err, "failed to persist step failure");
1152 }
1153
1154 Err(err)
1155 }
1156 }
1157 }
1158
1159 async fn start_step(&self, step_id: Uuid, now: DateTime<Utc>) -> Result<(), EngineError> {
1164 if !self.last_step_ids.is_empty() {
1165 let deps: Vec<NewStepDependency> = self
1166 .last_step_ids
1167 .iter()
1168 .map(|&depends_on| NewStepDependency {
1169 step_id,
1170 depends_on,
1171 })
1172 .collect();
1173 self.store.create_step_dependencies(deps).await?;
1174 }
1175
1176 self.store
1177 .update_step(
1178 step_id,
1179 StepUpdate {
1180 status: Some(StepStatus::Running),
1181 started_at: Some(now),
1182 ..StepUpdate::default()
1183 },
1184 )
1185 .await?;
1186
1187 Ok(())
1188 }
1189
1190 pub fn store(&self) -> &Arc<dyn Store> {
1192 &self.store
1193 }
1194
1195 pub async fn payload(&self) -> Result<Value, EngineError> {
1203 let run = self
1204 .store
1205 .get_run(self.run_id)
1206 .await?
1207 .ok_or(EngineError::Store(
1208 ironflow_store::error::StoreError::RunNotFound(self.run_id),
1209 ))?;
1210 Ok(run.payload)
1211 }
1212}
1213
1214impl fmt::Debug for WorkflowContext {
1215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1216 f.debug_struct("WorkflowContext")
1217 .field("run_id", &self.run_id)
1218 .field("position", &self.position)
1219 .field("total_cost_usd", &self.total_cost_usd)
1220 .finish_non_exhaustive()
1221 }
1222}
1223
1224fn extract_debug_messages_from_error(err: &EngineError) -> Option<Value> {
1227 if let EngineError::Operation(OperationError::Agent(AgentError::SchemaValidation {
1228 debug_messages,
1229 ..
1230 })) = err
1231 && !debug_messages.is_empty()
1232 {
1233 return serde_json::to_value(debug_messages).ok();
1234 }
1235 None
1236}
1237
1238struct StepPartialUsage {
1244 cost_usd: Option<Decimal>,
1245 duration_ms: Option<u64>,
1246 input_tokens: Option<u64>,
1247 output_tokens: Option<u64>,
1248}
1249
1250fn extract_raw_response_from_error(err: &EngineError) -> Option<Value> {
1256 if let EngineError::Operation(OperationError::Agent(AgentError::SchemaValidation {
1257 raw_response: Some(text),
1258 ..
1259 })) = err
1260 {
1261 return Some(Value::String(text.clone()));
1262 }
1263 None
1264}
1265
1266fn extract_partial_usage_from_error(err: &EngineError) -> Option<StepPartialUsage> {
1267 if let EngineError::Operation(OperationError::Agent(AgentError::SchemaValidation {
1268 partial_usage,
1269 ..
1270 })) = err
1271 && (partial_usage.cost_usd.is_some() || partial_usage.duration_ms.is_some())
1272 {
1273 return Some(StepPartialUsage {
1274 cost_usd: partial_usage
1275 .cost_usd
1276 .and_then(|c| Decimal::try_from(c).ok()),
1277 duration_ms: partial_usage.duration_ms,
1278 input_tokens: partial_usage.input_tokens,
1279 output_tokens: partial_usage.output_tokens,
1280 });
1281 }
1282 None
1283}
1284
1285#[cfg(test)]
1286mod tests {
1287 use super::*;
1288 use ironflow_core::providers::claude::ClaudeCodeProvider;
1289 use ironflow_core::providers::record_replay::RecordReplayProvider;
1290 use ironflow_store::memory::InMemoryStore;
1291 use ironflow_store::models::RunFilter;
1292 use ironflow_store::store::RunStore;
1293 use serde_json::json;
1294 use std::sync::Arc;
1295 use std::sync::atomic::{AtomicBool, Ordering};
1296 use uuid::Uuid;
1297
1298 fn create_test_provider() -> Arc<dyn ironflow_core::provider::AgentProvider> {
1300 let inner = ClaudeCodeProvider::new();
1301 Arc::new(RecordReplayProvider::replay(
1302 inner,
1303 "/tmp/ironflow-fixtures",
1304 ))
1305 }
1306
1307 fn create_test_context() -> WorkflowContext {
1309 let store = Arc::new(InMemoryStore::new());
1310 let provider = create_test_provider();
1311 let run_id = Uuid::now_v7();
1312 WorkflowContext::new(run_id, store, provider)
1313 }
1314
1315 #[test]
1316 fn context_new_initializes_correctly() {
1317 let ctx = create_test_context();
1318 assert_eq!(ctx.position, 0);
1319 assert_eq!(ctx.total_cost_usd, Decimal::ZERO);
1320 assert_eq!(ctx.total_duration_ms, 0);
1321 assert!(ctx.last_step_ids.is_empty());
1322 assert!(ctx.replay_steps.is_empty());
1323 assert!(ctx.log_sender.is_none());
1324 }
1325
1326 #[test]
1327 fn context_run_id_returns_correct_id() {
1328 let run_id = Uuid::now_v7();
1329 let store = Arc::new(InMemoryStore::new());
1330 let provider = create_test_provider();
1331 let ctx = WorkflowContext::new(run_id, store, provider);
1332 assert_eq!(ctx.run_id(), run_id);
1333 }
1334
1335 #[test]
1336 fn context_total_cost_usd_initially_zero() {
1337 let ctx = create_test_context();
1338 assert_eq!(ctx.total_cost_usd(), Decimal::ZERO);
1339 }
1340
1341 #[test]
1342 fn context_total_duration_ms_initially_zero() {
1343 let ctx = create_test_context();
1344 assert_eq!(ctx.total_duration_ms(), 0);
1345 }
1346
1347 #[test]
1348 fn context_with_handler_resolver_creates_context_with_resolver() {
1349 let store = Arc::new(InMemoryStore::new());
1350 let provider = create_test_provider();
1351 let run_id = Uuid::now_v7();
1352
1353 let called = Arc::new(AtomicBool::new(false));
1354 let called_clone = called.clone();
1355
1356 let resolver: HandlerResolver = Arc::new(move |_name: &str| {
1357 called_clone.store(true, Ordering::SeqCst);
1358 None
1359 });
1360
1361 let ctx = WorkflowContext::with_handler_resolver(run_id, store, provider, resolver);
1362
1363 assert_eq!(ctx.run_id(), run_id);
1364 assert!(ctx.handler_resolver.is_some());
1365 }
1366
1367 #[tokio::test]
1368 async fn context_set_log_sender_attaches_sender() {
1369 let mut ctx = create_test_context();
1370 let (sender, _receiver) = crate::log_sender::channel();
1371 ctx.set_log_sender(sender);
1372 assert!(ctx.log_sender.is_some());
1373 }
1374
1375 #[tokio::test]
1376 async fn context_skip_creates_skipped_step() {
1377 let store = Arc::new(InMemoryStore::new());
1378 let provider = create_test_provider();
1379
1380 store
1382 .create_run(NewRun {
1383 workflow_name: "test".to_string(),
1384 trigger: TriggerKind::Manual,
1385 payload: json!({}),
1386 max_retries: 0,
1387 handler_version: None,
1388 labels: Default::default(),
1389 scheduled_at: None,
1390 })
1391 .await
1392 .expect("failed to create run");
1393
1394 let runs = store
1396 .list_runs(RunFilter::default(), 1, 10)
1397 .await
1398 .expect("failed to list runs");
1399 let created_run_id = runs.items[0].id;
1400
1401 let mut ctx = WorkflowContext::new(created_run_id, store.clone(), provider);
1402 let initial_position = ctx.position;
1403
1404 ctx.skip("skip-step", "condition not met")
1405 .await
1406 .expect("skip failed");
1407
1408 assert_eq!(ctx.position, initial_position + 1);
1409 assert!(!ctx.last_step_ids.is_empty());
1410
1411 let steps = store
1413 .list_steps(created_run_id)
1414 .await
1415 .expect("failed to list steps");
1416 assert_eq!(steps.len(), 1);
1417 assert_eq!(steps[0].status.state, StepStatus::Skipped);
1418 }
1419
1420 #[tokio::test]
1421 async fn context_parallel_empty_steps_returns_empty_vec() {
1422 let mut ctx = create_test_context();
1423 let results = ctx
1424 .parallel(vec![], true)
1425 .await
1426 .expect("parallel should not fail on empty input");
1427 assert!(results.is_empty());
1428 }
1429
1430 #[tokio::test]
1431 async fn context_approval_first_execution_returns_error() {
1432 let store = Arc::new(InMemoryStore::new());
1433 let provider = create_test_provider();
1434
1435 store
1437 .create_run(NewRun {
1438 workflow_name: "test".to_string(),
1439 trigger: TriggerKind::Manual,
1440 payload: json!({}),
1441 max_retries: 0,
1442 handler_version: None,
1443 labels: Default::default(),
1444 scheduled_at: None,
1445 })
1446 .await
1447 .expect("failed to create run");
1448
1449 let runs = store
1451 .list_runs(RunFilter::default(), 1, 10)
1452 .await
1453 .expect("failed to list runs");
1454 let created_run_id = runs.items[0].id;
1455
1456 let mut ctx = WorkflowContext::new(created_run_id, store.clone(), provider);
1457
1458 let result = ctx
1459 .approval(
1460 "approve-step",
1461 crate::config::ApprovalConfig::new("Continue?"),
1462 )
1463 .await;
1464
1465 assert!(matches!(result, Err(EngineError::ApprovalRequired { .. })));
1467
1468 assert_eq!(ctx.position, 1);
1470
1471 let steps = store
1473 .list_steps(created_run_id)
1474 .await
1475 .expect("failed to list steps");
1476 assert_eq!(steps.len(), 1);
1477 assert_eq!(steps[0].status.state, StepStatus::AwaitingApproval);
1478 }
1479
1480 #[tokio::test]
1481 async fn context_approval_replay_returns_ok() {
1482 let store = Arc::new(InMemoryStore::new());
1483 let provider = create_test_provider();
1484
1485 store
1487 .create_run(NewRun {
1488 workflow_name: "test".to_string(),
1489 trigger: TriggerKind::Manual,
1490 payload: json!({}),
1491 max_retries: 0,
1492 handler_version: None,
1493 labels: Default::default(),
1494 scheduled_at: None,
1495 })
1496 .await
1497 .expect("failed to create run");
1498
1499 let runs = store
1501 .list_runs(RunFilter::default(), 1, 10)
1502 .await
1503 .expect("failed to list runs");
1504 let created_run_id = runs.items[0].id;
1505
1506 let step = store
1508 .create_step(NewStep {
1509 run_id: created_run_id,
1510 name: "approval".to_string(),
1511 kind: StepKind::Approval,
1512 position: 0,
1513 input: None,
1514 })
1515 .await
1516 .expect("failed to create step");
1517
1518 store
1520 .update_step(
1521 step.id,
1522 StepUpdate {
1523 status: Some(StepStatus::Running),
1524 started_at: Some(Utc::now()),
1525 ..StepUpdate::default()
1526 },
1527 )
1528 .await
1529 .expect("failed to update step to Running");
1530
1531 store
1532 .update_step(
1533 step.id,
1534 StepUpdate {
1535 status: Some(StepStatus::AwaitingApproval),
1536 ..StepUpdate::default()
1537 },
1538 )
1539 .await
1540 .expect("failed to update step to AwaitingApproval");
1541
1542 let mut ctx = WorkflowContext::new(created_run_id, store.clone(), provider);
1544 ctx.load_replay_steps()
1545 .await
1546 .expect("failed to load replay steps");
1547
1548 let result = ctx
1550 .approval("approval", crate::config::ApprovalConfig::new("Continue?"))
1551 .await;
1552
1553 assert!(result.is_ok());
1554
1555 let steps = store
1557 .list_steps(created_run_id)
1558 .await
1559 .expect("failed to list steps");
1560 assert_eq!(steps.len(), 1);
1561 assert_eq!(steps[0].status.state, StepStatus::Completed);
1562 }
1563
1564 #[tokio::test]
1565 async fn context_load_replay_steps_loads_completed_steps() {
1566 let store = Arc::new(InMemoryStore::new());
1567 let provider = create_test_provider();
1568
1569 store
1571 .create_run(NewRun {
1572 workflow_name: "test".to_string(),
1573 trigger: TriggerKind::Manual,
1574 payload: json!({}),
1575 max_retries: 0,
1576 handler_version: None,
1577 labels: Default::default(),
1578 scheduled_at: None,
1579 })
1580 .await
1581 .expect("failed to create run");
1582
1583 let runs = store
1585 .list_runs(RunFilter::default(), 1, 10)
1586 .await
1587 .expect("failed to list runs");
1588 let created_run_id = runs.items[0].id;
1589
1590 let completed_step = store
1592 .create_step(NewStep {
1593 run_id: created_run_id,
1594 name: "completed".to_string(),
1595 kind: StepKind::Shell,
1596 position: 0,
1597 input: None,
1598 })
1599 .await
1600 .expect("failed to create step");
1601
1602 store
1604 .update_step(
1605 completed_step.id,
1606 StepUpdate {
1607 status: Some(StepStatus::Running),
1608 started_at: Some(Utc::now()),
1609 ..StepUpdate::default()
1610 },
1611 )
1612 .await
1613 .expect("failed to update step to Running");
1614
1615 store
1616 .update_step(
1617 completed_step.id,
1618 StepUpdate {
1619 status: Some(StepStatus::Completed),
1620 completed_at: Some(Utc::now()),
1621 ..StepUpdate::default()
1622 },
1623 )
1624 .await
1625 .expect("failed to update step to Completed");
1626
1627 let _pending_step = store
1628 .create_step(NewStep {
1629 run_id: created_run_id,
1630 name: "pending".to_string(),
1631 kind: StepKind::Shell,
1632 position: 1,
1633 input: None,
1634 })
1635 .await
1636 .expect("failed to create step");
1637
1638 let mut ctx = WorkflowContext::new(created_run_id, store, provider);
1640 ctx.load_replay_steps()
1641 .await
1642 .expect("failed to load replay steps");
1643
1644 assert_eq!(ctx.replay_steps.len(), 1);
1646 assert!(ctx.replay_steps.contains_key(&0));
1647 assert!(!ctx.replay_steps.contains_key(&1));
1648 }
1649
1650 #[tokio::test]
1651 async fn context_payload_returns_run_payload() {
1652 let store = Arc::new(InMemoryStore::new());
1653 let provider = create_test_provider();
1654 let test_payload = json!({"key": "value", "number": 42});
1655
1656 store
1658 .create_run(NewRun {
1659 workflow_name: "test".to_string(),
1660 trigger: TriggerKind::Manual,
1661 payload: test_payload.clone(),
1662 max_retries: 0,
1663 handler_version: None,
1664 labels: Default::default(),
1665 scheduled_at: None,
1666 })
1667 .await
1668 .expect("failed to create run");
1669
1670 let runs = store
1672 .list_runs(RunFilter::default(), 1, 10)
1673 .await
1674 .expect("failed to list runs");
1675 let created_run_id = runs.items[0].id;
1676
1677 let ctx = WorkflowContext::new(created_run_id, store, provider);
1678 let payload = ctx.payload().await.expect("failed to get payload");
1679
1680 assert_eq!(payload, test_payload);
1681 }
1682
1683 #[tokio::test]
1684 async fn context_payload_returns_error_for_nonexistent_run() {
1685 let store = Arc::new(InMemoryStore::new());
1686 let provider = create_test_provider();
1687 let run_id = Uuid::now_v7();
1688
1689 let ctx = WorkflowContext::new(run_id, store, provider);
1690 let result = ctx.payload().await;
1691
1692 assert!(result.is_err());
1693 }
1694
1695 #[tokio::test]
1696 async fn context_store_returns_reference() {
1697 let ctx = create_test_context();
1698 let _store = ctx.store();
1699 }
1701
1702 #[test]
1703 fn context_debug_formatting() {
1704 let ctx = create_test_context();
1705 let debug_str = format!("{:?}", ctx);
1706 assert!(debug_str.contains("WorkflowContext"));
1707 assert!(debug_str.contains("run_id"));
1708 }
1709
1710 #[tokio::test]
1711 async fn context_last_step_ids_tracks_executed_steps() {
1712 let store = Arc::new(InMemoryStore::new());
1713 let provider = create_test_provider();
1714
1715 store
1717 .create_run(NewRun {
1718 workflow_name: "test".to_string(),
1719 trigger: TriggerKind::Manual,
1720 payload: json!({}),
1721 max_retries: 0,
1722 handler_version: None,
1723 labels: Default::default(),
1724 scheduled_at: None,
1725 })
1726 .await
1727 .expect("failed to create run");
1728
1729 let runs = store
1731 .list_runs(RunFilter::default(), 1, 10)
1732 .await
1733 .expect("failed to list runs");
1734 let created_run_id = runs.items[0].id;
1735
1736 let mut ctx = WorkflowContext::new(created_run_id, store, provider);
1737 assert!(ctx.last_step_ids.is_empty());
1738
1739 ctx.skip("step1", "reason").await.expect("skip failed");
1740
1741 assert_eq!(ctx.last_step_ids.len(), 1);
1742
1743 ctx.skip("step2", "reason").await.expect("skip failed");
1744
1745 assert_eq!(ctx.last_step_ids.len(), 1);
1747 }
1748}