#![allow(clippy::arc_with_non_send_sync)]
use crate::dag::InternalNode;
use crate::dag::NodeName;
use crate::rust_features::ExpectNone;
use crate::saga_action_error::ActionError;
use crate::saga_action_error::UndoActionError;
use crate::saga_action_generic::Action;
use crate::saga_action_generic::ActionConstant;
use crate::saga_action_generic::ActionData;
use crate::saga_action_generic::ActionInjectError;
use crate::saga_log::SagaNodeEventType;
use crate::saga_log::SagaNodeLoadStatus;
use crate::sec::RepeatInjected;
use crate::sec::SecExecClient;
use crate::ActionRegistry;
use crate::SagaCachedState;
use crate::SagaDag;
use crate::SagaId;
use crate::SagaLog;
use crate::SagaNodeEvent;
use crate::SagaNodeId;
use crate::SagaType;
use anyhow::anyhow;
use anyhow::ensure;
use anyhow::Context;
use futures::channel::mpsc;
use futures::future::BoxFuture;
use futures::lock::Mutex;
use futures::FutureExt;
use futures::StreamExt;
use futures::TryStreamExt;
use petgraph::algo::toposort;
use petgraph::graph::NodeIndex;
use petgraph::visit::Topo;
use petgraph::visit::Walker;
use petgraph::Direction;
use petgraph::Graph;
use petgraph::Incoming;
use petgraph::Outgoing;
use serde_json::json;
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::convert::TryFrom;
use std::fmt;
use std::future::Future;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
struct SgnsDone(Arc<serde_json::Value>);
struct SgnsFailed(ActionError);
struct SgnsUndone(UndoMode);
struct SgnsUndoFailed(UndoActionError);
struct SagaNode<S: SagaNodeStateType> {
node_id: NodeIndex,
state: S,
}
trait SagaNodeStateType {}
impl SagaNodeStateType for SgnsDone {}
impl SagaNodeStateType for SgnsFailed {}
impl SagaNodeStateType for SgnsUndone {}
impl SagaNodeStateType for SgnsUndoFailed {}
trait SagaNodeRest<UserType: SagaType>: Send + Sync {
fn propagate(
&self,
exec: &SagaExecutor<UserType>,
live_state: &mut SagaExecLiveState,
);
fn log_event(&self) -> SagaNodeEventType;
}
impl<UserType: SagaType> SagaNodeRest<UserType> for SagaNode<SgnsDone> {
fn log_event(&self) -> SagaNodeEventType {
SagaNodeEventType::Succeeded(Arc::clone(&self.state.0))
}
fn propagate(
&self,
exec: &SagaExecutor<UserType>,
live_state: &mut SagaExecLiveState,
) {
let graph = &exec.dag.graph;
assert!(!live_state.node_errors.contains_key(&self.node_id));
live_state
.node_outputs
.insert(self.node_id, Arc::clone(&self.state.0))
.expect_none("node finished twice (storing output)");
if self.node_id == exec.dag.end_node {
assert!(!live_state.stopping);
assert_eq!(live_state.exec_state, SagaCachedState::Running);
assert_eq!(graph.node_count(), live_state.node_outputs.len());
live_state.mark_saga_done();
return;
}
if live_state.stopping {
return;
}
if live_state.exec_state == SagaCachedState::Unwinding {
if neighbors_all(graph, &self.node_id, Outgoing, |child| {
live_state.nodes_undone.contains_key(child)
}) {
live_state.queue_undo.push(self.node_id);
}
return;
}
for child in graph.neighbors_directed(self.node_id, Outgoing) {
if neighbors_all(graph, &child, Incoming, |parent| {
live_state.node_outputs.contains_key(parent)
}) {
live_state.queue_todo.push(child);
}
}
}
}
impl<UserType: SagaType> SagaNodeRest<UserType> for SagaNode<SgnsFailed> {
fn log_event(&self) -> SagaNodeEventType {
SagaNodeEventType::Failed(self.state.0.clone())
}
fn propagate(
&self,
exec: &SagaExecutor<UserType>,
live_state: &mut SagaExecLiveState,
) {
let graph = &exec.dag.graph;
assert!(!live_state.node_outputs.contains_key(&self.node_id));
live_state
.node_errors
.insert(self.node_id, self.state.0.clone())
.expect_none("node finished twice (storing error)");
if live_state.stopping {
return;
}
if live_state.exec_state == SagaCachedState::Unwinding {
if neighbors_all(graph, &self.node_id, Outgoing, |child| {
live_state.nodes_undone.contains_key(child)
}) {
let new_node = SagaNode {
node_id: self.node_id,
state: SgnsUndone(UndoMode::ActionFailed),
};
new_node.propagate(exec, live_state);
}
} else {
live_state.exec_state = SagaCachedState::Unwinding;
assert_ne!(self.node_id, exec.dag.end_node);
let new_node = SagaNode {
node_id: exec.dag.end_node,
state: SgnsUndone(UndoMode::ActionNeverRan),
};
new_node.propagate(exec, live_state);
}
}
}
impl<UserType: SagaType> SagaNodeRest<UserType> for SagaNode<SgnsUndone> {
fn log_event(&self) -> SagaNodeEventType {
SagaNodeEventType::UndoFinished
}
fn propagate(
&self,
exec: &SagaExecutor<UserType>,
live_state: &mut SagaExecLiveState,
) {
let graph = &exec.dag.graph;
live_state
.nodes_undone
.insert(self.node_id, self.state.0)
.expect_none("node already undone");
if self.node_id == exec.dag.start_node {
assert!(!live_state.stopping);
live_state.mark_saga_done();
return;
}
if live_state.stopping {
return;
}
assert_eq!(live_state.exec_state, SagaCachedState::Unwinding);
for parent in graph.neighbors_directed(self.node_id, Incoming) {
if neighbors_all(graph, &parent, Outgoing, |child| {
live_state.nodes_undone.contains_key(child)
}) {
match live_state.node_exec_state(parent) {
NodeExecState::Blocked => {
let new_node = SagaNode {
node_id: parent,
state: SgnsUndone(UndoMode::ActionNeverRan),
};
new_node.propagate(exec, live_state);
continue;
}
NodeExecState::Failed => {
let new_node = SagaNode {
node_id: parent,
state: SgnsUndone(UndoMode::ActionFailed),
};
new_node.propagate(exec, live_state);
continue;
}
NodeExecState::QueuedToRun
| NodeExecState::TaskInProgress => {
continue;
}
NodeExecState::Done => {
live_state.queue_undo.push(parent);
}
NodeExecState::QueuedToUndo
| NodeExecState::UndoInProgress
| NodeExecState::Undone(_)
| NodeExecState::UndoFailed => {
panic!(
"already undoing or undone node whose child was \
just now undone"
);
}
}
}
}
}
}
impl<UserType: SagaType> SagaNodeRest<UserType> for SagaNode<SgnsUndoFailed> {
fn log_event(&self) -> SagaNodeEventType {
SagaNodeEventType::UndoFailed(self.state.0.clone())
}
fn propagate(
&self,
_exec: &SagaExecutor<UserType>,
live_state: &mut SagaExecLiveState,
) {
assert!(live_state.exec_state == SagaCachedState::Unwinding);
live_state
.undo_errors
.insert(self.node_id, self.state.0.clone())
.expect_none("undo node failed twice (storing error)");
live_state.saga_stuck();
}
}
struct TaskCompletion<UserType: SagaType> {
node_id: NodeIndex,
node: Box<dyn SagaNodeRest<UserType>>,
}
struct TaskParams<UserType: SagaType> {
dag: Arc<SagaDag>,
user_context: Arc<UserType::ExecContextType>,
live_state: Arc<Mutex<SagaExecLiveState>>,
node_id: NodeIndex,
done_tx: mpsc::Sender<TaskCompletion<UserType>>,
ancestor_tree: Arc<BTreeMap<NodeName, Arc<serde_json::Value>>>,
saga_params: Arc<serde_json::Value>,
action: Arc<dyn Action<UserType>>,
injected_repeat: Option<RepeatInjected>,
}
#[derive(Debug)]
pub struct SagaExecutor<UserType: SagaType> {
#[allow(dead_code)]
log: slog::Logger,
dag: Arc<SagaDag>,
action_registry: Arc<ActionRegistry<UserType>>,
finish_tx: broadcast::Sender<()>,
saga_id: SagaId,
node_saga_start: BTreeMap<NodeIndex, NodeIndex>,
live_state: Arc<Mutex<SagaExecLiveState>>,
user_context: Arc<UserType::ExecContextType>,
}
#[derive(Debug)]
enum RecoveryDirection {
Forward(bool),
Unwind(bool),
}
impl<UserType: SagaType> SagaExecutor<UserType> {
pub fn new(
log: slog::Logger,
saga_id: SagaId,
dag: Arc<SagaDag>,
registry: Arc<ActionRegistry<UserType>>,
user_context: Arc<UserType::ExecContextType>,
sec_hdl: SecExecClient,
) -> Result<SagaExecutor<UserType>, anyhow::Error> {
let sglog = SagaLog::new_empty(saga_id);
SagaExecutor::new_recover(
log,
saga_id,
dag,
registry,
user_context,
sec_hdl,
sglog,
)
}
pub fn new_recover(
log: slog::Logger,
saga_id: SagaId,
dag: Arc<SagaDag>,
registry: Arc<ActionRegistry<UserType>>,
user_context: Arc<UserType::ExecContextType>,
sec_hdl: SecExecClient,
sglog: SagaLog,
) -> Result<SagaExecutor<UserType>, anyhow::Error> {
Self::validate_saga(&dag, ®istry).with_context(|| {
format!("validating saga {:?}", dag.saga_name())
})?;
let forward = !sglog.unwinding();
let mut live_state = SagaExecLiveState {
stopping: false,
exec_state: if forward {
SagaCachedState::Running
} else {
SagaCachedState::Unwinding
},
queue_todo: Vec::new(),
queue_undo: Vec::new(),
node_tasks: BTreeMap::new(),
node_outputs: BTreeMap::new(),
nodes_undone: BTreeMap::new(),
node_errors: BTreeMap::new(),
undo_errors: BTreeMap::new(),
sglog,
injected_errors: BTreeSet::new(),
injected_undo_errors: BTreeSet::new(),
injected_repeats: BTreeMap::new(),
sec_hdl,
saga_id,
};
let mut loaded = BTreeSet::new();
let graph = &dag.graph;
let nodes_sorted = toposort(&graph, None).expect("saga DAG had cycles");
let node_saga_start = {
let mut node_saga_start = BTreeMap::new();
for node_index in &nodes_sorted {
let node = graph.node_weight(*node_index).unwrap();
let subsaga_start_index = match node {
InternalNode::Start { .. }
| InternalNode::SubsagaStart { .. } => {
*node_index
}
InternalNode::End
| InternalNode::Action { .. }
| InternalNode::Constant { .. }
| InternalNode::SubsagaEnd { .. } => {
let immed_ancestor = graph
.neighbors_directed(*node_index, petgraph::Incoming)
.next()
.unwrap();
let immed_ancestor_node =
dag.get(immed_ancestor).unwrap();
let ancestor = match immed_ancestor_node {
InternalNode::SubsagaEnd { .. } => {
let subsaga_start = *node_saga_start
.get(&immed_ancestor)
.unwrap();
graph
.neighbors_directed(
subsaga_start,
petgraph::Incoming,
)
.next()
.unwrap()
}
_ => immed_ancestor,
};
*node_saga_start.get(&ancestor).expect(
"expected to compute ancestor's subsaga start \
node first",
)
}
};
node_saga_start.insert(*node_index, subsaga_start_index);
}
node_saga_start
};
let graph_nodes = {
let mut nodes = nodes_sorted;
if !forward {
nodes.reverse();
}
nodes
};
for node_id in graph_nodes {
let node_status =
live_state.sglog.load_status_for_node(node_id.into());
for parent in graph.neighbors_directed(node_id, Incoming) {
let parent_status =
live_state.sglog.load_status_for_node(parent.into());
if !recovery_validate_parent(parent_status, node_status) {
return Err(anyhow!(
"recovery for saga {}: node {:?}: load status is \
\"{:?}\", which is illegal for parent load status \
\"{:?}\"",
saga_id,
node_id,
node_status,
parent_status,
));
}
}
let direction = if forward {
RecoveryDirection::Forward(neighbors_all(
graph,
&node_id,
Incoming,
|p| {
assert!(loaded.contains(p));
live_state.node_outputs.contains_key(p)
},
))
} else {
RecoveryDirection::Unwind(neighbors_all(
graph,
&node_id,
Outgoing,
|c| {
assert!(loaded.contains(c));
live_state.nodes_undone.contains_key(c)
},
))
};
match node_status {
SagaNodeLoadStatus::NeverStarted => {
match direction {
RecoveryDirection::Forward(true) => {
live_state.queue_todo.push(node_id);
}
RecoveryDirection::Unwind(true) => {
live_state
.nodes_undone
.insert(node_id, UndoMode::ActionNeverRan);
}
_ => (),
}
}
SagaNodeLoadStatus::Started => {
live_state.queue_todo.push(node_id);
}
SagaNodeLoadStatus::Succeeded(output) => {
assert!(!live_state.node_errors.contains_key(&node_id));
live_state
.node_outputs
.insert(node_id, Arc::clone(output))
.expect_none("recovered node twice (success case)");
if let RecoveryDirection::Unwind(true) = direction {
live_state.queue_undo.push(node_id);
}
}
SagaNodeLoadStatus::Failed(error) => {
assert!(!live_state.node_outputs.contains_key(&node_id));
live_state
.node_errors
.insert(node_id, error.clone())
.expect_none("recovered node twice (failure case)");
if let RecoveryDirection::Unwind(true) = direction {
live_state
.nodes_undone
.insert(node_id, UndoMode::ActionFailed);
}
}
SagaNodeLoadStatus::UndoStarted(output) => {
assert!(!forward);
live_state.queue_undo.push(node_id);
live_state
.node_outputs
.insert(node_id, Arc::clone(output))
.expect_none("recovered node twice (undo case)");
}
SagaNodeLoadStatus::UndoFinished => {
assert!(!forward);
live_state
.nodes_undone
.insert(node_id, UndoMode::ActionUndone);
}
SagaNodeLoadStatus::UndoFailed(error) => {
assert!(!forward);
live_state
.undo_errors
.insert(node_id, error.clone())
.expect_none(
"recovered node twice (undo failure case)",
);
live_state.saga_stuck();
}
}
assert!(loaded.insert(node_id));
}
if live_state.node_outputs.contains_key(&dag.end_node)
|| live_state.nodes_undone.contains_key(&dag.start_node)
{
live_state.mark_saga_done();
}
let (finish_tx, _) = broadcast::channel(1);
Ok(SagaExecutor {
log,
dag,
finish_tx,
saga_id,
user_context,
live_state: Arc::new(Mutex::new(live_state)),
action_registry: Arc::clone(®istry),
node_saga_start,
})
}
fn validate_saga(
saga: &SagaDag,
registry: &ActionRegistry<UserType>,
) -> Result<(), anyhow::Error> {
let mut nsubsaga_start = 0;
let mut nsubsaga_end = 0;
for node_index in saga.graph.node_indices() {
let node = &saga.graph[node_index];
match node {
InternalNode::Start { .. } => {
ensure!(
node_index == saga.start_node,
"found start node at unexpected index {:?}",
node_index
);
}
InternalNode::End => {
ensure!(
node_index == saga.end_node,
"found end node at unexpected index {:?}",
node_index
);
}
InternalNode::Action { name, action_name, .. } => {
let action = registry.get(&action_name);
ensure!(
action.is_ok(),
"action for node {:?} not registered: {:?}",
name,
action_name
);
}
InternalNode::Constant { .. } => (),
InternalNode::SubsagaStart { .. } => {
nsubsaga_start += 1;
}
InternalNode::SubsagaEnd { .. } => {
nsubsaga_end += 1;
}
}
}
ensure!(
saga.start_node.index() < saga.graph.node_count(),
"bad saga graph (missing start node)",
);
ensure!(
saga.end_node.index() < saga.graph.node_count(),
"bad saga graph (missing end node)",
);
ensure!(
nsubsaga_start == nsubsaga_end,
"bad saga graph (found {} subsaga start nodes but {} subsaga end \
nodes)",
nsubsaga_start,
nsubsaga_end
);
let nend_ancestors =
saga.graph.neighbors_directed(saga.end_node, Incoming).count();
ensure!(
nend_ancestors == 1,
"expected saga to end with exactly one node"
);
Ok(())
}
fn make_ancestor_tree(
&self,
tree: &mut BTreeMap<NodeName, Arc<serde_json::Value>>,
live_state: &SagaExecLiveState,
node_index: NodeIndex,
include_self: bool,
) {
if include_self {
self.make_ancestor_tree_node(tree, live_state, node_index);
return;
}
let ancestors = self.dag.graph.neighbors_directed(node_index, Incoming);
for ancestor in ancestors {
self.make_ancestor_tree_node(tree, live_state, ancestor);
}
}
fn make_ancestor_tree_node(
&self,
tree: &mut BTreeMap<NodeName, Arc<serde_json::Value>>,
live_state: &SagaExecLiveState,
node_index: NodeIndex,
) {
let dag_node = self.dag.get(node_index).unwrap();
match dag_node {
InternalNode::Constant { name, .. }
| InternalNode::Action { name, .. }
| InternalNode::SubsagaEnd { name, .. } => {
let output = live_state.node_output(node_index);
tree.insert(name.clone(), output);
}
InternalNode::Start { .. }
| InternalNode::End
| InternalNode::SubsagaStart { .. } => (),
}
let resume_node = match dag_node {
InternalNode::SubsagaStart { .. } => {
None
}
InternalNode::SubsagaEnd { .. } => {
Some(*self.node_saga_start.get(&node_index).unwrap())
}
InternalNode::Constant { .. }
| InternalNode::Action { .. }
| InternalNode::Start { .. }
| InternalNode::End => {
Some(node_index)
}
};
if let Some(resume_node) = resume_node {
self.make_ancestor_tree(tree, live_state, resume_node, false);
}
}
fn saga_params_for(
&self,
live_state: &SagaExecLiveState,
node_index: NodeIndex,
) -> Arc<serde_json::Value> {
let subsaga_start_index = self.node_saga_start[&node_index];
let subsaga_start_node = self.dag.get(subsaga_start_index).unwrap();
match subsaga_start_node {
InternalNode::Start { params } => params.clone(),
InternalNode::SubsagaStart { params_node_name, .. } => {
let mut tree = BTreeMap::new();
self.make_ancestor_tree(
&mut tree,
live_state,
subsaga_start_index,
false,
);
Arc::clone(tree.get(params_node_name).unwrap())
}
InternalNode::SubsagaEnd { .. }
| InternalNode::End
| InternalNode::Action { .. }
| InternalNode::Constant { .. } => {
panic!(
"containing saga cannot have started with {:?}",
subsaga_start_node
);
}
}
}
pub async fn inject_error(&self, node_id: NodeIndex) {
let mut live_state = self.live_state.lock().await;
live_state.injected_errors.insert(node_id);
}
pub async fn inject_error_undo(&self, node_id: NodeIndex) {
let mut live_state = self.live_state.lock().await;
live_state.injected_undo_errors.insert(node_id);
}
pub async fn inject_repeat(
&self,
node_id: NodeIndex,
repeat: RepeatInjected,
) {
let mut live_state = self.live_state.lock().await;
live_state.injected_repeats.insert(node_id, repeat);
}
async fn run_saga(&self) {
{
let live_state = self.live_state.lock().await;
if live_state.exec_state == SagaCachedState::Done {
self.finish_tx.send(()).expect("failed to send finish message");
live_state.sec_hdl.saga_update(live_state.exec_state).await;
return;
}
}
let (tx, mut rx) = mpsc::channel(2 * self.dag.graph.node_count());
loop {
self.kick_off_ready(&tx).await;
let message = rx.next().await.expect("broken tx");
let task = {
let mut live_state = self.live_state.lock().await;
live_state.node_task_done(message.node_id)
};
task.await.expect("node task failed unexpectedly");
let mut live_state = self.live_state.lock().await;
let prev_state = live_state.exec_state;
message.node.propagate(&self, &mut live_state);
if live_state.exec_state == SagaCachedState::Unwinding
&& prev_state != SagaCachedState::Unwinding
{
live_state
.sec_hdl
.saga_update(SagaCachedState::Unwinding)
.await;
}
if live_state.exec_state == SagaCachedState::Done {
break;
}
}
let live_state = self.live_state.try_lock().unwrap();
assert_eq!(live_state.exec_state, SagaCachedState::Done);
self.finish_tx.send(()).expect("failed to send finish message");
live_state.sec_hdl.saga_update(live_state.exec_state).await;
}
async fn kick_off_ready(
&self,
tx: &mpsc::Sender<TaskCompletion<UserType>>,
) {
let mut live_state = self.live_state.lock().await;
if live_state.stopping {
assert!(!live_state.node_tasks.is_empty());
return;
}
let todo_queue = live_state.queue_todo.clone();
live_state.queue_todo = Vec::new();
for node_id in todo_queue {
let mut ancestor_tree = BTreeMap::new();
self.make_ancestor_tree(
&mut ancestor_tree,
&live_state,
node_id,
false,
);
let saga_params = self.saga_params_for(&live_state, node_id);
let sgaction = if live_state.injected_errors.contains(&node_id) {
Arc::new(ActionInjectError {}) as Arc<dyn Action<UserType>>
} else {
self.node_action(&live_state, node_id)
};
let task_params = TaskParams {
dag: Arc::clone(&self.dag),
live_state: Arc::clone(&self.live_state),
node_id,
done_tx: tx.clone(),
ancestor_tree: Arc::new(ancestor_tree),
saga_params,
action: sgaction,
user_context: Arc::clone(&self.user_context),
injected_repeat: live_state
.injected_repeats
.get(&node_id)
.map(|r| *r),
};
let task = tokio::spawn(SagaExecutor::exec_node(task_params));
live_state.node_task(node_id, task);
}
if live_state.exec_state == SagaCachedState::Running {
assert!(live_state.queue_undo.is_empty());
return;
}
let undo_queue = live_state.queue_undo.clone();
live_state.queue_undo = Vec::new();
for node_id in undo_queue {
let mut ancestor_tree = BTreeMap::new();
self.make_ancestor_tree(
&mut ancestor_tree,
&live_state,
node_id,
true,
);
let saga_params = self.saga_params_for(&live_state, node_id);
let sgaction = if live_state.injected_undo_errors.contains(&node_id)
{
Arc::new(ActionInjectError {}) as Arc<dyn Action<UserType>>
} else {
self.node_action(&live_state, node_id)
};
let task_params = TaskParams {
dag: Arc::clone(&self.dag),
live_state: Arc::clone(&self.live_state),
node_id,
done_tx: tx.clone(),
ancestor_tree: Arc::new(ancestor_tree),
saga_params,
action: sgaction,
user_context: Arc::clone(&self.user_context),
injected_repeat: live_state
.injected_repeats
.get(&node_id)
.map(|r| *r),
};
let task = tokio::spawn(SagaExecutor::undo_node(task_params));
live_state.node_task(node_id, task);
}
}
fn node_action(
&self,
live_state: &SagaExecLiveState,
node_index: NodeIndex,
) -> Arc<dyn Action<UserType>> {
let registry = &self.action_registry;
let dag = &self.dag;
match dag.get(node_index).unwrap() {
InternalNode::Action { action_name: action, .. } => {
registry.get(action).expect("missing action for node")
}
InternalNode::Constant { value, .. } => {
Arc::new(ActionConstant::new(Arc::clone(value)))
}
InternalNode::Start { .. }
| InternalNode::End
| InternalNode::SubsagaStart { .. } => {
Arc::new(ActionConstant::new(Arc::new(serde_json::Value::Null)))
}
InternalNode::SubsagaEnd { .. } => {
let ancestors: Vec<_> = dag
.graph
.neighbors_directed(node_index, Incoming)
.collect();
assert_eq!(ancestors.len(), 1);
Arc::new(ActionConstant::new(
live_state.node_output(ancestors[0]),
))
}
}
}
async fn exec_node(task_params: TaskParams<UserType>) {
let node_id = task_params.node_id;
{
let mut live_state = task_params.live_state.lock().await;
let load_status =
live_state.sglog.load_status_for_node(node_id.into());
match load_status {
SagaNodeLoadStatus::NeverStarted => {
record_now(
&mut live_state,
node_id,
SagaNodeEventType::Started,
)
.await;
}
SagaNodeLoadStatus::Started => (),
SagaNodeLoadStatus::Succeeded(_)
| SagaNodeLoadStatus::Failed(_)
| SagaNodeLoadStatus::UndoStarted(_)
| SagaNodeLoadStatus::UndoFinished
| SagaNodeLoadStatus::UndoFailed(_) => {
panic!("starting node in bad state")
}
}
}
let make_action_context = || ActionContext {
ancestor_tree: Arc::clone(&task_params.ancestor_tree),
saga_params: Arc::clone(&task_params.saga_params),
node_id,
dag: Arc::clone(&task_params.dag),
user_context: Arc::clone(&task_params.user_context),
};
let mut result = task_params.action.do_it(make_action_context()).await;
if let Some(repeat) = task_params.injected_repeat {
for _ in 0..repeat.action.get() - 1 {
result = task_params.action.do_it(make_action_context()).await;
}
}
let node: Box<dyn SagaNodeRest<UserType>> = match result {
Ok(output) => {
Box::new(SagaNode { node_id, state: SgnsDone(output) })
}
Err(error) => {
Box::new(SagaNode { node_id, state: SgnsFailed(error) })
}
};
SagaExecutor::finish_task(task_params, node).await;
}
async fn undo_node(task_params: TaskParams<UserType>) {
let node_id = task_params.node_id;
{
let mut live_state = task_params.live_state.lock().await;
let load_status =
live_state.sglog.load_status_for_node(node_id.into());
match load_status {
SagaNodeLoadStatus::Succeeded(_) => {
record_now(
&mut live_state,
node_id,
SagaNodeEventType::UndoStarted,
)
.await;
}
SagaNodeLoadStatus::UndoStarted(_) => (),
SagaNodeLoadStatus::NeverStarted
| SagaNodeLoadStatus::Started
| SagaNodeLoadStatus::Failed(_)
| SagaNodeLoadStatus::UndoFinished
| SagaNodeLoadStatus::UndoFailed(_) => {
panic!("undoing node in bad state")
}
}
}
let make_action_context = || ActionContext {
ancestor_tree: Arc::clone(&task_params.ancestor_tree),
saga_params: Arc::clone(&task_params.saga_params),
node_id,
dag: Arc::clone(&task_params.dag),
user_context: Arc::clone(&task_params.user_context),
};
let count =
task_params.injected_repeat.map(|r| r.undo.get()).unwrap_or(1);
let action = &task_params.action;
let undo_error = futures::stream::iter(0..count)
.map(Ok::<u32, _>)
.try_for_each(|i| async move {
action
.undo_it(make_action_context())
.await
.with_context(|| format!("undo action attempt {}", i + 1))
})
.await;
if let Err(error) = undo_error {
let node = Box::new(SagaNode {
node_id,
state: SgnsUndoFailed(UndoActionError::PermanentFailure {
source_error: json!({ "message": format!("{:#}", error) }),
}),
});
SagaExecutor::finish_task(task_params, node).await;
} else {
let node = Box::new(SagaNode {
node_id,
state: SgnsUndone(UndoMode::ActionUndone),
});
SagaExecutor::finish_task(task_params, node).await;
};
}
async fn finish_task(
mut task_params: TaskParams<UserType>,
node: Box<dyn SagaNodeRest<UserType>>,
) {
let node_id = task_params.node_id;
let event_type = node.log_event();
{
let mut live_state = task_params.live_state.lock().await;
record_now(&mut live_state, node_id, event_type).await;
}
task_params
.done_tx
.try_send(TaskCompletion { node_id, node })
.expect("unexpected channel failure");
}
pub fn run(&self) -> impl Future<Output = ()> + '_ {
let mut rx = self.finish_tx.subscribe();
async move {
self.run_saga().await;
rx.recv().await.expect("failed to receive finish message")
}
}
pub fn result(&self) -> SagaResult {
let live_state = self
.live_state
.try_lock()
.expect("attempted to get result while saga still running?");
assert_eq!(live_state.exec_state, SagaCachedState::Done);
if !live_state.undo_errors.is_empty() {
let (error_node_id, error_source) =
live_state.node_errors.iter().next().expect(
"expected an action to have failed if an \
undo action failed",
);
let (undo_error_node_id, undo_error_source) =
live_state.undo_errors.iter().next().unwrap();
let error_node_name = self
.dag
.get(*error_node_id)
.unwrap()
.node_name()
.expect("unexpected failure from unnamed node")
.clone();
let undo_error_node_name = self
.dag
.get(*undo_error_node_id)
.unwrap()
.node_name()
.expect("unexpected failure from unnamed undo node")
.clone();
return SagaResult {
saga_id: self.saga_id,
saga_log: live_state.sglog.clone(),
kind: Err(SagaResultErr {
error_node_name,
error_source: error_source.clone(),
undo_failure: Some((
undo_error_node_name,
undo_error_source.clone(),
)),
}),
};
}
if live_state.nodes_undone.contains_key(&self.dag.start_node) {
assert!(live_state.nodes_undone.contains_key(&self.dag.end_node));
let (error_node_id, error_source) =
live_state.node_errors.iter().next().unwrap();
let error_node_name = self
.dag
.get(*error_node_id)
.unwrap()
.node_name()
.expect("unexpected failure from unnamed node")
.clone();
return SagaResult {
saga_id: self.saga_id,
saga_log: live_state.sglog.clone(),
kind: Err(SagaResultErr {
error_node_name,
error_source: error_source.clone(),
undo_failure: None,
}),
};
}
assert!(live_state.nodes_undone.is_empty());
let node_outputs = live_state
.node_outputs
.iter()
.filter_map(|(node_id, node_output)| {
self.dag.get(*node_id).unwrap().node_name().map(|node_name| {
(node_name.clone(), Arc::clone(node_output))
})
})
.collect();
let output_node_index = self
.dag
.graph
.neighbors_directed(self.dag.end_node, Incoming)
.next()
.unwrap();
let saga_output = live_state.node_output(output_node_index);
SagaResult {
saga_id: self.saga_id,
saga_log: live_state.sglog.clone(),
kind: Ok(SagaResultOk { saga_output, node_outputs }),
}
}
pub fn status(&self) -> BoxFuture<'_, SagaExecStatus> {
async move {
let live_state = self.live_state.lock().await;
let mut node_exec_states = BTreeMap::new();
let graph = &self.dag.graph;
let topo_visitor = Topo::new(graph);
for node in topo_visitor.iter(graph) {
node_exec_states.insert(node, live_state.node_exec_state(node));
}
SagaExecStatus {
saga_id: self.saga_id,
dag: Arc::clone(&self.dag),
node_exec_states,
sglog: live_state.sglog.clone(),
}
}
.boxed()
}
}
#[derive(Debug)]
struct SagaExecLiveState {
saga_id: SagaId,
sec_hdl: SecExecClient,
exec_state: SagaCachedState,
stopping: bool,
queue_todo: Vec<NodeIndex>,
queue_undo: Vec<NodeIndex>,
node_tasks: BTreeMap<NodeIndex, JoinHandle<()>>,
node_outputs: BTreeMap<NodeIndex, Arc<serde_json::Value>>,
nodes_undone: BTreeMap<NodeIndex, UndoMode>,
node_errors: BTreeMap<NodeIndex, ActionError>,
undo_errors: BTreeMap<NodeIndex, UndoActionError>,
sglog: SagaLog,
injected_errors: BTreeSet<NodeIndex>,
injected_undo_errors: BTreeSet<NodeIndex>,
injected_repeats: BTreeMap<NodeIndex, RepeatInjected>,
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
enum NodeExecState {
Blocked,
QueuedToRun,
TaskInProgress,
Done,
Failed,
QueuedToUndo,
UndoInProgress,
Undone(UndoMode),
UndoFailed,
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
enum UndoMode {
ActionNeverRan,
ActionUndone,
ActionFailed,
}
impl fmt::Display for NodeExecState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
NodeExecState::Blocked => "blocked",
NodeExecState::QueuedToRun => "queued-todo",
NodeExecState::TaskInProgress => "working",
NodeExecState::Done => "done",
NodeExecState::Failed => "failed",
NodeExecState::QueuedToUndo => "queued-undo",
NodeExecState::UndoInProgress => "undo-working",
NodeExecState::Undone(UndoMode::ActionNeverRan) => "abandoned",
NodeExecState::Undone(UndoMode::ActionUndone) => "undone",
NodeExecState::Undone(UndoMode::ActionFailed) => "failed",
NodeExecState::UndoFailed => "undo-failed",
})
}
}
impl SagaExecLiveState {
fn node_exec_state(&self, node_id: NodeIndex) -> NodeExecState {
let mut set: BTreeSet<NodeExecState> = BTreeSet::new();
let load_status = self.sglog.load_status_for_node(node_id.into());
if let Some(undo_mode) = self.nodes_undone.get(&node_id) {
set.insert(NodeExecState::Undone(*undo_mode));
} else if self.queue_undo.contains(&node_id) {
set.insert(NodeExecState::QueuedToUndo);
} else if let SagaNodeLoadStatus::Failed(_) = load_status {
assert!(self.node_errors.contains_key(&node_id));
set.insert(NodeExecState::Failed);
} else if let SagaNodeLoadStatus::UndoFailed(_) = load_status {
assert!(self.undo_errors.contains_key(&node_id));
set.insert(NodeExecState::UndoFailed);
} else if self.node_outputs.contains_key(&node_id) {
if self.node_tasks.contains_key(&node_id) {
set.insert(NodeExecState::UndoInProgress);
} else {
set.insert(NodeExecState::Done);
}
} else if self.node_tasks.contains_key(&node_id) {
set.insert(NodeExecState::TaskInProgress);
}
if self.queue_todo.contains(&node_id) {
set.insert(NodeExecState::QueuedToRun);
}
if set.is_empty() {
if let SagaNodeLoadStatus::NeverStarted = load_status {
NodeExecState::Blocked
} else {
panic!("could not determine node state");
}
} else {
assert_eq!(set.len(), 1);
let the_state = set.into_iter().last().unwrap();
the_state
}
}
fn mark_saga_done(&mut self) {
assert!(!self.stopping);
assert!(self.queue_todo.is_empty());
assert!(self.queue_undo.is_empty());
assert!(
self.exec_state == SagaCachedState::Running
|| self.exec_state == SagaCachedState::Unwinding
);
self.exec_state = SagaCachedState::Done;
}
fn saga_stuck(&mut self) {
assert!(self.exec_state == SagaCachedState::Unwinding);
self.stopping = true;
if self.node_tasks.is_empty() {
self.exec_state = SagaCachedState::Done;
}
}
fn node_task(&mut self, node_id: NodeIndex, task: JoinHandle<()>) {
assert!(!self.stopping);
self.node_tasks.insert(node_id, task);
}
fn node_task_done(&mut self, node_id: NodeIndex) -> JoinHandle<()> {
let rv = self
.node_tasks
.remove(&node_id)
.expect("processing task completion with no task present");
if self.stopping && self.node_tasks.is_empty() {
self.exec_state = SagaCachedState::Done;
}
rv
}
fn node_output(&self, node_id: NodeIndex) -> Arc<serde_json::Value> {
let output =
self.node_outputs.get(&node_id).expect("node has no output");
Arc::clone(output)
}
}
#[derive(Clone, Debug)]
pub struct SagaResult {
pub saga_id: SagaId,
pub saga_log: SagaLog,
pub kind: Result<SagaResultOk, SagaResultErr>,
}
#[derive(Clone, Debug)]
pub struct SagaResultOk {
saga_output: Arc<serde_json::Value>,
node_outputs: BTreeMap<NodeName, Arc<serde_json::Value>>,
}
impl SagaResultOk {
pub fn saga_output<T: ActionData + 'static>(
&self,
) -> Result<T, ActionError> {
serde_json::from_value((*self.saga_output).clone())
.context("final saga output")
.map_err(ActionError::new_deserialize)
}
pub fn lookup_node_output<T: ActionData + 'static>(
&self,
name: &str,
) -> Result<T, ActionError> {
let key = NodeName::new(name);
let output_json =
self.node_outputs.get(&key).unwrap_or_else(|| {
panic!(
"node with name \"{}\": not part of this saga",
key.as_ref(),
)
});
serde_json::from_value((**output_json).clone())
.context("final node output")
.map_err(ActionError::new_deserialize)
}
}
#[derive(Clone, Debug)]
pub struct SagaResultErr {
pub error_node_name: NodeName,
pub error_source: ActionError,
pub undo_failure: Option<(NodeName, UndoActionError)>,
}
#[derive(Clone, Debug)]
pub struct SagaExecStatus {
saga_id: SagaId,
node_exec_states: BTreeMap<NodeIndex, NodeExecState>,
dag: Arc<SagaDag>,
sglog: SagaLog,
}
impl fmt::Display for SagaExecStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.print(f)
}
}
impl SagaExecStatus {
pub fn log(&self) -> &SagaLog {
&self.sglog
}
pub fn print(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let orderer = PrintOrderer::new(&self.dag);
let output = orderer.print_order();
self.write_header(f)?;
for entry in output {
match entry {
PrintOrderEntry::Node { idx, indent_level } => {
self.print_node(f, idx, indent_level)?;
}
PrintOrderEntry::Parallel { indent_level } => {
Self::write_indented(
f,
indent_level,
"(parallel actions):\n",
)?;
}
}
}
Ok(())
}
fn print_node(
&self,
f: &mut fmt::Formatter<'_>,
idx: NodeIndex,
indent_level: usize,
) -> fmt::Result {
let node = self.dag.get(idx).unwrap();
let label = Self::mklabel(&node);
let state = &self.node_exec_states[&idx];
let msg = format!("{}: {}\n", state, label);
Self::write_indented(f, indent_level, &msg)?;
Ok(())
}
fn write_indented(
out: &mut fmt::Formatter<'_>,
indent_level: usize,
msg: &str,
) -> fmt::Result {
write!(
out,
"{:width$}+-- {}",
"",
msg,
width = Self::big_indent(indent_level)
)
}
fn mklabel(node: &InternalNode) -> String {
if let Some(name) = node.node_name() {
format!("{} (produces {:?})", node.label(), name)
} else {
node.label()
}
}
fn big_indent(indent_level: usize) -> usize {
indent_level * 8
}
fn write_header(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
out,
"{:width$}+ saga execution: {}\n",
"",
self.saga_id,
width = 0
)
}
}
#[derive(Debug, PartialEq)]
enum PrintOrderEntry {
Parallel { indent_level: usize },
Node { idx: NodeIndex, indent_level: usize },
}
impl PrintOrderEntry {
#[allow(unused)]
fn is_node(&self) -> bool {
if let PrintOrderEntry::Node { .. } = *self {
true
} else {
false
}
}
}
#[derive(Debug, PartialEq)]
enum StackEntry {
Parallel(Vec<NodeIndex>),
Subsaga,
}
struct PrintOrderer<'a> {
dag: &'a SagaDag,
output: Vec<PrintOrderEntry>,
stack: Vec<StackEntry>,
idx: NodeIndex,
indent_level: usize,
}
impl<'a> PrintOrderer<'a> {
pub fn new(dag: &'a SagaDag) -> PrintOrderer<'a> {
let idx = dag.start_node;
PrintOrderer {
dag,
output: Vec::new(),
stack: Vec::new(),
idx,
indent_level: 0,
}
}
fn output_current_node(&mut self) {
self.output.push(PrintOrderEntry::Node {
idx: self.idx,
indent_level: self.indent_level,
});
}
fn output_parallel(&mut self) {
self.output.push(PrintOrderEntry::Parallel {
indent_level: self.indent_level,
});
}
fn print_order(mut self) -> Vec<PrintOrderEntry> {
while self.idx != self.dag.end_node {
let node = self.dag.get(self.idx).unwrap();
if let &InternalNode::SubsagaStart { .. } = node {
self.output_current_node();
self.indent_level += 1;
self.stack.push(StackEntry::Subsaga);
self.descend();
} else if let &InternalNode::SubsagaEnd { .. } = node {
self.indent_level -= 1;
self.stack.pop();
self.output_current_node();
if !self.next_parallel_node() {
self.descend();
}
} else {
self.output_current_node();
if !self.next_parallel_node() {
self.descend();
}
}
}
self.output_current_node();
assert!(self.stack.is_empty());
return self.output;
}
fn descend(&mut self) {
let mut children: Vec<NodeIndex> =
self.dag.graph.neighbors_directed(self.idx, Outgoing).collect();
if children.len() == 0 {
assert!(self.stack.is_empty());
assert_eq!(self.dag.end_node, self.idx);
return;
}
if children.len() == 1 {
self.idx = children[0];
} else {
self.output_parallel();
self.idx = children.pop().unwrap();
self.indent_level += 1;
self.stack.push(StackEntry::Parallel(children));
}
}
fn next_parallel_node(&mut self) -> bool {
if let Some(StackEntry::Parallel(nodes)) = self.stack.last_mut() {
if let Some(next_idx) = nodes.pop() {
self.idx = next_idx;
return true;
} else {
self.indent_level -= 1;
self.stack.pop();
}
}
false
}
}
fn neighbors_all<F>(
graph: &Graph<InternalNode, ()>,
node_id: &NodeIndex,
direction: Direction,
test: F,
) -> bool
where
F: Fn(&NodeIndex) -> bool,
{
for p in graph.neighbors_directed(*node_id, direction) {
if !test(&p) {
return false;
}
}
return true;
}
fn recovery_validate_parent(
parent_status: &SagaNodeLoadStatus,
child_status: &SagaNodeLoadStatus,
) -> bool {
match child_status {
SagaNodeLoadStatus::Started
| SagaNodeLoadStatus::Succeeded(_)
| SagaNodeLoadStatus::UndoStarted(_) => {
matches!(parent_status, SagaNodeLoadStatus::Succeeded(_))
}
SagaNodeLoadStatus::Failed(_) => {
matches!(
parent_status,
SagaNodeLoadStatus::Succeeded(_)
| SagaNodeLoadStatus::UndoStarted(_)
| SagaNodeLoadStatus::UndoFinished
| SagaNodeLoadStatus::UndoFailed(_)
)
}
SagaNodeLoadStatus::UndoFinished => matches!(
parent_status,
SagaNodeLoadStatus::Succeeded(_)
| SagaNodeLoadStatus::UndoStarted(_)
| SagaNodeLoadStatus::UndoFinished
| SagaNodeLoadStatus::UndoFailed(_)
),
SagaNodeLoadStatus::UndoFailed(_) => {
matches!(parent_status, SagaNodeLoadStatus::Succeeded(_))
}
SagaNodeLoadStatus::NeverStarted => matches!(
parent_status,
SagaNodeLoadStatus::NeverStarted
| SagaNodeLoadStatus::Started
| SagaNodeLoadStatus::Succeeded(_)
| SagaNodeLoadStatus::Failed(_)
),
}
}
pub struct ActionContext<UserType: SagaType> {
ancestor_tree: Arc<BTreeMap<NodeName, Arc<serde_json::Value>>>,
node_id: NodeIndex,
dag: Arc<SagaDag>,
user_context: Arc<UserType::ExecContextType>,
saga_params: Arc<serde_json::Value>,
}
impl<UserType: SagaType> ActionContext<UserType> {
pub fn lookup<T: ActionData + 'static>(
&self,
name: &str,
) -> Result<T, ActionError> {
let key = name.to_string();
let item = self
.ancestor_tree
.get(&NodeName::new(key))
.unwrap_or_else(|| panic!("no ancestor called \"{}\"", name));
serde_json::from_value((**item).clone())
.with_context(|| format!("output from earlier node {:?}", name))
.map_err(ActionError::new_deserialize)
}
pub fn saga_params<T: ActionData + 'static>(
&self,
) -> Result<T, ActionError> {
serde_json::from_value((*self.saga_params).clone())
.with_context(|| {
let as_str = serde_json::to_string(&self.saga_params)
.unwrap_or_else(|_| format!("{:?}", self.saga_params));
format!("saga params ({})", as_str)
})
.map_err(ActionError::new_deserialize)
}
pub fn node_label(&self) -> String {
self.dag.get(self.node_id).unwrap().label()
}
pub fn user_data(&self) -> &UserType::ExecContextType {
&self.user_context
}
}
impl From<NodeIndex> for SagaNodeId {
fn from(node_id: NodeIndex) -> SagaNodeId {
SagaNodeId::from(u32::try_from(node_id.index()).unwrap())
}
}
async fn record_now(
live_state: &mut SagaExecLiveState,
node: NodeIndex,
event_type: SagaNodeEventType,
) {
let saga_id = live_state.saga_id;
let node_id = node.into();
let event = SagaNodeEvent { saga_id, node_id, event_type };
live_state.sglog.record(&event).unwrap();
live_state.sec_hdl.record(event).await;
}
pub trait SagaExecManager: fmt::Debug + Send + Sync {
fn run(&self) -> BoxFuture<'_, ()>;
fn result(&self) -> SagaResult;
fn status(&self) -> BoxFuture<'_, SagaExecStatus>;
fn inject_error(&self, node_id: NodeIndex) -> BoxFuture<'_, ()>;
fn inject_error_undo(&self, node_id: NodeIndex) -> BoxFuture<'_, ()>;
fn inject_repeat(
&self,
node_id: NodeIndex,
repeat: RepeatInjected,
) -> BoxFuture<'_, ()>;
}
impl<T> SagaExecManager for SagaExecutor<T>
where
T: SagaType + fmt::Debug,
{
fn run(&self) -> BoxFuture<'_, ()> {
self.run().boxed()
}
fn result(&self) -> SagaResult {
self.result()
}
fn status(&self) -> BoxFuture<'_, SagaExecStatus> {
self.status()
}
fn inject_error(&self, node_id: NodeIndex) -> BoxFuture<'_, ()> {
self.inject_error(node_id).boxed()
}
fn inject_error_undo(&self, node_id: NodeIndex) -> BoxFuture<'_, ()> {
self.inject_error_undo(node_id).boxed()
}
fn inject_repeat(
&self,
node_id: NodeIndex,
repeat: RepeatInjected,
) -> BoxFuture<'_, ()> {
self.inject_repeat(node_id, repeat).boxed()
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{DagBuilder, Node, SagaDag, SagaName};
use petgraph::graph::NodeIndex;
use std::fmt::Write;
fn constant(name: &str) -> Node {
Node::constant(name, serde_json::Value::Null)
}
fn constant_names_match(
names: &[&str],
indexes: &[NodeIndex],
dag: &SagaDag,
) -> bool {
assert_eq!(names.len(), indexes.len());
for i in 0..names.len() {
if !constant_name_matches(names[i], indexes[i], dag) {
return false;
}
}
true
}
fn constant_name_matches(
name: &str,
idx: NodeIndex,
dag: &SagaDag,
) -> bool {
let node = dag.get(idx).unwrap();
matches!(
node,
InternalNode::Constant { name: a, .. }
if a == &NodeName::new(name)
)
}
fn is_start_node(idx: NodeIndex, dag: &SagaDag) -> bool {
if let InternalNode::Start { .. } = dag.get(idx).unwrap() {
true
} else {
false
}
}
fn is_end_node(idx: NodeIndex, dag: &SagaDag) -> bool {
if let InternalNode::End = dag.get(idx).unwrap() {
true
} else {
false
}
}
fn print_for_testing(
entries: &Vec<PrintOrderEntry>,
dag: &SagaDag,
) -> String {
let mut out = String::new();
for entry in entries {
match entry {
PrintOrderEntry::Node { idx, indent_level } => {
let node = dag.get(*idx).unwrap();
write!(&mut out, "{}{:?}\n", spaces(*indent_level), node)
.unwrap();
}
PrintOrderEntry::Parallel { indent_level } => {
write!(
&mut out,
"{}{:?}\n",
spaces(*indent_level),
"Parallel: "
)
.unwrap();
}
}
}
out
}
fn spaces(indent_level: usize) -> String {
let num_spaces = indent_level * 4;
(0..num_spaces).fold(String::new(), |mut acc, _| {
acc.push(' ');
acc
})
}
#[test]
fn test_print_order_no_subsagas_no_parallel() {
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(constant("a"));
builder.append(constant("b"));
let dag = builder.build().unwrap();
let saga_dag = SagaDag::new(dag, serde_json::Value::Null);
let orderer = PrintOrderer::new(&saga_dag);
let entries = orderer.print_order();
assert_eq!(4, entries.len());
let mut indexes = Vec::new();
for entry in entries {
match entry {
PrintOrderEntry::Node { idx, indent_level } => {
indexes.push(idx);
assert_eq!(indent_level, 0);
}
_ => panic!("No parallel nodes should exist"),
}
}
assert!(is_start_node(indexes[0], &saga_dag));
assert!(constant_name_matches("a", indexes[1], &saga_dag));
assert!(constant_name_matches("b", indexes[2], &saga_dag));
assert!(is_end_node(indexes[3], &saga_dag));
}
#[test]
fn test_print_order_parallel_nodes_no_subsagas() {
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(constant("a"));
builder.append_parallel(vec![constant("b"), constant("c")]);
builder.append(constant("d"));
let dag = builder.build().unwrap();
let saga_dag = SagaDag::new(dag, serde_json::Value::Null);
let orderer = PrintOrderer::new(&saga_dag);
let entries = orderer.print_order();
assert_eq!(7, entries.len());
let mut actual_indexes = Vec::new();
for i in 0..7 {
match entries[i] {
PrintOrderEntry::Node { idx, indent_level } => match i {
0 => {
assert_eq!(indent_level, 0);
assert!(is_start_node(idx, &saga_dag));
}
1 | 5 => {
assert_eq!(indent_level, 0);
actual_indexes.push(idx);
}
3..=4 => {
assert_eq!(indent_level, 1);
actual_indexes.push(idx);
}
6 => {
assert_eq!(indent_level, 0);
assert!(is_end_node(idx, &saga_dag));
}
_ => panic!("invalid entry"),
},
PrintOrderEntry::Parallel { indent_level } => {
assert_eq!(2, i);
assert_eq!(indent_level, 0);
}
}
}
let expected_names = vec!["a", "b", "c", "d"];
assert!(constant_names_match(
&expected_names,
&actual_indexes,
&saga_dag
));
}
#[test]
fn test_print_order_nested_parallel_nodes_and_subsagas() {
let mut nested_subsaga =
DagBuilder::new(SagaName::new("test-nested-subsaga"));
nested_subsaga.append(constant("a"));
nested_subsaga.append_parallel(vec![constant("b"), constant("c")]);
nested_subsaga.append(constant("d"));
let nested_subsaga_dag = nested_subsaga.build().unwrap();
let mut subsaga = DagBuilder::new(SagaName::new("test-subsaga"));
subsaga.append(constant("a"));
subsaga.append_parallel(vec![constant("b"), constant("c")]);
subsaga.append(constant("d"));
subsaga.append(Node::subsaga("e", nested_subsaga_dag.clone(), "d"));
subsaga.append_parallel(vec![
constant("f"),
Node::subsaga("g", nested_subsaga_dag, "e"),
constant("h"),
]);
subsaga.append(constant("i"));
let subsaga_dag = subsaga.build().unwrap();
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(constant("a"));
builder.append_parallel(vec![constant("b"), constant("c")]);
builder.append(constant("d"));
builder.append(Node::subsaga("e", subsaga_dag, "d"));
let dag = builder.build().unwrap();
let saga_dag = SagaDag::new(dag, serde_json::Value::Null);
let orderer = PrintOrderer::new(&saga_dag);
let entries = orderer.print_order();
let actual = print_for_testing(&entries, &saga_dag);
let expected = "\
Start { params: Null }
Constant { name: \"a\", value: Null }
\"Parallel: \"
Constant { name: \"b\", value: Null }
Constant { name: \"c\", value: Null }
Constant { name: \"d\", value: Null }
SubsagaStart { saga_name: \"test-subsaga\", params_node_name: \"d\" }
Constant { name: \"a\", value: Null }
\"Parallel: \"
Constant { name: \"b\", value: Null }
Constant { name: \"c\", value: Null }
Constant { name: \"d\", value: Null }
SubsagaStart { saga_name: \"test-nested-subsaga\", params_node_name: \"d\" \
}
Constant { name: \"a\", value: Null }
\"Parallel: \"
Constant { name: \"b\", value: Null }
Constant { name: \"c\", value: Null }
Constant { name: \"d\", value: Null }
SubsagaEnd { name: \"e\" }
\"Parallel: \"
Constant { name: \"f\", value: Null }
SubsagaStart { saga_name: \"test-nested-subsaga\", params_node_name: \
\"e\" }
Constant { name: \"a\", value: Null }
\"Parallel: \"
Constant { name: \"b\", value: Null }
Constant { name: \"c\", value: Null }
Constant { name: \"d\", value: Null }
SubsagaEnd { name: \"g\" }
Constant { name: \"h\", value: Null }
Constant { name: \"i\", value: Null }
SubsagaEnd { name: \"e\" }
End
";
assert_eq!(actual, expected);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use crate::{Dag, DagBuilder, Node, SagaDag, SagaName};
use petgraph::graph::NodeIndex;
use proptest::prelude::*;
#[derive(Clone, Debug)]
enum NodeDesc {
Constant,
Parallel(Vec<NodeDesc>),
Subsaga(Vec<NodeDesc>),
}
impl NodeDesc {
fn is_parallel(&self) -> bool {
if let NodeDesc::Parallel(_) = *self {
true
} else {
false
}
}
}
fn arb_nodedesc() -> impl Strategy<Value = NodeDesc> {
let num_levels = 8;
let max_size = 256;
let items_per_collection = 10;
let leaf = prop_oneof![Just(NodeDesc::Constant)];
leaf.prop_recursive(
num_levels,
max_size,
items_per_collection,
|inner| {
prop_oneof![
prop::collection::vec(inner.clone(), 2..10).prop_map(|v| {
if v.iter().any(|node_desc| node_desc.is_parallel()) {
NodeDesc::Subsaga(v)
} else {
NodeDesc::Parallel(v)
}
}),
prop::collection::vec(inner, 1..10)
.prop_map(NodeDesc::Subsaga)
]
},
)
}
fn new_dag(nodes: &Vec<NodeDesc>, depth: usize) -> Dag {
let name = SagaName::new(&format!("test-saga-{}", depth));
let mut dag = DagBuilder::new(name);
let params_node_name = "0";
dag.append(Node::constant(params_node_name, serde_json::Value::Null));
let mut node_name = 1;
for node in nodes {
match node {
NodeDesc::Constant => {
dag.append(Node::constant(
&node_name.to_string(),
serde_json::Value::Null,
));
node_name += 1;
}
NodeDesc::Parallel(parallel_nodes) => {
let mut output = Vec::with_capacity(parallel_nodes.len());
for node in parallel_nodes {
match node {
NodeDesc::Constant => {
output.push(Node::constant(
&node_name.to_string(),
serde_json::Value::Null,
));
node_name += 1;
}
NodeDesc::Subsaga(subsaga_nodes) => {
let subsaga_dag =
new_dag(subsaga_nodes, depth + 1);
output.push(Node::subsaga(
&node_name.to_string(),
subsaga_dag,
params_node_name,
));
node_name += 1;
}
NodeDesc::Parallel(_) => panic!(
"Strategy Generation Error: Nested \
`NodeDesc::Parallel` not allowed!"
),
}
}
dag.append_parallel(output);
}
NodeDesc::Subsaga(subsaga_nodes) => {
let subsaga_dag = new_dag(subsaga_nodes, depth + 1);
dag.append(Node::subsaga(
&node_name.to_string(),
subsaga_dag,
params_node_name,
));
node_name += 1;
}
}
}
dag.append(Node::constant(
&node_name.to_string(),
serde_json::Value::Null,
));
dag.build().unwrap()
}
#[derive(Debug, Clone, PartialEq)]
enum IndentStackEntry {
Parallel,
Subsaga,
}
fn num_ancestors(dag: &SagaDag, idx: NodeIndex) -> usize {
dag.graph.edges_directed(idx, Direction::Incoming).count()
}
fn num_ancestors_of_child(dag: &SagaDag, idx: NodeIndex) -> usize {
let child = dag
.graph
.neighbors_directed(idx, Direction::Outgoing)
.next()
.unwrap();
dag.graph.edges_directed(child, Direction::Incoming).count()
}
fn appended_in_parallel(dag: &SagaDag, idx: NodeIndex) -> bool {
num_ancestors_of_child(dag, idx) > 1
}
fn property_indents_are_correct(
entries: &Vec<PrintOrderEntry>,
dag: &SagaDag,
) -> Result<(), TestCaseError> {
let mut indent_stack = Vec::new();
for entry in entries {
match entry {
PrintOrderEntry::Node { idx, indent_level } => {
let node = dag.get(*idx).unwrap();
match node {
InternalNode::Start { .. } => {
prop_assert_eq!(0, *indent_level);
prop_assert_eq!(indent_stack.len(), *indent_level);
}
InternalNode::End { .. } => {
prop_assert_eq!(0, *indent_level);
prop_assert_eq!(indent_stack.len(), *indent_level);
}
InternalNode::Action { .. } => {
panic!("No actions should exist!")
}
InternalNode::Constant { .. } => {
let parallel = appended_in_parallel(dag, *idx);
if *indent_level == 0 && indent_stack.is_empty() {
prop_assert!(!parallel);
continue;
}
if indent_stack.len() == *indent_level {
if let &IndentStackEntry::Subsaga =
indent_stack.last().unwrap()
{
prop_assert!(!parallel);
} else {
prop_assert!(parallel);
}
} else {
prop_assert_eq!(
indent_stack.len() - 1,
*indent_level
);
prop_assert_eq!(
&IndentStackEntry::Parallel,
indent_stack.last().unwrap()
);
prop_assert!(!parallel);
prop_assert!(num_ancestors(dag, *idx) > 1);
indent_stack.pop();
}
}
InternalNode::SubsagaStart { .. } => {
if indent_stack.len() != *indent_level {
prop_assert_eq!(
indent_stack.len() - 1,
*indent_level
);
prop_assert!(num_ancestors(dag, *idx) > 1);
prop_assert_eq!(
&IndentStackEntry::Parallel,
indent_stack.last().unwrap()
);
indent_stack.pop();
}
indent_stack.push(IndentStackEntry::Subsaga);
}
InternalNode::SubsagaEnd { .. } => {
prop_assert!(!indent_stack.is_empty());
prop_assert_eq!(
indent_stack.len() - 1,
*indent_level
);
prop_assert_eq!(
indent_stack.last().unwrap(),
&IndentStackEntry::Subsaga
);
indent_stack.pop();
}
}
}
PrintOrderEntry::Parallel { indent_level } => {
if indent_stack.len() == *indent_level {
if !indent_stack.is_empty() {
prop_assert_eq!(
&IndentStackEntry::Subsaga,
indent_stack.last().unwrap()
);
}
} else {
prop_assert_eq!(
&IndentStackEntry::Parallel,
indent_stack.last().unwrap()
);
prop_assert_eq!(indent_stack.len() - 1, *indent_level);
indent_stack.pop();
}
indent_stack.push(IndentStackEntry::Parallel);
}
}
}
Ok(())
}
proptest! {
#[test]
fn prints_correctly(nodes in prop::collection::vec(arb_nodedesc(), 1..10)) {
let dag = new_dag(&nodes, 0);
let saga_dag = SagaDag::new(dag, serde_json::Value::Null);
let orderer = PrintOrderer::new(&saga_dag);
let entries = orderer.print_order();
let num_nodes = entries.iter().filter(|e| e.is_node()).count();
prop_assert_eq!(num_nodes, saga_dag.graph.node_count());
property_indents_are_correct(&entries, &saga_dag)?;
}
}
}