use crate::time::Instant;
#[cfg(target_family = "wasm")]
use crate::tracing_wasm::WasmInstrument;
use crate::{
JunctureError, Node, State,
config::RunnableConfig,
graph::{RetryPolicy, execute_with_retry},
info_span,
interrupt::{InterruptContext, InterruptSignal, ResumeValue, Scratchpad},
pregel::context::TimeoutPolicy,
pregel::types::{PendingTask, SuperstepResult, TaskOutput},
runtime::Heartbeat,
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
#[cfg(not(target_family = "wasm"))]
use tracing::Instrument;
#[cfg(feature = "otel")]
use tracing::{Level, event};
#[expect(
clippy::too_many_arguments,
reason = "mirrors execute_superstep's parameter list for consistency; all parameters are necessary for single-task execution"
)]
async fn try_execute_single_task_inline<S: State>(
pending_tasks: &[PendingTask<S>],
arc_state: &Arc<S>,
nodes: &indexmap::IndexMap<String, Arc<dyn Node<S>>>,
config: &RunnableConfig,
cancellation_token: &CancellationToken,
checkpointer: Option<&Arc<dyn crate::checkpoint::CheckpointSaver>>,
error_handler_map: &HashMap<String, String>,
retry_policies: &HashMap<String, RetryPolicy>,
timeout_policies: &HashMap<String, TimeoutPolicy>,
fallback_map: &HashMap<String, String>,
) -> Result<Option<SuperstepResult<S>>, JunctureError>
where
S::Update: serde::Serialize,
{
if pending_tasks.len() != 1 {
return Ok(None);
}
let task = &pending_tasks[0];
let has_error_handler = error_handler_map.contains_key(&task.node_name);
let has_retry = retry_policies.contains_key(&task.node_name);
let has_timeout = timeout_policies.contains_key(&task.node_name);
let has_fallback = fallback_map.contains_key(&task.node_name);
if has_error_handler || has_retry || has_timeout || has_fallback {
return Ok(None);
}
let node = nodes
.get(&task.node_name)
.ok_or_else(|| JunctureError::execution(format!("Node '{}' not found", task.node_name)))?;
let task_state: Arc<S> = task
.state_override
.clone()
.map_or_else(|| Arc::clone(arc_state), Arc::new);
let task_config = config.clone();
let task_id = task.id.clone();
let node_name = task.node_name.clone();
let task_trigger = task.trigger.clone();
if cancellation_token.is_cancelled() {
return Err(JunctureError::execution("Task cancelled"));
}
let start = Instant::now();
let result = if let Some(ref tracker) = task_config.budget_tracker {
let tracker_ref = Arc::clone(tracker);
crate::pregel::BUDGET_TRACKER
.scope(tracker_ref, node.call_arc(task_state, &task_config))
.await
} else {
node.call_arc(task_state, &task_config).await
};
let duration = start.elapsed();
if let Some(ref collector) = config.metrics_collector {
#[allow(
clippy::cast_precision_loss,
reason = "Milliseconds as f64 is sufficient for histogram metrics"
)]
collector.record_histogram("juncture.node.duration_ms", duration.as_millis() as f64);
}
#[cfg(feature = "otel")]
{
event!(
name: "juncture.node.execute.metrics",
Level::DEBUG,
node_name = %node_name,
duration_ms = duration.as_millis(),
success = result.is_ok(),
output_type = "inline",
);
};
if let Some(ref handler) = config.callback_handler {
match &result {
Ok(_) => {
let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
handler.on_node_end(&node_name, &task_id, duration_ms);
}
Err(err) => {
handler.on_node_error(&node_name, err);
}
}
}
let command = result.inspect_err(|_e| {
cancellation_token.cancel();
})?;
let output = TaskOutput {
task_id,
node_name,
command,
duration,
trigger: task_trigger,
triggered_fields: Vec::new(),
error: None,
circuit_blocked: false,
};
if let Some(cp) = checkpointer
&& let Some(ref update) = output.command.update
{
let writes = serialize_pending_writes(&output.task_id, update);
if !writes.is_empty() {
let _ = cp.put_writes(config, writes, &output.task_id).await;
}
}
Ok(Some(SuperstepResult {
task_outputs: vec![output],
bubble_ups: Vec::new(),
}))
}
#[expect(
clippy::too_many_lines,
reason = "execute_superstep requires: early return, semaphore creation, interrupt context setup, task spawning with span creation, timeout/retry wrapping, and result collection. The length is justified by the complexity of parallel execution with proper error handling and observability."
)]
#[expect(
clippy::too_many_arguments,
reason = "execute_superstep requires: tasks, state, nodes, config, cancellation token, checkpointer, pending interrupts, scratchpad, error handler map, retry policies, timeout policies, and step. All are necessary for the multi-interrupt matching algorithm, error recovery, retry execution, and timeout enforcement."
)]
#[expect(
clippy::implicit_hasher,
reason = "error_handler_map, retry_policies, and timeout_policies use std::collections::HashMap as the canonical type matching the builder metadata extraction; no alternative hasher is needed."
)]
pub async fn execute_superstep<S: State>(
pending_tasks: &[PendingTask<S>],
state: &Arc<S>,
nodes: &indexmap::IndexMap<String, Arc<dyn Node<S>>>,
config: &RunnableConfig,
cancellation_token: &CancellationToken,
checkpointer: Option<&Arc<dyn crate::checkpoint::CheckpointSaver>>,
pending_interrupts: &[InterruptSignal],
scratchpad: &Scratchpad,
error_handler_map: &HashMap<String, String>,
retry_policies: &HashMap<String, RetryPolicy>,
timeout_policies: &HashMap<String, TimeoutPolicy>,
fallback_map: &HashMap<String, String>,
step: usize,
) -> Result<
(
SuperstepResult<S>,
mpsc::UnboundedReceiver<crate::interrupt::InterruptSignal>,
),
JunctureError,
>
where
S::Update: serde::Serialize,
{
if pending_tasks.is_empty() {
let (_interrupt_tx, interrupt_rx) = mpsc::unbounded_channel();
return Ok((SuperstepResult::empty(), interrupt_rx));
}
let resume_values =
match_resume_to_interrupts(&config.resume_value, pending_interrupts, scratchpad);
let (interrupt_tx, interrupt_rx) = mpsc::unbounded_channel();
let interrupt_context = Arc::new(InterruptContext::new(resume_values, interrupt_tx));
let arc_state: Arc<S> = Arc::clone(state);
if let Some(result) = try_execute_single_task_inline(
pending_tasks,
&arc_state,
nodes,
config,
cancellation_token,
checkpointer,
error_handler_map,
retry_policies,
timeout_policies,
fallback_map,
)
.await?
{
return Ok((result, interrupt_rx));
}
let semaphore = Arc::new(tokio::sync::Semaphore::new(config.max_parallel_tasks));
let mut join_set = JoinSet::new();
for task in pending_tasks {
let node = Arc::clone(nodes.get(&task.node_name).ok_or_else(|| {
JunctureError::execution(format!("Node '{}' not found", task.node_name))
})?);
let task_state: Arc<S> = task
.state_override
.clone()
.map_or_else(|| Arc::clone(&arc_state), Arc::new);
let mut task_config = config.clone();
let task_id = task.id.clone();
let node_name = task.node_name.clone();
let task_trigger = task.trigger.clone();
let permit = Arc::clone(&semaphore);
let token = cancellation_token.clone();
let ctx = Arc::clone(&interrupt_context);
let has_error_handler = error_handler_map.contains_key(&node_name);
let retry_policy = retry_policies.get(&task.node_name).cloned();
let timeout_policy = timeout_policies.get(&task.node_name).cloned();
let idle_watcher = timeout_policy.as_ref().and_then(|tp| {
tp.idle_timeout.map(|_| {
let (heartbeat, watcher) = Heartbeat::new_pair();
task_config.heartbeat = Some(heartbeat);
watcher
})
});
let callback_handler = task_config.callback_handler.clone();
let metrics_collector = task_config.metrics_collector.clone();
let span = info_span!(
"juncture.node.execute",
node_name = %node_name,
task_id = %task_id,
"juncture.step" = step,
"juncture.thread.id" = %config.thread_id.as_deref().unwrap_or(""),
"juncture.node.output_type" = tracing::field::Empty,
"juncture.node.duration_ms" = tracing::field::Empty,
"juncture.node.error" = tracing::field::Empty,
);
join_set.spawn(
async move {
let _permit = permit.acquire_owned().await.expect(
"Semaphore acquisition failed: semaphore should never be closed \
as it is owned by the PregelLoop and never dropped during execution",
);
if let Some(ref handler) = callback_handler {
handler.on_node_start(&node_name, &task_id);
}
let start = Instant::now();
let exec_node_name = node_name.clone();
let result = tokio::select! {
biased;
() = token.cancelled() => {
tracing::Span::current().record("juncture.node.error", "cancelled");
tracing::Span::current().record("otel.status_code", "ERROR");
let err = JunctureError::execution("Task cancelled");
if let Some(ref handler) = callback_handler {
handler.on_node_error(&node_name, &err);
}
return Err((node_name.clone(), err));
}
result = async {
let timeout_node_name = exec_node_name.clone();
let inner_future = async {
if let Some(ref policy) = retry_policy {
let ctx_ref = Arc::clone(&ctx);
crate::interrupt::INTERRUPT_CONTEXT.scope(ctx_ref, async move {
if let Some(ref tracker) = task_config.budget_tracker {
let tracker_ref = Arc::clone(tracker);
crate::pregel::BUDGET_TRACKER.scope(tracker_ref, async move {
execute_with_retry(
&exec_node_name,
policy,
|s, cfg| node.call(s, cfg),
&*task_state,
&task_config,
)
.await
}).await
} else {
execute_with_retry(
&exec_node_name,
policy,
|s, cfg| node.call(s, cfg),
&*task_state,
&task_config,
)
.await
}
}).await
} else {
crate::interrupt::INTERRUPT_CONTEXT.scope(ctx, async move {
if let Some(ref tracker) = task_config.budget_tracker {
let tracker_ref = Arc::clone(tracker);
crate::pregel::BUDGET_TRACKER.scope(tracker_ref, async move {
node.call_arc(Arc::clone(&task_state), &task_config).await
}).await
} else {
node.call_arc(Arc::clone(&task_state), &task_config).await
}
}).await
}
};
if let Some(ref tp) = timeout_policy {
let timeout_result = if let (Some(idle_to), Some(mut watcher)) = (
tp.idle_timeout,
idle_watcher,
) {
let to_name = timeout_node_name.clone();
tokio::time::timeout(tp.run_timeout, async move {
tokio::pin!(inner_future);
loop {
tokio::select! {
result = &mut inner_future => return result,
() = tokio::time::sleep(idle_to) => {
if !watcher.is_alive(idle_to) {
return Err(
crate::JunctureError::node_timeout(
crate::error::NodeTimeoutError::IdleTimeout {
node: to_name,
timeout: u64::try_from(
idle_to.as_millis(),
).unwrap_or(u64::MAX),
},
),
);
}
}
}
}
})
.await
} else {
tokio::time::timeout(tp.run_timeout, inner_future)
.await
};
timeout_result.map_or_else(
|_| {
Err(crate::JunctureError::node_timeout(
crate::error::NodeTimeoutError::RunTimeout {
node: timeout_node_name,
timeout: u64::try_from(tp.run_timeout.as_millis())
.unwrap_or(u64::MAX),
},
))
},
std::convert::identity,
)
} else {
inner_future.await
}
} => result,
};
let duration = start.elapsed();
let output_type = result.as_ref().map_or(
if has_error_handler {
"error_handler"
} else {
"error"
},
|command| {
if command.resume.is_some() {
"interrupt"
} else if matches!(command.goto, crate::command::Goto::Send(_)) {
"send"
} else if matches!(command.goto, crate::command::Goto::End) {
"end"
} else if !matches!(command.goto, crate::command::Goto::None) {
"goto"
} else if command.update.is_some() {
"update"
} else {
"none"
}
},
);
tracing::Span::current().record("juncture.node.output_type", output_type);
let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
tracing::Span::current().record("juncture.node.duration_ms", duration_ms);
if let Err(ref e) = result {
tracing::Span::current()
.record("juncture.node.error", tracing::field::display(e));
tracing::Span::current().record("otel.status_code", "ERROR");
}
if let Some(ref collector) = metrics_collector {
#[allow(
clippy::cast_precision_loss,
reason = "Milliseconds as f64 is sufficient for histogram metrics; sub-millisecond precision is not required for node duration tracking"
)]
collector.record_histogram(
"juncture.node.duration_ms",
duration.as_millis() as f64,
);
}
#[cfg(feature = "otel")]
{
event!(
name: "juncture.node.execute.metrics",
Level::DEBUG,
node_name = %node_name,
duration_ms = duration.as_millis(),
success = result.is_ok(),
output_type = %output_type,
);
};
if let Some(ref handler) = callback_handler {
match &result {
Ok(_) => {
handler.on_node_end(&node_name, &task_id, duration_ms);
}
Err(err) => {
handler.on_node_error(&node_name, err);
}
}
}
result
.map(|command| TaskOutput {
task_id,
node_name: node_name.clone(),
command,
duration,
trigger: task_trigger,
triggered_fields: Vec::new(), error: None,
circuit_blocked: false,
})
.map_err(|e| (node_name, e))
}
.instrument(span),
);
}
let mut task_outputs = Vec::new();
while let Some(result) = join_set.join_next().await {
match result {
Ok(Ok(output)) => {
if let Some(cp) = checkpointer
&& let Some(ref update) = output.command.update
{
let writes = serialize_pending_writes(&output.task_id, update);
if !writes.is_empty() {
let _ = cp.put_writes(config, writes, &output.task_id).await;
}
}
task_outputs.push(output);
}
Ok(Err((failed_node_name, error))) => {
let has_error_handler = error_handler_map.contains_key(&failed_node_name);
let has_fallback = fallback_map.contains_key(&failed_node_name);
if has_error_handler || has_fallback {
let recovery_type = if has_fallback {
"fallback"
} else {
"error_handler"
};
tracing::warn!(
name: "juncture.node.error.recovery_scheduled",
node_name = %failed_node_name,
recovery_type = %recovery_type,
error = %error,
"Node failed with recovery registered, scheduling recovery"
);
task_outputs.push(TaskOutput {
task_id: uuid::Uuid::new_v4().to_string(),
node_name: failed_node_name,
command: crate::Command::default(),
duration: std::time::Duration::ZERO,
trigger: crate::pregel::types::TaskTrigger::Pull,
triggered_fields: Vec::new(),
error: Some(error),
circuit_blocked: false,
});
} else {
cancellation_token.cancel();
join_set.shutdown().await;
return Err(error);
}
}
Err(join_error) => {
cancellation_token.cancel();
join_set.shutdown().await;
return Err(JunctureError::execution(format!(
"Task panicked: {join_error}"
)));
}
}
}
Ok((
SuperstepResult {
task_outputs,
bubble_ups: Vec::new(),
},
interrupt_rx,
))
}
#[allow(
clippy::ref_option,
reason = "config stores Option<ResumeValue> and we need to pass it by reference"
)]
#[must_use]
fn match_resume_to_interrupts(
resume_value: &Option<ResumeValue>,
pending_interrupts: &[InterruptSignal],
scratchpad: &Scratchpad,
) -> Vec<Option<serde_json::Value>> {
let Some(rv) = resume_value else {
return Vec::new();
};
match rv {
ResumeValue::Single(value) => {
if pending_interrupts.is_empty() {
vec![Some(value.clone())]
} else {
pending_interrupts
.iter()
.map(|signal| {
if let Some(ref id) = signal.id
&& scratchpad.get_null_resume(id)
{
return Some(serde_json::Value::Null);
}
Some(value.clone())
})
.collect()
}
}
ResumeValue::ById(map) => {
pending_interrupts
.iter()
.map(|signal| {
if let Some(ref id) = signal.id {
if let Some(value) = map.get(id) {
return Some(value.clone());
}
if scratchpad.get_null_resume(id) {
return Some(serde_json::Value::Null);
}
}
None
})
.collect()
}
ResumeValue::ByNamespace(map) => {
let max_index = map
.keys()
.filter_map(|k| k.parse::<usize>().ok())
.max()
.unwrap_or(0);
let size = pending_interrupts.len().max(max_index + 1);
let mut values = vec![None; size];
for (key, value) in map {
if let Ok(index) = key.parse::<usize>()
&& index < values.len()
{
values[index] = Some(value.clone());
}
}
for (i, signal) in pending_interrupts.iter().enumerate() {
if values[i].is_none()
&& let Some(ref id) = signal.id
&& scratchpad.get_null_resume(id)
{
values[i] = Some(serde_json::Value::Null);
}
}
values
}
}
}
fn serialize_pending_writes<U>(task_id: &str, update: &U) -> Vec<crate::checkpoint::PendingWrite>
where
U: serde::Serialize,
{
let Ok(value) = serde_json::to_value(update) else {
return Vec::new();
};
let Some(obj) = value.as_object() else {
return Vec::new();
};
obj.iter()
.filter(|(_, v)| !v.is_null())
.map(|(channel, value)| crate::checkpoint::PendingWrite {
task_id: task_id.to_string(),
channel: channel.clone(),
value: value.clone(),
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::node::{IntoNode, NodeFnCommand};
use crate::state::FieldVersions;
use chrono::Utc;
#[tokio::test]
async fn test_execute_superstep_empty() {
let state = TestState;
let nodes = indexmap::IndexMap::new();
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let (result, _rx) = execute_superstep(
&[],
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
0,
)
.await
.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_execute_superstep_single_task() {
let state = TestState;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"test_node".to_string(),
NodeFnCommand(|_s: &TestState| async move { Ok(crate::Command::end()) })
.into_node("test_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"test_node".to_string(),
)];
let (result, _rx) = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
0,
)
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result.task_outputs[0].node_name, "test_node");
}
#[tokio::test]
async fn test_execute_superstep_parallel_tasks() {
let state = TestState;
let mut nodes = indexmap::IndexMap::new();
for i in 0..3 {
nodes.insert(
format!("node_{i}"),
NodeFnCommand(move |_s: &TestState| async move { Ok(crate::Command::end()) })
.into_node(format!("node_{i}").as_str()),
);
}
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let tasks: Vec<PendingTask<TestState>> = (0..3)
.map(|i| PendingTask::pull(uuid::Uuid::new_v4().to_string(), format!("node_{i}")))
.collect();
let (result, _rx) = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
0,
)
.await
.unwrap();
assert_eq!(result.len(), 3);
}
#[tokio::test]
async fn test_execute_superstep_cancellation() {
let state = TestState;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"slow_node".to_string(),
NodeFnCommand(|_s: &TestState| async move {
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
Ok(crate::Command::end())
})
.into_node("slow_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"slow_node".to_string(),
)];
token.cancel();
let result = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
0,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_execution());
}
#[tokio::test]
async fn test_execute_superstep_node_not_found() {
let state = TestState;
let nodes = indexmap::IndexMap::new();
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"nonexistent".to_string(),
)];
let result = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
0,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_execution());
}
#[test]
fn test_match_resume_none_returns_empty() {
let scratchpad = Scratchpad::new();
let result = match_resume_to_interrupts(&None, &[], &scratchpad);
assert!(result.is_empty());
}
#[test]
fn test_match_single_value_no_pending_interrupts() {
let scratchpad = Scratchpad::new();
let resume = Some(ResumeValue::Single(serde_json::json!("yes")));
let result = match_resume_to_interrupts(&resume, &[], &scratchpad);
assert_eq!(result, vec![Some(serde_json::json!("yes"))]);
}
#[test]
fn test_match_single_value_with_pending_interrupts() {
let scratchpad = Scratchpad::new();
let resume = Some(ResumeValue::Single(serde_json::json!("approve")));
let interrupts = vec![
InterruptSignal {
index: 0,
id: Some("id-0".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
InterruptSignal {
index: 1,
id: Some("id-1".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
];
let result = match_resume_to_interrupts(&resume, &interrupts, &scratchpad);
assert_eq!(
result,
vec![
Some(serde_json::json!("approve")),
Some(serde_json::json!("approve")),
]
);
}
#[test]
fn test_match_single_value_with_scratchpad_null_resume() {
let mut scratchpad = Scratchpad::new();
scratchpad.mark_interrupt_processed("id-0");
let resume = Some(ResumeValue::Single(serde_json::json!("approve")));
let interrupts = vec![
InterruptSignal {
index: 0,
id: Some("id-0".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
InterruptSignal {
index: 1,
id: Some("id-1".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
];
let result = match_resume_to_interrupts(&resume, &interrupts, &scratchpad);
assert_eq!(
result,
vec![
Some(serde_json::Value::Null), Some(serde_json::json!("approve")), ]
);
}
#[test]
fn test_match_single_null_value_with_scratchpad() {
let mut scratchpad = Scratchpad::new();
scratchpad.mark_interrupt_processed("id-1");
let resume = Some(ResumeValue::Single(serde_json::Value::Null));
let interrupts = vec![
InterruptSignal {
index: 0,
id: Some("id-0".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
InterruptSignal {
index: 1,
id: Some("id-1".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
];
let result = match_resume_to_interrupts(&resume, &interrupts, &scratchpad);
assert_eq!(
result,
vec![Some(serde_json::Value::Null), Some(serde_json::Value::Null),]
);
}
#[test]
fn test_match_by_id_with_matching_ids() {
let scratchpad = Scratchpad::new();
let mut map = std::collections::HashMap::new();
map.insert("id-0".to_string(), serde_json::json!("value-0"));
map.insert("id-1".to_string(), serde_json::json!("value-1"));
let resume = Some(ResumeValue::ById(map));
let interrupts = vec![
InterruptSignal {
index: 0,
id: Some("id-0".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
InterruptSignal {
index: 1,
id: Some("id-1".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
];
let result = match_resume_to_interrupts(&resume, &interrupts, &scratchpad);
assert_eq!(
result,
vec![
Some(serde_json::json!("value-0")),
Some(serde_json::json!("value-1")),
]
);
}
#[test]
fn test_match_by_id_with_scratchpad_null_resume() {
let mut scratchpad = Scratchpad::new();
scratchpad.mark_interrupt_processed("id-0");
let mut map = std::collections::HashMap::new();
map.insert("id-1".to_string(), serde_json::json!("value-1"));
let resume = Some(ResumeValue::ById(map));
let interrupts = vec![
InterruptSignal {
index: 0,
id: Some("id-0".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
InterruptSignal {
index: 1,
id: Some("id-1".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
InterruptSignal {
index: 2,
id: Some("id-2".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
];
let result = match_resume_to_interrupts(&resume, &interrupts, &scratchpad);
assert_eq!(
result,
vec![
Some(serde_json::Value::Null), Some(serde_json::json!("value-1")), None, ]
);
}
#[test]
fn test_match_by_id_no_match_returns_none() {
let scratchpad = Scratchpad::new();
let mut map = std::collections::HashMap::new();
map.insert("other-id".to_string(), serde_json::json!("value"));
let resume = Some(ResumeValue::ById(map));
let interrupts = vec![InterruptSignal {
timestamp: Utc::now(),
index: 0,
id: Some("id-0".to_string()),
payload: serde_json::Value::Null,
}];
let result = match_resume_to_interrupts(&resume, &interrupts, &scratchpad);
assert_eq!(result, vec![None]);
}
#[test]
fn test_match_by_namespace_index_mapping() {
let scratchpad = Scratchpad::new();
let mut map = std::collections::HashMap::new();
map.insert("0".to_string(), serde_json::json!("first"));
map.insert("2".to_string(), serde_json::json!("third"));
let resume = Some(ResumeValue::ByNamespace(map));
let interrupts = vec![
InterruptSignal {
index: 0,
id: Some("id-0".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
InterruptSignal {
index: 1,
id: Some("id-1".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
InterruptSignal {
index: 2,
id: Some("id-2".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
];
let result = match_resume_to_interrupts(&resume, &interrupts, &scratchpad);
assert_eq!(
result,
vec![
Some(serde_json::json!("first")),
None,
Some(serde_json::json!("third")),
]
);
}
#[test]
fn test_match_by_namespace_with_scratchpad_fill() {
let mut scratchpad = Scratchpad::new();
scratchpad.mark_interrupt_processed("id-1");
let mut map = std::collections::HashMap::new();
map.insert("0".to_string(), serde_json::json!("first"));
let resume = Some(ResumeValue::ByNamespace(map));
let interrupts = vec![
InterruptSignal {
index: 0,
id: Some("id-0".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
InterruptSignal {
index: 1,
id: Some("id-1".to_string()),
payload: serde_json::Value::Null,
timestamp: Utc::now(),
},
];
let result = match_resume_to_interrupts(&resume, &interrupts, &scratchpad);
assert_eq!(
result,
vec![
Some(serde_json::json!("first")),
Some(serde_json::Value::Null), ]
);
}
#[test]
fn test_match_by_namespace_no_pending_interrupts() {
let scratchpad = Scratchpad::new();
let mut map = std::collections::HashMap::new();
map.insert("0".to_string(), serde_json::json!("first"));
map.insert("2".to_string(), serde_json::json!("third"));
let resume = Some(ResumeValue::ByNamespace(map));
let result = match_resume_to_interrupts(&resume, &[], &scratchpad);
assert_eq!(
result,
vec![
Some(serde_json::json!("first")),
None,
Some(serde_json::json!("third")),
]
);
}
#[test]
fn test_match_by_id_signal_without_id() {
let scratchpad = Scratchpad::new();
let mut map = std::collections::HashMap::new();
map.insert("id-0".to_string(), serde_json::json!("value"));
let resume = Some(ResumeValue::ById(map));
let interrupts = vec![InterruptSignal {
index: 0,
id: None,
payload: serde_json::Value::Null,
timestamp: Utc::now(),
}];
let result = match_resume_to_interrupts(&resume, &interrupts, &scratchpad);
assert_eq!(result, vec![None]);
}
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
struct TestState;
impl State for TestState {
type Update = TestUpdate;
type FieldVersions = FieldVersions;
fn apply(&mut self, _: Self::Update) -> crate::FieldsChanged {
crate::FieldsChanged(0)
}
fn reset_ephemeral(&mut self) {}
}
#[derive(Clone, Debug, Default, serde::Serialize)]
struct TestUpdate;
#[test]
fn test_serialize_pending_writes_unit_update() {
let update = TestUpdate;
let writes = serialize_pending_writes("task-1", &update);
assert!(writes.is_empty());
}
#[test]
fn test_serialize_pending_writes_with_fields() {
#[derive(serde::Serialize)]
struct SampleUpdate {
messages: Option<Vec<String>>,
count: Option<u64>,
untouched: Option<String>,
}
let update = SampleUpdate {
messages: Some(vec!["hello".to_string()]),
count: Some(42),
untouched: None,
};
let writes = serialize_pending_writes("task-99", &update);
assert_eq!(writes.len(), 2);
let channels: std::collections::HashSet<&str> =
writes.iter().map(|w| w.channel.as_str()).collect();
assert!(channels.contains("messages"));
assert!(channels.contains("count"));
assert!(!channels.contains("untouched"));
for w in &writes {
assert_eq!(w.task_id, "task-99");
}
let msg_write = writes
.iter()
.find(|w| w.channel == "messages")
.expect("messages write");
assert_eq!(msg_write.value, serde_json::json!(["hello"]));
let count_write = writes
.iter()
.find(|w| w.channel == "count")
.expect("count write");
assert_eq!(count_write.value, serde_json::json!(42));
}
#[test]
fn test_serialize_pending_writes_all_none() {
#[derive(serde::Serialize)]
struct EmptyUpdate {
a: Option<String>,
b: Option<u64>,
}
let update = EmptyUpdate { a: None, b: None };
let writes = serialize_pending_writes("task-x", &update);
assert!(writes.is_empty());
}
#[tokio::test]
async fn test_execute_superstep_with_retry_succeeds_after_failure() {
use std::sync::atomic::{AtomicU32, Ordering};
let state = TestState;
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"flaky_node".to_string(),
NodeFnCommand(move |_s: &TestState| {
let counter = Arc::clone(&attempt_clone);
async move {
let n = counter.fetch_add(1, Ordering::Relaxed);
if n == 0 {
Err(crate::JunctureError::execution("transient failure"))
} else {
Ok(crate::Command::end())
}
}
})
.into_node("flaky_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let retry_policies = {
let mut map = HashMap::new();
map.insert(
"flaky_node".to_string(),
RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
},
);
map
};
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"flaky_node".to_string(),
)];
let (result, _rx) = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&retry_policies,
&HashMap::new(),
&HashMap::new(),
0,
)
.await
.unwrap();
assert_eq!(result.len(), 1);
assert!(result.task_outputs[0].error.is_none());
assert_eq!(
attempt_count.load(Ordering::Relaxed),
2,
"should succeed on second attempt"
);
}
#[tokio::test]
async fn test_execute_superstep_with_retry_exhausts_attempts() {
use std::sync::atomic::{AtomicU32, Ordering};
let state = TestState;
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"always_fail".to_string(),
NodeFnCommand(move |_s: &TestState| {
let counter = Arc::clone(&attempt_clone);
async move {
counter.fetch_add(1, Ordering::Relaxed);
Err(crate::JunctureError::execution("persistent failure"))
}
})
.into_node("always_fail"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let retry_policies = {
let mut map = HashMap::new();
map.insert(
"always_fail".to_string(),
RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
},
);
map
};
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"always_fail".to_string(),
)];
let result = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&retry_policies,
&HashMap::new(),
&HashMap::new(),
0,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_execution());
assert_eq!(
attempt_count.load(Ordering::Relaxed),
3,
"should attempt exactly max_attempts times"
);
}
#[tokio::test]
async fn test_execute_superstep_retry_does_not_retry_cancelled() {
use std::sync::atomic::{AtomicU32, Ordering};
let state = TestState;
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"cancel_node".to_string(),
NodeFnCommand(move |_s: &TestState| {
let counter = Arc::clone(&attempt_clone);
async move {
counter.fetch_add(1, Ordering::Relaxed);
Err(crate::JunctureError::cancelled())
}
})
.into_node("cancel_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let retry_policies = {
let mut map = HashMap::new();
map.insert(
"cancel_node".to_string(),
RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
},
);
map
};
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"cancel_node".to_string(),
)];
let result = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&retry_policies,
&HashMap::new(),
&HashMap::new(),
0,
)
.await;
assert!(result.is_err());
assert!(
result.unwrap_err().is_cancelled(),
"cancelled errors should not be retried"
);
assert_eq!(
attempt_count.load(Ordering::Relaxed),
1,
"cancelled errors should not be retried"
);
}
#[tokio::test]
async fn test_execute_superstep_retry_only_applies_to_configured_node() {
use std::sync::atomic::{AtomicU32, Ordering};
let state = TestState;
let attempt_count_a = Arc::new(AtomicU32::new(0));
let attempt_count_b = Arc::new(AtomicU32::new(0));
let clone_a = Arc::clone(&attempt_count_a);
let clone_b = Arc::clone(&attempt_count_b);
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"node_a".to_string(),
NodeFnCommand(move |_s: &TestState| {
let counter = Arc::clone(&clone_a);
async move {
counter.fetch_add(1, Ordering::Relaxed);
Err(crate::JunctureError::execution("node_a fails"))
}
})
.into_node("node_a"),
);
nodes.insert(
"node_b".to_string(),
NodeFnCommand(move |_s: &TestState| {
let counter = Arc::clone(&clone_b);
async move {
counter.fetch_add(1, Ordering::Relaxed);
Err(crate::JunctureError::execution("node_b fails"))
}
})
.into_node("node_b"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let retry_policies = {
let mut map = HashMap::new();
map.insert(
"node_a".to_string(),
RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
},
);
map
};
let error_handlers = {
let mut map = HashMap::new();
map.insert("node_b".to_string(), "handler".to_string());
map
};
let tasks = vec![
PendingTask::pull(uuid::Uuid::new_v4().to_string(), "node_a".to_string()),
PendingTask::pull(uuid::Uuid::new_v4().to_string(), "node_b".to_string()),
];
let result = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&error_handlers,
&retry_policies,
&HashMap::new(),
&HashMap::new(),
0,
)
.await;
let err = result.unwrap_err();
assert!(err.is_execution(), "expected execution error, got: {err}");
assert_eq!(
attempt_count_a.load(Ordering::Relaxed),
3,
"node_a should retry max_attempts times"
);
assert_eq!(
attempt_count_b.load(Ordering::Relaxed),
1,
"node_b should execute only once (no retry policy)"
);
}
#[tokio::test]
async fn test_execute_superstep_no_retry_policy_same_behavior() {
let state = TestState;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"simple_node".to_string(),
NodeFnCommand(|_s: &TestState| async move { Ok(crate::Command::end()) })
.into_node("simple_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"simple_node".to_string(),
)];
let (result, _rx) = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
0,
)
.await
.unwrap();
assert_eq!(result.len(), 1);
assert!(result.task_outputs[0].error.is_none());
}
#[tokio::test]
async fn test_execute_superstep_with_timeout_succeeds_within_limit() {
let state = TestState;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"fast_node".to_string(),
NodeFnCommand(|_s: &TestState| async move {
Ok(crate::Command::end())
})
.into_node("fast_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let timeout_policies = {
let mut map = HashMap::new();
map.insert(
"fast_node".to_string(),
crate::pregel::context::TimeoutPolicy::new()
.with_run_timeout(std::time::Duration::from_secs(10)),
);
map
};
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"fast_node".to_string(),
)];
let (result, _rx) = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&HashMap::new(),
&timeout_policies,
&HashMap::new(),
0,
)
.await
.unwrap();
assert_eq!(result.len(), 1);
assert!(result.task_outputs[0].error.is_none());
}
#[tokio::test]
async fn test_execute_superstep_with_timeout_exceeds_limit() {
let state = TestState;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"slow_node".to_string(),
NodeFnCommand(|_s: &TestState| async move {
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
Ok(crate::Command::end())
})
.into_node("slow_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let timeout_policies = {
let mut map = HashMap::new();
map.insert(
"slow_node".to_string(),
crate::pregel::context::TimeoutPolicy::new()
.with_run_timeout(std::time::Duration::from_millis(50)),
);
map
};
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"slow_node".to_string(),
)];
let result = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&HashMap::new(),
&timeout_policies,
&HashMap::new(),
0,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.is_node_timeout(),
"expected node timeout error, got: {err}"
);
}
#[tokio::test]
async fn test_execute_superstep_timeout_wraps_retry_entire_sequence() {
use std::sync::atomic::{AtomicU32, Ordering};
let state = TestState;
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"slow_retry_node".to_string(),
NodeFnCommand(move |_s: &TestState| {
let counter = Arc::clone(&attempt_clone);
async move {
let _n = counter.fetch_add(1, Ordering::Relaxed);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
Err(crate::JunctureError::execution("transient failure"))
}
})
.into_node("slow_retry_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let retry_policies = {
let mut map = HashMap::new();
map.insert(
"slow_retry_node".to_string(),
RetryPolicy {
max_attempts: 10,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 1.0,
max_interval: std::time::Duration::from_millis(1),
jitter: false,
retry_on: None,
},
);
map
};
let timeout_policies = {
let mut map = HashMap::new();
map.insert(
"slow_retry_node".to_string(),
crate::pregel::context::TimeoutPolicy::new()
.with_run_timeout(std::time::Duration::from_millis(200)),
);
map
};
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"slow_retry_node".to_string(),
)];
let result = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&retry_policies,
&timeout_policies,
&HashMap::new(),
0,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.is_node_timeout(),
"timeout should fire before retries exhaust, got: {err}"
);
let attempts = attempt_count.load(Ordering::Relaxed);
assert!(
attempts >= 1,
"should have attempted at least once before timeout, got {attempts}"
);
assert!(
attempts < 10,
"timeout should have prevented all 10 retry attempts, got {attempts}"
);
}
#[tokio::test]
async fn test_execute_superstep_timeout_only_applies_to_configured_node() {
let state = TestState;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"fast_node".to_string(),
NodeFnCommand(|_s: &TestState| async move { Ok(crate::Command::end()) })
.into_node("fast_node"),
);
nodes.insert(
"slow_node".to_string(),
NodeFnCommand(|_s: &TestState| async move {
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
Ok(crate::Command::end())
})
.into_node("slow_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let timeout_policies = {
let mut map = HashMap::new();
map.insert(
"slow_node".to_string(),
crate::pregel::context::TimeoutPolicy::new()
.with_run_timeout(std::time::Duration::from_millis(50)),
);
map
};
let error_handlers = {
let mut map = HashMap::new();
map.insert("slow_node".to_string(), "handler".to_string());
map
};
let tasks = vec![
PendingTask::pull(uuid::Uuid::new_v4().to_string(), "fast_node".to_string()),
PendingTask::pull(uuid::Uuid::new_v4().to_string(), "slow_node".to_string()),
];
let (result, _rx) = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&error_handlers,
&HashMap::new(),
&timeout_policies,
&HashMap::new(),
0,
)
.await
.unwrap();
assert_eq!(result.len(), 2);
let fast_output = result
.task_outputs
.iter()
.find(|o| o.node_name == "fast_node")
.expect("fast_node output should exist");
assert!(fast_output.error.is_none());
let slow_output = result
.task_outputs
.iter()
.find(|o| o.node_name == "slow_node")
.expect("slow_node output should exist");
assert!(
slow_output.error.is_some(),
"slow_node should have timed out with error handler"
);
}
#[tokio::test]
async fn test_execute_superstep_no_timeout_policy_same_behavior() {
let state = TestState;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"simple_node".to_string(),
NodeFnCommand(|_s: &TestState| async move { Ok(crate::Command::end()) })
.into_node("simple_node"),
);
let config = RunnableConfig::new();
let token = CancellationToken::new();
let pending_interrupts = vec![];
let scratchpad = Scratchpad::new();
let tasks = vec![PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
"simple_node".to_string(),
)];
let (result, _rx) = execute_superstep(
&tasks,
&Arc::new(state.clone()),
&nodes,
&config,
&token,
None,
&pending_interrupts,
&scratchpad,
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
&HashMap::new(),
0,
)
.await
.unwrap();
assert_eq!(result.len(), 1);
assert!(result.task_outputs[0].error.is_none());
}
}