use thiserror::Error;
use crate::trace::{TraceEventKind, TraceTerminalStatus, WorkflowTrace};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReplayCachePolicy {
Always,
Refresh,
Mixed,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ReplayOptions {
pub cache_policy: ReplayCachePolicy,
}
impl Default for ReplayOptions {
fn default() -> Self {
Self {
cache_policy: ReplayCachePolicy::Refresh,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ReplayReport {
pub total_events: usize,
pub terminal_status: TraceTerminalStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReplayViolationCode {
NonMonotonicSequence,
MismatchedNodeLifecycle,
MissingTerminalEvent,
UnclosedNodeLifecycle,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ReplayViolation {
pub code: ReplayViolationCode,
pub message: String,
pub event_index: Option<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[error("workflow trace replay validation failed")]
pub struct ReplayError {
pub violations: Vec<ReplayViolation>,
}
pub fn replay_trace(trace: &WorkflowTrace) -> Result<ReplayReport, ReplayError> {
replay_trace_with_options(trace, &ReplayOptions::default())
}
pub fn replay_trace_with_options(
trace: &WorkflowTrace,
options: &ReplayOptions,
) -> Result<ReplayReport, ReplayError> {
match options.cache_policy {
ReplayCachePolicy::Always | ReplayCachePolicy::Refresh | ReplayCachePolicy::Mixed => {
replay_trace_internal(trace)
}
}
}
fn replay_trace_internal(trace: &WorkflowTrace) -> Result<ReplayReport, ReplayError> {
let mut violations = Vec::new();
let mut expected_seq = 0u64;
let mut stack: Vec<&str> = Vec::new();
let mut terminal_status = None;
for (index, event) in trace.events.iter().enumerate() {
if event.seq != expected_seq {
violations.push(ReplayViolation {
code: ReplayViolationCode::NonMonotonicSequence,
message: format!(
"expected event seq {} at index {}, found {}",
expected_seq, index, event.seq
),
event_index: Some(index),
});
expected_seq = event.seq.saturating_add(1);
} else {
expected_seq = expected_seq.saturating_add(1);
}
match &event.kind {
TraceEventKind::NodeEnter { node_id } => {
stack.push(node_id.as_str());
}
TraceEventKind::NodeExit { node_id } | TraceEventKind::NodeError { node_id, .. } => {
match stack.pop() {
Some(active_node) if active_node == node_id => {}
Some(active_node) => violations.push(ReplayViolation {
code: ReplayViolationCode::MismatchedNodeLifecycle,
message: format!(
"expected node '{}' to close, found '{}'",
active_node, node_id
),
event_index: Some(index),
}),
None => violations.push(ReplayViolation {
code: ReplayViolationCode::MismatchedNodeLifecycle,
message: format!(
"node '{}' closed without a matching enter event",
node_id
),
event_index: Some(index),
}),
}
}
TraceEventKind::Terminal { status } => {
terminal_status = Some(*status);
}
}
}
if terminal_status.is_none() {
violations.push(ReplayViolation {
code: ReplayViolationCode::MissingTerminalEvent,
message: "trace does not contain a terminal event".to_string(),
event_index: None,
});
}
if !stack.is_empty() {
violations.push(ReplayViolation {
code: ReplayViolationCode::UnclosedNodeLifecycle,
message: format!("{} node(s) remain open at end of trace", stack.len()),
event_index: None,
});
}
if violations.is_empty() {
Ok(ReplayReport {
total_events: trace.events.len(),
terminal_status: terminal_status.expect("terminal status must exist"),
})
} else {
Err(ReplayError { violations })
}
}
#[cfg(test)]
mod tests {
use crate::recorder::TraceRecorder;
use crate::replay::{
replay_trace, replay_trace_with_options, ReplayCachePolicy, ReplayOptions,
ReplayViolationCode,
};
use crate::trace::{
TraceEvent, TraceEventKind, TraceTerminalStatus, WorkflowTrace, WorkflowTraceMetadata,
};
fn metadata() -> WorkflowTraceMetadata {
WorkflowTraceMetadata {
trace_id: "trace-1".to_string(),
workflow_name: "demo".to_string(),
workflow_version: "v0".to_string(),
started_at_unix_ms: 100,
finished_at_unix_ms: None,
}
}
#[test]
fn replays_valid_trace() {
let recorder = TraceRecorder::new(metadata());
recorder.record_node_enter(101, "start").unwrap();
recorder.record_node_exit(102, "start").unwrap();
recorder
.record_terminal(103, TraceTerminalStatus::Completed)
.unwrap();
let trace = recorder.finalize(104).unwrap();
let report = replay_trace(&trace).expect("valid trace should replay");
assert_eq!(report.total_events, 3);
assert_eq!(report.terminal_status, TraceTerminalStatus::Completed);
}
#[test]
fn rejects_out_of_order_sequence() {
let trace = WorkflowTrace {
metadata: metadata(),
events: vec![
TraceEvent {
seq: 0,
timestamp_unix_ms: 101,
kind: TraceEventKind::NodeEnter {
node_id: "start".to_string(),
},
},
TraceEvent {
seq: 2,
timestamp_unix_ms: 102,
kind: TraceEventKind::NodeExit {
node_id: "start".to_string(),
},
},
TraceEvent {
seq: 3,
timestamp_unix_ms: 103,
kind: TraceEventKind::Terminal {
status: TraceTerminalStatus::Completed,
},
},
],
};
let err = replay_trace(&trace).expect_err("should reject non-monotonic sequence");
assert!(err
.violations
.iter()
.any(|v| v.code == ReplayViolationCode::NonMonotonicSequence));
}
#[test]
fn rejects_missing_terminal_event() {
let trace = WorkflowTrace {
metadata: metadata(),
events: vec![
TraceEvent {
seq: 0,
timestamp_unix_ms: 101,
kind: TraceEventKind::NodeEnter {
node_id: "start".to_string(),
},
},
TraceEvent {
seq: 1,
timestamp_unix_ms: 102,
kind: TraceEventKind::NodeExit {
node_id: "start".to_string(),
},
},
],
};
let err = replay_trace(&trace).expect_err("should reject missing terminal event");
assert!(err
.violations
.iter()
.any(|v| v.code == ReplayViolationCode::MissingTerminalEvent));
}
#[test]
fn rejects_mismatched_enter_exit() {
let trace = WorkflowTrace {
metadata: metadata(),
events: vec![
TraceEvent {
seq: 0,
timestamp_unix_ms: 101,
kind: TraceEventKind::NodeEnter {
node_id: "a".to_string(),
},
},
TraceEvent {
seq: 1,
timestamp_unix_ms: 102,
kind: TraceEventKind::NodeExit {
node_id: "b".to_string(),
},
},
TraceEvent {
seq: 2,
timestamp_unix_ms: 103,
kind: TraceEventKind::Terminal {
status: TraceTerminalStatus::Failed,
},
},
],
};
let err = replay_trace(&trace).expect_err("should reject mismatched lifecycle");
assert!(err
.violations
.iter()
.any(|v| v.code == ReplayViolationCode::MismatchedNodeLifecycle));
}
#[test]
fn supports_cache_policy_options() {
let recorder = TraceRecorder::new(metadata());
recorder.record_node_enter(101, "start").unwrap();
recorder.record_node_exit(102, "start").unwrap();
recorder
.record_terminal(103, TraceTerminalStatus::Completed)
.unwrap();
let trace = recorder.finalize(104).unwrap();
let report = replay_trace_with_options(
&trace,
&ReplayOptions {
cache_policy: ReplayCachePolicy::Mixed,
},
)
.expect("mixed policy should replay trace");
assert_eq!(report.total_events, 3);
}
}