use std::fmt;
use std::sync::Arc;
use std::time::Instant;
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use serde_json::Value;
use tokio::task::JoinSet;
use tracing::{error, info};
use uuid::Uuid;
use ironflow_core::error::{AgentError, OperationError};
use ironflow_core::provider::AgentProvider;
use ironflow_store::models::{
NewRun, NewStep, NewStepDependency, RunStatus, RunUpdate, Step, StepKind, StepStatus,
StepUpdate, TriggerKind,
};
use ironflow_store::store::RunStore;
use crate::config::{
AgentStepConfig, ApprovalConfig, HttpConfig, ShellConfig, StepConfig, WorkflowStepConfig,
};
use crate::error::EngineError;
use crate::executor::{ParallelStepResult, StepOutput, execute_step_config};
use crate::handler::WorkflowHandler;
use crate::operation::Operation;
pub(crate) type HandlerResolver =
Arc<dyn Fn(&str) -> Option<Arc<dyn WorkflowHandler>> + Send + Sync>;
pub struct WorkflowContext {
run_id: Uuid,
store: Arc<dyn RunStore>,
provider: Arc<dyn AgentProvider>,
handler_resolver: Option<HandlerResolver>,
position: u32,
last_step_ids: Vec<Uuid>,
total_cost_usd: Decimal,
total_duration_ms: u64,
replay_steps: std::collections::HashMap<u32, Step>,
}
impl WorkflowContext {
pub fn new(run_id: Uuid, store: Arc<dyn RunStore>, provider: Arc<dyn AgentProvider>) -> Self {
Self {
run_id,
store,
provider,
handler_resolver: None,
position: 0,
last_step_ids: Vec::new(),
total_cost_usd: Decimal::ZERO,
total_duration_ms: 0,
replay_steps: std::collections::HashMap::new(),
}
}
pub(crate) fn with_handler_resolver(
run_id: Uuid,
store: Arc<dyn RunStore>,
provider: Arc<dyn AgentProvider>,
resolver: HandlerResolver,
) -> Self {
Self {
run_id,
store,
provider,
handler_resolver: Some(resolver),
position: 0,
last_step_ids: Vec::new(),
total_cost_usd: Decimal::ZERO,
total_duration_ms: 0,
replay_steps: std::collections::HashMap::new(),
}
}
pub(crate) async fn load_replay_steps(&mut self) -> Result<(), EngineError> {
let steps = self.store.list_steps(self.run_id).await?;
for step in steps {
let dominated = matches!(
step.status.state,
StepStatus::Completed | StepStatus::Running | StepStatus::AwaitingApproval
);
if dominated {
self.replay_steps.insert(step.position, step);
}
}
Ok(())
}
pub fn run_id(&self) -> Uuid {
self.run_id
}
pub fn total_cost_usd(&self) -> Decimal {
self.total_cost_usd
}
pub fn total_duration_ms(&self) -> u64 {
self.total_duration_ms
}
pub async fn parallel(
&mut self,
steps: Vec<(&str, StepConfig)>,
fail_fast: bool,
) -> Result<Vec<ParallelStepResult>, EngineError> {
if steps.is_empty() {
return Ok(Vec::new());
}
let wave_position = self.position;
self.position += 1;
let now = Utc::now();
let mut step_records: Vec<(Uuid, String, StepConfig)> = Vec::with_capacity(steps.len());
for (name, config) in &steps {
let kind = config.kind();
let step = self
.store
.create_step(NewStep {
run_id: self.run_id,
name: name.to_string(),
kind,
position: wave_position,
input: Some(serde_json::to_value(config)?),
})
.await?;
self.start_step(step.id, now).await?;
step_records.push((step.id, name.to_string(), config.clone()));
}
let mut join_set = JoinSet::new();
for (idx, (_id, _name, config)) in step_records.iter().enumerate() {
let provider = self.provider.clone();
let config = config.clone();
join_set.spawn(async move { (idx, execute_step_config(&config, &provider).await) });
}
let mut indexed_results: Vec<Option<Result<StepOutput, String>>> =
vec![None; step_records.len()];
let mut first_error: Option<EngineError> = None;
while let Some(join_result) = join_set.join_next().await {
let (idx, step_result) = match join_result {
Ok(r) => r,
Err(e) => {
if first_error.is_none() {
first_error = Some(EngineError::StepConfig(format!("join error: {e}")));
}
if fail_fast {
join_set.abort_all();
}
continue;
}
};
let (step_id, step_name, _) = &step_records[idx];
let completed_at = Utc::now();
match step_result {
Ok(output) => {
self.total_cost_usd += output.cost_usd;
self.total_duration_ms += output.duration_ms;
let debug_messages_json = output.debug_messages_json();
self.store
.update_step(
*step_id,
StepUpdate {
status: Some(StepStatus::Completed),
output: Some(output.output.clone()),
duration_ms: Some(output.duration_ms),
cost_usd: Some(output.cost_usd),
input_tokens: output.input_tokens,
output_tokens: output.output_tokens,
completed_at: Some(completed_at),
debug_messages: debug_messages_json,
..StepUpdate::default()
},
)
.await?;
info!(
run_id = %self.run_id,
step = %step_name,
duration_ms = output.duration_ms,
"parallel step completed"
);
indexed_results[idx] = Some(Ok(output));
}
Err(err) => {
let err_msg = err.to_string();
let debug_messages_json = extract_debug_messages_from_error(&err);
if let Err(store_err) = self
.store
.update_step(
*step_id,
StepUpdate {
status: Some(StepStatus::Failed),
error: Some(err_msg.clone()),
completed_at: Some(completed_at),
debug_messages: debug_messages_json,
..StepUpdate::default()
},
)
.await
{
tracing::error!(
step_id = %step_id,
error = %store_err,
"failed to persist parallel step failure"
);
}
indexed_results[idx] = Some(Err(err_msg.clone()));
if first_error.is_none() {
first_error = Some(err);
}
if fail_fast {
join_set.abort_all();
}
}
}
}
if let Some(err) = first_error {
return Err(err);
}
self.last_step_ids = step_records.iter().map(|(id, _, _)| *id).collect();
let results: Vec<ParallelStepResult> = step_records
.iter()
.enumerate()
.map(|(idx, (step_id, name, _))| {
let output = match indexed_results[idx].take() {
Some(Ok(o)) => o,
_ => unreachable!("all steps succeeded if no error returned"),
};
ParallelStepResult {
name: name.clone(),
output,
step_id: *step_id,
}
})
.collect();
Ok(results)
}
pub async fn shell(
&mut self,
name: &str,
config: ShellConfig,
) -> Result<StepOutput, EngineError> {
self.execute_step(name, StepKind::Shell, StepConfig::Shell(config))
.await
}
pub async fn http(
&mut self,
name: &str,
config: HttpConfig,
) -> Result<StepOutput, EngineError> {
self.execute_step(name, StepKind::Http, StepConfig::Http(config))
.await
}
pub async fn agent(
&mut self,
name: &str,
config: impl Into<AgentStepConfig>,
) -> Result<StepOutput, EngineError> {
self.execute_step(name, StepKind::Agent, StepConfig::Agent(config.into()))
.await
}
pub async fn approval(
&mut self,
name: &str,
config: ApprovalConfig,
) -> Result<(), EngineError> {
let position = self.position;
self.position += 1;
if let Some(existing) = self.replay_steps.get(&position)
&& existing.kind == StepKind::Approval
{
if existing.status.state == StepStatus::AwaitingApproval {
self.store
.update_step(
existing.id,
StepUpdate {
status: Some(StepStatus::Completed),
completed_at: Some(Utc::now()),
..StepUpdate::default()
},
)
.await?;
}
self.last_step_ids = vec![existing.id];
info!(
run_id = %self.run_id,
step = %name,
position,
"approval step replayed (approved)"
);
return Ok(());
}
let step = self
.store
.create_step(NewStep {
run_id: self.run_id,
name: name.to_string(),
kind: StepKind::Approval,
position,
input: Some(serde_json::to_value(&config)?),
})
.await?;
self.start_step(step.id, Utc::now()).await?;
self.store
.update_step(
step.id,
StepUpdate {
status: Some(StepStatus::AwaitingApproval),
..StepUpdate::default()
},
)
.await?;
self.last_step_ids = vec![step.id];
Err(EngineError::ApprovalRequired {
run_id: self.run_id,
step_id: step.id,
message: config.message().to_string(),
})
}
pub async fn operation(
&mut self,
name: &str,
op: &dyn Operation,
) -> Result<StepOutput, EngineError> {
let kind = StepKind::Custom(op.kind().to_string());
let position = self.position;
self.position += 1;
let step = self
.store
.create_step(NewStep {
run_id: self.run_id,
name: name.to_string(),
kind,
position,
input: op.input(),
})
.await?;
self.start_step(step.id, Utc::now()).await?;
let start = Instant::now();
match op.execute().await {
Ok(output_value) => {
let duration_ms = start.elapsed().as_millis() as u64;
self.total_duration_ms += duration_ms;
let completed_at = Utc::now();
self.store
.update_step(
step.id,
StepUpdate {
status: Some(StepStatus::Completed),
output: Some(output_value.clone()),
duration_ms: Some(duration_ms),
cost_usd: Some(Decimal::ZERO),
completed_at: Some(completed_at),
..StepUpdate::default()
},
)
.await?;
info!(
run_id = %self.run_id,
step = %name,
kind = op.kind(),
duration_ms,
"operation step completed"
);
self.last_step_ids = vec![step.id];
Ok(StepOutput {
output: output_value,
duration_ms,
cost_usd: Decimal::ZERO,
input_tokens: None,
output_tokens: None,
debug_messages: None,
})
}
Err(err) => {
let completed_at = Utc::now();
if let Err(store_err) = self
.store
.update_step(
step.id,
StepUpdate {
status: Some(StepStatus::Failed),
error: Some(err.to_string()),
completed_at: Some(completed_at),
..StepUpdate::default()
},
)
.await
{
error!(step_id = %step.id, error = %store_err, "failed to persist step failure");
}
Err(err)
}
}
}
pub async fn workflow(
&mut self,
handler: &dyn WorkflowHandler,
payload: Value,
) -> Result<StepOutput, EngineError> {
let config = WorkflowStepConfig::new(handler.name(), payload);
let position = self.position;
self.position += 1;
let step = self
.store
.create_step(NewStep {
run_id: self.run_id,
name: config.workflow_name.clone(),
kind: StepKind::Workflow,
position,
input: Some(serde_json::to_value(&config)?),
})
.await?;
self.start_step(step.id, Utc::now()).await?;
match self.execute_child_workflow(&config).await {
Ok(output) => {
self.total_cost_usd += output.cost_usd;
self.total_duration_ms += output.duration_ms;
let completed_at = Utc::now();
self.store
.update_step(
step.id,
StepUpdate {
status: Some(StepStatus::Completed),
output: Some(output.output.clone()),
duration_ms: Some(output.duration_ms),
cost_usd: Some(output.cost_usd),
completed_at: Some(completed_at),
..StepUpdate::default()
},
)
.await?;
info!(
run_id = %self.run_id,
child_workflow = %config.workflow_name,
duration_ms = output.duration_ms,
"workflow step completed"
);
self.last_step_ids = vec![step.id];
Ok(output)
}
Err(err) => {
let completed_at = Utc::now();
if let Err(store_err) = self
.store
.update_step(
step.id,
StepUpdate {
status: Some(StepStatus::Failed),
error: Some(err.to_string()),
completed_at: Some(completed_at),
..StepUpdate::default()
},
)
.await
{
error!(step_id = %step.id, error = %store_err, "failed to persist step failure");
}
Err(err)
}
}
}
async fn execute_child_workflow(
&self,
config: &WorkflowStepConfig,
) -> Result<StepOutput, EngineError> {
let resolver = self.handler_resolver.as_ref().ok_or_else(|| {
EngineError::InvalidWorkflow(
"sub-workflow requires a handler resolver (use Engine to execute)".to_string(),
)
})?;
let handler = resolver(&config.workflow_name).ok_or_else(|| {
EngineError::InvalidWorkflow(format!("no handler registered: {}", config.workflow_name))
})?;
let child_run = self
.store
.create_run(NewRun {
workflow_name: config.workflow_name.clone(),
trigger: TriggerKind::Workflow,
payload: config.payload.clone(),
max_retries: 0,
})
.await?;
let child_run_id = child_run.id;
info!(
parent_run_id = %self.run_id,
child_run_id = %child_run_id,
workflow = %config.workflow_name,
"child run created"
);
self.store
.update_run_status(child_run_id, RunStatus::Running)
.await?;
let run_start = Instant::now();
let mut child_ctx = WorkflowContext {
run_id: child_run_id,
store: self.store.clone(),
provider: self.provider.clone(),
handler_resolver: self.handler_resolver.clone(),
position: 0,
last_step_ids: Vec::new(),
total_cost_usd: Decimal::ZERO,
total_duration_ms: 0,
replay_steps: std::collections::HashMap::new(),
};
let result = handler.execute(&mut child_ctx).await;
let total_duration = run_start.elapsed().as_millis() as u64;
let completed_at = Utc::now();
match result {
Ok(()) => {
self.store
.update_run(
child_run_id,
RunUpdate {
status: Some(RunStatus::Completed),
cost_usd: Some(child_ctx.total_cost_usd),
duration_ms: Some(total_duration),
completed_at: Some(completed_at),
..RunUpdate::default()
},
)
.await?;
Ok(StepOutput {
output: serde_json::json!({
"run_id": child_run_id,
"workflow_name": config.workflow_name,
"status": RunStatus::Completed,
"cost_usd": child_ctx.total_cost_usd,
"duration_ms": total_duration,
}),
duration_ms: total_duration,
cost_usd: child_ctx.total_cost_usd,
input_tokens: None,
output_tokens: None,
debug_messages: None,
})
}
Err(err) => {
if let Err(store_err) = self
.store
.update_run(
child_run_id,
RunUpdate {
status: Some(RunStatus::Failed),
error: Some(err.to_string()),
cost_usd: Some(child_ctx.total_cost_usd),
duration_ms: Some(total_duration),
completed_at: Some(completed_at),
..RunUpdate::default()
},
)
.await
{
error!(
child_run_id = %child_run_id,
store_error = %store_err,
"failed to persist child run failure"
);
}
Err(err)
}
}
}
fn try_replay_step(&mut self, position: u32) -> Option<StepOutput> {
let step = self.replay_steps.get(&position)?;
if step.status.state != StepStatus::Completed {
return None;
}
let output = StepOutput {
output: step.output.clone().unwrap_or(Value::Null),
duration_ms: step.duration_ms,
cost_usd: step.cost_usd,
input_tokens: step.input_tokens,
output_tokens: step.output_tokens,
debug_messages: None,
};
self.total_cost_usd += output.cost_usd;
self.total_duration_ms += output.duration_ms;
self.last_step_ids = vec![step.id];
info!(
run_id = %self.run_id,
step = %step.name,
position,
"step replayed from previous execution"
);
Some(output)
}
async fn execute_step(
&mut self,
name: &str,
kind: StepKind,
config: StepConfig,
) -> Result<StepOutput, EngineError> {
let position = self.position;
self.position += 1;
if let Some(output) = self.try_replay_step(position) {
return Ok(output);
}
let step = self
.store
.create_step(NewStep {
run_id: self.run_id,
name: name.to_string(),
kind,
position,
input: Some(serde_json::to_value(&config)?),
})
.await?;
self.start_step(step.id, Utc::now()).await?;
match execute_step_config(&config, &self.provider).await {
Ok(output) => {
self.total_cost_usd += output.cost_usd;
self.total_duration_ms += output.duration_ms;
let debug_messages_json = output.debug_messages_json();
let completed_at = Utc::now();
self.store
.update_step(
step.id,
StepUpdate {
status: Some(StepStatus::Completed),
output: Some(output.output.clone()),
duration_ms: Some(output.duration_ms),
cost_usd: Some(output.cost_usd),
input_tokens: output.input_tokens,
output_tokens: output.output_tokens,
completed_at: Some(completed_at),
debug_messages: debug_messages_json,
..StepUpdate::default()
},
)
.await?;
info!(
run_id = %self.run_id,
step = %name,
duration_ms = output.duration_ms,
"step completed"
);
self.last_step_ids = vec![step.id];
Ok(output)
}
Err(err) => {
let completed_at = Utc::now();
let debug_messages_json = extract_debug_messages_from_error(&err);
if let Err(store_err) = self
.store
.update_step(
step.id,
StepUpdate {
status: Some(StepStatus::Failed),
error: Some(err.to_string()),
completed_at: Some(completed_at),
debug_messages: debug_messages_json,
..StepUpdate::default()
},
)
.await
{
tracing::error!(step_id = %step.id, error = %store_err, "failed to persist step failure");
}
Err(err)
}
}
}
async fn start_step(&self, step_id: Uuid, now: DateTime<Utc>) -> Result<(), EngineError> {
if !self.last_step_ids.is_empty() {
let deps: Vec<NewStepDependency> = self
.last_step_ids
.iter()
.map(|&depends_on| NewStepDependency {
step_id,
depends_on,
})
.collect();
self.store.create_step_dependencies(deps).await?;
}
self.store
.update_step(
step_id,
StepUpdate {
status: Some(StepStatus::Running),
started_at: Some(now),
..StepUpdate::default()
},
)
.await?;
Ok(())
}
pub fn store(&self) -> &Arc<dyn RunStore> {
&self.store
}
pub async fn payload(&self) -> Result<Value, EngineError> {
let run = self
.store
.get_run(self.run_id)
.await?
.ok_or(EngineError::Store(
ironflow_store::error::StoreError::RunNotFound(self.run_id),
))?;
Ok(run.payload)
}
}
impl fmt::Debug for WorkflowContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WorkflowContext")
.field("run_id", &self.run_id)
.field("position", &self.position)
.field("total_cost_usd", &self.total_cost_usd)
.finish_non_exhaustive()
}
}
fn extract_debug_messages_from_error(err: &EngineError) -> Option<Value> {
if let EngineError::Operation(OperationError::Agent(AgentError::SchemaValidation {
debug_messages,
..
})) = err
&& !debug_messages.is_empty()
{
return serde_json::to_value(debug_messages).ok();
}
None
}